| # Copyright 2023 The Cobalt Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Updates the requires status checks for a branch. |
| |
| Requires PyGithub to run: |
| |
| $ pip install PyGithub |
| """ |
| |
| import argparse |
| from github import Github |
| from typing import List |
| |
| # Issue a Personal Access Token with 'repo' permission on |
| # https://github.com/settings/tokens. |
| YOUR_GITHUB_TOKEN = '' |
| assert YOUR_GITHUB_TOKEN != '', 'YOUR_GITHUB_TOKEN must be set.' |
| |
| TARGET_REPO = 'youtube/cobalt' |
| |
| EXCLUDED_CHECK_PATTERNS = [ |
| # Excludes non build/test checks. |
| 'feedback/copybara', |
| 'prepare_branch_list', |
| 'cherry_pick', |
| 'assign-reviewer', |
| |
| # Excludes coverage and test reports. |
| 'linux-coverage', |
| 'codecov', |
| 'on-host-unit-test-report', |
| |
| # Excludes blackbox, web platform, and unit tests run on-device. |
| '_on_device_', |
| |
| # Excludes slow and flaky evergreen tests. |
| 'evergreen-as-blackbox_test', |
| 'evergreen_test', |
| |
| # Excludes templated check names. |
| '${{' |
| ] |
| |
| # Exclude rc_11 and COBALT_9 releases. |
| MINIMUM_LTS_RELEASE_NUMBER = 19 |
| LATEST_LTS_RELEASE_NUMBER = 24 |
| |
| |
| def get_protected_branches() -> List[str]: |
| branches = ['main'] |
| for i in range(MINIMUM_LTS_RELEASE_NUMBER, LATEST_LTS_RELEASE_NUMBER + 1): |
| branches.append(f'{i}.lts.1+') |
| return branches |
| |
| |
| def initialize_repo_connection(): |
| g = Github(YOUR_GITHUB_TOKEN) |
| return g.get_repo(TARGET_REPO) |
| |
| |
| def get_checks_for_branch(repo, branch: str) -> None: |
| # The 'merged' sort order is not listed in public docs but still works. |
| # If this functionality is removed the alternative is to loop through all |
| # PRs and use the 'merged_at' property to determine which is the latest one. |
| # https://docs.github.com/en/rest/pulls/pulls#list-pull-requests |
| prs = repo.get_pulls( |
| state='closed', sort='merged', base=branch, direction='desc') |
| |
| latest_pr = None |
| for pr in prs: |
| if pr.merged: |
| latest_pr = pr |
| break |
| |
| latest_pr_commit = repo.get_commit(latest_pr.head.sha) |
| checks = latest_pr_commit.get_check_runs() |
| return checks |
| |
| |
| def should_include_run(check_run) -> bool: |
| for pattern in EXCLUDED_CHECK_PATTERNS: |
| if pattern in check_run.name: |
| return False |
| return True |
| |
| |
| def get_required_checks_for_branch(repo, branch: str) -> List[str]: |
| checks = get_checks_for_branch(repo, branch) |
| filtered_check_runs = [run for run in checks if should_include_run(run)] |
| check_names = set(run.name for run in filtered_check_runs) |
| return list(check_names) |
| |
| |
| def print_checks(repo, branch_name: str, new_checks: List[str], |
| print_unchanged: bool) -> None: |
| branch = repo.get_branch(branch_name) |
| current_checks = branch.get_required_status_checks().contexts |
| |
| def print_check_list(checks): |
| for check_name in sorted(checks): |
| print(check_name) |
| print() |
| |
| added_checks = set(new_checks) - set(current_checks) |
| if added_checks: |
| print(f'Required checks to be ADDED for {branch_name}:') |
| print_check_list(added_checks) |
| |
| removed_checks = set(current_checks) - set(new_checks) |
| if removed_checks: |
| print(f'Required checks to be REMOVED for {branch_name}:') |
| print_check_list(removed_checks) |
| |
| if print_unchanged: |
| unchanged_checks = set(current_checks).intersection(set(new_checks)) |
| print(f'Required checks that will REMAIN for {branch_name}:') |
| print_check_list(unchanged_checks) |
| |
| |
| def update_protection_for_branch(repo, branch: str, |
| check_names: List[str]) -> None: |
| branch = repo.get_branch(branch) |
| branch.edit_required_status_checks(contexts=check_names) |
| |
| |
| def parse_args() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| '-b', |
| '--branch', |
| action='append', |
| help='Branch to update. Can be repeated to update multiple branches.' |
| ' Defaults to all protected branches.') |
| parser.add_argument( |
| '--apply', action='store_true', help='Apply required checks updates.') |
| parser.add_argument( |
| '--print_unchanged', |
| action='store_true', |
| help='Also print the checks that will be left unchanged.' |
| ' Is a no-op with --apply.') |
| args = parser.parse_args() |
| |
| if not args.branch: |
| args.branch = get_protected_branches() |
| |
| return args |
| |
| |
| def main() -> None: |
| args = parse_args() |
| repo = initialize_repo_connection() |
| |
| if not args.apply: |
| print('This is a dry-run, printing pending changes only.') |
| |
| for branch in args.branch: |
| required_checks = get_required_checks_for_branch(repo, branch) |
| if args.apply: |
| update_protection_for_branch(repo, branch, required_checks) |
| else: |
| print_checks(repo, branch, required_checks, args.print_unchanged) |
| |
| if not args.apply: |
| print('Re-run with --apply to apply the changes.') |
| |
| |
| if __name__ == '__main__': |
| main() |