Coverage for src / agent / middleware.py: 78%

276 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-11 14:30 +0000

1# Copyright 2025-2026 Microsoft Corporation 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Middleware functions for Agent execution pipeline. 

16 

17This module provides middleware functions for logging and event emission 

18in the agent request/response pipeline. Middleware allows automatic event 

19tracking without modifying individual tools. 

20 

21Middleware Types: 

22 - Agent-level: Wrap entire agent execution (LLM calls) 

23 - Function-level: Wrap individual tool calls 

24 

25Event Emission: 

26 Middleware emits execution events for display visualization when 

27 should_show_visualization() returns True. This enables real-time 

28 execution tree display without coupling tools to display logic. 

29""" 

30 

31import json 

32import logging 

33import time 

34import uuid 

35from collections.abc import Awaitable, Callable 

36 

37# TYPE_CHECKING import for forward reference 

38from typing import TYPE_CHECKING, Any, cast 

39 

40from agent_framework import ( 

41 AgentRunContext, 

42 FunctionInvocationContext, 

43 FunctionMiddleware, 

44) 

45 

46from agent.config.manager import load_config 

47from agent.config.schema import AgentSettings 

48 

49if TYPE_CHECKING: 

50 from agent.trace_logger import TraceLogger 

51 

52logger = logging.getLogger(__name__) 

53 

54# Global trace logger instance (set by session setup) 

55_trace_logger: "TraceLogger | None" = None 

56 

57 

58def set_trace_logger(trace_logger: "TraceLogger | None") -> None: 

59 """Set the global trace logger instance. 

60 

61 Args: 

62 trace_logger: TraceLogger instance or None 

63 """ 

64 global _trace_logger 

65 _trace_logger = trace_logger 

66 

67 

68def get_trace_logger() -> "TraceLogger | None": 

69 """Get the current trace logger instance. 

70 

71 Returns: 

72 TraceLogger instance or None 

73 """ 

74 return _trace_logger 

75 

76 

77def _extract_model_from_config(config: AgentSettings) -> str | None: 

78 """Extract model name from config based on provider. 

79 

80 Args: 

81 config: Agent configuration 

82 

83 Returns: 

84 Model name or None 

85 """ 

86 provider = config.llm_provider 

87 if provider == "openai": 

88 return config.openai_model 

89 elif provider == "anthropic": 

90 return config.anthropic_model 

91 elif provider == "azure": 

92 return config.azure_openai_deployment 

93 elif provider == "gemini": 

94 return config.gemini_model 

95 elif provider == "github": 

96 return config.github_model 

97 elif provider == "local": 

98 return config.local_model 

99 elif provider == "foundry": 

100 return config.azure_model_deployment 

101 return None 

102 

103 

104# ============================================================================ 

105# Agent-Level Middleware 

106# ============================================================================ 

107 

108 

109async def agent_run_logging_middleware( 

110 context: AgentRunContext, 

111 next: Callable[[AgentRunContext], Awaitable[None]], 

112) -> None: 

113 """Log agent execution lifecycle and emit LLM request/response events. 

114 

115 This middleware: 

116 - Logs agent execution start/complete 

117 - Emits LLMRequestEvent before LLM call 

118 - Emits LLMResponseEvent after LLM call with duration 

119 - Captures trace-level LLM request/response data (if enabled) 

120 - Only emits events if should_show_visualization() is True 

121 

122 Args: 

123 context: Agent run context containing messages and state 

124 next: Next middleware in chain 

125 

126 Example: 

127 >>> middleware = {"agent": [agent_run_logging_middleware]} 

128 >>> agent = chat_client.create_agent(..., middleware=middleware) 

129 """ 

130 from agent.display import ( 

131 LLMRequestEvent, 

132 LLMResponseEvent, 

133 get_event_emitter, 

134 should_show_visualization, 

135 ) 

136 

137 logger.debug("Agent run starting...") 

138 

139 # Generate request ID for trace logging 

140 request_id = str(uuid.uuid4()) 

141 

142 # Load config for trace logging if enabled (reused for request and response) 

143 trace_logger = get_trace_logger() 

144 config = load_config() if trace_logger else None 

145 

146 # Emit LLM request event 

147 llm_event_id = None 

148 if should_show_visualization(): 

149 message_count = len(context.messages) if hasattr(context, "messages") else 0 

150 event = LLMRequestEvent(message_count=message_count) 

151 llm_event_id = event.event_id 

152 get_event_emitter().emit(event) 

153 logger.debug(f"Emitted LLM request event with {message_count} messages") 

154 

155 # Trace logging: Log request data with full context 

156 if trace_logger: 

157 try: 

158 # Extract conversation messages 

159 messages = [] 

160 if hasattr(context, "messages"): 

161 for msg in context.messages: 

162 if hasattr(msg, "to_dict"): 

163 messages.append(msg.to_dict()) 

164 elif hasattr(msg, "model_dump"): 

165 messages.append(msg.model_dump()) 

166 elif hasattr(msg, "dict"): 

167 messages.append(msg.dict()) 

168 else: 

169 messages.append({"content": str(msg)}) 

170 

171 # Get model and provider from config (already loaded at middleware start) 

172 assert config is not None, "Config should be loaded when trace logger is enabled" 

173 provider = config.llm_provider 

174 model = _extract_model_from_config(config) 

175 

176 # Extract FULL payload info if include_messages is enabled 

177 system_instructions: str | None = None 

178 tools_summary: dict[str, Any] | None = None 

179 

180 if trace_logger.include_messages: 

181 if hasattr(context, "agent"): 

182 agent = context.agent 

183 

184 # Get system instructions and tools from chat_options 

185 if hasattr(agent, "chat_options"): 

186 chat_options = agent.chat_options 

187 

188 # Get instructions 

189 if hasattr(chat_options, "instructions") and chat_options.instructions: 

190 system_instructions = chat_options.instructions 

191 

192 # Get tools 

193 if hasattr(chat_options, "tools") and chat_options.tools: 

194 tools = chat_options.tools 

195 tools_data = [] 

196 for tool in tools: 

197 tool_dict = ( 

198 tool.to_dict() 

199 if hasattr(tool, "to_dict") 

200 else {"name": str(tool)} 

201 ) 

202 tool_json = json.dumps(tool_dict, default=str) 

203 tools_data.append( 

204 { 

205 "name": tool_dict.get("name", "unknown"), 

206 "description": ( 

207 tool_dict.get("description", "")[:100] 

208 if tool_dict.get("description") 

209 else "" 

210 ), 

211 "estimated_tokens": len(tool_json) // 4, 

212 } 

213 ) 

214 

215 tools_summary = { 

216 "count": len(tools), 

217 "tools": tools_data, 

218 "total_estimated_tokens": sum( 

219 t["estimated_tokens"] for t in tools_data 

220 ), 

221 } 

222 

223 # Log request using TraceLogger 

224 trace_logger.log_request( 

225 request_id=request_id, 

226 messages=messages, 

227 model=model, 

228 provider=provider, 

229 system_instructions=system_instructions, 

230 tools_summary=tools_summary, 

231 ) 

232 

233 except Exception as e: 

234 logger.debug(f"Failed to log trace request: {e}") 

235 

236 start_time = time.time() 

237 

238 try: 

239 await next(context) 

240 duration = time.time() - start_time 

241 latency_ms = duration * 1000 

242 logger.debug("Agent run completed successfully") 

243 

244 # Emit LLM response event 

245 if should_show_visualization() and llm_event_id: 

246 response_event = LLMResponseEvent(duration=duration, event_id=llm_event_id) 

247 get_event_emitter().emit(response_event) 

248 logger.debug(f"Emitted LLM response event ({duration:.2f}s)") 

249 

250 # Trace logging: Log response data with token usage 

251 if trace_logger: 

252 try: 

253 # Extract response content and token usage from result 

254 response_content = "" 

255 input_tokens = None 

256 output_tokens = None 

257 total_tokens = None 

258 

259 if hasattr(context, "result") and context.result: 

260 result = context.result 

261 # Extract content 

262 if hasattr(result, "text"): 

263 response_content = str(result.text) 

264 elif hasattr(result, "content"): 

265 response_content = str(result.content) 

266 elif hasattr(result, "data"): 

267 response_content = str(result.data) 

268 else: 

269 response_content = str(result) 

270 

271 # Extract token usage from result.usage_details 

272 if hasattr(result, "usage_details") and result.usage_details: 

273 usage = result.usage_details 

274 if hasattr(usage, "input_token_count"): 

275 input_tokens = usage.input_token_count 

276 if hasattr(usage, "output_token_count"): 

277 output_tokens = usage.output_token_count 

278 if hasattr(usage, "total_token_count"): 

279 total_tokens = usage.total_token_count 

280 

281 # Try to extract token usage from various locations 

282 # Check thread for usage info (agent-framework stores it in thread messages) 

283 if hasattr(context, "thread") and context.thread: 

284 thread = context.thread 

285 if hasattr(thread, "messages") and thread.messages: 

286 # Check last message (assistant response) for usage 

287 last_msg = thread.messages[-1] 

288 

289 # Check for usage in message 

290 if hasattr(last_msg, "usage") and last_msg.usage: 

291 usage = last_msg.usage 

292 if hasattr(usage, "input_token_count"): 

293 input_tokens = usage.input_token_count 

294 if hasattr(usage, "output_token_count"): 

295 output_tokens = usage.output_token_count 

296 if hasattr(usage, "total_token_count"): 

297 total_tokens = usage.total_token_count 

298 

299 # Also check contents for UsageContent 

300 if hasattr(last_msg, "contents") and last_msg.contents: 

301 for content in last_msg.contents: 

302 # UsageContent has usage attribute 

303 if hasattr(content, "usage") and content.usage: 

304 usage = content.usage 

305 if hasattr(usage, "input_token_count"): 

306 input_tokens = usage.input_token_count 

307 if hasattr(usage, "output_token_count"): 

308 output_tokens = usage.output_token_count 

309 if hasattr(usage, "total_token_count"): 

310 total_tokens = usage.total_token_count 

311 

312 # Also check metadata for usage info 

313 if hasattr(context, "metadata") and context.metadata: 

314 metadata = context.metadata 

315 if "usage" in metadata: 

316 usage = metadata["usage"] 

317 if isinstance(usage, dict): 

318 input_tokens = usage.get("input_token_count") or usage.get( 

319 "input_tokens" 

320 ) 

321 output_tokens = usage.get("output_token_count") or usage.get( 

322 "output_tokens" 

323 ) 

324 total_tokens = usage.get("total_token_count") or usage.get( 

325 "total_tokens" 

326 ) 

327 

328 # Get model from config (loaded once at middleware start) 

329 model = _extract_model_from_config(config) if config else None 

330 

331 trace_logger.log_response( 

332 request_id=request_id, 

333 response_content=response_content, 

334 model=model, 

335 input_tokens=input_tokens, 

336 output_tokens=output_tokens, 

337 total_tokens=total_tokens, 

338 latency_ms=latency_ms, 

339 ) 

340 except Exception as e: 

341 logger.debug(f"Failed to log trace response: {e}") 

342 

343 except Exception as e: 

344 logger.error(f"Agent run failed: {e}") 

345 

346 # Trace logging: Log error 

347 if trace_logger: 

348 try: 

349 duration = time.time() - start_time 

350 latency_ms = duration * 1000 

351 trace_logger.log_response( 

352 request_id=request_id, 

353 response_content="", 

354 latency_ms=latency_ms, 

355 error=str(e), 

356 ) 

357 except Exception as log_err: 

358 logger.debug(f"Failed to log trace error: {log_err}") 

359 

360 # Classify and wrap provider exceptions for better error messages 

361 from agent.cli.error_handler import classify_provider_error 

362 

363 wrapped = classify_provider_error(e, config) 

364 if wrapped: 

365 # Raise wrapped exception with original as cause (preserves stack trace) 

366 raise wrapped from e 

367 else: 

368 # Unknown error, re-raise as-is 

369 raise 

370 

371 

372async def agent_observability_middleware( 

373 context: AgentRunContext, 

374 next: Callable[[AgentRunContext], Awaitable[None]], 

375) -> None: 

376 """Track agent execution duration with logging. 

