Skip to content

azad.compression.strategies.compact Module

azad.compression.strategies.compact

Compact Summarization compression strategy.

This module provides an AI-assisted compression strategy that preserves critical messages (like the first task messages and recent exchanges), while creating a compact summary of the rest to maintain context.

Attributes

Classes

CompactSummarizationStrategy

CompactSummarizationStrategy()

Bases: CompressionStrategy

Source code in azad/compression/strategies/compact.py
def __init__(self):
    self.logger = logging.getLogger(__name__)
    self.ai_network = None
Attributes
logger instance-attribute
logger = getLogger(__name__)
Functions
compress async
compress(task: Task, new_checkpoint: Optional[CompressionCheckpoint], config: Any) -> List[Message]

Called either to create/update a checkpoint (new_checkpoint != None) or to do a "transform-only" pass (new_checkpoint == None).

When new_checkpoint is None, we skip re-compression and simply reuse previously-compressed state. Otherwise, we apply normal compression, generate a summary, and create a new checkpoint. Handles context limit errors during summarization by moving messages to the kept set.

Source code in azad/compression/strategies/compact.py
async def compress(self,
                   task: Task,
                   new_checkpoint: Optional[CompressionCheckpoint],
                   config: Any
                  ) -> List[Message]:
    """
    Called either to create/update a checkpoint (new_checkpoint != None)
    or to do a "transform-only" pass (new_checkpoint == None).

    When new_checkpoint is None, we skip re-compression and simply reuse
    previously-compressed state. Otherwise, we apply normal compression,
    generate a summary, and create a new checkpoint. Handles context
    limit errors during summarization by moving messages to the kept set.
    """
    # If no new checkpoint, do transform-only (reuse existing compression)
    if new_checkpoint is None:
        self.logger.info("Performing transform-only operation.")
        return self.transform_only(task)

    # Otherwise, this is an actual compression event:
    self.logger.info("Starting compression event.")
    compact_config = config
    all_messages = task.current_task_messages()
    if not all_messages:
        self.logger.info("No messages in task to compress.")
        return []

    # 1) Identify initial critical vs. compressible messages
    critical_messages, compressible_messages = self._get_critical_messages(
        all_messages, compact_config
    )

    # Keep track of messages moved due to context limits
    rescued_messages: List[Message] = []

    # 2) Try to generate a summary, handling context limits iteratively
    summary = None
    if compressible_messages:
        messages_to_summarize = list(compressible_messages) # Work on a copy
        while messages_to_summarize:
            try:
                self.logger.info(f"Attempting summary with {len(messages_to_summarize)} compressible messages.")
                summary = await self._generate_summary(messages_to_summarize, compact_config, task)
                self.logger.debug("Summary generation successful.")
                # Successfully generated summary, break the loop
                break
            except litellm.exceptions.ContextWindowExceededError as e:
                self.logger.warning(f"Context window exceeded during summary generation: {e}")
                # Move the *last* message from the current summarization list
                # to the rescued list (it will be kept)
                rescued_message = messages_to_summarize.pop()
                rescued_messages.append(rescued_message)
                self.logger.debug(f"Removed message {rescued_message.id} from summarization batch due to context limit. It will be kept. Retrying summary...")
                if not messages_to_summarize:
                    self.logger.warning("No compressible messages left after attempting to handle context limits. No summary will be generated.")
                    summary = None # Ensure summary is None if we exhausted all messages
                    break # Exit loop
            except Exception as e:
                # Handle other potential errors during summarization
                self.logger.error(f"Unexpected error generating summary: {e}", exc_info=True)
                summary = None # Failed to generate summary
                new_checkpoint.metadata["summary_error"] = str(e)
                break # Exit loop
        else:
             # This else block executes if the while loop finished *without* break
             # i.e., messages_to_summarize became empty due to context limits
             self.logger.warning("Exhausted all compressible messages trying to fit context window. No summary generated.")
             summary = None


    # 3) Update critical and compressible lists *after* potential rescues
    # Add rescued messages to the critical set
    critical_messages.extend(rescued_messages)
    # The remaining messages in messages_to_summarize are the ones actually summarized (if any)
    # But compressible_messages should reflect the final state *before* summarization attempt started
    # So, we update the final compressed_message_ids based on which ones were *not* rescued.
    original_compressible_ids = {m.id for m in compressible_messages}
    rescued_message_ids = {m.id for m in rescued_messages}
    final_compressed_message_ids = original_compressible_ids - rescued_message_ids

    # Update the final list of compressible messages (those *actually* compressed)
    final_compressible_messages = [m for m in compressible_messages if m.id in final_compressed_message_ids]

    # Get final kept message IDs
    final_kept_message_ids = {m.id for m in critical_messages}


    # 4) Update checkpoint metadata with final state
    new_checkpoint.kept_message_ids = list(final_kept_message_ids)
    new_checkpoint.compressed_message_ids = list(final_compressed_message_ids)
    new_checkpoint.metadata = {
        "compressed_count": len(final_compressed_message_ids),
        "kept_count": len(final_kept_message_ids),
        "rescued_due_to_context_limit": len(rescued_message_ids),
    }

    # 5) Add compression message to Task if a summary was generated
    if summary and final_compressed_message_ids:
        self.logger.debug("Summary generated, adding compression message to task.")
        new_checkpoint.metadata["summary_text"] = summary
        new_checkpoint.metadata["summary_length"] = len(summary)

        try:
            # Determine start/end indices for the *original* range of compressed messages
            message_id_to_index = {msg.id: i for i, msg in enumerate(all_messages)}
            # Use original compressible_messages to define the span, even if some were rescued
            indices = [message_id_to_index.get(msg.id, -1) for msg in compressible_messages]
            valid_indices = [idx for idx in indices if idx >= 0]

            if valid_indices:
                start_idx = min(valid_indices)
                end_idx = max(valid_indices)

                task.add_compression_message(
                    start_idx=start_idx,
                    end_idx=end_idx,
                    reason="Compressed messages to maintain context",
                    strategies=[self.strategy_type.value],
                    # Report IDs actually compressed and kept
                    compressed_message_ids=list(final_compressed_message_ids),
                    kept_message_ids=list(final_kept_message_ids),
                    metadata={
                        "summary_text": summary,
                        "summary_length": len(summary),
                        # Report counts based on final state
                        "compressed_count": len(final_compressed_message_ids),
                        "initially_compressible_count": len(compressible_messages),
                        "rescued_count": len(rescued_message_ids),
                    },
                )
            else:
                self.logger.warning("Could not determine valid indices for compression message.")

        except Exception as e:
            self.logger.error(f"Error creating or adding compression message: {e}", exc_info=True)
            # Update checkpoint metadata even if task modification fails
            if "summary_error" not in new_checkpoint.metadata: # Don't overwrite previous error
                 new_checkpoint.metadata["compression_message_error"] = str(e)
    elif not summary and final_compressed_message_ids:
         # We intended to compress but failed to generate a summary (e.g. due to errors or context limits)
         self.logger.warning("Messages were identified for compression, but no summary was generated.")
         new_checkpoint.metadata["summary_text"] = f"[Failed to summarize {len(final_compressed_message_ids)} messages]"
         # Optionally add a compression message indicating failure? Depends on desired behavior.
         # For now, we don't add a compression message if summary generation failed.

    # 6) Log final stats
    # Recalculate total_count accurately based on non-compression messages in the original list
    original_non_compression_msgs = [m for m in all_messages if m.role != MessageRole.compression]
    total_count = len(original_non_compression_msgs)

    self.logger.info(
        f"Compact summarization finished: Kept {len(final_kept_message_ids)}/{total_count} original messages. "
        f"Successfully compressed {len(final_compressed_message_ids)} messages. "
        f"Rescued {len(rescued_message_ids)} messages due to context limits."
    )

    # 7) Return the final "kept" messages (original critical + rescued, excluding compression role messages)
    final_kept_messages: List[Message] = [
        msg
        for msg in all_messages # Iterate through original to preserve order somewhat
        if msg.id in final_kept_message_ids and msg.role != MessageRole.compression
    ]

    # Ensure all messages intended to be kept are included, even if order isn't perfect
    final_kept_ids_set = set(m.id for m in final_kept_messages)
    missing_kept = [m for m in critical_messages if m.id not in final_kept_ids_set]
    if missing_kept:
         self.logger.warning(f"Adding {len(missing_kept)} kept messages that were missed in primary list construction.")
         # A simple append might mess up order, but ensures they are returned
         final_kept_messages.extend(missing_kept) 

    return final_kept_messages
