Skip to content

azad.ainetwork.ainetwork_utils Module

azad.ainetwork.ainetwork_utils

Utility functions for AINetwork.

This module provides utility functions for the AINetwork class, including: - Content extraction from LLM response chunks - Response stream processing - Usage statistics extraction and event creation - Queue-based asynchronous processing

Attributes

logger module-attribute

logger = getLogger(__name__)

VALID_FINISH_REASONS module-attribute

VALID_FINISH_REASONS = {'stop', 'length', 'content-filter', 'tool-calls', 'error', 'other', 'unknown'}

Classes

Functions

extract_content_from_chunk

extract_content_from_chunk(chunk: ModelResponseStream) -> Optional[Tuple[Optional[str], Optional[str], Any, Optional[str], Optional[str], bool]]

Extract content from a single chunk.

Parameters:

  • chunk (ModelResponseStream) –

    The chunk to process

Returns:

  • Optional[Tuple[Optional[str], Optional[str], Any, Optional[str], Optional[str], bool]]

    Optional tuple of (normal_content, reasoning_content, tool_call, reasoning_signature, reasoning_redacted, has_finish_reason)

  • Optional[Tuple[Optional[str], Optional[str], Any, Optional[str], Optional[str], bool]]

    or None if the chunk has no content

Source code in azad/ainetwork/ainetwork_utils.py
def extract_content_from_chunk(
    chunk: litellm.types.utils.ModelResponseStream
) -> Optional[Tuple[Optional[str], Optional[str], Any, Optional[str], Optional[str], bool]]:
    """
    Extract content from a single chunk.

    Args:
        chunk: The chunk to process

    Returns:
        Optional tuple of (normal_content, reasoning_content, tool_call, 
                          reasoning_signature, reasoning_redacted, has_finish_reason)
        or None if the chunk has no content
    """
    # Skip invalid chunks
    if not chunk or not hasattr(chunk, 'choices') or not chunk.choices:
        return None

    choice = chunk.choices[0]

    # Default values
    normal_content: Optional[str] = None
    reasoning_content: Optional[str] = None
    has_finish: bool = False
    tool_call: Optional[litellm.types.utils.ChatCompletionDeltaToolCall] = None
    reasoning_signature: Optional[str] = None
    reasoning_redacted: Optional[str] = None

    # Check for finish reason
    if choice.finish_reason is not None:
        has_finish = True

    # Process delta content (newer format)
    if hasattr(choice, 'delta') and choice.delta:
        delta = choice.delta

        # Extract thinking blocks and signature if available
        if hasattr(delta, "thinking_blocks") and delta.thinking_blocks:
            first_thinking_block = delta.thinking_blocks[0]
            if 'signature' in first_thinking_block:
                reasoning_signature = first_thinking_block['signature'] # type: ignore
                logger.debug(f"Reasoning Signature from delta: {reasoning_signature}")

        # Check for reasoning content in delta
        if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
            reasoning_content = delta.reasoning_content
            logger.debug(f"Reasoning content from delta: {reasoning_content}")

        # Check for regular content in delta
        elif hasattr(delta, 'content') and delta.content:
            normal_content = delta.content
            logger.debug(f"Content from delta: {normal_content}")

        # Check for tool calls in delta
        elif hasattr(delta, 'tool_calls') and delta.tool_calls:
            tool_call = delta.tool_calls[0]
            logger.debug(f"Tool call from delta: {tool_call}")

    # Return the extracted content
    return normal_content, reasoning_content, tool_call, reasoning_signature, reasoning_redacted, has_finish

process_response_stream async

process_response_stream(response_stream: CustomStreamWrapper, parser: DialectParser, tool_call_id: Optional[str]) -> AsyncGenerator[AINetworkEventUnion, None]

Process the streaming response, accumulating reasoning content and yielding events.

Parameters:

  • response_stream (CustomStreamWrapper) –

    The streaming response from the LLM

  • parser (DialectParser) –

    The dialect parser

  • tool_call_id (Optional[str]) –

    Optional ID for tool calls

Yields:

