Source code for kavalai.llm_clients.embeddings

"""
Copyright 2026 OÜ KAVAL AI (registry code 17393877)

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import time
from typing import List, Optional, Tuple

from kavalai.agents.db import ModelCallStat
from kavalai.llm_clients.common import create_model_call_stat, get_model_name
from kavalai.normalizer import Normalizer, get_default_normalizer

Embeddings = List[List[float]]


[docs] class BaseEmbeddingClient: """Common interface for v2 embedding clients. The model name is bound at construction (the factory splits the ``provider/model`` string), so ``compute_embeddings`` only takes the texts. Implementations return the embeddings plus a database-ready :class:`~kavalai.agents.db.ModelCallStat` (the ORM row) so callers such as :class:`~kavalai.agents.rag_service.RagService` can persist usage directly. """ def __init__(self, model: str): self.model = model
[docs] async def compute_embeddings( self, texts: List[str], normalize: bool = False, normalizer: Optional[Normalizer] = None, **kwargs, ) -> Tuple[Embeddings, ModelCallStat]: raise NotImplementedError("Subclasses must implement compute_embeddings.")
def _maybe_normalize( embeddings: Embeddings, normalize: bool, normalizer: Optional[Normalizer] ) -> Embeddings: if not normalize: return embeddings if normalizer is None: normalizer = get_default_normalizer() return normalizer.transform(embeddings)
[docs] class OpenAIEmbeddingClient(BaseEmbeddingClient): """OpenAI embeddings (e.g. ``text-embedding-3-small``).""" def __init__( self, model: str, api_key: Optional[str] = None, base_url: Optional[str] = None, timeout: float = 30.0, ): super().__init__(model) from openai import AsyncOpenAI self.timeout = timeout self.client = AsyncOpenAI( api_key=api_key or os.getenv("OPENAI_API_KEY"), base_url=base_url, timeout=timeout, )
[docs] async def compute_embeddings( self, texts: List[str], normalize: bool = False, normalizer: Optional[Normalizer] = None, **kwargs, ) -> Tuple[Embeddings, ModelCallStat]: start_time = time.perf_counter() response = await self.client.embeddings.create( input=texts, model=self.model, timeout=self.timeout, **kwargs ) duration = time.perf_counter() - start_time embeddings = _maybe_normalize( [data.embedding for data in response.data], normalize, normalizer ) total_tokens = response.usage.total_tokens if response.usage else 0 stats = create_model_call_stat( call_type="embedding", model=f"openai/{self.model}", duration_sections=duration, batch_size=len(texts), total_tokens=total_tokens, response_data=response.model_dump() if hasattr(response, "model_dump") else response, ) return embeddings, stats
[docs] class GeminiEmbeddingClient(BaseEmbeddingClient): """Google Gemini embeddings.""" def __init__(self, model: str, api_key: Optional[str] = None): super().__init__(model) from google import genai self.client = genai.Client(api_key=api_key or os.getenv("GEMINI_API_KEY"))
[docs] async def compute_embeddings( self, texts: List[str], normalize: bool = False, normalizer: Optional[Normalizer] = None, **kwargs, ) -> Tuple[Embeddings, ModelCallStat]: from google.genai import types start_time = time.perf_counter() model_name = get_model_name(self.model) response = await self.client.aio.models.embed_content( model=model_name, contents=texts, config=types.EmbedContentConfig(**kwargs), ) duration = time.perf_counter() - start_time embeddings = _maybe_normalize( [embedding.values for embedding in response.embeddings], normalize, normalizer, ) stats = create_model_call_stat( call_type="embedding", model=f"gemini/{model_name}", duration_sections=duration, batch_size=len(texts), total_tokens=0, ) return embeddings, stats
[docs] class OllamaEmbeddingClient(BaseEmbeddingClient): """Ollama (local) embeddings.""" def __init__(self, model: str, host: Optional[str] = None, timeout: float = 30.0): super().__init__(model) import ollama self.client = ollama.AsyncClient( host=host or os.getenv("OLLAMA_HOST", "http://localhost:11434"), timeout=timeout, )
[docs] async def compute_embeddings( self, texts: List[str], normalize: bool = False, normalizer: Optional[Normalizer] = None, **kwargs, ) -> Tuple[Embeddings, ModelCallStat]: start_time = time.perf_counter() model_name = get_model_name(self.model) embeddings: Embeddings = [] total_prompt_tokens = 0 for text in texts: response = await self.client.embed(model=model_name, input=text, **kwargs) embeddings.extend(response.get("embeddings", [])) total_prompt_tokens += response.get("prompt_eval_count", 0) embeddings = _maybe_normalize(embeddings, normalize, normalizer) duration = time.perf_counter() - start_time stats = create_model_call_stat( call_type="embedding", model=f"ollama/{model_name}", duration_sections=duration, batch_size=len(texts), total_tokens=total_prompt_tokens, ) return embeddings, stats
[docs] class FastEmbedClient(BaseEmbeddingClient): """Local embeddings via FastEmbed / ONNX Runtime (no API key).""" def __init__( self, model: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, **kwargs, ): super().__init__(model) self.cache_dir = cache_dir self.threads = threads self.init_kwargs = kwargs self._embedding_model = None def _get_model(self): if self._embedding_model is None: from fastembed import TextEmbedding self._embedding_model = TextEmbedding( model_name=self.model, cache_dir=self.cache_dir, threads=self.threads, **self.init_kwargs, ) return self._embedding_model
[docs] async def compute_embeddings( self, texts: List[str], normalize: bool = False, normalizer: Optional[Normalizer] = None, **kwargs, ) -> Tuple[Embeddings, ModelCallStat]: start_time = time.perf_counter() embeddings = [e.tolist() for e in self._get_model().embed(texts, **kwargs)] embeddings = _maybe_normalize(embeddings, normalize, normalizer) duration = time.perf_counter() - start_time stats = create_model_call_stat( call_type="embedding", model=f"fastembed/{get_model_name(self.model)}", duration_sections=duration, batch_size=len(texts), total_tokens=None, # FastEmbed does not expose token counts. cost=0.0, ) stats.currency = "USD" return embeddings, stats
[docs] class BrowserEmbeddingClient(BaseEmbeddingClient): """In-browser embeddings via the WebLLM bridge (Pyodide only, no API key). Mirrors :class:`~kavalai.llm_clients.browser_client.BrowserLLMClient`: inference happens inside the page through ``window.kavalBrowserLLM``, here via its async ``embed`` function:: window.kavalBrowserLLM.embed(requestJson) -> Promise<resultJson> where ``requestJson`` is a JSON string of ``{model, input}`` (``input`` is the list of texts) and ``resultJson`` is a JSON string of either ``{embeddings, usage}`` or ``{error}``. The model is downloaded once and cached by the browser — no API key, no provider account, no CORS. Use it through ``make_embedding_client("browser/<model-id>")``; ``<model-id>`` is passed verbatim to the bridge (e.g. a WebLLM embedding id like ``snowflake-arctic-embed-m-q0f32-MLC-b4``). """
[docs] async def compute_embeddings( self, texts: List[str], normalize: bool = False, normalizer: Optional[Normalizer] = None, **kwargs, ) -> Tuple[Embeddings, ModelCallStat]: import json from kavalai.llm_clients.base_client import LlmClientException from kavalai.llm_clients.browser_client import get_browser_bridge start_time = time.perf_counter() bridge = get_browser_bridge() request = {"model": self.model, "input": list(texts)} try: # ``bridge.embed`` resolves to a JS Promise; awaiting it in Pyodide # yields the resolved JSON string. raw = await bridge.embed(json.dumps(request)) except Exception as exc: # JsException or anything the bridge throws. raise LlmClientException( f"In-browser embedding call failed: {exc}" ) from exc data = json.loads(raw) if data.get("error"): raise LlmClientException(f"In-browser embedding error: {data['error']}") embeddings = _maybe_normalize( data.get("embeddings") or [], normalize, normalizer ) usage = data.get("usage") or {} total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0 duration = time.perf_counter() - start_time stats = create_model_call_stat( call_type="embedding", model=f"browser/{self.model}", duration_sections=duration, batch_size=len(texts), total_tokens=total_tokens, ) return embeddings, stats
[docs] def make_embedding_client(model: str) -> BaseEmbeddingClient: """Construct a v2 embedding client from a ``provider/model`` string. Supported providers: ``openai``, ``gemini``, ``ollama``, ``fastembed``, ``browser``. The provider is split off and the remainder (which may itself contain slashes, e.g. ``fastembed/BAAI/bge-small-en-v1.5``) is the model name. The ``browser`` provider runs entirely client-side via a WebLLM bridge (Pyodide only) and needs no API key. """ if "/" not in model: raise ValueError(f"Embedding model must be 'provider/model', got '{model}'.") provider, model_name = model.split("/", maxsplit=1) if provider == "openai": return OpenAIEmbeddingClient(model_name, api_key=os.getenv("OPENAI_API_KEY")) if provider == "gemini": return GeminiEmbeddingClient(model_name, api_key=os.getenv("GEMINI_API_KEY")) if provider == "ollama": return OllamaEmbeddingClient(model_name, host=os.getenv("OLLAMA_HOST")) if provider == "fastembed": threads = os.getenv("FASTEMBED_THREADS") return FastEmbedClient( model_name, cache_dir=os.getenv("FASTEMBED_CACHE_DIR"), threads=int(threads) if threads else None, ) if provider == "browser": return BrowserEmbeddingClient(model_name) raise ValueError(f"Unsupported embedding provider: '{provider}'.")