fix: [api] Pass along length of response from request
This commit is contained in:
parent
2d5933ff59
commit
ac73bd1a04
2 changed files with 87 additions and 22 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from ...languages import SUPPORTED_LANGUAGES
|
||||||
from ...outbound.anthropic.adventure_prompts import (
|
from ...outbound.anthropic.adventure_prompts import (
|
||||||
build_conversation_messages,
|
build_conversation_messages,
|
||||||
build_entry_system_prompt,
|
build_entry_system_prompt,
|
||||||
|
|
@ -21,8 +22,11 @@ from ...outbound.postgres.repositories.adventure_repository import (
|
||||||
PostgresAdventureRepository,
|
PostgresAdventureRepository,
|
||||||
)
|
)
|
||||||
from ...storage import upload_audio
|
from ...storage import upload_audio
|
||||||
from ..models.adventure import Adventure, AdventureEntry, AdventureEntryPossibleChoiceDecision
|
from ..models.adventure import (
|
||||||
from ...languages import SUPPORTED_LANGUAGES
|
Adventure,
|
||||||
|
AdventureEntry,
|
||||||
|
AdventureEntryPossibleChoiceDecision,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -60,6 +64,7 @@ class AdventureService:
|
||||||
setting: list[str],
|
setting: list[str],
|
||||||
vibes: list[str],
|
vibes: list[str],
|
||||||
protagonist: list[str],
|
protagonist: list[str],
|
||||||
|
entry_word_count_range: list[int],
|
||||||
max_entry_count: int = 6,
|
max_entry_count: int = 6,
|
||||||
) -> tuple[Adventure, AdventureEntry]:
|
) -> tuple[Adventure, AdventureEntry]:
|
||||||
"""Creates the adventure and a placeholder for the first entry.
|
"""Creates the adventure and a placeholder for the first entry.
|
||||||
|
|
@ -76,7 +81,10 @@ class AdventureService:
|
||||||
vibes=vibes,
|
vibes=vibes,
|
||||||
protagonist=protagonist,
|
protagonist=protagonist,
|
||||||
max_entry_count=max_entry_count,
|
max_entry_count=max_entry_count,
|
||||||
entry_story_text_target_length={"min": 700, "max": 800},
|
entry_story_text_target_length={
|
||||||
|
"min": entry_word_count_range[0],
|
||||||
|
"max": entry_word_count_range[1],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
first_entry = await self.entry_repo.create(
|
first_entry = await self.entry_repo.create(
|
||||||
adventure_id=uuid.UUID(adventure.id),
|
adventure_id=uuid.UUID(adventure.id),
|
||||||
|
|
@ -154,10 +162,14 @@ class AdventureService:
|
||||||
is_final_entry = current_entry.entry_index + 1 == adventure.max_entry_count
|
is_final_entry = current_entry.entry_index + 1 == adventure.max_entry_count
|
||||||
|
|
||||||
prior_entries = await self._load_prior_entries_with_metadata(
|
prior_entries = await self._load_prior_entries_with_metadata(
|
||||||
all_entries=[e for e in all_entries if e.entry_index < current_entry.entry_index],
|
all_entries=[
|
||||||
|
e for e in all_entries if e.entry_index < current_entry.entry_index
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
language_name = SUPPORTED_LANGUAGES.get(adventure.language, adventure.language)
|
language_name = SUPPORTED_LANGUAGES.get(
|
||||||
|
adventure.language, adventure.language
|
||||||
|
)
|
||||||
competency = adventure.competencies[0] if adventure.competencies else "B1"
|
competency = adventure.competencies[0] if adventure.competencies else "B1"
|
||||||
system_prompt = build_entry_system_prompt(
|
system_prompt = build_entry_system_prompt(
|
||||||
language_name=language_name,
|
language_name=language_name,
|
||||||
|
|
@ -193,10 +205,15 @@ class AdventureService:
|
||||||
if not is_final_entry:
|
if not is_final_entry:
|
||||||
await self.choice_repo.create_many(
|
await self.choice_repo.create_many(
|
||||||
entry_id=entry_id,
|
entry_id=entry_id,
|
||||||
choices=[(i, label, text) for i, (label, text) in enumerate(choices_parsed)],
|
choices=[
|
||||||
|
(i, label, text)
|
||||||
|
for i, (label, text) in enumerate(choices_parsed)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
translated = await self.deepl_client.translate(story_text, adventure.source_language)
|
translated = await self.deepl_client.translate(
|
||||||
|
story_text, adventure.source_language
|
||||||
|
)
|
||||||
await self.translation_repo.create(
|
await self.translation_repo.create(
|
||||||
entry_id=entry_id,
|
entry_id=entry_id,
|
||||||
component_type="story_text",
|
component_type="story_text",
|
||||||
|
|
@ -218,7 +235,9 @@ class AdventureService:
|
||||||
|
|
||||||
if is_first_entry:
|
if is_first_entry:
|
||||||
title_system = build_title_system_prompt()
|
title_system = build_title_system_prompt()
|
||||||
title_user = build_title_user_message(story_text, language_name, adventure.genres)
|
title_user = build_title_user_message(
|
||||||
|
story_text, language_name, adventure.genres
|
||||||
|
)
|
||||||
title_raw, _ = await self.anthropic_client.complete(
|
title_raw, _ = await self.anthropic_client.complete(
|
||||||
system_prompt=title_system,
|
system_prompt=title_system,
|
||||||
messages=[{"role": "user", "content": title_user}],
|
messages=[{"role": "user", "content": title_user}],
|
||||||
|
|
@ -230,13 +249,17 @@ class AdventureService:
|
||||||
)
|
)
|
||||||
|
|
||||||
new_status = "complete" if is_final_entry else "active"
|
new_status = "complete" if is_final_entry else "active"
|
||||||
await self.adventure_repo.update_status(adventure_id=adventure_id, status=new_status)
|
await self.adventure_repo.update_status(
|
||||||
|
adventure_id=adventure_id, status=new_status
|
||||||
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Entry pipeline failed for entry %s", entry_id)
|
logger.exception("Entry pipeline failed for entry %s", entry_id)
|
||||||
try:
|
try:
|
||||||
await self.entry_repo.update_status(entry_id=entry_id, status="error")
|
await self.entry_repo.update_status(entry_id=entry_id, status="error")
|
||||||
await self.adventure_repo.update_status(adventure_id=adventure_id, status="error")
|
await self.adventure_repo.update_status(
|
||||||
|
adventure_id=adventure_id, status="error"
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to mark entry/adventure as error")
|
logger.exception("Failed to mark entry/adventure as error")
|
||||||
|
|
||||||
|
|
@ -258,7 +281,11 @@ class AdventureService:
|
||||||
next_entry = sorted_entries[i + 1]
|
next_entry = sorted_entries[i + 1]
|
||||||
if next_entry.generated_from_choice_id:
|
if next_entry.generated_from_choice_id:
|
||||||
chosen = next(
|
chosen = next(
|
||||||
(c for c in choices if c.id == next_entry.generated_from_choice_id),
|
(
|
||||||
|
c
|
||||||
|
for c in choices
|
||||||
|
if c.id == next_entry.generated_from_choice_id
|
||||||
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if chosen:
|
if chosen:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ... import worker
|
||||||
from ...auth import verify_token
|
from ...auth import verify_token
|
||||||
from ...config import settings
|
from ...config import settings
|
||||||
from ...domain.services.adventure_service import AdventureService
|
from ...domain.services.adventure_service import AdventureService
|
||||||
|
|
@ -23,7 +24,6 @@ from ...outbound.postgres.repositories.adventure_repository import (
|
||||||
PostgresAdventureEntryTranslationRepository,
|
PostgresAdventureEntryTranslationRepository,
|
||||||
PostgresAdventureRepository,
|
PostgresAdventureRepository,
|
||||||
)
|
)
|
||||||
from ... import worker
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/adventures", tags=["adventures"])
|
router = APIRouter(prefix="/adventures", tags=["adventures"])
|
||||||
|
|
||||||
|
|
@ -44,8 +44,7 @@ _STUB_ENTRY_RESPONSE = (
|
||||||
"no notes"
|
"no notes"
|
||||||
)
|
)
|
||||||
_STUB_TITLE_RESPONSE = (
|
_STUB_TITLE_RESPONSE = (
|
||||||
"La Nuit Parisienne\n"
|
"La Nuit Parisienne\nUne aventure mystérieuse dans les rues sombres de Paris."
|
||||||
"Une aventure mystérieuse dans les rues sombres de Paris."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -57,7 +56,12 @@ class _StubAnthropicClient:
|
||||||
model: str = "claude-sonnet-4-6",
|
model: str = "claude-sonnet-4-6",
|
||||||
max_tokens: int = 2048,
|
max_tokens: int = 2048,
|
||||||
) -> tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
usage = {"provider": "stub", "model": "stub", "input_tokens": 0, "output_tokens": 0}
|
usage = {
|
||||||
|
"provider": "stub",
|
||||||
|
"model": "stub",
|
||||||
|
"input_tokens": 0,
|
||||||
|
"output_tokens": 0,
|
||||||
|
}
|
||||||
if "game master" in system_prompt.lower():
|
if "game master" in system_prompt.lower():
|
||||||
return _STUB_ENTRY_RESPONSE, usage
|
return _STUB_ENTRY_RESPONSE, usage
|
||||||
return _STUB_TITLE_RESPONSE, usage
|
return _STUB_TITLE_RESPONSE, usage
|
||||||
|
|
@ -67,7 +71,9 @@ class _StubDeepLClient:
|
||||||
def can_translate_to(self, lang: str) -> bool:
|
def can_translate_to(self, lang: str) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def translate(self, text: str, to_language: str, context: str | None = None) -> str:
|
async def translate(
|
||||||
|
self, text: str, to_language: str, context: str | None = None
|
||||||
|
) -> str:
|
||||||
return f"[STUB] {text[:120]}"
|
return f"[STUB] {text[:120]}"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -89,6 +95,7 @@ class _StubGeminiClient:
|
||||||
# Service factory
|
# Service factory
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _make_service(db: AsyncSession) -> AdventureService:
|
def _make_service(db: AsyncSession) -> AdventureService:
|
||||||
if settings.stub_generation:
|
if settings.stub_generation:
|
||||||
anthropic = _StubAnthropicClient() # type: ignore[assignment]
|
anthropic = _StubAnthropicClient() # type: ignore[assignment]
|
||||||
|
|
@ -123,6 +130,7 @@ async def _run_entry_pipeline_task(
|
||||||
# Request / response models
|
# Request / response models
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class CreateAdventureRequest(BaseModel):
|
class CreateAdventureRequest(BaseModel):
|
||||||
language: str
|
language: str
|
||||||
source_language: str
|
source_language: str
|
||||||
|
|
@ -131,6 +139,7 @@ class CreateAdventureRequest(BaseModel):
|
||||||
setting: list[str]
|
setting: list[str]
|
||||||
vibes: list[str]
|
vibes: list[str]
|
||||||
protagonist: list[str]
|
protagonist: list[str]
|
||||||
|
entry_word_count_range: str
|
||||||
max_entry_count: int = 6
|
max_entry_count: int = 6
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -196,6 +205,7 @@ class EntryDetailResponse(BaseModel):
|
||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _to_adventure_response(adventure) -> AdventureResponse:
|
def _to_adventure_response(adventure) -> AdventureResponse:
|
||||||
return AdventureResponse(
|
return AdventureResponse(
|
||||||
id=adventure.id,
|
id=adventure.id,
|
||||||
|
|
@ -226,6 +236,7 @@ def _parse_adventure_id(adventure_id: str) -> uuid.UUID:
|
||||||
# Endpoints
|
# Endpoints
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=AdventureResponse, status_code=201)
|
@router.post("", response_model=AdventureResponse, status_code=201)
|
||||||
async def create_adventure(
|
async def create_adventure(
|
||||||
body: CreateAdventureRequest,
|
body: CreateAdventureRequest,
|
||||||
|
|
@ -240,13 +251,30 @@ async def create_adventure(
|
||||||
detail=f"Unsupported language '{body.language}'. Supported: {list(SUPPORTED_LANGUAGES)}",
|
detail=f"Unsupported language '{body.language}'. Supported: {list(SUPPORTED_LANGUAGES)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
deepl_client = DeepLClient(settings.deepl_api_key) if not settings.stub_generation else _StubDeepLClient() # type: ignore[assignment]
|
deepl_client = (
|
||||||
|
DeepLClient(settings.deepl_api_key)
|
||||||
|
if not settings.stub_generation
|
||||||
|
else _StubDeepLClient()
|
||||||
|
) # type: ignore[assignment]
|
||||||
if not deepl_client.can_translate_to(body.source_language):
|
if not deepl_client.can_translate_to(body.source_language):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Cannot translate to source language '{body.source_language}'",
|
detail=f"Cannot translate to source language '{body.source_language}'",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Word count is e.g. "100-200 Words", convert to a tuple of ints (100, 200)
|
||||||
|
try:
|
||||||
|
word_count_range = tuple(
|
||||||
|
int(x.strip().split()[0]) for x in body.entry_word_count_range.split("-")
|
||||||
|
)
|
||||||
|
if len(word_count_range) != 2 or word_count_range[0] >= word_count_range[1]:
|
||||||
|
raise ValueError()
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Invalid entry_word_count_range. Expected format 'min-max Words', e.g. '100-200 Words'",
|
||||||
|
)
|
||||||
|
|
||||||
adventure, first_entry = await _make_service(db).create_adventure_for_user(
|
adventure, first_entry = await _make_service(db).create_adventure_for_user(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
language=body.language,
|
language=body.language,
|
||||||
|
|
@ -257,9 +285,12 @@ async def create_adventure(
|
||||||
vibes=body.vibes,
|
vibes=body.vibes,
|
||||||
protagonist=body.protagonist,
|
protagonist=body.protagonist,
|
||||||
max_entry_count=body.max_entry_count,
|
max_entry_count=body.max_entry_count,
|
||||||
|
entry_word_count_range=word_count_range,
|
||||||
)
|
)
|
||||||
await worker.enqueue(
|
await worker.enqueue(
|
||||||
partial(_run_entry_pipeline_task, uuid.UUID(adventure.id), uuid.UUID(first_entry.id))
|
partial(
|
||||||
|
_run_entry_pipeline_task, uuid.UUID(adventure.id), uuid.UUID(first_entry.id)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return _to_adventure_response(adventure)
|
return _to_adventure_response(adventure)
|
||||||
|
|
||||||
|
|
@ -281,7 +312,9 @@ async def get_adventure(
|
||||||
token_data: dict = Depends(verify_token),
|
token_data: dict = Depends(verify_token),
|
||||||
) -> AdventureResponse:
|
) -> AdventureResponse:
|
||||||
user_id = uuid.UUID(token_data["sub"])
|
user_id = uuid.UUID(token_data["sub"])
|
||||||
adventure = await PostgresAdventureRepository(db).get_by_id(_parse_adventure_id(adventure_id))
|
adventure = await PostgresAdventureRepository(db).get_by_id(
|
||||||
|
_parse_adventure_id(adventure_id)
|
||||||
|
)
|
||||||
if adventure is None or adventure.user_id != str(user_id):
|
if adventure is None or adventure.user_id != str(user_id):
|
||||||
raise HTTPException(status_code=404, detail="Adventure not found")
|
raise HTTPException(status_code=404, detail="Adventure not found")
|
||||||
return _to_adventure_response(adventure)
|
return _to_adventure_response(adventure)
|
||||||
|
|
@ -301,7 +334,9 @@ async def delete_adventure(
|
||||||
await repo.soft_delete(uuid.UUID(adventure.id))
|
await repo.soft_delete(uuid.UUID(adventure.id))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{adventure_id}/decisions", response_model=DecisionResponse, status_code=201)
|
@router.post(
|
||||||
|
"/{adventure_id}/decisions", response_model=DecisionResponse, status_code=201
|
||||||
|
)
|
||||||
async def record_decision(
|
async def record_decision(
|
||||||
adventure_id: str,
|
adventure_id: str,
|
||||||
body: CreateDecisionRequest,
|
body: CreateDecisionRequest,
|
||||||
|
|
@ -316,7 +351,9 @@ async def record_decision(
|
||||||
raise HTTPException(status_code=400, detail="Invalid choice_id")
|
raise HTTPException(status_code=400, detail="Invalid choice_id")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decision, next_entry = await _make_service(db).record_decision_and_prepare_next_entry(
|
decision, next_entry = await _make_service(
|
||||||
|
db
|
||||||
|
).record_decision_and_prepare_next_entry(
|
||||||
adventure_id=_parse_adventure_id(adventure_id),
|
adventure_id=_parse_adventure_id(adventure_id),
|
||||||
choice_id=choice_id,
|
choice_id=choice_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
@ -418,7 +455,8 @@ async def get_entry(
|
||||||
story_text=entry.story_text,
|
story_text=entry.story_text,
|
||||||
created_at=entry.created_at.isoformat(),
|
created_at=entry.created_at.isoformat(),
|
||||||
choices=[
|
choices=[
|
||||||
ChoiceResponse(id=c.id, index=c.index, label=c.label, text=c.text) for c in choices
|
ChoiceResponse(id=c.id, index=c.index, label=c.label, text=c.text)
|
||||||
|
for c in choices
|
||||||
],
|
],
|
||||||
translation=translation.translated_text if translation else None,
|
translation=translation.translated_text if translation else None,
|
||||||
audio_file_name=audio.file_name if audio else None,
|
audio_file_name=audio.file_name if audio else None,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue