Coverage for src / agent / providers / gemini / chat_client.py: 94%
80 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 14:30 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 14:30 +0000
1# Copyright 2025-2026 Microsoft Corporation
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""
16Google Gemini chat client implementation.
18This module provides a custom GeminiChatClient that integrates Google's Gemini
19models with the Microsoft Agent Framework by extending BaseChatClient.
20"""
22import logging
23from collections.abc import AsyncIterator
24from typing import Any
26from agent_framework import (
27 BaseChatClient,
28 ChatMessage,
29 ChatOptions,
30 ChatResponse,
31 ChatResponseUpdate,
32 use_function_invocation,
33)
34from google import genai
36from .types import (
37 extract_usage_metadata,
38 from_gemini_message,
39 to_gemini_message,
40 to_gemini_tools,
41)
43logger = logging.getLogger(__name__)
46@use_function_invocation
47class GeminiChatClient(BaseChatClient):
48 """Chat client for Google Gemini models.
50 This client extends BaseChatClient to provide integration with Google's
51 Gemini API, supporting both API key authentication (Gemini Developer API)
52 and Vertex AI authentication (Google Cloud Platform).
54 Args:
55 model_id: Gemini model name (e.g., "gemini-2.0-flash-exp", "gemini-2.5-pro")
56 api_key: Gemini API key for developer API authentication (optional)
57 project_id: GCP project ID for Vertex AI authentication (optional)
58 location: GCP location for Vertex AI (e.g., "us-central1") (optional)
59 use_vertexai: Whether to use Vertex AI authentication (default: False)
61 Example:
62 >>> # Using API key
63 >>> client = GeminiChatClient(
64 ... model_id="gemini-2.0-flash-exp",
65 ... api_key="your-api-key"
66 ... )
68 >>> # Using Vertex AI
69 >>> client = GeminiChatClient(
70 ... model_id="gemini-2.5-pro",
71 ... project_id="your-gcp-project",
72 ... location="us-central1",
73 ... use_vertexai=True
74 ... )
75 """
77 # OpenTelemetry provider name for tracing
78 OTEL_PROVIDER_NAME = "gemini"
80 def __init__(
81 self,
82 model_id: str,
83 api_key: str | None = None,
84 project_id: str | None = None,
85 location: str | None = None,
86 use_vertexai: bool = False,
87 ):
88 """Initialize GeminiChatClient with authentication credentials.
90 Args:
91 model_id: Gemini model name
92 api_key: API key for Gemini Developer API
93 project_id: GCP project ID for Vertex AI
94 location: GCP location for Vertex AI
95 use_vertexai: Use Vertex AI authentication
97 Raises:
98 ValueError: If required credentials are missing
99 """
100 super().__init__()
102 self.model_id = model_id
103 self.use_vertexai = use_vertexai
105 # Initialize Google Gen AI client based on authentication method
106 if use_vertexai:
107 # Vertex AI authentication (uses Google Cloud default credentials)
108 if not project_id or not location:
109 raise ValueError("Vertex AI authentication requires both project_id and location")
110 self.client = genai.Client(
111 vertexai=True,
112 project=project_id,
113 location=location,
114 )
115 logger.info(
116 f"Initialized Gemini client with Vertex AI (project={project_id}, location={location})"
117 )
118 else:
119 # API key authentication
120 if not api_key:
121 raise ValueError("API key authentication requires api_key")
122 self.client = genai.Client(api_key=api_key)
123 logger.info("Initialized Gemini client with API key")
125 def _build_call_id_mapping(self, messages: list[ChatMessage]) -> dict[str, str]:
126 """Build mapping from function call_id to function name.
128 Args:
129 messages: List of chat messages
131 Returns:
132 Dictionary mapping call_id to function name
133 """
134 call_id_to_name: dict[str, str] = {}
135 for m in messages:
136 for c in getattr(m, "contents", []) or []:
137 try:
138 # Lazy import to avoid circular types
139 from agent_framework import FunctionCallContent
141 if isinstance(c, FunctionCallContent):
142 call_id_to_name[c.call_id] = c.name
143 except Exception:
144 # Ignore errors during FunctionCallContent type check
145 # Not all contents are FunctionCallContent
146 pass
147 return call_id_to_name
149 def _prepare_options(
150 self, messages: list[ChatMessage], chat_options: ChatOptions | None = None
151 ) -> dict[str, Any]:
152 """Prepare generation config from ChatOptions.
154 Args:
155 messages: List of chat messages
156 chat_options: Optional chat configuration
158 Returns:
159 Dictionary with Gemini generation configuration
160 """
161 config: dict[str, Any] = {}
163 if chat_options:
164 # Map temperature
165 if chat_options.temperature is not None:
166 config["temperature"] = chat_options.temperature
168 # Map max_tokens
169 if chat_options.max_tokens is not None:
170 config["max_output_tokens"] = chat_options.max_tokens
172 # Map top_p
173 if chat_options.top_p is not None:
174 config["top_p"] = chat_options.top_p
176 # Handle tools/functions (pass via config for google-genai)
177 tools = chat_options.tools() if callable(chat_options.tools) else chat_options.tools
178 if tools:
179 config["tools"] = to_gemini_tools(tools) # type: ignore[arg-type]
181 return config
183 def _handle_gemini_error(self, error: Exception) -> Exception:
184 """Map Gemini SDK exceptions to agent-framework exceptions.
186 Args:
187 error: Exception from Gemini SDK
189 Returns:
190 Mapped exception for agent-framework
191 """
192 # For now, pass through the original exception
193 # In the future, we can map specific Gemini exceptions to framework exceptions
194 import traceback
196 logger.error(f"Gemini API error: {error}")
197 logger.error(f"Traceback: {traceback.format_exc()}")
198 return error
200 async def _inner_get_response( # type: ignore[override]
201 self,
202 *,
203 messages: list[ChatMessage],
204 chat_options: ChatOptions,
205 **kwargs: Any,
206 ) -> ChatResponse:
207 """Get non-streaming response from Gemini API.
209 This method is required by BaseChatClient and handles synchronous
210 chat completions.
212 Args:
213 messages: List of chat messages
214 chat_options: Optional chat configuration
216 Returns:
217 ChatResponse with the model's reply
219 Raises:
220 Exception: If API call fails
221 """
222 try:
223 # Build mapping from function call_id to function name
224 call_id_to_name = self._build_call_id_mapping(messages)
226 # Convert messages to Gemini format
227 gemini_messages = [to_gemini_message(msg, call_id_to_name) for msg in messages]
229 # Prepare generation config (contains tools when provided)
230 config = self._prepare_options(messages, chat_options)
232 # Build contents for Gemini API
233 # Note: Gemini expects a list of content objects
234 contents = []
235 for msg in gemini_messages:
236 contents.append({"role": msg["role"], "parts": msg["parts"]})
238 # Call Gemini API (synchronous)
239 response = self.client.models.generate_content(
240 model=self.model_id,
241 contents=contents,
242 config=config if config else None, # type: ignore[arg-type]
243 )
245 # Convert response to ChatResponse
246 chat_message = from_gemini_message(response)
248 # Extract usage metadata
249 usage = extract_usage_metadata(response)
251 # Create ChatResponse with usage_details (let class handle dict -> UsageDetails)
252 return ChatResponse(messages=[chat_message], usage_details=usage or None) # type: ignore[arg-type]
254 except Exception as e:
255 raise self._handle_gemini_error(e)
257 async def _inner_get_streaming_response( # type: ignore[override]
258 self,
259 *,
260 messages: list[ChatMessage],
261 chat_options: ChatOptions,
262 **kwargs: Any,
263 ) -> AsyncIterator[ChatResponseUpdate]:
264 """Get streaming response from Gemini API.
266 This method is required by BaseChatClient and handles streaming
267 chat completions, yielding response chunks as they arrive.
269 Args:
270 messages: List of chat messages
271 chat_options: Optional chat configuration
273 Yields:
274 ChatResponseUpdate objects with response chunks
276 Raises:
277 Exception: If API call fails
278 """
279 try:
280 # Build mapping from function call_id to function name
281 call_id_to_name = self._build_call_id_mapping(messages)
283 # Convert messages to Gemini format
284 gemini_messages = [to_gemini_message(msg, call_id_to_name) for msg in messages]
286 # Prepare generation config (contains tools when provided)
287 config = self._prepare_options(messages, chat_options)
289 # Build contents for Gemini API
290 contents = []
291 for msg in gemini_messages:
292 contents.append({"role": msg["role"], "parts": msg["parts"]})
294 # Call Gemini API with streaming
295 stream = self.client.models.generate_content_stream(
296 model=self.model_id,
297 contents=contents,
298 config=config if config else None, # type: ignore[arg-type]
299 )
301 # Yield chunks as they arrive
302 for chunk in stream:
303 if hasattr(chunk, "text") and chunk.text:
304 yield ChatResponseUpdate(
305 text=chunk.text,
306 role="assistant",
307 )
309 except Exception as e:
310 raise self._handle_gemini_error(e)