Source code in azad/ainetwork/ainetwork_utils.py
async def process_response_stream(
    response_stream: CustomStreamWrapper,
    parser: DialectParser,
    tool_call_id: Optional[str]
) -> AsyncGenerator[ainetwork_types.AINetworkEventUnion, None]:
    """
    Process the streaming response, accumulating reasoning content and yielding events.

    Args:
        response_stream: The streaming response from the LLM
        parser: The dialect parser
        tool_call_id: Optional ID for tool calls

    Yields:
        Network events from the processed response
    """
    # Initialize state
    first_token_emitted = False
    reasoning_active = False
    text_active = False
    is_in_tool_call: Union[Literal[False], Literal["native"], Literal["feed"]] = False
    reasoning_buffer: List[str] = []

    # We'll use a queue to manage chunks to avoid blocking
    chunk_queue: asyncio.Queue = asyncio.Queue()


    # Create a task to read chunks and put them in the queue
    async def reader_task():
        try:
            async for chunk in response_stream:
                await chunk_queue.put(chunk)
            # Signal end of stream
            await chunk_queue.put(None)
        except asyncio.CancelledError:
            # Handle task cancellation gracefully
            logger.debug("Reader task cancelled")
            raise
        except Exception as e:
            # Log detailed error information for debugging
            error_type = type(e).__name__
            logger.exception(f"Error in stream reader task: {error_type} - {str(e)}")

            # Put the error in the queue for the main task to handle
            await chunk_queue.put(e)

    # Start the reader task with proper error handling
    reader = asyncio.create_task(reader_task())
    logger.debug("Started stream reader task")

    try:
        # Process chunks from the queue
        while True:
            # Get the next chunk
            chunk = await chunk_queue.get()                
            # Handle end of stream
            if chunk is None:
                break

            # Handle errors
            if isinstance(chunk, Exception):
                raise chunk

            # Process the chunk with error handling
            try:
                content_info = extract_content_from_chunk(chunk)

                # Skip empty chunks
                if content_info is None:
                    continue

                normal_content, reasoning_content, tool_call_delta, reasoning_signature, reasoning_redacted, has_finish = content_info
            except Exception as e:
                logger.error(f"Error extracting content from chunk: {type(e).__name__} - {str(e)}")
                # Skip this problematic chunk and continue processing
                continue

            # Emit first token event if needed
            if not first_token_emitted and (normal_content or reasoning_content):
                first_token_emitted = True
                yield ainetwork_types.AINetworkFirstToken()

            if reasoning_signature:
                yield ainetwork_types.AIEventReasoningChunk(content="", signature=reasoning_signature)
            # Process reasoning content - accumulate in buffer and emit events
            if reasoning_content:
                # Start reasoning section if not already started
                if not reasoning_active:
                    reasoning_active = True
                    yield ainetwork_types.AINetworkEventReasoningStart(
                        startedAt=time.time().__round__() * 1000,
                    )

                # Add to reasoning buffer and emit chunk events in real-time
                reasoning_buffer.append(reasoning_content)

                # Emit reasoning content immediately for real-time processing
                yield ainetwork_types.AIEventReasoningChunk(content=reasoning_content, signature=reasoning_signature)

            # Process normal content
            if normal_content:
                # End reasoning section if active
                if reasoning_active:
                    # Clear buffer and end reasoning section
                    reasoning_buffer = []
                    reasoning_active = False
                    yield ainetwork_types.AINetworkEventReasoningEnd(
                        endedAt=time.time().__round__() * 1000,
                    )

                # Start text section if not active
                if not text_active:
                    text_active = True
                    yield ainetwork_types.AINetworkEventTextStart()

                # Feed normal content to parser with error handling
                try:
                    normal_content_bytes = normal_content.encode('utf-8')

                    # Special handling for content that contains HTML tags to prevent parsing issues
                    has_html_tags = "<" in normal_content and ">" in normal_content

                    if has_html_tags:
                        # For HTML content, we'll sanitize it for the parser
                        # to prevent premature detection of closing tags as special syntax
                        logger.debug("Detected HTML-like content, ensuring proper parsing")

                        # Parse with additional safeguards for HTML content
                        for event in parser.feed(normal_content_bytes):
                            try:
                                # Add tool_call_id to events that need it
                                if tool_call_id is not None and isinstance(event, (
                                    ainetwork_types.AINetworkEventToolName,
                                    ainetwork_types.AINetworkEventToolReady
                                )):
                                    event.tool_call_id = tool_call_id
                                yield event
                            except Exception as e:
                                # Log error but continue processing other events
                                logger.error(f"Error processing HTML content event: {type(e).__name__} - {str(e)}")
                    else:
                        # Normal content handling
                        for event in parser.feed(normal_content_bytes):
                            try:
                                # Add tool_call_id to events that need it
                                if tool_call_id is not None and isinstance(event, (
                                    ainetwork_types.AINetworkEventToolName,
                                    ainetwork_types.AINetworkEventToolReady
                                )):
                                    event.tool_call_id = tool_call_id
                                yield event
                            except Exception as e:
                                # Log error but continue processing other events
                                logger.error(f"Error processing parser event: {type(e).__name__} - {str(e)}")
                except Exception as e:
                    # Log parser error but continue with other chunks
                    logger.error(f"Error in parser feed: {type(e).__name__} - {str(e)}")

            if tool_call_delta:
                # End reasoning section if active
                if reasoning_active:
                    # Clear buffer and end reasoning section
                    reasoning_buffer = []
                    reasoning_active = False
                    yield ainetwork_types.AINetworkEventReasoningEnd(
                        endedAt=time.time().__round__() * 1000,
                    )
                is_in_tool_call = "native"
                # feed the tool call delta to the parser with error handling
                try:
                    for event in parser.feed_tool_call_delta(tool_call_delta):
                        try:
                            # Add tool_call_id to events that need it
                            if tool_call_id is not None and isinstance(event, (
                                ainetwork_types.AINetworkEventToolName,
                                ainetwork_types.AINetworkEventToolReady
                            )):
                                event.tool_call_id = tool_call_id
                            yield event
                        except Exception as e:
                            # Log error but continue processing other events
                            logger.error(f"Error processing tool call event: {type(e).__name__} - {str(e)}")
                except Exception as e:
                    # Log parser error but continue with other chunks
                    logger.error(f"Error in tool call parser feed: {type(e).__name__} - {str(e)}")
            # Handle finish reason
            if has_finish:
                # If we're in reasoning mode, end it properly
                if reasoning_active:
                    reasoning_buffer = []
                    reasoning_active = False
                    yield ainetwork_types.AINetworkEventReasoningEnd(
                        endedAt=time.time().__round__() * 1000
                    )
                if is_in_tool_call == "native":
                    is_in_tool_call = False

                # End text section if active
                if text_active:
                    text_active = False
                    yield ainetwork_types.AINetworkEventTextEnd()
                break
    finally:
        # Wait for the reader task to complete
        try:
            await reader
            logger.debug("Stream reader task completed")
            # extract raw text from the text chunk queue
            # Check if this was a multi-tool response
            if hasattr(parser, "is_multi_tool_response") and getattr(parser, "is_multi_tool_response")():
                tool_calls = getattr(parser, "get_detected_tool_calls")()
                if len(tool_calls) > 1:
                    # Log multi-tool detection
                    tool_names = [tc["tool_name"] for tc in tool_calls]
                    logger.info(f"Detected multiple tool calls: {len(tool_calls)} tools - {tool_names}")

                    # Signal multi-tool start - this should be Start, not End to avoid premature completion
                    yield ainetwork_types.AINetworkEventMultiToolStart()

                    # Log the tool calls being processed
                    for i, tc in enumerate(tool_calls):
                        logger.info(f"Multi-tool #{i+1}: {tc['tool_name']} with {len(tc['args'])} parameters")

                        # Generate unique tool_call_id if not provided
                        tool_id = tc.get('tool_call_id')
                        tool_call_id = str(tool_id) if tool_id is not None else str(uuid.uuid4())

                        # Explicitly emit a AINetworkEventToolName event for this tool
                        yield ainetwork_types.AINetworkEventToolName(
                            tool_name=tc['tool_name'],
                            tool_call_id=tool_call_id
                        )

                        # Emit AINetworkEventToolReady event for each tool with isolated parameters
                        yield ainetwork_types.AINetworkEventToolReady(
                            tool_name=tc['tool_name'],
                            args={k: v for k, v in tc['args'].items()},
                            tool_call_id=tool_call_id,
                        )

                        # Ensure each tool's args are properly isolated
                        for k, v in tc['args'].items():
                            if isinstance(v, str):
                                # Log only a preview of the value to avoid excessive logging
                                preview = v[:50] + "..." if len(v) > 50 else v
                                logger.debug(f"Tool {i+1} param {k}: {preview}")

                    # Emit multi-tool end event with correct count to signal LLM that all tools are defined
                    yield ainetwork_types.AINetworkEventMultiToolEnd(
                        tool_count=len(tool_calls)
                    )

                    # Special processing for HTML content in multi-tool calls
                    for tc in tool_calls:
                        # Check for HTML content in any parameter
                        for param_name, param_value in tc.get('args', {}).items():
                            if isinstance(param_value, str) and '<' in param_value and '>' in param_value:
                                logger.info(f"Detected potential HTML content in tool {tc['tool_name']}, param {param_name}")

                                # Check if HTML content is incomplete (missing closing tags)
                                if '<html' in param_value.lower() and '</html>' not in param_value.lower():
                                    logger.warning(f"HTML content appears incomplete in {tc['tool_name']}")
                                    # We still proceed as the content may be valid enough for use

                    # Also emit a tool request summary log for better UI tracking
                    tool_names = [tc["tool_name"] for tc in tool_calls]
                    logger.log(LogLevel.PROD, "", extra={
                        "event": "multi_tool_request",
                        "tool_count": len(tool_calls),
                        "tool_names": tool_names,
                        "is_final": True
                    })

                    # Final sanity check to ensure tools have isolated parameters
                    for i, tc in enumerate(tool_calls):
                        tc["args"] = {k: v for k, v in tc["args"].items()}



        except Exception as e:
            logger.error(f"Error checking for multiple tool calls: {type(e).__name__} - {str(e)}")