transform_only
transform_only(task: Task) -> List[Message]

When no new checkpoint is provided, we simply reuse whatever compression checkpoints already exist. We do NOT create or update any compression messages. We just figure out which messages are kept vs. compressed based on previously recorded checkpoints, and then return them in the order:

1) All kept messages up to the checkpoint
2) An informational summary (if one exists)
3) The remaining kept messages (after the checkpoint).
Source code in azad/compression/strategies/compact.py
def transform_only(self, task: Task) -> List[Message]:
    """
    When no new checkpoint is provided, we simply reuse whatever compression
    checkpoints already exist. We do NOT create or update any compression
    messages. We just figure out which messages are kept vs. compressed based
    on previously recorded checkpoints, and then return them in the order:

        1) All kept messages up to the checkpoint
        2) An informational summary (if one exists)
        3) The remaining kept messages (after the checkpoint).
    """
    all_messages = task.current_task_messages()
    if not all_messages:
        return []

    # Separate compression messages from normal messages.
    compression_msgs: List[Message] = [m for m in all_messages if m.role == MessageRole.compression]
    normal_msgs: List[Message] = [m for m in all_messages if m.role != MessageRole.compression]

    # If there are no compression checkpoints, keep all normal messages as-is.
    if not compression_msgs:
        self.logger.debug("[transform_only] No existing checkpoints. Keeping all messages.")
        return normal_msgs

    # Identify which messages are compressed vs. kept
    compressed_message_ids = set()
    kept_message_ids = set()

    for msg in normal_msgs:
        previously_compressed = False
        for comp_msg in compression_msgs:
            for part in comp_msg.content:
                if (
                    isinstance(part, CompressionPart) 
                    and part.compressed_message_ids
                    and msg.id in part.compressed_message_ids
                ):
                    compressed_message_ids.add(msg.id)
                    previously_compressed = True
                    break
            if previously_compressed:
                break

        if not previously_compressed:
            kept_message_ids.add(msg.id)

    total_count = len(normal_msgs)
    kept_count = len(kept_message_ids)
    compressed_count = len(compressed_message_ids)

    self.logger.debug(
        f"[transform_only] Reusing existing checkpoints. "
        f"Kept {kept_count}/{total_count} messages, compressed {compressed_count}."
    )

    # Find any summary text from the LAST compression message that includes it
    summary_text = None
    for comp_msg in reversed(compression_msgs):
        for part in comp_msg.content:
            if isinstance(part, CompressionPart) and part.metadata and isinstance(part.metadata, dict):
                st = (
                    part.metadata.get('checkpoint', {})
                                .get('metadata', {})
                                .get('summary_text')
                )
                if st:
                    summary_text = st
                    break
        if summary_text:
            break

    # Filter down to only kept messages (in chronological order)
    kept_messages = [m for m in normal_msgs if m.id in kept_message_ids]
    kept_messages.sort(key=lambda x: x.started_ts)  # or use x.id if that's guaranteed chronological

    # If there is no summary, just return the kept messages in order.
    if not summary_text:
        return kept_messages

    # Otherwise, we split at the largest (most recent) compressed message's timestamp or ID.
    # We'll use started_ts to determine "before vs. after the checkpoint."
    compressed_msg_map = {m.id: m for m in normal_msgs if m.id in compressed_message_ids}
    if compressed_msg_map:
        # Grab the message with the largest started_ts among the compressed
        latest_ts = max(msg.started_ts for msg in compressed_msg_map.values())
    else:
        # If for some reason no compressed messages had valid timestamps, treat all as post-summary
        latest_ts = 0

    pre_summary: List[Message] = []
    post_summary: List[Message] = []
    for m in kept_messages:
        if m.started_ts <= latest_ts:
            pre_summary.append(m)
        else:
            post_summary.append(m)

    # Create an informational message for the summary
    informational_part = InformationalPart(
        is_visible_ai=True,
        is_visible_ui=False,
        informational_type="summary",
        details=summary_text,
        additional_data=None
    )

    ts = int(time.time())
    new_message = InformationalMessage(
        task_id=task.id,
        started_ts=ts,
        finished_ts=ts, # Informational is instantaneous
        content=[informational_part]
    )
    # Return them in the order: [informational summary] -> [post-summary kept]
    return [new_message] + post_summary