177 lines
6.5 KiB
Python
177 lines
6.5 KiB
Python
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Protocol
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from ..entities.vocab_entities import LearnableWordBankEntryEntity, UserLanguagePairEntity
|
|
from ....domain.models.vocab import LearnableWordBankEntry, UserLanguagePair
|
|
|
|
|
|
class VocabRepository(Protocol):
|
|
async def get_or_create_language_pair(
|
|
self, user_id: uuid.UUID, source_lang: str, target_lang: str
|
|
) -> UserLanguagePair: ...
|
|
|
|
async def get_language_pair(self, language_pair_id: uuid.UUID) -> UserLanguagePair | None: ...
|
|
|
|
async def add_entry(
|
|
self,
|
|
user_id: uuid.UUID,
|
|
language_pair_id: uuid.UUID,
|
|
surface_text: str,
|
|
entry_pathway: str,
|
|
is_phrase: bool = False,
|
|
sense_id: uuid.UUID | None = None,
|
|
wordform_id: uuid.UUID | None = None,
|
|
source_article_id: uuid.UUID | None = None,
|
|
disambiguation_status: str = "pending",
|
|
) -> LearnableWordBankEntry: ...
|
|
|
|
async def get_entries_for_user(
|
|
self, user_id: uuid.UUID, language_pair_id: uuid.UUID
|
|
) -> list[LearnableWordBankEntry]: ...
|
|
|
|
async def set_sense(
|
|
self, entry_id: uuid.UUID, sense_id: uuid.UUID
|
|
) -> LearnableWordBankEntry: ...
|
|
|
|
async def get_entry(self, entry_id: uuid.UUID) -> LearnableWordBankEntry | None: ...
|
|
|
|
async def get_pending_disambiguation(self, user_id: uuid.UUID) -> list[LearnableWordBankEntry]: ...
|
|
|
|
|
|
def _pair_to_model(entity: UserLanguagePairEntity) -> UserLanguagePair:
|
|
return UserLanguagePair(
|
|
id=str(entity.id),
|
|
user_id=str(entity.user_id),
|
|
source_lang=entity.source_lang,
|
|
target_lang=entity.target_lang,
|
|
)
|
|
|
|
|
|
def _entry_to_model(entity: LearnableWordBankEntryEntity) -> LearnableWordBankEntry:
|
|
return LearnableWordBankEntry(
|
|
id=str(entity.id),
|
|
user_id=str(entity.user_id),
|
|
language_pair_id=str(entity.language_pair_id),
|
|
sense_id=str(entity.sense_id) if entity.sense_id else None,
|
|
wordform_id=str(entity.wordform_id) if entity.wordform_id else None,
|
|
surface_text=entity.surface_text,
|
|
is_phrase=entity.is_phrase,
|
|
entry_pathway=entity.entry_pathway,
|
|
source_article_id=str(entity.source_article_id) if entity.source_article_id else None,
|
|
disambiguation_status=entity.disambiguation_status,
|
|
created_at=entity.created_at,
|
|
)
|
|
|
|
|
|
class PostgresVocabRepository:
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
self.db = db
|
|
|
|
async def get_or_create_language_pair(
|
|
self, user_id: uuid.UUID, source_lang: str, target_lang: str
|
|
) -> UserLanguagePair:
|
|
result = await self.db.execute(
|
|
select(UserLanguagePairEntity).where(
|
|
UserLanguagePairEntity.user_id == user_id,
|
|
UserLanguagePairEntity.source_lang == source_lang,
|
|
UserLanguagePairEntity.target_lang == target_lang,
|
|
)
|
|
)
|
|
entity = result.scalar_one_or_none()
|
|
if entity is None:
|
|
entity = UserLanguagePairEntity(
|
|
user_id=user_id,
|
|
source_lang=source_lang,
|
|
target_lang=target_lang,
|
|
)
|
|
self.db.add(entity)
|
|
await self.db.flush()
|
|
return _pair_to_model(entity)
|
|
|
|
async def get_language_pair(self, language_pair_id: uuid.UUID) -> UserLanguagePair | None:
|
|
result = await self.db.execute(
|
|
select(UserLanguagePairEntity).where(UserLanguagePairEntity.id == language_pair_id)
|
|
)
|
|
entity = result.scalar_one_or_none()
|
|
return _pair_to_model(entity) if entity else None
|
|
|
|
async def add_entry(
|
|
self,
|
|
user_id: uuid.UUID,
|
|
language_pair_id: uuid.UUID,
|
|
surface_text: str,
|
|
entry_pathway: str,
|
|
is_phrase: bool = False,
|
|
sense_id: uuid.UUID | None = None,
|
|
wordform_id: uuid.UUID | None = None,
|
|
source_article_id: uuid.UUID | None = None,
|
|
disambiguation_status: str = "pending",
|
|
) -> LearnableWordBankEntry:
|
|
entity = LearnableWordBankEntryEntity(
|
|
user_id=user_id,
|
|
language_pair_id=language_pair_id,
|
|
surface_text=surface_text,
|
|
entry_pathway=entry_pathway,
|
|
is_phrase=is_phrase,
|
|
sense_id=sense_id,
|
|
wordform_id=wordform_id,
|
|
source_article_id=source_article_id,
|
|
disambiguation_status=disambiguation_status,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
self.db.add(entity)
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _entry_to_model(entity)
|
|
|
|
async def get_entries_for_user(
|
|
self, user_id: uuid.UUID, language_pair_id: uuid.UUID
|
|
) -> list[LearnableWordBankEntry]:
|
|
result = await self.db.execute(
|
|
select(LearnableWordBankEntryEntity)
|
|
.where(
|
|
LearnableWordBankEntryEntity.user_id == user_id,
|
|
LearnableWordBankEntryEntity.language_pair_id == language_pair_id,
|
|
)
|
|
.order_by(LearnableWordBankEntryEntity.created_at.desc())
|
|
)
|
|
return [_entry_to_model(e) for e in result.scalars().all()]
|
|
|
|
async def set_sense(
|
|
self, entry_id: uuid.UUID, sense_id: uuid.UUID
|
|
) -> LearnableWordBankEntry:
|
|
result = await self.db.execute(
|
|
select(LearnableWordBankEntryEntity).where(
|
|
LearnableWordBankEntryEntity.id == entry_id
|
|
)
|
|
)
|
|
entity = result.scalar_one()
|
|
entity.sense_id = sense_id
|
|
entity.disambiguation_status = "resolved"
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _entry_to_model(entity)
|
|
|
|
async def get_entry(self, entry_id: uuid.UUID) -> LearnableWordBankEntry | None:
|
|
result = await self.db.execute(
|
|
select(LearnableWordBankEntryEntity).where(
|
|
LearnableWordBankEntryEntity.id == entry_id
|
|
)
|
|
)
|
|
entity = result.scalar_one_or_none()
|
|
return _entry_to_model(entity) if entity else None
|
|
|
|
async def get_pending_disambiguation(self, user_id: uuid.UUID) -> list[LearnableWordBankEntry]:
|
|
result = await self.db.execute(
|
|
select(LearnableWordBankEntryEntity)
|
|
.where(
|
|
LearnableWordBankEntryEntity.user_id == user_id,
|
|
LearnableWordBankEntryEntity.disambiguation_status == "pending",
|
|
)
|
|
.order_by(LearnableWordBankEntryEntity.created_at.desc())
|
|
)
|
|
return [_entry_to_model(e) for e in result.scalars().all()]
|