Skip to content

Agent Module

The Agent module contains the AzadAgent class, which is the central component of the Azad system. It manages tasks, handles tool calls, and coordinates the execution of the AI agent.

Overview

The AzadAgent class is a ProtocolHandler that manages the execution of tasks. It:

  • Handles incoming requests from clients
  • Manages task state and execution
  • Coordinates tool calls and responses
  • Communicates with the AI network

The agent uses a request/response pattern for most operations, with streaming responses for operations that produce multiple events over time.

Key Concepts

Task Management

The agent manages tasks through several methods:

  • load_task(): Loads a task and its configuration
  • get_loaded_task(): Returns the currently loaded task
  • step_task(): Executes a single step of the task

Tool Execution

The agent handles tool execution through:

  • Tool calls from the language model
  • Approval from the user (if required)
  • Execution of the tool
  • Return of the result to the language model

Server Tools

The agent supports server-side tools that run on the server without requiring user approval for each step. These tools:

  • Run asynchronously in the background
  • Report progress through callbacks
  • Access server resources and environments
  • Return structured results to the agent

Daemon Mode

The agent can run in daemon mode, which:

  • Starts a WebSocket server
  • Listens for client connections
  • Handles client requests
  • Manages the agent lifecycle

Implementation Details

The AzadAgent class implements the ProtocolHandler interface, which allows it to handle network protocol messages. It uses:

  • @request_response decorators for methods that return a single response
  • @request_stream decorators for methods that return a stream of responses

The agent maintains state for:

  • The current task
  • The environment
  • The task configuration
  • Active server tools

Usage Example

Here's a simplified example of how to use the AzadAgent:

# Create an agent
agent = AzadAgent()

# Load a task
await agent.load_task(task, task_config)

# Execute a step
async for event in agent.step_task(task, task_config):
    # Process the event
    handle_event(event)

# Run a server tool
async for event in agent.run_server_tool("tool_name", {"param": "value"}):
    # Process the event
    handle_event(event)

Important Considerations

When working with the AzadAgent:

  1. Task State: The agent maintains the state of the current task. Make sure to load a task before executing steps.

  2. Environment Initialization: The agent initializes the environment when a task is loaded. Make sure the protocol is set before initializing the environment.

  3. Cancellation: Use the abort_step() method to cancel a running step.

  4. Server Tools: Server tools run asynchronously in the background. Use the _cancel_all_server_tools() method to cancel all running server tools.

  5. Daemon Mode: Use the enter_daemon_mode() class method to start the agent in daemon mode.

API Reference

azad.agent

Attributes

Classes

SingleRequest

Bases: BaseModel

Attributes
task instance-attribute
task: Task[None]
model_config class-attribute instance-attribute
model_config = ConfigDict(arbitrary_types_allowed=True)

ResponseDataOneShot

Bases: BaseModel

Attributes
tool_calls class-attribute instance-attribute
tool_calls: list[Dict[str, Any]] = []
usage class-attribute instance-attribute
usage: Optional[Dict[str, Any]] = None
model_config class-attribute instance-attribute
model_config = ConfigDict(arbitrary_types_allowed=True)

AzadAgent

AzadAgent(config=None)

Bases: ProtocolHandler

Source code in azad/agent.py
def __init__(self, config=None):
    self.request_map = {}
    self.logger.info("created AzadAgent")
