diff --git a/commonTools/test/utilities/check-mpi-comm-world-usage.py b/commonTools/test/utilities/check-mpi-comm-world-usage.py index 81a6776f7d3d..827454ee1ab9 100644 --- a/commonTools/test/utilities/check-mpi-comm-world-usage.py +++ b/commonTools/test/utilities/check-mpi-comm-world-usage.py @@ -53,7 +53,6 @@ def parse_diff_output(changed_files): line_counter = 0 for line in lines: if line.startswith("+"): - line_counter += 1 if ( "MPI_COMM_WORLD" in line and not "CHECK: ALLOW MPI_COMM_WORLD" in line @@ -62,23 +61,30 @@ def parse_diff_output(changed_files): # and "CHECK: ALLOW MPI_COMM_WORLD" is not present changed_lines.append(start_line + line_counter) + line_counter += 1 + if changed_lines: files[file_name] = changed_lines return files -def get_changed_files_uncommitted(): - """Get a dictionary of files and their changed lines where MPI_COMM_WORLD was added from uncommitted changes.""" - cmd = ["git", "diff", "-U0", "--ignore-all-space", "HEAD"] - result = subprocess.check_output(cmd).decode("utf-8") - - return parse_diff_output(result) +def get_common_ancestor(target_branch, feature_branch): + cmd = ["git", "merge-base", target_branch, feature_branch] + return subprocess.check_output(cmd).decode("utf-8").strip() -def get_changed_files(start_commit, end_commit): - """Get a dictionary of files and their changed lines between two commits where MPI_COMM_WORLD was added.""" - cmd = ["git", "diff", "-U0", "--ignore-all-space", start_commit, end_commit] +def get_changed_files(target_branch, feature_branch): + """Get a dictionary of files and their changed lines between the common ancestor and feature_branch.""" + start_commit = get_common_ancestor(target_branch, feature_branch) + cmd = [ + "git", + "diff", + "-U0", + "--ignore-all-space", + start_commit, + feature_branch + ] result = subprocess.check_output(cmd).decode("utf-8") return parse_diff_output(result) @@ -109,21 +115,12 @@ def print_occurences(changed_files, title): end_commit = parser.parse_args().head print(f"End commit: {end_commit}") - commited_occurences = get_changed_files(start_commit, end_commit) - uncommited_occurences = get_changed_files_uncommitted() - - mpi_comm_world_detected = commited_occurences or uncommited_occurences + mpi_comm_world_detected = get_changed_files(start_commit, end_commit) if mpi_comm_world_detected: - if commited_occurences: - print_occurences( - commited_occurences, "Detected MPI_COMM_WORLD in the following files:" - ) - if uncommited_occurences: - print_occurences( - uncommited_occurences, - "Detected MPI_COMM_WORLD in the following files (uncommited changes):", - ) + print_occurences( + mpi_comm_world_detected, "Detected MPI_COMM_WORLD in the following files:" + ) sys.exit(1) # Exit with an error code to fail the GitHub Action else: