Skip to content

Edge case clean up #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 160 additions & 61 deletions src/stack_pr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import logging
import os
import re
import subprocess
import sys
from dataclasses import dataclass
from functools import cache
Expand Down Expand Up @@ -399,19 +400,40 @@ def last(ref: str, sep: str = "/") -> str:
return ref.rsplit(sep, 1)[-1]


# TODO: Move to 'modular.utils.git'
class GitAncestryError(RuntimeError):
"""Error raised when git ancestry check fails."""


def is_ancestor(commit1: str, commit2: str, *, verbose: bool) -> bool:
"""
Returns true if 'commit1' is an ancestor of 'commit2'.

Raises:
GitAncestryError: If git command fails for reasons other than ancestry check.
"""
# TODO: We need to check returncode of this command more carefully, as the
# command simply might fail (rc != 0 and rc != 1).
p = run_shell_command(
["git", "merge-base", "--is-ancestor", commit1, commit2],
check=False,
quiet=not verbose,
)
return p.returncode == 0
error_code = None

try:
p = run_shell_command(
["git", "merge-base", "--is-ancestor", commit1, commit2],
check=False,
quiet=not verbose,
)
if p.returncode == 0:
return True
if p.returncode == 1:
return False

# Store error code for later
error_code = p.returncode
except subprocess.SubprocessError as e:
raise GitAncestryError(f"Failed to determine ancestry relationship: {e}") from e

# Handle error code outside the try block
if error_code is None:
# This should never happen, but just in case
raise GitAncestryError("Unexpected error in git ancestry check")
raise GitAncestryError(f"Git ancestry check failed with code {error_code}")


def is_repo_clean() -> bool:
Expand Down Expand Up @@ -456,59 +478,74 @@ def set_base_branches(st: list[StackEntry], target: str) -> None:

def verify(st: list[StackEntry], *, check_base: bool = False) -> None:
log(h("Verifying stack info"))
for index, e in enumerate(st):
if e.has_missing_info():
error(ERROR_STACKINFO_MISSING.format(**locals()))
for index, entry in enumerate(st):
if entry.has_missing_info():
error(ERROR_STACKINFO_MISSING.format(e=entry))
raise RuntimeError

if len(e.pr.split("/")) == 0 or not last(e.pr).isnumeric():
error(ERROR_STACKINFO_BAD_LINK.format(**locals()))
if len(entry.pr.split("/")) == 0 or not last(entry.pr).isnumeric():
error(ERROR_STACKINFO_BAD_LINK.format(e=entry))
raise RuntimeError

ghinfo = get_command_output(
[
"gh",
"pr",
"view",
e.pr,
"--json",
"baseRefName,headRefName,number,state,body,title,url,mergeStateStatus",
]
)
d = json.loads(ghinfo)
for required_field in ["state", "number", "baseRefName", "headRefName"]:
if required_field not in d:
error(ERROR_STACKINFO_MALFORMED_RESPONSE.format(**locals()))
try:
ghinfo = get_command_output(
[
"gh",
"pr",
"view",
entry.pr,
"--json",
"baseRefName,headRefName,number,state,body,title,url,mergeStateStatus",
]
)

try:
d = json.loads(ghinfo)
except json.JSONDecodeError as e:
error(f"Failed to parse JSON response from GitHub: {ghinfo}")
raise RuntimeError("Invalid JSON response from GitHub") from e

for required_field in ["state", "number", "baseRefName", "headRefName"]:
if required_field not in d:
error(
ERROR_STACKINFO_MALFORMED_RESPONSE.format(
e=entry, required_field=required_field, d=d
)
)
raise RuntimeError

if d["state"] != "OPEN":
error(ERROR_STACKINFO_PR_NOT_OPEN.format(e=entry, d=d))
raise RuntimeError

if d["state"] != "OPEN":
error(ERROR_STACKINFO_PR_NOT_OPEN.format(**locals()))
raise RuntimeError
if int(last(entry.pr)) != d["number"]:
error(ERROR_STACKINFO_PR_NUMBER_MISMATCH.format(e=entry, d=d))
raise RuntimeError

if int(last(e.pr)) != d["number"]:
error(ERROR_STACKINFO_PR_NUMBER_MISMATCH.format(**locals()))
raise RuntimeError
if entry.head != d["headRefName"]:
error(ERROR_STACKINFO_PR_HEAD_MISMATCH.format(e=entry, d=d))
raise RuntimeError

