Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow adding labels when a backport fails #40

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions patchback/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Misc. utility functions
"""

from __future__ import annotations

from typing import TypeVar

_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")


def strip_nones(mapping: dict[_T1, _T2 | None]) -> dict[_T1, _T2]:
"""
Remove keys set to None from a dictionary

Returns:
A new dictionary instance
"""
return {key: value for key, value in mapping.items() if value is not None}
19 changes: 19 additions & 0 deletions patchback/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DEFAULT_BACKPORT_BRANCH_PREFIX = 'patchback/backports/'
DEFAULT_BACKPORT_LABEL_PREFIX = 'backport-'
DEFAULT_TARGET_BRANCH_PREFIX = ''
DEFAULT_FAILED_BACKPORT_LABEL_PREFIX = 'failed-backport-'


@attr.dataclass
Expand All @@ -30,6 +31,24 @@ class PatchbackConfig:
)
"""Prefix that the older/stable version branch has."""

failed_label_prefix: str | None = attr.ib(default=DEFAULT_FAILED_BACKPORT_LABEL_PREFIX)
gotmax23 marked this conversation as resolved.
Show resolved Hide resolved
"""
Add {failed_label_prefix}-{target_branch} label when backport fails.
{target_branch_prefix} is stripped from {target_branch}.
Set to None to disable adding a label on failure.
"""

@failed_label_prefix.validator
def _v_failed_label_prefix(self, _, value: str | None) -> None:
"""
Ensure backport_label_prefix and failed_label_prefix are different
to avoid infinite loops
"""
if value == self.backport_label_prefix:
raise ValueError(
'failed_label_prefix and backport_label_prefix must be unique values'
)


async def get_patchback_config(
*,
Expand Down
58 changes: 58 additions & 0 deletions patchback/event_handlers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Webhook event handlers."""

from __future__ import annotations

import http
import functools
import logging
import pathlib
import tempfile
from subprocess import CalledProcessError, check_output, check_call
from typing import Any

from anyio import run_in_thread
from gidgethub import BadRequest, ValidationError
Expand All @@ -18,6 +22,7 @@
from .locking_api import LockingAPI
from .config import get_patchback_config
from .github_reporter import PullRequestReporter
from .labels_api import IssueLabelsAPI, RepoLabelsAPI


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -282,6 +287,9 @@ async def on_merge_of_labeled_pr(
repository['pulls_url'],
repository['full_name'],
repository['clone_url'],
repo_config.backport_label_prefix,
repo_config.target_branch_prefix,
repo_config.failed_label_prefix,
)


Expand Down Expand Up @@ -332,6 +340,9 @@ async def on_label_added_to_merged_pr(
repository['pulls_url'],
repository['full_name'],
repository['clone_url'],
repo_config.backport_label_prefix,
repo_config.target_branch_prefix,
repo_config.failed_label_prefix,
)