377 

378 This middleware logs execution duration for all agent runs. 

379 The agent-framework handles OpenTelemetry instrumentation automatically. 

380 

381 Args: 

382 context: Agent run context 

383 next: Next middleware in chain 

384 

385 Example: 

386 >>> middleware = {"agent": [agent_observability_middleware]} 

387 >>> agent = chat_client.create_agent(..., middleware=middleware) 

388 """ 

389 start_time = time.time() 

390 

391 try: 

392 await next(context) 

393 finally: 

394 duration = time.time() - start_time 

395 logger.info(f"Agent execution took {duration:.2f}s") 

396 

397 

398# ============================================================================ 

399# Function-Level Middleware 

400# ============================================================================ 

401 

402 

403async def logging_function_middleware( 

404 context: FunctionInvocationContext, 

405 next: Callable, 

406) -> Any: 

407 """Middleware to log function/tool calls and emit execution events with OpenTelemetry. 

408 

409 This middleware: 

410 - Logs tool execution start/complete/error 

411 - Emits ToolStartEvent before tool execution 

412 - Emits ToolCompleteEvent on success with result summary 

413 - Emits ToolErrorEvent on failure 

414 - Sets tool context for nested event tracking 

415 - Creates OpenTelemetry spans for tool execution (when enabled) 

