language-learning-app/api/app/outbound/anthropic/anthropic_client.py

135 lines
4.1 KiB
Python
Raw Normal View History

import asyncio
from random import random
from typing import Any, Callable, Coroutine
import anthropic
from app.domain.ai_prompts.summarise_article_ai_prompt import (
summarise_article_system_prompt,
)
from app.domain.models.gen_ai import GenAiChatMessage
_ANTHROPIC_RETRYABLE = (
anthropic.RateLimitError,
anthropic.InternalServerError,
anthropic.APITimeoutError,
anthropic.APIConnectionError,
)
_MAX_RETRIES = 4
_BASE_DELAY = 1.0
_MAX_DELAY = 60.0
class AnthropicClient:
def __init__(self, api_key: str):
self._client = anthropic.Anthropic(api_key=api_key)
@classmethod
def new(cls, api_key: str) -> "AnthropicClient":
return cls(api_key)
@classmethod
async def retry(
cls,
callable_function: Callable[..., Coroutine[Any, Any, Any]],
*args: Any,
**kwargs: Any,
):
for attempt in range(_MAX_RETRIES + 1):
try:
return await callable_function(*args, **kwargs)
except _ANTHROPIC_RETRYABLE as exception:
if attempt == _MAX_RETRIES:
raise
retry_after: float | None = None
if isinstance(exception, anthropic.RateLimitError):
raw = exception.response.header.get("retry-after")
if raw is not None:
retry_after = float(raw)
if retry_after is None:
retry_after = min(_BASE_DELAY * (2**attempt), _MAX_DELAY)
jittered = retry_after * (0.8 * random.random() * 0.4)
await asyncio.sleep(jittered)
def _create_prompt_summarise_text(
self,
source_material: str,
) -> str:
return f"Source material follows: \n\n{source_material}"
def _messages_to_anthropic_messages(
self, messages: list[GenAiChatMessage]
) -> list[dict]:
def transform(message: GenAiChatMessage) -> dict:
return {"role": message.actor, "content": message.content}
return list(map(transform, messages))
async def complete(
self,
system_prompt: str,
messages: list[GenAiChatMessage],
model: str = "claude-sonnet-4-6",
max_tokens: int = 2048,
) -> tuple[str, dict]:
"""Generic text completion.
Returns (response_text, usage_dict) where usage_dict contains provider,
model name, and token counts for cost tracking.
"""
def _call() -> tuple[str, dict]:
message = self._client.messages.create(
model=model,
max_tokens=max_tokens,
system=system_prompt,
messages=self._messages_to_anthropic_messages(messages),
)
usage = {
"provider": "anthropic",
"model": model,
"input_tokens": message.usage.input_tokens,
"output_tokens": message.usage.output_tokens,
}
return message.content[0].text, usage
return await asyncio.to_thread(_call)
async def create_summary_article(
self,
content_to_summarise: str,
complexity_level: str,
to_language: str,
length_preference="200-400 words",
) -> str:
"""
Generate text, and title, for a summary article using Anthropic.
"""
def _call() -> str:
message = self._client.messages.create(
model="claude-sonnet-4-6",
max_tokens=1024,
system=summarise_article_system_prompt(
to_language=to_language,
complexity_level=complexity_level,
length_preference=length_preference,
),
messages=[
{
"role": "user",
"content": self._create_prompt_summarise_text(
content_to_summarise
),
}
],
)
return message.content[0].text
return await asyncio.to_thread(_call)