Coverage for src / agent / providers / gemini / chat_client.py: 94%

80 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-11 14:30 +0000

1# Copyright 2025-2026 Microsoft Corporation 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15""" 

16Google Gemini chat client implementation. 

17 

18This module provides a custom GeminiChatClient that integrates Google's Gemini 

19models with the Microsoft Agent Framework by extending BaseChatClient. 

20""" 

21 

22import logging 

23from collections.abc import AsyncIterator 

24from typing import Any 

25 

26from agent_framework import ( 

27 BaseChatClient, 

28 ChatMessage, 

29 ChatOptions, 

30 ChatResponse, 

31 ChatResponseUpdate, 

32 use_function_invocation, 

33) 

34from google import genai 

35 

36from .types import ( 

37 extract_usage_metadata, 

38 from_gemini_message, 

39 to_gemini_message, 

40 to_gemini_tools, 

41) 

42 

43logger = logging.getLogger(__name__) 

44 

45 

46@use_function_invocation 

47class GeminiChatClient(BaseChatClient): 

48 """Chat client for Google Gemini models. 

49 

50 This client extends BaseChatClient to provide integration with Google's 

51 Gemini API, supporting both API key authentication (Gemini Developer API) 

52 and Vertex AI authentication (Google Cloud Platform). 

53 

54 Args: 

55 model_id: Gemini model name (e.g., "gemini-2.0-flash-exp", "gemini-2.5-pro") 

56 api_key: Gemini API key for developer API authentication (optional) 

57 project_id: GCP project ID for Vertex AI authentication (optional) 

58 location: GCP location for Vertex AI (e.g., "us-central1") (optional) 

59 use_vertexai: Whether to use Vertex AI authentication (default: False) 

60 

61 Example: 

62 >>> # Using API key 

63 >>> client = GeminiChatClient( 

64 ... model_id="gemini-2.0-flash-exp", 

65 ... api_key="your-api-key" 

66 ... ) 

67 

68 >>> # Using Vertex AI 

69 >>> client = GeminiChatClient( 

70 ... model_id="gemini-2.5-pro", 

71 ... project_id="your-gcp-project", 

72 ... location="us-central1", 

73 ... use_vertexai=True 

74 ... ) 

75 """ 

76 

77 # OpenTelemetry provider name for tracing 

78 OTEL_PROVIDER_NAME = "gemini" 

79 

80 def __init__( 

81 self, 

82 model_id: str, 

83 api_key: str | None = None, 

84 project_id: str | None = None, 

85 location: str | None = None, 

86 use_vertexai: bool = False, 

87 ): 

88 """Initialize GeminiChatClient with authentication credentials. 

89 

90 Args: 

91 model_id: Gemini model name 

92 api_key: API key for Gemini Developer API 

93 project_id: GCP project ID for Vertex AI 

94 location: GCP location for Vertex AI 

95 use_vertexai: Use Vertex AI authentication 

96 

97 Raises: 

98 ValueError: If required credentials are missing 

99 """ 

100 super().__init__() 

101 

102 self.model_id = model_id 

103 self.use_vertexai = use_vertexai 

104 

105 # Initialize Google Gen AI client based on authentication method 

106 if use_vertexai: 

107 # Vertex AI authentication (uses Google Cloud default credentials) 

108 if not project_id or not location: 

109 raise ValueError("Vertex AI authentication requires both project_id and location") 

110 self.client = genai.Client( 

111 vertexai=True, 

112 project=project_id, 

113 location=location, 

114 ) 

115 logger.info( 

116 f"Initialized Gemini client with Vertex AI (project={project_id}, location={location})" 

117 ) 

118 else: 

119 # API key authentication 

120 if not api_key: 

121 raise ValueError("API key authentication requires api_key") 

122 self.client = genai.Client(api_key=api_key) 

123 logger.info("Initialized Gemini client with API key") 

124 

125 def _build_call_id_mapping(self, messages: list[ChatMessage]) -> dict[str, str]: 

126 """Build mapping from function call_id to function name. 

127 

128 Args: 

129 messages: List of chat messages 

130 

131 Returns: 

132 Dictionary mapping call_id to function name 

133 """ 

134 call_id_to_name: dict[str, str] = {} 

135 for m in messages: 

136 for c in getattr(m, "contents", []) or []: 

137 try: 

138 # Lazy import to avoid circular types 

139 from agent_framework import FunctionCallContent 

140 

141 if isinstance(c, FunctionCallContent): 

142 call_id_to_name[c.call_id] = c.name 

143 except Exception: 

144 # Ignore errors during FunctionCallContent type check 

145 # Not all contents are FunctionCallContent 

146 pass 

147 return call_id_to_name 

148 

149 def _prepare_options( 

150 self, messages: list[ChatMessage], chat_options: ChatOptions | None = None 

151 ) -> dict[str, Any]: 

152 """Prepare generation config from ChatOptions. 

153 

154 Args: 

155 messages: List of chat messages 

156 chat_options: Optional chat configuration 

157 

158 Returns: 

159 Dictionary with Gemini generation configuration 

