Coverage for src / agent / middleware.py: 78%
276 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 14:30 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 14:30 +0000
1# Copyright 2025-2026 Microsoft Corporation
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Middleware functions for Agent execution pipeline.
17This module provides middleware functions for logging and event emission
18in the agent request/response pipeline. Middleware allows automatic event
19tracking without modifying individual tools.
21Middleware Types:
22 - Agent-level: Wrap entire agent execution (LLM calls)
23 - Function-level: Wrap individual tool calls
25Event Emission:
26 Middleware emits execution events for display visualization when
27 should_show_visualization() returns True. This enables real-time
28 execution tree display without coupling tools to display logic.
29"""
31import json
32import logging
33import time
34import uuid
35from collections.abc import Awaitable, Callable
37# TYPE_CHECKING import for forward reference
38from typing import TYPE_CHECKING, Any, cast
40from agent_framework import (
41 AgentRunContext,
42 FunctionInvocationContext,
43 FunctionMiddleware,
44)
46from agent.config.manager import load_config
47from agent.config.schema import AgentSettings
49if TYPE_CHECKING:
50 from agent.trace_logger import TraceLogger
52logger = logging.getLogger(__name__)
54# Global trace logger instance (set by session setup)
55_trace_logger: "TraceLogger | None" = None
58def set_trace_logger(trace_logger: "TraceLogger | None") -> None:
59 """Set the global trace logger instance.
61 Args:
62 trace_logger: TraceLogger instance or None
63 """
64 global _trace_logger
65 _trace_logger = trace_logger
68def get_trace_logger() -> "TraceLogger | None":
69 """Get the current trace logger instance.
71 Returns:
72 TraceLogger instance or None
73 """
74 return _trace_logger
77def _extract_model_from_config(config: AgentSettings) -> str | None:
78 """Extract model name from config based on provider.
80 Args:
81 config: Agent configuration
83 Returns:
84 Model name or None
85 """
86 provider = config.llm_provider
87 if provider == "openai":
88 return config.openai_model
89 elif provider == "anthropic":
90 return config.anthropic_model
91 elif provider == "azure":
92 return config.azure_openai_deployment
93 elif provider == "gemini":
94 return config.gemini_model
95 elif provider == "github":
96 return config.github_model
97 elif provider == "local":
98 return config.local_model
99 elif provider == "foundry":
100 return config.azure_model_deployment
101 return None
104# ============================================================================
105# Agent-Level Middleware
106# ============================================================================
109async def agent_run_logging_middleware(
110 context: AgentRunContext,
111 next: Callable[[AgentRunContext], Awaitable[None]],
112) -> None:
113 """Log agent execution lifecycle and emit LLM request/response events.
115 This middleware:
116 - Logs agent execution start/complete
117 - Emits LLMRequestEvent before LLM call
118 - Emits LLMResponseEvent after LLM call with duration
119 - Captures trace-level LLM request/response data (if enabled)
120 - Only emits events if should_show_visualization() is True
122 Args:
123 context: Agent run context containing messages and state
124 next: Next middleware in chain
126 Example:
127 >>> middleware = {"agent": [agent_run_logging_middleware]}
128 >>> agent = chat_client.create_agent(..., middleware=middleware)
129 """
130 from agent.display import (
131 LLMRequestEvent,
132 LLMResponseEvent,
133 get_event_emitter,
134 should_show_visualization,
135 )
137 logger.debug("Agent run starting...")
139 # Generate request ID for trace logging
140 request_id = str(uuid.uuid4())
142 # Load config for trace logging if enabled (reused for request and response)
143 trace_logger = get_trace_logger()
144 config = load_config() if trace_logger else None
146 # Emit LLM request event
147 llm_event_id = None
148 if should_show_visualization():
149 message_count = len(context.messages) if hasattr(context, "messages") else 0
150 event = LLMRequestEvent(message_count=message_count)
151 llm_event_id = event.event_id
152 get_event_emitter().emit(event)
153 logger.debug(f"Emitted LLM request event with {message_count} messages")
155 # Trace logging: Log request data with full context
156 if trace_logger:
157 try:
158 # Extract conversation messages
159 messages = []
160 if hasattr(context, "messages"):
161 for msg in context.messages:
162 if hasattr(msg, "to_dict"):
163 messages.append(msg.to_dict())
164 elif hasattr(msg, "model_dump"):
165 messages.append(msg.model_dump())
166 elif hasattr(msg, "dict"):
167 messages.append(msg.dict())
168 else:
169 messages.append({"content": str(msg)})
171 # Get model and provider from config (already loaded at middleware start)
172 assert config is not None, "Config should be loaded when trace logger is enabled"
173 provider = config.llm_provider
174 model = _extract_model_from_config(config)
176 # Extract FULL payload info if include_messages is enabled
177 system_instructions: str | None = None
178 tools_summary: dict[str, Any] | None = None
180 if trace_logger.include_messages:
181 if hasattr(context, "agent"):
182 agent = context.agent
184 # Get system instructions and tools from chat_options
185 if hasattr(agent, "chat_options"):
186 chat_options = agent.chat_options
188 # Get instructions
189 if hasattr(chat_options, "instructions") and chat_options.instructions:
190 system_instructions = chat_options.instructions
192 # Get tools
193 if hasattr(chat_options, "tools") and chat_options.tools:
194 tools = chat_options.tools
195 tools_data = []
196 for tool in tools:
197 tool_dict = (
198 tool.to_dict()
199 if hasattr(tool, "to_dict")
200 else {"name": str(tool)}
201 )
202 tool_json = json.dumps(tool_dict, default=str)
203 tools_data.append(
204 {
205 "name": tool_dict.get("name", "unknown"),
206 "description": (
207 tool_dict.get("description", "")[:100]
208 if tool_dict.get("description")
209 else ""
210 ),
211 "estimated_tokens": len(tool_json) // 4,
212 }
213 )
215 tools_summary = {
216 "count": len(tools),
217 "tools": tools_data,
218 "total_estimated_tokens": sum(
219 t["estimated_tokens"] for t in tools_data
220 ),
221 }
223 # Log request using TraceLogger
224 trace_logger.log_request(
225 request_id=request_id,
226 messages=messages,
227 model=model,
228 provider=provider,
229 system_instructions=system_instructions,
230 tools_summary=tools_summary,
231 )
233 except Exception as e:
234 logger.debug(f"Failed to log trace request: {e}")
236 start_time = time.time()
238 try:
239 await next(context)
240 duration = time.time() - start_time
241 latency_ms = duration * 1000
242 logger.debug("Agent run completed successfully")
244 # Emit LLM response event
245 if should_show_visualization() and llm_event_id:
246 response_event = LLMResponseEvent(duration=duration, event_id=llm_event_id)
247 get_event_emitter().emit(response_event)
248 logger.debug(f"Emitted LLM response event ({duration:.2f}s)")
250 # Trace logging: Log response data with token usage
251 if trace_logger:
252 try:
253 # Extract response content and token usage from result
254 response_content = ""
255 input_tokens = None
256 output_tokens = None
257 total_tokens = None
259 if hasattr(context, "result") and context.result:
260 result = context.result
261 # Extract content
262 if hasattr(result, "text"):
263 response_content = str(result.text)
264 elif hasattr(result, "content"):
265 response_content = str(result.content)
266 elif hasattr(result, "data"):
267 response_content = str(result.data)
268 else:
269 response_content = str(result)
271 # Extract token usage from result.usage_details
272 if hasattr(result, "usage_details") and result.usage_details:
273 usage = result.usage_details
274 if hasattr(usage, "input_token_count"):
275 input_tokens = usage.input_token_count
276 if hasattr(usage, "output_token_count"):
277 output_tokens = usage.output_token_count
278 if hasattr(usage, "total_token_count"):
279 total_tokens = usage.total_token_count
281 # Try to extract token usage from various locations
282 # Check thread for usage info (agent-framework stores it in thread messages)
283 if hasattr(context, "thread") and context.thread:
284 thread = context.thread
285 if hasattr(thread, "messages") and thread.messages:
286 # Check last message (assistant response) for usage
287 last_msg = thread.messages[-1]
289 # Check for usage in message
290 if hasattr(last_msg, "usage") and last_msg.usage:
291 usage = last_msg.usage
292 if hasattr(usage, "input_token_count"):
293 input_tokens = usage.input_token_count
294 if hasattr(usage, "output_token_count"):
295 output_tokens = usage.output_token_count
296 if hasattr(usage, "total_token_count"):
297 total_tokens = usage.total_token_count
299 # Also check contents for UsageContent
300 if hasattr(last_msg, "contents") and last_msg.contents:
301 for content in last_msg.contents:
302 # UsageContent has usage attribute
303 if hasattr(content, "usage") and content.usage:
304 usage = content.usage
305 if hasattr(usage, "input_token_count"):
306 input_tokens = usage.input_token_count
307 if hasattr(usage, "output_token_count"):
308 output_tokens = usage.output_token_count
309 if hasattr(usage, "total_token_count"):
310 total_tokens = usage.total_token_count
312 # Also check metadata for usage info
313 if hasattr(context, "metadata") and context.metadata:
314 metadata = context.metadata
315 if "usage" in metadata:
316 usage = metadata["usage"]
317 if isinstance(usage, dict):
318 input_tokens = usage.get("input_token_count") or usage.get(
319 "input_tokens"
320 )
321 output_tokens = usage.get("output_token_count") or usage.get(
322 "output_tokens"
323 )
324 total_tokens = usage.get("total_token_count") or usage.get(
325 "total_tokens"
326 )
328 # Get model from config (loaded once at middleware start)
329 model = _extract_model_from_config(config) if config else None
331 trace_logger.log_response(
332 request_id=request_id,
333 response_content=response_content,
334 model=model,
335 input_tokens=input_tokens,
336 output_tokens=output_tokens,
337 total_tokens=total_tokens,
338 latency_ms=latency_ms,
339 )
340 except Exception as e:
341 logger.debug(f"Failed to log trace response: {e}")
343 except Exception as e:
344 logger.error(f"Agent run failed: {e}")
346 # Trace logging: Log error
347 if trace_logger:
348 try:
349 duration = time.time() - start_time
350 latency_ms = duration * 1000
351 trace_logger.log_response(
352 request_id=request_id,
353 response_content="",
354 latency_ms=latency_ms,
355 error=str(e),
356 )
357 except Exception as log_err:
358 logger.debug(f"Failed to log trace error: {log_err}")
360 # Classify and wrap provider exceptions for better error messages
361 from agent.cli.error_handler import classify_provider_error
363 wrapped = classify_provider_error(e, config)
364 if wrapped:
365 # Raise wrapped exception with original as cause (preserves stack trace)
366 raise wrapped from e
367 else:
368 # Unknown error, re-raise as-is
369 raise
372async def agent_observability_middleware(
373 context: AgentRunContext,
374 next: Callable[[AgentRunContext], Awaitable[None]],
375) -> None:
376 """Track agent execution duration with logging.
378 This middleware logs execution duration for all agent runs.
379 The agent-framework handles OpenTelemetry instrumentation automatically.
381 Args:
382 context: Agent run context
383 next: Next middleware in chain
385 Example:
386 >>> middleware = {"agent": [agent_observability_middleware]}
387 >>> agent = chat_client.create_agent(..., middleware=middleware)
388 """
389 start_time = time.time()
391 try:
392 await next(context)
393 finally:
394 duration = time.time() - start_time
395 logger.info(f"Agent execution took {duration:.2f}s")
398# ============================================================================
399# Function-Level Middleware
400# ============================================================================
403async def logging_function_middleware(
404 context: FunctionInvocationContext,
405 next: Callable,
406) -> Any:
407 """Middleware to log function/tool calls and emit execution events with OpenTelemetry.
409 This middleware:
410 - Logs tool execution start/complete/error
411 - Emits ToolStartEvent before tool execution
412 - Emits ToolCompleteEvent on success with result summary
413 - Emits ToolErrorEvent on failure
414 - Sets tool context for nested event tracking
415 - Creates OpenTelemetry spans for tool execution (when enabled)
416 - Only emits events if should_show_visualization() is True
418 Args:
419 context: Function invocation context with function metadata and arguments
420 next: Next middleware in chain
422 Returns:
423 Result from tool execution
425 Example:
426 >>> middleware = {"function": [logging_function_middleware]}
427 >>> agent = chat_client.create_agent(..., middleware=middleware)
428 """
429 from agent_framework.observability import OtelAttr, get_meter, get_tracer
430 from opentelemetry import trace as ot_trace
432 from agent.display import (
433 ToolCompleteEvent,
434 ToolErrorEvent,
435 ToolStartEvent,
436 get_current_tool_event_id,
437 get_event_emitter,
438 set_current_tool_event_id,
439 should_show_visualization,
440 )
441 from agent.observability import get_current_agent_span
443 tool_name = context.function.name
444 args = context.arguments
446 logger.info(f"Tool call: {tool_name}")
448 # Emit tool start event (if visualization enabled)
449 tool_event_id = None
450 parent_id = None
451 if should_show_visualization():
452 # Get parent event ID for nested tools
453 parent_id = get_current_tool_event_id()
455 # Convert args to dict for event (using Pydantic v2 model_dump)
456 if hasattr(args, "model_dump"):
457 args_dict = args.model_dump()
458 elif hasattr(args, "dict"):
459 # Fallback for Pydantic v1 compatibility
460 args_dict = args.dict()
461 elif isinstance(args, dict):
462 args_dict = args
463 else:
464 args_dict = {}
466 # Sanitize (remove sensitive keys)
467 safe_args = {
468 k: v
469 for k, v in args_dict.items()
470 if k not in ["token", "api_key", "password", "secret"]
471 }
473 event = ToolStartEvent(tool_name=tool_name, arguments=safe_args, parent_id=parent_id)
474 tool_event_id = event.event_id
475 get_event_emitter().emit(event)
477 # Set tool context for child operations (enables nested event display)
478 set_current_tool_event_id(tool_event_id)
479 logger.debug(f"Set tool context: {tool_name} (event_id: {tool_event_id[:8]}...)")
481 # Check if observability is enabled
482 config = load_config()
483 tracer = get_tracer(__name__) if config.enable_otel else None
484 meter = get_meter(__name__) if config.enable_otel else None
486 start_time = time.time()
488 # Create span context manager (no-op if observability disabled)
489 # Prepare parent context for robust nesting
490 parent_context = None
491 if tracer:
492 try:
493 from opentelemetry import trace as ot_trace
495 current_span = ot_trace.get_current_span()
496 # If current span looks invalid, try the saved agent span
497 if current_span is None or not getattr(current_span, "is_recording", lambda: False)():
498 saved_agent_span = get_current_agent_span()
499 if saved_agent_span is not None:
500 parent_context = ot_trace.set_span_in_context(saved_agent_span)
501 except Exception:
502 parent_context = None
504 span_context = (
505 tracer.start_as_current_span(f"tool.{tool_name}", context=parent_context)
506 if tracer
507 else _noop_context_manager()
508 )
510 with span_context as span:
511 try:
512 # Set span attributes if observability enabled
513 if span and config.enable_otel:
514 # Set GenAI span type
515 span.set_attribute("span_type", "GenAI")
517 # Set operation name and tool info
518 span.set_attribute("gen_ai.operation.name", OtelAttr.TOOL_EXECUTION_OPERATION)
519 span.set_attribute(OtelAttr.TOOL_NAME, tool_name)
521 # Add tool description if available
522 if hasattr(context.function, "description"):
523 span.set_attribute(OtelAttr.TOOL_DESCRIPTION, context.function.description)
525 # Add tool parameters if sensitive data enabled
526 if config.enable_sensitive_data and args:
527 # Convert args to dict for serialization
528 args_data: dict[str, Any] | str
529 if hasattr(args, "model_dump"):
530 args_data = args.model_dump()
531 elif hasattr(args, "dict"):
532 args_data = args.dict()
533 elif isinstance(args, dict):
534 args_data = args
535 else:
536 args_data = str(args)
538 span.set_attribute(
539 OtelAttr.TOOL_ARGUMENTS,
540 json.dumps(args_data) if isinstance(args_data, dict) else args_data,
541 )
543 result = await next(context)
544 duration = time.time() - start_time
545 logger.info(f"Tool call {tool_name} completed successfully ({duration:.2f}s)")
547 # Record metrics if observability enabled
548 if meter and config.enable_otel:
549 duration_histogram = meter.create_histogram(
550 name="tool.execution.duration",
551 description="Tool execution duration in seconds",
552 unit="s",
553 )
554 duration_histogram.record(duration, {"tool": tool_name, "status": "success"})
556 # Set tool result if sensitive data enabled
557 if span and config.enable_otel and config.enable_sensitive_data:
558 result_str = json.dumps(result) if isinstance(result, (dict, list)) else str(result)
559 # Truncate to 1000 chars to avoid excessive data
560 span.set_attribute(OtelAttr.TOOL_RESULT, result_str[:1000])
562 # Emit tool complete event
563 if should_show_visualization() and tool_event_id:
564 # Extract summary from result
565 summary = _extract_tool_summary(tool_name, result)
566 complete_event = ToolCompleteEvent(
567 tool_name=tool_name,
568 result_summary=summary,
569 duration=duration,
570 event_id=tool_event_id,
571 )
572 get_event_emitter().emit(complete_event)
574 return result
575 except Exception as e:
576 duration = time.time() - start_time
577 logger.error(f"Tool call {tool_name} failed: {e}")
579 # Record error metrics if observability enabled
580 if meter and config.enable_otel:
581 duration_histogram = meter.create_histogram(
582 name="tool.execution.duration",
583 description="Tool execution duration in seconds",
584 unit="s",
585 )
586 duration_histogram.record(duration, {"tool": tool_name, "status": "error"})
588 # Capture exception in span
589 if span and config.enable_otel:
590 span.record_exception(e)
591 from opentelemetry.trace import Status, StatusCode
593 span.set_status(Status(StatusCode.ERROR, str(e)))
595 # Emit tool error event
596 if should_show_visualization() and tool_event_id:
597 error_event = ToolErrorEvent(
598 tool_name=tool_name,
599 error_message=str(e),
600 duration=duration,
601 event_id=tool_event_id,
602 )
603 get_event_emitter().emit(error_event)
605 raise
606 finally:
607 # Clear tool context when exiting tool (restore parent)
608 if should_show_visualization():
609 set_current_tool_event_id(parent_id)
610 if parent_id:
611 logger.debug("Restored parent tool context")
612 else:
613 logger.debug(f"Cleared tool context: {tool_name}")
616def _noop_context_manager() -> Any:
617 """No-op context manager for when observability is disabled.
619 Returns:
620 Context manager that does nothing
622 Example:
623 >>> with _noop_context_manager() as span:
624 ... # span is None, no telemetry recorded
625 ... pass
626 """
627 from contextlib import nullcontext
629 return nullcontext(None)
632def _extract_tool_summary(tool_name: str, result: Any) -> str:
633 """Extract human-readable summary from tool result.
635 Attempts to extract meaningful summary from common result patterns:
636 - Dict with "message" key
637 - Dict with "summary" key
638 - String result (truncated to 100 chars)
639 - Other types return "Complete"
641 Args:
642 tool_name: Name of the tool
643 result: Tool result
645 Returns:
646 Brief summary string (max 100 chars)
648 Example:
649 >>> _extract_tool_summary("read_file", {"message": "File read successfully"})
650 'File read successfully'
651 >>> _extract_tool_summary("list_directory", "Long result...")
652 'Long result...' # (truncated if > 100 chars)
653 """
654 if isinstance(result, dict):
655 if "message" in result:
656 msg = str(result["message"])
657 return msg[:100] if len(msg) > 100 else msg
658 elif "summary" in result:
659 summary = str(result["summary"])
660 return summary[:100] if len(summary) > 100 else summary
661 elif isinstance(result, str):
662 return result[:100] if len(result) > 100 else result
664 return "Complete"
667# ============================================================================
668# Middleware Factory
669# ============================================================================
672def create_middleware() -> list:
673 """Create default middleware for agent and function levels.
675 Returns:
676 List of middleware (framework auto-categorizes by type)
678 Note:
679 Memory is handled via ContextProvider, not middleware.
680 See MemoryContextProvider for memory management.
682 Example:
683 >>> from agent.middleware import create_middleware
684 >>> middleware = create_middleware()
685 >>> agent = chat_client.create_agent(
686 ... name="Agent",
687 ... instructions="...",
688 ... tools=tools,
689 ... middleware=middleware
690 ... )
691 """
692 return [
693 agent_run_logging_middleware,
694 agent_observability_middleware,
695 logging_function_middleware,
696 ]
699# Backward compatibility
700def create_function_middleware() -> list[FunctionMiddleware]:
701 """Create list of function middleware (legacy).
703 Provided for backward compatibility. Use create_middleware() instead.
705 Returns:
706 List of function middleware
708 Example:
709 >>> from agent.middleware import create_function_middleware
710 >>> function_mw = create_function_middleware()
711 """
712 return [cast(FunctionMiddleware, logging_function_middleware)]