Coverage for src / agent / tools / git.py: 87%
269 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 14:30 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 14:30 +0000
1# Copyright 2025-2026 Microsoft Corporation
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Git repository tools for safe, sandboxed git operations.
17This module provides structured git tools that enable agents to clone,
18manage, and work with git repositories in a controlled environment.
20Key Features:
21- Repos directory sandboxing with path traversal protection
22- Service name sanitization for security
23- Repository cloning with depth and branch options
24- Status, pull, and branch management operations
25- Remote management (list, add, remove)
27All operations are sandboxed to repos_root (defaults to ~/.osdu-agent/repos).
28"""
30import asyncio
31import logging
32import re
33import subprocess
34from pathlib import Path
35from typing import Annotated, Any
37from pydantic import Field
39from agent.config.schema import AgentSettings
40from agent.tools.toolset import AgentToolset
42logger = logging.getLogger(__name__)
44# Valid characters for repository names (alphanumeric, hyphens, underscores, dots)
45VALID_REPO_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*$")
47# Git URL patterns for validation
48GIT_URL_PATTERNS = [
49 re.compile(r"^https?://[^\s]+\.git$"), # HTTP/HTTPS with .git
50 re.compile(r"^https?://[^\s]+$"), # HTTP/HTTPS without .git
51 re.compile(r"^git@[^\s:]+:[^\s]+\.git$"), # SSH with .git
52 re.compile(r"^git@[^\s:]+:[^\s]+$"), # SSH without .git
53 re.compile(r"^ssh://[^\s]+$"), # SSH URL format
54]
57class GitTools(AgentToolset):
58 """Git repository tools for safe, sandboxed operations.
60 This toolset provides git operations with security guarantees:
61 - All paths must be under configured repos_root
62 - Repository names are sanitized
63 - Path traversal attempts are blocked
64 - Operations are isolated using subprocess with cwd
66 Example:
67 >>> from agent.config import load_config
68 >>> settings = load_config()
69 >>> tools = GitTools(settings)
70 >>> result = await tools.list_repositories()
71 >>> print(result)
72 {'success': True, 'result': {'repositories': [...]}}
73 """
75 def __init__(self, settings: AgentSettings):
76 """Initialize GitTools with settings.
78 Args:
79 settings: Agent settings instance with repos_root configuration
80 """
81 super().__init__(settings)
82 self._repos_root_cache: Path | None = None
84 def get_tools(self) -> list:
85 """Get list of git tools.
87 Returns:
88 List of git tool functions
89 """
90 return [
91 self.clone_repository,
92 self.list_repositories,
93 self.get_status,
94 self.pull_repository,
95 self.create_branch,
96 self.manage_remote,
97 ]
99 def _get_repos_root(self) -> Path:
100 """Get and cache repos root from config or default.
102 Priority order:
103 1. Cached value (for performance)
104 2. config.repos_root (from settings.json)
105 3. ~/.osdu-agent/repos (default fallback)
107 Returns:
108 Resolved repos root Path (always returns a valid Path)
109 """
110 if self._repos_root_cache is not None:
111 return self._repos_root_cache
113 repos_root: Path | None = None
115 # Check config first
116 if hasattr(self.config, "repos_root") and self.config.repos_root is not None:
117 repos_root = self.config.repos_root
118 else:
119 # Default to ~/.osdu-agent/repos
120 repos_root = Path.home() / ".osdu-agent" / "repos"
122 # Ensure directory exists
123 repos_root.mkdir(parents=True, exist_ok=True)
125 self._repos_root_cache = repos_root
126 return self._repos_root_cache
128 def _validate_repo_name(self, name: str) -> dict | None:
129 """Validate repository name for security.
131 Args:
132 name: Repository name to validate
134 Returns:
135 Error dict if invalid, None if valid
136 """
137 if not name:
138 return self._create_error_response(
139 error="empty_repo_name",
140 message="Repository name cannot be empty",
141 )
143 if not VALID_REPO_NAME_PATTERN.match(name):
144 return self._create_error_response(
145 error="invalid_repo_name",
146 message=f"Invalid repository name: {name}. Use alphanumeric, hyphens, underscores, dots only.",
147 )
149 # Prevent path traversal
150 if ".." in name or "/" in name or "\\" in name:
151 logger.warning(f"Path traversal attempt in repo name: {name}")
152 return self._create_error_response(
153 error="path_traversal_attempt",
154 message=f"Repository name contains invalid characters: {name}",
155 )
157 return None
159 def _resolve_repo_path(self, repo_name: str) -> dict | Path:
160 """Resolve repository path within repos_root boundaries.
162 Args:
163 repo_name: Name of the repository
165 Returns:
166 Resolved Path object if valid, or error dict if validation fails
167 """
168 # Validate repo name first
169 error = self._validate_repo_name(repo_name)
170 if error:
171 return error
173 repos_root = self._get_repos_root()
174 repo_path = repos_root / repo_name
176 # Verify path is within repos_root after resolution
177 try:
178 resolved = repo_path.resolve()
179 if not resolved.is_relative_to(repos_root):
180 logger.warning(f"Path outside repos_root: {repo_name} -> {resolved}")
181 return self._create_error_response(
182 error="path_outside_repos",
183 message=f"Repository path resolves outside repos directory: {repo_name}",
184 )
185 except (ValueError, TypeError):
186 return self._create_error_response(
187 error="path_outside_repos",
188 message=f"Repository path resolves outside repos directory: {repo_name}",
189 )
191 return resolved
193 def _validate_git_url(self, url: str) -> dict | None:
194 """Validate git URL format.
196 Args:
197 url: Git URL to validate
199 Returns:
200 Error dict if invalid, None if valid
201 """
202 if not url:
203 return self._create_error_response(
204 error="empty_url",
205 message="Git URL cannot be empty",
206 )
208 for pattern in GIT_URL_PATTERNS:
209 if pattern.match(url):
210 return None
212 return self._create_error_response(
213 error="invalid_git_url",
214 message=f"Invalid git URL format: {url}. Use https://, git@, or ssh:// URLs.",
215 )
217 def _is_git_repo(self, path: Path) -> bool:
218 """Check if path is a git repository.
220 Args:
221 path: Path to check
223 Returns:
224 True if path contains .git directory
225 """
226 return (path / ".git").is_dir()
228 async def _run_git_command(
229 self,
230 args: list[str],
231 cwd: Path | None = None,
232 timeout: int | None = None,
233 ) -> tuple[bool, str, str]:
234 """Run git command in subprocess.
236 Args:
237 args: Git command arguments (without 'git' prefix)
238 cwd: Working directory for command
239 timeout: Command timeout in seconds
241 Returns:
242 Tuple of (success, stdout, stderr)
243 """
244 if timeout is None:
245 timeout = self.config.git_clone_timeout
247 cmd = ["git"] + args
249 try:
250 # Run in subprocess to isolate from main process
251 loop = asyncio.get_event_loop()
252 result = await loop.run_in_executor(
253 None,
254 lambda: subprocess.run(
255 cmd,
256 cwd=cwd,
257 capture_output=True,
258 text=True,
259 timeout=timeout,
260 ),
261 )
263 success = result.returncode == 0
264 return success, result.stdout, result.stderr
266 except subprocess.TimeoutExpired:
267 return False, "", f"Command timed out after {timeout} seconds"
268 except FileNotFoundError:
269 return False, "", "Git is not installed or not in PATH"
270 except Exception as e:
271 return False, "", f"Error running git command: {str(e)}"
273 async def clone_repository(
274 self,
275 url: Annotated[str, Field(description="Git repository URL (https:// or git@)")],
276 name: Annotated[
277 str | None, Field(description="Local repository name (defaults to repo name from URL)")
278 ] = None,
279 branch: Annotated[str | None, Field(description="Branch to clone")] = None,
280 depth: Annotated[
281 int | None, Field(description="Shallow clone depth (e.g., 1 for latest only)")
282 ] = None,
283 ) -> dict:
284 """Clone git repository into repos directory. Supports depth and branch options. Returns clone path and branch."""
285 # Validate URL
286 url_error = self._validate_git_url(url)
287 if url_error:
288 return url_error
290 # Extract repo name from URL if not provided
291 if name is None:
292 # Extract from URL: handle both .git suffix and bare URLs
293 url_path = url.rstrip("/")
294 if url_path.endswith(".git"):
295 url_path = url_path[:-4]
296 name = url_path.split("/")[-1].split(":")[-1]
298 # Resolve and validate target path
299 repo_path = self._resolve_repo_path(name)
300 if isinstance(repo_path, dict):
301 return repo_path # Error response
303 # Check if already exists
304 if repo_path.exists():
305 if self._is_git_repo(repo_path):
306 return self._create_error_response(
307 error="repo_exists",
308 message=f"Repository already exists: {name}. Use pull_repository to update.",
309 )
310 else:
311 return self._create_error_response(
312 error="path_exists",
313 message=f"Path exists but is not a git repository: {name}",
314 )
316 # Build clone command
317 args = ["clone"]
319 if depth is not None and depth > 0:
320 args.extend(["--depth", str(depth)])
322 if branch:
323 args.extend(["--branch", branch])
325 args.extend([url, str(repo_path)])
327 # Execute clone
328 success, stdout, stderr = await self._run_git_command(args)
330 if not success:
331 return self._create_error_response(
332 error="clone_failed",
333 message=f"Failed to clone repository: {stderr.strip()}",
334 )
336 # Get current branch
337 branch_success, branch_out, _ = await self._run_git_command(
338 ["rev-parse", "--abbrev-ref", "HEAD"],
339 cwd=repo_path,
340 )
342 current_branch = branch_out.strip() if branch_success else "unknown"
344 result: dict[str, Any] = {
345 "name": name,
346 "path": str(repo_path),
347 "url": url,
348 "branch": current_branch,
349 "depth": depth,
350 }
352 return self._create_success_response(
353 result=result,
354 message=f"Cloned repository {name} to {repo_path}",
355 )
357 async def list_repositories(self) -> dict:
358 """List local repositories with current branch and clean/modified status."""
359 repos_root = self._get_repos_root()
361 repositories: list[dict[str, Any]] = []
363 try:
364 for entry in repos_root.iterdir():
365 if not entry.is_dir():
366 continue
368 # Security: Resolve symlinks and verify path stays within repos_root
369 try:
370 resolved_entry = entry.resolve()
371 if not resolved_entry.is_relative_to(repos_root):
372 logger.warning(
373 f"Skipping entry outside repos_root: {entry.name} -> {resolved_entry}"
374 )
375 continue
376 except (ValueError, OSError) as e:
377 logger.warning(f"Skipping entry with resolution error: {entry.name}: {e}")
378 continue
380 if not self._is_git_repo(resolved_entry):
381 continue
383 # Get branch
384 branch_success, branch_out, _ = await self._run_git_command(
385 ["rev-parse", "--abbrev-ref", "HEAD"],
386 cwd=resolved_entry,
387 timeout=10,
388 )
389 branch = branch_out.strip() if branch_success else "unknown"
391 # Check if clean
392 status_success, status_out, _ = await self._run_git_command(
393 ["status", "--porcelain"],
394 cwd=resolved_entry,
395 timeout=10,
396 )
397 is_clean = status_success and not status_out.strip()
399 repositories.append(
400 {
401 "name": entry.name,
402 "path": str(resolved_entry),
403 "branch": branch,
404 "is_clean": is_clean,
405 }
406 )
408 except PermissionError:
409 return self._create_error_response(
410 error="permission_denied",
411 message=f"Permission denied reading repos directory: {repos_root}",
412 )
413 except OSError as e:
414 return self._create_error_response(
415 error="os_error",
416 message=f"Error listing repositories: {str(e)}",
417 )
419 result = {
420 "repos_root": str(repos_root),
421 "repositories": repositories,
422 "count": len(repositories),
423 }
425 return self._create_success_response(
426 result=result,
427 message=f"Found {len(repositories)} repositories",
428 )
430 async def get_status(
431 self,
432 name: Annotated[str, Field(description="Repository name")],
433 ) -> dict:
434 """Get git status for repository. Shows branch, tracking, staged/unstaged changes, untracked files."""
435 # Resolve repo path
436 repo_path = self._resolve_repo_path(name)
437 if isinstance(repo_path, dict):
438 return repo_path
440 # Check if exists and is git repo
441 if not repo_path.exists():
442 return self._create_error_response(
443 error="repo_not_found",
444 message=f"Repository not found: {name}",
445 )
447 if not self._is_git_repo(repo_path):
448 return self._create_error_response(
449 error="not_a_repo",
450 message=f"Path is not a git repository: {name}",
451 )
453 # Get current branch
454 branch_success, branch_out, _ = await self._run_git_command(
455 ["rev-parse", "--abbrev-ref", "HEAD"],
456 cwd=repo_path,
457 )
458 branch = branch_out.strip() if branch_success else "unknown"
460 # Get tracking branch
461 tracking_success, tracking_out, _ = await self._run_git_command(
462 ["rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"],
463 cwd=repo_path,
464 )
465 tracking_branch = tracking_out.strip() if tracking_success else None
467 # Get ahead/behind counts
468 ahead = 0
469 behind = 0
470 if tracking_branch:
471 counts_success, counts_out, _ = await self._run_git_command(
472 ["rev-list", "--left-right", "--count", "HEAD...@{u}"],
473 cwd=repo_path,
474 )
475 if counts_success:
476 parts = counts_out.strip().split()
477 if len(parts) == 2:
478 ahead = int(parts[0])
479 behind = int(parts[1])
481 # Get status (porcelain for parsing)
482 status_success, status_out, _ = await self._run_git_command(
483 ["status", "--porcelain"],
484 cwd=repo_path,
485 )
487 staged: list[str] = []
488 unstaged: list[str] = []
489 untracked: list[str] = []
491 if status_success:
492 for line in status_out.splitlines():
493 if len(line) < 3:
494 continue
496 index_status = line[0]
497 worktree_status = line[1]
498 filename = line[3:]
500 # Staged changes (index modified)
501 if index_status in "MADRC":
502 staged.append(filename)
504 # Unstaged changes (worktree modified)
505 if worktree_status in "MD":
506 unstaged.append(filename)
508 # Untracked files
509 if index_status == "?" and worktree_status == "?":
510 untracked.append(filename)
512 is_clean = not staged and not unstaged and not untracked
514 result: dict[str, Any] = {
515 "name": name,
516 "path": str(repo_path),
517 "branch": branch,
518 "tracking_branch": tracking_branch,
519 "ahead": ahead,
520 "behind": behind,
521 "is_clean": is_clean,
522 "staged": staged,
523 "unstaged": unstaged,
524 "untracked": untracked,
525 }
527 return self._create_success_response(
528 result=result,
529 message=f"Status for {name}: branch={branch}, clean={is_clean}",
530 )
532 async def pull_repository(
533 self,
534 name: Annotated[str, Field(description="Repository name")],
535 remote: Annotated[str | None, Field(description="Remote name")] = None,
536 ) -> dict:
537 """Pull latest changes from remote. Returns update status and any conflicts."""
538 # Resolve repo path
539 repo_path = self._resolve_repo_path(name)
540 if isinstance(repo_path, dict):
541 return repo_path
543 # Check if exists and is git repo
544 if not repo_path.exists():
545 return self._create_error_response(
546 error="repo_not_found",
547 message=f"Repository not found: {name}",
548 )
550 if not self._is_git_repo(repo_path):
551 return self._create_error_response(
552 error="not_a_repo",
553 message=f"Path is not a git repository: {name}",
554 )
556 # Use default remote if not specified
557 if remote is None:
558 remote = self.config.git_default_remote
560 # Fetch first to get remote info
561 fetch_success, _, fetch_err = await self._run_git_command(
562 ["fetch", remote],
563 cwd=repo_path,
564 )
566 if not fetch_success:
567 return self._create_error_response(
568 error="fetch_failed",
569 message=f"Failed to fetch from remote '{remote}': {fetch_err.strip()}",
570 )
572 # Pull changes
573 args = ["pull", remote]
574 success, stdout, stderr = await self._run_git_command(args, cwd=repo_path)
576 if not success:
577 # Check if it's a conflict
578 if "conflict" in stderr.lower() or "conflict" in stdout.lower():
579 return self._create_error_response(
580 error="merge_conflict",
581 message=f"Merge conflict during pull: {stderr.strip()}",
582 )
583 return self._create_error_response(
584 error="pull_failed",
585 message=f"Failed to pull: {stderr.strip()}",
586 )
588 # Parse output for update info
589 updated = "Already up to date" not in stdout
591 result: dict[str, Any] = {
592 "name": name,
593 "path": str(repo_path),
594 "remote": remote,
595 "updated": updated,
596 "message": stdout.strip() if stdout else "Pull completed",
597 }
599 return self._create_success_response(
600 result=result,
601 message=f"Pulled {name} from {remote}: {'updated' if updated else 'already up to date'}",
602 )
604 async def create_branch(
605 self,
606 name: Annotated[str, Field(description="Repository name")],
607 branch_name: Annotated[str, Field(description="New branch name")],
608 checkout: Annotated[bool, Field(description="Checkout the new branch")] = True,
609 source_branch: Annotated[
610 str | None, Field(description="Source branch to create from")
611 ] = None,
612 ) -> dict:
613 """Create branch in repository. Optional checkout and source branch. Returns branch name."""
614 # Resolve repo path
615 repo_path = self._resolve_repo_path(name)
616 if isinstance(repo_path, dict):
617 return repo_path
619 # Check if exists and is git repo
620 if not repo_path.exists():
621 return self._create_error_response(
622 error="repo_not_found",
623 message=f"Repository not found: {name}",
624 )
626 if not self._is_git_repo(repo_path):
627 return self._create_error_response(
628 error="not_a_repo",
629 message=f"Path is not a git repository: {name}",
630 )
632 # Validate branch name
633 if not branch_name or not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9._/-]*$", branch_name):
634 return self._create_error_response(
635 error="invalid_branch_name",
636 message=f"Invalid branch name: {branch_name}",
637 )
639 # Check if branch already exists
640 exists_success, _, _ = await self._run_git_command(
641 ["rev-parse", "--verify", branch_name],
642 cwd=repo_path,
643 )
645 if exists_success:
646 return self._create_error_response(
647 error="branch_exists",
648 message=f"Branch already exists: {branch_name}",
649 )
651 # If source branch specified, checkout it first
652 if source_branch:
653 checkout_success, _, checkout_err = await self._run_git_command(
654 ["checkout", source_branch],
655 cwd=repo_path,
656 )
657 if not checkout_success:
658 return self._create_error_response(
659 error="source_checkout_failed",
660 message=f"Failed to checkout source branch '{source_branch}': {checkout_err.strip()}",
661 )
663 # Create branch
664 if checkout:
665 args = ["checkout", "-b", branch_name]
666 else:
667 args = ["branch", branch_name]
669 success, stdout, stderr = await self._run_git_command(args, cwd=repo_path)
671 if not success:
672 return self._create_error_response(
673 error="create_branch_failed",
674 message=f"Failed to create branch: {stderr.strip()}",
675 )
677 # Get current branch after operation
678 current_success, current_out, _ = await self._run_git_command(
679 ["rev-parse", "--abbrev-ref", "HEAD"],
680 cwd=repo_path,
681 )
682 current_branch = current_out.strip() if current_success else "unknown"
684 result: dict[str, Any] = {
685 "name": name,
686 "branch_name": branch_name,
687 "checked_out": checkout,
688 "current_branch": current_branch,
689 "source_branch": source_branch,
690 }
692 return self._create_success_response(
693 result=result,
694 message=f"Created branch '{branch_name}' in {name}",
695 )
697 async def manage_remote(
698 self,
699 name: Annotated[str, Field(description="Repository name")],
700 action: Annotated[str, Field(description="Action: list, add, or remove")],
701 remote_name: Annotated[
702 str | None, Field(description="Remote name (required for add/remove)")
703 ] = None,
704 remote_url: Annotated[
705 str | None, Field(description="Remote URL (required for add)")
706 ] = None,
707 ) -> dict:
708 """Manage git remotes: list, add, or remove. Validates URL format. Returns remote configuration."""
709 # Resolve repo path
710 repo_path = self._resolve_repo_path(name)
711 if isinstance(repo_path, dict):
712 return repo_path
714 # Check if exists and is git repo
715 if not repo_path.exists():
716 return self._create_error_response(
717 error="repo_not_found",
718 message=f"Repository not found: {name}",
719 )
721 if not self._is_git_repo(repo_path):
722 return self._create_error_response(
723 error="not_a_repo",
724 message=f"Path is not a git repository: {name}",
725 )
727 # Validate action
728 valid_actions = ["list", "add", "remove"]
729 if action not in valid_actions:
730 return self._create_error_response(
731 error="invalid_action",
732 message=f"Invalid action: {action}. Valid actions: {', '.join(valid_actions)}",
733 )
735 if action == "list":
736 # Get verbose remote list
737 success, stdout, stderr = await self._run_git_command(
738 ["remote", "-v"],
739 cwd=repo_path,
740 )
742 if not success:
743 return self._create_error_response(
744 error="list_remotes_failed",
745 message=f"Failed to list remotes: {stderr.strip()}",
746 )
748 remotes: dict[str, dict[str, str]] = {}
749 for line in stdout.splitlines():
750 parts = line.split()
751 if len(parts) >= 2:
752 rname = parts[0]
753 rurl = parts[1]
754 rtype = parts[2].strip("()") if len(parts) > 2 else "unknown"
756 if rname not in remotes:
757 remotes[rname] = {}
758 remotes[rname][rtype] = rurl
760 result: dict[str, Any] = {
761 "name": name,
762 "remotes": remotes,
763 "count": len(remotes),
764 }
766 return self._create_success_response(
767 result=result,
768 message=f"Listed {len(remotes)} remotes for {name}",
769 )
771 elif action == "add":
772 # Validate inputs
773 if not remote_name:
774 return self._create_error_response(
775 error="missing_remote_name",
776 message="Remote name is required for add action",
777 )
779 if not remote_url:
780 return self._create_error_response(
781 error="missing_remote_url",
782 message="Remote URL is required for add action",
783 )
785 # Validate URL
786 url_error = self._validate_git_url(remote_url)
787 if url_error:
788 return url_error
790 # Add remote
791 success, stdout, stderr = await self._run_git_command(
792 ["remote", "add", remote_name, remote_url],
793 cwd=repo_path,
794 )
796 if not success:
797 return self._create_error_response(
798 error="add_remote_failed",
799 message=f"Failed to add remote: {stderr.strip()}",
800 )
802 result = {
803 "name": name,
804 "action": "add",
805 "remote_name": remote_name,
806 "remote_url": remote_url,
807 }
809 return self._create_success_response(
810 result=result,
811 message=f"Added remote '{remote_name}' to {name}",
812 )
814 else: # remove
815 # Validate inputs
816 if not remote_name:
817 return self._create_error_response(
818 error="missing_remote_name",
819 message="Remote name is required for remove action",
820 )
822 # Remove remote
823 success, stdout, stderr = await self._run_git_command(
824 ["remote", "remove", remote_name],
825 cwd=repo_path,
826 )
828 if not success:
829 return self._create_error_response(
830 error="remove_remote_failed",
831 message=f"Failed to remove remote: {stderr.strip()}",
832 )
834 result = {
835 "name": name,
836 "action": "remove",
837 "remote_name": remote_name,
838 }
840 return self._create_success_response(
841 result=result,
842 message=f"Removed remote '{remote_name}' from {name}",
843 )