Skip to content

Commit

Permalink
Refactor rules into configurable phases
Browse files Browse the repository at this point in the history
  • Loading branch information
borkaehw committed Oct 25, 2019
1 parent 0f89c21 commit 8f9aa27
Show file tree
Hide file tree
Showing 21 changed files with 937 additions and 565 deletions.
10 changes: 10 additions & 0 deletions scala/private/phases/api.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def run_phases(ctx, phases):
global_provider = {}
current_provider = struct(**global_provider)
for (name, function) in phases:
new_provider = function(ctx, current_provider)
if new_provider != None:
global_provider[name] = new_provider
current_provider = struct(**global_provider)

return current_provider
38 changes: 38 additions & 0 deletions scala/private/phases/phase_coda.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# PHASE: coda
#
# DOCUMENT THIS
#
def phase_binary_coda(ctx, p):
return struct(
executable = p.declare_executable,
coverage = p.compile.coverage,
files = depset([p.declare_executable, ctx.outputs.jar]),
instrumented_files = p.compile.coverage.instrumented_files,
providers = [p.compile.merged_provider, p.collect_jars.jars2labels] + p.compile.coverage.providers,
runfiles = p.runfiles.runfiles,
scala = p.scala_provider,
transitive_rjars = p.compile.rjars, #calling rules need this for the classpath in the launcher
)

def phase_library_coda(ctx, p):
return struct(
files = depset([ctx.outputs.jar] + p.compile.full_jars), # Here is the default output
instrumented_files = p.compile.coverage.instrumented_files,
jars_to_labels = p.collect_jars.jars2labels,
providers = [p.compile.merged_provider, p.collect_jars.jars2labels] + p.compile.coverage.providers,
runfiles = p.runfiles.runfiles,
scala = p.scala_provider,
)

def phase_test_coda(ctx, p):
coverage_runfiles = p.coverage_runfiles.coverage_runfiles
coverage_runfiles.extend(p.write_executable)
return struct(
executable = p.declare_executable,
files = depset([p.declare_executable, ctx.outputs.jar]),
instrumented_files = p.compile.coverage.instrumented_files,
providers = [p.compile.merged_provider, p.collect_jars.jars2labels] + p.compile.coverage.providers,
runfiles = ctx.runfiles(coverage_runfiles, transitive_files = p.runfiles.runfiles.files),
scala = p.scala_provider,
)
71 changes: 71 additions & 0 deletions scala/private/phases/phase_collect_jars.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#
# PHASE: collect jars
#
# DOCUMENT THIS
#
load(
"@io_bazel_rules_scala//scala/private:rule_impls.bzl",
"collect_jars_from_common_ctx",
)

def phase_test_collect_jars(ctx, p):
args = struct(
base_classpath = p.init.scalac_provider.default_classpath + [ctx.attr._scalatest],
extra_runtime_deps = [
ctx.attr._scalatest_reporter,
ctx.attr._scalatest_runner,
],
)
return phase_common_collect_jars(ctx, p, args)

def phase_repl_collect_jars(ctx, p):
args = struct(
base_classpath = p.init.scalac_provider.default_repl_classpath,
)
return phase_common_collect_jars(ctx, p, args)

def phase_macro_library_collect_jars(ctx, p):
args = struct(
base_classpath = p.init.scalac_provider.default_macro_classpath,
)
return phase_common_collect_jars(ctx, p, args)

def phase_junit_test_collect_jars(ctx, p):
args = struct(
extra_deps = [
ctx.attr._junit,
ctx.attr._hamcrest,
ctx.attr.suite_label,
ctx.attr._bazel_test_runner,
],
)
return phase_common_collect_jars(ctx, p, args)

def phase_library_for_plugin_bootstrapping_collect_jars(ctx, p):
args = struct(
unused_dependency_checker_mode = "off",
)
return phase_common_collect_jars(ctx, p, args)

def phase_common_collect_jars(ctx, p, _args = struct()):
return _phase_collect_jars(
ctx,
_args.base_classpath if hasattr(_args, "base_classpath") else p.init.scalac_provider.default_classpath,
_args.extra_deps if hasattr(_args, "extra_deps") else [],
_args.extra_runtime_deps if hasattr(_args, "extra_runtime_deps") else [],
_args.unused_dependency_checker_mode if hasattr(_args, "unused_dependency_checker_mode") else p.unused_deps_checker,
)