calculate_tokens_from_usage

calculate_tokens_from_usage(usage: Dict[str, Any])
Source code in azad/ainetwork/ainetwork_utils.py
def calculate_tokens_from_usage(usage: Dict[str, Any]):
    usage_obj = litellm.Usage(**usage)

    reasoning_tokens = 0
    cache_read_tokens = 0
    if usage_obj.completion_tokens_details:
        completion_tokens_details = usage_obj.completion_tokens_details
        reasoning_tokens = completion_tokens_details.reasoning_tokens if completion_tokens_details.reasoning_tokens else 0

    if usage_obj.prompt_tokens_details:
        prompt_tokens_details = usage_obj.prompt_tokens_details
        cache_read_tokens = prompt_tokens_details.cached_tokens if prompt_tokens_details.cached_tokens else 0

    if cache_read_tokens == 0:
        try:
            cache_read_tokens = usage_obj.cache_read_input_tokens
        except AttributeError:
            cache_read_tokens = 0

    output_tokens = usage_obj.completion_tokens if usage_obj.completion_tokens else 0
    input_tokens = usage_obj.prompt_tokens - cache_read_tokens if usage_obj.prompt_tokens else 0
    try:
        cache_write_tokens = usage_obj.cache_creation_input_tokens
    except AttributeError:
        cache_write_tokens = 0
    current_context_count = input_tokens + cache_write_tokens + cache_read_tokens - reasoning_tokens

    return output_tokens, input_tokens, cache_write_tokens, cache_read_tokens, reasoning_tokens, current_context_count

