Coverage for src / agent / memory / mem0_store.py: 90%

132 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-11 14:30 +0000

1# Copyright 2025-2026 Microsoft Corporation 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Mem0-based semantic memory storage implementation. 

16 

17This module provides semantic memory storage using mem0's Python library 

18with support for local Chroma storage and cloud mem0.ai service. 

19""" 

20 

21import asyncio 

22import logging 

23import re 

24from datetime import UTC, datetime 

25 

26from agent.config.schema import AgentSettings 

27from agent.memory.manager import MemoryManager 

28from agent.memory.mem0_utils import create_memory_instance 

29 

30logger = logging.getLogger(__name__) 

31 

32# Patterns for detecting sensitive data (API keys, tokens, passwords) 

33SENSITIVE_PATTERNS = [ 

34 re.compile(r"sk[-_][a-zA-Z0-9_]{20,}", re.IGNORECASE), # API keys (sk_...) 

35 re.compile(r"bearer\s+[a-zA-Z0-9\-._~+/]+=*", re.IGNORECASE), # Bearer tokens 

36 re.compile(r'api[_-]?key["\s:=]+[a-zA-Z0-9\-._~+/]+', re.IGNORECASE), # API key assignments 

37 re.compile(r'token["\s:=]+[a-zA-Z0-9\-._~+/]{20,}', re.IGNORECASE), # Token assignments 

38 re.compile(r'password["\s:=]+\S+', re.IGNORECASE), # Password assignments 

39] 

40 

41 

42class Mem0Store(MemoryManager): 

43 """Semantic memory storage using mem0 Python library. 

44 

45 Integrates mem0 directly into the agent process, reusing the agent's 

46 existing LLM configuration. Supports local Chroma storage or cloud mem0.ai. 

47 

48 Attributes: 

49 config: Agent configuration with mem0 settings 

50 memory: mem0.Memory instance 

51 user_id: User namespace for memory isolation 

52 namespace: Combined user:project namespace 

53 

54 Example: 

55 >>> config = AgentConfig.from_env() 

56 >>> store = Mem0Store(config) 

57 >>> await store.add([{"role": "user", "content": "My name is Alice"}]) 

58 """ 

59 

60 def __init__(self, config: AgentSettings): 

61 """Initialize mem0 store with configuration. 

62 

63 Args: 

64 config: Agent configuration with mem0 settings 

65 

66 Raises: 

67 ValueError: If mem0 configuration is invalid 

68 ImportError: If mem0ai or chromadb not installed 

69 """ 

70 super().__init__(config) 

71 

72 try: 

73 self.memory = create_memory_instance(config) 

74 logger.info("Mem0Store initialized successfully") 

75 except Exception as e: 

76 logger.error(f"Failed to initialize Mem0Store: {e}") 

77 raise 

78 

79 # Set up namespacing for user/project isolation 

80 self.user_id = config.mem0_user_id or "default-user" 

81 

82 # Build namespace (user:project or just user) 

83 if config.mem0_project_id: 

84 self.namespace = f"{self.user_id}:{config.mem0_project_id}" 

85 else: 

86 self.namespace = self.user_id 

87 

88 logger.debug(f"Mem0Store namespace: {self.namespace}") 

89 

90 def _scrub_sensitive_content(self, content: str) -> tuple[str, bool]: 

91 """Scrub potential secrets from content before storage. 

92 

93 Args: 

94 content: Message content to check 

95 

96 Returns: 

97 Tuple of (scrubbed_content, was_modified) 

98 """ 

99 modified = False 

100 

101 for pattern in SENSITIVE_PATTERNS: 

102 if pattern.search(content): 

103 content = pattern.sub("[REDACTED]", content) 

104 modified = True 

105 

106 if modified: 

107 logger.warning("Detected and redacted potential secrets from content") 

108 

109 return content, modified 

110 

111 def _should_save_message(self, msg: dict) -> bool: 

112 """Check if message should be saved to memory. 

113 

114 Implements safety gates to prevent storing sensitive data. 

115 

116 Args: 

117 msg: Message dict with role, content, and optional metadata 

118 

119 Returns: 

120 True if message should be saved, False otherwise 

