Coverage for src / agent / persistence.py: 81%

219 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-11 14:30 +0000

1# Copyright 2025-2026 Microsoft Corporation 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Thread persistence for Agent. 

16 

17This module provides functionality to save and load conversation threads, 

18enabling users to maintain conversation history across sessions. 

19 

20Adapted from butler-agent for agent-template architecture. 

21""" 

22 

23import json 

24import logging 

25import re 

26from datetime import datetime 

27from pathlib import Path 

28from typing import Any 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33def _sanitize_conversation_name(name: str) -> str: 

34 """Sanitize conversation name to prevent path traversal attacks. 

35 

36 Args: 

37 name: User-provided conversation name 

38 

39 Returns: 

40 Sanitized name safe for filesystem use 

41 

42 Raises: 

43 ValueError: If name is invalid or unsafe 

44 

45 Example: 

46 >>> _sanitize_conversation_name("my-session") 

47 'my-session' 

48 >>> _sanitize_conversation_name("../etc/passwd") 

49 ValueError: Invalid conversation name 

50 """ 

51 # Trim whitespace 

52 name = name.strip() 

53 

54 # Check length (1-64 characters) 

55 if not name or len(name) > 64: 

56 raise ValueError("Conversation name must be between 1 and 64 characters") 

57 

58 # Check for valid characters: alphanumeric, underscore, dash, dot 

59 if not re.match(r"^[A-Za-z0-9._-]+$", name): 

60 raise ValueError( 

61 "Conversation name can only contain letters, numbers, underscores, dashes, and dots" 

62 ) 

63 

64 # Prevent path traversal attempts 

65 if ".." in name or name.startswith("."): 

66 raise ValueError("Invalid conversation name: path traversal not allowed") 

67 

68 # Prevent reserved names 

69 reserved_names = {"index", "metadata", "con", "prn", "aux", "nul"} 

70 if name.lower() in reserved_names: 

71 raise ValueError(f"Reserved name '{name}' cannot be used") 

72 

73 return name 

74 

75 

76class ThreadPersistence: 

77 """Manage conversation thread serialization and storage. 

78 

79 This class provides thread save/load functionality with: 

80 - Automatic session directory creation (~/.osdu-agent/sessions/) 

81 - Fallback serialization when framework fails 

82 - Context summary generation for session resume 

83 - Session metadata tracking 

84 

85 Example: 

86 >>> persistence = ThreadPersistence() 

87 >>> # Save thread 

88 >>> await persistence.save_thread(thread, "my-session") 

89 >>> # Load thread 

90 >>> thread, context = await persistence.load_thread(agent, "my-session") 

91 """ 

92 

93 def __init__(self, storage_dir: Path | None = None, memory_dir: Path | None = None): 

94 """Initialize persistence manager. 

95 

96 Args: 

97 storage_dir: Directory for storing conversations 

98 (default: ~/.osdu-agent/sessions) 

99 memory_dir: Directory for storing memory state 

100 (default: ~/.osdu-agent/memory) 

101 """ 

102 if storage_dir is None: 

103 storage_dir = Path.home() / ".osdu-agent" / "sessions" 

104 

105 self.storage_dir = Path(storage_dir) 

106 self.storage_dir.mkdir(parents=True, exist_ok=True) 

107 

108 # Memory directory 

109 if memory_dir is None: 

110 memory_dir = Path.home() / ".agent" / "memory" 

111 self.memory_dir = Path(memory_dir) 

112 self.memory_dir.mkdir(parents=True, exist_ok=True) 

113 

114 # Metadata file tracks all conversations 

115 self.metadata_file = self.storage_dir / "index.json" 

116 self._load_metadata() 

117 

118 logger.debug(f"Thread persistence initialized: {self.storage_dir}") 

119 logger.debug(f"Memory persistence initialized: {self.memory_dir}") 

120 

121 def _load_metadata(self) -> None: 

122 """Load conversation metadata index.""" 

123 if self.metadata_file.exists(): 

124 try: 

125 with open(self.metadata_file) as f: 

126 self.metadata = json.load(f) 

127 except Exception as e: 

128 logger.warning(f"Failed to load metadata, starting fresh: {e}") 

129 self.metadata = {"conversations": {}} 

130 else: 

131 self.metadata = {"conversations": {}} 

132 # Create initial metadata file 

133 self._save_metadata() 

134 

135 def _save_metadata(self) -> None: 

136 """Save conversation metadata index.""" 

137 try: 

138 with open(self.metadata_file, "w") as f: 

139 json.dump(self.metadata, f, indent=2) 

140 except Exception as e: 

141 logger.error(f"Failed to save metadata: {e}") 

142 raise 

143 

144 def _generate_context_summary(self, messages: list[dict]) -> str: 

145 """Generate a concise context summary from message history. 

