Skip to content

Extending Azad

This guide explains how to extend Azad with new features, such as adding new tools, dialects, or compression strategies.

Adding New Tools

Azad supports two types of tools:

  1. Client-side tools: These are executed by the client (e.g., VS Code extension) and require user approval.
  2. Server-side tools: These are executed by the server and don't require user approval.

Creating a Server-Side Tool

To create a new server-side tool, you need to:

  1. Create a new class that inherits from ServerToolBase in the azad/server_tools directory.
  2. Register the tool in the SERVER_TOOL_REGISTRY.

Here's an example of a simple server-side tool:

from typing import Dict, Any, Optional
import asyncio
import logging
from azad.server_tools.base import ServerToolBase

class MyTool(ServerToolBase):
    """A simple server-side tool."""

    def __init__(self, config: Dict[str, Any], protocol: Optional[Any] = None):
        super().__init__(config, protocol)
        self.logger = logging.getLogger(__name__)

    async def run(self, args: Dict[str, Any], step_callback: callable, server_tool_run_id: str) -> Dict[str, Any]:
        """Run the tool with the given arguments."""
        # Report initial progress
        await step_callback({
            "server_tool_run_id": server_tool_run_id,
            "tool_name": "my_tool",
            "step_number": 1,
            "status": "progress",
            "data": {"message": "Starting my tool..."}
        })

        # Do some work
        await asyncio.sleep(1)  # Simulate work

        # Report completion
        await step_callback({
            "server_tool_run_id": server_tool_run_id,
            "tool_name": "my_tool",
            "step_number": 2,
            "status": "complete",
            "data": {"result": "Tool execution complete!"}
        })

        # Return the final result
        return {"result": "Tool execution complete!"}

# Register the tool
from azad.server_tools import SERVER_TOOL_REGISTRY
SERVER_TOOL_REGISTRY["my_tool"] = MyTool

Defining Tool Metadata

For the AI to use your tool, you need to define its metadata in the task configuration:

from azad.type_definitions import ToolMetadata, ParameterMetadata, ToolExample

my_tool_metadata = ToolMetadata(
    name="my_tool",
    description="A simple tool that does something useful.",
    parameters={
        "param1": ParameterMetadata(description="The first parameter", streamable=False),
        "param2": ParameterMetadata(description="The second parameter", streamable=True),
    },
    required_parameters=["param1"],
    examples=[
        ToolExample(
            explanation="Example usage of my_tool",
            parameters={"param1": "value1", "param2": "value2"}
        )
    ],
    is_task_entry=False,
    is_task_exit=False,
    supports_parallel=False
)

# Add this to your task configuration
task_config.tool_metadata.append(my_tool_metadata)

Adding New Dialects

Dialects in Azad define how the agent communicates with the language model. To add a new dialect:

  1. Create a new directory in azad/prompts/dialects for your dialect.
  2. Create a dialect.py file with a class that inherits from Dialect.
  3. Create a parser.py file with a class that inherits from DialectParser.
  4. Register your dialect in the PromptDialectRegistry.

Here's an example of a simple dialect:

# azad/prompts/dialects/my_dialect/dialect.py
from typing import Optional, Dict, List
from azad.prompts.base_dialect import Dialect, DialectParser, PromptData
from azad.mind_map import Message, MessageRole, ToolCallPart, ToolResultPart

class MyDialectConfig:
    """Configuration for MyDialect."""
    pass

class MyDialect(Dialect):
    """A simple dialect for AI agent communication."""
    config_cls = MyDialectConfig

    def __init__(self, config: MyDialectConfig):
        self.config = config

    def format_dialect_rules(self, prompt_data: PromptData) -> str:
        """Return the dialect rules."""
        return """
        # My Dialect Rules

        When you want to use a tool, format your request like this:

        USE_TOOL: tool_name
        PARAM1: value1
        PARAM2: value2
        END_TOOL
        """

    def format_example(self, tool_name: str, parameters: dict) -> str:
        """Format an example tool call."""
        params = "