121 """ 

122 # Check for explicit opt-out 

123 metadata = msg.get("metadata", {}) 

124 if metadata.get("save") is False: 

125 logger.debug(f"Skipping message with save=false: {msg.get('role')}") 

126 return False 

127 

128 # Check for explicit opt-in (overrides filters) 

129 if metadata.get("force_save") is True: 

130 return True 

131 

132 # Default: only save user and assistant messages 

133 role = msg.get("role", "") 

134 return role in ("user", "assistant") 

135 

136 async def add(self, messages: list[dict]) -> dict: 

137 """Add messages to mem0 storage with semantic indexing. 

138 

139 Messages are automatically indexed with vector embeddings and 

140 entity extraction for semantic search. 

141 

142 Args: 

143 messages: List of message dicts with role and content 

144 

145 Returns: 

146 Structured response dict with success status 

147 

148 Example: 

149 >>> await store.add([ 

150 ... {"role": "user", "content": "My name is Alice"}, 

151 ... {"role": "assistant", "content": "Nice to meet you, Alice!"} 

152 ... ]) 

153 """ 

154 if not messages: 

155 return self._create_error_response( 

156 error="invalid_input", message="No messages provided" 

157 ) 

158 

159 try: 

160 # Filter and prepare messages 

161 messages_to_add = [] 

162 for msg in messages: 

163 # Validate message structure 

164 if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: 

165 logger.warning(f"Skipping invalid message: {msg}") 

166 continue 

167 

168 # Apply safety gates 

169 if not self._should_save_message(msg): 

170 continue 

171 

172 content = msg.get("content", "").strip() 

173 if not content: 

174 continue 

175 

176 # Scrub sensitive content 

177 scrubbed_content, was_scrubbed = self._scrub_sensitive_content(content) 

178 

179 messages_to_add.append({"role": msg["role"], "content": scrubbed_content}) 

180 

181 if not messages_to_add: 

182 return self._create_success_response( 

183 result=[], message="No messages to add after filtering" 

184 ) 

185 

186 # Add to mem0 (wrapped to avoid blocking event loop) 

187 await asyncio.to_thread( 

188 self.memory.add, messages=messages_to_add, user_id=self.namespace 

189 ) 

190 

191 logger.debug(f"Added {len(messages_to_add)} messages to mem0") 

192 

193 return self._create_success_response( 

194 result=messages_to_add, message=f"Added {len(messages_to_add)} messages to memory" 

195 ) 

196 

197 except Exception as e: 

198 logger.error(f"Error adding messages to mem0: {e}", exc_info=True) 

199 return self._create_error_response( 

200 error="storage_error", message=f"Failed to add messages: {str(e)}" 

201 ) 

202 

203 async def search(self, query: str, limit: int = 5) -> dict: 

204 """Search memories by semantic similarity. 

205 

206 Uses vector embeddings to find semantically similar memories, 

207 not just keyword matches. 

208 

209 Args: 

210 query: Search query string 

211 limit: Maximum number of results 

212 

213 Returns: 

214 Structured response dict with matching memories 

215 

216 Example: 

217 >>> result = await store.search("authentication errors", limit=5) 

218 >>> # Will also find memories about "login failures" 

219 """ 

220 if not query or not query.strip(): 

221 return self._create_error_response( 

222 error="invalid_query", message="Search query cannot be empty" 

223 ) 

224 

225 try: 

226 # Search mem0 with semantic similarity 

227 results = await asyncio.to_thread( 

228 self.memory.search, query=query, user_id=self.namespace, limit=limit 

229 ) 

230 

231 # Convert mem0 results to standardized format 

232 memories = [] 

233 if results and "results" in results: 

234 for result in results["results"]: 

235 memory = { 

236 "id": result.get("id") or result.get("memory_id"), 

237 "role": "assistant", # mem0 doesn't store role separately 

238 "content": result.get("memory", ""), 

239 "timestamp": result.get("created_at") or result.get("updated_at", ""), 

240 "metadata": result.get("metadata", {}), 

241 "score": result.get("score"), # Semantic similarity score 

242 } 

243 memories.append(memory) 

244 

245 logger.debug(f"Search for '{query}' returned {len(memories)} results") 

246 

247 return self._create_success_response( 

248 result=memories, 

249 message=f"Found {len(memories)} matching memories for query: {query}", 

250 ) 

251 

252 except Exception as e: 

253 logger.error(f"Error searching mem0: {e}", exc_info=True) 

254 return self._create_error_response( 

255 error="search_error", message=f"Failed to search memories: {str(e)}" 

256 ) 

257 

258 async def get_all(self) -> dict: 

259 """Get all memories from storage for current namespace. 

260 

261 Returns: 