Attributes
logger class-attribute instance-attribute
logger = getLogger(__name__)
protocol instance-attribute
protocol: Optional[Protocol] = None
Functions
step_task async
step_task(task: Task, task_config: TaskConfig, iteration: int) -> AsyncIterator[AINetworkEventUnion]
Source code in azad/agent.py
@request_stream
async def step_task(self,task: Task,task_config: TaskConfig,iteration: int) -> AsyncIterator[ainetwork_types.AINetworkEventUnion]:
    cancellation_token = asyncio.Event()
    request_id = generate_untagged_request_id(self.request_map)

    event_queue: asyncio.Queue[ainetwork_types.AINetworkEventUnion | None] = asyncio.Queue()
    loop = asyncio.get_event_loop()

    # Define emit_event within the scope to capture the queue
    def emit_event_sync(event: ainetwork_types.AINetworkEventUnion | None) -> None:
        try:
            self.logger.debug(f"Putting event in queue: {event.__class__.__name__ if event else 'None'}")
            # Use call_soon_threadsafe if loop_stepper might run in another thread
            # loop.call_soon_threadsafe(event_queue.put_nowait, event)
            # If loop_stepper runs in the same event loop, put_nowait is fine
            event_queue.put_nowait(event)
        except Exception as e:
             self.logger.error(f"Error putting event in queue: {e}", exc_info=True)

    # Create the environment
    environment = Environment(self.protocol) # type: ignore
    # Create the task with explicit type annotation
    step_task = asyncio.create_task(
        loop_stepper.step(
            task=task, # Now guaranteed not None
            task_config=task_config,
            environment=environment,
            # Corrected: Use the defined sync emitter
            emit_event=emit_event_sync,
            iteration=iteration,
            cancellation_token=cancellation_token
        )
    )

    self.request_map[request_id] = SingleRequest(
        request_id=request_id,
        cancellation_token=cancellation_token,
        task=step_task
    )

    # Log to help trace the cancellation token lifecycle
    self.logger.debug(f"Created new request with ID: {request_id} and cancellation token: {cancellation_token} and async io task id: {id(step_task)}")

    # Emit the request ID to the event queue
    emit_event_sync(
        ainetwork_types.NetworkRequestAck(
            request_id=request_id,
        )
    )
    # on cancellation token set we want to send ainetwork_types.NetworkAbortRequestAck to the client
    # and then let the task finish

    # Set up monitoring for the cancellation token
    async def monitor_cancellation():
        try:
            # Wait for the cancellation token to be set
            await cancellation_token.wait()
            self.logger.info(f"Cancellation token set for request {request_id}, sending abort acknowledgment")
            # Send the acknowledgment to the client
            emit_event_sync(
                ainetwork_types.NetworkAbortRequestAck()
            )
            # We don't raise an exception here, allowing the task to finish gracefully
        except Exception as e:
            self.logger.error(f"Error in cancellation monitor: {e}", exc_info=True)

    # Start the monitoring task
    cancellation_monitor = asyncio.create_task(monitor_cancellation())


    while True:
        evt = await event_queue.get()
        # Check for None *before* yielding
        if evt is None:
            event_queue.task_done() # Mark task as done for the queue
            break
        yield evt # Yield only if not None
        if isinstance(evt, ainetwork_types.StepEndEvent):
            event_queue.task_done() # Mark task as done for the queue
            break
        event_queue.task_done() # Mark task as done for the queue
    print("Step task done BEFORE AWAIT")
    try:
        await step_task
        # Corrected: Indent this line
        print("Step task done AFTER AWAIT")
    except asyncio.CancelledError:
        self.logger.info("Step task await cancelled.")
    except Exception as e:
        self.logger.error(f"Error occurred during step task execution: {e}", exc_info=True)
        # Optionally yield an error event back to the client
        yield ainetwork_types.NetworkConnectionInterrupted(
            message=f"Error during step execution: {e}",
            errName=type(e).__name__,
            origErrName=type(e).__name__
        )
    finally:
        # Clean up the cancellation monitor task
        if 'cancellation_monitor' in locals():
            if not cancellation_monitor.done():
                cancellation_monitor.cancel()
            try:
                await asyncio.wait_for(cancellation_monitor, timeout=1.0)
            except (asyncio.CancelledError, asyncio.TimeoutError):
                pass
            except Exception as e:
                self.logger.error(f"Error cleaning up cancellation monitor: {e}", exc_info=True)
abort_step async
abort_step(request_id: str) -> BaseResponse

Aborts the currently running agent step.

