fix: [api] Pass along length of response from request

This commit is contained in:
wilson 2026-05-03 22:38:17 +01:00
parent 2d5933ff59
commit ac73bd1a04
2 changed files with 87 additions and 22 deletions

View file

@ -1,6 +1,7 @@
import logging import logging
import uuid import uuid
from ...languages import SUPPORTED_LANGUAGES
from ...outbound.anthropic.adventure_prompts import ( from ...outbound.anthropic.adventure_prompts import (
build_conversation_messages, build_conversation_messages,
build_entry_system_prompt, build_entry_system_prompt,
@ -21,8 +22,11 @@ from ...outbound.postgres.repositories.adventure_repository import (
PostgresAdventureRepository, PostgresAdventureRepository,
) )
from ...storage import upload_audio from ...storage import upload_audio
from ..models.adventure import Adventure, AdventureEntry, AdventureEntryPossibleChoiceDecision from ..models.adventure import (
from ...languages import SUPPORTED_LANGUAGES Adventure,
AdventureEntry,
AdventureEntryPossibleChoiceDecision,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -60,6 +64,7 @@ class AdventureService:
setting: list[str], setting: list[str],
vibes: list[str], vibes: list[str],
protagonist: list[str], protagonist: list[str],
entry_word_count_range: list[int],
max_entry_count: int = 6, max_entry_count: int = 6,
) -> tuple[Adventure, AdventureEntry]: ) -> tuple[Adventure, AdventureEntry]:
"""Creates the adventure and a placeholder for the first entry. """Creates the adventure and a placeholder for the first entry.
@ -76,7 +81,10 @@ class AdventureService:
vibes=vibes, vibes=vibes,
protagonist=protagonist, protagonist=protagonist,
max_entry_count=max_entry_count, max_entry_count=max_entry_count,
entry_story_text_target_length={"min": 700, "max": 800}, entry_story_text_target_length={
"min": entry_word_count_range[0],
"max": entry_word_count_range[1],
},
) )
first_entry = await self.entry_repo.create( first_entry = await self.entry_repo.create(
adventure_id=uuid.UUID(adventure.id), adventure_id=uuid.UUID(adventure.id),
@ -154,10 +162,14 @@ class AdventureService:
is_final_entry = current_entry.entry_index + 1 == adventure.max_entry_count is_final_entry = current_entry.entry_index + 1 == adventure.max_entry_count
prior_entries = await self._load_prior_entries_with_metadata( prior_entries = await self._load_prior_entries_with_metadata(
all_entries=[e for e in all_entries if e.entry_index < current_entry.entry_index], all_entries=[
e for e in all_entries if e.entry_index < current_entry.entry_index
],
) )
language_name = SUPPORTED_LANGUAGES.get(adventure.language, adventure.language) language_name = SUPPORTED_LANGUAGES.get(
adventure.language, adventure.language
)
competency = adventure.competencies[0] if adventure.competencies else "B1" competency = adventure.competencies[0] if adventure.competencies else "B1"
system_prompt = build_entry_system_prompt( system_prompt = build_entry_system_prompt(
language_name=language_name, language_name=language_name,
@ -193,10 +205,15 @@ class AdventureService:
if not is_final_entry: if not is_final_entry:
await self.choice_repo.create_many( await self.choice_repo.create_many(
entry_id=entry_id, entry_id=entry_id,
choices=[(i, label, text) for i, (label, text) in enumerate(choices_parsed)], choices=[
(i, label, text)
for i, (label, text) in enumerate(choices_parsed)
],
) )
translated = await self.deepl_client.translate(story_text, adventure.source_language) translated = await self.deepl_client.translate(
story_text, adventure.source_language
)
await self.translation_repo.create( await self.translation_repo.create(
entry_id=entry_id, entry_id=entry_id,
component_type="story_text", component_type="story_text",
@ -218,7 +235,9 @@ class AdventureService:
if is_first_entry: if is_first_entry:
title_system = build_title_system_prompt() title_system = build_title_system_prompt()
title_user = build_title_user_message(story_text, language_name, adventure.genres) title_user = build_title_user_message(
story_text, language_name, adventure.genres
)
title_raw, _ = await self.anthropic_client.complete( title_raw, _ = await self.anthropic_client.complete(
system_prompt=title_system, system_prompt=title_system,
messages=[{"role": "user", "content": title_user}], messages=[{"role": "user", "content": title_user}],
@ -230,13 +249,17 @@ class AdventureService:
) )
new_status = "complete" if is_final_entry else "active" new_status = "complete" if is_final_entry else "active"
await self.adventure_repo.update_status(adventure_id=adventure_id, status=new_status) await self.adventure_repo.update_status(
adventure_id=adventure_id, status=new_status
)
except Exception: except Exception:
logger.exception("Entry pipeline failed for entry %s", entry_id) logger.exception("Entry pipeline failed for entry %s", entry_id)
try: try:
await self.entry_repo.update_status(entry_id=entry_id, status="error") await self.entry_repo.update_status(entry_id=entry_id, status="error")
await self.adventure_repo.update_status(adventure_id=adventure_id, status="error") await self.adventure_repo.update_status(
adventure_id=adventure_id, status="error"
)
except Exception: except Exception:
logger.exception("Failed to mark entry/adventure as error") logger.exception("Failed to mark entry/adventure as error")
@ -258,7 +281,11 @@ class AdventureService:
next_entry = sorted_entries[i + 1] next_entry = sorted_entries[i + 1]
if next_entry.generated_from_choice_id: if next_entry.generated_from_choice_id:
chosen = next( chosen = next(
(c for c in choices if c.id == next_entry.generated_from_choice_id), (
c
for c in choices
if c.id == next_entry.generated_from_choice_id
),
None, None,
) )
if chosen: if chosen:

View file

@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from ... import worker
from ...auth import verify_token from ...auth import verify_token
from ...config import settings from ...config import settings
from ...domain.services.adventure_service import AdventureService from ...domain.services.adventure_service import AdventureService
@ -23,7 +24,6 @@ from ...outbound.postgres.repositories.adventure_repository import (
PostgresAdventureEntryTranslationRepository, PostgresAdventureEntryTranslationRepository,
PostgresAdventureRepository, PostgresAdventureRepository,
) )
from ... import worker
router = APIRouter(prefix="/adventures", tags=["adventures"]) router = APIRouter(prefix="/adventures", tags=["adventures"])
@ -44,8 +44,7 @@ _STUB_ENTRY_RESPONSE = (
"no notes" "no notes"
) )
_STUB_TITLE_RESPONSE = ( _STUB_TITLE_RESPONSE = (
"La Nuit Parisienne\n" "La Nuit Parisienne\nUne aventure mystérieuse dans les rues sombres de Paris."
"Une aventure mystérieuse dans les rues sombres de Paris."
) )
@ -57,7 +56,12 @@ class _StubAnthropicClient:
model: str = "claude-sonnet-4-6", model: str = "claude-sonnet-4-6",
max_tokens: int = 2048, max_tokens: int = 2048,
) -> tuple[str, dict]: ) -> tuple[str, dict]:
usage = {"provider": "stub", "model": "stub", "input_tokens": 0, "output_tokens": 0} usage = {
"provider": "stub",
"model": "stub",
"input_tokens": 0,
"output_tokens": 0,
}
if "game master" in system_prompt.lower(): if "game master" in system_prompt.lower():
return _STUB_ENTRY_RESPONSE, usage return _STUB_ENTRY_RESPONSE, usage
return _STUB_TITLE_RESPONSE, usage return _STUB_TITLE_RESPONSE, usage
@ -67,7 +71,9 @@ class _StubDeepLClient:
def can_translate_to(self, lang: str) -> bool: def can_translate_to(self, lang: str) -> bool:
return True return True
async def translate(self, text: str, to_language: str, context: str | None = None) -> str: async def translate(
self, text: str, to_language: str, context: str | None = None
) -> str:
return f"[STUB] {text[:120]}" return f"[STUB] {text[:120]}"
@ -89,6 +95,7 @@ class _StubGeminiClient:
# Service factory # Service factory
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _make_service(db: AsyncSession) -> AdventureService: def _make_service(db: AsyncSession) -> AdventureService:
if settings.stub_generation: if settings.stub_generation:
anthropic = _StubAnthropicClient() # type: ignore[assignment] anthropic = _StubAnthropicClient() # type: ignore[assignment]
@ -123,6 +130,7 @@ async def _run_entry_pipeline_task(
# Request / response models # Request / response models
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class CreateAdventureRequest(BaseModel): class CreateAdventureRequest(BaseModel):
language: str language: str
source_language: str source_language: str
@ -131,6 +139,7 @@ class CreateAdventureRequest(BaseModel):
setting: list[str] setting: list[str]
vibes: list[str] vibes: list[str]
protagonist: list[str] protagonist: list[str]
entry_word_count_range: str
max_entry_count: int = 6 max_entry_count: int = 6
@ -196,6 +205,7 @@ class EntryDetailResponse(BaseModel):
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _to_adventure_response(adventure) -> AdventureResponse: def _to_adventure_response(adventure) -> AdventureResponse:
return AdventureResponse( return AdventureResponse(
id=adventure.id, id=adventure.id,
@ -226,6 +236,7 @@ def _parse_adventure_id(adventure_id: str) -> uuid.UUID:
# Endpoints # Endpoints
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("", response_model=AdventureResponse, status_code=201) @router.post("", response_model=AdventureResponse, status_code=201)
async def create_adventure( async def create_adventure(
body: CreateAdventureRequest, body: CreateAdventureRequest,
@ -240,13 +251,30 @@ async def create_adventure(
detail=f"Unsupported language '{body.language}'. Supported: {list(SUPPORTED_LANGUAGES)}", detail=f"Unsupported language '{body.language}'. Supported: {list(SUPPORTED_LANGUAGES)}",
) )
deepl_client = DeepLClient(settings.deepl_api_key) if not settings.stub_generation else _StubDeepLClient() # type: ignore[assignment] deepl_client = (
DeepLClient(settings.deepl_api_key)
if not settings.stub_generation
else _StubDeepLClient()
) # type: ignore[assignment]
if not deepl_client.can_translate_to(body.source_language): if not deepl_client.can_translate_to(body.source_language):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot translate to source language '{body.source_language}'", detail=f"Cannot translate to source language '{body.source_language}'",
) )
# Word count is e.g. "100-200 Words", convert to a tuple of ints (100, 200)
try:
word_count_range = tuple(
int(x.strip().split()[0]) for x in body.entry_word_count_range.split("-")
)
if len(word_count_range) != 2 or word_count_range[0] >= word_count_range[1]:
raise ValueError()
except ValueError:
raise HTTPException(
status_code=400,
detail="Invalid entry_word_count_range. Expected format 'min-max Words', e.g. '100-200 Words'",
)
adventure, first_entry = await _make_service(db).create_adventure_for_user( adventure, first_entry = await _make_service(db).create_adventure_for_user(
user_id=user_id, user_id=user_id,
language=body.language, language=body.language,
@ -257,9 +285,12 @@ async def create_adventure(
vibes=body.vibes, vibes=body.vibes,
protagonist=body.protagonist, protagonist=body.protagonist,
max_entry_count=body.max_entry_count, max_entry_count=body.max_entry_count,
entry_word_count_range=word_count_range,
) )
await worker.enqueue( await worker.enqueue(
partial(_run_entry_pipeline_task, uuid.UUID(adventure.id), uuid.UUID(first_entry.id)) partial(
_run_entry_pipeline_task, uuid.UUID(adventure.id), uuid.UUID(first_entry.id)
)
) )
return _to_adventure_response(adventure) return _to_adventure_response(adventure)
@ -281,7 +312,9 @@ async def get_adventure(
token_data: dict = Depends(verify_token), token_data: dict = Depends(verify_token),
) -> AdventureResponse: ) -> AdventureResponse:
user_id = uuid.UUID(token_data["sub"]) user_id = uuid.UUID(token_data["sub"])
adventure = await PostgresAdventureRepository(db).get_by_id(_parse_adventure_id(adventure_id)) adventure = await PostgresAdventureRepository(db).get_by_id(
_parse_adventure_id(adventure_id)
)
if adventure is None or adventure.user_id != str(user_id): if adventure is None or adventure.user_id != str(user_id):
raise HTTPException(status_code=404, detail="Adventure not found") raise HTTPException(status_code=404, detail="Adventure not found")
return _to_adventure_response(adventure) return _to_adventure_response(adventure)
@ -301,7 +334,9 @@ async def delete_adventure(
await repo.soft_delete(uuid.UUID(adventure.id)) await repo.soft_delete(uuid.UUID(adventure.id))
@router.post("/{adventure_id}/decisions", response_model=DecisionResponse, status_code=201) @router.post(
"/{adventure_id}/decisions", response_model=DecisionResponse, status_code=201
)
async def record_decision( async def record_decision(
adventure_id: str, adventure_id: str,
body: CreateDecisionRequest, body: CreateDecisionRequest,
@ -316,7 +351,9 @@ async def record_decision(
raise HTTPException(status_code=400, detail="Invalid choice_id") raise HTTPException(status_code=400, detail="Invalid choice_id")
try: try:
decision, next_entry = await _make_service(db).record_decision_and_prepare_next_entry( decision, next_entry = await _make_service(
db
).record_decision_and_prepare_next_entry(
adventure_id=_parse_adventure_id(adventure_id), adventure_id=_parse_adventure_id(adventure_id),
choice_id=choice_id, choice_id=choice_id,
user_id=user_id, user_id=user_id,
@ -418,7 +455,8 @@ async def get_entry(
story_text=entry.story_text, story_text=entry.story_text,
created_at=entry.created_at.isoformat(), created_at=entry.created_at.isoformat(),
choices=[ choices=[
ChoiceResponse(id=c.id, index=c.index, label=c.label, text=c.text) for c in choices ChoiceResponse(id=c.id, index=c.index, label=c.label, text=c.text)
for c in choices
], ],
translation=translation.translated_text if translation else None, translation=translation.translated_text if translation else None,
audio_file_name=audio.file_name if audio else None, audio_file_name=audio.file_name if audio else None,