language-learning-app/api/app/outbound/gemini/gemini_client.py
wilson fecb5839ea
Some checks failed
/ test (push) Has been cancelled
feats: use Procrastinate for persistant jobs; try using Gemini for text
generation
2026-05-27 18:45:52 +01:00

117 lines
3.5 KiB
Python

import asyncio
import io
import json
import wave
from google import genai
from google.genai import types as genai_types
from app.domain.models.gen_ai import GenAiChatMessage
def _pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000) -> bytes:
buf = io.BytesIO()
with wave.open(buf, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(pcm_data)
return buf.getvalue()
VOICE_BY_LANGUAGE: dict[str, str] = {
"fr": "Kore",
"es": "Charon",
"it": "Aoede",
"de": "Fenrir",
"en": "Kore",
}
class GeminiClient:
"""Communicate with Google's Gemini LLM"""
def __init__(self, api_key: str):
self._api_key = api_key
@classmethod
def new(cls, api_key: str) -> "GeminiClient":
return GeminiClient(api_key)
def get_voice_by_language(self, target_language: str) -> str:
possible_voice = VOICE_BY_LANGUAGE.get(target_language)
if not possible_voice:
raise ValueError(f"No voice found for language: {target_language}")
return possible_voice
def _make_gemini_messags(
self, messages: list[GenAiChatMessage]
) -> list[genai_types.Content]:
def transform(message: GenAiChatMessage) -> genai_types.Content:
role_name = "model"
if message.actor == "user":
role_name = "user"
return genai_types.Content(
role=role_name,
parts=[genai_types.Part.from_text(text=message.content)],
)
return list(map(transform, messages))
async def complete(
self,
system_prompt: str,
messages: list[GenAiChatMessage],
model: str = "gemini-3.1-flash-lite",
max_tokens: int = 2048,
) -> tuple[str, dict]:
client = genai.Client(api_key=self._api_key)
contents = self._make_gemini_messags(messages)
response = client.models.generate_content(
model=model,
contents=contents,
config=genai_types.GenerateContentConfig(
system_instruction=system_prompt,
temperature=1.5,
top_p=0.95,
max_output_tokens=max_tokens,
),
)
response_text = response.candidates[0].content.parts[0].text
response_metadata = {
"model": model,
"total_token_count": response.usage_metadata.total_token_count,
}
return response_text, response_metadata
async def generate_audio(self, text: str, voice: str) -> bytes:
"""Generate TTS audio and return WAV bytes."""
def _call() -> bytes:
client = genai.Client(api_key=self._api_key)
response = client.models.generate_content(
model="gemini-2.5-flash-preview-tts",
contents=text,
config=genai_types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=genai_types.SpeechConfig(
voice_config=genai_types.VoiceConfig(
prebuilt_voice_config=genai_types.PrebuiltVoiceConfig(
voice_name=voice,
)
)
),
),
)
pcm_data = response.candidates[0].content.parts[0].inline_data.data
return _pcm_to_wav(pcm_data)
return await asyncio.to_thread(_call)