if e.head != d["headRefName"]:
error(ERROR_STACKINFO_PR_HEAD_MISMATCH.format(**locals()))
raise RuntimeError
# 'Base' branch might diverge when the stack is modified (e.g. when a
# new commit is added to the middle of the stack). It is not an issue
# if we're updating the stack (i.e. in 'submit'), but it is an issue if
# we are trying to land it.
if check_base and entry.base != d["baseRefName"]:
error(ERROR_STACKINFO_PR_BASE_MISMATCH.format(e=entry, d=d))
raise RuntimeError

# 'Base' branch might diverge when the stack is modified (e.g. when a
# new commit is added to the middle of the stack). It is not an issue
# if we're updating the stack (i.e. in 'submit'), but it is an issue if
# we are trying to land it.
if check_base and e.base != d["baseRefName"]:
error(ERROR_STACKINFO_PR_BASE_MISMATCH.format(**locals()))
raise RuntimeError
# The first entry on the stack needs to be actually mergeable on GitHub.
if (
check_base
and index == 0
and d["mergeStateStatus"] not in ["CLEAN", "UNKNOWN", "UNSTABLE"]
):
error(ERROR_STACKINFO_PR_NOT_MERGEABLE.format(e=entry, d=d))
raise RuntimeError

# The first entry on the stack needs to be actually mergeable on GitHub.
if (
check_base
and index == 0
and d["mergeStateStatus"] not in ["CLEAN", "UNKNOWN", "UNSTABLE"]
):
error(ERROR_STACKINFO_PR_NOT_MERGEABLE.format(**locals()))
raise RuntimeError
except subprocess.CalledProcessError as exc:
error(f"Failed to get PR information from GitHub: {exc}")
raise RuntimeError("GitHub API request failed") from exc


