Skip to content

Commit

Permalink
Let .bzl files record their usages of repo mapping
Browse files Browse the repository at this point in the history
In the same vein as #20742, we record all repo mapping entries used during the load of a .bzl file too, including any of its `load()` statements and calls to `Label()` that contain an apparent repo name.

See #20721 (comment) for a more detailed explanation for this change, and the test cases in this commit for more potential triggers.

Fixes #20721

Closes #20830.

PiperOrigin-RevId: 597351525
Change-Id: I8f6ed297b81d55f7476a93bdc6668e1e1dcbe536
  • Loading branch information
Wyverald committed Jan 10, 2024
1 parent ce993c4 commit 9c768c6
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -348,21 +348,23 @@ private static boolean didRepoMappingsChange(
.map(RepositoryMappingValue::key)
.collect(toImmutableSet()));
if (env.valuesMissing()) {
// This shouldn't really happen, since the RepositoryMappingValues of any recorded repos
// should have already been requested by the time we load the .bzl for the extension. And this
// method is only called if the transitive .bzl digest hasn't changed.
// However, we pretend it could happen anyway because we're good citizens.
// This likely means that one of the 'source repos' in the recorded mapping entries is no
// longer there.
throw new NeedsSkyframeRestartException();
}
for (Table.Cell<RepositoryName, String, RepositoryName> cell : recordedRepoMappings.cellSet()) {
RepositoryMappingValue repoMappingValue =
(RepositoryMappingValue) result.get(RepositoryMappingValue.key(cell.getRowKey()));
if (repoMappingValue == null) {
// Again, this shouldn't happen. But anyway.
throw new NeedsSkyframeRestartException();
}
if (!cell.getValue()
.equals(repoMappingValue.getRepositoryMapping().get(cell.getColumnKey()))) {
// Very importantly, `repoMappingValue` here could be for a repo that's no longer existent in
// the dep graph. See
// bazel_lockfile_test.testExtensionRepoMappingChange_sourceRepoNoLongerExistent for a test
// case.
if (repoMappingValue.equals(RepositoryMappingValue.NOT_FOUND_VALUE)
|| !cell.getValue()
.equals(repoMappingValue.getRepositoryMapping().get(cell.getColumnKey()))) {
// Wee woo wee woo -- diff detected!
return true;
}
Expand Down Expand Up @@ -806,20 +808,19 @@ private RegularRunnableExtension loadRegularRunnableExtension(
if (envVars == null) {
return null;
}
return new RegularRunnableExtension(
BazelModuleContext.of(bzlLoadValue.getModule()), extension, envVars);
return new RegularRunnableExtension(bzlLoadValue, extension, envVars);
}

private final class RegularRunnableExtension implements RunnableExtension {
private final BazelModuleContext bazelModuleContext;
private final BzlLoadValue bzlLoadValue;
private final ModuleExtension extension;
private final ImmutableMap<String, String> envVars;

RegularRunnableExtension(
BazelModuleContext bazelModuleContext,
BzlLoadValue bzlLoadValue,
ModuleExtension extension,
ImmutableMap<String, String> envVars) {
this.bazelModuleContext = bazelModuleContext;
this.bzlLoadValue = bzlLoadValue;
this.extension = extension;
this.envVars = envVars;
}
Expand All @@ -838,7 +839,7 @@ public ImmutableMap<String, String> getEnvVars() {

@Override
public byte[] getBzlTransitiveDigest() {
return bazelModuleContext.bzlTransitiveDigest();
return BazelModuleContext.of(bzlLoadValue.getModule()).bzlTransitiveDigest();
}

@Nullable
Expand All @@ -853,12 +854,13 @@ public RunModuleExtensionResult run(
new ModuleExtensionEvalStarlarkThreadContext(
usagesValue.getExtensionUniqueName() + "~",
extensionId.getBzlFileLabel().getPackageIdentifier(),
bazelModuleContext.repoMapping(),
BazelModuleContext.of(bzlLoadValue.getModule()).repoMapping(),
directories,
env.getListener());
ModuleExtensionContext moduleContext;
Optional<ModuleExtensionMetadata> moduleExtensionMetadata;
var repoMappingRecorder = new Label.RepoMappingRecorder();
repoMappingRecorder.mergeEntries(bzlLoadValue.getRecordedRepoMappings());
try (Mutability mu =
Mutability.create("module extension", usagesValue.getExtensionUniqueName())) {
StarlarkThread thread = new StarlarkThread(mu, starlarkSemantics);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ public static final class RepoMappingRecorder {
/** {@code <fromRepo, apparentRepoName, canonicalRepoName> } */
Table<RepositoryName, String, RepositoryName> entries = HashBasedTable.create();

public void mergeEntries(Table<RepositoryName, String, RepositoryName> entries) {
this.entries.putAll(entries);
}

public ImmutableTable<RepositoryName, String, RepositoryName> recordedEntries() {
return ImmutableTable.<RepositoryName, String, RepositoryName>builder()
.orderRowsBy(Comparator.comparing(RepositoryName::getName))
Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/google/devtools/build/lib/skyframe/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2190,6 +2190,7 @@ java_library(
"//src/main/java/com/google/devtools/build/skyframe:skyframe-objects",
"//third_party:auto_value",
"//third_party:guava",
"//third_party:jsr305",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ private BzlLoadValue computeInternalWithCompiledBzl(
if (repoMapping == null) {
return null;
}
Label.RepoMappingRecorder repoMappingRecorder = new Label.RepoMappingRecorder();
ImmutableList<Pair<String, Location>> programLoads = getLoadsFromProgram(prog);
ImmutableList<Label> loadLabels =
getLoadLabels(
Expand All @@ -773,7 +774,8 @@ private BzlLoadValue computeInternalWithCompiledBzl(
pkg,
repoMapping,
key.isSclDialect(),
isSclFlagEnabled);
isSclFlagEnabled,
repoMappingRecorder);
if (loadLabels == null) {
throw new BzlLoadFailedException(
String.format(
Expand Down Expand Up @@ -824,6 +826,7 @@ private BzlLoadValue computeInternalWithCompiledBzl(
BzlLoadValue v = loadValues.get(i++);
loadMap.put(load.first, v.getModule()); // dups ok
fp.addBytes(v.getTransitiveDigest());
repoMappingRecorder.mergeEntries(v.getRecordedRepoMappings());
}

// Retrieve predeclared symbols and complete the digest computation.
Expand Down Expand Up @@ -865,7 +868,14 @@ private BzlLoadValue computeInternalWithCompiledBzl(
// caching BzlLoadValues. Note that executing the code mutates the Module and
// BzlInitThreadContext.
executeBzlFile(
prog, label, module, loadMap, context, builtins.starlarkSemantics, env.getListener());
prog,
label,
module,
loadMap,
context,
builtins.starlarkSemantics,
env.getListener(),
repoMappingRecorder);

BzlVisibility bzlVisibility = context.getBzlVisibility();
if (bzlVisibility == null) {
Expand All @@ -874,7 +884,8 @@ private BzlLoadValue computeInternalWithCompiledBzl(
// We save load visibility in the BzlLoadValue rather than the BazelModuleContext because
// visibility doesn't need to be introspected by any Starlark builtin methods, and because the
// alternative would mean mutating or overwriting the BazelModuleContext after evaluation.
return new BzlLoadValue(module, transitiveDigest, bzlVisibility);
return new BzlLoadValue(
module, transitiveDigest, bzlVisibility, repoMappingRecorder.recordedEntries());
}

@Nullable
Expand Down Expand Up @@ -1061,7 +1072,8 @@ private static ImmutableList<Label> getLoadLabels(
PackageIdentifier base,
RepositoryMapping repoMapping,
boolean withinSclDialect,
boolean isSclFlagEnabled) {
boolean isSclFlagEnabled,
@Nullable Label.RepoMappingRecorder repoMappingRecorder) {
boolean ok = true;

ImmutableList.Builder<Label> loadLabels = ImmutableList.builderWithExpectedSize(loads.size());
Expand All @@ -1074,7 +1086,8 @@ private static ImmutableList<Label> getLoadLabels(
throw new LabelSyntaxException("in .scl files, load labels must begin with \"//\"");
}
Label label =
Label.parseWithPackageContext(unparsedLabel, PackageContext.of(base, repoMapping));
Label.parseWithPackageContext(
unparsedLabel, PackageContext.of(base, repoMapping), repoMappingRecorder);
checkValidLoadLabel(
label,
/* fromBuiltinsRepo= */ StarlarkBuiltinsValue.isBuiltinsRepo(base.getRepository()),
Expand Down Expand Up @@ -1109,7 +1122,8 @@ static ImmutableList<Label> getLoadLabels(
repoMapping,
/* withinSclDialect= */ false,
/* isSclFlagEnabled= */ starlarkSemantics.getBool(
BuildLanguageOptions.EXPERIMENTAL_ENABLE_SCL_DIALECT));
BuildLanguageOptions.EXPERIMENTAL_ENABLE_SCL_DIALECT),
/* repoMappingRecorder= */ null);
}

/** Extracts load statements from compiled program (see {@link #getLoadLabels}). */
Expand Down Expand Up @@ -1340,11 +1354,15 @@ private static void executeBzlFile(
Map<String, Module> loadedModules,
BzlInitThreadContext context,
StarlarkSemantics starlarkSemantics,
ExtendedEventHandler skyframeEventHandler)
ExtendedEventHandler skyframeEventHandler,
Label.RepoMappingRecorder repoMappingRecorder)
throws BzlLoadFailedException, InterruptedException {
try (Mutability mu = Mutability.create("loading", label)) {
StarlarkThread thread = new StarlarkThread(mu, starlarkSemantics);
thread.setLoader(loadedModules::get);
// This is needed so that any calls to `Label()` will have its used repo mapping entries
// recorded. See #20721 for more details.
thread.setThreadLocal(Label.RepoMappingRecorder.class, repoMappingRecorder);

// Wrap the skyframe event handler to listen for starlark errors.
AtomicBoolean sawStarlarkError = new AtomicBoolean(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableTable;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.cmdline.RepositoryName;
import com.google.devtools.build.lib.concurrent.ThreadSafety.Immutable;
Expand Down Expand Up @@ -52,12 +53,18 @@ public class BzlLoadValue implements SkyValue {
// from the Module as client data?
private final byte[] transitiveDigest; // of .bzl file and load dependencies
private final BzlVisibility bzlVisibility;
private final ImmutableTable<RepositoryName, String, RepositoryName> recordedRepoMappings;

@VisibleForTesting
public BzlLoadValue(Module module, byte[] transitiveDigest, BzlVisibility bzlVisibility) {
public BzlLoadValue(
Module module,
byte[] transitiveDigest,
BzlVisibility bzlVisibility,
ImmutableTable<RepositoryName, String, RepositoryName> recordedRepoMappings) {
this.module = checkNotNull(module);
this.transitiveDigest = checkNotNull(transitiveDigest);
this.bzlVisibility = checkNotNull(bzlVisibility);
this.recordedRepoMappings = checkNotNull(recordedRepoMappings);
}

/** Returns the .bzl module. */
Expand All @@ -75,6 +82,14 @@ public BzlVisibility getBzlVisibility() {
return bzlVisibility;
}

/**
* Returns the repo mapping entries used to laod this bzl file. Stored for correctness across
* Bazel server restarts.
*/
public ImmutableTable<RepositoryName, String, RepositoryName> getRecordedRepoMappings() {
return recordedRepoMappings;
}

private static final SkyKeyInterner<Key> keyInterner = SkyKey.newInterner();

/** SkyKey for a Starlark load. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.devtools.build.skyframe.SkyValue;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;

/**
* A value that represents the 'mappings' of an external Bazel workspace, as defined in the main
Expand Down Expand Up @@ -57,6 +58,9 @@ public abstract class RepositoryMappingValue implements SkyValue {
public static final RepositoryMappingValue VALUE_FOR_ROOT_MODULE_WITHOUT_REPOS =
RepositoryMappingValue.createForWorkspaceRepo(RepositoryMapping.ALWAYS_FALLBACK);

public static final RepositoryMappingValue NOT_FOUND_VALUE =
RepositoryMappingValue.createForWorkspaceRepo(null);

/**
* Returns a {@link RepositoryMappingValue} for a repo defined in MODULE.bazel, which has an
* associated module.
Expand All @@ -80,6 +84,8 @@ public static RepositoryMappingValue createForWorkspaceRepo(RepositoryMapping re
repositoryMapping, Optional.empty(), Optional.empty());
}

/** The actual repo mapping. Will be null if the requested repo doesn't exist. */
@Nullable
public abstract RepositoryMapping getRepositoryMapping();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Tables;
import com.google.devtools.build.lib.analysis.util.BuildViewTestCase;
import com.google.devtools.build.lib.clock.BlazeClock;
import com.google.devtools.build.lib.cmdline.BazelModuleContext;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.cmdline.RepositoryName;
import com.google.devtools.build.lib.packages.RuleVisibility;
import com.google.devtools.build.lib.packages.semantics.BuildLanguageOptions;
import com.google.devtools.build.lib.pkgcache.PackageOptions;
Expand Down Expand Up @@ -1035,7 +1037,7 @@ public void testLoadBzlFileFromWorkspaceWithRemapping() throws Exception {

scratch.file("/y/WORKSPACE");
scratch.file("/y/BUILD");
scratch.file("/y/y.bzl", "y_symbol = 5");
scratch.file("/y/y.bzl", "l = Label('@z//:z')", "y_symbol = 5");

scratch.file("/a/WORKSPACE");
scratch.file("/a/BUILD");
Expand All @@ -1051,8 +1053,13 @@ public void testLoadBzlFileFromWorkspaceWithRemapping() throws Exception {
SkyframeExecutorTestUtils.evaluate(
getSkyframeExecutor(), skyKey, /*keepGoing=*/ false, reporter);

assertThat(result.get(skyKey).getModule().getGlobals())
.containsEntry("a_symbol", StarlarkInt.of(5));
var bzlLoadValue = result.get(skyKey);
assertThat(bzlLoadValue.getModule().getGlobals()).containsEntry("a_symbol", StarlarkInt.of(5));
assertThat(bzlLoadValue.getRecordedRepoMappings().cellSet())
.containsExactly(
Tables.immutableCell(RepositoryName.create("a"), "x", RepositoryName.create("y")),
Tables.immutableCell(RepositoryName.create("y"), "z", RepositoryName.create("z")))
.inOrder();
}

@Test
Expand All @@ -1072,6 +1079,7 @@ public void testLoadBzlFileFromBzlmod() throws Exception {
fooDir.getRelative("test.bzl").getPathString(),
// Also test that bzlmod .bzl files can load .scl files.
"load('@bar_alias//:test.scl', 'haha')",
"l = Label('@foo//:whatever')",
"hoho = haha");
Path barDir = moduleRoot.getRelative("bar~2.0");
scratch.file(barDir.getRelative("WORKSPACE").getPathString());
Expand All @@ -1084,8 +1092,15 @@ public void testLoadBzlFileFromBzlmod() throws Exception {
getSkyframeExecutor(), skyKey, /*keepGoing=*/ false, reporter);

assertThatEvaluationResult(result).hasNoError();
assertThat(result.get(skyKey).getModule().getGlobals())
.containsEntry("hoho", StarlarkInt.of(5));
var bzlLoadValue = result.get(skyKey);
assertThat(bzlLoadValue.getModule().getGlobals()).containsEntry("hoho", StarlarkInt.of(5));
assertThat(bzlLoadValue.getRecordedRepoMappings().cellSet())
.containsExactly(
Tables.immutableCell(
RepositoryName.create("foo~1.0"), "bar_alias", RepositoryName.create("bar~2.0")),
Tables.immutableCell(
RepositoryName.create("foo~1.0"), "foo", RepositoryName.create("foo~1.0")))
.inOrder();
// Note that we're not testing the case of a non-registry override using @bazel_tools here, but
// that is incredibly hard to set up in a unit test. So we should just rely on integration tests
// for that.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ java_test(
size = "small",
srcs = ["BzlLoadValueCodecTest.java"],
deps = [
"//src/main/java/com/google/devtools/build/lib/cmdline",
"//src/main/java/com/google/devtools/build/lib/packages",
"//src/main/java/com/google/devtools/build/lib/skyframe:bzl_load_value",
"//src/main/java/com/google/devtools/build/lib/skyframe/serialization/testutils",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import static java.nio.charset.StandardCharsets.ISO_8859_1;

import com.google.common.collect.ImmutableClassToInstanceMap;
import com.google.common.collect.ImmutableTable;
import com.google.devtools.build.lib.cmdline.RepositoryName;
import com.google.devtools.build.lib.packages.BzlVisibility;
import com.google.devtools.build.lib.skyframe.BzlLoadValue;
import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester;
Expand All @@ -29,6 +31,10 @@
/** Tests for {@link BzlLoadValue} serialization. */
@RunWith(JUnit4.class)
public class BzlLoadValueCodecTest {
private static final ImmutableTable<RepositoryName, String, RepositoryName> SOME_TABLE =
ImmutableTable.of(
RepositoryName.createUnvalidated("foo"), "bar", RepositoryName.createUnvalidated("quux"));

@Test
public void objectCodecTests() throws Exception {
Module module = Module.create();
Expand All @@ -37,13 +43,14 @@ public void objectCodecTests() throws Exception {
module.setGlobal("c", 3);
byte[] digest = "dummy".getBytes(ISO_8859_1);

new SerializationTester(new BzlLoadValue(module, digest, BzlVisibility.PUBLIC))
new SerializationTester(new BzlLoadValue(module, digest, BzlVisibility.PUBLIC, SOME_TABLE))
.setVerificationFunction(
(SerializationTester.VerificationFunction<BzlLoadValue>)
(x, y) -> {
if (!java.util.Arrays.equals(x.getTransitiveDigest(), y.getTransitiveDigest())) {
throw new AssertionError("unequal digests after serialization");
}
assertThat(x.getRecordedRepoMappings()).isEqualTo(y.getRecordedRepoMappings());
})
.runTestsWithoutStableSerializationCheck();
}
Expand All @@ -64,6 +71,6 @@ private static BzlLoadValue makeBLV(String name, Object value) {
module.setGlobal(name, value);

byte[] digest = "dummy".getBytes(ISO_8859_1);
return new BzlLoadValue(module, digest, BzlVisibility.PUBLIC);
return new BzlLoadValue(module, digest, BzlVisibility.PUBLIC, SOME_TABLE);
}
}
Loading

0 comments on commit 9c768c6

Please sign in to comment.