416 - Only emits events if should_show_visualization() is True 

417 

418 Args: 

419 context: Function invocation context with function metadata and arguments 

420 next: Next middleware in chain 

421 

422 Returns: 

423 Result from tool execution 

424 

425 Example: 

426 >>> middleware = {"function": [logging_function_middleware]} 

427 >>> agent = chat_client.create_agent(..., middleware=middleware) 

428 """ 

429 from agent_framework.observability import OtelAttr, get_meter, get_tracer 

430 from opentelemetry import trace as ot_trace 

431 

432 from agent.display import ( 

433 ToolCompleteEvent, 

434 ToolErrorEvent, 

435 ToolStartEvent, 

436 get_current_tool_event_id, 

437 get_event_emitter, 

438 set_current_tool_event_id, 

439 should_show_visualization, 

440 ) 

441 from agent.observability import get_current_agent_span 

442 

443 tool_name = context.function.name 

444 args = context.arguments 

445 

446 logger.info(f"Tool call: {tool_name}") 

447 

448 # Emit tool start event (if visualization enabled) 

449 tool_event_id = None 

450 parent_id = None 

451 if should_show_visualization(): 

452 # Get parent event ID for nested tools 

453 parent_id = get_current_tool_event_id() 

454 

455 # Convert args to dict for event (using Pydantic v2 model_dump) 

456 if hasattr(args, "model_dump"): 

457 args_dict = args.model_dump() 

458 elif hasattr(args, "dict"): 

459 # Fallback for Pydantic v1 compatibility 

460 args_dict = args.dict() 

461 elif isinstance(args, dict): 

462 args_dict = args 

463 else: 

464 args_dict = {} 

465 

466 # Sanitize (remove sensitive keys) 

467 safe_args = { 

468 k: v 

469 for k, v in args_dict.items() 

470 if k not in ["token", "api_key", "password", "secret"] 

471 } 

472 

473 event = ToolStartEvent(tool_name=tool_name, arguments=safe_args, parent_id=parent_id) 

474 tool_event_id = event.event_id 

475 get_event_emitter().emit(event) 

476 

477 # Set tool context for child operations (enables nested event display) 

478 set_current_tool_event_id(tool_event_id) 

479 logger.debug(f"Set tool context: {tool_name} (event_id: {tool_event_id[:8]}...)") 

480 

481 # Check if observability is enabled 

482 config = load_config() 

483 tracer = get_tracer(__name__) if config.enable_otel else None 

484 meter = get_meter(__name__) if config.enable_otel else None 

485 

486 start_time = time.time() 

487 

488 # Create span context manager (no-op if observability disabled) 

489 # Prepare parent context for robust nesting 

490 parent_context = None 

491 if tracer: 

492 try: 

493 from opentelemetry import trace as ot_trace 

494 

495 current_span = ot_trace.get_current_span() 

496 # If current span looks invalid, try the saved agent span 

497 if current_span is None or not getattr(current_span, "is_recording", lambda: False)(): 

498 saved_agent_span = get_current_agent_span() 

499 if saved_agent_span is not None: 

500 parent_context = ot_trace.set_span_in_context(saved_agent_span) 

501 except Exception: 

502 parent_context = None 

503 

504 span_context = ( 

505 tracer.start_as_current_span(f"tool.{tool_name}", context=parent_context) 

506 if tracer 

507 else _noop_context_manager() 

508 ) 

509 

510 with span_context as span: 

511 try: 

512 # Set span attributes if observability enabled 

513 if span and config.enable_otel: 

514 # Set GenAI span type 

515 span.set_attribute("span_type", "GenAI") 

516 

517 # Set operation name and tool info 

518 span.set_attribute("gen_ai.operation.name", OtelAttr.TOOL_EXECUTION_OPERATION) 

519 span.set_attribute(OtelAttr.TOOL_NAME, tool_name) 

520 

521 # Add tool description if available 

522 if hasattr(context.function, "description"): 

523 span.set_attribute(OtelAttr.TOOL_DESCRIPTION, context.function.description) 

524 

525 # Add tool parameters if sensitive data enabled 

526 if config.enable_sensitive_data and args: 

527 # Convert args to dict for serialization 

528 args_data: dict[str, Any] | str 

529 if hasattr(args, "model_dump"): 

530 args_data = args.model_dump() 

531 elif hasattr(args, "dict"): 

532 args_data = args.dict() 

533 elif isinstance(args, dict): 

534 args_data = args 

535 else: 

536 args_data = str(args) 

537 

538 span.set_attribute( 

539 OtelAttr.TOOL_ARGUMENTS, 

540 json.dumps(args_data) if isinstance(args_data, dict) else args_data, 

541 ) 

542 

543 result = await next(context) 

544 duration = time.time() - start_time 

545 logger.info(f"Tool call {tool_name} completed successfully ({duration:.2f}s)") 

546 

547 # Record metrics if observability enabled 

548 if meter and config.enable_otel: 

549 duration_histogram = meter.create_histogram( 

550 name="tool.execution.duration", 

551 description="Tool execution duration in seconds", 

552 unit="s", 

553 ) 

554 duration_histogram.record(duration, {"tool": tool_name, "status": "success"}) 

555 

556 # Set tool result if sensitive data enabled 

557 if span and config.enable_otel and config.enable_sensitive_data: 

558 result_str = json.dumps(result) if isinstance(result, (dict, list)) else str(result) 

559 # Truncate to 1000 chars to avoid excessive data 

560 span.set_attribute(OtelAttr.TOOL_RESULT, result_str[:1000]) 

561 

562 # Emit tool complete event 

563 if should_show_visualization() and tool_event_id: 

564 # Extract summary from result 

565 summary = _extract_tool_summary(tool_name, result) 

566 complete_event = ToolCompleteEvent( 

567 tool_name=tool_name, 

568 result_summary=summary, 

569 duration=duration, 

570 event_id=tool_event_id, 

571 ) 

572 get_event_emitter().emit(complete_event) 

573 

574 return result 

575 except Exception as e: 

576 duration = time.time() - start_time 

577 logger.error(f"Tool call {tool_name} failed: {e}") 

578 

579 # Record error metrics if observability enabled 

580 if meter and config.enable_otel: 

581 duration_histogram = meter.create_histogram( 

582 name="tool.execution.duration", 

583 description="Tool execution duration in seconds", 

584 unit="s", 

585 ) 

586 duration_histogram.record(duration, {"tool": tool_name, "status": "error"}) 

587 

588 # Capture exception in span 

589 if span and config.enable_otel: 

590 span.record_exception(e) 

591 from opentelemetry.trace import Status, StatusCode 

592 

593 span.set_status(Status(StatusCode.ERROR, str(e))) 

594 

595 # Emit tool error event 

596 if should_show_visualization() and tool_event_id: 

597 error_event = ToolErrorEvent( 

598 tool_name=tool_name, 

599 error_message=str(e), 

600 duration=duration, 

601 event_id=tool_event_id, 

602 ) 

603 get_event_emitter().emit(error_event) 

604 

605 raise 

606 finally: 

607 # Clear tool context when exiting tool (restore parent) 

608 if should_show_visualization(): 

609 set_current_tool_event_id(parent_id) 

610 if parent_id: 

611 logger.debug("Restored parent tool context") 

612 else: 

613 logger.debug(f"Cleared tool context: {tool_name}") 

614 

615 

616def _noop_context_manager() -> Any: 

617 """No-op context manager for when observability is disabled. 