create_connection_ended_event async

create_connection_ended_event(response_data: Dict[str, Any], is_kodu_provider: bool) -> AINetworkConnectionEnded

Create a connection ended event with usage statistics.

Parameters:

  • response_data (Dict[str, Any]) –

    Usage data from the LLM

  • is_kodu_provider (bool) –

    Whether the call was through Kodu tunnel

Returns:

Source code in azad/ainetwork/ainetwork_utils.py
async def create_connection_ended_event(response_data: Dict[str, Any], is_kodu_provider: bool) -> ainetwork_types.AINetworkConnectionEnded:
    """
    Create a connection ended event with usage statistics.

    Args:
        response_data: Usage data from the LLM
        is_kodu_provider: Whether the call was through Kodu tunnel

    Returns:
        A connection ended event with usage statistics
    """
    # Extract cost information safely
    cost = 0.0
    try:
        if response_data.get("usage") is not None and response_data["usage"].get("cost") is not None:
            cost_value = response_data["usage"]["cost"]
        else:
            cost_value = response_data.get('response_cost', 0)
        cost = float(cost_value) if cost_value is not None else 0.0
        cost = max(0.0, cost)  # Ensure cost is non-negative
    except (ValueError, TypeError) as e:
        logger.warning(f"Error converting cost: {e}")

    # Extract usage information safely (handling different provider structures)
    output_tokens = 0
    input_tokens = 0
    cache_write_tokens = None
    cache_read_tokens = None
    reasoning_tokens = None
    current_context_count = None

    try:
        # Try to get usage from streaming response (Anthropic structure) litellm enforces this structure for all providers.
        if 'async_complete_streaming_response' in response_data and 'usage' in response_data['async_complete_streaming_response']:
            usage = response_data['async_complete_streaming_response']['usage']
            (output_tokens,
             input_tokens,
             cache_write_tokens,
             cache_read_tokens,
             reasoning_tokens,
             current_context_count) = calculate_tokens_from_usage(usage)

        # Try other common structures (OpenAI, etc.)
        elif 'usage' in response_data:
            usage = response_data['usage']
            (output_tokens,
             input_tokens,
             cache_write_tokens,
             cache_read_tokens,
             reasoning_tokens,
             current_context_count) = calculate_tokens_from_usage(usage)

        # If response object directly contains usage fields
        else:
            input_tokens = response_data.get('prompt_tokens', 0)
            output_tokens = response_data.get('completion_tokens', 0)

        # Ensure values are integers
        input_tokens = int(input_tokens) if input_tokens is not None else 0
        output_tokens = int(output_tokens) if output_tokens is not None else 0
        if cache_read_tokens is not None:
            cache_read_tokens = int(cache_read_tokens)
        if cache_write_tokens is not None:
            cache_write_tokens = int(cache_write_tokens)

    except Exception as e:
        logger.warning(f"Error extracting usage information: {e}")
        logger.debug(f"Raw usage data: {response_data}")

    finish_reason_str: Optional[str] = None
    raw_response_content: Optional[str] = None

    choices = response_data.get("choices", [])
    if choices and isinstance(choices, list) and len(choices) > 0:
        first_choice = choices[0]
        if isinstance(first_choice, dict):
            finish_reason_str = first_choice.get("finish_reason")
            message = first_choice.get("message", {})
            if isinstance(message, dict):
                raw_response_content_dict = message.get("content")
                if isinstance(raw_response_content_dict, dict) and 'text' in raw_response_content_dict:
                    raw_response_content = raw_response_content_dict['text']
                elif isinstance(raw_response_content_dict, str):
                    raw_response_content = raw_response_content_dict

    finish_reason = _map_finish_reason(finish_reason_str)


    # Handle kodu request cost calculation and credit deduction
    if is_kodu_provider:
        try:
            # Get the cost margin from environment variables (default: 1.1 for 10% increase)
            from ..db_models import db
            from ..env_settings import settings

            cost_margin = settings.COST_MARGIN or 1

            # Apply the cost margin
            cost = cost * cost_margin

            # Deduct the cost from the user's credits in the database using SQLAlchemy

            # Check if deployed and has DB access
            if settings.FLY_PROCESS_GROUP is not None and settings.DATABASE_URL is not None and settings.TURSO_AUTH_TOKEN is not None:
                # Get the message and inference ID
                inference_id = response_data.get('id', str(uuid.uuid4()))

                # Extract model and usage information from usage_data
                model = response_data.get('model', '')
                usage = response_data.get('usage', {})

                # Update user credits and record transaction using SQLAlchemy with API key
                api_key = response_data.get('original_api_key')

                started_at = response_data.get('created')

                if not started_at:
                    # Fallback to current time if 'created' is not available
                    started_at = int(time.time() * 1000)
                    logger.warning("No 'created' timestamp found in response, using current time for started_at.")
                else:
                    started_at = int(started_at)  * 1000  # Convert to milliseconds

                if not api_key:
                    logger.warning("No API key found in response, skipping credit deduction.")
                    raise ValueError("No API key found in response data for credit deduction.")

                success = db.update_user_credits_and_record(
                    cost=cost,
                    started_at=started_at,
                    inference_id=inference_id,
                    model=model,
                    usage=usage,
                    api_key=api_key
                )

            else:
                logger.warning("Not deployed or no DB access. Skipping credit deduction.")
        except Exception as e:
            logger.error(f"Error handling kodu request: {str(e)}")

    return ainetwork_types.AINetworkConnectionEnded(
        cost=cost,
        input_tokens=input_tokens,
        output_tokens=output_tokens,
        cache_read_tokens=cache_read_tokens,
        cache_write_tokens=cache_write_tokens,
        generation_id=response_data.get('id', None),
        reasoning_tokens=reasoning_tokens,
        current_context_count=current_context_count,
        finish_reason=finish_reason,
        raw_response_content=raw_response_content
    )