Source code in azad/agent.py
@request_response
async def abort_step(self,request_id: str) -> BaseResponse:
    """Aborts the currently running agent step."""
    self.logger.info("Abort step called.")

    # Capture current task and token state at start of abort
    if request_id not in self.request_map:
        self.logger.info(f"Request ID {request_id} not found in request map.")
        return MessageResponse(message="Request ID not found.")

    target_task = self.request_map[request_id].task
    target_token = self.request_map[request_id].cancellation_token

    # Add extra diagnostic info about what we're trying to cancel
    if target_task:
        self.logger.info(f"Setting cancellation token for task id {id(target_task)}")
    else:
        self.logger.info("No active step task to abort")
        return MessageResponse(message="No active step task to abort.")

    try:
        # Signal cancellation via the event first
        if target_token:
            self.logger.info(f"Setting cancellation token (id: {id(target_token)}) to abort LLM stream")
            target_token.set()

        # Then cancel the task (this will properly cleanup)
        if target_task and not target_task.done():
            self.logger.info(f"Cancelling step task id {id(target_task)}")
            target_task.cancel()

            try:
                # Wait for the task to acknowledge cancellation
                await asyncio.wait_for(target_task, timeout=5.0)
                self.logger.info(f"Step task id {id(target_task)} cancellation processed (awaited).")
            except asyncio.CancelledError:
                self.logger.info(f"Step task id {id(target_task)} was successfully cancelled (threw CancelledError).")
            except asyncio.TimeoutError:
                self.logger.warning(f"Timeout waiting for step task id {id(target_task)} cancellation confirmation.")
            except Exception as e:
                 self.logger.error(f"Error during step task id {id(target_task)} cancellation: {e}", exc_info=True)

            return MessageResponse(message="Abort signal sent and cancellation initiated.")
        else:
            # This happens if task is already done or was already canceled
            reason = "already done" if target_task and target_task.done() else "already canceled"
            self.logger.info(f"No active step task to cancel or it has {reason}.")
            return MessageResponse(message=f"No active step task to abort or it has {reason}.")
    except Exception as e:
        self.logger.error(f"Error in abort_step: {e}", exc_info=True)
        # Corrected: Use 'error' parameter for ErrorResponse
        return ErrorResponse(error=f"Failed to abort step: {e}")
request_response_oneshot async
request_response_oneshot(request: OneShotRequestUnion) -> BaseResponse

Handles a non-streaming one-shot request to the AI network using litellm directly. Makes a synchronous call to litellm and returns the response.