146 

147 Args: 

148 messages: List of message dicts with role and content 

149 

150 Returns: 

151 Context summary string for AI 

152 

153 Example: 

154 >>> messages = [{"role": "user", "content": "Hello"}] 

155 >>> summary = persistence._generate_context_summary(messages) 

156 >>> "resuming" in summary.lower() 

157 True 

158 """ 

159 if not messages: 

160 return "Empty session - no previous context." 

161 

162 summary_parts = ["You are resuming a previous session. Here's what happened:\n"] 

163 

164 # Track key information 

165 user_requests = [] 

166 tools_called = [] 

167 

168 for msg in messages: 

169 role = msg.get("role", "") 

170 content = msg.get("content", "") 

171 

172 if role == "user": 

173 user_requests.append(content[:200]) # Truncate long messages 

174 

175 # Track tool calls 

176 if "tool_calls" in msg: 

177 for tc in msg["tool_calls"]: 

178 tools_called.append(tc.get("name", "unknown")) 

179 

180 # Build summary 

181 if user_requests: 

182 summary_parts.append("User requests:") 

183 for i, req in enumerate(user_requests[:5], 1): # Max 5 

184 summary_parts.append(f"{i}. {req}") 

185 summary_parts.append("") 

186 

187 if tools_called: 

188 summary_parts.append(f"Tools used: {', '.join(set(tools_called))}") 

189 

190 summary_parts.append(f"\nTotal conversation: {len(messages)} messages exchanged.") 

191 summary_parts.append( 

192 "\nPlease continue helping the user based on this context. " 

193 "If the user asks about previous actions, you can reference the above." 

194 ) 

195 

196 return "\n".join(summary_parts) 

197 

198 async def _fallback_serialize(self, thread: Any) -> dict: 

199 """Fallback serialization when thread.serialize() fails. 

200 

201 Manually extracts messages and converts them to JSON-serializable format. 

202 

203 Args: 

204 thread: AgentThread to serialize 

205 

206 Returns: 

207 Dictionary with serialized thread data 

208 """ 

209 messages_data = [] 

210 

211 # Extract messages from message_store (Agent Framework pattern) 

212 messages = [] 

213 if hasattr(thread, "message_store") and thread.message_store: 

214 try: 

215 messages = await thread.message_store.list_messages() 

216 logger.debug(f"Extracted {len(messages)} messages from message_store") 

217 except Exception as e: 

218 logger.warning(f"Failed to list messages from store: {e}") 

219 

220 if messages: 

221 for msg in messages: 

222 # Extract role (might be a Role enum object, convert to string) 

223 role = getattr(msg, "role", "unknown") 

224 msg_dict: dict[str, Any] = {"role": str(role) if role else "unknown"} 

225 

226 # Extract message content 

227 if hasattr(msg, "text"): 

228 msg_dict["content"] = str(msg.text) 

229 elif hasattr(msg, "content"): 

230 # Content can be string or list of content blocks 

231 content = msg.content 

232 if isinstance(content, str): 

233 msg_dict["content"] = content 

234 elif isinstance(content, list): 

235 # Join content blocks 

236 msg_dict["content"] = " ".join(str(block) for block in content) 

237 else: 

238 msg_dict["content"] = str(content) 

239 else: 

240 msg_dict["content"] = str(msg) 

241 

242 # Extract tool calls if present 

243 if hasattr(msg, "tool_calls") and msg.tool_calls: 

244 tool_calls_data = [] 

245 for tc in msg.tool_calls: 

246 tool_call = { 

247 "name": str(getattr(tc, "name", "unknown")), 

248 "arguments": str(getattr(tc, "arguments", "")), 

249 } 

250 tool_calls_data.append(tool_call) 

251 msg_dict["tool_calls"] = tool_calls_data 

252 

253 messages_data.append(msg_dict) 

254 

255 return { 

256 "messages": messages_data, 

257 "metadata": {"fallback": True, "version": "1.0"}, 

258 } 

259 

260 async def save_thread( 

261 self, 

262 thread: Any, 

263 name: str, 

264 description: str | None = None, 

265 messages: list[dict] | None = None, 

266 ) -> Path: 

267 """Save a conversation thread. 