delay async

delay(ms: int)

Simple asynchronous delay.

Source code in azad/ainetwork/ainetwork_utils.py
async def delay(ms: int):
    """Simple asynchronous delay."""
    await asyncio.sleep(ms / 1000.0)

with_retry async

with_retry(fn: Callable[[], Awaitable[Any]], retries: int = 5, initial_delay: int = 250) -> Any

Async retry decorator.

Source code in azad/ainetwork/ainetwork_utils.py
async def with_retry(fn: Callable[[], Awaitable[Any]], retries: int = 5, initial_delay: int = 250) -> Any:
    """Async retry decorator."""
    last_error: Optional[Exception] = None
    for i in range(retries):
        try:
            return await fn()
        except Exception as error:
            last_error = error
            logger.warning(f"Retry attempt {i+1}/{retries} failed for {fn.__name__ if hasattr(fn, '__name__') else 'anonymous function'}: {error}")
            if i < retries - 1:
                wait_time = (i + 1) * initial_delay
                logger.info(f"Retrying in {wait_time}ms...")
                await delay(wait_time) # Incremental delay
            else:
                logger.error(f"All {retries} retry attempts failed.")
    # If loop finishes, raise the last recorded error
    if last_error is not None:
        raise last_error
    else:
        # Should not happen if retries > 0, but defensively handle
        raise RuntimeError("Retry loop finished without success or error.")

