"""
Copyright 2026 OÜ KAVAL AI (registry code 17393877)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import inspect
import json
from loguru import logger
import os
from typing import Any, Dict, List, Optional, Callable, Type
import httpx
from pydantic import BaseModel, create_model
# ``mcp`` is an optional extra: it is not pyodide-compatible and pulls in a
# large dependency tree. Import it lazily so the core library imports without
# it; only MCP tool calls require ``pip install kavalai[mcp]``.
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
except ImportError: # pragma: no cover - exercised in pyodide / minimal installs
ClientSession = None
StdioServerParameters = None
sse_client = None
stdio_client = None
from kavalai.agents.workflow_model import (
McpServer,
RestServer,
WorkflowException,
)
SSE_CLIENT_TIMEOUT_SECONDS = 30.0
[docs]
class FunctionKernelException(WorkflowException):
"""Raised by :class:`FunctionKernel` when registering or calling a tool fails.
Covers errors such as unknown protocols, duplicate or unregistered tools,
malformed tool URIs, argument validation failures and errors surfaced by
the underlying REST/MCP/Python tool call.
"""
[docs]
class Validator:
"""
Helps converting and validating tool inputs and outputs.
Allows only basic data types and pydantic models.
"""
[docs]
@staticmethod
def create_model_from_signature(
name: str, sig: inspect.Signature, is_input: bool = True
) -> Type[BaseModel]:
if is_input:
input_fields = {}
for param_name, p in sig.parameters.items():
annotation = (
p.annotation if p.annotation != inspect.Parameter.empty else Any
)
default = p.default if p.default != inspect.Parameter.empty else ...
input_fields[param_name] = (annotation, default)
return create_model(f"{name}_input", **input_fields)
else:
output_annotation = (
sig.return_annotation
if sig.return_annotation != inspect.Signature.empty
else Any
)
if inspect.isclass(output_annotation) and issubclass(
output_annotation, BaseModel
):
return output_annotation
else:
return create_model(f"{name}_output", result=(output_annotation, ...))
[docs]
@staticmethod
def create_model_from_schema(name: str, schema: Dict[str, Any]) -> Type[BaseModel]:
return _create_model_from_jsonschema(name, schema)
[docs]
@staticmethod
def validate_arguments(
model: Type[BaseModel], arguments: Dict[str, Any]
) -> Dict[str, Any]:
return model(**arguments).model_dump()
[docs]
@staticmethod
def cast_result(
result: Any, target_output_type: Optional[Type], context_info: str = ""
) -> Any:
"""Cast result to target output type if it's a Pydantic model."""
if (
target_output_type
and inspect.isclass(target_output_type)
and issubclass(target_output_type, BaseModel)
):
try:
if isinstance(result, dict):
return target_output_type(**result)
if isinstance(result, BaseModel):
if isinstance(result, target_output_type):
return result
return target_output_type(**result.model_dump())
fields = target_output_type.model_fields
if len(fields) == 1:
field_name = list(fields.keys())[0]
# If result is already of the correct type for the field, or can be cast
return target_output_type(**{field_name: result})
try:
return target_output_type(result)
except Exception:
return result
except Exception as e:
logger.warning(
f"{context_info} returned incompatible result for {target_output_type}: {e}"
)
return result
return result
[docs]
class FunctionKernel:
"""
Manages registration and execution of tools (REST, MCP, Python).
Handles conversions and MCP session management.
"""
def __init__(self):
self.rest_servers: Dict[str, RestServer] = {}
self.rest_tool_definitions: Dict[str, Dict[str, ToolDefinition]] = {}
self.mcp_servers: Dict[str, McpServer] = {}
self.python_tools: Dict[str, Callable] = {}
self.python_tool_definitions: Dict[str, ToolDefinition] = {}
self.mcp_tool_definitions: Dict[str, Dict[str, ToolDefinition]] = {}
# MCP session management
self.mcp_sessions: Dict[str, "ClientSession"] = {}
self.mcp_cleanups: List[Any] = []
[docs]
def register_rest_server(self, server: RestServer):
if server.name in self.rest_servers:
raise FunctionKernelException(
f"REST server '{server.name}' is already registered."
)
self.rest_servers[server.name] = server
[docs]
def register_mcp_server(self, server: McpServer):
if server.name in self.mcp_servers:
raise FunctionKernelException(
f"MCP server '{server.name}' is already registered."
)
self.mcp_servers[server.name] = server
def _generate_python_tool_definition(
self, name: str, func: Callable
) -> ToolDefinition:
sig = inspect.signature(func)
InputModel = Validator.create_model_from_signature(name, sig, is_input=True)
OutputModel = Validator.create_model_from_signature(name, sig, is_input=False)
return ToolDefinition(
name=name,
description=inspect.getdoc(func) or "",
input_model=InputModel,
output_model=OutputModel,
)
async def _handle_rest_call(
self,
server_name: str,
tool_name: str,
arguments: Dict[str, Any],
output_type: Optional[type] = None,
**kwargs,
) -> Any:
"""Handle REST tool calls with validation if definition exists."""
method = kwargs.get("method", "get")
if (
server_name in self.rest_tool_definitions
and tool_name in self.rest_tool_definitions[server_name]
):
definition = self.rest_tool_definitions[server_name][tool_name]
try:
desc_data = json.loads(definition.description)
method = desc_data.get("method", method)
except Exception:
pass
# Validate input
validated_args = Validator.validate_arguments(
definition.input_model, arguments
)
return await self._call_rest_tool(
server_name,
tool_name,
validated_args,
method,
output_type or definition.output_model,
)
return await self._call_rest_tool(
server_name, tool_name, arguments, method, output_type
)
[docs]
async def close(self):
"""Cleanup all MCP sessions."""
for session in reversed(self.mcp_cleanups):
try:
# Some are async context managers, some are ClientSessions
if hasattr(session, "__aexit__"):
await session.__aexit__(None, None, None)
except Exception as e:
logger.warning(f"Error during MCP cleanup: {e}")
self.mcp_sessions.clear()
self.mcp_cleanups.clear()
async def _call_rest_tool(
self,
server_name: str,
tool: str,
arguments: Dict[str, Any],
method: str = "get",
output_type: Optional[type] = None,
) -> Any:
if server_name not in self.rest_servers:
raise FunctionKernelException(
f"REST server '{server_name}' not registered."
)
server = self.rest_servers[server_name]
url = server.url
if not url and server.url_env:
url = os.environ.get(server.url_env)
if not url:
raise FunctionKernelException(
f"URL for REST server '{server_name}' not found."
)
auth = None
if server.username_env and server.password_env:
username = os.environ.get(server.username_env)
password = os.environ.get(server.password_env)
if username and password:
auth = (username, password)
async with httpx.AsyncClient(auth=auth) as client:
kwargs = {
"params": arguments,
"timeout": 60.0,
}
if method.lower() in ("post", "put", "patch"):
kwargs["json"] = arguments
kwargs.pop("params", None)
full_url = f"{url.rstrip('/')}/{tool.lstrip('/')}"
logger.info(f"Calling {method.upper()} {full_url}")
response = await client.request(method.upper(), full_url, **kwargs)
response.raise_for_status()
result_data = response.json()
return Validator.cast_result(
result_data, output_type, f"REST tool '{server_name}.{tool}'"
)
async def _get_mcp_session(self, server_name: str) -> "ClientSession":
if ClientSession is None:
raise FunctionKernelException(
"MCP support requires the optional 'mcp' dependency. "
"Install it with: pip install kavalai[mcp]"
)
if server_name in self.mcp_sessions:
return self.mcp_sessions[server_name]
if server_name not in self.mcp_servers:
raise FunctionKernelException(f"MCP server '{server_name}' not registered.")
config = self.mcp_servers[server_name]
if config.url or config.url_env:
url = config.url
if not url and config.url_env:
url = os.environ.get(config.url_env)
if not url:
raise FunctionKernelException(
f"URL for MCP server '{server_name}' not found."
)
logger.info(f"Connecting to HTTP MCP server {server_name} at {url}")
aclient = sse_client(url, timeout=SSE_CLIENT_TIMEOUT_SECONDS)
read, write = await aclient.__aenter__()
self.mcp_cleanups.append(aclient)
else:
command = config.command
if not command and config.command_env:
command = os.environ.get(config.command_env)
if not command:
raise FunctionKernelException(
f"Command for MCP server '{server_name}' not found."
)
server_params = StdioServerParameters(
command=command,
args=config.args,
env={**os.environ, **config.env},
)
logger.info(f"Connecting to stdio MCP server {server_name}")
aclient = stdio_client(server_params)
read, write = await aclient.__aenter__()
self.mcp_cleanups.append(aclient)
session = ClientSession(read, write)
await session.__aenter__()
self.mcp_cleanups.append(session)
await session.initialize()
self.mcp_sessions[server_name] = session
# Fetch and store tool definitions
await self._refresh_mcp_tool_definitions(server_name, session)
return session
async def _refresh_mcp_tool_definitions(
self, server_name: str, session: "ClientSession"
):
try:
tools_result = await session.list_tools()
definitions = {}
for tool in tools_result.tools:
# MCP tool input schema is usually a JSON Schema
# For now, we store the raw schema and we could dynamically create a Pydantic model
# But to stay consistent with the "Pydantic models for everything" requirement:
input_model = Validator.create_model_from_schema(
f"{server_name}_{tool.name}_input", tool.inputSchema
)
# MCP doesn't strictly define output schema in tool list, so we use a generic one
output_model = create_model(
f"{server_name}_{tool.name}_output", result=(Any, ...)
)
definitions[tool.name] = ToolDefinition(
name=tool.name,
description=tool.description,
input_model=input_model,
output_model=output_model,
)
self.mcp_tool_definitions[server_name] = definitions
except Exception as e:
logger.warning(f"Could not refresh tools for MCP server {server_name}: {e}")
async def _call_mcp_tool(
self,
server_name: str,
tool: str,
arguments: Dict[str, Any],
output_type: Optional[type] = None,
) -> Any:
session = await self._get_mcp_session(server_name)
# Validate arguments if definition exists
if (
server_name in self.mcp_tool_definitions
and tool in self.mcp_tool_definitions[server_name]
):
definition = self.mcp_tool_definitions[server_name][tool]
try:
arguments = Validator.validate_arguments(
definition.input_model, arguments
)
except Exception as e:
raise FunctionKernelException(
f"MCP tool '{tool}' on server '{server_name}' argument validation failed: {e}"
)
logger.info(f"Calling MCP tool {server_name}/{tool}")
response = await session.call_tool(tool, arguments=arguments)
if response.isError:
raise FunctionKernelException(
f"MCP tool '{tool}' on server '{server_name}' failed: {response.content}"
)
result_data = None
for content in response.content:
if hasattr(content, "text"):
try:
result_data = json.loads(content.text)
except (json.JSONDecodeError, TypeError):
result_data = content.text
break
# Convert output using output_type or definition's output_model
target_output_type = output_type
if not target_output_type and server_name in self.mcp_tool_definitions:
if tool in self.mcp_tool_definitions[server_name]:
target_output_type = self.mcp_tool_definitions[server_name][
tool
].output_model
return Validator.cast_result(
result_data, target_output_type, f"MCP tool '{server_name}.{tool}'"
)
async def _call_python_tool(
self,
python_tool: str,
arguments: Dict[str, Any],
output_type: Optional[type] = None,
) -> Any:
if python_tool in self.python_tools:
func = self.python_tools[python_tool]
definition = self.python_tool_definitions.get(python_tool)
else:
raise FunctionKernelException(
f"Python tool '{python_tool}' not registered."
)
# Validate arguments using input model
try:
call_args = Validator.validate_arguments(definition.input_model, arguments)
# Ensure complex Pydantic types are passed as model instances if needed
for param_name, p in inspect.signature(func).parameters.items():
if (
param_name in call_args
and isinstance(p.annotation, type)
and issubclass(p.annotation, BaseModel)
):
call_args[param_name] = p.annotation(**call_args[param_name])
except Exception as e:
raise FunctionKernelException(
f"Python tool '{python_tool}' argument validation failed: {e}"
)
sig = inspect.signature(func)
try:
bound_args = sig.bind(**call_args)
except TypeError as e:
raise FunctionKernelException(
f"Python tool '{python_tool}' signature mismatch: {e}"
)
try:
if inspect.iscoroutinefunction(func):
result = await func(*bound_args.args, **bound_args.kwargs)
else:
result = func(*bound_args.args, **bound_args.kwargs)
except Exception as e:
logger.exception(f"Error executing python_tool '{python_tool}'")
raise FunctionKernelException(
f"Error executing python_tool '{python_tool}': {e}"
)
target_output_type = output_type or definition.output_model
return Validator.cast_result(
result, target_output_type, f"Python tool '{python_tool}'"
)
[docs]
def get_output_model(self, tool_uri: str) -> Type[BaseModel]:
"""Return the output Pydantic model for a registered tool."""
return self.get_tool_definition(tool_uri).output_model
def _strip_metadata(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Strip metadata from arguments."""
return {
k: v for k, v in arguments.items() if k not in ("__line__", "__file_path__")
}
def _parse_tool_uri(tool_uri: str) -> tuple[str, str]:
"""Parse tool URI into protocol and path."""
if "://" not in tool_uri:
raise FunctionKernelException(
f"Invalid tool URI format: '{tool_uri}'. Expected protocol://[name|module].function_name"
)
protocol, path = tool_uri.split("://", 1)
return protocol, path
def _create_model_from_jsonschema(name: str, schema: Dict[str, Any]) -> Type[BaseModel]:
"""Very basic JSON Schema to Pydantic model conversion."""
properties = schema.get("properties", {})
required = schema.get("required", [])
fields = {}
type_map = {
"string": str,
"number": float,
"integer": int,
"boolean": bool,
"object": dict,
"array": list,
}
for prop_name, prop_schema in properties.items():
prop_type = prop_schema.get("type", "any")
python_type = type_map.get(prop_type, Any)
default = ... if prop_name in required else None
fields[prop_name] = (python_type, default)
return create_model(name, **fields)
def _get_schema(model: Type[BaseModel]) -> Dict[str, Any]:
schema_dict = model.model_json_schema()
# Remove pydantic-specific keys to keep it cleaner for LLM
schema_dict.pop("title", None)
schema_dict.pop("type", None)
return schema_dict
def _is_tool_allowed(name: str, allowed_tools: Optional[List[str]]) -> bool:
"""Check if a tool name is allowed."""
# If allowed_tools is not given, by default we allow calling all tools.
if allowed_tools is None:
return True
if name in allowed_tools:
return True
# Checks for dynamic rest/mcp server and if the server is allowed
# e.g. rest://server.* or mcp://server.*
if "." in name:
server_prefix = name.split(".")[0] + ".*"
if server_prefix in allowed_tools:
return True
return False
def _add_tool_to_list(
tools_list: List[Dict[str, Any]],
name: str,
description: str,
input_model: Type[BaseModel],
allowed_tools: Optional[List[str]],
):
"""Add tool to list if it's allowed."""
if _is_tool_allowed(name, allowed_tools):
tools_list.append(
{
"name": name,
"description": description,
"inputSchema": _get_schema(input_model),
}
)