".join([f"{k.upper()}: {v}" for k, v in parameters.items()])
        return f"USE_TOOL: {tool_name}
{params}
END_TOOL"

    def format_tool_call(self, tool_call: ToolCallPart) -> str:
        """Format a tool call."""
        params = "
".join([f"{k.upper()}: {v}" for k, v in tool_call.args.items()])
        return f"USE_TOOL: {tool_call.tool_name}
{params}
END_TOOL"

    def format_tool_result(self, tool_result: ToolResultPart) -> str:
        """Format a tool result."""
        return f"TOOL_RESULT: {tool_result.tool_name}
{tool_result.result}
END_RESULT"

    def format_history_item(self, item: Message) -> Optional[dict]:
        """Convert a Message to a LiteLLM message."""
        if item.role == MessageRole.user:
            return {"role": "user", "content": self._format_user_content(item)}
        elif item.role == MessageRole.assistant:
            return {"role": "assistant", "content": self._format_assistant_content(item)}
        elif item.role == MessageRole.tool:
            return {"role": "user", "content": self._format_tool_content(item)}
        elif item.role == MessageRole.system:
            return {"role": "system", "content": self._format_system_content(item)}
        return None

    def _format_user_content(self, item: Message) -> str:
        """Format user message content."""
        return " ".join(part.text for part in item.content if part.type == "text")

    def _format_assistant_content(self, item: Message) -> str:
        """Format assistant message content."""
        text = " ".join(part.text for part in item.content if part.type == "text")
        tools = "
".join(self.format_tool_call(tc) for tc in item.content if tc.type == "toolCall")
        return f"{text}
{tools}" if tools else text

    def _format_tool_content(self, item: Message) -> str:
        """Format tool message content."""
        return "
".join(self.format_tool_result(tr) for tr in item.content if tr.type == "toolResult")

    def _format_system_content(self, item: Message) -> str:
        """Format system message content."""
        return " ".join(part.text for part in item.content if part.type == "text")

    def create_parser(self, prompt_data: PromptData) -> DialectParser:
        """Create a parser for this dialect."""
        from .parser import MyDialectParser
        return MyDialectParser(self.config)

    def is_native_toolcalling(self) -> bool:
        """Return whether this dialect supports native tool calling."""
        return False
# azad/prompts/dialects/my_dialect/parser.py
import re
from typing import List
from azad.prompts.base_dialect import DialectParser
from azad.ainetwork.types import AINetworkEventUnion, AINetworkEventTextChunk, AINetworkEventToolName, AINetworkEventParameterStart, AINetworkEventParameterChunk, AINetworkEventParameterEnd

class MyDialectParser(DialectParser):
    """Parser for MyDialect."""

    def __init__(self, config):
        self.config = config
        self.buffer = ""
        self.current_tool = None
        self.current_param = None
        self.current_params = {}

    def feed(self, data: bytes) -> List[AINetworkEventUnion]:
        """Parse incoming data and emit events."""
        text = data.decode('utf-8')
        self.buffer += text
        events = []

        # Check for tool calls
        if "USE_TOOL:" in self.buffer:
            # Extract tool name
            tool_match = re.search(r"USE_TOOL: (\w+)", self.buffer)
            if tool_match and not self.current_tool:
                self.current_tool = tool_match.group(1)
                events.append(AINetworkEventToolName(tool_name=self.current_tool, tool_call_id=None))

                # Extract parameters
                param_matches = re.finditer(r"(\w+): (.+?)(?=\n\w+:|END_TOOL)", self.buffer, re.DOTALL)
                for match in param_matches:
                    param_name = match.group(1).lower()
                    param_value = match.group(2).strip()

                    if param_name != "use_tool":  # Skip the tool name
                        events.append(AINetworkEventParameterStart(parameter=param_name))
                        events.append(AINetworkEventParameterChunk(parameter=param_name, content=param_value))
                        events.append(AINetworkEventParameterEnd(parameter=param_name))
                        self.current_params[param_name] = param_value

                # Check if tool call is complete
                if "END_TOOL" in self.buffer:
                    self.current_tool = None
                    self.current_params = {}
                    # Remove the processed tool call from the buffer
                    end_idx = self.buffer.find("END_TOOL") + len("END_TOOL")
                    self.buffer = self.buffer[end_idx:]

        # Add any remaining text as a text chunk
        if self.buffer and not self.current_tool:
            events.append(AINetworkEventTextChunk(content=self.buffer))
            self.buffer = ""

        return events