Expand All @@ -348,6 +359,9 @@ async def process_pr_backport_labels(
backport_branch_prefix,
pr_api_url, repo_slug,
git_url,
backport_label_prefix: str,
target_branch_prefix: str,
failed_label_prefix: str | None,
) -> None:
gh_api = RUNTIME_CONTEXT.app_installation_client
checks_api = ChecksAPI(
Expand All @@ -366,6 +380,17 @@ async def process_pr_backport_labels(
locking_api=locking_api,
branch_name=target_branch,
)
labels_api = RepoLabelsAPI(api=gh_api, repo_slug=repo_slug)
issue_labels_api = IssueLabelsAPI(api=gh_api, repo_slug=repo_slug, number=pr_number)
failed_label_cb = functools.partial(
add_failure_label,
labels_api=labels_api,
issue_labels_api=issue_labels_api,
backport_label_prefix=backport_label_prefix,
failed_label_prefix=failed_label_prefix,
target_branch_prefix=target_branch_prefix,
target_branch=target_branch,
)

await pr_reporter.start_reporting(pr_head_sha, pr_number, pr_merge_commit)

Expand Down Expand Up @@ -396,6 +421,7 @@ async def process_pr_backport_labels(
subtitle='💔 cherry-picking failed — target branch does not exist',
summary=f'❌ {lu_err!s}',
)
await failed_label_cb()
return
except ValueError as val_err:
logger.info(
Expand All @@ -409,6 +435,7 @@ async def process_pr_backport_labels(
text=manual_backport_guide,
summary=f'❌ {val_err!s}',
)
await failed_label_cb()
return
except PermissionError as perm_err:
logger.info(
Expand All @@ -423,6 +450,7 @@ async def process_pr_backport_labels(
text=manual_backport_guide,
summary=f'❌ {perm_err!s}',
)
await failed_label_cb()
return
else:
logger.info('Backport PR branch: `%s`', backport_pr_branch)
Expand Down Expand Up @@ -461,6 +489,7 @@ async def process_pr_backport_labels(
text=manual_backport_guide,
summary=f'❌ {backport_pr_branch_msg}\n\n{val_err!s}',
)
await failed_label_cb()
return
except BadRequest as bad_req_err:
if (
Expand All @@ -480,6 +509,7 @@ async def process_pr_backport_labels(
text=manual_backport_guide,
summary=f'❌ {backport_pr_branch_msg}\n\n{bad_req_err!s}',
)
await failed_label_cb()
return
else:
logger.info('Created a PR @ %s', pr_resp['html_url'])
Expand All @@ -490,3 +520,31 @@ async def process_pr_backport_labels(
text=f'Backported as {pr_resp["html_url"]}',
summary=f'✅ {backport_pr_branch_msg!s}',
)

async def add_failure_label(
*,
labels_api: RepoLabelsAPI,
issue_labels_api: IssueLabelsAPI,
backport_label_prefix: str,
failed_label_prefix: str | None,
target_branch_prefix: str,
target_branch: str,
) -> dict[str, Any] | None:
if failed_label_prefix is None:
return None
stripped_branch = target_branch[len(target_branch_prefix):]
label: str = failed_label_prefix + stripped_branch
backport_label = backport_label_prefix + stripped_branch
# Create label if it doesn't exist
try:
await labels_api.get_label(label)
except BadRequest as exc:
if exc.status_code != 404:
raise
await labels_api.create_label(
label, f'Failed to backport PR to {target_branch}.'
)
# Add failed label
await issue_labels_api.add_labels(label)
# Delete backport label
await issue_labels_api.remove_label(backport_label)
Comment on lines +549 to +550
Copy link
Author

@gotmax23 gotmax23 Oct 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm debating whether it makes sense to also drop the existing backport label on failures. Wdyt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A-la #12? That'd need to be a separate PR.

52 changes: 52 additions & 0 deletions patchback/labels_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Wrappers around repository and issue label APIs
"""

from __future__ import annotations

from collections.abc import AsyncIterator
from typing import Any

from gidgethub.abc import GitHubAPI

from patchback._utils import strip_nones


class RepoLabelsAPI:
def __init__(self, *, api: GitHubAPI, repo_slug: str) -> None:
self._api: GitHubAPI = api
self._labels_api = f"/repos/{repo_slug}/labels"

async def create_label(
self, name: str, description: str | None = None, color: str | None = None
) -> dict[str, Any]:
return await self._api.post(
self._labels_api,
data=strip_nones(
{"name": name, "description": description, "color": color}
),
)

def list_labels(self) -> AsyncIterator[dict[str, Any]]:
return self._api.getiter(self._labels_api)

async def get_label(self, name: str) -> dict[str, Any]:
return await self._api.getitem(f"{self._labels_api}/{name}")


class IssueLabelsAPI:
def __init__(self, *, api: GitHubAPI, repo_slug: str, number: int) -> None:
self._api: GitHubAPI = api
self._issue_labels_api = f"/repos/{repo_slug}/issues/{number}/labels"

async def add_labels(self, *labels: str) -> list[dict[str, Any]]:
return await self._api.post(
self._issue_labels_api,
data={"labels": labels},
)

def list_labels(self) -> AsyncIterator[dict[str, Any]]:
return self._api.getiter(self._issue_labels_api)

async def remove_label(self, label: str) -> None:
return await self._api.delete(f"{self._issue_labels_api}/{label}")