def print_stack(st: list[StackEntry], *, links: bool, level: int = 1) -> None:
Expand Down Expand Up @@ -603,10 +640,39 @@ def get_taken_branch_ids(refs: list[str], branch_name_template: str) -> list[int


def generate_available_branch_name(refs: list[str], branch_name_template: str) -> str:
"""Generate an available branch name that doesn't conflict with existing branches.

This function handles potential race conditions by using an ID higher than
the current maximum.

Args:
refs: List of existing branch references
branch_name_template: Template for the branch name

Returns:
A branch name that doesn't conflict with existing branches
"""
max_attempts = 100
branch_ids = get_taken_branch_ids(refs, branch_name_template)
max_ref_num = max(branch_ids) if branch_ids else 0
new_branch_id = max_ref_num + 1
return generate_branch_name(branch_name_template, new_branch_id)

# Safety check: verify the new branch name doesn't already exist
new_branch_name = generate_branch_name(branch_name_template, new_branch_id)
attempts = 0
while any(
ref.endswith(f"/{new_branch_name}") or ref == new_branch_name for ref in refs
):
# Increment and try again if there's a conflict
new_branch_id += 1
new_branch_name = generate_branch_name(branch_name_template, new_branch_id)
attempts += 1
if attempts > max_attempts: # Prevent infinite loops
raise RuntimeError(
"Unable to generate a unique branch name after 100 attempts"
)

return new_branch_name


def get_available_branch_name(remote: str, branch_name_template: str) -> str:
Expand Down Expand Up @@ -955,7 +1021,7 @@ def command_submit(
return

if (draft_bitmask is not None) and (len(draft_bitmask) != len(st)):
log(h("Draft bitmask passed to 'submit' doesn't match number of PRs!"))
error("Draft bitmask passed to 'submit' doesn't match number of PRs!")
return

# Create local branches and initialize base and head fields in the stack
Expand Down Expand Up @@ -1121,6 +1187,19 @@ def delete_remote_branches(
cmd.extend([f":{branch}" for branch in remote_branches_to_delete])
run_shell_command(cmd, check=False, quiet=not verbose)

# Close associated PRs as mentioned in the docstring
for e in st:
if e.has_pr():
try:
run_shell_command(
["gh", "pr", "close", e.pr, "--delete-branch=false"],
check=False,
quiet=not verbose,
)
log(f"Closed PR {e.pr}", level=1)
except Exception as exc: # noqa: BLE001
log(f"Failed to close PR {e.pr}: {exc}", level=1)


# ===----------------------------------------------------------------------=== #
# Entry point for 'land' command
Expand Down Expand Up @@ -1467,7 +1546,7 @@ def load_config(config_file: str) -> configparser.ConfigParser:
return config


def main() -> None: # noqa: PLR0912
def main() -> None: # noqa: PLR0912, PLR0915
config_file = os.getenv("STACKPR_CONFIG", ".stack-pr.cfg")
config = load_config(config_file)

Expand All @@ -1490,9 +1569,17 @@ def main() -> None: # noqa: PLR0912

current_branch = get_current_branch_name()
get_branch_name_base(common_args.branch_name_template)
stashed = False
try:
if args.command in ["submit", "export"] and args.stash:
run_shell_command(["git", "stash", "save"], quiet=not common_args.verbose)
# Check if there's anything to stash first
if not is_repo_clean():
run_shell_command(
["git", "stash", "save"], quiet=not common_args.verbose
)
stashed = True
else:
log("No changes to stash", level=1)

if args.command != "view" and not is_repo_clean():
error(ERROR_REPO_DIRTY)
Expand All @@ -1518,15 +1605,27 @@ def main() -> None: # noqa: PLR0912
return
except Exception as exc:
# If something failed, checkout the original branch
run_shell_command(
["git", "checkout", current_branch], quiet=not common_args.verbose
)
try:
run_shell_command(
["git", "checkout", current_branch], quiet=not common_args.verbose
)
except Exception as checkout_error: # noqa: BLE001
error(f"Failed to checkout original branch: {checkout_error}")
if isinstance(exc, SubprocessError):
print_cmd_failure_details(exc)
raise
finally:
if args.command in ["submit", "export"] and args.stash:
run_shell_command(["git", "stash", "pop"], quiet=not common_args.verbose)
# Only try to pop the stash if we actually stashed something
if stashed and args.command in ["submit", "export"]:
try:
run_shell_command(
["git", "stash", "pop"], quiet=not common_args.verbose
)
except Exception as stash_error: # noqa: BLE001
error(f"Failed to pop stashed changes: {stash_error}")
error(
"Your changes are still in the stash. Run 'git stash pop' to retrieve them."
)


if __name__ == "__main__":
Expand Down
36 changes: 23 additions & 13 deletions src/stack_pr/git.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import re
import string
import subprocess
from collections.abc import Sequence
Expand Down Expand Up @@ -132,20 +131,26 @@ def get_uncommitted_changes(
return changes


# TODO: enforce this as a module dependency
def check_gh_installed() -> None:
"""Check if the gh tool is installed.
"""Check if the gh tool is installed and authenticated.

Raises:
GitError if gh is not available.
GitError: If gh is not available or not authenticated.
"""

try:
run_shell_command(["gh"], capture_output=True, quiet=False)
# Check if gh is installed
run_shell_command(["gh", "--version"], capture_output=True, quiet=False)

# Check if gh is authenticated
auth_status = get_command_output(["gh", "auth", "status"], check=False)
if "You are not logged into any GitHub hosts" in auth_status:
raise GitError(
"'gh' is not authenticated. Please run 'gh auth login' to authenticate."
)
except subprocess.CalledProcessError as err:
raise GitError(
"'gh' is not installed. Please visit https://cli.github.com/ for"
" installation instuctions."
"'gh' is not installed or not accessible. Please visit https://cli.github.com/ for"
" installation instructions."
) from err


Expand Down Expand Up @@ -175,12 +180,17 @@ def get_gh_username() -> str:
]
)

# Extract the login name.
m = re.search(r"\"login\":\"(.*?)\"", user_query)
if not m:
raise GitError("Unable to find current github user name")
# Parse JSON response properly
import json

return m.group(1)
try:
response = json.loads(user_query)
login = response.get("data", {}).get("viewer", {}).get("login")
if not login:
raise GitError("Unable to find current github user name")
return str(login) # Ensure we return a string
except json.JSONDecodeError as e:
raise GitError("Invalid response from GitHub API") from e


def get_changed_files(
Expand Down
13 changes: 6 additions & 7 deletions src/stack_pr/shell_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@
from collections.abc import Iterable
from logging import getLogger
from pathlib import Path
from typing import Any, Union

if sys.version_info >= (3, 13):
# Unpack moved to typing
from typing import Any, Union
else:
from typing import Union

from typing_extensions import Any
# For Python versions that don't have typing.Unpack yet (pre-3.13),
# we import from typing_extensions instead
if sys.version_info < (3, 13):
from typing_extensions import Unpack # noqa: F401


logger = getLogger(__name__)

# Define type for shell commands, using Iterable for improved compatibility
ShellCommand = Iterable[Union[str, Path]]


Expand Down