Adding New Compression Strategies

Compression strategies in Azad help manage context windows for large conversations. To add a new compression strategy:

  1. Create a new class that inherits from CompressionStrategy in the azad/compression/strategies directory.
  2. Register the strategy in the CompressionStrategyRegistry.

Here's an example of a simple compression strategy:

from typing import List, Optional
from enum import Enum
from azad.compression.core import CompressionStrategy, CompressionStrategyType, CompressionConfig, CompressionCheckpoint
from azad.mind_map import Task, Message

class MyCompressionConfig(CompressionConfig):
    """Configuration for MyCompressionStrategy."""
    strategy_type: str = CompressionStrategyType.MY_STRATEGY
    max_messages: int = 10

class MyCompressionStrategy(CompressionStrategy):
    """A simple compression strategy that keeps only the last N messages."""

    @property
    def strategy_type(self) -> CompressionStrategyType:
        """Get the type of this compression strategy."""
        return CompressionStrategyType.MY_STRATEGY

    async def compress(self, task: Task, new_checkpoint: Optional[CompressionCheckpoint], config: MyCompressionConfig) -> List[Message]:
        """Compress the messages in a task."""
        # Get the current task messages
        current_messages = task.current_task_messages()

        # If we have fewer messages than the limit, no compression needed
        if len(current_messages) <= config.max_messages:
            return current_messages

        # Keep only the last N messages
        kept_messages = current_messages[-config.max_messages:]

        # If we're creating a new checkpoint, update it
        if new_checkpoint:
            # Record which messages were compressed and which were kept
            new_checkpoint.compressed_message_ids = [msg.id for msg in current_messages[:-config.max_messages]]
            new_checkpoint.kept_message_ids = [msg.id for msg in kept_messages]

            # Add metadata if needed
            new_checkpoint.metadata = {
                "strategy": self.strategy_type,
                "max_messages": config.max_messages,
                "compressed_count": len(current_messages) - config.max_messages,
                "kept_count": len(kept_messages)
            }

        return kept_messages

# Register the strategy
from azad.compression.registry import CompressionStrategyRegistry
CompressionStrategyRegistry.register(CompressionStrategyType.MY_STRATEGY, MyCompressionStrategy)

Extending the Agent

To extend the AzadAgent class with new functionality:

  1. Subclass AzadAgent and add your new methods.
  2. Use the @request_response or @request_stream decorators for network communication.

Here's an example:

from typing import Dict, Any
from azad.agent import AzadAgent
from azad.slipstream.base import request_response
from azad.network.interfaces import MessageResponse, ErrorResponse

class MyExtendedAgent(AzadAgent):
    """Extended version of AzadAgent with additional functionality."""

    @request_response
    async def my_custom_method(self, param1: str, param2: int) -> MessageResponse | ErrorResponse:
        """A custom method that does something useful."""
        try:
            # Do something with the parameters
            result = f"Processed {param1} with value {param2}"

            # Return a success response
            return MessageResponse(message=result)
        except Exception as e:
            # Return an error response
            return ErrorResponse(error=f"Failed to process: {e}")

Contributing to Azad

If you've developed a useful extension to Azad, consider contributing it back to the project:

  1. Fork the repository on GitHub.
  2. Create a new branch for your feature.
  3. Add your code following the project's coding style.
  4. Add tests for your new functionality.
  5. Submit a pull request with a clear description of your changes.

For more information on contributing, see the project's GitHub repository.