Coverage for src / agent / memory / adapter.py: 69%

39 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-15 17:07 +0000

1# Copyright 2025-2026 Microsoft Corporation 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Adapter layer to unify mem0 local (Memory) and cloud (MemoryClient) APIs. 

16 

17This module provides a consistent interface regardless of whether mem0 is 

18running in local mode (Chroma storage) or cloud mode (mem0.ai service). 

19""" 

20 

21from abc import ABC, abstractmethod 

22from typing import Any 

23 

24 

25class Mem0Adapter(ABC): 

26 """Abstract adapter for mem0 operations. 

27 

28 Provides a unified interface that hides the API differences between 

29 mem0's local Memory class and cloud MemoryClient class. 

30 """ 

31 

32 @abstractmethod 

33 def add(self, messages: list[dict], user_id: str) -> Any: 

34 """Add messages to memory storage.""" 

35 pass 

36 

37 @abstractmethod 

38 def search(self, query: str, user_id: str, limit: int) -> list[dict]: 

39 """Search memories and return normalized results.""" 

40 pass 

41 

42 @abstractmethod 

43 def get_all(self, user_id: str) -> list[dict]: 

44 """Get all memories and return normalized results.""" 

45 pass 

46 

47 @abstractmethod 

48 def delete_all(self, user_id: str) -> None: 

49 """Delete all memories for a user.""" 

50 pass 

51 

52 

53class LocalMem0Adapter(Mem0Adapter): 

54 """Adapter for mem0 local mode using Memory class with Chroma storage.""" 

55 

56 def __init__(self, memory: Any): 

57 """Initialize with a mem0 Memory instance. 

58 

59 Args: 

60 memory: mem0.Memory instance configured for local storage 

61 """ 

62 self._memory = memory 

63 

64 def add(self, messages: list[dict], user_id: str) -> Any: 

65 """Add messages using local Memory API.""" 

66 return self._memory.add(messages=messages, user_id=user_id) 

67 

68 def search(self, query: str, user_id: str, limit: int) -> list[dict]: 

69 """Search using local Memory API and normalize results.""" 

70 results = self._memory.search(query=query, user_id=user_id, limit=limit) 

71 return self._normalize_results(results) 

72 

73 def get_all(self, user_id: str) -> list[dict]: 

74 """Get all using local Memory API and normalize results.""" 

75 results = self._memory.get_all(user_id=user_id) 

76 return self._normalize_results(results) 

77 

78 def delete_all(self, user_id: str) -> None: 

79 """Delete all using local Memory API.""" 

80 self._memory.delete_all(user_id=user_id) 

81 

82 def _normalize_results(self, results: Any) -> list[dict]: 

83 """Normalize local Memory results to list format. 

84 

85 Local Memory returns: {"results": [...]} 

86 """ 

87 if isinstance(results, dict) and "results" in results: 

88 result_list: list[dict] = results["results"] 

89 return result_list 

90 return results if isinstance(results, list) else [] 

91 

92 

93class CloudMem0Adapter(Mem0Adapter): 

94 """Adapter for mem0 cloud mode using MemoryClient.""" 

95 

96 def __init__(self, client: Any): 

97 """Initialize with a mem0 MemoryClient instance. 

98 

99 Args: 

100 client: mem0.MemoryClient instance for cloud API 

101 """ 

102 self._client = client 

103 

104 def add(self, messages: list[dict], user_id: str) -> Any: 

105 """Add messages using cloud MemoryClient API.""" 

106 return self._client.add(messages=messages, user_id=user_id) 

107 

108 def search(self, query: str, user_id: str, limit: int) -> list[dict]: 

109 """Search using cloud MemoryClient API with filters.""" 

110 # Cloud API requires filters dict instead of user_id parameter 

111 results = self._client.search( 

112 query=query, 

113 filters={"user_id": user_id}, 

114 top_k=limit, 

115 ) 

116 return self._normalize_results(results) 

117 

118 def get_all(self, user_id: str) -> list[dict]: 

119 """Get all using cloud MemoryClient API with filters.""" 

120 results = self._client.get_all(filters={"user_id": user_id}) 

121 return self._normalize_results(results) 

122 

123 def delete_all(self, user_id: str) -> None: 

124 """Delete all using cloud MemoryClient API with filters.""" 

125 self._client.delete_all(filters={"user_id": user_id}) 

126 

127 def _normalize_results(self, results: Any) -> list[dict]: 

128 """Normalize cloud MemoryClient results to list format. 

129 

130 Cloud MemoryClient returns: {"results": [...]} (dict with results key) 

131 Note: Despite earlier assumptions, cloud API returns same format as local. 

132 """ 

133 if isinstance(results, dict) and "results" in results: 

134 result_list: list[dict] = results["results"] 

135 return result_list 

136 return results if isinstance(results, list) else []