language-learning-app/api/app/routers/generation.py

137 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import uuid
from functools import partial
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from ..auth import verify_token
from ..database import get_db, AsyncSessionLocal
from ..models import Job
from ..storage import upload_audio
from ..services import llm, tts, job_repo
from ..services.tts import VOICE_BY_LANGUAGE
from .. import worker
router = APIRouter(prefix="/generate", tags=["generation"])
SUPPORTED_LANGUAGES: dict[str, str] = {
"en": "English",
"fr": "French",
"es": "Spanish",
"it": "Italian",
"de": "German",
}
SUPPORTED_LEVELS = {"A1", "A2", "B1", "B2", "C1", "C2"}
class GenerationRequest(BaseModel):
target_language: str
complexity_level: str
input_texts: list[str]
topic: str | None = None
source_language: str = "en"
class GenerationResponse(BaseModel):
job_id: str
async def _run_generation(job_id: uuid.UUID, request: GenerationRequest) -> None:
async with AsyncSessionLocal() as db:
job = await db.get(Job, job_id)
await job_repo.mark_processing(db, job)
try:
from_language = SUPPORTED_LANGUAGES[request.source_language]
language_name = SUPPORTED_LANGUAGES[request.target_language]
topic_part = f"Topic: {request.topic}. " if request.topic else ""
combined_preview = " ".join(request.input_texts)[:300]
input_summary = (
f"{topic_part}Based on {len(request.input_texts)} input text(s): "
f"{combined_preview}..."
)
source_material = "\n\n".join(request.input_texts[:3])
topic_line = f"\nTopic focus: {request.topic}" if request.topic else ""
prompt = (
f"You are a language learning content creator. "
f"Using the input provided, you generate engaging realistic text in {language_name} "
f"at {request.complexity_level} proficiency level (CEFR scale).\n\n"
f"The text should:\n"
f"- Be appropriate for a {request.complexity_level} learner\n"
f"- Maintain a similar tone to the input text. Where appropriate, use idioms\n"
f"- Feel natural and authentic, like content a native speaker would read\n"
f"- Be formatted in markdown with paragraphs and line breaks\n"
f"- Be 200400 words long\n"
f"- Be inspired by the following source material "
f"(but written originally in {language_name}):\n\n"
f"{source_material}"
f"{topic_line}\n\n"
f"Respond with ONLY the generated text in {language_name}, "
f"no explanations or translations.\n"
f"The 'Topic focus' should be a comma-separated list of up to three topics, in {language_name}."
)
generated_text = await llm.generate_text(prompt)
translate_prompt = (
f"You are a helpful assistant that translates text. Translate just the previous summary "
f"content in {language_name} text you generated based on the input I gave you. Translate "
f"it back into {from_language}.\n"
f"- Keep the translation as close as possible to the original meaning and tone\n"
f"- Send through only the translated text, no explanations or notes\n"
)
translated_text = await llm.translate_text(prompt, generated_text, translate_prompt)
# Save LLM results before attempting TTS so they're preserved on failure
await job_repo.save_llm_results(
db, job, generated_text, translated_text, input_summary[:500]
)
voice = VOICE_BY_LANGUAGE.get(request.target_language, "Kore")
wav_bytes = await tts.generate_audio(generated_text, voice)
audio_key = f"audio/{job_id}.wav"
upload_audio(audio_key, wav_bytes)
await job_repo.mark_succeeded(db, job, audio_key)
except Exception as exc:
await job_repo.mark_failed(db, job, str(exc))
@router.post("", response_model=GenerationResponse, status_code=202)
async def create_generation_job(
request: GenerationRequest,
db: AsyncSession = Depends(get_db),
token_data: dict = Depends(verify_token),
) -> GenerationResponse:
if request.target_language not in SUPPORTED_LANGUAGES:
raise HTTPException(
status_code=400,
detail=f"Unsupported language '{request.target_language}'. "
f"Supported: {list(SUPPORTED_LANGUAGES)}",
)
if request.complexity_level not in SUPPORTED_LEVELS:
raise HTTPException(
status_code=400,
detail=f"Unsupported level '{request.complexity_level}'. "
f"Supported: {sorted(SUPPORTED_LEVELS)}",
)
job = Job(
user_id=uuid.UUID(token_data["sub"]),
source_language=request.source_language,
target_language=request.target_language,
complexity_level=request.complexity_level,
)
db.add(job)
await db.commit()
await db.refresh(job)
await worker.enqueue(partial(_run_generation, job.id, request))
return GenerationResponse(job_id=str(job.id))