language-learning-app/api/app/outbound/postgres/repositories/pack_repository.py

331 lines
11 KiB
Python

import uuid
from datetime import datetime, timezone
from typing import Protocol
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from ..entities.pack_entities import (
WordBankPackEntity,
WordBankPackEntryEntity,
WordBankPackFlashcardTemplateEntity,
)
from ..entities.vocab_entities import LearnableWordBankEntryEntity
from ....domain.models.pack import WordBankPack, WordBankPackEntry, WordBankPackFlashcardTemplate
class PackRepository(Protocol):
async def create_pack(
self,
name: str,
name_target: str,
description: str,
description_target: str,
source_lang: str,
target_lang: str,
proficiencies: list[str],
) -> WordBankPack: ...
async def update_pack(
self,
pack_id: uuid.UUID,
name: str | None = None,
name_target: str | None = None,
description: str | None = None,
description_target: str | None = None,
proficiencies: list[str] | None = None,
) -> WordBankPack: ...
async def publish_pack(self, pack_id: uuid.UUID) -> WordBankPack: ...
async def get_pack(self, pack_id: uuid.UUID) -> WordBankPack | None: ...
async def list_packs(
self,
source_lang: str | None = None,
target_lang: str | None = None,
published_only: bool = False,
) -> list[WordBankPack]: ...
async def add_entry(
self,
pack_id: uuid.UUID,
sense_id: uuid.UUID | None,
surface_text: str,
) -> WordBankPackEntry: ...
async def remove_entry(self, entry_id: uuid.UUID) -> None: ...
async def get_entries_for_pack(self, pack_id: uuid.UUID) -> list[WordBankPackEntry]: ...
async def add_flashcard_template(
self,
pack_entry_id: uuid.UUID,
card_direction: str,
prompt_text: str,
answer_text: str,
prompt_context_text: str | None = None,
answer_context_text: str | None = None,
) -> WordBankPackFlashcardTemplate: ...
async def remove_flashcard_template(self, template_id: uuid.UUID) -> None: ...
async def get_templates_for_entry(
self, pack_entry_id: uuid.UUID
) -> list[WordBankPackFlashcardTemplate]: ...
async def get_templates_for_entries(
self, pack_entry_ids: list[uuid.UUID]
) -> dict[str, list[WordBankPackFlashcardTemplate]]: ...
async def count_entries_for_pack(self, pack_id: uuid.UUID) -> int: ...
async def get_pack_ids_added_by_user(
self, user_id: uuid.UUID, language_pair_id: uuid.UUID
) -> set[str]: ...
def _pack_to_model(entity: WordBankPackEntity) -> WordBankPack:
return WordBankPack(
id=str(entity.id),
name=entity.name,
name_target=entity.name_target,
description=entity.description,
description_target=entity.description_target,
source_lang=entity.source_lang,
target_lang=entity.target_lang,
proficiencies=entity.proficiencies,
is_published=entity.is_published,
created_at=entity.created_at,
)
def _entry_to_model(entity: WordBankPackEntryEntity) -> WordBankPackEntry:
return WordBankPackEntry(
id=str(entity.id),
pack_id=str(entity.pack_id),
sense_id=str(entity.sense_id) if entity.sense_id else None,
surface_text=entity.surface_text,
created_at=entity.created_at,
)
def _template_to_model(entity: WordBankPackFlashcardTemplateEntity) -> WordBankPackFlashcardTemplate:
return WordBankPackFlashcardTemplate(
id=str(entity.id),
pack_entry_id=str(entity.pack_entry_id),
card_direction=entity.card_direction,
prompt_text=entity.prompt_text,
answer_text=entity.answer_text,
prompt_context_text=entity.prompt_context_text,
answer_context_text=entity.answer_context_text,
created_at=entity.created_at,
)
class PostgresPackRepository:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def create_pack(
self,
name: str,
name_target: str,
description: str,
description_target: str,
source_lang: str,
target_lang: str,
proficiencies: list[str],
) -> WordBankPack:
entity = WordBankPackEntity(
name=name,
name_target=name_target,
description=description,
description_target=description_target,
source_lang=source_lang,
target_lang=target_lang,
proficiencies=proficiencies,
created_at=datetime.now(timezone.utc),
)
self.db.add(entity)
await self.db.commit()
await self.db.refresh(entity)
return _pack_to_model(entity)
async def update_pack(
self,
pack_id: uuid.UUID,
name: str | None = None,
name_target: str | None = None,
description: str | None = None,
description_target: str | None = None,
proficiencies: list[str] | None = None,
) -> WordBankPack:
result = await self.db.execute(
select(WordBankPackEntity).where(WordBankPackEntity.id == pack_id)
)
entity = result.scalar_one()
if name is not None:
entity.name = name
if name_target is not None:
entity.name_target = name_target
if description is not None:
entity.description = description
if description_target is not None:
entity.description_target = description_target
if proficiencies is not None:
entity.proficiencies = proficiencies
await self.db.commit()
await self.db.refresh(entity)
return _pack_to_model(entity)
async def publish_pack(self, pack_id: uuid.UUID) -> WordBankPack:
result = await self.db.execute(
select(WordBankPackEntity).where(WordBankPackEntity.id == pack_id)
)
entity = result.scalar_one()
entity.is_published = True
await self.db.commit()
await self.db.refresh(entity)
return _pack_to_model(entity)
async def get_pack(self, pack_id: uuid.UUID) -> WordBankPack | None:
result = await self.db.execute(
select(WordBankPackEntity).where(WordBankPackEntity.id == pack_id)
)
entity = result.scalar_one_or_none()
return _pack_to_model(entity) if entity else None
async def list_packs(
self,
source_lang: str | None = None,
target_lang: str | None = None,
published_only: bool = False,
) -> list[WordBankPack]:
query = select(WordBankPackEntity)
if source_lang:
query = query.where(WordBankPackEntity.source_lang == source_lang)
if target_lang:
query = query.where(WordBankPackEntity.target_lang == target_lang)
if published_only:
query = query.where(WordBankPackEntity.is_published.is_(True))
query = query.order_by(WordBankPackEntity.created_at.desc())
result = await self.db.execute(query)
return [_pack_to_model(e) for e in result.scalars().all()]
async def add_entry(
self,
pack_id: uuid.UUID,
sense_id: uuid.UUID | None,
surface_text: str,
) -> WordBankPackEntry:
entity = WordBankPackEntryEntity(
pack_id=pack_id,
sense_id=sense_id,
surface_text=surface_text,
created_at=datetime.now(timezone.utc),
)
self.db.add(entity)
await self.db.commit()
await self.db.refresh(entity)
return _entry_to_model(entity)
async def remove_entry(self, entry_id: uuid.UUID) -> None:
result = await self.db.execute(
select(WordBankPackEntryEntity).where(WordBankPackEntryEntity.id == entry_id)
)
entity = result.scalar_one_or_none()
if entity:
await self.db.delete(entity)
await self.db.commit()
async def get_entries_for_pack(self, pack_id: uuid.UUID) -> list[WordBankPackEntry]:
result = await self.db.execute(
select(WordBankPackEntryEntity)
.where(WordBankPackEntryEntity.pack_id == pack_id)
.order_by(WordBankPackEntryEntity.created_at.asc())
)
return [_entry_to_model(e) for e in result.scalars().all()]
async def add_flashcard_template(
self,
pack_entry_id: uuid.UUID,
card_direction: str,
prompt_text: str,
answer_text: str,
prompt_context_text: str | None = None,
answer_context_text: str | None = None,
) -> WordBankPackFlashcardTemplate:
entity = WordBankPackFlashcardTemplateEntity(
pack_entry_id=pack_entry_id,
card_direction=card_direction,
prompt_text=prompt_text,
answer_text=answer_text,
prompt_context_text=prompt_context_text,
answer_context_text=answer_context_text,
created_at=datetime.now(timezone.utc),
)
self.db.add(entity)
await self.db.commit()
await self.db.refresh(entity)
return _template_to_model(entity)
async def remove_flashcard_template(self, template_id: uuid.UUID) -> None:
result = await self.db.execute(
select(WordBankPackFlashcardTemplateEntity).where(
WordBankPackFlashcardTemplateEntity.id == template_id
)
)
entity = result.scalar_one_or_none()
if entity:
await self.db.delete(entity)
await self.db.commit()
async def get_templates_for_entry(
self, pack_entry_id: uuid.UUID
) -> list[WordBankPackFlashcardTemplate]:
result = await self.db.execute(
select(WordBankPackFlashcardTemplateEntity)
.where(WordBankPackFlashcardTemplateEntity.pack_entry_id == pack_entry_id)
.order_by(WordBankPackFlashcardTemplateEntity.created_at.asc())
)
return [_template_to_model(e) for e in result.scalars().all()]
async def get_templates_for_entries(
self, pack_entry_ids: list[uuid.UUID]
) -> dict[str, list[WordBankPackFlashcardTemplate]]:
if not pack_entry_ids:
return {}
result = await self.db.execute(
select(WordBankPackFlashcardTemplateEntity)
.where(WordBankPackFlashcardTemplateEntity.pack_entry_id.in_(pack_entry_ids))
.order_by(WordBankPackFlashcardTemplateEntity.created_at.asc())
)
grouped: dict[str, list[WordBankPackFlashcardTemplate]] = {}
for entity in result.scalars().all():
key = str(entity.pack_entry_id)
grouped.setdefault(key, []).append(_template_to_model(entity))
return grouped
async def count_entries_for_pack(self, pack_id: uuid.UUID) -> int:
result = await self.db.execute(
select(func.count()).where(WordBankPackEntryEntity.pack_id == pack_id)
)
return result.scalar_one()
async def get_pack_ids_added_by_user(
self, user_id: uuid.UUID, language_pair_id: uuid.UUID
) -> set[str]:
result = await self.db.execute(
select(WordBankPackEntryEntity.pack_id)
.join(
LearnableWordBankEntryEntity,
LearnableWordBankEntryEntity.pack_entry_id == WordBankPackEntryEntity.id,
)
.where(
LearnableWordBankEntryEntity.user_id == user_id,
LearnableWordBankEntryEntity.language_pair_id == language_pair_id,
)
.distinct()
)
return {str(row) for row in result.scalars().all()}