async def compress(self,
task: Task,
new_checkpoint: Optional[CompressionCheckpoint],
config: Any
) -> List[Message]:
"""
Called either to create/update a checkpoint (new_checkpoint != None)
or to do a "transform-only" pass (new_checkpoint == None).
When new_checkpoint is None, we skip re-compression and simply reuse
previously-compressed state. Otherwise, we apply normal compression,
generate a summary, and create a new checkpoint. Handles context
limit errors during summarization by moving messages to the kept set.
"""
# If no new checkpoint, do transform-only (reuse existing compression)
if new_checkpoint is None:
self.logger.info("Performing transform-only operation.")
return self.transform_only(task)
# Otherwise, this is an actual compression event:
self.logger.info("Starting compression event.")
compact_config = config
all_messages = task.current_task_messages()
if not all_messages:
self.logger.info("No messages in task to compress.")
return []
# 1) Identify initial critical vs. compressible messages
critical_messages, compressible_messages = self._get_critical_messages(
all_messages, compact_config
)
# Keep track of messages moved due to context limits
rescued_messages: List[Message] = []
# 2) Try to generate a summary, handling context limits iteratively
summary = None
if compressible_messages:
messages_to_summarize = list(compressible_messages) # Work on a copy
while messages_to_summarize:
try:
self.logger.info(f"Attempting summary with {len(messages_to_summarize)} compressible messages.")
summary = await self._generate_summary(messages_to_summarize, compact_config, task)
self.logger.debug("Summary generation successful.")
# Successfully generated summary, break the loop
break
except litellm.exceptions.ContextWindowExceededError as e:
self.logger.warning(f"Context window exceeded during summary generation: {e}")
# Move the *last* message from the current summarization list
# to the rescued list (it will be kept)
rescued_message = messages_to_summarize.pop()
rescued_messages.append(rescued_message)
self.logger.debug(f"Removed message {rescued_message.id} from summarization batch due to context limit. It will be kept. Retrying summary...")
if not messages_to_summarize:
self.logger.warning("No compressible messages left after attempting to handle context limits. No summary will be generated.")
summary = None # Ensure summary is None if we exhausted all messages
break # Exit loop
except Exception as e:
# Handle other potential errors during summarization
self.logger.error(f"Unexpected error generating summary: {e}", exc_info=True)
summary = None # Failed to generate summary
new_checkpoint.metadata["summary_error"] = str(e)
break # Exit loop
else:
# This else block executes if the while loop finished *without* break
# i.e., messages_to_summarize became empty due to context limits
self.logger.warning("Exhausted all compressible messages trying to fit context window. No summary generated.")
summary = None
# 3) Update critical and compressible lists *after* potential rescues
# Add rescued messages to the critical set
critical_messages.extend(rescued_messages)
# The remaining messages in messages_to_summarize are the ones actually summarized (if any)
# But compressible_messages should reflect the final state *before* summarization attempt started
# So, we update the final compressed_message_ids based on which ones were *not* rescued.
original_compressible_ids = {m.id for m in compressible_messages}
rescued_message_ids = {m.id for m in rescued_messages}
final_compressed_message_ids = original_compressible_ids - rescued_message_ids
# Update the final list of compressible messages (those *actually* compressed)
final_compressible_messages = [m for m in compressible_messages if m.id in final_compressed_message_ids]
# Get final kept message IDs
final_kept_message_ids = {m.id for m in critical_messages}
# 4) Update checkpoint metadata with final state
new_checkpoint.kept_message_ids = list(final_kept_message_ids)
new_checkpoint.compressed_message_ids = list(final_compressed_message_ids)
new_checkpoint.metadata = {
"compressed_count": len(final_compressed_message_ids),
"kept_count": len(final_kept_message_ids),
"rescued_due_to_context_limit": len(rescued_message_ids),
}
# 5) Add compression message to Task if a summary was generated
if summary and final_compressed_message_ids:
self.logger.debug("Summary generated, adding compression message to task.")
new_checkpoint.metadata["summary_text"] = summary
new_checkpoint.metadata["summary_length"] = len(summary)
try:
# Determine start/end indices for the *original* range of compressed messages
message_id_to_index = {msg.id: i for i, msg in enumerate(all_messages)}
# Use original compressible_messages to define the span, even if some were rescued
indices = [message_id_to_index.get(msg.id, -1) for msg in compressible_messages]
valid_indices = [idx for idx in indices if idx >= 0]
if valid_indices:
start_idx = min(valid_indices)
end_idx = max(valid_indices)
task.add_compression_message(
start_idx=start_idx,
end_idx=end_idx,
reason="Compressed messages to maintain context",
strategies=[self.strategy_type.value],
# Report IDs actually compressed and kept
compressed_message_ids=list(final_compressed_message_ids),
kept_message_ids=list(final_kept_message_ids),
metadata={
"summary_text": summary,
"summary_length": len(summary),
# Report counts based on final state
"compressed_count": len(final_compressed_message_ids),
"initially_compressible_count": len(compressible_messages),
"rescued_count": len(rescued_message_ids),
},
)
else:
self.logger.warning("Could not determine valid indices for compression message.")
except Exception as e:
self.logger.error(f"Error creating or adding compression message: {e}", exc_info=True)
# Update checkpoint metadata even if task modification fails
if "summary_error" not in new_checkpoint.metadata: # Don't overwrite previous error
new_checkpoint.metadata["compression_message_error"] = str(e)
elif not summary and final_compressed_message_ids:
# We intended to compress but failed to generate a summary (e.g. due to errors or context limits)
self.logger.warning("Messages were identified for compression, but no summary was generated.")
new_checkpoint.metadata["summary_text"] = f"[Failed to summarize {len(final_compressed_message_ids)} messages]"
# Optionally add a compression message indicating failure? Depends on desired behavior.
# For now, we don't add a compression message if summary generation failed.
# 6) Log final stats
# Recalculate total_count accurately based on non-compression messages in the original list
original_non_compression_msgs = [m for m in all_messages if m.role != MessageRole.compression]
total_count = len(original_non_compression_msgs)
self.logger.info(
f"Compact summarization finished: Kept {len(final_kept_message_ids)}/{total_count} original messages. "
f"Successfully compressed {len(final_compressed_message_ids)} messages. "
f"Rescued {len(rescued_message_ids)} messages due to context limits."
)
# 7) Return the final "kept" messages (original critical + rescued, excluding compression role messages)
final_kept_messages: List[Message] = [
msg
for msg in all_messages # Iterate through original to preserve order somewhat
if msg.id in final_kept_message_ids and msg.role != MessageRole.compression
]
# Ensure all messages intended to be kept are included, even if order isn't perfect
final_kept_ids_set = set(m.id for m in final_kept_messages)
missing_kept = [m for m in critical_messages if m.id not in final_kept_ids_set]
if missing_kept:
self.logger.warning(f"Adding {len(missing_kept)} kept messages that were missed in primary list construction.")
# A simple append might mess up order, but ensures they are returned
final_kept_messages.extend(missing_kept)
return final_kept_messages