Coverage for src / agent / persistence.py: 81%
219 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 14:30 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 14:30 +0000
1# Copyright 2025-2026 Microsoft Corporation
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Thread persistence for Agent.
17This module provides functionality to save and load conversation threads,
18enabling users to maintain conversation history across sessions.
20Adapted from butler-agent for agent-template architecture.
21"""
23import json
24import logging
25import re
26from datetime import datetime
27from pathlib import Path
28from typing import Any
30logger = logging.getLogger(__name__)
33def _sanitize_conversation_name(name: str) -> str:
34 """Sanitize conversation name to prevent path traversal attacks.
36 Args:
37 name: User-provided conversation name
39 Returns:
40 Sanitized name safe for filesystem use
42 Raises:
43 ValueError: If name is invalid or unsafe
45 Example:
46 >>> _sanitize_conversation_name("my-session")
47 'my-session'
48 >>> _sanitize_conversation_name("../etc/passwd")
49 ValueError: Invalid conversation name
50 """
51 # Trim whitespace
52 name = name.strip()
54 # Check length (1-64 characters)
55 if not name or len(name) > 64:
56 raise ValueError("Conversation name must be between 1 and 64 characters")
58 # Check for valid characters: alphanumeric, underscore, dash, dot
59 if not re.match(r"^[A-Za-z0-9._-]+$", name):
60 raise ValueError(
61 "Conversation name can only contain letters, numbers, underscores, dashes, and dots"
62 )
64 # Prevent path traversal attempts
65 if ".." in name or name.startswith("."):
66 raise ValueError("Invalid conversation name: path traversal not allowed")
68 # Prevent reserved names
69 reserved_names = {"index", "metadata", "con", "prn", "aux", "nul"}
70 if name.lower() in reserved_names:
71 raise ValueError(f"Reserved name '{name}' cannot be used")
73 return name
76class ThreadPersistence:
77 """Manage conversation thread serialization and storage.
79 This class provides thread save/load functionality with:
80 - Automatic session directory creation (~/.osdu-agent/sessions/)
81 - Fallback serialization when framework fails
82 - Context summary generation for session resume
83 - Session metadata tracking
85 Example:
86 >>> persistence = ThreadPersistence()
87 >>> # Save thread
88 >>> await persistence.save_thread(thread, "my-session")
89 >>> # Load thread
90 >>> thread, context = await persistence.load_thread(agent, "my-session")
91 """
93 def __init__(self, storage_dir: Path | None = None, memory_dir: Path | None = None):
94 """Initialize persistence manager.
96 Args:
97 storage_dir: Directory for storing conversations
98 (default: ~/.osdu-agent/sessions)
99 memory_dir: Directory for storing memory state
100 (default: ~/.osdu-agent/memory)
101 """
102 if storage_dir is None:
103 storage_dir = Path.home() / ".osdu-agent" / "sessions"
105 self.storage_dir = Path(storage_dir)
106 self.storage_dir.mkdir(parents=True, exist_ok=True)
108 # Memory directory
109 if memory_dir is None:
110 memory_dir = Path.home() / ".agent" / "memory"
111 self.memory_dir = Path(memory_dir)
112 self.memory_dir.mkdir(parents=True, exist_ok=True)
114 # Metadata file tracks all conversations
115 self.metadata_file = self.storage_dir / "index.json"
116 self._load_metadata()
118 logger.debug(f"Thread persistence initialized: {self.storage_dir}")
119 logger.debug(f"Memory persistence initialized: {self.memory_dir}")
121 def _load_metadata(self) -> None:
122 """Load conversation metadata index."""
123 if self.metadata_file.exists():
124 try:
125 with open(self.metadata_file) as f:
126 self.metadata = json.load(f)
127 except Exception as e:
128 logger.warning(f"Failed to load metadata, starting fresh: {e}")
129 self.metadata = {"conversations": {}}
130 else:
131 self.metadata = {"conversations": {}}
132 # Create initial metadata file
133 self._save_metadata()
135 def _save_metadata(self) -> None:
136 """Save conversation metadata index."""
137 try:
138 with open(self.metadata_file, "w") as f:
139 json.dump(self.metadata, f, indent=2)
140 except Exception as e:
141 logger.error(f"Failed to save metadata: {e}")
142 raise
144 def _generate_context_summary(self, messages: list[dict]) -> str:
145 """Generate a concise context summary from message history.
147 Args:
148 messages: List of message dicts with role and content
150 Returns:
151 Context summary string for AI
153 Example:
154 >>> messages = [{"role": "user", "content": "Hello"}]
155 >>> summary = persistence._generate_context_summary(messages)
156 >>> "resuming" in summary.lower()
157 True
158 """
159 if not messages:
160 return "Empty session - no previous context."
162 summary_parts = ["You are resuming a previous session. Here's what happened:\n"]
164 # Track key information
165 user_requests = []
166 tools_called = []
168 for msg in messages:
169 role = msg.get("role", "")
170 content = msg.get("content", "")
172 if role == "user":
173 user_requests.append(content[:200]) # Truncate long messages
175 # Track tool calls
176 if "tool_calls" in msg:
177 for tc in msg["tool_calls"]:
178 tools_called.append(tc.get("name", "unknown"))
180 # Build summary
181 if user_requests:
182 summary_parts.append("User requests:")
183 for i, req in enumerate(user_requests[:5], 1): # Max 5
184 summary_parts.append(f"{i}. {req}")
185 summary_parts.append("")
187 if tools_called:
188 summary_parts.append(f"Tools used: {', '.join(set(tools_called))}")
190 summary_parts.append(f"\nTotal conversation: {len(messages)} messages exchanged.")
191 summary_parts.append(
192 "\nPlease continue helping the user based on this context. "
193 "If the user asks about previous actions, you can reference the above."
194 )
196 return "\n".join(summary_parts)
198 async def _fallback_serialize(self, thread: Any) -> dict:
199 """Fallback serialization when thread.serialize() fails.
201 Manually extracts messages and converts them to JSON-serializable format.
203 Args:
204 thread: AgentThread to serialize
206 Returns:
207 Dictionary with serialized thread data
208 """
209 messages_data = []
211 # Extract messages from message_store (Agent Framework pattern)
212 messages = []
213 if hasattr(thread, "message_store") and thread.message_store:
214 try:
215 messages = await thread.message_store.list_messages()
216 logger.debug(f"Extracted {len(messages)} messages from message_store")
217 except Exception as e:
218 logger.warning(f"Failed to list messages from store: {e}")
220 if messages:
221 for msg in messages:
222 # Extract role (might be a Role enum object, convert to string)
223 role = getattr(msg, "role", "unknown")
224 msg_dict: dict[str, Any] = {"role": str(role) if role else "unknown"}
226 # Extract message content
227 if hasattr(msg, "text"):
228 msg_dict["content"] = str(msg.text)
229 elif hasattr(msg, "content"):
230 # Content can be string or list of content blocks
231 content = msg.content
232 if isinstance(content, str):
233 msg_dict["content"] = content
234 elif isinstance(content, list):
235 # Join content blocks
236 msg_dict["content"] = " ".join(str(block) for block in content)
237 else:
238 msg_dict["content"] = str(content)
239 else:
240 msg_dict["content"] = str(msg)
242 # Extract tool calls if present
243 if hasattr(msg, "tool_calls") and msg.tool_calls:
244 tool_calls_data = []
245 for tc in msg.tool_calls:
246 tool_call = {
247 "name": str(getattr(tc, "name", "unknown")),
248 "arguments": str(getattr(tc, "arguments", "")),
249 }
250 tool_calls_data.append(tool_call)
251 msg_dict["tool_calls"] = tool_calls_data
253 messages_data.append(msg_dict)
255 return {
256 "messages": messages_data,
257 "metadata": {"fallback": True, "version": "1.0"},
258 }
260 async def save_thread(
261 self,
262 thread: Any,
263 name: str,
264 description: str | None = None,
265 messages: list[dict] | None = None,
266 ) -> Path:
267 """Save a conversation thread.
269 Args:
270 thread: AgentThread to serialize (can be None for providers without thread support)
271 name: Name for this conversation
272 description: Optional description
273 messages: Optional list of message dicts for manual tracking
274 (used when thread is None or doesn't support serialization)
276 Returns:
277 Path to saved conversation file
279 Raises:
280 ValueError: If name is invalid or unsafe
281 Exception: If serialization or save fails
283 Example:
284 >>> persistence = ThreadPersistence()
285 >>> path = await persistence.save_thread(thread, "session-1")
286 >>> path.exists()
287 True
288 """
289 # Sanitize name for security
290 safe_name = _sanitize_conversation_name(name)
291 logger.info(f"Saving conversation '{safe_name}'...")
293 # Extract first message for preview
294 first_message = ""
295 message_count = 0
296 if hasattr(thread, "message_store") and thread.message_store:
297 try:
298 messages = await thread.message_store.list_messages()
299 if messages is not None:
300 message_count = len(messages)
302 # Try to get first user message
303 for msg in messages:
304 if hasattr(msg, "role") and str(msg.role) == "user":
305 if hasattr(msg, "text"):
306 first_message = str(msg.text)[:100]
307 elif hasattr(msg, "content"):
308 content = msg.content
309 if isinstance(content, str):
310 first_message = content[:100]
311 else:
312 first_message = str(content)[:100]
313 break
314 except Exception as e:
315 logger.warning(f"Failed to extract first message: {e}")
317 # Serialize thread
318 serialized = None
319 try:
320 # Try framework serialization first
321 if thread and hasattr(thread, "serialize"):
322 serialized = await thread.serialize()
323 logger.debug("Used framework serialization")
324 except Exception as e:
325 logger.warning(f"Framework serialization failed: {e}, using fallback")
327 # Fallback to manual serialization if needed
328 if serialized is None:
329 # Use provided messages if available, otherwise extract from thread
330 if messages:
331 logger.debug(
332 f"Using manually tracked messages for fallback serialization ({len(messages)} messages)"
333 )
334 serialized = {
335 "messages": messages,
336 "metadata": {"fallback": True, "version": "1.0", "manual_tracking": True},
337 }
338 else:
339 serialized = await self._fallback_serialize(thread)
340 logger.debug("Used fallback serialization from thread")
342 # Build conversation data
343 conversation_data = {
344 "name": safe_name,
345 "description": description or "",
346 "created_at": datetime.now().isoformat(),
347 "updated_at": datetime.now().isoformat(),
348 "message_count": message_count,
349 "first_message": first_message,
350 "thread": serialized,
351 }
353 # Save to file
354 file_path = self.storage_dir / f"{safe_name}.json"
355 with open(file_path, "w") as f:
356 json.dump(conversation_data, f, indent=2)
358 # Update metadata
359 self.metadata["conversations"][safe_name] = {
360 "name": safe_name,
361 "description": description or "",
362 "created_at": conversation_data["created_at"],
363 "updated_at": conversation_data["updated_at"],
364 "message_count": message_count,
365 "first_message": first_message,
366 }
367 self._save_metadata()
369 logger.info(f"Saved conversation to {file_path}")
370 return file_path
372 async def load_thread(
373 self, agent: Any, name: str, show_history: bool = True
374 ) -> tuple[Any, str | None]:
375 """Load a conversation thread.
377 Args:
378 agent: Agent instance to create new thread
379 name: Name of conversation to load
380 show_history: Whether to display conversation history (default: True)
382 Returns:
383 Tuple of (thread, context_summary)
384 - thread: Loaded or new thread
385 - context_summary: Summary for AI if using fallback, None otherwise
387 Raises:
388 FileNotFoundError: If conversation doesn't exist
389 Exception: If deserialization fails
391 Example:
392 >>> thread, context = await persistence.load_thread(agent, "session-1")
393 >>> thread is not None
394 True
395 """
396 safe_name = _sanitize_conversation_name(name)
397 file_path = self.storage_dir / f"{safe_name}.json"
399 if not file_path.exists():
400 raise FileNotFoundError(f"Conversation '{safe_name}' not found")
402 logger.info(f"Loading conversation '{safe_name}'...")
404 with open(file_path) as f:
405 data = json.load(f)
407 thread_data = data["thread"]
409 # Extract messages for context (may or may not display)
410 messages = thread_data.get("messages", [])
412 # Display conversation history only if requested
413 # When memory is enabled, we suppress this since memory handles context
414 if show_history and messages:
415 from rich.console import Console
416 from rich.markdown import Markdown
418 console = Console()
420 # Small header to indicate resuming (subtle, not intrusive)
421 console.print(
422 f"\n[dim italic]Resuming session ({len(messages)} messages)[/dim italic]\n"
423 )
425 # Display in the same format as live conversation for consistency
426 for _i, msg in enumerate(messages):
427 role = msg.get("role", "unknown")
428 content = msg.get("content", "")
430 if role == "user":
431 # Match the prompt format exactly: > user_message
432 console.print(f"> {content}")
433 console.print() # Blank line after user input (like live conversation)
434 elif role == "assistant":
435 # Match the agent response format (markdown rendering, no label)
436 console.print(Markdown(content))
437 # Don't add blank line - let CLI handle spacing
439 # Check if fallback serialization was used
440 if thread_data.get("metadata", {}).get("fallback"):
441 logger.info("Loading fallback-serialized session")
443 # Generate context summary for AI
444 context_summary = self._generate_context_summary(messages)
446 # Create new thread (can't deserialize fallback format)
447 thread = (
448 agent.chat_client.create_thread()
449 if hasattr(agent.chat_client, "create_thread")
450 else None
451 )
453 return thread, context_summary
454 else:
455 # Try framework deserialization
456 try:
457 if hasattr(agent.chat_client, "deserialize_thread"):
458 thread = await agent.chat_client.deserialize_thread(thread_data)
459 logger.info("Successfully deserialized thread using framework")
460 # Thread has full context, no summary needed
461 return thread, None
462 else:
463 # Framework doesn't support deserialization, use fallback
464 logger.warning("Framework doesn't support deserialization, using fallback")
465 context_summary = self._generate_context_summary(messages)
466 thread = (
467 agent.chat_client.create_thread()
468 if hasattr(agent.chat_client, "create_thread")
469 else None
470 )
471 return thread, context_summary
472 except Exception as e:
473 logger.error(f"Deserialization failed: {e}, using fallback")
474 context_summary = self._generate_context_summary(messages)
475 thread = (
476 agent.chat_client.create_thread()
477 if hasattr(agent.chat_client, "create_thread")
478 else None
479 )
480 return thread, context_summary
482 def list_sessions(self) -> list[dict]:
483 """List all saved conversation sessions.
485 Returns:
486 List of session metadata dicts
488 Example:
489 >>> sessions = persistence.list_sessions()
490 >>> all("name" in s for s in sessions)
491 True
492 """
493 return list(self.metadata.get("conversations", {}).values())
495 def delete_session(self, name: str) -> None:
496 """Delete a conversation session.
498 Args:
499 name: Name of session to delete
501 Raises:
502 FileNotFoundError: If session doesn't exist
504 Example:
505 >>> persistence.delete_session("session-1")
506 """
507 safe_name = _sanitize_conversation_name(name)
508 file_path = self.storage_dir / f"{safe_name}.json"
510 if not file_path.exists():
511 raise FileNotFoundError(f"Session '{safe_name}' not found")
513 # Delete file
514 file_path.unlink()
516 # Remove from metadata
517 if safe_name in self.metadata.get("conversations", {}):
518 del self.metadata["conversations"][safe_name]
519 self._save_metadata()
521 logger.info(f"Deleted session '{safe_name}'")
523 async def save_memory_state(self, session_name: str, memory_data: list[dict]) -> Path:
524 """Save memory state for a session.
526 Args:
527 session_name: Name of the session
528 memory_data: List of memory entries to save
530 Returns:
531 Path to saved memory file
533 Raises:
534 ValueError: If session_name is invalid
536 Example:
537 >>> memories = [{"role": "user", "content": "Hello"}]
538 >>> path = await persistence.save_memory_state("session-1", memories)
539 """
540 from agent.memory.persistence import MemoryPersistence
542 safe_name = _sanitize_conversation_name(session_name)
543 memory_persistence = MemoryPersistence(storage_dir=self.memory_dir)
545 # Get memory file path
546 memory_path = memory_persistence.get_memory_path(safe_name)
548 # Save memory state
549 await memory_persistence.save(memory_data, memory_path)
551 # Update session metadata to track memory
552 if safe_name in self.metadata.get("conversations", {}):
553 self.metadata["conversations"][safe_name]["has_memory"] = True
554 self.metadata["conversations"][safe_name]["memory_count"] = len(memory_data)
555 self._save_metadata()
557 return memory_path
559 async def load_memory_state(self, session_name: str) -> list[dict] | None:
560 """Load memory state for a session.
562 Args:
563 session_name: Name of the session
565 Returns:
566 List of memory entries or None if no memory exists
568 Raises:
569 ValueError: If session_name is invalid
571 Example:
572 >>> memories = await persistence.load_memory_state("session-1")
573 """
574 from agent.memory.persistence import MemoryPersistence
576 safe_name = _sanitize_conversation_name(session_name)
577 memory_persistence = MemoryPersistence(storage_dir=self.memory_dir)
579 # Get memory file path
580 memory_path = memory_persistence.get_memory_path(safe_name)
582 # Load memory state
583 return await memory_persistence.load(memory_path)