618 

619 Returns: 

620 Context manager that does nothing 

621 

622 Example: 

623 >>> with _noop_context_manager() as span: 

624 ... # span is None, no telemetry recorded 

625 ... pass 

626 """ 

627 from contextlib import nullcontext 

628 

629 return nullcontext(None) 

630 

631 

632def _extract_tool_summary(tool_name: str, result: Any) -> str: 

633 """Extract human-readable summary from tool result. 

634 

635 Attempts to extract meaningful summary from common result patterns: 

636 - Dict with "message" key 

637 - Dict with "summary" key 

638 - String result (truncated to 100 chars) 

639 - Other types return "Complete" 

640 

641 Args: 

642 tool_name: Name of the tool 

643 result: Tool result 

644 

645 Returns: 

646 Brief summary string (max 100 chars) 

647 

648 Example: 

649 >>> _extract_tool_summary("read_file", {"message": "File read successfully"}) 

650 'File read successfully' 

651 >>> _extract_tool_summary("list_directory", "Long result...") 

652 'Long result...' # (truncated if > 100 chars) 

653 """ 

654 if isinstance(result, dict): 

655 if "message" in result: 

656 msg = str(result["message"]) 

657 return msg[:100] if len(msg) > 100 else msg 

658 elif "summary" in result: 

659 summary = str(result["summary"]) 

660 return summary[:100] if len(summary) > 100 else summary 

661 elif isinstance(result, str): 

662 return result[:100] if len(result) > 100 else result 

663 

664 return "Complete" 

665 

666 

667# ============================================================================ 

668# Middleware Factory 

669# ============================================================================ 

670 

671 

672def create_middleware() -> list: 

673 """Create default middleware for agent and function levels. 

674 

675 Returns: 

676 List of middleware (framework auto-categorizes by type) 

677 

678 Note: 

679 Memory is handled via ContextProvider, not middleware. 

680 See MemoryContextProvider for memory management. 

681 

682 Example: 

683 >>> from agent.middleware import create_middleware 

684 >>> middleware = create_middleware() 

685 >>> agent = chat_client.create_agent( 

686 ... name="Agent", 

687 ... instructions="...", 

688 ... tools=tools, 

689 ... middleware=middleware 

690 ... ) 

691 """ 

692 return [ 

693 agent_run_logging_middleware, 

694 agent_observability_middleware, 

695 logging_function_middleware, 

696 ] 

697 

698 

699# Backward compatibility 

700def create_function_middleware() -> list[FunctionMiddleware]: 

701 """Create list of function middleware (legacy). 

702 

703 Provided for backward compatibility. Use create_middleware() instead. 

704 

705 Returns: 

706 List of function middleware 

707 

708 Example: 

709 >>> from agent.middleware import create_function_middleware 

710 >>> function_mw = create_function_middleware() 

711 """ 

712 return [cast(FunctionMiddleware, logging_function_middleware)]