"""AI model interface and provider abstraction with retry support.
All provider calls are wrapped with an exponential backoff retry to cope
with transient failures. The retry behaviour can be configured via
``retries`` parameters on the public methods.
The :class:`AIModelInterface` constructor accepts optional API keys for
multiple providers. Notably, the ``openai_api_key`` and
``anthropic_api_key`` parameters (or the ``OPENAI_API_KEY`` and
``ANTHROPIC_API_KEY`` environment variables) enable OpenAI and
Anthropic models respectively.
Notes
-----
This module requires the ``openai`` (version >=1.0), ``google-genai``,
``anthropic``, ``requests``, ``transformers`` and ``diffusers`` packages.
"""
from __future__ import annotations
import asyncio
import base64
import io
import json
import os
import time
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from pathlib import Path
from typing import Any
from openai import AsyncOpenAI, OpenAI
try:
from google import genai
except ImportError: # pragma: no cover - optional dependency
genai = None
try:
import anthropic
except ImportError: # pragma: no cover - optional dependency
anthropic = None
import requests
import yaml
from huggingface_hub import InferenceClient
try:
from transformers import pipeline as hf_local_pipeline
except ImportError as exc: # pragma: no cover - optional dependency
if "sklearn" in repr(exc).lower():
raise ImportError(
"transformers import failed because scikit-learn is missing. "
"Install scikit-learn with `pip install scikit-learn`."
) from exc
hf_local_pipeline = None
try:
from diffusers import StableDiffusionPipeline
except ImportError: # pragma: no cover - optional dependency
StableDiffusionPipeline = None
[docs]
PromptType = str | list[dict[str, str]]
def _parse_model_mapping(env_value: str | None) -> dict[str, str]:
"""Parse ``alias:model_id`` pairs from an environment variable."""
mapping: dict[str, str] = {}
if not env_value:
return mapping
for raw_item in env_value.split(","):
alias, _, model_id = raw_item.partition(":")
alias = alias.strip()
model_id = model_id.strip()
if not alias:
continue
mapping[alias] = model_id or alias
return mapping
def _load_yaml_environment(path: str | Path) -> None:
"""Load environment variables from a YAML configuration file if available."""
config_path = Path(path)
if not config_path.exists():
return
try:
parsed = yaml.safe_load(config_path.read_text(encoding="utf-8")) or {}
except Exception as exc: # pragma: no cover - YAML parse errors are rare
raise RuntimeError(
f"Failed to parse configuration file {config_path}: {exc}"
) from exc
if not isinstance(parsed, dict):
return
env_values = parsed.get("environment", parsed)
if not isinstance(env_values, dict):
return
for key, value in env_values.items():
if value is None:
continue
key_str = str(key).strip()
if not key_str:
continue
# Force override any existing environment variables
os.environ[key_str] = str(value).strip()
[docs]
def load_env_file(path: str | Path = Path(".env")) -> None:
"""Load configuration-managed credentials before falling back to ``.env``.
Parameters
----------
path : str or Path, optional
Path to the environment file to load. Defaults to ``.env``.
"""
_load_yaml_environment(Path("config") / "coord2region-config.yaml")
env_path = Path(path)
if not env_path.exists():
return
for line in env_path.read_text(encoding="utf-8").splitlines():
stripped = line.strip()
if not stripped or stripped.startswith("#") or "=" not in stripped:
continue
key, value = stripped.split("=", 1)
os.environ.setdefault(key.strip(), value.strip())
[docs]
def huggingface_credentials_present() -> bool:
"""Check whether Hugging Face credentials are available.
Returns
-------
bool
``True`` if either Hugging Face API key environment variable is set.
"""
return bool(
os.environ.get("HUGGINGFACE_API_KEY")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
[docs]
def pick_first_supported_model(
ai: AIModelInterface, candidates: Iterable[str]
) -> str | None:
"""Return the first supported model from a list of candidates.
Parameters
----------
ai : AIModelInterface
Interface used to query model availability.
candidates : iterable of str
Candidate model names evaluated in order of preference.
Returns
-------
str or None
First supported model name or ``None`` if no match is found.
"""
for model in candidates:
try:
if ai.supports(model):
return model
except Exception:
continue
return None
[docs]
def build_generation_summary(model: str, response: str, provider: str) -> str:
"""Return a JSON summary describing a text generation output.
Parameters
----------
model : str
Model name used for the generation.
response : str
Raw text produced by the model.
provider : str
Provider label for the selected model.
Returns
-------
str
JSON-formatted metadata describing the generation.
"""
summary = {
"provider": provider,
"model": model,
"has_reasoning": "<think>" in response.lower(),
"tokens": len(response.split()),
}
return json.dumps(summary, indent=2)
def _retry_sync(func, retries: int = 3, base_delay: float = 0.1) -> Any:
"""Retry ``func`` with exponential backoff."""
delay = base_delay
for attempt in range(retries):
try:
return func()
except Exception:
if attempt == retries - 1:
raise
time.sleep(delay)
delay *= 2
async def _retry_async(func, retries: int = 3, base_delay: float = 0.1) -> Any:
"""Asynchronously retry ``func`` with exponential backoff."""
delay = base_delay
for attempt in range(retries):
try:
return await func()
except Exception:
if attempt == retries - 1:
raise
await asyncio.sleep(delay)
delay *= 2
def _retry_stream(func, retries: int = 3, base_delay: float = 0.1) -> Iterator[str]:
"""Retry a streaming function yielding from successive attempts."""
def generator() -> Iterator[str]:
delay = base_delay
for attempt in range(retries):
try:
yield from func()
return
except Exception:
if attempt == retries - 1:
raise
time.sleep(delay)
delay *= 2
return generator()
[docs]
class ModelProvider(ABC):
"""Base class for all model providers.
See the ``README`` section *Adding a Custom LLM Provider* for
guidance on implementing subclasses.
Parameters
----------
models : dict
Mapping of friendly model names to provider-specific identifiers.
"""
#: Whether the provider natively supports batching multiple prompts in a
#: single API call. Subclasses can override this to ``True`` when their
#: backend exposes such functionality.
[docs]
supports_batching: bool = False
def __init__(self, models: dict[str, str]):
[docs]
def supports(self, model: str) -> bool:
"""Return ``True`` if the provider exposes the requested model."""
return model in self.models
@abstractmethod
[docs]
def generate_text(self, model: str, prompt: PromptType, max_tokens: int) -> str:
"""Generate text from the given model."""
[docs]
async def generate_text_async(
self, model: str, prompt: PromptType, max_tokens: int
) -> str:
"""Asynchronously generate text.
Providers that expose native async APIs should override this method.
The default implementation simply delegates to :meth:`generate_text`
using ``asyncio.to_thread`` to avoid blocking the event loop.
"""
return await asyncio.to_thread(self.generate_text, model, prompt, max_tokens)
[docs]
def stream_generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> Iterator[str]:
"""Yield generated text chunks.
Providers that support server-side streaming should override this
method. The base implementation yields the full response in a single
chunk.
"""
yield self.generate_text(model, prompt, max_tokens)
class GeminiProvider(ModelProvider):
"""Provider for Google Gemini models.
Parameters
----------
api_key : str
API key used to authenticate with Google GenAI.
"""
def __init__(self, api_key: str): # pragma: no cover - network client setup
if genai is None:
raise ImportError(
"Google Gemini support requires the google-genai package. "
"Install it via `pip install google-genai`."
)
models = {
"gemini-1.0-pro": "gemini-1.0-pro",
"gemini-1.5-pro": "gemini-1.5-pro",
"gemini-2.0-flash": "gemini-2.0-flash",
}
super().__init__(models)
self.client = genai.Client(api_key=api_key)
def generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - thin wrapper
if isinstance(prompt, list):
prompt = " ".join(
msg["content"] for msg in prompt if msg.get("role") == "user"
)
response = self.client.models.generate_content(model=model, contents=[prompt])
return response.text
async def generate_text_async(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - thin wrapper
if hasattr(self.client.models, "generate_content_async"):
if isinstance(prompt, list):
prompt = " ".join(
msg["content"] for msg in prompt if msg.get("role") == "user"
)
response = await self.client.models.generate_content_async(
model=model, contents=[prompt]
)
return response.text
return await super().generate_text_async(model, prompt, max_tokens)
def stream_generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> Iterator[str]: # pragma: no cover - thin wrapper
if isinstance(prompt, list):
prompt = " ".join(
msg["content"] for msg in prompt if msg.get("role") == "user"
)
stream = self.client.models.generate_content(
model=model, contents=[prompt], stream=True
)
for chunk in stream:
text = getattr(chunk, "text", None)
if text:
yield text
class OpenRouterProvider(ModelProvider):
"""Provider for models available via OpenRouter (e.g., DeepSeek)."""
def __init__(self, api_key: str): # pragma: no cover - network client setup
models = {
"deepseek-r1": "deepseek/deepseek-r1:free",
"deepseek-chat-v3-0324": "deepseek/deepseek-chat-v3-0324:free",
}
super().__init__(models)
self.client = OpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1")
self.async_client = AsyncOpenAI(
api_key=api_key, base_url="https://openrouter.ai/api/v1"
)
def generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
response = self.client.responses.create(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
)
return response.output[0].content[0].text
async def generate_text_async(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
response = await self.async_client.responses.create(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
)
return response.output[0].content[0].text
def stream_generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> Iterator[str]: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
with self.client.responses.stream(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
) as stream:
for event in stream:
if event.type == "response.output_text.delta":
yield event.delta
class GroqProvider(ModelProvider):
"""Provider for Groq-hosted OpenAI-compatible models."""
def __init__(self, api_key: str): # pragma: no cover - network client setup
models = {
"groq-llama-3.1-70b": "llama-3.1-70b-versatile",
"groq-llama-3.1-8b": "llama-3.1-8b-instant",
}
super().__init__(models)
client_kwargs = {
"api_key": api_key,
"base_url": "https://api.groq.com/openai/v1",
}
self.client = OpenAI(**client_kwargs)
self.async_client = AsyncOpenAI(**client_kwargs)
def generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
response = self.client.responses.create(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
)
return response.output[0].content[0].text
async def generate_text_async(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
response = await self.async_client.responses.create(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
)
return response.output[0].content[0].text
def stream_generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> Iterator[str]: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
with self.client.responses.stream(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
) as stream:
for event in stream:
if event.type == "response.output_text.delta":
yield event.delta
class DeepSeekProvider(ModelProvider):
"""Provider for DeepSeek's native API."""
def __init__(self, api_key: str): # pragma: no cover - network client setup
models = {
"deepseek-reasoner": "deepseek-reasoner",
"deepseek-chat": "deepseek-chat",
}
super().__init__(models)
client_kwargs = {"api_key": api_key, "base_url": "https://api.deepseek.com/v1"}
self.client = OpenAI(**client_kwargs)
self.async_client = AsyncOpenAI(**client_kwargs)
def generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
response = self.client.responses.create(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
)
return response.output[0].content[0].text
async def generate_text_async(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
response = await self.async_client.responses.create(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
)
return response.output[0].content[0].text
def stream_generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> Iterator[str]: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
with self.client.responses.stream(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
) as stream:
for event in stream:
if event.type == "response.output_text.delta":
yield event.delta
class TogetherProvider(ModelProvider):
"""Provider for Together AI models (DeepSeek, Llama, Mixtral, etc.)."""
def __init__(self, api_key: str): # pragma: no cover - network client setup
models = {
"together-deepseek-r1": "deepseek-ai/DeepSeek-R1",
"together-llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
}
super().__init__(models)
client_kwargs = {"api_key": api_key, "base_url": "https://api.together.ai/v1"}
self.client = OpenAI(**client_kwargs)
self.async_client = AsyncOpenAI(**client_kwargs)
def generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
response = self.client.responses.create(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
)
return response.output[0].content[0].text
async def generate_text_async(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
response = await self.async_client.responses.create(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
)
return response.output[0].content[0].text
def stream_generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> Iterator[str]: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
with self.client.responses.stream(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
) as stream:
for event in stream:
if event.type == "response.output_text.delta":
yield event.delta
class LocalOpenAIProvider(ModelProvider):
"""Provider for self-hosted OpenAI-compatible gateways (vLLM, TGI, Ollama)."""
def __init__(
self,
*,
base_url: str,
api_key: str | None = None,
models: dict[str, str] | None = None,
default_model: str = "local-reasoning",
): # pragma: no cover - optional local deployment wrapper
if models is None or not models:
models = {default_model: default_model}
super().__init__(models)
api_key_value = api_key or "EMPTY"
client_kwargs = {
"api_key": api_key_value,
"base_url": base_url.rstrip("/"),
}
self.client = OpenAI(**client_kwargs)
self.async_client = AsyncOpenAI(**client_kwargs)
def generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
response = self.client.responses.create(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
)
return response.output[0].content[0].text
async def generate_text_async(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
response = await self.async_client.responses.create(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
)
return response.output[0].content[0].text
def stream_generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> Iterator[str]: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
with self.client.responses.stream(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
) as stream:
for event in stream:
if event.type == "response.output_text.delta":
yield event.delta
class OpenAIProvider(ModelProvider):
"""Provider for OpenAI's GPT models."""
def __init__(
self, api_key: str, project: str | None = None
): # pragma: no cover - network client setup
models = {
"gpt-4o": "gpt-4o",
"gpt-4o-mini": "gpt-4o-mini",
"gpt-4": "gpt-4-turbo-2024-04-09",
"gpt-image-1": "gpt-4o", # Uses gpt-4o with image generation tool
"dall-e-3": "dall-e-3",
"dall-e-2": "dall-e-2",
}
super().__init__(models)
self._image_models = {"gpt-image-1", "dall-e-3", "dall-e-2"}
client_kwargs: dict[str, Any] = {"api_key": api_key}
if project:
client_kwargs["project"] = project
self.client = OpenAI(**client_kwargs)
self.async_client = AsyncOpenAI(**client_kwargs)
def generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
response = self.client.responses.create(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
)
return response.output[0].content[0].text
async def generate_text_async(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
response = await self.async_client.responses.create(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
)
return response.output[0].content[0].text
def stream_generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> Iterator[str]: # pragma: no cover - thin wrapper
if isinstance(prompt, str):
prompt_input: PromptType = [{"role": "user", "content": prompt}]
else:
prompt_input = prompt
with self.client.responses.stream(
model=self.models[model],
input=prompt_input,
max_output_tokens=max_tokens,
) as stream:
for event in stream:
if event.type == "response.output_text.delta":
yield event.delta
def generate_image(
self, model: str, prompt: str, **kwargs: Any
) -> bytes: # pragma: no cover - network image wrapper
if model not in self._image_models:
raise ValueError(f"Model '{model}' is not an image model")
# gpt-image-1 uses the Responses API with image generation tool
if model == "gpt-image-1":
# Build the tool parameters from kwargs
tool_params = {"type": "image_generation"}
if "quality" in kwargs:
tool_params["quality"] = kwargs["quality"]
if "size" in kwargs:
tool_params["size"] = kwargs["size"]
if "background" in kwargs:
tool_params["background"] = kwargs["background"]
try:
response = self.client.responses.create(
model=self.models[model], # This should be gpt-4o or similar
input=prompt,
tools=[tool_params],
)
# Extract the image data from the response
image_data = None
for output in response.output:
if output.type == "image_generation_call":
image_data = output.result
break
if image_data:
# The result is already base64 encoded
return base64.b64decode(image_data)
else:
raise ValueError("No image generated in response")
except AttributeError:
# Fallback for older OpenAI SDK versions that don't have responses API
raise NotImplementedError(
"gpt-image-1 requires the Responses API which is not available"
" in your OpenAI SDK version. "
"Please update to the latest OpenAI SDK or use"
" dall-e-2/dall-e-3 instead."
) from None
# DALL-E models use the Images API
elif self.models[model] in ["dall-e-3", "dall-e-2"]:
# Remove unsupported kwargs for images.generate
image_kwargs = {}
if "size" in kwargs:
image_kwargs["size"] = kwargs["size"]
if "quality" in kwargs and self.models[model] == "dall-e-3":
image_kwargs["quality"] = kwargs["quality"]
if "n" in kwargs:
image_kwargs["n"] = kwargs["n"]
if self.models[model] == "dall-e-3":
# DALL-E 3 supports b64_json response format
resp = self.client.images.generate(
model=self.models[model],
prompt=prompt,
response_format="b64_json",
**image_kwargs,
)
data = resp.data[0].b64_json
return base64.b64decode(data)
else:
# DALL-E 2 uses URL format
resp = self.client.images.generate(
model=self.models[model], prompt=prompt, **image_kwargs
)
# Get the URL and download the image
image_url = resp.data[0].url
response = requests.get(image_url)
response.raise_for_status()
return response.content
class AnthropicProvider(ModelProvider):
"""Provider for Anthropic's Claude models.
Parameters
----------
api_key : str
API key used to authenticate with Anthropic.
"""
def __init__(self, api_key: str): # pragma: no cover - network client setup
if anthropic is None:
raise ImportError(
"Anthropic support requires the anthropic package. "
"Install it via `pip install anthropic`."
)
models = {
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-3-opus": "claude-3-opus-20240229",
"claude-image": "claude-3-opus-20240229",
}
super().__init__(models)
self._image_models = {"claude-image"}
self.client = anthropic.Anthropic(api_key=api_key)
def generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
else:
messages = prompt
response = self.client.messages.create(
model=self.models[model],
max_tokens=max_tokens,
messages=messages,
)
if response.content:
return response.content[0].text
return ""
def generate_image(
self, model: str, prompt: str, **kwargs: Any
) -> bytes: # pragma: no cover - network image wrapper
if model not in self._image_models:
raise ValueError(f"Model '{model}' is not an image model")
resp = self.client.images.generate(model=self.models[model], prompt=prompt)
data = resp.data[0].b64_json # type: ignore[attr-defined]
return base64.b64decode(data)
class HuggingFaceProvider(ModelProvider):
"""Provider using the HuggingFace Inference Hub."""
API_URL = "https://api-inference.huggingface.co/models/{model}"
def __init__(
self,
api_key: str,
*,
model_providers: dict[str, str] | None = None,
timeout: float = 60.0,
): # pragma: no cover - network client setup
models = {
"distilgpt2": "distilgpt2",
"deepseek-r1-distill-qwen-14b": (
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
),
"deepseek-r1": "deepseek-ai/DeepSeek-R1",
"gpt-oss-120b": "openai/gpt-oss-120b",
"llama-3.3-70b-instruct": ("meta-llama/Llama-3.3-70B-Instruct"),
"stabilityai/stable-diffusion-2": ("stabilityai/stable-diffusion-2"),
"stabilityai/stable-diffusion-3.5-large": (
"stabilityai/stable-diffusion-3.5-large"
),
"stabilityai/stable-diffusion-xl-base-1.0": (
"stabilityai/stable-diffusion-xl-base-1.0"
),
}
super().__init__(models)
self.api_key = api_key
self.model_providers = model_providers or {}
self._timeout = timeout
self._provider_clients: dict[str | None, InferenceClient] = {
None: InferenceClient(token=api_key, timeout=timeout)
}
self._router_client = OpenAI(
api_key=api_key, base_url="https://router.huggingface.co/v1"
)
self._router_provider_names = {
"together",
"sambanova",
"novita",
"fireworks",
"replicate",
"groq",
"cerebras",
"featherless",
"hyperbolic",
}
@staticmethod
def _normalize_messages(prompt: PromptType) -> list[dict[str, str]]:
if isinstance(prompt, str):
return [{"role": "user", "content": prompt}]
return prompt
@staticmethod
def _extract_text(choice: Any) -> str:
message = getattr(choice, "message", None)
if message is None and isinstance(choice, dict):
message = choice.get("message")
if message is None:
text = getattr(choice, "text", None)
if text is None and isinstance(choice, dict):
text = choice.get("text")
return text or ""
content = getattr(message, "content", None)
if content is None and isinstance(message, dict):
content = message.get("content")
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict):
parts.append(item.get("text", ""))
else:
parts.append(str(item))
return "".join(parts)
if content is not None:
return str(content)
return ""
def _get_client(self, provider: str | None) -> InferenceClient:
if provider not in self._provider_clients:
self._provider_clients[provider] = InferenceClient(
token=self.api_key, timeout=self._timeout, provider=provider
)
return self._provider_clients[provider]
def _legacy_generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> str:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
if isinstance(prompt, list):
user_chunks = [p["content"] for p in prompt if p.get("role") == "user"]
system_chunks = [p["content"] for p in prompt if p.get("role") == "system"]
combined_parts: list[str] = []
if system_chunks:
combined_parts.append("\n".join(system_chunks))
if user_chunks:
combined_parts.append("\n".join(user_chunks))
prompt_input: str | PromptType = "\n\n".join(combined_parts)
else:
prompt_input = prompt
data = {
"inputs": prompt_input,
"parameters": {
"max_new_tokens": max_tokens,
"return_full_text": False,
},
"options": {"wait_for_model": True},
}
url = self.API_URL.format(model=self.models[model])
resp = requests.post(url, headers=headers, json=data, timeout=self._timeout)
resp.raise_for_status()
result = resp.json()
if isinstance(result, list) and result:
generated = result[0]
if isinstance(generated, dict):
return generated.get("generated_text", str(generated))
return str(generated)
if isinstance(result, dict):
text = result.get("generated_text")
if text:
return text
return str(result)
def generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - network wrapper
messages = self._normalize_messages(prompt)
provider = self.model_providers.get(model)
if provider and provider.lower() in self._router_provider_names:
try:
completion = self._router_client.chat.completions.create(
model=f"{self.models[model]}:{provider}",
messages=messages,
max_tokens=max_tokens,
)
choice = completion.choices[0] if completion.choices else None
if choice is not None:
message = getattr(choice, "message", None)
if message is None and isinstance(choice, dict):
message = choice.get("message")
if message is not None:
content = getattr(message, "content", None)
if content is None and isinstance(message, dict):
content = message.get("content")
if isinstance(content, list):
return "".join(str(item) for item in content)
if content is not None:
return str(content)
return ""
except Exception:
# Fall back to inference client below
pass
try:
client = self._get_client(provider)
completion = client.chat.completions.create(
model=self.models[model],
messages=messages,
max_tokens=max_tokens,
)
if getattr(completion, "choices", None):
return self._extract_text(completion.choices[0])
return ""
except Exception:
return self._legacy_generate_text(model, prompt, max_tokens)
def generate_image(
self, model: str, prompt: str
) -> bytes: # pragma: no cover - network wrapper
"""Generate an image using the HuggingFace Inference API."""
provider = self.model_providers.get(model)
try:
client = self._get_client(provider)
image = client.text_to_image(prompt, model=self.models[model])
if hasattr(image, "save"):
buf = io.BytesIO()
image.save(buf, format="PNG")
return buf.getvalue()
if isinstance(image, bytes):
return image
except Exception:
# Fall back to classic binary endpoint.
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "image/png",
}
data = {"inputs": prompt}
url = self.API_URL.format(model=self.models[model])
resp = requests.post(url, headers=headers, json=data, timeout=self._timeout)
resp.raise_for_status()
return resp.content
class HuggingFaceLocalProvider(ModelProvider):
"""Provider that runs HuggingFace models locally.
Uses ``transformers`` for text and ``diffusers`` for images. Both text and
image generation can be configured independently by specifying
``text_model`` and/or ``image_model`` when registering the provider. The
heavy model weights are loaded on first use.
Parameters
----------
text_model : str, optional
Local text generation model name.
image_model : str, optional
Local image generation model name.
"""
def __init__(
self,
*,
text_model: str | None = None,
image_model: str | None = None,
): # pragma: no cover - heavy local dependency
models: dict[str, str] = {}
if text_model:
models[text_model] = text_model
if image_model:
models[image_model] = image_model
if not models:
raise ValueError("At least one of text_model or image_model must be set")
if hf_local_pipeline is None:
raise ImportError(
"Local HuggingFace inference requires the transformers package. "
"Install it via `pip install transformers`."
)
if image_model and StableDiffusionPipeline is None:
raise ImportError(
"Local HuggingFace image generation requires the diffusers package. "
"Install it via `pip install diffusers`."
)
super().__init__(models)
self._text_model = text_model
self._image_model = image_model
self._text_generator = None
self._image_pipeline = None
def _ensure_text_pipeline(self) -> None:
if self._text_generator is None:
self._text_generator = hf_local_pipeline(
"text-generation", model=self._text_model
)
def _ensure_image_pipeline(self) -> None:
if self._image_pipeline is None:
self._image_pipeline = StableDiffusionPipeline.from_pretrained(
self._image_model
)
def generate_text(
self, model: str, prompt: PromptType, max_tokens: int
) -> str: # pragma: no cover - optional heavy dependency
if model != self._text_model or not self._text_model:
raise ValueError(f"Model '{model}' is not configured for text generation")
if isinstance(prompt, list):
prompt = " ".join(
msg["content"] for msg in prompt if msg.get("role") == "user"
)
self._ensure_text_pipeline()
result = self._text_generator(prompt, max_new_tokens=max_tokens)
return result[0]["generated_text"]
def generate_image(self, model: str, prompt: str) -> bytes: # pragma: no cover
if model != self._image_model or not self._image_model:
raise ValueError(f"Model '{model}' is not configured for image generation")
self._ensure_image_pipeline()
image = self._image_pipeline(prompt).images[0]
buf = io.BytesIO()
image.save(buf, format="PNG")
return buf.getvalue()
[docs]
class AIModelInterface:
"""Register and dispatch to different AI model providers."""
_PROVIDER_CLASSES = {
"gemini": GeminiProvider,
"openrouter": OpenRouterProvider,
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
"huggingface": HuggingFaceProvider,
"huggingface_local": HuggingFaceLocalProvider,
"groq": GroqProvider,
"deepseek": DeepSeekProvider,
"together": TogetherProvider,
"local_openai": LocalOpenAIProvider,
}
def __init__(
self,
gemini_api_key: str | None = None,
openrouter_api_key: str | None = None,
openai_api_key: str | None = None,
openai_project: str | None = None,
anthropic_api_key: str | None = None,
huggingface_api_key: str | None = None,
groq_api_key: str | None = None,
deepseek_api_key: str | None = None,
together_api_key: str | None = None,
local_openai_base_url: str | None = None,
local_openai_api_key: str | None = None,
local_openai_models: dict[str, str] | None = None,
enabled_providers: list[str] | None = None,
):
"""Initialise the interface and register available providers.
The interface accepts optional API keys for different large language
model providers. The ``openai_api_key`` and ``anthropic_api_key``
parameters, or their respective ``OPENAI_API_KEY`` and
``ANTHROPIC_API_KEY`` environment variables, enable OpenAI and
Anthropic support.
Parameters
----------
gemini_api_key : str, optional
API key for Google Gemini.
openrouter_api_key : str, optional
API key for OpenRouter.
openai_api_key : str, optional
API key for OpenAI. Defaults to ``OPENAI_API_KEY`` environment
variable if not provided.
anthropic_api_key : str, optional
API key for Anthropic. Defaults to ``ANTHROPIC_API_KEY`` environment
variable if not provided.
huggingface_api_key : str, optional
API key for HuggingFace Inference API. Defaults to
``HUGGINGFACE_API_KEY`` or ``HUGGINGFACEHUB_API_TOKEN`` environment
variables.
groq_api_key : str, optional
API key for Groq Cloud. Defaults to ``GROQ_API_KEY`` environment
variable.
deepseek_api_key : str, optional
API key for DeepSeek's native API. Defaults to ``DEEPSEEK_API_KEY``.
together_api_key : str, optional
API key for Together AI. Defaults to ``TOGETHER_API_KEY``.
local_openai_base_url : str, optional
Base URL for a self-hosted OpenAI-compatible server (vLLM, TGI,
Ollama). Defaults to ``AI_BASE_URL`` environment variable.
local_openai_api_key : str, optional
API key for the self-hosted OpenAI-compatible server. Defaults to
``AI_API_KEY`` environment variable.
local_openai_models : dict, optional
Mapping of public model aliases to backend IDs for the local
provider. Defaults to parsing ``AI_MODELS`` environment variable
(``alias:model`` comma-separated pairs).
enabled_providers : list[str], optional
Restrict registration to this subset of providers. By default, all
providers with available API keys are enabled.
"""
env_providers = os.environ.get("AI_MODEL_PROVIDERS")
if enabled_providers is None and env_providers:
enabled_providers = [
p.strip() for p in env_providers.split(",") if p.strip()
]
self._providers: dict[str, ModelProvider] = {}
provider_configs: dict[str, dict[str, Any]] = {}
gemini_key = gemini_api_key or os.environ.get("GEMINI_API_KEY")
if gemini_key:
provider_configs["gemini"] = {"api_key": gemini_key}
openrouter_key = openrouter_api_key or os.environ.get("OPENROUTER_API_KEY")
if openrouter_key:
provider_configs["openrouter"] = {"api_key": openrouter_key}
openai_key = openai_api_key or os.environ.get("OPENAI_API_KEY")
openai_project_value = openai_project or os.environ.get("OPENAI_PROJECT")
if openai_key:
if openai_key.startswith("sk-proj-") and not openai_project_value:
raise ValueError(
"OPENAI_API_KEY appears to be a project-scoped key but no "
"OPENAI_PROJECT was provided."
"Set the project ID via the openai_project argument or the "
"OPENAI_PROJECT environment variable."
)
openai_cfg: dict[str, Any] = {"api_key": openai_key}
if openai_project_value:
openai_cfg["project"] = openai_project_value
provider_configs["openai"] = openai_cfg
anthropic_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY")
if anthropic_key:
provider_configs["anthropic"] = {"api_key": anthropic_key}
huggingface_key = (
huggingface_api_key
or os.environ.get("HUGGINGFACE_API_KEY")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
if huggingface_key:
hf_provider_map = _parse_model_mapping(
os.environ.get("HUGGINGFACE_MODEL_PROVIDERS")
)
hf_config: dict[str, Any] = {"api_key": huggingface_key}
if hf_provider_map:
hf_config["model_providers"] = hf_provider_map
provider_configs["huggingface"] = hf_config
groq_key = groq_api_key or os.environ.get("GROQ_API_KEY")
if groq_key:
provider_configs["groq"] = {"api_key": groq_key}
deepseek_key = deepseek_api_key or os.environ.get("DEEPSEEK_API_KEY")
if deepseek_key:
provider_configs["deepseek"] = {"api_key": deepseek_key}
together_key = together_api_key or os.environ.get("TOGETHER_API_KEY")
if together_key:
provider_configs["together"] = {"api_key": together_key}
local_base_url = local_openai_base_url or os.environ.get("AI_BASE_URL")
if local_base_url:
local_config: dict[str, Any] = {"base_url": local_base_url}
local_key = local_openai_api_key or os.environ.get("AI_API_KEY")
if local_key:
local_config["api_key"] = local_key
local_models = local_openai_models or _parse_model_mapping(
os.environ.get("AI_MODELS")
)
if local_models:
local_config["models"] = local_models
provider_configs["local_openai"] = local_config
for name in self._PROVIDER_CLASSES:
if enabled_providers is not None and name not in enabled_providers:
continue
config = provider_configs.get(name)
if not config:
continue
try:
self.register_provider(name, **config)
except Exception:
continue
[docs]
def register_provider(
self,
provider: ModelProvider | str,
*,
enabled: bool = True,
**config: Any,
) -> None:
"""Register a provider and its models.
Parameters
----------
provider : ModelProvider | str
Either an instantiated provider or the name of a provider defined in
:attr:`_PROVIDER_CLASSES`.
enabled : bool, optional
When ``False`` the provider is skipped.
**config : dict, optional
Configuration forwarded to the provider constructor when ``provider``
is given as a string.
"""
if not enabled:
return
if isinstance(provider, str):
cls = self._PROVIDER_CLASSES.get(provider)
if cls is None:
raise ValueError(f"Unknown provider '{provider}'")
provider_obj = cls(**config)
else:
provider_obj = provider
for model in provider_obj.models:
self._providers[model] = provider_obj
[docs]
def supports(self, model: str) -> bool:
"""Return whether ``model`` is registered with any provider."""
return model in self._providers
[docs]
def supports_batching(self, model: str) -> bool:
"""Return whether the provider for ``model`` supports batching."""
provider = self._providers.get(model)
if provider is None:
available = list(self._providers.keys())
raise ValueError(
f"Model '{model}' not supported. Available models: {available}"
)
return getattr(provider, "supports_batching", False)
[docs]
def generate_text(
self,
model: str,
prompt: PromptType,
max_tokens: int = 1000,
retries: int = 3,
) -> str:
"""Generate text using a registered model with retry.
Parameters
----------
model, prompt, max_tokens : see :meth:`ModelProvider.generate_text`
retries : int
Number of attempts before raising the final error.
"""
provider = self._providers.get(model)
if provider is None:
available = list(self._providers.keys())
raise ValueError(
f"Model '{model}' not supported. Available models: {available}"
)
try:
return _retry_sync(
lambda: provider.generate_text(model, prompt, max_tokens=max_tokens),
retries=retries,
)
except Exception as e: # pragma: no cover - simple re-raise
raise RuntimeError(f"Error generating response with {model}: {e}") from e
[docs]
async def generate_text_async(
self,
model: str,
prompt: PromptType,
max_tokens: int = 1000,
retries: int = 3,
) -> str:
"""Asynchronously generate text using a registered model with retry.
Parameters
----------
model, prompt, max_tokens : see :meth:`ModelProvider.generate_text`
retries : int
Number of attempts before raising the final error.
"""
provider = self._providers.get(model)
if provider is None:
available = list(self._providers.keys())
raise ValueError(
f"Model '{model}' not supported. Available models: {available}"
)
try:
return await _retry_async(
lambda: provider.generate_text_async(
model, prompt, max_tokens=max_tokens
),
retries=retries,
)
except Exception as e: # pragma: no cover - simple re-raise
raise RuntimeError(f"Error generating response with {model}: {e}") from e
[docs]
def stream_generate_text(
self,
model: str,
prompt: PromptType,
max_tokens: int = 1000,
retries: int = 3,
) -> Iterator[str]:
"""Stream generated text chunks from a registered model with retry.
Parameters
----------
model, prompt, max_tokens : see
:meth:`ModelProvider.stream_generate_text`
retries : int
Number of attempts before raising the final error.
"""
provider = self._providers.get(model)
if provider is None:
available = list(self._providers.keys())
raise ValueError(
f"Model '{model}' not supported. Available models: {available}"
)
try:
return _retry_stream(
lambda: provider.stream_generate_text(
model, prompt, max_tokens=max_tokens
),
retries=retries,
)
except Exception as e: # pragma: no cover - simple re-raise
raise RuntimeError(f"Error generating response with {model}: {e}") from e
[docs]
def generate_image(
self,
model: str,
prompt: str,
retries: int = 3,
**kwargs: Any,
) -> bytes:
"""Generate an image using a registered model with retry."""
provider = self._providers.get(model)
if provider is None or not hasattr(provider, "generate_image"):
available = [
m for m, p in self._providers.items() if hasattr(p, "generate_image")
]
raise ValueError(
f"Model '{model}' not supported for image generation. "
f"Available image models: {available}"
)
try:
return _retry_sync(
lambda: provider.generate_image(model, prompt, **kwargs),
retries=retries,
)
except Exception as e: # pragma: no cover - simple re-raise
raise RuntimeError(f"Error generating image with {model}: {e}") from e
[docs]
def list_available_models(self) -> list[str]:
"""Return the list of registered model names."""
return list(self._providers.keys())
[docs]
def provider_name(self, model: str) -> str:
"""Return the provider class name registered for ``model``."""
provider = self._providers.get(model)
return type(provider).__name__ if provider is not None else "UnknownProvider"
__all__ = [
"AIModelInterface",
"ModelProvider",
"PromptType",
"load_env_file",
"huggingface_credentials_present",
"pick_first_supported_model",
"build_generation_summary",
]