Skip to content

AI Network Module

The AI Network module manages communication with language model providers, processes streaming responses, and handles context window exceeded errors with compression.

Overview

The AINetwork class is the main component of this module. It:

  • Sends requests to language models via LiteLLM
  • Processes streaming responses
  • Manages context window exceeded errors with compression
  • Emits standardized network events

This module is critical for the agent's ability to communicate with language models and handle their responses.

Key Concepts

Request/Response Flow

The AI Network follows a specific flow for making requests to language models:

  1. Format the prompt data using the specified dialect
  2. Send the request to the language model
  3. Process the streaming response
  4. Emit events based on the response
  5. Handle any errors that occur

Compression Handling

When a context window exceeded error occurs, the AI Network:

  1. Uses a compression handler to compress the message history
  2. Retries the request with the compressed messages
  3. Emits compression events to inform the client

Event Emission

The AI Network emits various events during the request/response cycle:

  • NetworkConnectionAttempt: When a connection is attempted
  • NetworkConnectionEstablished: When a connection is established
  • NetworkConnectionFailed: When a connection fails
  • NetworkConnectionInterrupted: When a connection is interrupted
  • AINetworkEventContentComplete: When content is complete
  • AINetworkConnectionEnded: When a connection ends

Dialect Integration

The AI Network works with the dialect system to:

  1. Format prompts for the language model
  2. Parse responses from the language model
  3. Extract tool calls and other content

Implementation Details

The make_request() method is the main entry point for the AI Network. It:

  1. Initializes a compression handler
  2. Formats the prompt data using the specified dialect
  3. Sends the request to the language model
  4. Processes the streaming response
  5. Emits events based on the response
  6. Handles any errors that occur

The method returns an async generator that yields network events.

Usage Example

Here's a simplified example of how to use the AINetwork:

# Create an AI Network
ai_network = AINetwork()

# Make a request
event_stream = ai_network.make_request(
    prompt_data=prompt_data,
    task=task,
    dialect=dialect,
    config=network_config,
    assistant_id=assistant_id,
)

# Process the events
async for event in event_stream:
    # Handle the event
    handle_event(event)

Important Considerations

When working with the AINetwork:

  1. Compression: The AI Network automatically handles context window exceeded errors with compression. Make sure to configure the compression strategy appropriately.

  2. Event Handling: Implement proper event handling to process events emitted by the AI Network.

  3. Error Handling: Handle errors appropriately to ensure the agent can recover and continue execution.

  4. Resource Cleanup: Ensure that resources are properly cleaned up, especially when cancelling a request.

API Reference

azad.ainetwork.network

AINetwork manages communication with language model providers.

This module provides the main AINetwork class that handles: 1. Communicating with language models via LiteLLM 2. Processing streaming responses 3. Managing context window exceeded errors with compression 4. Error handling and usage statistics

Attributes

Classes

AsyncCompletionCallback

AsyncCompletionCallback(future: Future, expected_request_id: str)

Callback class to handle async completion events. This class is used to track whetever the cache was hit or not. It is used to set the result of a future when the callback is called. It now includes a unique request ID to ensure it only processes events for its specific request.

Source code in azad/ainetwork/network.py
def __init__(self, future: asyncio.Future, expected_request_id: str):
    self.logger = logging.getLogger(__name__)
    self.future = future
    self.expected_request_id = expected_request_id
Attributes
logger instance-attribute
logger = getLogger(__name__)
expected_request_id instance-attribute
expected_request_id = expected_request_id

AINetwork

AINetwork()

AINetwork handles communication with language model providers, processes streaming responses, and emits standardized network events for consumption by other system components.

This class is responsible for: 1. Sending requests to language models via LiteLLM 2. Processing streaming responses 3. Managing context window exceeded errors with compression 4. Emitting properly sequenced events 5. Error management and reporting

Initialize the AINetwork.

This constructor initializes the logger and attempts to load custom model settings from a JSON file named "azad_model_cost_map.json" in the current directory.

The custom model settings allow overriding litellm's default model configurations, such as pricing, context window sizes, etc. This enables faster updates to model configurations than waiting for litellm's official updates.

If the file doesn't exist or there's an error loading it, the default litellm settings will be used.

Example JSON format: { "model_name": { "litellm_provider": "openai", "max_tokens": 8192, "input_cost_per_token": 0.0001, "output_cost_per_token": 0.0002 } }