160 """ 

161 config: dict[str, Any] = {} 

162 

163 if chat_options: 

164 # Map temperature 

165 if chat_options.temperature is not None: 

166 config["temperature"] = chat_options.temperature 

167 

168 # Map max_tokens 

169 if chat_options.max_tokens is not None: 

170 config["max_output_tokens"] = chat_options.max_tokens 

171 

172 # Map top_p 

173 if chat_options.top_p is not None: 

174 config["top_p"] = chat_options.top_p 

175 

176 # Handle tools/functions (pass via config for google-genai) 

177 tools = chat_options.tools() if callable(chat_options.tools) else chat_options.tools 

178 if tools: 

179 config["tools"] = to_gemini_tools(tools) # type: ignore[arg-type] 

180 

181 return config 

182 

183 def _handle_gemini_error(self, error: Exception) -> Exception: 

184 """Map Gemini SDK exceptions to agent-framework exceptions. 

185 

186 Args: 

187 error: Exception from Gemini SDK 

188 

189 Returns: 

190 Mapped exception for agent-framework 

191 """ 

192 # For now, pass through the original exception 

193 # In the future, we can map specific Gemini exceptions to framework exceptions 

194 import traceback 

195 

196 logger.error(f"Gemini API error: {error}") 

197 logger.error(f"Traceback: {traceback.format_exc()}") 

198 return error 

199 

200 async def _inner_get_response( # type: ignore[override] 

201 self, 

202 *, 

203 messages: list[ChatMessage], 

204 chat_options: ChatOptions, 

205 **kwargs: Any, 

206 ) -> ChatResponse: 

207 """Get non-streaming response from Gemini API. 

208 

209 This method is required by BaseChatClient and handles synchronous 

210 chat completions. 

211 

212 Args: 

213 messages: List of chat messages 

214 chat_options: Optional chat configuration 

215 

216 Returns: 

217 ChatResponse with the model's reply 

218 

219 Raises: 

220 Exception: If API call fails 

221 """ 

222 try: 

223 # Build mapping from function call_id to function name 

224 call_id_to_name = self._build_call_id_mapping(messages) 

225 

226 # Convert messages to Gemini format 

227 gemini_messages = [to_gemini_message(msg, call_id_to_name) for msg in messages] 

228 

229 # Prepare generation config (contains tools when provided) 

230 config = self._prepare_options(messages, chat_options) 

231 

232 # Build contents for Gemini API 

233 # Note: Gemini expects a list of content objects 

234 contents = [] 

235 for msg in gemini_messages: 

236 contents.append({"role": msg["role"], "parts": msg["parts"]}) 

237 

238 # Call Gemini API (synchronous) 

239 response = self.client.models.generate_content( 

240 model=self.model_id, 

241 contents=contents, 

242 config=config if config else None, # type: ignore[arg-type] 

243 ) 

244 

245 # Convert response to ChatResponse 

246 chat_message = from_gemini_message(response) 

247 

248 # Extract usage metadata 

249 usage = extract_usage_metadata(response) 

250 

251 # Create ChatResponse with usage_details (let class handle dict -> UsageDetails) 

252 return ChatResponse(messages=[chat_message], usage_details=usage or None) # type: ignore[arg-type] 

253 

254 except Exception as e: 

255 raise self._handle_gemini_error(e) 

256 

257 async def _inner_get_streaming_response( # type: ignore[override] 

258 self, 

259 *, 

260 messages: list[ChatMessage], 

261 chat_options: ChatOptions, 

262 **kwargs: Any, 

263 ) -> AsyncIterator[ChatResponseUpdate]: 

264 """Get streaming response from Gemini API. 

265 

266 This method is required by BaseChatClient and handles streaming 

267 chat completions, yielding response chunks as they arrive. 

268 

269 Args: 

270 messages: List of chat messages 

271 chat_options: Optional chat configuration 

272 

273 Yields: 

274 ChatResponseUpdate objects with response chunks 

275 

276 Raises: 

277 Exception: If API call fails 

278 """ 

279 try: 

280 # Build mapping from function call_id to function name 

281 call_id_to_name = self._build_call_id_mapping(messages) 

282 

283 # Convert messages to Gemini format 

284 gemini_messages = [to_gemini_message(msg, call_id_to_name) for msg in messages] 

285 

286 # Prepare generation config (contains tools when provided) 

287 config = self._prepare_options(messages, chat_options) 

288 

289 # Build contents for Gemini API 

290 contents = [] 

291 for msg in gemini_messages: 

292 contents.append({"role": msg["role"], "parts": msg["parts"]}) 

293 

294 # Call Gemini API with streaming 

295 stream = self.client.models.generate_content_stream( 

296 model=self.model_id, 

297 contents=contents, 

298 config=config if config else None, # type: ignore[arg-type] 

299 ) 

300 

301 # Yield chunks as they arrive 

302 for chunk in stream: 

303 if hasattr(chunk, "text") and chunk.text: 

304 yield ChatResponseUpdate( 

305 text=chunk.text, 

306 role="assistant", 

307 ) 

308 

309 except Exception as e: 

310 raise self._handle_gemini_error(e)