Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Scala import java info #517

Closed
wants to merge 7 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 49 additions & 48 deletions scala/scala_import.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
#if you change make sure to manually re-import an intellij project and see imports
#are resolved (not red) and clickable
def _scala_import_impl(ctx):
target_data = _code_jars_and_intellij_metadata_from(ctx.attr.jars)
(current_target_compile_jars, intellij_metadata) = (target_data.code_jars, target_data.intellij_metadata)
current_jars = depset(current_target_compile_jars)
exports = _collect(ctx.attr.exports)
transitive_runtime_jars = _collect_runtime(ctx.attr.runtime_deps)
jars = _collect(ctx.attr.deps)
jars2labels = {}
_collect_labels(ctx.attr.deps, jars2labels)
_collect_labels(ctx.attr.exports, jars2labels) #untested
_add_labels_of_current_code_jars(depset(transitive=[current_jars, exports.compile_jars]), ctx.label, jars2labels) #last to override the label of the export compile jars to the current target
intellij_metadata = _intellij_metadata_from(ctx.attr.jars)

jars_provider = _jars_to_provider(ctx.attr.jars)

deps_provider = _labels_to_provider(ctx.attr.deps)
runtime_deps_provider = _labels_to_provider(ctx.attr.runtime_deps)
exports_provider = _labels_to_provider(ctx.attr.exports)

jars2labels = _collect_jar_labels(ctx)
_add_labels_of_current_code_jars(depset(transitive=[jars_provider.compile_jars, exports_provider.compile_jars]), ctx.label, jars2labels) #last to override the label of the export compile jars to the current target

return struct(
scala = struct(
outputs = struct (
Expand All @@ -20,27 +21,48 @@ def _scala_import_impl(ctx):
),
jars_to_labels = jars2labels,
providers = [
_create_provider(current_jars, transitive_runtime_jars, jars, exports)
java_common.merge([jars_provider, deps_provider, runtime_deps_provider, exports_provider])
],
)
def _create_provider(current_target_compile_jars, transitive_runtime_jars, jars, exports):
return java_common.create_provider(
use_ijar = False,
compile_time_jars = depset(transitive = [current_target_compile_jars, exports.compile_jars]),
transitive_compile_time_jars = depset(transitive = [jars.transitive_compile_jars, current_target_compile_jars, exports.transitive_compile_jars]) ,
transitive_runtime_jars = depset(transitive = [transitive_runtime_jars, jars.transitive_runtime_jars, current_target_compile_jars, exports.transitive_runtime_jars]) ,
)

def _jars_to_provider(jars):
providers = []
for jar in jars:
if JavaInfo in jar:
fail("jars must contain only jar files")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have a test for this negative case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and also the error message is about what it should contain but the check is about whether or not this has JavaInfo.
Isn't there a mix? Maybe have a failure here to say targets of jvm_rules aren't allowed here and add another check below to see that the files are only jar files


code_jars = _filter_out_non_code_jars(jar.files)

for code_jar in code_jars:
providers.append(_jar_to_provider(code_jar))

return java_common.merge(providers)

def _jar_to_provider(jar):
return JavaInfo(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you building the JavaInfo like this (without the deps, exports,runtime_deps)?
I think that's not the correct semantics

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point there is only a jar file - not a JavaInfo

output_jar = jar,
compile_jar = jar,
)

def _collect_jar_labels(ctx):
jars2labels = {}
_collect_labels(ctx.attr.deps, jars2labels)
_collect_labels(ctx.attr.exports, jars2labels) #untested
return jars2labels

def _collect_labels(deps, jars2labels):
for dep_target in deps:
java_provider = dep_target[JavaInfo]
_transitively_accumulate_labels(dep_target, java_provider, jars2labels)

def _add_labels_of_current_code_jars(code_jars, label, jars2labels):
for jar in code_jars.to_list():
jars2labels[jar.path] = label

def _code_jars_and_intellij_metadata_from(jars):
code_jars = []
def _intellij_metadata_from(jars):
intellij_metadata = []
for jar in jars:
current_jar_code_jars = _filter_out_non_code_jars(jar.files)
code_jars += current_jar_code_jars
for current_class_jar in current_jar_code_jars: #intellij, untested
intellij_metadata.append(struct(
ijar = None,
Expand All @@ -49,34 +71,21 @@ def _code_jars_and_intellij_metadata_from(jars):
source_jars = [],
)
)
return struct(code_jars = code_jars, intellij_metadata = intellij_metadata)
return intellij_metadata

def _filter_out_non_code_jars(files):
return [file for file in files.to_list() if not _is_source_jar(file)]

def _is_source_jar(file):
return file.basename.endswith("-sources.jar")

# TODO: it seems this could be reworked to use java_common.merge
def _collect(deps):
transitive_compile_jars = []
runtime_jars = []
compile_jars = []

for dep_target in deps:
java_provider = dep_target[JavaInfo]
compile_jars.append(java_provider.compile_jars)
transitive_compile_jars.append(java_provider.transitive_compile_time_jars)
runtime_jars.append(java_provider.transitive_runtime_jars)
def _labels_to_provider(labels):
providers = []
for label in labels:
providers.append(label[JavaInfo])

return struct(transitive_runtime_jars = depset(transitive = runtime_jars),
transitive_compile_jars = depset(transitive = transitive_compile_jars),
compile_jars = depset(transitive = compile_jars))
return java_common.merge(providers)

def _collect_labels(deps, jars2labels):
for dep_target in deps:
java_provider = dep_target[JavaInfo]
_transitively_accumulate_labels(dep_target, java_provider,jars2labels)

def _transitively_accumulate_labels(dep_target, java_provider, jars2labels):
if hasattr(dep_target, "jars_to_labels"):
Expand All @@ -85,14 +94,6 @@ def _transitively_accumulate_labels(dep_target, java_provider, jars2labels):
for jar in java_provider.compile_jars.to_list():
jars2labels[jar.path] = dep_target.label

def _collect_runtime(runtime_deps):
jar_deps = []
for dep_target in runtime_deps:
java_provider = dep_target[JavaInfo]
jar_deps.append(java_provider.transitive_runtime_jars)

return depset(transitive = jar_deps)

scala_import = rule(
implementation=_scala_import_impl,
attrs={
Expand Down