Source code in azad/ainetwork/network.py
def __init__(self) -> None:
    """
    Initialize the AINetwork.

    This constructor initializes the logger and attempts to load custom model settings
    from a JSON file named "azad_model_cost_map.json" in the current directory.

    The custom model settings allow overriding litellm's default model configurations,
    such as pricing, context window sizes, etc. This enables faster updates to model
    configurations than waiting for litellm's official updates.

    If the file doesn't exist or there's an error loading it, the default litellm
    settings will be used.

    Example JSON format:
    {
        "model_name": {
            "litellm_provider": "openai",
            "max_tokens": 8192,
            "input_cost_per_token": 0.0001,
            "output_cost_per_token": 0.0002
        }
    }
    """
    self.logger = logging.getLogger(__name__)
Attributes
logger instance-attribute
logger = getLogger(__name__)
Functions
make_request async
make_request(prompt_data: PromptData, task: Task, dialect: Dialect, config: NetworkConfig, assistant_id: str, tool_call_id: Optional[str] = None) -> AsyncGenerator[AINetworkEventUnion, None]

Make a request to the language model and yield network events from the response.

Parameters:

  • prompt_data (PromptData) –

    The prompt data to send

  • dialect (Dialect) –

    The dialect to use for formatting and parsing

  • config (NetworkConfig) –

    Network configuration

  • assistant_id (str) –

    ID for the assistant

  • tool_call_id (Optional[str], default: None ) –

    Optional ID for the tool call

Yields:

