Coverage for src / agent / memory / mem0_store.py: 90%
132 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 14:30 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 14:30 +0000
1# Copyright 2025-2026 Microsoft Corporation
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Mem0-based semantic memory storage implementation.
17This module provides semantic memory storage using mem0's Python library
18with support for local Chroma storage and cloud mem0.ai service.
19"""
21import asyncio
22import logging
23import re
24from datetime import UTC, datetime
26from agent.config.schema import AgentSettings
27from agent.memory.manager import MemoryManager
28from agent.memory.mem0_utils import create_memory_instance
30logger = logging.getLogger(__name__)
32# Patterns for detecting sensitive data (API keys, tokens, passwords)
33SENSITIVE_PATTERNS = [
34 re.compile(r"sk[-_][a-zA-Z0-9_]{20,}", re.IGNORECASE), # API keys (sk_...)
35 re.compile(r"bearer\s+[a-zA-Z0-9\-._~+/]+=*", re.IGNORECASE), # Bearer tokens
36 re.compile(r'api[_-]?key["\s:=]+[a-zA-Z0-9\-._~+/]+', re.IGNORECASE), # API key assignments
37 re.compile(r'token["\s:=]+[a-zA-Z0-9\-._~+/]{20,}', re.IGNORECASE), # Token assignments
38 re.compile(r'password["\s:=]+\S+', re.IGNORECASE), # Password assignments
39]
42class Mem0Store(MemoryManager):
43 """Semantic memory storage using mem0 Python library.
45 Integrates mem0 directly into the agent process, reusing the agent's
46 existing LLM configuration. Supports local Chroma storage or cloud mem0.ai.
48 Attributes:
49 config: Agent configuration with mem0 settings
50 memory: mem0.Memory instance
51 user_id: User namespace for memory isolation
52 namespace: Combined user:project namespace
54 Example:
55 >>> config = AgentConfig.from_env()
56 >>> store = Mem0Store(config)
57 >>> await store.add([{"role": "user", "content": "My name is Alice"}])
58 """
60 def __init__(self, config: AgentSettings):
61 """Initialize mem0 store with configuration.
63 Args:
64 config: Agent configuration with mem0 settings
66 Raises:
67 ValueError: If mem0 configuration is invalid
68 ImportError: If mem0ai or chromadb not installed
69 """
70 super().__init__(config)
72 try:
73 self.memory = create_memory_instance(config)
74 logger.info("Mem0Store initialized successfully")
75 except Exception as e:
76 logger.error(f"Failed to initialize Mem0Store: {e}")
77 raise
79 # Set up namespacing for user/project isolation
80 self.user_id = config.mem0_user_id or "default-user"
82 # Build namespace (user:project or just user)
83 if config.mem0_project_id:
84 self.namespace = f"{self.user_id}:{config.mem0_project_id}"
85 else:
86 self.namespace = self.user_id
88 logger.debug(f"Mem0Store namespace: {self.namespace}")
90 def _scrub_sensitive_content(self, content: str) -> tuple[str, bool]:
91 """Scrub potential secrets from content before storage.
93 Args:
94 content: Message content to check
96 Returns:
97 Tuple of (scrubbed_content, was_modified)
98 """
99 modified = False
101 for pattern in SENSITIVE_PATTERNS:
102 if pattern.search(content):
103 content = pattern.sub("[REDACTED]", content)
104 modified = True
106 if modified:
107 logger.warning("Detected and redacted potential secrets from content")
109 return content, modified
111 def _should_save_message(self, msg: dict) -> bool:
112 """Check if message should be saved to memory.
114 Implements safety gates to prevent storing sensitive data.
116 Args:
117 msg: Message dict with role, content, and optional metadata
119 Returns:
120 True if message should be saved, False otherwise
121 """
122 # Check for explicit opt-out
123 metadata = msg.get("metadata", {})
124 if metadata.get("save") is False:
125 logger.debug(f"Skipping message with save=false: {msg.get('role')}")
126 return False
128 # Check for explicit opt-in (overrides filters)
129 if metadata.get("force_save") is True:
130 return True
132 # Default: only save user and assistant messages
133 role = msg.get("role", "")
134 return role in ("user", "assistant")
136 async def add(self, messages: list[dict]) -> dict:
137 """Add messages to mem0 storage with semantic indexing.
139 Messages are automatically indexed with vector embeddings and
140 entity extraction for semantic search.
142 Args:
143 messages: List of message dicts with role and content
145 Returns:
146 Structured response dict with success status
148 Example:
149 >>> await store.add([
150 ... {"role": "user", "content": "My name is Alice"},
151 ... {"role": "assistant", "content": "Nice to meet you, Alice!"}
152 ... ])
153 """
154 if not messages:
155 return self._create_error_response(
156 error="invalid_input", message="No messages provided"
157 )
159 try:
160 # Filter and prepare messages
161 messages_to_add = []
162 for msg in messages:
163 # Validate message structure
164 if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
165 logger.warning(f"Skipping invalid message: {msg}")
166 continue
168 # Apply safety gates
169 if not self._should_save_message(msg):
170 continue
172 content = msg.get("content", "").strip()
173 if not content:
174 continue
176 # Scrub sensitive content
177 scrubbed_content, was_scrubbed = self._scrub_sensitive_content(content)
179 messages_to_add.append({"role": msg["role"], "content": scrubbed_content})
181 if not messages_to_add:
182 return self._create_success_response(
183 result=[], message="No messages to add after filtering"
184 )
186 # Add to mem0 (wrapped to avoid blocking event loop)
187 await asyncio.to_thread(
188 self.memory.add, messages=messages_to_add, user_id=self.namespace
189 )
191 logger.debug(f"Added {len(messages_to_add)} messages to mem0")
193 return self._create_success_response(
194 result=messages_to_add, message=f"Added {len(messages_to_add)} messages to memory"
195 )
197 except Exception as e:
198 logger.error(f"Error adding messages to mem0: {e}", exc_info=True)
199 return self._create_error_response(
200 error="storage_error", message=f"Failed to add messages: {str(e)}"
201 )
203 async def search(self, query: str, limit: int = 5) -> dict:
204 """Search memories by semantic similarity.
206 Uses vector embeddings to find semantically similar memories,
207 not just keyword matches.
209 Args:
210 query: Search query string
211 limit: Maximum number of results
213 Returns:
214 Structured response dict with matching memories
216 Example:
217 >>> result = await store.search("authentication errors", limit=5)
218 >>> # Will also find memories about "login failures"
219 """
220 if not query or not query.strip():
221 return self._create_error_response(
222 error="invalid_query", message="Search query cannot be empty"
223 )
225 try:
226 # Search mem0 with semantic similarity
227 results = await asyncio.to_thread(
228 self.memory.search, query=query, user_id=self.namespace, limit=limit
229 )
231 # Convert mem0 results to standardized format
232 memories = []
233 if results and "results" in results:
234 for result in results["results"]:
235 memory = {
236 "id": result.get("id") or result.get("memory_id"),
237 "role": "assistant", # mem0 doesn't store role separately
238 "content": result.get("memory", ""),
239 "timestamp": result.get("created_at") or result.get("updated_at", ""),
240 "metadata": result.get("metadata", {}),
241 "score": result.get("score"), # Semantic similarity score
242 }
243 memories.append(memory)
245 logger.debug(f"Search for '{query}' returned {len(memories)} results")
247 return self._create_success_response(
248 result=memories,
249 message=f"Found {len(memories)} matching memories for query: {query}",
250 )
252 except Exception as e:
253 logger.error(f"Error searching mem0: {e}", exc_info=True)
254 return self._create_error_response(
255 error="search_error", message=f"Failed to search memories: {str(e)}"
256 )
258 async def get_all(self) -> dict:
259 """Get all memories from storage for current namespace.
261 Returns:
262 Structured response dict with all memories
263 """
264 try:
265 # Get all memories for user
266 results = await asyncio.to_thread(self.memory.get_all, user_id=self.namespace)
268 memories = []
269 if results and "results" in results:
270 for result in results["results"]:
271 memory = {
272 "id": result.get("id") or result.get("memory_id"),
273 "role": "assistant",
274 "content": result.get("memory", ""),
275 "timestamp": result.get("created_at") or result.get("updated_at", ""),
276 "metadata": result.get("metadata", {}),
277 }
278 memories.append(memory)
280 return self._create_success_response(result=memories, message="Retrieved all memories")
282 except Exception as e:
283 logger.error(f"Error retrieving all memories from mem0: {e}", exc_info=True)
284 return self._create_error_response(
285 error="retrieval_error", message=f"Failed to retrieve memories: {str(e)}"
286 )
288 async def get_recent(self, limit: int = 10) -> dict:
289 """Get recent memories.
291 Args:
292 limit: Number of recent memories to retrieve
294 Returns:
295 Structured response dict with recent memories
296 """
297 try:
298 # Get all and sort by timestamp
299 all_result = await self.get_all()
301 if not all_result.get("success"):
302 return all_result
304 memories = all_result["result"]
306 # Sort by timestamp (most recent first) with proper datetime parsing
307 def parse_timestamp(memory: dict) -> datetime:
308 """Parse timestamp with fallback to epoch."""
309 timestamp_str = memory.get("timestamp", "")
310 try:
311 # Try parsing ISO format
312 dt = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
313 # Ensure timezone-aware for comparison
314 if dt.tzinfo is None:
315 dt = dt.replace(tzinfo=UTC)
316 return dt
317 except (ValueError, AttributeError):
318 # Fallback to epoch (oldest possible time) in UTC
319 return datetime.min.replace(tzinfo=UTC)
321 sorted_memories = sorted(memories, key=parse_timestamp, reverse=True)
323 # Take most recent N
324 recent = sorted_memories[:limit]
326 return self._create_success_response(
327 result=recent, message=f"Retrieved {len(recent)} recent memories"
328 )
330 except Exception as e:
331 logger.error(f"Error retrieving recent memories: {e}", exc_info=True)
332 return self._create_error_response(
333 error="retrieval_error", message=f"Failed to retrieve recent memories: {str(e)}"
334 )
336 async def clear(self) -> dict:
337 """Clear all memories from storage for current namespace.
339 Returns:
340 Structured response dict with success status
341 """
342 try:
343 # Delete all memories for user
344 await asyncio.to_thread(self.memory.delete_all, user_id=self.namespace)
346 logger.info(f"Cleared all memories for namespace: {self.namespace}")
348 return self._create_success_response(
349 result=None, message=f"Cleared all memories for namespace: {self.namespace}"
350 )
352 except Exception as e:
353 logger.error(f"Error clearing mem0 memories: {e}", exc_info=True)
354 return self._create_error_response(
355 error="clear_error", message=f"Failed to clear memories: {str(e)}"
356 )
358 async def retrieve_for_context(self, messages: list[dict], limit: int = 10) -> dict:
359 """Retrieve semantically relevant memories for context injection.
361 Overrides default implementation to use mem0's semantic search
362 for better relevance matching.
364 Args:
365 messages: Current conversation messages
366 limit: Maximum number of memories to retrieve
368 Returns:
369 Structured response dict with relevant memories
370 """
371 # Extract query from latest user message
372 query = None
373 for msg in reversed(messages):
374 if isinstance(msg, dict) and msg.get("role") == "user":
375 query = msg.get("content", "").strip()
376 break
378 # Use semantic search if we have a query
379 if query:
380 logger.debug(f"Using semantic search for context: {query[:50]}...")
381 return await self.search(query, limit=limit)
382 else:
383 # Fall back to recent memories
384 logger.debug("No query found, using recent memories for context")
385 return await self.get_recent(limit=limit)