language-learning-app/api/app/outbound/postgres/repositories/pack_repository.py

328 lines
11 KiB
Python
Raw Normal View History

import uuid
from datetime import datetime, timezone
from typing import Protocol
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from ..entities.pack_entities import (
WordBankPackEntity,
WordBankPackEntryEntity,
WordBankPackFlashcardTemplateEntity,
)
from ..entities.vocab_entities import LearnableWordBankEntryEntity
from ....domain.models.pack import WordBankPack, WordBankPackEntry, WordBankPackFlashcardTemplate
class PackRepository(Protocol):
async def create_pack(
self,
name: str,
name_target: str,
description: str,
description_target: str,
source_lang: str,
target_lang: str,
proficiencies: list[str],
) -> WordBankPack: ...
async def update_pack(
self,
pack_id: uuid.UUID,
name: str | None = None,
name_target: str | None = None,
description: str | None = None,
description_target: str | None = None,
proficiencies: list[str] | None = None,
) -> WordBankPack: ...
async def publish_pack(self, pack_id: uuid.UUID) -> WordBankPack: ...
async def get_pack(self, pack_id: uuid.UUID) -> WordBankPack | None: ...
async def list_packs(
self,
source_lang: str | None = None,
target_lang: str | None = None,
published_only: bool = False,
) -> list[WordBankPack]: ...
async def add_entry(
self,
pack_id: uuid.UUID,
sense_id: uuid.UUID | None,
surface_text: str,
) -> WordBankPackEntry: ...
async def remove_entry(self, entry_id: uuid.UUID) -> None: ...
async def get_entries_for_pack(self, pack_id: uuid.UUID) -> list[WordBankPackEntry]: ...
async def add_flashcard_template(
self,
pack_entry_id: uuid.UUID,
prompt_text: str,
answer_text: str,
prompt_context_text: str | None = None,
answer_context_text: str | None = None,
) -> WordBankPackFlashcardTemplate: ...
async def remove_flashcard_template(self, template_id: uuid.UUID) -> None: ...
async def get_templates_for_entry(
self, pack_entry_id: uuid.UUID
) -> list[WordBankPackFlashcardTemplate]: ...
async def get_templates_for_entries(
self, pack_entry_ids: list[uuid.UUID]
) -> dict[str, list[WordBankPackFlashcardTemplate]]: ...
async def count_entries_for_pack(self, pack_id: uuid.UUID) -> int: ...
async def get_pack_ids_added_by_user(
self, user_id: uuid.UUID, language_pair_id: uuid.UUID
) -> set[str]: ...
def _pack_to_model(entity: WordBankPackEntity) -> WordBankPack:
return WordBankPack(
id=str(entity.id),
name=entity.name,
name_target=entity.name_target,
description=entity.description,
description_target=entity.description_target,
source_lang=entity.source_lang,
target_lang=entity.target_lang,
proficiencies=entity.proficiencies,
is_published=entity.is_published,
created_at=entity.created_at,
)
def _entry_to_model(entity: WordBankPackEntryEntity) -> WordBankPackEntry:
return WordBankPackEntry(
id=str(entity.id),
pack_id=str(entity.pack_id),
sense_id=str(entity.sense_id) if entity.sense_id else None,
surface_text=entity.surface_text,
created_at=entity.created_at,
)
def _template_to_model(entity: WordBankPackFlashcardTemplateEntity) -> WordBankPackFlashcardTemplate:
return WordBankPackFlashcardTemplate(
id=str(entity.id),
pack_entry_id=str(entity.pack_entry_id),
prompt_text=entity.prompt_text,
answer_text=entity.answer_text,
prompt_context_text=entity.prompt_context_text,
answer_context_text=entity.answer_context_text,
created_at=entity.created_at,
)
class PostgresPackRepository:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def create_pack(
self,
name: str,
name_target: str,
description: str,
description_target: str,
source_lang: str,
target_lang: str,
proficiencies: list[str],
) -> WordBankPack:
entity = WordBankPackEntity(
name=name,
name_target=name_target,
description=description,
description_target=description_target,
source_lang=source_lang,
target_lang=target_lang,
proficiencies=proficiencies,
created_at=datetime.now(timezone.utc),
)
self.db.add(entity)
await self.db.commit()
await self.db.refresh(entity)
return _pack_to_model(entity)
async def update_pack(
self,
pack_id: uuid.UUID,
name: str | None = None,
name_target: str | None = None,
description: str | None = None,
description_target: str | None = None,
proficiencies: list[str] | None = None,
) -> WordBankPack:
result = await self.db.execute(
select(WordBankPackEntity).where(WordBankPackEntity.id == pack_id)
)
entity = result.scalar_one()
if name is not None:
entity.name = name
if name_target is not None:
entity.name_target = name_target
if description is not None:
entity.description = description
if description_target is not None:
entity.description_target = description_target
if proficiencies is not None:
entity.proficiencies = proficiencies
await self.db.commit()
await self.db.refresh(entity)
return _pack_to_model(entity)
async def publish_pack(self, pack_id: uuid.UUID) -> WordBankPack:
result = await self.db.execute(
select(WordBankPackEntity).where(WordBankPackEntity.id == pack_id)
)
entity = result.scalar_one()
entity.is_published = True
await self.db.commit()
await self.db.refresh(entity)
return _pack_to_model(entity)
async def get_pack(self, pack_id: uuid.UUID) -> WordBankPack | None:
result = await self.db.execute(
select(WordBankPackEntity).where(WordBankPackEntity.id == pack_id)
)
entity = result.scalar_one_or_none()
return _pack_to_model(entity) if entity else None
async def list_packs(
self,
source_lang: str | None = None,
target_lang: str | None = None,
published_only: bool = False,
) -> list[WordBankPack]:
query = select(WordBankPackEntity)
if source_lang:
query = query.where(WordBankPackEntity.source_lang == source_lang)
if target_lang:
query = query.where(WordBankPackEntity.target_lang == target_lang)
if published_only:
query = query.where(WordBankPackEntity.is_published.is_(True))
query = query.order_by(WordBankPackEntity.created_at.desc())
result = await self.db.execute(query)
return [_pack_to_model(e) for e in result.scalars().all()]
async def add_entry(
self,
pack_id: uuid.UUID,
sense_id: uuid.UUID | None,
surface_text: str,
) -> WordBankPackEntry:
entity = WordBankPackEntryEntity(
pack_id=pack_id,
sense_id=sense_id,
surface_text=surface_text,
created_at=datetime.now(timezone.utc),
)
self.db.add(entity)
await self.db.commit()
await self.db.refresh(entity)
return _entry_to_model(entity)
async def remove_entry(self, entry_id: uuid.UUID) -> None:
result = await self.db.execute(
select(WordBankPackEntryEntity).where(WordBankPackEntryEntity.id == entry_id)
)
entity = result.scalar_one_or_none()
if entity:
await self.db.delete(entity)
await self.db.commit()
async def get_entries_for_pack(self, pack_id: uuid.UUID) -> list[WordBankPackEntry]:
result = await self.db.execute(
select(WordBankPackEntryEntity)
.where(WordBankPackEntryEntity.pack_id == pack_id)
.order_by(WordBankPackEntryEntity.created_at.asc())
)
return [_entry_to_model(e) for e in result.scalars().all()]
async def add_flashcard_template(
self,
pack_entry_id: uuid.UUID,
prompt_text: str,
answer_text: str,
prompt_context_text: str | None = None,
answer_context_text: str | None = None,
) -> WordBankPackFlashcardTemplate:
entity = WordBankPackFlashcardTemplateEntity(
pack_entry_id=pack_entry_id,
prompt_text=prompt_text,
answer_text=answer_text,
prompt_context_text=prompt_context_text,
answer_context_text=answer_context_text,
created_at=datetime.now(timezone.utc),
)
self.db.add(entity)
await self.db.commit()
await self.db.refresh(entity)
return _template_to_model(entity)
async def remove_flashcard_template(self, template_id: uuid.UUID) -> None:
result = await self.db.execute(
select(WordBankPackFlashcardTemplateEntity).where(
WordBankPackFlashcardTemplateEntity.id == template_id
)
)
entity = result.scalar_one_or_none()
if entity:
await self.db.delete(entity)
await self.db.commit()
async def get_templates_for_entry(
self, pack_entry_id: uuid.UUID
) -> list[WordBankPackFlashcardTemplate]:
result = await self.db.execute(
select(WordBankPackFlashcardTemplateEntity)
.where(WordBankPackFlashcardTemplateEntity.pack_entry_id == pack_entry_id)
.order_by(WordBankPackFlashcardTemplateEntity.created_at.asc())
)
return [_template_to_model(e) for e in result.scalars().all()]
async def get_templates_for_entries(
self, pack_entry_ids: list[uuid.UUID]
) -> dict[str, list[WordBankPackFlashcardTemplate]]:
if not pack_entry_ids:
return {}
result = await self.db.execute(
select(WordBankPackFlashcardTemplateEntity)
.where(WordBankPackFlashcardTemplateEntity.pack_entry_id.in_(pack_entry_ids))
.order_by(WordBankPackFlashcardTemplateEntity.created_at.asc())
)
grouped: dict[str, list[WordBankPackFlashcardTemplate]] = {}
for entity in result.scalars().all():
key = str(entity.pack_entry_id)
grouped.setdefault(key, []).append(_template_to_model(entity))
return grouped
async def count_entries_for_pack(self, pack_id: uuid.UUID) -> int:
result = await self.db.execute(
select(func.count()).where(WordBankPackEntryEntity.pack_id == pack_id)
)
return result.scalar_one()
async def get_pack_ids_added_by_user(
self, user_id: uuid.UUID, language_pair_id: uuid.UUID
) -> set[str]:
result = await self.db.execute(
select(WordBankPackEntryEntity.pack_id)
.join(
LearnableWordBankEntryEntity,
LearnableWordBankEntryEntity.pack_entry_id == WordBankPackEntryEntity.id,
)
.where(
LearnableWordBankEntryEntity.user_id == user_id,
LearnableWordBankEntryEntity.language_pair_id == language_pair_id,
)
.distinct()
)
return {str(row) for row in result.scalars().all()}