get_openrouter_generation_data_py async

get_openrouter_generation_data_py(generation_id: str, api_key: str) -> Optional[Dict[str, Any]]

Fetches generation data from OpenRouter API with retry logic.

Parameters:

  • generation_id (str) –

    The ID of the generation from OpenRouter.

  • api_key (str) –

    The OpenRouter API key.

Returns:

  • Optional[Dict[str, Any]]

    The 'data' dictionary from the OpenRouter response, or None if fetching fails.

Source code in azad/ainetwork/ainetwork_utils.py
async def get_openrouter_generation_data_py(generation_id: str, api_key: str) -> Optional[Dict[str, Any]]:
    """
    Fetches generation data from OpenRouter API with retry logic.

    Args:
        generation_id: The ID of the generation from OpenRouter.
        api_key: The OpenRouter API key.

    Returns:
        The 'data' dictionary from the OpenRouter response, or None if fetching fails.
    """
    url = f"https://openrouter.ai/api/v1/generation?id={generation_id}"
    headers = {"Authorization": f"Bearer {api_key}"}

    async def fetch_data():
        # Use a context manager for the client
        async with httpx.AsyncClient(timeout=20.0) as client: # Added timeout
            logger.info(f"Attempting to fetch OpenRouter data from: {url}")
            response = await client.get(url, headers=headers)
            # Raise exceptions for bad status codes (4xx or 5xx)
            response.raise_for_status()
            try:
                response_data = response.json()
                if 'data' in response_data:
                    logger.info(f"Successfully fetched OpenRouter data for ID {generation_id}")
                    return response_data['data']
                else:
                    logger.warning(f"OpenRouter response for ID {generation_id} missing 'data' field: {response_data}")
                    # Treat missing 'data' as a failure for retry purposes
                    raise ValueError("OpenRouter response missing 'data' field")
            except json.JSONDecodeError as json_err:
                 logger.error(f"Failed to decode JSON from OpenRouter for ID {generation_id}: {json_err}")
                 raise # Re-raise to trigger retry

    try:
        # Wrap the fetch_data call with the retry logic
        generation_data = await with_retry(fetch_data, retries=3, initial_delay=500) # Reduced retries slightly
        return generation_data
    except Exception as e:
        logger.error(f"Failed to fetch OpenRouter generation data for ID {generation_id} after retries: {type(e).__name__} - {e}")
        return None

Modules