Coverage for src / agent / memory / adapter.py: 69%
39 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 17:07 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 17:07 +0000
1# Copyright 2025-2026 Microsoft Corporation
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Adapter layer to unify mem0 local (Memory) and cloud (MemoryClient) APIs.
17This module provides a consistent interface regardless of whether mem0 is
18running in local mode (Chroma storage) or cloud mode (mem0.ai service).
19"""
21from abc import ABC, abstractmethod
22from typing import Any
25class Mem0Adapter(ABC):
26 """Abstract adapter for mem0 operations.
28 Provides a unified interface that hides the API differences between
29 mem0's local Memory class and cloud MemoryClient class.
30 """
32 @abstractmethod
33 def add(self, messages: list[dict], user_id: str) -> Any:
34 """Add messages to memory storage."""
35 pass
37 @abstractmethod
38 def search(self, query: str, user_id: str, limit: int) -> list[dict]:
39 """Search memories and return normalized results."""
40 pass
42 @abstractmethod
43 def get_all(self, user_id: str) -> list[dict]:
44 """Get all memories and return normalized results."""
45 pass
47 @abstractmethod
48 def delete_all(self, user_id: str) -> None:
49 """Delete all memories for a user."""
50 pass
53class LocalMem0Adapter(Mem0Adapter):
54 """Adapter for mem0 local mode using Memory class with Chroma storage."""
56 def __init__(self, memory: Any):
57 """Initialize with a mem0 Memory instance.
59 Args:
60 memory: mem0.Memory instance configured for local storage
61 """
62 self._memory = memory
64 def add(self, messages: list[dict], user_id: str) -> Any:
65 """Add messages using local Memory API."""
66 return self._memory.add(messages=messages, user_id=user_id)
68 def search(self, query: str, user_id: str, limit: int) -> list[dict]:
69 """Search using local Memory API and normalize results."""
70 results = self._memory.search(query=query, user_id=user_id, limit=limit)
71 return self._normalize_results(results)
73 def get_all(self, user_id: str) -> list[dict]:
74 """Get all using local Memory API and normalize results."""
75 results = self._memory.get_all(user_id=user_id)
76 return self._normalize_results(results)
78 def delete_all(self, user_id: str) -> None:
79 """Delete all using local Memory API."""
80 self._memory.delete_all(user_id=user_id)
82 def _normalize_results(self, results: Any) -> list[dict]:
83 """Normalize local Memory results to list format.
85 Local Memory returns: {"results": [...]}
86 """
87 if isinstance(results, dict) and "results" in results:
88 result_list: list[dict] = results["results"]
89 return result_list
90 return results if isinstance(results, list) else []
93class CloudMem0Adapter(Mem0Adapter):
94 """Adapter for mem0 cloud mode using MemoryClient."""
96 def __init__(self, client: Any):
97 """Initialize with a mem0 MemoryClient instance.
99 Args:
100 client: mem0.MemoryClient instance for cloud API
101 """
102 self._client = client
104 def add(self, messages: list[dict], user_id: str) -> Any:
105 """Add messages using cloud MemoryClient API."""
106 return self._client.add(messages=messages, user_id=user_id)
108 def search(self, query: str, user_id: str, limit: int) -> list[dict]:
109 """Search using cloud MemoryClient API with filters."""
110 # Cloud API requires filters dict instead of user_id parameter
111 results = self._client.search(
112 query=query,
113 filters={"user_id": user_id},
114 top_k=limit,
115 )
116 return self._normalize_results(results)
118 def get_all(self, user_id: str) -> list[dict]:
119 """Get all using cloud MemoryClient API with filters."""
120 results = self._client.get_all(filters={"user_id": user_id})
121 return self._normalize_results(results)
123 def delete_all(self, user_id: str) -> None:
124 """Delete all using cloud MemoryClient API with filters."""
125 self._client.delete_all(filters={"user_id": user_id})
127 def _normalize_results(self, results: Any) -> list[dict]:
128 """Normalize cloud MemoryClient results to list format.
130 Cloud MemoryClient returns: {"results": [...]} (dict with results key)
131 Note: Despite earlier assumptions, cloud API returns same format as local.
132 """
133 if isinstance(results, dict) and "results" in results:
134 result_list: list[dict] = results["results"]
135 return result_list
136 return results if isinstance(results, list) else []