Source code for kavalai.workflow.engine

"""
Copyright 2026 OÜ KAVAL AI (registry code 17393877)

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import importlib
import os
import time
from typing import Any, Callable, Optional
from uuid import uuid4

import yaml
from loguru import logger
from pydantic import BaseModel, ValidationError

from kavalai.agents.schema_parser import SchemaParser
from kavalai.agents.run_context import RunContext
from kavalai.agents.utils import to_plain
from kavalai.agents.agent import Agent
from kavalai.workflow import clients as client_factory_module
from kavalai.workflow.expressions import evaluate_bool, evaluate_value
from kavalai.workflow.models import (
    AgentNode,
    EndNode,
    FunctionNode,
    IfNode,
    LLMNode,
    Node,
    SwitchNode,
    WorkflowGraph,
)
from kavalai.workflow.state import WorkflowState
from kavalai.workflow.storage.base import DataStorage
from kavalai.workflow.tasklog.base import TaskLogger, TokenAccumulator
from kavalai.agents.workflow_model import WorkflowException
from kavalai.functionkernel import FunctionKernel, pythontool
from kavalai.llm_clients.base_client import BaseLlmClient, ChatHistory, ChatMessage

ClientFactory = Callable[..., BaseLlmClient]

DEFAULT_MAX_NODE_VISITS = 1000


[docs] def make_prompt(prompt: str, input_data: dict) -> str: """Combine a rendered prompt with resolved input data into a system message.""" pieces = [prompt] if input_data: pieces.append("INPUT DATA:") for key, value in input_data.items(): if isinstance(value, BaseModel): value = value.model_dump_json() pieces.append(f"{key}:{value}") return "\n".join(pieces)
[docs] class WorkflowEngine: """Executes a v2 :class:`WorkflowGraph` as a DAG / state machine. The engine walks the graph from the start node, following transitions and evaluating branch nodes, until it reaches an end node. Each node's result is stored in the run context; the serialized :class:`WorkflowState` is checkpointed to ``storage`` after every node and per-node debug data flows to ``task_logger``. Parameters ========== graph: WorkflowGraph The parsed workflow definition. storage: Optional[DataStorage] Persistence backend for agents/sessions/runs/chat/state. task_logger: Optional[TaskLogger] Backend for per-node debug data and model statistics. client_factory: Optional[ClientFactory] Factory ``(model, parameters, stats_receiver) -> BaseLlmClient`` used to build LLM clients. Defaults to the provider factory; inject a fake for offline testing. max_node_visits: int Safety cap on total node executions to guard against infinite loops. """ def __init__( self, graph: WorkflowGraph, *, storage: Optional[DataStorage] = None, task_logger: Optional[TaskLogger] = None, client_factory: Optional[ClientFactory] = None, data_models: Optional[dict[str, type[BaseModel]]] = None, max_node_visits: int = DEFAULT_MAX_NODE_VISITS, ): self.graph = graph self.storage = storage self.task_logger = task_logger self.client_factory = client_factory or client_factory_module.make_client self.max_node_visits = max_node_visits # Per-run token aggregator; recreated for each run() so totals don't leak. self._token_stats = TokenAccumulator(task_logger) # Data types are usually JSON-schema fragments compiled to Pydantic models # by the SchemaParser. ``data_models`` lets callers (e.g. the # WorkflowBuilder's ``data_model``) supply ready-made Pydantic models # directly; those names are used as-is and skip the parser. overrides = data_models or {} to_parse = {k: v for k, v in graph.data_types.items() if k not in overrides} self.parser = SchemaParser(to_parse) self.models = self.parser.parse_all() self.models.update(overrides) self.node_map = graph.node_map # Build the function kernel and register declared servers / tools, reusing # the v1 registration approach. self.kernel = FunctionKernel() for server in graph.rest_servers: self.kernel.register_rest_server(server) for server in graph.mcp_servers: self.kernel.register_mcp_server(server) for func_config in graph.python_functions: module_path, func_name = func_config.path.rsplit(".", 1) module = importlib.import_module(module_path) func = getattr(module, func_name) if not getattr(func, "_is_kavalai_tool", False): func = pythontool(func) self.kernel.register_python_tool(func_config.name, func) # ------------------------------------------------------------------ loaders
[docs] @classmethod def from_yaml(cls, yaml_string: str, **kwargs) -> "WorkflowEngine": """Build an engine from a YAML workflow definition string.""" try: data = yaml.load(yaml_string, Loader=yaml.SafeLoader) # nosec B506 graph = WorkflowGraph(**data) except ValidationError as e: raise WorkflowException(f"Workflow validation failed: {e}") from e return cls(graph, **kwargs)
[docs] @classmethod def from_yaml_path(cls, yaml_path: str, **kwargs) -> "WorkflowEngine": """Build an engine from a YAML workflow definition file.""" with open(yaml_path, "r") as f: return cls.from_yaml(f.read(), **kwargs)
[docs] @classmethod def from_dict(cls, data: dict, **kwargs) -> "WorkflowEngine": """Build an engine from a parsed workflow definition dict.""" try: graph = WorkflowGraph(**data) except ValidationError as e: raise WorkflowException(f"Workflow validation failed: {e}") from e return cls(graph, **kwargs)
# ------------------------------------------------------------------- helpers
[docs] def get_data_type(self, name: Optional[str]): if not name: return None return self.models.get(name)
def _resolve_model(self, node_model: Optional[str]) -> str: model = ( node_model or self.graph.llm_model or os.environ.get("KAVALAI_DEFAULT_LLM_MODEL") ) if not model: raise WorkflowException( "No LLM model configured (set node.llm_model, graph.llm_model " "or KAVALAI_DEFAULT_LLM_MODEL)." ) return model def _make_llm_client( self, node_model: Optional[str], llm_kwargs: dict, agent_id: Optional[str] ) -> BaseLlmClient: model = self._resolve_model(node_model) merged = dict(self.graph.llm_kwargs) merged.update(llm_kwargs or {}) parameters = client_factory_module.build_parameters(merged) # The accumulator tallies tokens for the whole run and forwards each call # to the task logger (when configured). self._token_stats.agent_id = agent_id return self.client_factory(model, parameters, self._token_stats) # --------------------------------------------------------------------- nodes async def _run_llm_node(self, node: LLMNode, run_context: RunContext) -> None: input_data = await run_context.prepare_tool_inputs(node) rendered_prompt = await run_context.render_prompt(node.prompt) text = make_prompt(rendered_prompt, input_data) messages = [ChatMessage(role="system", content=text)] if node.use_history and self.storage and run_context.session_id: history = await self.storage.get_chat_history(str(run_context.session_id)) for msg in history: messages.append(ChatMessage(role=msg.role, content=msg.content)) agent_id = str(run_context.agent_id) if run_context.agent_id else None client = self._make_llm_client(node.llm_model, node.llm_kwargs, agent_id) start = time.perf_counter() response = await client.chat_completions( chat_history=ChatHistory(messages=messages), response_model=self.get_data_type(node.output), ) duration = time.perf_counter() - start run_context.data[node.output] = response self._log_node( run_context, node, inputs=input_data, output=response, prompt=text, duration=duration, ) async def _run_agent_node(self, node: AgentNode, run_context: RunContext) -> None: input_data = await run_context.prepare_tool_inputs(node) rendered_prompt = await run_context.render_prompt(node.prompt) agent_id = str(run_context.agent_id) if run_context.agent_id else None client = self._make_llm_client(node.llm_model, node.llm_kwargs, agent_id) agent = Agent(llm_client=client, kernel=self.kernel, run_context=run_context) start = time.perf_counter() result = await agent.prompt( prompt=rendered_prompt, response_model=self.get_data_type(node.output), max_steps=node.max_steps, ) duration = time.perf_counter() - start run_context.data[node.output] = result self._log_node( run_context, node, inputs=input_data, output=result, prompt=rendered_prompt, duration=duration, ) async def _run_function_node( self, node: FunctionNode, run_context: RunContext ) -> None: inputs = await run_context.prepare_tool_inputs(node) output_type = self.get_data_type(node.output) call_kwargs: dict[str, Any] = {} if node.tool.startswith("rest://"): call_kwargs["method"] = node.method start = time.perf_counter() result = await self.kernel.call_tool( tool_uri=node.tool, arguments=inputs, output_type=output_type, **call_kwargs, ) duration = time.perf_counter() - start run_context.data[node.output] = result self._log_node( run_context, node, inputs=inputs, output=result, duration=duration, ) def _log_node( self, run_context: RunContext, node: Node, *, inputs: Optional[dict], output: Any, prompt: Optional[str] = None, duration: float, ) -> None: if not self.task_logger: return self.task_logger.log_node( run_id=str(run_context.run_id) if run_context.run_id else None, session_id=str(run_context.session_id) if run_context.session_id else None, agent_id=str(run_context.agent_id) if run_context.agent_id else None, node_name=node.name, node_type=node.type, inputs=to_plain(inputs) if inputs else inputs, output=to_plain(output) if output is not None else None, prompt=prompt, duration=duration, ) def _next_node(self, node: Node, run_context: RunContext) -> Optional[str]: """Return the name of the next node to execute, or None at an end node.""" if isinstance(node, EndNode): return None if isinstance(node, IfNode): return ( node.then if evaluate_bool(node.condition, run_context.data) else node.else_ ) if isinstance(node, SwitchNode): value = evaluate_value(node.expr, run_context.data) return node.cases.get(value, node.default) return node.next async def _execute_node(self, node: Node, run_context: RunContext) -> None: """Run a side-effecting node (branch nodes are pure routing).""" if isinstance(node, LLMNode): await self._run_llm_node(node, run_context) elif isinstance(node, AgentNode): await self._run_agent_node(node, run_context) elif isinstance(node, FunctionNode): await self._run_function_node(node, run_context) # start / if / switch / end nodes have no side effects here. # ----------------------------------------------------------------------- run
[docs] async def run( self, input_data: dict, *, session_id: Optional[str] = None, external_id: Optional[str] = None, ) -> WorkflowState: """Execute the workflow for ``input_data`` and return the final state.""" invocation_id = uuid4().hex[:8] # Fresh token aggregator so totals never leak between runs. self._token_stats = TokenAccumulator(self.task_logger) parsed_input = self.get_data_type("input")(**input_data) run_context = RunContext() run_context.data["input"] = parsed_input run_context.templates = {t.name: t.value for t in self.graph.templates} state = WorkflowState( workflow_name=self.graph.name, status="running", input_data=to_plain(input_data), invocation_id=invocation_id, ) # Bind the invocation id onto every log record emitted during the run — # the engine, the agent loop and the LLM clients — so an entire # invocation can be grepped out of the logs by its id. with logger.contextualize(invocation_id=invocation_id): logger.info(f"[{invocation_id}] Starting workflow '{self.graph.name}'") if self.storage: handle = await self.storage.initialize_run( workflow_name=self.graph.name, description=self.graph.description, input_schema=self.graph.data_types.get("input"), output_schema=self.graph.data_types.get("output"), workflow=self.graph.model_dump(), session_id=session_id, external_id=external_id, input_data=to_plain(input_data), ) run_context.agent_id = handle.agent_id run_context.session_id = handle.session_id run_context.run_id = handle.run_id state.agent_id = handle.agent_id state.session_id = handle.session_id state.run_id = handle.run_id user_message = getattr(parsed_input, "user_message", str(input_data)) await self.storage.add_chat_message( agent_id=handle.agent_id, session_id=handle.session_id, run_id=handle.run_id, role="user", content=user_message, ) try: await self._walk(run_context, state) except WorkflowException: raise except Exception as e: state.status = "failed" state.error = str(e) raise WorkflowException(e) from e finally: await self.kernel.close() # Record and report token usage, then persist the final state # (including the totals) regardless of success or failure. state.token_usage = self._token_stats.summary() if self.task_logger: await self.task_logger.flush() await self._checkpoint(run_context, state) self._log_token_usage(invocation_id) return state
def _log_token_usage(self, invocation_id: str) -> None: """Log the aggregate model token usage for the run.""" s = self._token_stats logger.info( f"[{invocation_id}] Workflow '{self.graph.name}' token usage: " f"{s.model_calls} model call(s), {s.total_tokens} tokens " f"(prompt={s.prompt_tokens}, completion={s.completion_tokens})" ) async def _walk(self, run_context: RunContext, state: WorkflowState) -> None: current: Optional[str] = self.graph.start visits = 0 while current is not None: node = self.node_map[current] visits += 1 if visits > self.max_node_visits: raise WorkflowException( f"Exceeded max node visits ({self.max_node_visits}); " "the workflow may contain an infinite loop." ) state.current_node = node.name await self._execute_node(node, run_context) state.trace.append(node.name) state.data = to_plain(run_context.data) await self._checkpoint(run_context, state) if isinstance(node, EndNode): await self._finish(node, run_context, state) return current = self._next_node(node, run_context) # A non-end node with no outgoing transition (switch with no default match). raise WorkflowException( f"Workflow halted at node '{state.current_node}' with no next node " "and without reaching an end node." ) async def _finish( self, node: EndNode, run_context: RunContext, state: WorkflowState ) -> None: output_value = run_context.data.get(node.output) output_data = to_plain(output_value) if output_value is not None else None state.output_data = output_data state.status = "completed" if self.storage and run_context.run_id: await self.storage.update_run( str(run_context.run_id), output_data=output_data, context=to_plain(run_context.data), ) agent_response = getattr(output_value, "agent_response", "") await self.storage.add_chat_message( agent_id=str(run_context.agent_id), session_id=str(run_context.session_id), run_id=str(run_context.run_id), role="assistant", content=agent_response, ) await self._checkpoint(run_context, state) logger.info( f"[{state.invocation_id}] Workflow '{self.graph.name}' completed " f"(session={state.session_id})" ) async def _checkpoint(self, run_context: RunContext, state: WorkflowState) -> None: if self.storage and run_context.run_id: await self.storage.save_state(str(run_context.run_id), state)