262 Structured response dict with all memories 

263 """ 

264 try: 

265 # Get all memories for user 

266 results = await asyncio.to_thread(self.memory.get_all, user_id=self.namespace) 

267 

268 memories = [] 

269 if results and "results" in results: 

270 for result in results["results"]: 

271 memory = { 

272 "id": result.get("id") or result.get("memory_id"), 

273 "role": "assistant", 

274 "content": result.get("memory", ""), 

275 "timestamp": result.get("created_at") or result.get("updated_at", ""), 

276 "metadata": result.get("metadata", {}), 

277 } 

278 memories.append(memory) 

279 

280 return self._create_success_response(result=memories, message="Retrieved all memories") 

281 

282 except Exception as e: 

283 logger.error(f"Error retrieving all memories from mem0: {e}", exc_info=True) 

284 return self._create_error_response( 

285 error="retrieval_error", message=f"Failed to retrieve memories: {str(e)}" 

286 ) 

287 

288 async def get_recent(self, limit: int = 10) -> dict: 

289 """Get recent memories. 

290 

291 Args: 

292 limit: Number of recent memories to retrieve 

293 

294 Returns: 

295 Structured response dict with recent memories 

296 """ 

297 try: 

298 # Get all and sort by timestamp 

299 all_result = await self.get_all() 

300 

301 if not all_result.get("success"): 

302 return all_result 

303 

304 memories = all_result["result"] 

305 

306 # Sort by timestamp (most recent first) with proper datetime parsing 

307 def parse_timestamp(memory: dict) -> datetime: 

308 """Parse timestamp with fallback to epoch.""" 

309 timestamp_str = memory.get("timestamp", "") 

310 try: 

311 # Try parsing ISO format 

312 dt = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")) 

313 # Ensure timezone-aware for comparison 

314 if dt.tzinfo is None: 

315 dt = dt.replace(tzinfo=UTC) 

316 return dt 

317 except (ValueError, AttributeError): 

318 # Fallback to epoch (oldest possible time) in UTC 

319 return datetime.min.replace(tzinfo=UTC) 

320 

321 sorted_memories = sorted(memories, key=parse_timestamp, reverse=True) 

322 

323 # Take most recent N 

324 recent = sorted_memories[:limit] 

325 

326 return self._create_success_response( 

327 result=recent, message=f"Retrieved {len(recent)} recent memories" 

328 ) 

329 

330 except Exception as e: 

331 logger.error(f"Error retrieving recent memories: {e}", exc_info=True) 

332 return self._create_error_response( 

333 error="retrieval_error", message=f"Failed to retrieve recent memories: {str(e)}" 

334 ) 

335 

336 async def clear(self) -> dict: 

337 """Clear all memories from storage for current namespace. 

338 

339 Returns: 

340 Structured response dict with success status 

341 """ 

342 try: 

343 # Delete all memories for user 

344 await asyncio.to_thread(self.memory.delete_all, user_id=self.namespace) 

345 

346 logger.info(f"Cleared all memories for namespace: {self.namespace}") 

347 

348 return self._create_success_response( 

349 result=None, message=f"Cleared all memories for namespace: {self.namespace}" 

350 ) 

351 

352 except Exception as e: 

353 logger.error(f"Error clearing mem0 memories: {e}", exc_info=True) 

354 return self._create_error_response( 

355 error="clear_error", message=f"Failed to clear memories: {str(e)}" 

356 ) 

357 

358 async def retrieve_for_context(self, messages: list[dict], limit: int = 10) -> dict: 

359 """Retrieve semantically relevant memories for context injection. 

360 

361 Overrides default implementation to use mem0's semantic search 

362 for better relevance matching. 

363 

364 Args: 

365 messages: Current conversation messages 

366 limit: Maximum number of memories to retrieve 

367 

368 Returns: 

369 Structured response dict with relevant memories 

370 """ 

371 # Extract query from latest user message 

372 query = None 

373 for msg in reversed(messages): 

374 if isinstance(msg, dict) and msg.get("role") == "user": 

375 query = msg.get("content", "").strip() 

376 break 

377 

378 # Use semantic search if we have a query 

379 if query: 

380 logger.debug(f"Using semantic search for context: {query[:50]}...") 

381 return await self.search(query, limit=limit) 

382 else: 

383 # Fall back to recent memories 

384 logger.debug("No query found, using recent memories for context") 

385 return await self.get_recent(limit=limit)