language-learning-app/api/app/outbound/postgres/repositories/dictionary_repository.py

214 lines
7.5 KiB
Python

import uuid
from dataclasses import dataclass
from typing import Protocol
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from ....domain.models.dictionary import Lemma, Sense, Wordform
from ..entities.dictionary_entities import (
DictionaryLemmaEntity,
DictionarySenseEntity,
DictionaryWordformEntity,
)
class DictionaryRepository(Protocol):
async def get_senses_for_headword(
self, headword: str, language: str
) -> list[Sense]: ...
async def get_senses_for_headword_and_pos(
self, headword: str, language: str, pos_normalised: str
) -> list[Sense]: ...
async def get_senses_for_lemma(self, lemma_id: uuid.UUID) -> list[Sense]: ...
async def find_senses_by_english_gloss(
self, text: str, target_lang: str
) -> list[Sense]: ...
async def get_sense(self, sense_id: uuid.UUID) -> Sense | None: ...
async def get_lemma(self, lemma_id: uuid.UUID) -> Lemma | None: ...
async def get_wordforms_by_form(
self, form: str, language: str
) -> list[Wordform]: ...
async def search_wordforms_by_prefix(
self, prefix: str, language: str
) -> list[Wordform]: ...
async def get_wordforms_for_lemma(self, lemma_id: uuid.UUID) -> list[Wordform]: ...
def _sense_to_model(entity: DictionarySenseEntity) -> Sense:
return Sense(
id=str(entity.id),
lemma_id=str(entity.lemma_id),
sense_index=entity.sense_index,
gloss=entity.gloss,
topics=entity.topics or [],
tags=entity.tags or [],
)
def _lemma_to_model(entity: DictionaryLemmaEntity) -> Lemma:
return Lemma(
id=str(entity.id),
headword=entity.headword,
language=entity.language,
pos_raw=entity.pos_raw,
pos_normalised=entity.pos_normalised,
gender=entity.gender,
tags=entity.tags or [],
)
def _wordform_to_model(entity: DictionaryWordformEntity) -> Wordform:
return Wordform(
id=str(entity.id),
lemma_id=str(entity.lemma_id),
form=entity.form,
tags=entity.tags or [],
)
class PostgresDictionaryRepository:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def get_senses_for_headword(
self, headword: str, language: str
) -> list[Sense]:
result = await self.db.execute(
select(DictionarySenseEntity)
.join(
DictionaryLemmaEntity,
DictionarySenseEntity.lemma_id == DictionaryLemmaEntity.id,
)
.where(
DictionaryLemmaEntity.headword == headword,
DictionaryLemmaEntity.language == language,
)
.order_by(DictionarySenseEntity.sense_index)
)
return [_sense_to_model(e) for e in result.scalars().all()]
async def find_senses_by_english_gloss(
self, text: str, target_lang: str
) -> list[Sense]:
"""EN→target direction: find senses whose gloss matches the given English text.
Uses a case-insensitive exact match on the gloss column, filtered to the
target language via the joined lemma row.
"""
result = await self.db.execute(
select(DictionarySenseEntity)
.join(
DictionaryLemmaEntity,
DictionarySenseEntity.lemma_id == DictionaryLemmaEntity.id,
)
.where(
DictionarySenseEntity.gloss.ilike(text),
DictionaryLemmaEntity.language == target_lang,
)
.order_by(DictionarySenseEntity.sense_index)
)
return [_sense_to_model(e) for e in result.scalars().all()]
async def get_sense(self, sense_id: uuid.UUID) -> Sense | None:
result = await self.db.execute(
select(DictionarySenseEntity).where(DictionarySenseEntity.id == sense_id)
)
entity = result.scalar_one_or_none()
return _sense_to_model(entity) if entity else None
async def get_lemma(self, lemma_id: uuid.UUID) -> Lemma | None:
result = await self.db.execute(
select(DictionaryLemmaEntity).where(DictionaryLemmaEntity.id == lemma_id)
)
entity = result.scalar_one_or_none()
return _lemma_to_model(entity) if entity else None
async def get_senses_for_headword_and_pos(
self, headword: str, language: str, pos_normalised: str
) -> list[Sense]:
result = await self.db.execute(
select(DictionarySenseEntity)
.join(
DictionaryLemmaEntity,
DictionarySenseEntity.lemma_id == DictionaryLemmaEntity.id,
)
.where(
DictionaryLemmaEntity.headword == headword,
DictionaryLemmaEntity.language == language,
DictionaryLemmaEntity.pos_normalised == pos_normalised,
)
.order_by(DictionarySenseEntity.sense_index)
)
return [_sense_to_model(e) for e in result.scalars().all()]
async def get_senses_for_lemma(self, lemma_id: uuid.UUID) -> list[Sense]:
result = await self.db.execute(
select(DictionarySenseEntity)
.where(DictionarySenseEntity.lemma_id == lemma_id)
.order_by(DictionarySenseEntity.sense_index)
)
return [_sense_to_model(e) for e in result.scalars().all()]
async def get_wordforms_by_form(self, form: str, language: str) -> list[Wordform]:
result = await self.db.execute(
select(DictionaryWordformEntity)
.join(
DictionaryLemmaEntity,
DictionaryWordformEntity.lemma_id == DictionaryLemmaEntity.id,
)
.where(
DictionaryWordformEntity.form == form,
DictionaryLemmaEntity.language == language,
)
)
return [_wordform_to_model(e) for e in result.scalars().all()]
async def search_wordforms_by_prefix(
self, prefix: str, language: str
) -> list[Wordform]:
result = await self.db.execute(
select(DictionaryWordformEntity)
.join(
DictionaryLemmaEntity,
DictionaryWordformEntity.lemma_id == DictionaryLemmaEntity.id,
)
.where(
func.unaccent(DictionaryWordformEntity.form).ilike(
func.unaccent(prefix) + "%"
),
DictionaryLemmaEntity.language == language,
)
)
return [_wordform_to_model(e) for e in result.scalars().all()]
async def search_senses_by_prefix(
self, prefix: str, lang: str
) -> list[tuple[Sense, Lemma]]:
result = await self.db.execute(
select(DictionarySenseEntity, DictionaryLemmaEntity)
.join(
DictionaryLemmaEntity,
DictionarySenseEntity.lemma_id == DictionaryLemmaEntity.id,
)
.where(
DictionarySenseEntity.gloss.ilike(prefix),
DictionaryLemmaEntity.language == lang,
)
)
results: list[tuple[Sense, Lemma]] = []
for sense_with_lemma in result.all():
sense, lemma = sense_with_lemma.tuple()
results.append((_sense_to_model(sense), _lemma_to_model(lemma)))
return results
async def get_wordforms_for_lemma(self, lemma_id: uuid.UUID) -> list[Wordform]:
result = await self.db.execute(
select(DictionaryWordformEntity).where(
DictionaryWordformEntity.lemma_id == lemma_id
)
)
return [_wordform_to_model(e) for e in result.scalars().all()]