diff --git a/src/stack_pr/cli.py b/src/stack_pr/cli.py index d068998..7043851 100755 --- a/src/stack_pr/cli.py +++ b/src/stack_pr/cli.py @@ -55,6 +55,7 @@ import logging import os import re +import subprocess import sys from dataclasses import dataclass from functools import cache @@ -399,19 +400,40 @@ def last(ref: str, sep: str = "/") -> str: return ref.rsplit(sep, 1)[-1] -# TODO: Move to 'modular.utils.git' +class GitAncestryError(RuntimeError): + """Error raised when git ancestry check fails.""" + + def is_ancestor(commit1: str, commit2: str, *, verbose: bool) -> bool: """ Returns true if 'commit1' is an ancestor of 'commit2'. + + Raises: + GitAncestryError: If git command fails for reasons other than ancestry check. """ - # TODO: We need to check returncode of this command more carefully, as the - # command simply might fail (rc != 0 and rc != 1). - p = run_shell_command( - ["git", "merge-base", "--is-ancestor", commit1, commit2], - check=False, - quiet=not verbose, - ) - return p.returncode == 0 + error_code = None + + try: + p = run_shell_command( + ["git", "merge-base", "--is-ancestor", commit1, commit2], + check=False, + quiet=not verbose, + ) + if p.returncode == 0: + return True + if p.returncode == 1: + return False + + # Store error code for later + error_code = p.returncode + except subprocess.SubprocessError as e: + raise GitAncestryError(f"Failed to determine ancestry relationship: {e}") from e + + # Handle error code outside the try block + if error_code is None: + # This should never happen, but just in case + raise GitAncestryError("Unexpected error in git ancestry check") + raise GitAncestryError(f"Git ancestry check failed with code {error_code}") def is_repo_clean() -> bool: @@ -456,59 +478,74 @@ def set_base_branches(st: list[StackEntry], target: str) -> None: def verify(st: list[StackEntry], *, check_base: bool = False) -> None: log(h("Verifying stack info")) - for index, e in enumerate(st): - if e.has_missing_info(): - error(ERROR_STACKINFO_MISSING.format(**locals())) + for index, entry in enumerate(st): + if entry.has_missing_info(): + error(ERROR_STACKINFO_MISSING.format(e=entry)) raise RuntimeError - if len(e.pr.split("/")) == 0 or not last(e.pr).isnumeric(): - error(ERROR_STACKINFO_BAD_LINK.format(**locals())) + if len(entry.pr.split("/")) == 0 or not last(entry.pr).isnumeric(): + error(ERROR_STACKINFO_BAD_LINK.format(e=entry)) raise RuntimeError - ghinfo = get_command_output( - [ - "gh", - "pr", - "view", - e.pr, - "--json", - "baseRefName,headRefName,number,state,body,title,url,mergeStateStatus", - ] - ) - d = json.loads(ghinfo) - for required_field in ["state", "number", "baseRefName", "headRefName"]: - if required_field not in d: - error(ERROR_STACKINFO_MALFORMED_RESPONSE.format(**locals())) + try: + ghinfo = get_command_output( + [ + "gh", + "pr", + "view", + entry.pr, + "--json", + "baseRefName,headRefName,number,state,body,title,url,mergeStateStatus", + ] + ) + + try: + d = json.loads(ghinfo) + except json.JSONDecodeError as e: + error(f"Failed to parse JSON response from GitHub: {ghinfo}") + raise RuntimeError("Invalid JSON response from GitHub") from e + + for required_field in ["state", "number", "baseRefName", "headRefName"]: + if required_field not in d: + error( + ERROR_STACKINFO_MALFORMED_RESPONSE.format( + e=entry, required_field=required_field, d=d + ) + ) + raise RuntimeError + + if d["state"] != "OPEN": + error(ERROR_STACKINFO_PR_NOT_OPEN.format(e=entry, d=d)) raise RuntimeError - if d["state"] != "OPEN": - error(ERROR_STACKINFO_PR_NOT_OPEN.format(**locals())) - raise RuntimeError + if int(last(entry.pr)) != d["number"]: + error(ERROR_STACKINFO_PR_NUMBER_MISMATCH.format(e=entry, d=d)) + raise RuntimeError - if int(last(e.pr)) != d["number"]: - error(ERROR_STACKINFO_PR_NUMBER_MISMATCH.format(**locals())) - raise RuntimeError + if entry.head != d["headRefName"]: + error(ERROR_STACKINFO_PR_HEAD_MISMATCH.format(e=entry, d=d)) + raise RuntimeError - if e.head != d["headRefName"]: - error(ERROR_STACKINFO_PR_HEAD_MISMATCH.format(**locals())) - raise RuntimeError + # 'Base' branch might diverge when the stack is modified (e.g. when a + # new commit is added to the middle of the stack). It is not an issue + # if we're updating the stack (i.e. in 'submit'), but it is an issue if + # we are trying to land it. + if check_base and entry.base != d["baseRefName"]: + error(ERROR_STACKINFO_PR_BASE_MISMATCH.format(e=entry, d=d)) + raise RuntimeError - # 'Base' branch might diverge when the stack is modified (e.g. when a - # new commit is added to the middle of the stack). It is not an issue - # if we're updating the stack (i.e. in 'submit'), but it is an issue if - # we are trying to land it. - if check_base and e.base != d["baseRefName"]: - error(ERROR_STACKINFO_PR_BASE_MISMATCH.format(**locals())) - raise RuntimeError + # The first entry on the stack needs to be actually mergeable on GitHub. + if ( + check_base + and index == 0 + and d["mergeStateStatus"] not in ["CLEAN", "UNKNOWN", "UNSTABLE"] + ): + error(ERROR_STACKINFO_PR_NOT_MERGEABLE.format(e=entry, d=d)) + raise RuntimeError - # The first entry on the stack needs to be actually mergeable on GitHub. - if ( - check_base - and index == 0 - and d["mergeStateStatus"] not in ["CLEAN", "UNKNOWN", "UNSTABLE"] - ): - error(ERROR_STACKINFO_PR_NOT_MERGEABLE.format(**locals())) - raise RuntimeError + except subprocess.CalledProcessError as exc: + error(f"Failed to get PR information from GitHub: {exc}") + raise RuntimeError("GitHub API request failed") from exc def print_stack(st: list[StackEntry], *, links: bool, level: int = 1) -> None: @@ -603,10 +640,39 @@ def get_taken_branch_ids(refs: list[str], branch_name_template: str) -> list[int def generate_available_branch_name(refs: list[str], branch_name_template: str) -> str: + """Generate an available branch name that doesn't conflict with existing branches. + + This function handles potential race conditions by using an ID higher than + the current maximum. + + Args: + refs: List of existing branch references + branch_name_template: Template for the branch name + + Returns: + A branch name that doesn't conflict with existing branches + """ + max_attempts = 100 branch_ids = get_taken_branch_ids(refs, branch_name_template) max_ref_num = max(branch_ids) if branch_ids else 0 new_branch_id = max_ref_num + 1 - return generate_branch_name(branch_name_template, new_branch_id) + + # Safety check: verify the new branch name doesn't already exist + new_branch_name = generate_branch_name(branch_name_template, new_branch_id) + attempts = 0 + while any( + ref.endswith(f"/{new_branch_name}") or ref == new_branch_name for ref in refs + ): + # Increment and try again if there's a conflict + new_branch_id += 1 + new_branch_name = generate_branch_name(branch_name_template, new_branch_id) + attempts += 1 + if attempts > max_attempts: # Prevent infinite loops + raise RuntimeError( + "Unable to generate a unique branch name after 100 attempts" + ) + + return new_branch_name def get_available_branch_name(remote: str, branch_name_template: str) -> str: @@ -955,7 +1021,7 @@ def command_submit( return if (draft_bitmask is not None) and (len(draft_bitmask) != len(st)): - log(h("Draft bitmask passed to 'submit' doesn't match number of PRs!")) + error("Draft bitmask passed to 'submit' doesn't match number of PRs!") return # Create local branches and initialize base and head fields in the stack @@ -1121,6 +1187,19 @@ def delete_remote_branches( cmd.extend([f":{branch}" for branch in remote_branches_to_delete]) run_shell_command(cmd, check=False, quiet=not verbose) + # Close associated PRs as mentioned in the docstring + for e in st: + if e.has_pr(): + try: + run_shell_command( + ["gh", "pr", "close", e.pr, "--delete-branch=false"], + check=False, + quiet=not verbose, + ) + log(f"Closed PR {e.pr}", level=1) + except Exception as exc: # noqa: BLE001 + log(f"Failed to close PR {e.pr}: {exc}", level=1) + # ===----------------------------------------------------------------------=== # # Entry point for 'land' command @@ -1467,7 +1546,7 @@ def load_config(config_file: str) -> configparser.ConfigParser: return config -def main() -> None: # noqa: PLR0912 +def main() -> None: # noqa: PLR0912, PLR0915 config_file = os.getenv("STACKPR_CONFIG", ".stack-pr.cfg") config = load_config(config_file) @@ -1490,9 +1569,17 @@ def main() -> None: # noqa: PLR0912 current_branch = get_current_branch_name() get_branch_name_base(common_args.branch_name_template) + stashed = False try: if args.command in ["submit", "export"] and args.stash: - run_shell_command(["git", "stash", "save"], quiet=not common_args.verbose) + # Check if there's anything to stash first + if not is_repo_clean(): + run_shell_command( + ["git", "stash", "save"], quiet=not common_args.verbose + ) + stashed = True + else: + log("No changes to stash", level=1) if args.command != "view" and not is_repo_clean(): error(ERROR_REPO_DIRTY) @@ -1518,15 +1605,27 @@ def main() -> None: # noqa: PLR0912 return except Exception as exc: # If something failed, checkout the original branch - run_shell_command( - ["git", "checkout", current_branch], quiet=not common_args.verbose - ) + try: + run_shell_command( + ["git", "checkout", current_branch], quiet=not common_args.verbose + ) + except Exception as checkout_error: # noqa: BLE001 + error(f"Failed to checkout original branch: {checkout_error}") if isinstance(exc, SubprocessError): print_cmd_failure_details(exc) raise finally: - if args.command in ["submit", "export"] and args.stash: - run_shell_command(["git", "stash", "pop"], quiet=not common_args.verbose) + # Only try to pop the stash if we actually stashed something + if stashed and args.command in ["submit", "export"]: + try: + run_shell_command( + ["git", "stash", "pop"], quiet=not common_args.verbose + ) + except Exception as stash_error: # noqa: BLE001 + error(f"Failed to pop stashed changes: {stash_error}") + error( + "Your changes are still in the stash. Run 'git stash pop' to retrieve them." + ) if __name__ == "__main__": diff --git a/src/stack_pr/git.py b/src/stack_pr/git.py index 3cb6ce3..c1c3eb7 100644 --- a/src/stack_pr/git.py +++ b/src/stack_pr/git.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re import string import subprocess from collections.abc import Sequence @@ -132,20 +131,26 @@ def get_uncommitted_changes( return changes -# TODO: enforce this as a module dependency def check_gh_installed() -> None: - """Check if the gh tool is installed. + """Check if the gh tool is installed and authenticated. Raises: - GitError if gh is not available. + GitError: If gh is not available or not authenticated. """ - try: - run_shell_command(["gh"], capture_output=True, quiet=False) + # Check if gh is installed + run_shell_command(["gh", "--version"], capture_output=True, quiet=False) + + # Check if gh is authenticated + auth_status = get_command_output(["gh", "auth", "status"], check=False) + if "You are not logged into any GitHub hosts" in auth_status: + raise GitError( + "'gh' is not authenticated. Please run 'gh auth login' to authenticate." + ) except subprocess.CalledProcessError as err: raise GitError( - "'gh' is not installed. Please visit https://cli.github.com/ for" - " installation instuctions." + "'gh' is not installed or not accessible. Please visit https://cli.github.com/ for" + " installation instructions." ) from err @@ -175,12 +180,17 @@ def get_gh_username() -> str: ] ) - # Extract the login name. - m = re.search(r"\"login\":\"(.*?)\"", user_query) - if not m: - raise GitError("Unable to find current github user name") + # Parse JSON response properly + import json - return m.group(1) + try: + response = json.loads(user_query) + login = response.get("data", {}).get("viewer", {}).get("login") + if not login: + raise GitError("Unable to find current github user name") + return str(login) # Ensure we return a string + except json.JSONDecodeError as e: + raise GitError("Invalid response from GitHub API") from e def get_changed_files( diff --git a/src/stack_pr/shell_commands.py b/src/stack_pr/shell_commands.py index ac12e30..2f7b629 100644 --- a/src/stack_pr/shell_commands.py +++ b/src/stack_pr/shell_commands.py @@ -5,18 +5,17 @@ from collections.abc import Iterable from logging import getLogger from pathlib import Path +from typing import Any, Union -if sys.version_info >= (3, 13): - # Unpack moved to typing - from typing import Any, Union -else: - from typing import Union - - from typing_extensions import Any +# For Python versions that don't have typing.Unpack yet (pre-3.13), +# we import from typing_extensions instead +if sys.version_info < (3, 13): + from typing_extensions import Unpack # noqa: F401 logger = getLogger(__name__) +# Define type for shell commands, using Iterable for improved compatibility ShellCommand = Iterable[Union[str, Path]]