327 lines
11 KiB
Python
327 lines
11 KiB
Python
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Protocol
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from ..entities.pack_entities import (
|
|
WordBankPackEntity,
|
|
WordBankPackEntryEntity,
|
|
WordBankPackFlashcardTemplateEntity,
|
|
)
|
|
from ..entities.vocab_entities import LearnableWordBankEntryEntity
|
|
from ....domain.models.pack import WordBankPack, WordBankPackEntry, WordBankPackFlashcardTemplate
|
|
|
|
|
|
class PackRepository(Protocol):
|
|
async def create_pack(
|
|
self,
|
|
name: str,
|
|
name_target: str,
|
|
description: str,
|
|
description_target: str,
|
|
source_lang: str,
|
|
target_lang: str,
|
|
proficiencies: list[str],
|
|
) -> WordBankPack: ...
|
|
|
|
async def update_pack(
|
|
self,
|
|
pack_id: uuid.UUID,
|
|
name: str | None = None,
|
|
name_target: str | None = None,
|
|
description: str | None = None,
|
|
description_target: str | None = None,
|
|
proficiencies: list[str] | None = None,
|
|
) -> WordBankPack: ...
|
|
|
|
async def publish_pack(self, pack_id: uuid.UUID) -> WordBankPack: ...
|
|
|
|
async def get_pack(self, pack_id: uuid.UUID) -> WordBankPack | None: ...
|
|
|
|
async def list_packs(
|
|
self,
|
|
source_lang: str | None = None,
|
|
target_lang: str | None = None,
|
|
published_only: bool = False,
|
|
) -> list[WordBankPack]: ...
|
|
|
|
async def add_entry(
|
|
self,
|
|
pack_id: uuid.UUID,
|
|
sense_id: uuid.UUID | None,
|
|
surface_text: str,
|
|
) -> WordBankPackEntry: ...
|
|
|
|
async def remove_entry(self, entry_id: uuid.UUID) -> None: ...
|
|
|
|
async def get_entries_for_pack(self, pack_id: uuid.UUID) -> list[WordBankPackEntry]: ...
|
|
|
|
async def add_flashcard_template(
|
|
self,
|
|
pack_entry_id: uuid.UUID,
|
|
prompt_text: str,
|
|
answer_text: str,
|
|
prompt_context_text: str | None = None,
|
|
answer_context_text: str | None = None,
|
|
) -> WordBankPackFlashcardTemplate: ...
|
|
|
|
async def remove_flashcard_template(self, template_id: uuid.UUID) -> None: ...
|
|
|
|
async def get_templates_for_entry(
|
|
self, pack_entry_id: uuid.UUID
|
|
) -> list[WordBankPackFlashcardTemplate]: ...
|
|
|
|
async def get_templates_for_entries(
|
|
self, pack_entry_ids: list[uuid.UUID]
|
|
) -> dict[str, list[WordBankPackFlashcardTemplate]]: ...
|
|
|
|
async def count_entries_for_pack(self, pack_id: uuid.UUID) -> int: ...
|
|
|
|
async def get_pack_ids_added_by_user(
|
|
self, user_id: uuid.UUID, language_pair_id: uuid.UUID
|
|
) -> set[str]: ...
|
|
|
|
|
|
def _pack_to_model(entity: WordBankPackEntity) -> WordBankPack:
|
|
return WordBankPack(
|
|
id=str(entity.id),
|
|
name=entity.name,
|
|
name_target=entity.name_target,
|
|
description=entity.description,
|
|
description_target=entity.description_target,
|
|
source_lang=entity.source_lang,
|
|
target_lang=entity.target_lang,
|
|
proficiencies=entity.proficiencies,
|
|
is_published=entity.is_published,
|
|
created_at=entity.created_at,
|
|
)
|
|
|
|
|
|
def _entry_to_model(entity: WordBankPackEntryEntity) -> WordBankPackEntry:
|
|
return WordBankPackEntry(
|
|
id=str(entity.id),
|
|
pack_id=str(entity.pack_id),
|
|
sense_id=str(entity.sense_id) if entity.sense_id else None,
|
|
surface_text=entity.surface_text,
|
|
created_at=entity.created_at,
|
|
)
|
|
|
|
|
|
def _template_to_model(entity: WordBankPackFlashcardTemplateEntity) -> WordBankPackFlashcardTemplate:
|
|
return WordBankPackFlashcardTemplate(
|
|
id=str(entity.id),
|
|
pack_entry_id=str(entity.pack_entry_id),
|
|
prompt_text=entity.prompt_text,
|
|
answer_text=entity.answer_text,
|
|
prompt_context_text=entity.prompt_context_text,
|
|
answer_context_text=entity.answer_context_text,
|
|
created_at=entity.created_at,
|
|
)
|
|
|
|
|
|
class PostgresPackRepository:
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
self.db = db
|
|
|
|
async def create_pack(
|
|
self,
|
|
name: str,
|
|
name_target: str,
|
|
description: str,
|
|
description_target: str,
|
|
source_lang: str,
|
|
target_lang: str,
|
|
proficiencies: list[str],
|
|
) -> WordBankPack:
|
|
entity = WordBankPackEntity(
|
|
name=name,
|
|
name_target=name_target,
|
|
description=description,
|
|
description_target=description_target,
|
|
source_lang=source_lang,
|
|
target_lang=target_lang,
|
|
proficiencies=proficiencies,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
self.db.add(entity)
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _pack_to_model(entity)
|
|
|
|
async def update_pack(
|
|
self,
|
|
pack_id: uuid.UUID,
|
|
name: str | None = None,
|
|
name_target: str | None = None,
|
|
description: str | None = None,
|
|
description_target: str | None = None,
|
|
proficiencies: list[str] | None = None,
|
|
) -> WordBankPack:
|
|
result = await self.db.execute(
|
|
select(WordBankPackEntity).where(WordBankPackEntity.id == pack_id)
|
|
)
|
|
entity = result.scalar_one()
|
|
if name is not None:
|
|
entity.name = name
|
|
if name_target is not None:
|
|
entity.name_target = name_target
|
|
if description is not None:
|
|
entity.description = description
|
|
if description_target is not None:
|
|
entity.description_target = description_target
|
|
if proficiencies is not None:
|
|
entity.proficiencies = proficiencies
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _pack_to_model(entity)
|
|
|
|
async def publish_pack(self, pack_id: uuid.UUID) -> WordBankPack:
|
|
result = await self.db.execute(
|
|
select(WordBankPackEntity).where(WordBankPackEntity.id == pack_id)
|
|
)
|
|
entity = result.scalar_one()
|
|
entity.is_published = True
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _pack_to_model(entity)
|
|
|
|
async def get_pack(self, pack_id: uuid.UUID) -> WordBankPack | None:
|
|
result = await self.db.execute(
|
|
select(WordBankPackEntity).where(WordBankPackEntity.id == pack_id)
|
|
)
|
|
entity = result.scalar_one_or_none()
|
|
return _pack_to_model(entity) if entity else None
|
|
|
|
async def list_packs(
|
|
self,
|
|
source_lang: str | None = None,
|
|
target_lang: str | None = None,
|
|
published_only: bool = False,
|
|
) -> list[WordBankPack]:
|
|
query = select(WordBankPackEntity)
|
|
if source_lang:
|
|
query = query.where(WordBankPackEntity.source_lang == source_lang)
|
|
if target_lang:
|
|
query = query.where(WordBankPackEntity.target_lang == target_lang)
|
|
if published_only:
|
|
query = query.where(WordBankPackEntity.is_published.is_(True))
|
|
query = query.order_by(WordBankPackEntity.created_at.desc())
|
|
result = await self.db.execute(query)
|
|
return [_pack_to_model(e) for e in result.scalars().all()]
|
|
|
|
async def add_entry(
|
|
self,
|
|
pack_id: uuid.UUID,
|
|
sense_id: uuid.UUID | None,
|
|
surface_text: str,
|
|
) -> WordBankPackEntry:
|
|
entity = WordBankPackEntryEntity(
|
|
pack_id=pack_id,
|
|
sense_id=sense_id,
|
|
surface_text=surface_text,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
self.db.add(entity)
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _entry_to_model(entity)
|
|
|
|
async def remove_entry(self, entry_id: uuid.UUID) -> None:
|
|
result = await self.db.execute(
|
|
select(WordBankPackEntryEntity).where(WordBankPackEntryEntity.id == entry_id)
|
|
)
|
|
entity = result.scalar_one_or_none()
|
|
if entity:
|
|
await self.db.delete(entity)
|
|
await self.db.commit()
|
|
|
|
async def get_entries_for_pack(self, pack_id: uuid.UUID) -> list[WordBankPackEntry]:
|
|
result = await self.db.execute(
|
|
select(WordBankPackEntryEntity)
|
|
.where(WordBankPackEntryEntity.pack_id == pack_id)
|
|
.order_by(WordBankPackEntryEntity.created_at.asc())
|
|
)
|
|
return [_entry_to_model(e) for e in result.scalars().all()]
|
|
|
|
async def add_flashcard_template(
|
|
self,
|
|
pack_entry_id: uuid.UUID,
|
|
prompt_text: str,
|
|
answer_text: str,
|
|
prompt_context_text: str | None = None,
|
|
answer_context_text: str | None = None,
|
|
) -> WordBankPackFlashcardTemplate:
|
|
entity = WordBankPackFlashcardTemplateEntity(
|
|
pack_entry_id=pack_entry_id,
|
|
prompt_text=prompt_text,
|
|
answer_text=answer_text,
|
|
prompt_context_text=prompt_context_text,
|
|
answer_context_text=answer_context_text,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
self.db.add(entity)
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _template_to_model(entity)
|
|
|
|
async def remove_flashcard_template(self, template_id: uuid.UUID) -> None:
|
|
result = await self.db.execute(
|
|
select(WordBankPackFlashcardTemplateEntity).where(
|
|
WordBankPackFlashcardTemplateEntity.id == template_id
|
|
)
|
|
)
|
|
entity = result.scalar_one_or_none()
|
|
if entity:
|
|
await self.db.delete(entity)
|
|
await self.db.commit()
|
|
|
|
async def get_templates_for_entry(
|
|
self, pack_entry_id: uuid.UUID
|
|
) -> list[WordBankPackFlashcardTemplate]:
|
|
result = await self.db.execute(
|
|
select(WordBankPackFlashcardTemplateEntity)
|
|
.where(WordBankPackFlashcardTemplateEntity.pack_entry_id == pack_entry_id)
|
|
.order_by(WordBankPackFlashcardTemplateEntity.created_at.asc())
|
|
)
|
|
return [_template_to_model(e) for e in result.scalars().all()]
|
|
|
|
async def get_templates_for_entries(
|
|
self, pack_entry_ids: list[uuid.UUID]
|
|
) -> dict[str, list[WordBankPackFlashcardTemplate]]:
|
|
if not pack_entry_ids:
|
|
return {}
|
|
result = await self.db.execute(
|
|
select(WordBankPackFlashcardTemplateEntity)
|
|
.where(WordBankPackFlashcardTemplateEntity.pack_entry_id.in_(pack_entry_ids))
|
|
.order_by(WordBankPackFlashcardTemplateEntity.created_at.asc())
|
|
)
|
|
grouped: dict[str, list[WordBankPackFlashcardTemplate]] = {}
|
|
for entity in result.scalars().all():
|
|
key = str(entity.pack_entry_id)
|
|
grouped.setdefault(key, []).append(_template_to_model(entity))
|
|
return grouped
|
|
|
|
async def count_entries_for_pack(self, pack_id: uuid.UUID) -> int:
|
|
result = await self.db.execute(
|
|
select(func.count()).where(WordBankPackEntryEntity.pack_id == pack_id)
|
|
)
|
|
return result.scalar_one()
|
|
|
|
async def get_pack_ids_added_by_user(
|
|
self, user_id: uuid.UUID, language_pair_id: uuid.UUID
|
|
) -> set[str]:
|
|
result = await self.db.execute(
|
|
select(WordBankPackEntryEntity.pack_id)
|
|
.join(
|
|
LearnableWordBankEntryEntity,
|
|
LearnableWordBankEntryEntity.pack_entry_id == WordBankPackEntryEntity.id,
|
|
)
|
|
.where(
|
|
LearnableWordBankEntryEntity.user_id == user_id,
|
|
LearnableWordBankEntryEntity.language_pair_id == language_pair_id,
|
|
)
|
|
.distinct()
|
|
)
|
|
return {str(row) for row in result.scalars().all()}
|