268 

269 Args: 

270 thread: AgentThread to serialize (can be None for providers without thread support) 

271 name: Name for this conversation 

272 description: Optional description 

273 messages: Optional list of message dicts for manual tracking 

274 (used when thread is None or doesn't support serialization) 

275 

276 Returns: 

277 Path to saved conversation file 

278 

279 Raises: 

280 ValueError: If name is invalid or unsafe 

281 Exception: If serialization or save fails 

282 

283 Example: 

284 >>> persistence = ThreadPersistence() 

285 >>> path = await persistence.save_thread(thread, "session-1") 

286 >>> path.exists() 

287 True 

288 """ 

289 # Sanitize name for security 

290 safe_name = _sanitize_conversation_name(name) 

291 logger.info(f"Saving conversation '{safe_name}'...") 

292 

293 # Extract first message for preview 

294 first_message = "" 

295 message_count = 0 

296 if hasattr(thread, "message_store") and thread.message_store: 

297 try: 

298 messages = await thread.message_store.list_messages() 

299 if messages is not None: 

300 message_count = len(messages) 

301 

302 # Try to get first user message 

303 for msg in messages: 

304 if hasattr(msg, "role") and str(msg.role) == "user": 

305 if hasattr(msg, "text"): 

306 first_message = str(msg.text)[:100] 

307 elif hasattr(msg, "content"): 

308 content = msg.content 

309 if isinstance(content, str): 

310 first_message = content[:100] 

311 else: 

312 first_message = str(content)[:100] 

313 break 

314 except Exception as e: 

315 logger.warning(f"Failed to extract first message: {e}") 

316 

317 # Serialize thread 

318 serialized = None 

319 try: 

320 # Try framework serialization first 

321 if thread and hasattr(thread, "serialize"): 

322 serialized = await thread.serialize() 

323 logger.debug("Used framework serialization") 

324 except Exception as e: 

325 logger.warning(f"Framework serialization failed: {e}, using fallback") 

326 

327 # Fallback to manual serialization if needed 

328 if serialized is None: 

329 # Use provided messages if available, otherwise extract from thread 

330 if messages: 

331 logger.debug( 

332 f"Using manually tracked messages for fallback serialization ({len(messages)} messages)" 

333 ) 

334 serialized = { 

335 "messages": messages, 

336 "metadata": {"fallback": True, "version": "1.0", "manual_tracking": True}, 

337 } 

338 else: 

339 serialized = await self._fallback_serialize(thread) 

340 logger.debug("Used fallback serialization from thread") 

341 

342 # Build conversation data 

343 conversation_data = { 

344 "name": safe_name, 

345 "description": description or "", 

346 "created_at": datetime.now().isoformat(), 

347 "updated_at": datetime.now().isoformat(), 

348 "message_count": message_count, 

349 "first_message": first_message, 

350 "thread": serialized, 

351 } 

352 

353 # Save to file 

354 file_path = self.storage_dir / f"{safe_name}.json" 

355 with open(file_path, "w") as f: 

356 json.dump(conversation_data, f, indent=2) 

357 

358 # Update metadata 

359 self.metadata["conversations"][safe_name] = { 

360 "name": safe_name, 

361 "description": description or "", 

362 "created_at": conversation_data["created_at"], 

363 "updated_at": conversation_data["updated_at"], 

364 "message_count": message_count, 

365 "first_message": first_message, 

366 } 

367 self._save_metadata() 

368 

369 logger.info(f"Saved conversation to {file_path}") 

370 return file_path 

371 

372 async def load_thread( 

373 self, agent: Any, name: str, show_history: bool = True 

374 ) -> tuple[Any, str | None]: 

375 """Load a conversation thread. 

376 

377 Args: 

378 agent: Agent instance to create new thread 

379 name: Name of conversation to load 

380 show_history: Whether to display conversation history (default: True) 

381 

382 Returns: 

383 Tuple of (thread, context_summary) 

384 - thread: Loaded or new thread 

385 - context_summary: Summary for AI if using fallback, None otherwise 

386 

387 Raises: 

388 FileNotFoundError: If conversation doesn't exist 

389 Exception: If deserialization fails 

390 

391 Example: 

392 >>> thread, context = await persistence.load_thread(agent, "session-1") 

393 >>> thread is not None 

394 True 

395 """ 

396 safe_name = _sanitize_conversation_name(name) 

397 file_path = self.storage_dir / f"{safe_name}.json" 

398 

399 if not file_path.exists(): 

400 raise FileNotFoundError(f"Conversation '{safe_name}' not found") 

401 

402 logger.info(f"Loading conversation '{safe_name}'...") 

403 

404 with open(file_path) as f: 

405 data = json.load(f) 

406 

407 thread_data = data["thread"] 

408 

409 # Extract messages for context (may or may not display) 

410 messages = thread_data.get("messages", []) 

411 

412 # Display conversation history only if requested 

413 # When memory is enabled, we suppress this since memory handles context 

414 if show_history and messages: 

415 from rich.console import Console 

416 from rich.markdown import Markdown 

417 

418 console = Console() 

419 

420 # Small header to indicate resuming (subtle, not intrusive) 

421 console.print( 

422 f"\n[dim italic]Resuming session ({len(messages)} messages)[/dim italic]\n" 

423 ) 

424 

425 # Display in the same format as live conversation for consistency 

426 for _i, msg in enumerate(messages): 

427 role = msg.get("role", "unknown") 

428 content = msg.get("content", "") 

429 

430 if role == "user": 

431 # Match the prompt format exactly: > user_message 

432 console.print(f"> {content}") 

433 console.print() # Blank line after user input (like live conversation) 

434 elif role == "assistant": 

435 # Match the agent response format (markdown rendering, no label) 

436 console.print(Markdown(content)) 

437 # Don't add blank line - let CLI handle spacing 

438 

439 # Check if fallback serialization was used 

440 if thread_data.get("metadata", {}).get("fallback"): 

441 logger.info("Loading fallback-serialized session") 

442 

443 # Generate context summary for AI 

444 context_summary = self._generate_context_summary(messages) 

445 

446 # Create new thread (can't deserialize fallback format) 

447 thread = ( 

448 agent.chat_client.create_thread() 

449 if hasattr(agent.chat_client, "create_thread") 

450 else None 

451 ) 

452 

453 return thread, context_summary 

454 else: 

455 # Try framework deserialization 

456 try: 

457 if hasattr(agent.chat_client, "deserialize_thread"): 

458 thread = await agent.chat_client.deserialize_thread(thread_data) 

459 logger.info("Successfully deserialized thread using framework") 

460 # Thread has full context, no summary needed 

461 return thread, None 

462 else: 

463 # Framework doesn't support deserialization, use fallback 

464 logger.warning("Framework doesn't support deserialization, using fallback") 

465 context_summary = self._generate_context_summary(messages) 

466 thread = ( 

467 agent.chat_client.create_thread() 

468 if hasattr(agent.chat_client, "create_thread") 

469 else None 

470 ) 

471 return thread, context_summary 

472 except Exception as e: 

473 logger.error(f"Deserialization failed: {e}, using fallback") 

474 context_summary = self._generate_context_summary(messages) 

475 thread = ( 

476 agent.chat_client.create_thread() 

477 if hasattr(agent.chat_client, "create_thread") 

478 else None 

479 ) 

480 return thread, context_summary 

481 

482 def list_sessions(self) -> list[dict]: 

483 """List all saved conversation sessions. 

484 

485 Returns: 

486 List of session metadata dicts 

487 

488 Example: 

489 >>> sessions = persistence.list_sessions() 

490 >>> all("name" in s for s in sessions) 

491 True 

492 """ 

493 return list(self.metadata.get("conversations", {}).values()) 

494 

495 def delete_session(self, name: str) -> None: 

496 """Delete a conversation session. 

497 

498 Args: 

499 name: Name of session to delete 

500 

501 Raises: 

502 FileNotFoundError: If session doesn't exist 

503 

504 Example: 

505 >>> persistence.delete_session("session-1") 

506 """ 

507 safe_name = _sanitize_conversation_name(name) 

508 file_path = self.storage_dir / f"{safe_name}.json" 

509 

510 if not file_path.exists(): 

511 raise FileNotFoundError(f"Session '{safe_name}' not found") 

512 

513 # Delete file 

514 file_path.unlink() 

515 

516 # Remove from metadata 

517 if safe_name in self.metadata.get("conversations", {}): 

518 del self.metadata["conversations"][safe_name] 

519 self._save_metadata() 

520 

521 logger.info(f"Deleted session '{safe_name}'") 

522 

523 async def save_memory_state(self, session_name: str, memory_data: list[dict]) -> Path: 

524 """Save memory state for a session. 

525 

526 Args: 

527 session_name: Name of the session 

528 memory_data: List of memory entries to save 

529 

530 Returns: 

531 Path to saved memory file 

532 

533 Raises: 

534 ValueError: If session_name is invalid 

535 

536 Example: 

537 >>> memories = [{"role": "user", "content": "Hello"}] 

538 >>> path = await persistence.save_memory_state("session-1", memories) 

539 """ 

540 from agent.memory.persistence import MemoryPersistence 

541 

542 safe_name = _sanitize_conversation_name(session_name) 

543 memory_persistence = MemoryPersistence(storage_dir=self.memory_dir) 

544 

545 # Get memory file path 

546 memory_path = memory_persistence.get_memory_path(safe_name) 

547 

548 # Save memory state 

549 await memory_persistence.save(memory_data, memory_path) 

550 

551 # Update session metadata to track memory 

552 if safe_name in self.metadata.get("conversations", {}): 

553 self.metadata["conversations"][safe_name]["has_memory"] = True 

554 self.metadata["conversations"][safe_name]["memory_count"] = len(memory_data) 

555 self._save_metadata() 

556 

557 return memory_path 

558 

559 async def load_memory_state(self, session_name: str) -> list[dict] | None: 

560 """Load memory state for a session. 

561 

562 Args: 

563 session_name: Name of the session 

564 

565 Returns: 

566 List of memory entries or None if no memory exists 

567 

568 Raises: 

569 ValueError: If session_name is invalid 

570 

571 Example: 

572 >>> memories = await persistence.load_memory_state("session-1") 

573 """ 

574 from agent.memory.persistence import MemoryPersistence 

575 

576 safe_name = _sanitize_conversation_name(session_name) 

577 memory_persistence = MemoryPersistence(storage_dir=self.memory_dir) 

578 

579 # Get memory file path 

580 memory_path = memory_persistence.get_memory_path(safe_name) 

581 

582 # Load memory state 

583 return await memory_persistence.load(memory_path)