Source code in azad/agent.py
@request_response
async def request_response_oneshot(
    self,
    request: ainetwork_types.OneShotRequestUnion,
) -> BaseResponse:
    """
    Handles a non-streaming one-shot request to the AI network using litellm directly.
    Makes a synchronous call to litellm and returns the response.
    """
    try:
        # Extract parameters from the request
        model_id = request.model_id
        messages = request.params.messages
        api_key = request.model_api_key
        temperature = request.temperature if request.params else 1.0
        max_tokens = request.max_tokens if request.params else None




        # Convert messages to litellm format
        litellm_messages = []
        for msg in messages:
            # Extract text content from message parts based on type
            content = ""
            if hasattr(msg, 'content') and isinstance(msg.content, list):
                text_parts = []
                for part in msg.content:
                    # Use isinstance for proper type checking
                    if isinstance(part, TextPart):
                        text_parts.append(part.text)
                    elif isinstance(part, ReasoningPart):
                        text_parts.append(part.reasoning)
                    elif isinstance(part, InformationalPart) and part.details:
                        text_parts.append(part.details)
                content = "\n".join(filter(None, text_parts))

            litellm_msg = {
                "role": msg.role,
                "content": content
            }

            # Handle AssistantMessage with tool calls
            if isinstance(msg, AssistantMessage) and msg.content:
                tool_calls = []
                for part in msg.content:
                    if isinstance(part, ToolCallPart):
                        # Convert args to JSON string if it's a dict
                        args_str = json.dumps(part.args) if isinstance(part.args, dict) else str(part.args)
                        tool_calls.append({
                            "id": part.tool_call_id,
                            "type": "function",
                            "function": {
                                "name": part.tool_name,
                                "arguments": args_str
                            }
                        })
                if tool_calls:
                    litellm_msg["tool_calls"] = tool_calls

            # Handle ToolMessage with tool call id
            elif isinstance(msg, ToolMessage) and msg.content:
                for part in msg.content:
                    if isinstance(part, ToolResultPart):
                        litellm_msg["tool_call_id"] = part.tool_call_id
                        # Extract result content
                        if part.result:
                            litellm_msg["content"] = json.dumps(part.result) if isinstance(part.result, dict) else str(part.result)
                        break

            litellm_messages.append(litellm_msg)

        # Prepare kwargs for litellm
        kwargs = {
            "model": model_id,
            "messages": litellm_messages,
            "temperature": temperature,
        }

        if api_key:
            kwargs["api_key"] = api_key
        if max_tokens:
            kwargs["max_tokens"] = max_tokens

        # Make synchronous call to litellm
        # Note: Using asyncio.to_thread to run sync function in async context
        response = await asyncio.to_thread(litellm.completion, **kwargs)

        # Extract response data - litellm returns a ModelResponse object
        if not response:
            return ErrorResponse(error="No response from model")

        # Get choices from response
        choices = getattr(response, 'choices', None)
        if not choices or len(choices) == 0:
            return ErrorResponse(error="No choices in response from model")

        choice = choices[0]

        # Extract message from choice
        message = getattr(choice, 'message', None)
        if not message:
            # Try alternative attribute names
            message = getattr(choice, 'delta', None) or choice

        # Extract content
        content = ""
        if hasattr(message, 'content') and message.content is not None:
            content = str(message.content)
        elif hasattr(choice, 'text') and choice.text is not None:
            content = str(choice.text)

        # Format tool calls if present
        tool_calls = []
        tool_calls_attr = getattr(message, 'tool_calls', None)
        if tool_calls_attr:
            for tc in tool_calls_attr:
                tc_id = getattr(tc, 'id', '')
                tc_function = getattr(tc, 'function', None)
                if tc_function:
                    tool_calls.append({
                        "tool_name": getattr(tc_function, 'name', ''),
                        "tool_call_id": tc_id,
                        "args": getattr(tc_function, 'arguments', '{}')
                    })

        # Build response data
        response_data = {
            "content": content,
            "tool_calls": tool_calls,
        }

        # Add usage information if available
        usage = getattr(response, 'usage', None)
        if usage:
            response_data["usage"] = {
                "cost": getattr(usage, 'total_cost', 0.0),
                "input_tokens": getattr(usage, 'prompt_tokens', 0),
                "output_tokens": getattr(usage, 'completion_tokens', 0),
            }


        return DataResponse(data=response_data)

    except Exception as e:
        self.logger.error(f"Error in oneshot_request: {e}", exc_info=True)
        return ErrorResponse(error=f"Failed to process oneshot request: {str(e)}")
upload_failed_edit async
upload_failed_edit(task: Task, task_config: TaskConfig, file_content: str) -> BaseResponse

Schedules the upload of a failed edit to run as a background job.

Source code in azad/agent.py
@request_response
async def upload_failed_edit(self, task: Task, task_config: TaskConfig, file_content: str) -> BaseResponse:
    """
    Schedules the upload of a failed edit to run as a background job.
    """
    try:
        # Schedule the background task
        asyncio.create_task(
            self._record_failed_edit_in_background(task, task_config, file_content)
        )
        # Return immediately
        return MessageResponse(message="Failed edit recording has been scheduled to run in the background.")
    except Exception as e:
        self.logger.error(f"Error scheduling failed edit upload: {e}", exc_info=True)
        return ErrorResponse(error=f"Failed to schedule failed edit upload: {str(e)}")
