Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 73 additions & 4 deletions src/strands/session/s3_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast

import boto3
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
boto_session: Optional[boto3.Session] = None,
boto_client_config: Optional[BotocoreConfig] = None,
region_name: Optional[str] = None,
max_parallel_reads: int = 1,
**kwargs: Any,
):
"""Initialize S3SessionManager with S3 storage.
Expand All @@ -62,11 +64,20 @@ def __init__(
boto_session: Optional boto3 session
boto_client_config: Optional boto3 client configuration
region_name: AWS region for S3 storage
max_parallel_reads: Maximum number of parallel S3 read operations for list_messages().
Defaults to 1 (sequential) for backward compatibility and safety.
Set to a higher value (e.g., 10) for better performance with many messages.
Can be overridden per-call via list_messages() kwargs.
**kwargs: Additional keyword arguments for future extensibility.
"""
self.bucket = bucket
self.prefix = prefix

# Validate max_parallel_reads
if not isinstance(max_parallel_reads, int) or max_parallel_reads < 1:
raise ValueError(f"max_parallel_reads must be a positive integer, got {max_parallel_reads}")
self.max_parallel_reads = max_parallel_reads

session = boto_session or boto3.Session(region_name=region_name)

# Add strands-agents to the request user agent
Expand Down Expand Up @@ -259,7 +270,24 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio
def list_messages(
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any
) -> List[SessionMessage]:
"""List messages for an agent with pagination from S3."""
"""List messages for an agent with pagination from S3.

Args:
session_id: ID of the session
agent_id: ID of the agent
limit: Optional limit on number of messages to return
offset: Optional offset for pagination
**kwargs: Additional keyword arguments. Supports:
max_parallel_reads: Override the instance-level max_parallel_reads setting
for this call only.

Returns:
List of SessionMessage objects, sorted by message_id.

Raises:
ValueError: If max_parallel_reads override is not a positive integer.
SessionException: If S3 error occurs during message retrieval.
"""
messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/"
try:
paginator = self.client.get_paginator("list_objects_v2")
Expand Down Expand Up @@ -287,10 +315,51 @@ def list_messages(
else:
message_keys = message_keys[offset:]

# Load only the required message objects
# Load message objects in parallel for better performance
messages: List[SessionMessage] = []
for key in message_keys:
message_data = self._read_s3_object(key)
if not message_keys:
return messages

# Use ThreadPoolExecutor to fetch messages concurrently
# Allow per-call override of max_parallel_reads via kwargs, otherwise use instance default
max_parallel_reads_override = kwargs.get("max_parallel_reads")
if max_parallel_reads_override is not None:
if not isinstance(max_parallel_reads_override, int) or max_parallel_reads_override < 1:
raise ValueError(
f"max_parallel_reads must be a positive integer, got {max_parallel_reads_override}"
)
max_parallel_reads_value = max_parallel_reads_override
else:
# Instance default was already validated in __init__, no need to check again
max_parallel_reads_value = self.max_parallel_reads

max_workers = min(max_parallel_reads_value, len(message_keys))

with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all read tasks
future_to_key = {executor.submit(self._read_s3_object, key): key for key in message_keys}

# Create a mapping from key to index to maintain order
key_to_index = {key: idx for idx, key in enumerate(message_keys)}

# Initialize results list with None placeholders to maintain order
results: List[Optional[Dict[str, Any]]] = [None] * len(message_keys)

# Process results as they complete
for future in as_completed(future_to_key):
key = future_to_key[future]
try:
message_data = future.result()
# Store result at the correct index to maintain order
results[key_to_index[key]] = message_data
except Exception as e:
# Log error but continue processing other messages
# Individual failures shouldn't stop the entire operation
logger.warning("key=<%s> | failed to read message from s3", key, exc_info=e)

# Convert results to SessionMessage objects, filtering out None values
# If SessionMessage.from_dict fails, let it propagate - data corruption should be visible
for message_data in results:
if message_data:
messages.append(SessionMessage.from_dict(message_data))

Expand Down
Loading