Coverage for src / agent / tools / maven.py: 93%
201 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 14:30 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 14:30 +0000
1# Copyright 2025-2026 Microsoft Corporation
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Maven tools for dependency version checking and security scanning."""
17import asyncio
18import json
19import logging
20import xml.etree.ElementTree as ET
21from pathlib import Path
22from typing import Annotated, Any
24from pydantic import Field
26from agent.config.schema import AgentSettings
27from agent.services.maven import MavenApiService, MavenCacheService
28from agent.services.maven.api import MavenApiError
29from agent.services.maven.types import (
30 MavenErrorCode,
31 PomAnalysisResult,
32 PomDependency,
33 SecurityScanResult,
34 Vulnerability,
35 VulnerabilitySeverity,
36)
37from agent.tools.toolset import AgentToolset
39logger = logging.getLogger(__name__)
42class MavenTools(AgentToolset):
43 """Maven dependency and security tools."""
45 def __init__(self, settings: AgentSettings):
46 super().__init__(settings)
47 self._cache = MavenCacheService(
48 default_ttl=getattr(settings.agent, "maven_cache_ttl", 3600)
49 )
50 self._api = MavenApiService(
51 cache=self._cache,
52 timeout=getattr(settings.agent, "maven_timeout", 30),
53 )
54 self._trivy_checked = False
55 self._trivy_available: bool | None = None
57 def get_tools(self) -> list:
58 """Return list of Maven tools."""
59 return [
60 self.check_version,
61 self.check_versions_batch,
62 self.scan_security,
63 self.analyze_pom,
64 ]
66 async def _check_trivy_available(self) -> bool:
67 """Check if Trivy is installed (async)."""
68 if self._trivy_checked:
69 return self._trivy_available or False
71 self._trivy_checked = True
72 try:
73 proc = await asyncio.create_subprocess_exec(
74 "trivy",
75 "--version",
76 stdout=asyncio.subprocess.DEVNULL,
77 stderr=asyncio.subprocess.DEVNULL,
78 )
79 await asyncio.wait_for(proc.wait(), timeout=10)
80 self._trivy_available = proc.returncode == 0
81 except (TimeoutError, FileNotFoundError):
82 self._trivy_available = False
84 return self._trivy_available
86 async def check_version(
87 self,
88 dependency: Annotated[str, Field(description="groupId:artifactId")],
89 version: Annotated[str, Field(description="Version to check")],
90 packaging: Annotated[str, Field(description="Package type")] = "jar",
91 classifier: Annotated[str | None, Field(description="Classifier")] = None,
92 ) -> dict:
93 """Check dependency version existence and available updates."""
94 try:
95 result = await self._api.check_version(dependency, version, packaging, classifier)
96 return self._create_success_response(
97 result=result,
98 message=f"Checked {dependency}:{version}",
99 )
100 except ValueError as e:
101 return self._create_error_response(
102 error=MavenErrorCode.INVALID_COORDINATE,
103 message=str(e),
104 )
105 except MavenApiError as e:
106 return self._create_error_response(
107 error=e.error_code.value,
108 message=str(e),
109 )
110 except Exception as e:
111 logger.exception("Error checking version")
112 return self._create_error_response(
113 error=MavenErrorCode.MAVEN_API_ERROR,
114 message=f"Error checking version: {e}",
115 )
117 async def check_versions_batch(
118 self,
119 dependencies: Annotated[
120 list[dict[str, str]],
121 Field(description="List of {dependency, version} objects"),
122 ],
123 ) -> dict:
124 """Check multiple dependencies for updates."""
125 results: list[dict[str, Any]] = []
126 success_count = 0
127 with_updates = 0
129 for dep in dependencies:
130 dependency = dep.get("dependency", "")
131 version = dep.get("version", "")
132 packaging = dep.get("packaging", "jar")
133 classifier = dep.get("classifier")
135 if not dependency or not version:
136 results.append(
137 {
138 "dependency": dependency,
139 "version": version,
140 "status": "error",
141 "error": "Missing dependency or version",
142 }
143 )
144 continue
146 try:
147 result = await self._api.check_version(dependency, version, packaging, classifier)
148 results.append(
149 {
150 "dependency": dependency,
151 "version": version,
152 "status": "success",
153 "result": result,
154 }
155 )
156 success_count += 1
158 if (
159 result.get("has_major_update")
160 or result.get("has_minor_update")
161 or result.get("has_patch_update")
162 ):
163 with_updates += 1
165 except Exception as e:
166 results.append(
167 {
168 "dependency": dependency,
169 "version": version,
170 "status": "error",
171 "error": str(e),
172 }
173 )
175 return self._create_success_response(
176 result={
177 "total": len(dependencies),
178 "success": success_count,
179 "failed": len(dependencies) - success_count,
180 "with_updates": with_updates,
181 "results": results,
182 },
183 message=f"Checked {len(dependencies)} deps: {success_count} ok, {with_updates} updates",
184 )
186 async def scan_security(
187 self,
188 path: Annotated[str, Field(description="Path to project or POM")],
189 severity_filter: Annotated[
190 list[str] | None,
191 Field(description="Severities: critical, high, medium, low"),
192 ] = None,
193 ) -> dict:
194 """Scan for vulnerabilities using Trivy."""
195 if not await self._check_trivy_available():
196 return self._create_error_response(
197 error=MavenErrorCode.TRIVY_NOT_AVAILABLE,
198 message="Trivy not available. Install: brew install trivy",
199 )
201 scan_path = Path(path).expanduser().resolve()
202 if not scan_path.exists():
203 return self._create_error_response(
204 error=MavenErrorCode.INVALID_PATH,
205 message=f"Path does not exist: {path}",
206 )
208 if scan_path.is_file() and scan_path.name.endswith(".xml"):
209 target = str(scan_path)
210 elif scan_path.is_dir():
211 if not (scan_path / "pom.xml").exists():
212 return self._create_error_response(
213 error=MavenErrorCode.POM_PARSE_ERROR,
214 message=f"No pom.xml in directory: {path}",
215 )
216 target = str(scan_path)
217 else:
218 return self._create_error_response(
219 error=MavenErrorCode.INVALID_PATH,
220 message=f"Invalid path: {path}",
221 )
223 if severity_filter is None:
224 severity_filter = ["critical", "high", "medium", "low"]
225 severity_filter = [s.upper() for s in severity_filter]
227 try:
228 import tempfile
230 with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp:
231 output_file = tmp.name
233 try:
234 proc = await asyncio.create_subprocess_exec(
235 "trivy",
236 "fs",
237 "--security-checks",
238 "vuln",
239 "--format",
240 "json",
241 "--output",
242 output_file,
243 "--severity",
244 ",".join(severity_filter),
245 target,
246 stdout=asyncio.subprocess.PIPE,
247 stderr=asyncio.subprocess.PIPE,
248 )
249 _, stderr = await asyncio.wait_for(proc.communicate(), timeout=300)
251 if proc.returncode != 0:
252 return self._create_error_response(
253 error=MavenErrorCode.TRIVY_SCAN_FAILED,
254 message=f"Trivy failed: {stderr.decode()}",
255 )
257 with open(output_file) as f:
258 trivy_data = json.load(f)
260 finally:
261 Path(output_file).unlink(missing_ok=True)
263 vulnerabilities = self._process_trivy_results(trivy_data)
264 severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
265 for vuln in vulnerabilities:
266 sev = vuln.severity.value.lower()
267 if sev in severity_counts:
268 severity_counts[sev] += 1
270 scan_result = SecurityScanResult(
271 vulnerabilities_found=len(vulnerabilities) > 0,
272 total_vulnerabilities=len(vulnerabilities),
273 severity_counts=severity_counts,
274 vulnerabilities=vulnerabilities,
275 scan_target=target,
276 trivy_available=True,
277 )
279 return self._create_success_response(
280 result=scan_result.model_dump(),
281 message=f"Found {len(vulnerabilities)} vulnerabilities",
282 )
284 except TimeoutError:
285 return self._create_error_response(
286 error=MavenErrorCode.TRIVY_SCAN_FAILED,
287 message="Trivy timed out after 5 minutes",
288 )
289 except Exception as e:
290 logger.exception("Security scan error")
291 return self._create_error_response(
292 error=MavenErrorCode.TRIVY_SCAN_FAILED,
293 message=f"Scan error: {e}",
294 )
296 def _process_trivy_results(self, trivy_data: dict) -> list[Vulnerability]:
297 """Parse Trivy JSON output."""
298 vulnerabilities = []
300 for result in trivy_data.get("Results", []):
301 for vuln in result.get("Vulnerabilities", []):
302 pkg_id = vuln.get("PkgID", "")
303 parts = pkg_id.split(":")
304 package_name = f"{parts[0]}:{parts[1]}" if len(parts) >= 2 else pkg_id
305 installed_version = (
306 parts[2] if len(parts) >= 3 else vuln.get("InstalledVersion", "")
307 )
309 severity_str = vuln.get("Severity", "UNKNOWN").upper()
310 try:
311 severity = VulnerabilitySeverity(severity_str.lower())
312 except ValueError:
313 severity = VulnerabilitySeverity.UNKNOWN
315 vulnerabilities.append(
316 Vulnerability(
317 cve_id=vuln.get("VulnerabilityID", ""),
318 severity=severity,
319 package_name=package_name,
320 installed_version=installed_version,
321 fixed_version=vuln.get("FixedVersion"),
322 description=vuln.get("Description", ""),
323 )
324 )
326 return vulnerabilities
328 async def analyze_pom(
329 self,
330 path: Annotated[str, Field(description="Path to pom.xml")],
331 check_versions: Annotated[bool, Field(description="Check for updates")] = False,
332 ) -> dict:
333 """Analyze POM file for dependencies and project info."""
334 pom_path = Path(path).expanduser().resolve()
335 if not pom_path.exists():
336 return self._create_error_response(
337 error=MavenErrorCode.INVALID_PATH,
338 message=f"POM not found: {path}",
339 )
341 if not pom_path.is_file() or not pom_path.name.endswith(".xml"):
342 return self._create_error_response(
343 error=MavenErrorCode.POM_PARSE_ERROR,
344 message=f"Invalid POM file: {path}",
345 )
347 try:
348 tree = ET.parse(pom_path)
349 root = tree.getroot()
350 ns = {"m": "http://maven.apache.org/POM/4.0.0"}
352 group_id = self._find_text(root, "groupId", ns)
353 artifact_id = self._find_text(root, "artifactId", ns)
354 version = self._find_text(root, "version", ns)
355 packaging = self._find_text(root, "packaging", ns) or "jar"
357 parent: dict[str, str] | None = None
358 parent_elem = root.find("m:parent", ns) or root.find("parent")
359 if parent_elem is not None:
360 parent = {
361 "group_id": self._find_text(parent_elem, "groupId", ns) or "",
362 "artifact_id": self._find_text(parent_elem, "artifactId", ns) or "",
363 "version": self._find_text(parent_elem, "version", ns) or "",
364 }
366 properties: dict[str, str] = {}
367 props_elem = root.find("m:properties", ns) or root.find("properties")
368 if props_elem is not None:
369 for prop in props_elem:
370 tag = prop.tag.replace(f"{{{ns.get('m', '')}}}", "")
371 if prop.text:
372 properties[tag] = prop.text
374 dependencies = self._extract_dependencies(root, "dependencies", ns)
375 dep_mgmt = self._extract_dependencies(root, "dependencyManagement/dependencies", ns)
377 modules: list[str] = []
378 modules_elem = root.find("m:modules", ns) or root.find("modules")
379 if modules_elem is not None:
380 for module in modules_elem.findall("m:module", ns) or modules_elem.findall(
381 "module"
382 ):
383 if module.text:
384 modules.append(module.text)
386 pom_result = PomAnalysisResult(
387 pom_path=str(pom_path),
388 group_id=group_id,
389 artifact_id=artifact_id,
390 version=version,
391 packaging=packaging,
392 parent=parent,
393 dependencies=dependencies,
394 dependency_management=dep_mgmt,
395 properties=properties,
396 modules=modules,
397 )
399 result_dict = pom_result.model_dump()
401 if check_versions and dependencies:
402 version_checks = []
403 for dep in dependencies:
404 if dep.version and not dep.version.startswith("${"):
405 try:
406 coord = f"{dep.group_id}:{dep.artifact_id}"
407 check = await self._api.check_version(coord, dep.version)
408 version_checks.append(
409 {
410 "dependency": coord,
411 "current": dep.version,
412 "latest_versions": check.get("latest_versions", {}),
413 "has_updates": (
414 check.get("has_major_update")
415 or check.get("has_minor_update")
416 or check.get("has_patch_update")
417 ),
418 }
419 )
420 except Exception as e:
421 logger.debug("Error checking %s: %s", dep, e)
423 result_dict["version_checks"] = version_checks
425 return self._create_success_response(
426 result=result_dict,
427 message=f"Analyzed: {artifact_id or 'unknown'}",
428 )
430 except ET.ParseError as e:
431 return self._create_error_response(
432 error=MavenErrorCode.POM_PARSE_ERROR,
433 message=f"Parse error: {e}",
434 )
435 except Exception as e:
436 logger.exception("POM analysis error")
437 return self._create_error_response(
438 error=MavenErrorCode.POM_PARSE_ERROR,
439 message=f"Analysis error: {e}",
440 )
442 def _find_text(self, elem: ET.Element, path: str, ns: dict[str, str]) -> str | None:
443 """Find text in element with namespace handling."""
444 found = elem.find(f"m:{path}", ns)
445 if found is not None and found.text:
446 return found.text
447 found = elem.find(path)
448 if found is not None and found.text:
449 return found.text
450 return None
452 def _extract_dependencies(
453 self, root: ET.Element, path: str, ns: dict[str, str]
454 ) -> list[PomDependency]:
455 """Extract dependencies from POM element."""
456 dependencies: list[PomDependency] = []
458 parts = path.split("/")
459 container = root
460 for part in parts:
461 found = container.find(f"m:{part}", ns) or container.find(part)
462 if found is None:
463 return dependencies
464 container = found
466 dep_elems = container.findall("m:dependency", ns) or container.findall("dependency")
468 for dep in dep_elems:
469 group_id = self._find_text(dep, "groupId", ns) or ""
470 artifact_id = self._find_text(dep, "artifactId", ns) or ""
471 version = self._find_text(dep, "version", ns)
472 scope = self._find_text(dep, "scope", ns) or "compile"
473 optional_str = self._find_text(dep, "optional", ns)
474 optional = optional_str == "true" if optional_str else False
476 if group_id and artifact_id:
477 dependencies.append(
478 PomDependency(
479 group_id=group_id,
480 artifact_id=artifact_id,
481 version=version,
482 scope=scope,
483 optional=optional,
484 )
485 )
487 return dependencies