1 | #!/usr/bin/env python3
2 | # For the dependencies, see the requirements.txt
3 |
4 | import logging
5 | import re
6 | import traceback
7 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace
8 | from collections import OrderedDict
9 | from copy import deepcopy
10 | from dataclasses import dataclass, field
11 | from itertools import accumulate
12 | from os import getenv
13 | from pathlib import Path
14 | from subprocess import check_output
15 | from textwrap import dedent
16 | from typing import Any, Iterable, Optional, Pattern, TypedDict, Union
17 |
18 | import yaml
19 | from filecache import DAY, filecache
20 | from gql import Client, gql
21 | from gql.transport.requests import RequestsHTTPTransport
22 | from graphql import DocumentNode
23 |
24 |
25 | class DagNode(TypedDict):
26 | needs: set[str]
27 | stage: str
28 | # `name` is redundant but is here for retro-compatibility
29 | name: str
30 |
31 |
32 | # see create_job_needs_dag function for more details
33 | Dag = dict[str, DagNode]
34 |
35 |
36 | StageSeq = OrderedDict[str, set[str]]
37 | TOKEN_DIR = Path(getenv("XDG_CONFIG_HOME") or Path.home() / ".config")
38 |
39 |
40 | def get_token_from_default_dir() -> str:
41 | token_file = TOKEN_DIR / "gitlab-token"
42 | try:
43 | return str(token_file.resolve())
44 | except FileNotFoundError as ex:
45 | print(
46 | f"Could not find {token_file}, please provide a token file as an argument"
47 | )
48 | raise ex
49 |
50 |
51 | def get_project_root_dir():
52 | root_path = Path(__file__).parent.parent.parent.resolve()
53 | gitlab_file = root_path / ".gitlab-ci.yml"
54 | assert gitlab_file.exists()
55 |
56 | return root_path
57 |
58 |
59 | @dataclass
60 | class GitlabGQL:
61 | _transport: Any = field(init=False)
62 | client: Client = field(init=False)
63 | url: str = "https://gitlab.freedesktop.org/api/graphql"
64 | token: Optional[str] = None
65 |
66 | def __post_init__(self) -> None:
67 | self._setup_gitlab_gql_client()
68 |
69 | def _setup_gitlab_gql_client(self) -> None:
70 | # Select your transport with a defined url endpoint
71 | headers = {}
72 | if self.token:
73 | headers["Authorization"] = f"Bearer {self.token}"
74 | self._transport = RequestsHTTPTransport(url=self.url, headers=headers)
75 |
76 | # Create a GraphQL client using the defined transport
77 | self.client = Client(transport=self._transport, fetch_schema_from_transport=True)
78 |
79 | def query(
80 | self,
81 | gql_file: Union[Path, str],
82 | params: dict[str, Any] = {},
83 | operation_name: Optional[str] = None,
84 | paginated_key_loc: Iterable[str] = [],
85 | disable_cache: bool = False,
86 | ) -> dict[str, Any]:
87 | def run_uncached() -> dict[str, Any]:
88 | if paginated_key_loc:
89 | return self._sweep_pages(gql_file, params, operation_name, paginated_key_loc)
90 | return self._query(gql_file, params, operation_name)
91 |
92 | if disable_cache:
93 | return run_uncached()
94 |
95 | try:
96 | # Create an auxiliary variable to deliver a cached result and enable catching exceptions
97 | # Decorate the query to be cached
98 | if paginated_key_loc:
99 | result = self._sweep_pages_cached(
100 | gql_file, params, operation_name, paginated_key_loc
101 | )
102 | else:
103 | result = self._query_cached(gql_file, params, operation_name)
104 | return result # type: ignore
105 | except Exception as ex:
106 | logging.error(f"Cached query failed with {ex}")
107 | # print exception traceback
108 | traceback_str = "".join(traceback.format_exception(ex))
109 | logging.error(traceback_str)
110 | self.invalidate_query_cache()
111 | logging.error("Cache invalidated, retrying without cache")
112 | finally:
113 | return run_uncached()
114 |
115 | def _query(
116 | self,
117 | gql_file: Union[Path, str],
118 | params: dict[str, Any] = {},
119 | operation_name: Optional[str] = None,
120 | ) -> dict[str, Any]:
121 | # Provide a GraphQL query
122 | source_path: Path = Path(__file__).parent
123 | pipeline_query_file: Path = source_path / gql_file
124 |
125 | query: DocumentNode
126 | with open(pipeline_query_file, "r") as f:
127 | pipeline_query = f.read()
128 | query = gql(pipeline_query)
129 |
130 | # Execute the query on the transport
131 | return self.client.execute_sync(
132 | query, variable_values=params, operation_name=operation_name
133 | )
134 |
135 | @filecache(DAY)
136 | def _sweep_pages_cached(self, *args, **kwargs):
137 | return self._sweep_pages(*args, **kwargs)
138 |
139 | @filecache(DAY)
140 | def _query_cached(self, *args, **kwargs):
141 | return self._query(*args, **kwargs)
142 |
143 | def _sweep_pages(
144 | self, query, params, operation_name=None, paginated_key_loc: Iterable[str] = []
145 | ) -> dict[str, Any]:
146 | """
147 | Retrieve paginated data from a GraphQL API and concatenate the results into a single
148 | response.
149 |
150 | Args:
151 | query: represents a filepath with the GraphQL query to be executed.
152 | params: a dictionary that contains the parameters to be passed to the query. These
153 | parameters can be used to filter or modify the results of the query.
154 | operation_name: The `operation_name` parameter is an optional parameter that specifies
155 | the name of the GraphQL operation to be executed. It is used when making a GraphQL
156 | query to specify which operation to execute if there are multiple operations defined
157 | in the GraphQL schema. If not provided, the default operation will be executed.
158 | paginated_key_loc (Iterable[str]): The `paginated_key_loc` parameter is an iterable of
159 | strings that represents the location of the paginated field within the response. It
160 | is used to extract the paginated field from the response and append it to the final
161 | result. The node has to be a list of objects with a `pageInfo` field that contains
162 | at least the `hasNextPage` and `endCursor` fields.
163 |
164 | Returns:
165 | a dictionary containing the response from the query with the paginated field
166 | concatenated.
167 | """
168 |
169 | def fetch_page(cursor: str | None = None) -> dict[str, Any]:
170 | if cursor:
171 | params["cursor"] = cursor
172 | logging.info(
173 | f"Found more than 100 elements, paginating. "
174 | f"Current cursor at {cursor}"
175 | )
176 |
177 | return self._query(query, params, operation_name)
178 |
179 | # Execute the initial query
180 | response: dict[str, Any] = fetch_page()
181 |
182 | # Initialize an empty list to store the final result
183 | final_partial_field: list[dict[str, Any]] = []
184 |
185 | # Loop until all pages have been retrieved
186 | while True:
187 | # Get the partial field to be appended to the final result
188 | partial_field = response
189 | for key in paginated_key_loc:
190 | partial_field = partial_field[key]
191 |
192 | # Append the partial field to the final result
193 | final_partial_field += partial_field["nodes"]
194 |
195 | # Check if there are more pages to retrieve
196 | page_info = partial_field["pageInfo"]
197 | if not page_info["hasNextPage"]:
198 | break
199 |
200 | # Execute the query with the updated cursor parameter
201 | response = fetch_page(page_info["endCursor"])
202 |
203 | # Replace the "nodes" field in the original response with the final result
204 | partial_field["nodes"] = final_partial_field
205 | return response
206 |
207 | def invalidate_query_cache(self) -> None:
208 | logging.warning("Invalidating query cache")
209 | try:
210 | self._sweep_pages._db.clear()
211 | self._query._db.clear()
212 | except AttributeError as ex:
213 | logging.warning(f"Could not invalidate cache, maybe it was not used in {ex.args}?")
214 |
215 |
216 | def insert_early_stage_jobs(stage_sequence: StageSeq, jobs_metadata: Dag) -> Dag:
217 | pre_processed_dag: dict[str, set[str]] = {}
218 | jobs_from_early_stages = list(accumulate(stage_sequence.values(), set.union))
219 | for job_name, metadata in jobs_metadata.items():
220 | final_needs: set[str] = deepcopy(metadata["needs"])
221 | # Pre-process jobs that are not based on needs field
222 | # e.g. sanity job in mesa MR pipelines
223 | if not final_needs:
224 | job_stage: str = jobs_metadata[job_name]["stage"]
225 | stage_index: int = list(stage_sequence.keys()).index(job_stage)
226 | if stage_index > 0:
227 | final_needs |= jobs_from_early_stages[stage_index - 1]
228 | pre_processed_dag[job_name] = final_needs
229 |
230 | for job_name, needs in pre_processed_dag.items():
231 | jobs_metadata[job_name]["needs"] = needs
232 |
233 | return jobs_metadata
234 |
235 |
236 | def traverse_dag_needs(jobs_metadata: Dag) -> None:
237 | created_jobs = set(jobs_metadata.keys())
238 | for job, metadata in jobs_metadata.items():
239 | final_needs: set = deepcopy(metadata["needs"]) & created_jobs
240 | # Post process jobs that are based on needs field
241 | partial = True
242 |
243 | while partial:
244 | next_depth: set[str] = {n for dn in final_needs for n in jobs_metadata[dn]["needs"]}
245 | partial: bool = not final_needs.issuperset(next_depth)
246 | final_needs = final_needs.union(next_depth)
247 |
248 | jobs_metadata[job]["needs"] = final_needs
249 |
250 |
251 | def extract_stages_and_job_needs(
252 | pipeline_jobs: dict[str, Any], pipeline_stages: dict[str, Any]
253 | ) -> tuple[StageSeq, Dag]:
254 | jobs_metadata = Dag()
255 | # Record the stage sequence to post process deps that are not based on needs
256 | # field, for example: sanity job
257 | stage_sequence: OrderedDict[str, set[str]] = OrderedDict()
258 | for stage in pipeline_stages["nodes"]:
259 | stage_sequence[stage["name"]] = set()
260 |
261 | for job in pipeline_jobs["nodes"]:
262 | stage_sequence[job["stage"]["name"]].add(job["name"])
263 | dag_job: DagNode = {
264 | "name": job["name"],
265 | "stage": job["stage"]["name"],
266 | "needs": set([j["node"]["name"] for j in job["needs"]["edges"]]),
267 | }
268 | jobs_metadata[job["name"]] = dag_job
269 |
270 | return stage_sequence, jobs_metadata
271 |
272 |
273 | def create_job_needs_dag(gl_gql: GitlabGQL, params, disable_cache: bool = True) -> Dag:
274 | """
275 | This function creates a Directed Acyclic Graph (DAG) to represent a sequence of jobs, where each
276 | job has a set of jobs that it depends on (its "needs") and belongs to a certain "stage".
277 | The "name" of the job is used as the key in the dictionary.
278 |
279 | For example, consider the following DAG:
280 |
281 | 1. build stage: job1 -> job2 -> job3
282 | 2. test stage: job2 -> job4
283 |
284 | - The job needs for job3 are: job1, job2
285 | - The job needs for job4 are: job2
286 | - The job2 needs to wait all jobs from build stage to finish.
287 |
288 | The resulting DAG would look like this:
289 |
290 | dag = {
291 | "job1": {"needs": set(), "stage": "build", "name": "job1"},
292 | "job2": {"needs": {"job1", "job2", job3"}, "stage": "test", "name": "job2"},
293 | "job3": {"needs": {"job1", "job2"}, "stage": "build", "name": "job3"},
294 | "job4": {"needs": {"job2"}, "stage": "test", "name": "job4"},
295 | }
296 |
297 | To access the job needs, one can do:
298 |
299 | dag["job3"]["needs"]
300 |
301 | This will return the set of jobs that job3 needs: {"job1", "job2"}
302 |
303 | Args:
304 | gl_gql (GitlabGQL): The `gl_gql` parameter is an instance of the `GitlabGQL` class, which is
305 | used to make GraphQL queries to the GitLab API.
306 | params (dict): The `params` parameter is a dictionary that contains the necessary parameters
307 | for the GraphQL query. It is used to specify the details of the pipeline for which the
308 | job needs DAG is being created.
309 | The specific keys and values in the `params` dictionary will depend on
310 | the requirements of the GraphQL query being executed
311 | disable_cache (bool): The `disable_cache` parameter is a boolean that specifies whether the
312 |
313 | Returns:
314 | The final DAG (Directed Acyclic Graph) representing the job dependencies sourced from needs
315 | or stages rule.
316 | """
317 | stages_jobs_gql = gl_gql.query(
318 | "pipeline_details.gql",
319 | params=params,
320 | paginated_key_loc=["project", "pipeline", "jobs"],
321 | disable_cache=disable_cache,
322 | )
323 | pipeline_data = stages_jobs_gql["project"]["pipeline"]
324 | if not pipeline_data:
325 | raise RuntimeError(f"Could not find any pipelines for {params}")
326 |
327 | stage_sequence, jobs_metadata = extract_stages_and_job_needs(
328 | pipeline_data["jobs"], pipeline_data["stages"]
329 | )
330 | # Fill the DAG with the job needs from stages that don't have any needs but still need to wait
331 | # for previous stages
332 | final_dag = insert_early_stage_jobs(stage_sequence, jobs_metadata)
333 | # Now that each job has its direct needs filled correctly, update the "needs" field for each job
334 | # in the DAG by performing a topological traversal
335 | traverse_dag_needs(final_dag)
336 |
337 | return final_dag
338 |
339 |
340 | def filter_dag(dag: Dag, regex: Pattern) -> Dag:
341 | jobs_with_regex: set[str] = {job for job in dag if regex.fullmatch(job)}
342 | return Dag({job: data for job, data in dag.items() if job in sorted(jobs_with_regex)})
343 |
344 |
345 | def print_dag(dag: Dag) -> None:
346 | for job, data in dag.items():
347 | print(f"{job}:")
348 | print(f"\t{' '.join(data['needs'])}")
349 | print()
350 |
351 |
352 | def fetch_merged_yaml(gl_gql: GitlabGQL, params) -> dict[str, Any]:
353 | params["content"] = dedent("""\
354 | include:
355 | - local: .gitlab-ci.yml
356 | """)
357 | raw_response = gl_gql.query("job_details.gql", params)
358 | if merged_yaml := raw_response["ciConfig"]["mergedYaml"]:
359 | return yaml.safe_load(merged_yaml)
360 |
361 | gl_gql.invalidate_query_cache()
362 | raise ValueError(
363 | """
364 | Could not fetch any content for merged YAML,
365 | please verify if the git SHA exists in remote.
366 | Maybe you forgot to `git push`? """
367 | )
368 |
369 |
370 | def recursive_fill(job, relationship_field, target_data, acc_data: dict, merged_yaml):
371 | if relatives := job.get(relationship_field):
372 | if isinstance(relatives, str):
373 | relatives = [relatives]
374 |
375 | for relative in relatives:
376 | parent_job = merged_yaml[relative]
377 | acc_data = recursive_fill(parent_job, acc_data, merged_yaml) # type: ignore
378 |
379 | acc_data |= job.get(target_data, {})
380 |
381 | return acc_data
382 |
383 |
384 | def get_variables(job, merged_yaml, project_path, sha) -> dict[str, str]:
385 | p = get_project_root_dir() / ".gitlab-ci" / "image-tags.yml"
386 | image_tags = yaml.safe_load(p.read_text())
387 |
388 | variables = image_tags["variables"]
389 | variables |= merged_yaml["variables"]
390 | variables |= job["variables"]
391 | variables["CI_PROJECT_PATH"] = project_path
392 | variables["CI_PROJECT_NAME"] = project_path.split("/")[1]
393 | variables["CI_REGISTRY_IMAGE"] = "registry.freedesktop.org/${CI_PROJECT_PATH}"
394 | variables["CI_COMMIT_SHA"] = sha
395 |
396 | while recurse_among_variables_space(variables):
397 | pass
398 |
399 | return variables
400 |
401 |
402 | # Based on: https://stackoverflow.com/a/2158532/1079223
403 | def flatten(xs):
404 | for x in xs:
405 | if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
406 | yield from flatten(x)
407 | else:
408 | yield x
409 |
410 |
411 | def get_full_script(job) -> list[str]:
412 | script = []
413 | for script_part in ("before_script", "script", "after_script"):
414 | script.append(f"# {script_part}")
415 | lines = flatten(job.get(script_part, []))
416 | script.extend(lines)
417 | script.append("")
418 |
419 | return script
420 |
421 |
422 | def recurse_among_variables_space(var_graph) -> bool:
423 | updated = False
424 | for var, value in var_graph.items():
425 | value = str(value)
426 | dep_vars = []
427 | if match := re.findall(r"(\$[{]?[\w\d_]*[}]?)", value):
428 | all_dep_vars = [v.lstrip("${").rstrip("}") for v in match]
429 | # print(value, match, all_dep_vars)
430 | dep_vars = [v for v in all_dep_vars if v in var_graph]
431 |
432 | for dep_var in dep_vars:
433 | dep_value = str(var_graph[dep_var])
434 | new_value = var_graph[var]
435 | new_value = new_value.replace(f"${{{dep_var}}}", dep_value)
436 | new_value = new_value.replace(f"${dep_var}", dep_value)
437 | var_graph[var] = new_value
438 | updated |= dep_value != new_value
439 |
440 | return updated
441 |
442 |
443 | def print_job_final_definition(job_name, merged_yaml, project_path, sha):
444 | job = merged_yaml[job_name]
445 | variables = get_variables(job, merged_yaml, project_path, sha)
446 |
447 | print("# --------- variables ---------------")
448 | for var, value in sorted(variables.items()):
449 | print(f"export {var}={value!r}")
450 |
451 | # TODO: Recurse into needs to get full script
452 | # TODO: maybe create a extra yaml file to avoid too much rework
453 | script = get_full_script(job)
454 | print()
455 | print()
456 | print("# --------- full script ---------------")
457 | print("\n".join(script))
458 |
459 | if image := variables.get("MESA_IMAGE"):
460 | print()
461 | print()
462 | print("# --------- container image ---------------")
463 | print(image)
464 |
465 |
466 | def from_sha_to_pipeline_iid(gl_gql: GitlabGQL, params) -> str:
467 | result = gl_gql.query("pipeline_utils.gql", params)
468 |
469 | return result["project"]["pipelines"]["nodes"][0]["iid"]
470 |
471 |
472 | def parse_args() -> Namespace:
473 | parser = ArgumentParser(
474 | formatter_class=ArgumentDefaultsHelpFormatter,
475 | description="CLI and library with utility functions to debug jobs via Gitlab GraphQL",
476 | epilog=f"""Example:
477 | {Path(__file__).name} --print-dag""",
478 | )
479 | parser.add_argument("-pp", "--project-path", type=str, default="mesa/mesa")
480 | parser.add_argument("--sha", "--rev", type=str, default='HEAD')
481 | parser.add_argument(
482 | "--regex",
483 | type=str,
484 | required=False,
485 | help="Regex pattern for the job name to be considered",
486 | )
487 | mutex_group_print = parser.add_mutually_exclusive_group()
488 | mutex_group_print.add_argument(
489 | "--print-dag",
490 | action="store_true",
491 | help="Print job needs DAG",
492 | )
493 | mutex_group_print.add_argument(
494 | "--print-merged-yaml",
495 | action="store_true",
496 | help="Print the resulting YAML for the specific SHA",
497 | )
498 | mutex_group_print.add_argument(
499 | "--print-job-manifest",
500 | metavar='JOB_NAME',
501 | type=str,
502 | help="Print the resulting job data"
503 | )
504 | parser.add_argument(
505 | "--gitlab-token-file",
506 | type=str,
507 | default=get_token_from_default_dir(),
508 | help="force GitLab token, otherwise it's read from $XDG_CONFIG_HOME/gitlab-token",
509 | )
510 |
511 | args = parser.parse_args()
512 | args.gitlab_token = Path(args.gitlab_token_file).read_text().strip()
513 | return args
514 |
515 |
516 | def main():
517 | args = parse_args()
518 | gl_gql = GitlabGQL(token=args.gitlab_token)
519 |
520 | sha = check_output(['git', 'rev-parse', args.sha]).decode('ascii').strip()
521 |
522 | if args.print_dag:
523 | iid = from_sha_to_pipeline_iid(gl_gql, {"projectPath": args.project_path, "sha": sha})
524 | dag = create_job_needs_dag(
525 | gl_gql, {"projectPath": args.project_path, "iid": iid}, disable_cache=True
526 | )
527 |
528 | if args.regex:
529 | dag = filter_dag(dag, re.compile(args.regex))
530 |
531 | print_dag(dag)
532 |
533 | if args.print_merged_yaml or args.print_job_manifest:
534 | merged_yaml = fetch_merged_yaml(
535 | gl_gql, {"projectPath": args.project_path, "sha": sha}
536 | )
537 |
538 | if args.print_merged_yaml:
539 | print(yaml.dump(merged_yaml, indent=2))
540 |
541 | if args.print_job_manifest:
542 | print_job_final_definition(
543 | args.print_job_manifest, merged_yaml, args.project_path, sha
544 | )
545 |
546 |
547 | if __name__ == "__main__":
548 | main()