import uuid from datetime import datetime, timezone from typing import Protocol from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from ..entities.pack_entities import ( WordBankPackEntity, WordBankPackEntryEntity, WordBankPackFlashcardTemplateEntity, ) from ..entities.vocab_entities import LearnableWordBankEntryEntity from ....domain.models.pack import WordBankPack, WordBankPackEntry, WordBankPackFlashcardTemplate class PackRepository(Protocol): async def create_pack( self, name: str, name_target: str, description: str, description_target: str, source_lang: str, target_lang: str, proficiencies: list[str], ) -> WordBankPack: ... async def update_pack( self, pack_id: uuid.UUID, name: str | None = None, name_target: str | None = None, description: str | None = None, description_target: str | None = None, proficiencies: list[str] | None = None, ) -> WordBankPack: ... async def publish_pack(self, pack_id: uuid.UUID) -> WordBankPack: ... async def get_pack(self, pack_id: uuid.UUID) -> WordBankPack | None: ... async def list_packs( self, source_lang: str | None = None, target_lang: str | None = None, published_only: bool = False, ) -> list[WordBankPack]: ... async def add_entry( self, pack_id: uuid.UUID, sense_id: uuid.UUID | None, surface_text: str, ) -> WordBankPackEntry: ... async def remove_entry(self, entry_id: uuid.UUID) -> None: ... async def get_entries_for_pack(self, pack_id: uuid.UUID) -> list[WordBankPackEntry]: ... async def add_flashcard_template( self, pack_entry_id: uuid.UUID, prompt_text: str, answer_text: str, prompt_context_text: str | None = None, answer_context_text: str | None = None, ) -> WordBankPackFlashcardTemplate: ... async def remove_flashcard_template(self, template_id: uuid.UUID) -> None: ... async def get_templates_for_entry( self, pack_entry_id: uuid.UUID ) -> list[WordBankPackFlashcardTemplate]: ... async def get_templates_for_entries( self, pack_entry_ids: list[uuid.UUID] ) -> dict[str, list[WordBankPackFlashcardTemplate]]: ... async def count_entries_for_pack(self, pack_id: uuid.UUID) -> int: ... async def get_pack_ids_added_by_user( self, user_id: uuid.UUID, language_pair_id: uuid.UUID ) -> set[str]: ... def _pack_to_model(entity: WordBankPackEntity) -> WordBankPack: return WordBankPack( id=str(entity.id), name=entity.name, name_target=entity.name_target, description=entity.description, description_target=entity.description_target, source_lang=entity.source_lang, target_lang=entity.target_lang, proficiencies=entity.proficiencies, is_published=entity.is_published, created_at=entity.created_at, ) def _entry_to_model(entity: WordBankPackEntryEntity) -> WordBankPackEntry: return WordBankPackEntry( id=str(entity.id), pack_id=str(entity.pack_id), sense_id=str(entity.sense_id) if entity.sense_id else None, surface_text=entity.surface_text, created_at=entity.created_at, ) def _template_to_model(entity: WordBankPackFlashcardTemplateEntity) -> WordBankPackFlashcardTemplate: return WordBankPackFlashcardTemplate( id=str(entity.id), pack_entry_id=str(entity.pack_entry_id), prompt_text=entity.prompt_text, answer_text=entity.answer_text, prompt_context_text=entity.prompt_context_text, answer_context_text=entity.answer_context_text, created_at=entity.created_at, ) class PostgresPackRepository: def __init__(self, db: AsyncSession) -> None: self.db = db async def create_pack( self, name: str, name_target: str, description: str, description_target: str, source_lang: str, target_lang: str, proficiencies: list[str], ) -> WordBankPack: entity = WordBankPackEntity( name=name, name_target=name_target, description=description, description_target=description_target, source_lang=source_lang, target_lang=target_lang, proficiencies=proficiencies, created_at=datetime.now(timezone.utc), ) self.db.add(entity) await self.db.commit() await self.db.refresh(entity) return _pack_to_model(entity) async def update_pack( self, pack_id: uuid.UUID, name: str | None = None, name_target: str | None = None, description: str | None = None, description_target: str | None = None, proficiencies: list[str] | None = None, ) -> WordBankPack: result = await self.db.execute( select(WordBankPackEntity).where(WordBankPackEntity.id == pack_id) ) entity = result.scalar_one() if name is not None: entity.name = name if name_target is not None: entity.name_target = name_target if description is not None: entity.description = description if description_target is not None: entity.description_target = description_target if proficiencies is not None: entity.proficiencies = proficiencies await self.db.commit() await self.db.refresh(entity) return _pack_to_model(entity) async def publish_pack(self, pack_id: uuid.UUID) -> WordBankPack: result = await self.db.execute( select(WordBankPackEntity).where(WordBankPackEntity.id == pack_id) ) entity = result.scalar_one() entity.is_published = True await self.db.commit() await self.db.refresh(entity) return _pack_to_model(entity) async def get_pack(self, pack_id: uuid.UUID) -> WordBankPack | None: result = await self.db.execute( select(WordBankPackEntity).where(WordBankPackEntity.id == pack_id) ) entity = result.scalar_one_or_none() return _pack_to_model(entity) if entity else None async def list_packs( self, source_lang: str | None = None, target_lang: str | None = None, published_only: bool = False, ) -> list[WordBankPack]: query = select(WordBankPackEntity) if source_lang: query = query.where(WordBankPackEntity.source_lang == source_lang) if target_lang: query = query.where(WordBankPackEntity.target_lang == target_lang) if published_only: query = query.where(WordBankPackEntity.is_published.is_(True)) query = query.order_by(WordBankPackEntity.created_at.desc()) result = await self.db.execute(query) return [_pack_to_model(e) for e in result.scalars().all()] async def add_entry( self, pack_id: uuid.UUID, sense_id: uuid.UUID | None, surface_text: str, ) -> WordBankPackEntry: entity = WordBankPackEntryEntity( pack_id=pack_id, sense_id=sense_id, surface_text=surface_text, created_at=datetime.now(timezone.utc), ) self.db.add(entity) await self.db.commit() await self.db.refresh(entity) return _entry_to_model(entity) async def remove_entry(self, entry_id: uuid.UUID) -> None: result = await self.db.execute( select(WordBankPackEntryEntity).where(WordBankPackEntryEntity.id == entry_id) ) entity = result.scalar_one_or_none() if entity: await self.db.delete(entity) await self.db.commit() async def get_entries_for_pack(self, pack_id: uuid.UUID) -> list[WordBankPackEntry]: result = await self.db.execute( select(WordBankPackEntryEntity) .where(WordBankPackEntryEntity.pack_id == pack_id) .order_by(WordBankPackEntryEntity.created_at.asc()) ) return [_entry_to_model(e) for e in result.scalars().all()] async def add_flashcard_template( self, pack_entry_id: uuid.UUID, prompt_text: str, answer_text: str, prompt_context_text: str | None = None, answer_context_text: str | None = None, ) -> WordBankPackFlashcardTemplate: entity = WordBankPackFlashcardTemplateEntity( pack_entry_id=pack_entry_id, prompt_text=prompt_text, answer_text=answer_text, prompt_context_text=prompt_context_text, answer_context_text=answer_context_text, created_at=datetime.now(timezone.utc), ) self.db.add(entity) await self.db.commit() await self.db.refresh(entity) return _template_to_model(entity) async def remove_flashcard_template(self, template_id: uuid.UUID) -> None: result = await self.db.execute( select(WordBankPackFlashcardTemplateEntity).where( WordBankPackFlashcardTemplateEntity.id == template_id ) ) entity = result.scalar_one_or_none() if entity: await self.db.delete(entity) await self.db.commit() async def get_templates_for_entry( self, pack_entry_id: uuid.UUID ) -> list[WordBankPackFlashcardTemplate]: result = await self.db.execute( select(WordBankPackFlashcardTemplateEntity) .where(WordBankPackFlashcardTemplateEntity.pack_entry_id == pack_entry_id) .order_by(WordBankPackFlashcardTemplateEntity.created_at.asc()) ) return [_template_to_model(e) for e in result.scalars().all()] async def get_templates_for_entries( self, pack_entry_ids: list[uuid.UUID] ) -> dict[str, list[WordBankPackFlashcardTemplate]]: if not pack_entry_ids: return {} result = await self.db.execute( select(WordBankPackFlashcardTemplateEntity) .where(WordBankPackFlashcardTemplateEntity.pack_entry_id.in_(pack_entry_ids)) .order_by(WordBankPackFlashcardTemplateEntity.created_at.asc()) ) grouped: dict[str, list[WordBankPackFlashcardTemplate]] = {} for entity in result.scalars().all(): key = str(entity.pack_entry_id) grouped.setdefault(key, []).append(_template_to_model(entity)) return grouped async def count_entries_for_pack(self, pack_id: uuid.UUID) -> int: result = await self.db.execute( select(func.count()).where(WordBankPackEntryEntity.pack_id == pack_id) ) return result.scalar_one() async def get_pack_ids_added_by_user( self, user_id: uuid.UUID, language_pair_id: uuid.UUID ) -> set[str]: result = await self.db.execute( select(WordBankPackEntryEntity.pack_id) .join( LearnableWordBankEntryEntity, LearnableWordBankEntryEntity.pack_entry_id == WordBankPackEntryEntity.id, ) .where( LearnableWordBankEntryEntity.user_id == user_id, LearnableWordBankEntryEntity.language_pair_id == language_pair_id, ) .distinct() ) return {str(row) for row in result.scalars().all()}