117 lines
3.5 KiB
Python
117 lines
3.5 KiB
Python
import asyncio
|
|
import io
|
|
import json
|
|
import wave
|
|
|
|
from google import genai
|
|
from google.genai import types as genai_types
|
|
|
|
from app.domain.models.gen_ai import GenAiChatMessage
|
|
|
|
|
|
def _pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000) -> bytes:
|
|
buf = io.BytesIO()
|
|
with wave.open(buf, "wb") as wf:
|
|
wf.setnchannels(1)
|
|
wf.setsampwidth(2)
|
|
wf.setframerate(sample_rate)
|
|
wf.writeframes(pcm_data)
|
|
return buf.getvalue()
|
|
|
|
|
|
VOICE_BY_LANGUAGE: dict[str, str] = {
|
|
"fr": "Kore",
|
|
"es": "Charon",
|
|
"it": "Aoede",
|
|
"de": "Fenrir",
|
|
"en": "Kore",
|
|
}
|
|
|
|
|
|
class GeminiClient:
|
|
"""Communicate with Google's Gemini LLM"""
|
|
|
|
def __init__(self, api_key: str):
|
|
self._api_key = api_key
|
|
|
|
@classmethod
|
|
def new(cls, api_key: str) -> "GeminiClient":
|
|
return GeminiClient(api_key)
|
|
|
|
def get_voice_by_language(self, target_language: str) -> str:
|
|
possible_voice = VOICE_BY_LANGUAGE.get(target_language)
|
|
|
|
if not possible_voice:
|
|
raise ValueError(f"No voice found for language: {target_language}")
|
|
|
|
return possible_voice
|
|
|
|
def _make_gemini_messags(
|
|
self, messages: list[GenAiChatMessage]
|
|
) -> list[genai_types.Content]:
|
|
def transform(message: GenAiChatMessage) -> genai_types.Content:
|
|
role_name = "model"
|
|
|
|
if message.actor == "user":
|
|
role_name = "user"
|
|
|
|
return genai_types.Content(
|
|
role=role_name,
|
|
parts=[genai_types.Part.from_text(text=message.content)],
|
|
)
|
|
|
|
return list(map(transform, messages))
|
|
|
|
async def complete(
|
|
self,
|
|
system_prompt: str,
|
|
messages: list[GenAiChatMessage],
|
|
model: str = "gemini-3.1-flash-lite",
|
|
max_tokens: int = 2048,
|
|
) -> tuple[str, dict]:
|
|
client = genai.Client(api_key=self._api_key)
|
|
|
|
contents = self._make_gemini_messags(messages)
|
|
|
|
response = client.models.generate_content(
|
|
model=model,
|
|
contents=contents,
|
|
config=genai_types.GenerateContentConfig(
|
|
system_instruction=system_prompt,
|
|
temperature=1.5,
|
|
top_p=0.95,
|
|
max_output_tokens=max_tokens,
|
|
),
|
|
)
|
|
|
|
response_text = response.candidates[0].content.parts[0].text
|
|
response_metadata = {
|
|
"model": model,
|
|
"total_token_count": response.usage_metadata.total_token_count,
|
|
}
|
|
|
|
return response_text, response_metadata
|
|
|
|
async def generate_audio(self, text: str, voice: str) -> bytes:
|
|
"""Generate TTS audio and return WAV bytes."""
|
|
|
|
def _call() -> bytes:
|
|
client = genai.Client(api_key=self._api_key)
|
|
response = client.models.generate_content(
|
|
model="gemini-2.5-flash-preview-tts",
|
|
contents=text,
|
|
config=genai_types.GenerateContentConfig(
|
|
response_modalities=["AUDIO"],
|
|
speech_config=genai_types.SpeechConfig(
|
|
voice_config=genai_types.VoiceConfig(
|
|
prebuilt_voice_config=genai_types.PrebuiltVoiceConfig(
|
|
voice_name=voice,
|
|
)
|
|
)
|
|
),
|
|
),
|
|
)
|
|
pcm_data = response.candidates[0].content.parts[0].inline_data.data
|
|
return _pcm_to_wav(pcm_data)
|
|
|
|
return await asyncio.to_thread(_call)
|