def _phase_collect_jars(
ctx,
base_classpath,
extra_deps,
extra_runtime_deps,
unused_dependency_checker_mode):
return collect_jars_from_common_ctx(
ctx,
base_classpath,
extra_deps,
extra_runtime_deps,
unused_dependency_checker_mode == "off",
)
144 changes: 144 additions & 0 deletions scala/private/phases/phase_compile.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#
# PHASE: compile
#
# DOCUMENT THIS
#
load(
"@io_bazel_rules_scala//scala/private:rule_impls.bzl",
"compile_or_empty",
"pack_source_jars",
)

def phase_binary_compile(ctx, p):
args = struct(
unused_dependency_checker_ignored_targets = [
target.label
for target in p.init.scalac_provider.default_classpath +
ctx.attr.unused_dependency_checker_ignored_targets
],
)
return phase_common_compile(ctx, p, args)

def phase_library_compile(ctx, p):
args = struct(
srcjars = p.init.srcjars,
buildijar = True,
unused_dependency_checker_ignored_targets = [
target.label
for target in p.init.scalac_provider.default_classpath + ctx.attr.exports +
ctx.attr.unused_dependency_checker_ignored_targets
],
)
return phase_common_compile(ctx, p, args)

def phase_library_for_plugin_bootstrapping_compile(ctx, p):
args = struct(
buildijar = True,
unused_dependency_checker_ignored_targets = [
target.label
for target in p.init.scalac_provider.default_classpath + ctx.attr.exports
],
unused_dependency_checker_mode = "off",
)
return phase_common_compile(ctx, p, args)

def phase_macro_library_compile(ctx, p):
args = struct(
unused_dependency_checker_ignored_targets = [
target.label
for target in p.init.scalac_provider.default_macro_classpath + ctx.attr.exports +
ctx.attr.unused_dependency_checker_ignored_targets
],
)
return phase_common_compile(ctx, p, args)

def phase_junit_test_compile(ctx, p):
args = struct(
implicit_junit_deps_needed_for_java_compilation = [
ctx.attr._junit,
ctx.attr._hamcrest,
],
unused_dependency_checker_ignored_targets = [
target.label
for target in p.init.scalac_provider.default_classpath +
ctx.attr.unused_dependency_checker_ignored_targets
] + [
ctx.attr._junit.label,
ctx.attr._hamcrest.label,
ctx.attr.suite_label.label,
ctx.attr._bazel_test_runner.label,
],
)
return phase_common_compile(ctx, p, args)

def phase_repl_compile(ctx, p):
args = struct(
unused_dependency_checker_ignored_targets = [
target.label
for target in p.init.scalac_provider.default_repl_classpath +
ctx.attr.unused_dependency_checker_ignored_targets
],
)
return phase_common_compile(ctx, p, args)

def phase_test_compile(ctx, p):
args = struct(
unused_dependency_checker_ignored_targets = [
target.label
for target in p.init.scalac_provider.default_classpath +
ctx.attr.unused_dependency_checker_ignored_targets
],
)
return phase_common_compile(ctx, p, args)

def phase_common_compile(ctx, p, _args = struct()):
return _phase_compile(
ctx,
p,
_args.srcjars if hasattr(_args, "srcjars") else depset(),
_args.buildijar if hasattr(_args, "buildijar") else False,
_args.implicit_junit_deps_needed_for_java_compilation if hasattr(_args, "implicit_junit_deps_needed_for_java_compilation") else [],
_args.unused_dependency_checker_ignored_targets if hasattr(_args, "unused_dependency_checker_ignored_targets") else [],
_args.unused_dependency_checker_mode if hasattr(_args, "unused_dependency_checker_mode") else p.unused_deps_checker,
)

def _phase_compile(
ctx,
p,
srcjars,
buildijar,
implicit_junit_deps_needed_for_java_compilation,
unused_dependency_checker_ignored_targets,
unused_dependency_checker_mode):
manifest = ctx.outputs.manifest
jars = p.collect_jars.compile_jars
rjars = p.collect_jars.transitive_runtime_jars
transitive_compile_jars = p.collect_jars.transitive_compile_jars
jars2labels = p.collect_jars.jars2labels.jars_to_labels
deps_providers = p.collect_jars.deps_providers