Source code in azad/ainetwork/network.py
async def make_request(
    self,
    prompt_data: PromptData, 
    task: Task,
    dialect: Dialect,
    config: ainetwork_types.NetworkConfig,
    assistant_id: str,
    tool_call_id: Optional[str] = None
) -> AsyncGenerator[ainetwork_types.AINetworkEventUnion, None]:
    """
    Make a request to the language model and yield network events from the response.

    Args:
        prompt_data: The prompt data to send
        dialect: The dialect to use for formatting and parsing
        config: Network configuration
        assistant_id: ID for the assistant
        tool_call_id: Optional ID for the tool call

    Yields:
        Network events from the processed response
    """
    # Initialize the compression handler with the task
    compression_strategy = CompressionStrategyRegistry.get_strategy(CompressionStrategyType.COMPACT_SUMMARIZATION)
    compression_config = CompactSummarizationConfig(api_key=config.api_key, model_name=config.model)
    compressor = Compressor(strategy=compression_strategy, config=compression_config)
    compressor.load(task)
    compression_handler = CompressionHandler(compressor)

    is_native_tool_call = dialect.is_native_toolcalling()
    format_tool_schema = dialect.format_tools_schema(prompt_data.tool_metadata)

    unique_request_id = str(uuid.uuid4())

    # Signal connection attempt
    yield ainetwork_types.NetworkConnectionAttempt(model_name=config.model, assistant_id=assistant_id)

    # Setup usage tracking
    usage_future: asyncio.Future[Dict[str, Any]] = asyncio.Future()

    # Setup callback for tracking usage, now with the unique request ID
    completion_callback = AsyncCompletionCallback(usage_future, unique_request_id)
    # Use the context manager to safely manage callbacks
    async with managed_callbacks(success_callback=completion_callback, failure_callback=completion_callback):
        # Prepare the LLM request parameters
        try:
            response = None

            # Configure thinking parameter for Anthropic models
            thinking_param: Optional[litellm.types.llms.anthropic.AnthropicThinkingParam] = None
            reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None

            if config.reasoning_effort is not None and config.reasoning_effort in ['low', 'medium', 'high']:
                reasoning_effort = config.reasoning_effort
                match reasoning_effort:
                    case "low":
                        thinking_param = litellm.types.llms.anthropic.AnthropicThinkingParam(
                            type="enabled",
                            budget_tokens=litellm.constants.DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET,
                        )
                    case "medium":
                        thinking_param = litellm.types.llms.anthropic.AnthropicThinkingParam(
                            type="enabled",
                            budget_tokens=litellm.constants.DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET,
                        )
                    case "high":
                        thinking_param = litellm.types.llms.anthropic.AnthropicThinkingParam(
                            type="enabled",
                            budget_tokens=litellm.constants.DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET,
                        )


            task_config = task.current_task_config()
            token_limit = task_config.compression_config.token_limit
            is_kodu_provider = "kodu" in task_config.model_name

            # Attempt the request with compression retries
            while True:
                try:
                    # First load the compressor context boundaries
                    current_context_window = await compressor.transform()
                    # Update the prompt data with the context boundaries list
                    prompt_data.override_messages(current_context_window)
                    # Format the messages using the dialect
                    xml_with_search_rules_path = "xml.with.search.rules.prompt" if task_config.enable_search else None
                    messages: List[Dict[str, Any]] = dialect.format_messages(prompt_data, xml_with_search_rules_path)

                    estimated_token_count = estimate_token_count(messages, config.model)
                    if estimated_token_count > token_limit:
                        self.logger.warning(f"Estimated token count ({estimated_token_count}) exceeds limit ({token_limit})")
                        raise litellm.exceptions.ContextWindowExceededError(
                            message=f"Estimated token count ({estimated_token_count}) exceeds configured limit ({token_limit})",
                            model=config.model,
                            llm_provider="unknown"
                        )

                    # Enable parallel tool calls for large context models
                    parallel_tool_calls = prompt_data.task_config.enable_parallel_tools

                    # Prepare additional parameters for Gemini models
                    additional_params: Dict[str, Any] = {"metadata": {"internal_request_id": unique_request_id}}
                    if config.enable_search:
                        if is_native_tool_call:
                            additional_params["tools"] = format_tool_schema + [{"googleSearch": {}}]
                        else:
                            additional_params["tools"] = [{"googleSearch": {}}]

                        self.logger.debug(f"Enabling search/grounding for {config.model}")

                    model_name = config.model
                    user_api_key = config.api_key
                    if is_kodu_provider:
                        additional_params["original_api_key"] = config.api_key
                        additional_params["is_kodu_provider"] = is_kodu_provider
                        model_name = config.model.replace("kodu", "openrouter")
                        user_api_key = settings.OPENROUTER_KEY

                    if litellm.supports_reasoning(model=model_name) == True:
                        additional_params["allowed_openai_params"] = ['reasoning_effort']
                    is_anthropic_provider = "anthropic" in model_name
                    if is_anthropic_provider:
                        additional_params["allowed_openai_params"] = []

                    is_openrouter_provider = "openrouter" in model_name
                    if is_openrouter_provider:
                        additional_params["allowed_openai_params"] = ['reasoning_effort', 'reasoning']
                        if reasoning_effort:
                            additional_params["reasoning"] = {
                                "effort": reasoning_effort,
                            }
                            reasoning_effort = None  # Reset to avoid passing it again

                    is_sonnet_or_opus = "sonnet" in model_name or "opus" in model_name
                    if is_sonnet_or_opus:
                        # check if reasoning_effort is set, if so, we need to make sure we don't overflow the max tokens
                        if reasoning_effort:
                            high_budget = litellm.constants.DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET
                            medium_budget = litellm.constants.DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET
                            low_budget = litellm.constants.DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET
                            current_budget = high_budget if reasoning_effort == "high" else (
                                medium_budget if reasoning_effort == "medium" else low_budget
                            )
                            # config.max_tokens cannot be below the current budget if this is the case we need to increase it above the budget
                            # The budget_tokens parameter determines the maximum number of tokens Claude is allowed to use for its internal reasoning process. In Claude 4 models, this limit applies to full thinking tokens, and not to the summarized output. Larger budgets can improve response quality by enabling more thorough analysis for complex problems, although Claude may not use the entire budget allocated, especially at ranges above 32k.
                            # budget_tokens must be set to a value less than max_tokens. 
                            if config.max_tokens is not None and isinstance(config.max_tokens, int) and config.max_tokens < current_budget:
                                self.logger.warning(
                                    f"Config max_tokens ({config.max_tokens}) is below the current budget ({current_budget}). "
                                    "Increasing max_tokens to match the budget."
                                )
                                config.max_tokens = current_budget + 2000

                    if prompt_data.dyanmic_task_config.extra_params:
                        merge_into_additional_params(prompt_data.dyanmic_task_config.extra_params, additional_params)

                    # Make the request to the LLM
                    response = await acompletion(
                        base_url=config.api_base,
                        model=model_name,
                        messages=messages,
                        api_key=user_api_key,
                        max_tokens=config.max_tokens if thinking_param and config.max_tokens is int and config.max_tokens > 0 or reasoning_effort else None,
                        stream=True,
                        thinking=thinking_param,
                        reasoning_effort=reasoning_effort,
                        function_call=None,

                        stream_options={"include_usage": True},
                        # parallel_tool_calls=parallel_tool_calls if is_native_tool_call else None,
                        temperature=1 if thinking_param else 0.1,
                        **additional_params,

                    )
                    # Successfully got a response, break the retry loop
                    break

                except litellm.exceptions.ContextWindowExceededError as e:
                    # Check if compression is enabled in the task config
                    compression_enabled = task_config.compression_config.enabled

                    # If compression is disabled, yield error and return
                    if not compression_enabled:
                        self.logger.warning("Context window exceeded but compression is disabled in task configuration")
                        yield ainetwork_types.AINetworkEventCompressionNeeded(
                            message="Context window exceeded and compression is disabled in task configuration. Please reduce your input or enable compression.",
                            errName="CompressionDisabledError",
                            origErrName="ContextWindowExceededError",
                            token_count=estimated_token_count,
                            token_limit=token_limit
                        )
                        return

                    # If compression is enabled, try to handle with compression
                    try:
                        # Process compression events
                        async for event in compression_handler.handle_context_exceeded(
                            task=task,
                            error_message=str(e),
                        ):
                            yield event

                        self.logger.info("Retrying request with compressed messages")                            
                        # Continue to the next iteration of the loop to retry
                        continue

                    except Exception as compression_error:
                        # If compression handling itself fails, log and raise
                        self.logger.error(f"Error during compression handling: {str(compression_error)}")
                        raise

                except Exception as e:
                    # Handle other connection failures
                    err = ainetwork_errors.create(e)
                    yield ainetwork_types.NetworkConnectionFailed(
                        message=err.error_message,
                        errName=err.__class__.__name__,
                        origErrName=e.__class__.__name__ if e else None
                    )
                    return

            # Signal successful connection
            yield ainetwork_types.AINetworkConnectionEstablished()

            # Process the response stream
            try:
                # Create parser for the dialect
                parser = dialect.create_parser(prompt_data)

                # Process the response stream and yield events
                async for event in ainetwork_utils.process_response_stream(
                    response, # type: ignore
                    parser,
                    tool_call_id
                ):
                    yield event

                # Signal content completion
                yield ainetwork_types.AINetworkEventContentComplete()

                # Get usage statistics
                try:
                    usage_stats = await asyncio.wait_for(usage_future, timeout=5.0)
                    yield await ainetwork_utils.create_connection_ended_event(usage_stats, is_kodu_provider)
                except asyncio.TimeoutError:
                    self.logger.warning("Timeout waiting for usage statistics")
                    yield ainetwork_types.AINetworkConnectionEnded(
                        cost=0, 
                        input_tokens=0, 
                        output_tokens=0,
                        cache_read_tokens=None,
                        cache_write_tokens=None,
                        finish_reason="stop",
                        raw_response_content=None
                    )

            except Exception as e:
                # Handle any errors during stream processing
                network_error = ainetwork_errors.create(e)
                yield ainetwork_types.NetworkConnectionInterrupted(
                    message=network_error.error_message,
                    errName=network_error.__class__.__name__,
                    origErrName=e.__class__.__name__ if e else None
                )

        except Exception as e:
            # Handle any uncaught exceptions
            err = ainetwork_errors.create(e)
            yield ainetwork_types.NetworkConnectionFailed(
                message=err.error_message,
                errName=err.__class__.__name__,
                origErrName=e.__class__.__name__ if e else None
            )
            return

Functions

estimate_token_count

estimate_token_count(messages: List[Dict[str, Any]], model: str) -> int

Estimate the token count of messages using tiktoken.

Parameters:

  • messages (List[Dict[str, Any]]) –

    List of formatted messages to count tokens for

  • model (str) –

    The model to use for token counting

Returns:

  • int

    Estimated token count

Source code in azad/ainetwork/network.py
def estimate_token_count(messages: List[Dict[str, Any]], model: str) -> int:
    """
    Estimate the token count of messages using tiktoken.

    Args:
        messages: List of formatted messages to count tokens for
        model: The model to use for token counting

    Returns:
        Estimated token count
    """
    try:
        enc = tiktoken.encoding_for_model(model)
    except KeyError:
        enc = tiktoken.get_encoding("cl100k_base")

    num_tokens = 3

    for message in messages:
        num_tokens += 3

        for key, value in message.items():
            if isinstance(value, str):
                num_tokens += len(enc.encode(value))

                if key == "name":
                    num_tokens += 1

    return num_tokens

Modules