Skip to content

Commit

Permalink
fix branch protection fetch
Browse files Browse the repository at this point in the history
  • Loading branch information
chdsbd committed Jun 20, 2024
1 parent 0065e64 commit 20f3809
Showing 1 changed file with 64 additions and 46 deletions.
110 changes: 64 additions & 46 deletions bot/kodiak/queries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,60 @@ class GraphQLResponse(TypedDict, total=False):
query (
$owner: String!,
$repo: String!,
$baseRef: String!,
$rootConfigFileExpression: String!,
$githubConfigFileExpression: String!,
$orgRootConfigFileExpression: String,
$orgGithubConfigFileExpression: String
) {
repository(owner: $owner, name: $repo) {
refs(refPrefix: "refs/heads/", query: $baseRef, first: 1) {
edges {
node {
name
branchProtectionRule {
requiresStatusChecks
requiredStatusCheckContexts
requiresCodeOwnerReviews
requiresStrictStatusChecks
requiresLinearHistory
requiresConversationResolution
requiresCommitSignatures
restrictsPushes
requiredApprovingReviewCount
requireLastPushApproval
requiredStatusChecks {
app {
slug
}
context
}
pushAllowances(first: 100) {
nodes {
actor {
... on Team {
name
}
... on Actor {
login
}
... on User {
login
}
... on App {
databaseId
}
}
}
}
}
}
}
}
rootConfigFile: object(expression: $rootConfigFileExpression) {
... on Blob {
text
Expand Down Expand Up @@ -156,30 +204,6 @@ def get_event_info_query(
return """
query GetEventInfo($owner: String!, $repo: String!, $PRNumber: Int!) {
repository(owner: $owner, name: $repo) {
branchProtectionRules(first: 100) {
nodes {
matchingRefs(first: 100) {
nodes {
name
}
}
requiresStatusChecks
requiredStatusCheckContexts
requiresStrictStatusChecks
requiresCommitSignatures
%(requiresConversationResolution)s
restrictsPushes
pushAllowances(first: 100) {
nodes {
actor {
... on App {
databaseId
}
}
}
}
}
}
mergeCommitAllowed
rebaseMergeAllowed
squashMergeAllowed
Expand Down Expand Up @@ -606,29 +630,20 @@ def get_sha(*, pr: Dict[str, Any]) -> Optional[str]:
return None


def get_branch_protection_dicts(*, repo: Dict[str, Any]) -> List[Dict[str, Any]]:
try:
return cast(List[Dict[str, Any]], repo["branchProtectionRules"]["nodes"])
except (KeyError, TypeError):
return []


def get_branch_protection(
*, repo: Dict[str, Any], ref_name: str
*, config_response: Dict[str, Any], ref_name: str
) -> Optional[BranchProtectionRule]:
for rule in get_branch_protection_dicts(repo=repo):
try:
branchProtectionRule = config_response['repository']['refs']['edges'][0]['node']['branchProtectionRule']
try:
nodes = rule["matchingRefs"]["nodes"]
except (KeyError, TypeError):
nodes = []
for node in nodes:
if node["name"] == ref_name:
try:
return BranchProtectionRule.parse_obj(rule)
except ValueError:
logger.warning("Could not parse branch protection", exc_info=True)
return None
return None
return BranchProtectionRule.parse_obj(branchProtectionRule)
except ValueError:
logger.warning("Could not parse branch protection", exc_info=True)
return None
except (KeyError, TypeError):
return None


def get_review_requests_dicts(*, pr: Dict[str, Any]) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -793,6 +808,7 @@ class CfgInfo:
parsed: V1 | pydantic.ValidationError | toml.TomlDecodeError
text: str
file_expression: str
branch_protection: Optional[BranchProtectionRule] = None


@dataclass
Expand Down Expand Up @@ -977,6 +993,7 @@ async def get_config_for_ref(
variables=dict(
owner=self.owner,
repo=self.repo,
baseRef=ref,
rootConfigFileExpression=repo_root_config_expression,
githubConfigFileExpression=repo_github_config_expression,
orgRootConfigFileExpression=org_root_config_expression,
Expand All @@ -990,6 +1007,9 @@ async def get_config_for_ref(
if data is None:
self.log.error("could not fetch default branch name", res=res)
return None


branch_protection = get_branch_protection(config_response=data,ref_name=ref)

parsed_config = parse_config(data)
if parsed_config is None:
Expand All @@ -1013,6 +1033,7 @@ def get_file_expression() -> str:
parsed=V1.parse_toml(parsed_config.text),
text=parsed_config.text,
file_expression=get_file_expression(),
branch_protection=branch_protection,
)

async def get_event_info(self, pr_number: int) -> Optional[EventInfoResponse]:
Expand Down Expand Up @@ -1109,9 +1130,6 @@ async def get_event_info(self, pr_number: int) -> Optional[EventInfoResponse]:
if cfg is None:
log.info("no config found")
return None
branch_protection = get_branch_protection(
repo=repository, ref_name=pr.baseRefName
)

all_reviews = get_reviews(pr=pull_request)
bot_reviews = self.get_bot_reviews(reviews=all_reviews)
Expand All @@ -1128,7 +1146,7 @@ async def get_event_info(self, pr_number: int) -> Optional[EventInfoResponse]:
is_private=repository.get("isPrivate") is True,
),
subscription=subscription,
branch_protection=branch_protection,
branch_protection=cfg.branch_protection,
review_requests=get_requested_reviews(pr=pull_request),
bot_reviews=bot_reviews,
status_contexts=get_status_contexts(pr=pull_request),
Expand Down

0 comments on commit 20f3809

Please sign in to comment.