442 lines
15 KiB
Python
442 lines
15 KiB
Python
import uuid
|
|
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from ....domain.models.adventure import (
|
|
Adventure,
|
|
AdventureEntry,
|
|
AdventureEntryAudio,
|
|
AdventureEntryPossibleChoice,
|
|
AdventureEntryPossibleChoiceDecision,
|
|
AdventureEntryTranslation,
|
|
)
|
|
from ..entities.adventure_entities import (
|
|
AdventureEntity,
|
|
AdventureEntryAudioEntity,
|
|
AdventureEntryEntity,
|
|
AdventureEntryPossibleChoiceDecisionEntity,
|
|
AdventureEntryPossibleChoiceEntity,
|
|
AdventureEntryTranslationEntity,
|
|
)
|
|
|
|
|
|
def _to_adventure(e: AdventureEntity) -> Adventure:
|
|
return Adventure(
|
|
id=str(e.id),
|
|
user_id=str(e.user_id),
|
|
status=e.status,
|
|
language=e.language,
|
|
source_language=e.source_language,
|
|
competencies=e.competencies,
|
|
max_entry_count=e.max_entry_count,
|
|
entry_story_text_target_length=e.entry_story_text_target_length,
|
|
title=e.title,
|
|
description=e.description,
|
|
plot_summary=e.plot_summary,
|
|
genres=e.genres,
|
|
setting=e.setting,
|
|
vibes=e.vibes,
|
|
protagonist=e.protagonist,
|
|
created_at=e.created_at,
|
|
deleted_at=e.deleted_at,
|
|
)
|
|
|
|
|
|
def _to_entry(e: AdventureEntryEntity) -> AdventureEntry:
|
|
return AdventureEntry(
|
|
id=str(e.id),
|
|
adventure_id=str(e.adventure_id),
|
|
generated_from_choice_id=str(e.generated_from_choice_id) if e.generated_from_choice_id else None,
|
|
status=e.status,
|
|
entry_index=e.entry_index,
|
|
story_text=e.story_text,
|
|
gamemaster_notes=e.gamemaster_notes,
|
|
llm_data=e.llm_data,
|
|
story_text_linguistic_data=e.story_text_linguistic_data,
|
|
pipeline_timing=e.pipeline_timing,
|
|
created_at=e.created_at,
|
|
)
|
|
|
|
|
|
def _to_choice(e: AdventureEntryPossibleChoiceEntity) -> AdventureEntryPossibleChoice:
|
|
return AdventureEntryPossibleChoice(
|
|
id=str(e.id),
|
|
entry_id=str(e.entry_id),
|
|
index=e.index,
|
|
label=e.label,
|
|
text=e.text,
|
|
)
|
|
|
|
|
|
def _to_decision(e: AdventureEntryPossibleChoiceDecisionEntity) -> AdventureEntryPossibleChoiceDecision:
|
|
return AdventureEntryPossibleChoiceDecision(
|
|
id=str(e.id),
|
|
choice_id=str(e.choice_id),
|
|
user_id=str(e.user_id),
|
|
created_at=e.created_at,
|
|
)
|
|
|
|
|
|
def _to_translation(e: AdventureEntryTranslationEntity) -> AdventureEntryTranslation:
|
|
return AdventureEntryTranslation(
|
|
id=str(e.id),
|
|
entry_id=str(e.entry_id),
|
|
component_type=e.component_type,
|
|
target_language=e.target_language,
|
|
translated_text=e.translated_text,
|
|
)
|
|
|
|
|
|
def _to_audio(e: AdventureEntryAudioEntity) -> AdventureEntryAudio:
|
|
return AdventureEntryAudio(
|
|
id=str(e.id),
|
|
entry_id=str(e.entry_id),
|
|
component_type=e.component_type,
|
|
tts_provider=e.tts_provider,
|
|
tts_options=e.tts_options,
|
|
file_name=e.file_name,
|
|
)
|
|
|
|
|
|
class PostgresAdventureRepository:
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
self.db = db
|
|
|
|
async def create(
|
|
self,
|
|
user_id: uuid.UUID,
|
|
language: str,
|
|
source_language: str,
|
|
competencies: list[str],
|
|
genres: list[str],
|
|
setting: list[str],
|
|
vibes: list[str],
|
|
protagonist: list[str],
|
|
max_entry_count: int,
|
|
entry_story_text_target_length: dict,
|
|
) -> Adventure:
|
|
entity = AdventureEntity(
|
|
user_id=user_id,
|
|
language=language,
|
|
source_language=source_language,
|
|
competencies=competencies,
|
|
genres=genres,
|
|
setting=setting,
|
|
vibes=vibes,
|
|
protagonist=protagonist,
|
|
max_entry_count=max_entry_count,
|
|
entry_story_text_target_length=entry_story_text_target_length,
|
|
)
|
|
self.db.add(entity)
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _to_adventure(entity)
|
|
|
|
async def get_by_id(self, adventure_id: uuid.UUID) -> Adventure | None:
|
|
result = await self.db.execute(
|
|
select(AdventureEntity).where(AdventureEntity.id == adventure_id)
|
|
)
|
|
entity = result.scalar_one_or_none()
|
|
return _to_adventure(entity) if entity else None
|
|
|
|
async def list_for_user(self, user_id: uuid.UUID) -> list[Adventure]:
|
|
result = await self.db.execute(
|
|
select(AdventureEntity)
|
|
.where(AdventureEntity.user_id == user_id, AdventureEntity.deleted_at.is_(None))
|
|
.order_by(AdventureEntity.created_at.desc())
|
|
)
|
|
return [_to_adventure(e) for e in result.scalars().all()]
|
|
|
|
async def update_status(self, adventure_id: uuid.UUID, status: str) -> Adventure:
|
|
result = await self.db.execute(
|
|
select(AdventureEntity).where(AdventureEntity.id == adventure_id)
|
|
)
|
|
entity = result.scalar_one()
|
|
entity.status = status
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _to_adventure(entity)
|
|
|
|
async def update_title_and_description(
|
|
self, adventure_id: uuid.UUID, title: str, description: str
|
|
) -> Adventure:
|
|
result = await self.db.execute(
|
|
select(AdventureEntity).where(AdventureEntity.id == adventure_id)
|
|
)
|
|
entity = result.scalar_one()
|
|
entity.title = title
|
|
entity.description = description
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _to_adventure(entity)
|
|
|
|
async def soft_delete(self, adventure_id: uuid.UUID) -> Adventure:
|
|
result = await self.db.execute(
|
|
select(AdventureEntity).where(AdventureEntity.id == adventure_id)
|
|
)
|
|
entity = result.scalar_one()
|
|
entity.deleted_at = datetime.now(timezone.utc)
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _to_adventure(entity)
|
|
|
|
|
|
class PostgresAdventureEntryRepository:
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
self.db = db
|
|
|
|
async def create(
|
|
self,
|
|
adventure_id: uuid.UUID,
|
|
entry_index: int,
|
|
generated_from_choice_id: uuid.UUID | None,
|
|
) -> AdventureEntry:
|
|
entity = AdventureEntryEntity(
|
|
adventure_id=adventure_id,
|
|
entry_index=entry_index,
|
|
generated_from_choice_id=generated_from_choice_id,
|
|
)
|
|
self.db.add(entity)
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _to_entry(entity)
|
|
|
|
async def get_by_id(self, entry_id: uuid.UUID) -> AdventureEntry | None:
|
|
result = await self.db.execute(
|
|
select(AdventureEntryEntity).where(AdventureEntryEntity.id == entry_id)
|
|
)
|
|
entity = result.scalar_one_or_none()
|
|
return _to_entry(entity) if entity else None
|
|
|
|
async def list_for_adventure(self, adventure_id: uuid.UUID) -> list[AdventureEntry]:
|
|
result = await self.db.execute(
|
|
select(AdventureEntryEntity)
|
|
.where(AdventureEntryEntity.adventure_id == adventure_id)
|
|
.order_by(AdventureEntryEntity.entry_index.asc())
|
|
)
|
|
return [_to_entry(e) for e in result.scalars().all()]
|
|
|
|
async def update_content(
|
|
self,
|
|
entry_id: uuid.UUID,
|
|
story_text: str,
|
|
gamemaster_notes: str,
|
|
llm_data: dict,
|
|
status: str,
|
|
) -> AdventureEntry:
|
|
result = await self.db.execute(
|
|
select(AdventureEntryEntity).where(AdventureEntryEntity.id == entry_id)
|
|
)
|
|
entity = result.scalar_one()
|
|
entity.story_text = story_text
|
|
entity.gamemaster_notes = gamemaster_notes
|
|
entity.llm_data = llm_data
|
|
entity.status = status
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _to_entry(entity)
|
|
|
|
async def update_status(self, entry_id: uuid.UUID, status: str) -> AdventureEntry:
|
|
result = await self.db.execute(
|
|
select(AdventureEntryEntity).where(AdventureEntryEntity.id == entry_id)
|
|
)
|
|
entity = result.scalar_one()
|
|
entity.status = status
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _to_entry(entity)
|
|
|
|
async def update_linguistic_data(
|
|
self,
|
|
entry_id: uuid.UUID,
|
|
story_text_linguistic_data: dict,
|
|
pipeline_timing: dict,
|
|
) -> AdventureEntry:
|
|
result = await self.db.execute(
|
|
select(AdventureEntryEntity).where(AdventureEntryEntity.id == entry_id)
|
|
)
|
|
entity = result.scalar_one()
|
|
entity.story_text_linguistic_data = story_text_linguistic_data
|
|
entity.pipeline_timing = pipeline_timing
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _to_entry(entity)
|
|
|
|
async def count_complete(self, adventure_id: uuid.UUID) -> int:
|
|
result = await self.db.execute(
|
|
select(func.count()).select_from(AdventureEntryEntity).where(
|
|
AdventureEntryEntity.adventure_id == adventure_id,
|
|
AdventureEntryEntity.status == "complete",
|
|
)
|
|
)
|
|
return result.scalar_one()
|
|
|
|
|
|
class PostgresAdventureEntryChoiceRepository:
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
self.db = db
|
|
|
|
async def create_many(
|
|
self,
|
|
entry_id: uuid.UUID,
|
|
choices: list[tuple[int, str, str]], # (index, label, text)
|
|
) -> list[AdventureEntryPossibleChoice]:
|
|
entities = [
|
|
AdventureEntryPossibleChoiceEntity(
|
|
entry_id=entry_id, index=index, label=label, text=text
|
|
)
|
|
for index, label, text in choices
|
|
]
|
|
for e in entities:
|
|
self.db.add(e)
|
|
await self.db.commit()
|
|
for e in entities:
|
|
await self.db.refresh(e)
|
|
return [_to_choice(e) for e in entities]
|
|
|
|
async def get_by_id(self, choice_id: uuid.UUID) -> AdventureEntryPossibleChoice | None:
|
|
result = await self.db.execute(
|
|
select(AdventureEntryPossibleChoiceEntity).where(
|
|
AdventureEntryPossibleChoiceEntity.id == choice_id
|
|
)
|
|
)
|
|
entity = result.scalar_one_or_none()
|
|
return _to_choice(entity) if entity else None
|
|
|
|
async def list_for_entry(self, entry_id: uuid.UUID) -> list[AdventureEntryPossibleChoice]:
|
|
result = await self.db.execute(
|
|
select(AdventureEntryPossibleChoiceEntity)
|
|
.where(AdventureEntryPossibleChoiceEntity.entry_id == entry_id)
|
|
.order_by(AdventureEntryPossibleChoiceEntity.index.asc())
|
|
)
|
|
return [_to_choice(e) for e in result.scalars().all()]
|
|
|
|
|
|
class PostgresAdventureEntryDecisionRepository:
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
self.db = db
|
|
|
|
async def create(
|
|
self, choice_id: uuid.UUID, user_id: uuid.UUID
|
|
) -> AdventureEntryPossibleChoiceDecision:
|
|
entity = AdventureEntryPossibleChoiceDecisionEntity(
|
|
choice_id=choice_id, user_id=user_id
|
|
)
|
|
self.db.add(entity)
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _to_decision(entity)
|
|
|
|
async def get_for_entry_and_user(
|
|
self, entry_id: uuid.UUID, user_id: uuid.UUID
|
|
) -> AdventureEntryPossibleChoiceDecision | None:
|
|
result = await self.db.execute(
|
|
select(AdventureEntryPossibleChoiceDecisionEntity)
|
|
.join(
|
|
AdventureEntryPossibleChoiceEntity,
|
|
AdventureEntryPossibleChoiceDecisionEntity.choice_id
|
|
== AdventureEntryPossibleChoiceEntity.id,
|
|
)
|
|
.where(
|
|
AdventureEntryPossibleChoiceEntity.entry_id == entry_id,
|
|
AdventureEntryPossibleChoiceDecisionEntity.user_id == user_id,
|
|
)
|
|
)
|
|
entity = result.scalar_one_or_none()
|
|
return _to_decision(entity) if entity else None
|
|
|
|
async def list_for_adventure_and_user(
|
|
self, adventure_id: uuid.UUID, user_id: uuid.UUID
|
|
) -> list[AdventureEntryPossibleChoiceDecision]:
|
|
result = await self.db.execute(
|
|
select(AdventureEntryPossibleChoiceDecisionEntity)
|
|
.join(
|
|
AdventureEntryPossibleChoiceEntity,
|
|
AdventureEntryPossibleChoiceDecisionEntity.choice_id
|
|
== AdventureEntryPossibleChoiceEntity.id,
|
|
)
|
|
.join(
|
|
AdventureEntryEntity,
|
|
AdventureEntryPossibleChoiceEntity.entry_id == AdventureEntryEntity.id,
|
|
)
|
|
.where(
|
|
AdventureEntryEntity.adventure_id == adventure_id,
|
|
AdventureEntryPossibleChoiceDecisionEntity.user_id == user_id,
|
|
)
|
|
)
|
|
return [_to_decision(e) for e in result.scalars().all()]
|
|
|
|
|
|
class PostgresAdventureEntryTranslationRepository:
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
self.db = db
|
|
|
|
async def create(
|
|
self,
|
|
entry_id: uuid.UUID,
|
|
component_type: str,
|
|
target_language: str,
|
|
translated_text: str,
|
|
) -> AdventureEntryTranslation:
|
|
entity = AdventureEntryTranslationEntity(
|
|
entry_id=entry_id,
|
|
component_type=component_type,
|
|
target_language=target_language,
|
|
translated_text=translated_text,
|
|
)
|
|
self.db.add(entity)
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _to_translation(entity)
|
|
|
|
async def get_for_entry(
|
|
self, entry_id: uuid.UUID, component_type: str, target_language: str
|
|
) -> AdventureEntryTranslation | None:
|
|
result = await self.db.execute(
|
|
select(AdventureEntryTranslationEntity).where(
|
|
AdventureEntryTranslationEntity.entry_id == entry_id,
|
|
AdventureEntryTranslationEntity.component_type == component_type,
|
|
AdventureEntryTranslationEntity.target_language == target_language,
|
|
)
|
|
)
|
|
entity = result.scalar_one_or_none()
|
|
return _to_translation(entity) if entity else None
|
|
|
|
|
|
class PostgresAdventureEntryAudioRepository:
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
self.db = db
|
|
|
|
async def create(
|
|
self,
|
|
entry_id: uuid.UUID,
|
|
component_type: str,
|
|
tts_provider: str,
|
|
tts_options: dict,
|
|
file_name: str,
|
|
) -> AdventureEntryAudio:
|
|
entity = AdventureEntryAudioEntity(
|
|
entry_id=entry_id,
|
|
component_type=component_type,
|
|
tts_provider=tts_provider,
|
|
tts_options=tts_options,
|
|
file_name=file_name,
|
|
)
|
|
self.db.add(entity)
|
|
await self.db.commit()
|
|
await self.db.refresh(entity)
|
|
return _to_audio(entity)
|
|
|
|
async def get_for_entry(
|
|
self, entry_id: uuid.UUID, component_type: str
|
|
) -> AdventureEntryAudio | None:
|
|
result = await self.db.execute(
|
|
select(AdventureEntryAudioEntity).where(
|
|
AdventureEntryAudioEntity.entry_id == entry_id,
|
|
AdventureEntryAudioEntity.component_type == component_type,
|
|
)
|
|
)
|
|
entity = result.scalar_one_or_none()
|
|
return _to_audio(entity) if entity else None
|