language-learning-app/api/app/outbound/postgres/repositories/vocab_repository.py

198 lines
7.4 KiB
Python

import uuid
from datetime import datetime, timezone
from typing import Protocol
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from ..entities.vocab_entities import LearnableWordBankEntryEntity, UserLanguagePairEntity
from ....domain.models.vocab import LearnableWordBankEntry, UserLanguagePair
class VocabRepository(Protocol):
async def get_or_create_language_pair(
self, user_id: uuid.UUID, source_lang: str, target_lang: str
) -> UserLanguagePair: ...
async def get_language_pair(self, language_pair_id: uuid.UUID) -> UserLanguagePair | None: ...
async def add_entry(
self,
user_id: uuid.UUID,
language_pair_id: uuid.UUID,
surface_text: str,
entry_pathway: str,
is_phrase: bool = False,
sense_id: uuid.UUID | None = None,
wordform_id: uuid.UUID | None = None,
source_article_id: uuid.UUID | None = None,
disambiguation_status: str = "pending",
pack_entry_id: uuid.UUID | None = None,
) -> LearnableWordBankEntry: ...
async def get_sense_ids_for_user_in_pair(
self, user_id: uuid.UUID, language_pair_id: uuid.UUID
) -> set[str]: ...
async def get_entries_for_user(
self, user_id: uuid.UUID, language_pair_id: uuid.UUID
) -> list[LearnableWordBankEntry]: ...
async def set_sense(
self, entry_id: uuid.UUID, sense_id: uuid.UUID
) -> LearnableWordBankEntry: ...
async def get_entry(self, entry_id: uuid.UUID) -> LearnableWordBankEntry | None: ...
async def get_pending_disambiguation(self, user_id: uuid.UUID) -> list[LearnableWordBankEntry]: ...
def _pair_to_model(entity: UserLanguagePairEntity) -> UserLanguagePair:
return UserLanguagePair(
id=str(entity.id),
user_id=str(entity.user_id),
source_lang=entity.source_lang,
target_lang=entity.target_lang,
)
def _entry_to_model(entity: LearnableWordBankEntryEntity) -> LearnableWordBankEntry:
return LearnableWordBankEntry(
id=str(entity.id),
user_id=str(entity.user_id),
language_pair_id=str(entity.language_pair_id),
sense_id=str(entity.sense_id) if entity.sense_id else None,
wordform_id=str(entity.wordform_id) if entity.wordform_id else None,
surface_text=entity.surface_text,
is_phrase=entity.is_phrase,
entry_pathway=entity.entry_pathway,
source_article_id=str(entity.source_article_id) if entity.source_article_id else None,
disambiguation_status=entity.disambiguation_status,
pack_entry_id=str(entity.pack_entry_id) if entity.pack_entry_id else None,
created_at=entity.created_at,
)
class PostgresVocabRepository:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def get_or_create_language_pair(
self, user_id: uuid.UUID, source_lang: str, target_lang: str
) -> UserLanguagePair:
result = await self.db.execute(
select(UserLanguagePairEntity).where(
UserLanguagePairEntity.user_id == user_id,
UserLanguagePairEntity.source_lang == source_lang,
UserLanguagePairEntity.target_lang == target_lang,
)
)
entity = result.scalar_one_or_none()
if entity is None:
entity = UserLanguagePairEntity(
user_id=user_id,
source_lang=source_lang,
target_lang=target_lang,
)
self.db.add(entity)
await self.db.flush()
return _pair_to_model(entity)
async def get_language_pair(self, language_pair_id: uuid.UUID) -> UserLanguagePair | None:
result = await self.db.execute(
select(UserLanguagePairEntity).where(UserLanguagePairEntity.id == language_pair_id)
)
entity = result.scalar_one_or_none()
return _pair_to_model(entity) if entity else None
async def add_entry(
self,
user_id: uuid.UUID,
language_pair_id: uuid.UUID,
surface_text: str,
entry_pathway: str,
is_phrase: bool = False,
sense_id: uuid.UUID | None = None,
wordform_id: uuid.UUID | None = None,
source_article_id: uuid.UUID | None = None,
disambiguation_status: str = "pending",
pack_entry_id: uuid.UUID | None = None,
) -> LearnableWordBankEntry:
entity = LearnableWordBankEntryEntity(
user_id=user_id,
language_pair_id=language_pair_id,
surface_text=surface_text,
entry_pathway=entry_pathway,
is_phrase=is_phrase,
sense_id=sense_id,
wordform_id=wordform_id,
source_article_id=source_article_id,
disambiguation_status=disambiguation_status,
pack_entry_id=pack_entry_id,
created_at=datetime.now(timezone.utc),
)
self.db.add(entity)
await self.db.commit()
await self.db.refresh(entity)
return _entry_to_model(entity)
async def get_sense_ids_for_user_in_pair(
self, user_id: uuid.UUID, language_pair_id: uuid.UUID
) -> set[str]:
from sqlalchemy import select
result = await self.db.execute(
select(LearnableWordBankEntryEntity.sense_id).where(
LearnableWordBankEntryEntity.user_id == user_id,
LearnableWordBankEntryEntity.language_pair_id == language_pair_id,
LearnableWordBankEntryEntity.sense_id.is_not(None),
)
)
return {str(row) for row in result.scalars().all()}
async def get_entries_for_user(
self, user_id: uuid.UUID, language_pair_id: uuid.UUID
) -> list[LearnableWordBankEntry]:
result = await self.db.execute(
select(LearnableWordBankEntryEntity)
.where(
LearnableWordBankEntryEntity.user_id == user_id,
LearnableWordBankEntryEntity.language_pair_id == language_pair_id,
)
.order_by(LearnableWordBankEntryEntity.created_at.desc())
)
return [_entry_to_model(e) for e in result.scalars().all()]
async def set_sense(
self, entry_id: uuid.UUID, sense_id: uuid.UUID
) -> LearnableWordBankEntry:
result = await self.db.execute(
select(LearnableWordBankEntryEntity).where(
LearnableWordBankEntryEntity.id == entry_id
)
)
entity = result.scalar_one()
entity.sense_id = sense_id
entity.disambiguation_status = "resolved"
await self.db.commit()
await self.db.refresh(entity)
return _entry_to_model(entity)
async def get_entry(self, entry_id: uuid.UUID) -> LearnableWordBankEntry | None:
result = await self.db.execute(
select(LearnableWordBankEntryEntity).where(
LearnableWordBankEntryEntity.id == entry_id
)
)
entity = result.scalar_one_or_none()
return _entry_to_model(entity) if entity else None
async def get_pending_disambiguation(self, user_id: uuid.UUID) -> list[LearnableWordBankEntry]:
result = await self.db.execute(
select(LearnableWordBankEntryEntity)
.where(
LearnableWordBankEntryEntity.user_id == user_id,
LearnableWordBankEntryEntity.disambiguation_status == "pending",
)
.order_by(LearnableWordBankEntryEntity.created_at.desc())
)
return [_entry_to_model(e) for e in result.scalars().all()]