out = compile_or_empty(
ctx,
manifest,
jars,
srcjars,
buildijar,
transitive_compile_jars,
jars2labels,
implicit_junit_deps_needed_for_java_compilation,
unused_dependency_checker_mode,
unused_dependency_checker_ignored_targets,
deps_providers,
)

return struct(
class_jar = out.class_jar,
coverage = out.coverage,
full_jars = out.full_jars,
ijar = out.ijar,
ijars = out.ijars,
rjars = depset(out.full_jars, transitive = [rjars]),
java_jar = out.java_jar,
source_jars = pack_source_jars(ctx) + out.source_jars,
merged_provider = out.merged_provider,
)
28 changes: 28 additions & 0 deletions scala/private/phases/phase_coverage_runfiles.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#
# PHASE: coverage runfiles
#
# DOCUMENT THIS
#
load(
"@io_bazel_rules_scala//scala/private:coverage_replacements_provider.bzl",
_coverage_replacements_provider = "coverage_replacements_provider",
)

def phase_coverage_runfiles(ctx, p):
coverage_runfiles = []
rjars = p.compile.rjars
if ctx.configuration.coverage_enabled and _coverage_replacements_provider.is_enabled(ctx):
coverage_replacements = _coverage_replacements_provider.from_ctx(
ctx,
base = p.compile.coverage.replacements,
).replacements

rjars = depset([
coverage_replacements[jar] if jar in coverage_replacements else jar
for jar in rjars.to_list()
])
coverage_runfiles = ctx.files._jacocorunner + ctx.files._lcov_merger + coverage_replacements.values()
return struct(
coverage_runfiles = coverage_runfiles,
rjars = rjars,
)
12 changes: 12 additions & 0 deletions scala/private/phases/phase_declare_executable.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#
# PHASE: declare executable
#
# DOCUMENT THIS
#
load(
"@io_bazel_rules_scala//scala/private:rule_impls.bzl",
"declare_executable",
)

def phase_declare_executable(ctx, p):
return declare_executable(ctx)
39 changes: 39 additions & 0 deletions scala/private/phases/phase_init.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#
# PHASE: init
#
# DOCUMENT THIS
#
load(
"@io_bazel_rules_scala//scala/private:rule_impls.bzl",
"get_scalac_provider",
)
load(
"@io_bazel_rules_scala//scala/private:common.bzl",
"collect_jars",
"collect_srcjars",
"write_manifest",
)

def phase_library_init(ctx, p):
# This will be used to pick up srcjars from non-scala library
# targets (like thrift code generation)
srcjars = collect_srcjars(ctx.attr.deps)

# Add information from exports (is key that AFTER all build actions/runfiles analysis)
# Since after, will not show up in deploy_jar or old jars runfiles
# Notice that compile_jars is intentionally transitive for exports
exports_jars = collect_jars(ctx.attr.exports)

args = phase_common_init(ctx, p)

return struct(
srcjars = srcjars,
exports_jars = exports_jars,
scalac_provider = args.scalac_provider,
)

def phase_common_init(ctx, p):
write_manifest(ctx)
return struct(
scalac_provider = get_scalac_provider(ctx),
)
46 changes: 46 additions & 0 deletions scala/private/phases/phase_java_wrapper.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#
# PHASE: java wrapper
#
# DOCUMENT THIS
#
load(
"@io_bazel_rules_scala//scala/private:rule_impls.bzl",
"write_java_wrapper",
)

def phase_repl_java_wrapper(ctx, p):
args = struct(
args = " ".join(ctx.attr.scalacopts),
wrapper_preamble = """
# save stty like in bin/scala
saved_stty=$(stty -g 2>/dev/null)
if [[ ! $? ]]; then
saved_stty=""
fi
function finish() {
if [[ "$saved_stty" != "" ]]; then
stty $saved_stty
saved_stty=""
fi
}
trap finish EXIT
""",
)
return phase_common_java_wrapper(ctx, p, args)

def phase_common_java_wrapper(ctx, p, _args = struct()):
return _phase_java_wrapper(
ctx,
_args.args if hasattr(_args, "args") else "",
_args.wrapper_preamble if hasattr(_args, "wrapper_preamble") else "",
)

def _phase_java_wrapper(
ctx,
args,
wrapper_preamble):
return write_java_wrapper(
ctx,
args,
wrapper_preamble,
)
Loading

0 comments on commit 8f9aa27

Please sign in to comment.