render_prompt async
render_prompt(task: Task, task_config: TaskConfig) -> BaseResponse

Renders the prompt messages that would be sent to the LLM without actually sending them.

Source code in azad/agent.py
@request_response
async def render_prompt(self, task: Task, task_config: TaskConfig) -> BaseResponse:
    """
    Renders the prompt messages that would be sent to the LLM without actually sending them.
    """
    try:
        if self.protocol is None:
            return ErrorResponse(error="Protocol is not available.")

        # Get the current task config part, which holds tool metadata and dialect info
        task_config_part = task.current_task_config()
        tool_metadata = task_config_part.tool_metadata

        # Determine the dialect
        dialect_name = task_config_part.dialect_name or "xml"  # Fallback to xml
        dialect_options = task_config_part.dialect_options or {}

        # Use the registry's create method to properly instantiate the dialect with its config
        dialect = PromptDialectRegistry.create(dialect_name, **dialect_options)

        # Prepare prompt data
        prompt_data = PromptData(
            messages=task.messages,
            task_config=task_config_part,
            dyanmic_task_config=task_config,
            tool_metadata=tool_metadata,
            current_assistant_id="assistant",  # Placeholder, not critical for rendering
        )

        # Format messages
        formatted_messages = dialect.format_messages(prompt_data, rules_path=None)

        # Pydantic models in the data will be automatically serialized to dicts
        return DataResponse(data={"messages": formatted_messages})

    except Exception as e:
        self.logger.error(f"Error in render_prompt: {e}", exc_info=True)
        return ErrorResponse(error=f"Failed to render prompt: {str(e)}")
run_server_tool async
run_server_tool(tool_name: str, args: Dict[str, Any], task_config: Optional[TaskConfig] = None) -> AsyncIterator[AINetworkEventUnion]

Runs a server-side tool initiated by the client.

Source code in azad/agent.py
@request_stream
async def run_server_tool(self, tool_name: str, args: Dict[str, Any],
    task_config: Optional[TaskConfig] = None, # Optional task config
    ) -> AsyncIterator[ainetwork_types.AINetworkEventUnion]:
    """Runs a server-side tool initiated by the client."""
    server_tool_run_id = nanoid.generate()
    self.logger.info(f"Received request to run server tool '{tool_name}' with ID: {server_tool_run_id}")

    tool_class = SERVER_TOOL_REGISTRY.get(tool_name)
    if not tool_class:
        self.logger.error(f"Server tool '{tool_name}' not found in registry.")
        # Corrected: Use keyword arguments for constructor and add step_number
        yield ainetwork_types.ServerToolUpdateEvent(
            server_tool_run_id=server_tool_run_id,
            tool_name=tool_name,
            step_number=None, # Explicitly add optional field
            status="error",
            data={"error": f"Tool '{tool_name}' not found."}
        )
        return

    # Prepare config for the tool - potentially merge global/task config
    # For now, just pass task_config if available, otherwise empty dict
    tool_config = task_config.model_dump() if task_config else {}

    tool_instance: ServerToolBase | None = None
    tool_task: asyncio.Task | None = None

    try:
        # Instantiate the tool
        tool_instance = tool_class(config=tool_config, protocol=self.protocol) # type: ignore

        # Create a queue for communicating between browser tool and our stream
        event_queue: asyncio.Queue[ainetwork_types.AINetworkEventUnion] = asyncio.Queue()

        # Create a regular async function that browser tool can await
        async def server_tool_step_handler(update_payload: Dict[str, Any]) -> None:
            # Create the event from the payload
            event_data = ainetwork_types.ServerToolUpdateEvent(
                server_tool_run_id=update_payload.get("server_tool_run_id", server_tool_run_id),
                tool_name=update_payload.get("tool_name", tool_name),
                step_number=update_payload.get("step_number"),
                status=update_payload.get("status", "progress"), # Default to progress
                data=update_payload.get("data", {})
            )
            # Put the event in the queue to be yielded to the client
            await event_queue.put(event_data)

        # Create and track the background task for the tool run
        tool_task = asyncio.create_task(
            tool_instance.run(
                args=args,
                step_callback=server_tool_step_handler, # Pass the regular async function
                server_tool_run_id=server_tool_run_id
            )
        )
        self.logger.info(f"Started server tool '{tool_name}' run ID: {server_tool_run_id}")

        # Yield events from the queue back to the client
        try:
            # Process events from the queue while the tool task is running
            while True:
                # Use a small timeout to periodically check if the tool task is done
                try:
                    # Wait for an event with a timeout
                    event = await asyncio.wait_for(event_queue.get(), 0.1)
                    # Yield the event to the client
                    yield event
                    event_queue.task_done()
                except asyncio.TimeoutError:
                    # Check if the tool task is done
                    if tool_task.done():
                        # Get the result (or re-raise any exception)
                        final_result = await tool_task
                        break
                    # Otherwise continue waiting for events
                    continue

            # Drain any remaining events in the queue
            while not event_queue.empty():
                event = await event_queue.get()
                yield event
                event_queue.task_done()

            # Log completion
            self.logger.info(f"Server tool '{tool_name}' run ID {server_tool_run_id} finished.")
        except Exception as e:
            self.logger.error(f"Error processing events from '{tool_name}' (ID: {server_tool_run_id}): {e}", exc_info=True)
            # If the tool task is still running, cancel it
            if not tool_task.done():
                tool_task.cancel()
            # Re-raise the exception to be caught by the outer try/except
            raise


    except Exception as e:
        self.logger.error(f"Error running server tool '{tool_name}' (ID: {server_tool_run_id}): {e}", exc_info=True)
        # Yield a final error event if the exception occurred outside the tool's run() try/except
        # Corrected: Use keyword arguments for constructor and add step_number
        yield ainetwork_types.ServerToolUpdateEvent(
            server_tool_run_id=server_tool_run_id,
            tool_name=tool_name,
            step_number=None, # Explicitly add optional field
            status="error",
            data={"error": f"Failed to run tool: {e}", "traceback": traceback.format_exc()}
        )
    finally:
        self.logger.debug(f"Cleaned up tracking for server tool run ID: {server_tool_run_id}")
enter_daemon_mode async classmethod
enter_daemon_mode(config: GlobalConfig) -> None
Source code in azad/agent.py
@classmethod
async def enter_daemon_mode(cls, config: GlobalConfig) -> None:
    protocol: WebSocketNetworkProtocol | None = None
    try:
        print(f"Starting daemon mode with config: {config}")
        # Create protocol with self as handler and callbacks
        protocol = WebSocketNetworkProtocol(handler_class=cls)
        # Start protocol server
        await protocol.start_server(config.port, config.interface)
    except Exception as e:
         cls.logger.critical(f"Failed to start daemon mode: {e}", exc_info=True)
         # Perform any necessary cleanup before exiting
         if protocol:
             await protocol.stop_server(2) # Assuming stop_server exists and cleans up
    finally:
         # Ensure cleanup happens even if start_server runs indefinitely until interrupted
         cls.logger.info("Daemon mode shutting down.")
         if protocol:
             # Cancel any running server tools managed by handler instances
             # This requires accessing handler instances, which might be complex depending on protocol design
             # For now, assume protocol.stop_server handles graceful shutdown if possible
             await protocol.stop_server(2)

Functions

generate_untagged_request_id

generate_untagged_request_id(request_map: Dict[str, SingleRequest]) -> str

Generates a unique request ID that is not already in use.

Source code in azad/agent.py
def generate_untagged_request_id(request_map: Dict[str, SingleRequest]) -> str:
    """Generates a unique request ID that is not already in use."""
    while True:
        request_id = uuid.uuid4().hex
        if request_id not in request_map:
            return request_id

Modules