Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remote: Proactively close the ZstdInputStream in ZstdDecompressingOutputStream. #15061

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/main/java/com/google/devtools/build/lib/remote/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ java_library(
"//src/main/java/com/google/devtools/build/lib/vfs:pathfragment",
"//src/main/java/com/google/devtools/common/options",
"//src/main/protobuf:failure_details_java_proto",
"//third_party:apache_commons_compress",
"//third_party:auth",
"//third_party:caffeine",
"//third_party:flogger",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.flogger.GoogleLogger;
import com.google.common.io.CountingOutputStream;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
Expand Down Expand Up @@ -67,10 +68,8 @@
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.commons.compress.utils.CountingOutputStream;

/** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */
@ThreadSafe
Expand Down Expand Up @@ -303,7 +302,7 @@ public ListenableFuture<Void> uploadActionResult(
public ListenableFuture<Void> downloadBlob(
RemoteActionExecutionContext context, Digest digest, OutputStream out) {
if (digest.getSizeBytes() == 0) {
return Futures.immediateFuture(null);
return Futures.immediateVoidFuture();
}

@Nullable Supplier<Digest> digestSupplier = null;
Expand All @@ -313,26 +312,14 @@ public ListenableFuture<Void> downloadBlob(
out = digestOut;
}

CountingOutputStream outputStream;
if (options.cacheCompression) {
try {
outputStream = new ZstdDecompressingOutputStream(out);
} catch (IOException e) {
return Futures.immediateFailedFuture(e);
}
} else {
outputStream = new CountingOutputStream(out);
}

return downloadBlob(context, digest, outputStream, digestSupplier);
return downloadBlob(context, digest, new CountingOutputStream(out), digestSupplier);
}

private ListenableFuture<Void> downloadBlob(
RemoteActionExecutionContext context,
Digest digest,
CountingOutputStream out,
@Nullable Supplier<Digest> digestSupplier) {
AtomicLong offset = new AtomicLong(0);
ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
ListenableFuture<Long> downloadFuture =
Utils.refreshIfUnauthenticatedAsync(
Expand All @@ -343,7 +330,6 @@ private ListenableFuture<Void> downloadBlob(
channel ->
requestRead(
context,
offset,
progressiveBackoff,
digest,
out,
Expand All @@ -370,20 +356,25 @@ public static String getResourceName(String instanceName, Digest digest, boolean

private ListenableFuture<Long> requestRead(
RemoteActionExecutionContext context,
AtomicLong offset,
ProgressiveBackoff progressiveBackoff,
Digest digest,
CountingOutputStream out,
CountingOutputStream rawOut,
@Nullable Supplier<Digest> digestSupplier,
Channel channel) {
String resourceName =
getResourceName(options.remoteInstanceName, digest, options.cacheCompression);
SettableFuture<Long> future = SettableFuture.create();
OutputStream out;
try {
out = options.cacheCompression ? new ZstdDecompressingOutputStream(rawOut) : rawOut;
} catch (IOException e) {
return Futures.immediateFailedFuture(e);
}
bsAsyncStub(context, channel)
.read(
ReadRequest.newBuilder()
.setResourceName(resourceName)
.setReadOffset(offset.get())
.setReadOffset(rawOut.getCount())
.build(),
new StreamObserver<ReadResponse>() {

Expand All @@ -392,7 +383,6 @@ public void onNext(ReadResponse readResponse) {
ByteString data = readResponse.getData();
try {
data.writeTo(out);
offset.set(out.getBytesWritten());
} catch (IOException e) {
// Cancel the call.
throw new RuntimeException(e);
Expand All @@ -403,14 +393,15 @@ public void onNext(ReadResponse readResponse) {

@Override
public void onError(Throwable t) {
if (offset.get() == digest.getSizeBytes()) {
if (rawOut.getCount() == digest.getSizeBytes()) {
// If the file was fully downloaded, it doesn't matter if there was an error at
// the end of the stream.
logger.atInfo().withCause(t).log(
"ignoring error because file was fully received");
onCompleted();
return;
}
releaseOut();
Status status = Status.fromThrowable(t);
if (status.getCode() == Status.Code.NOT_FOUND) {
future.setException(new CacheNotFoundException(digest));
Expand All @@ -426,12 +417,24 @@ public void onCompleted() {
Utils.verifyBlobContents(digest, digestSupplier.get());
}
out.flush();
future.set(offset.get());
future.set(rawOut.getCount());
} catch (IOException e) {
future.setException(e);
} catch (RuntimeException e) {
logger.atWarning().withCause(e).log("Unexpected exception");
future.setException(e);
} finally {
releaseOut();
}
}

private void releaseOut() {
if (out instanceof ZstdDecompressingOutputStream) {
try {
((ZstdDecompressingOutputStream) out).closeShallow();
} catch (IOException e) {
logger.atWarning().withCause(e).log("failed to cleanly close output stream");
}
}
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ java_library(
name = "zstd",
srcs = glob(["*.java"]),
deps = [
"//third_party:apache_commons_compress",
"//third_party:guava",
"//third_party/protobuf:protobuf_java",
"@zstd-jni",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,35 @@
// limitations under the License.
package com.google.devtools.build.lib.remote.zstd;

import com.github.luben.zstd.ZstdInputStream;
import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
import com.google.protobuf.ByteString;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import org.apache.commons.compress.utils.CountingOutputStream;

/** A {@link CountingOutputStream} that use zstd to decompress the content. */
public class ZstdDecompressingOutputStream extends CountingOutputStream {
/** An {@link OutputStream} that use zstd to decompress the content. */
public final class ZstdDecompressingOutputStream extends OutputStream {
private final OutputStream out;
private ByteArrayInputStream inner;
private final ZstdInputStream zis;
private final ZstdInputStreamNoFinalizer zis;

public ZstdDecompressingOutputStream(OutputStream out) throws IOException {
super(out);
this.out = out;
zis =
new ZstdInputStream(
new InputStream() {
@Override
public int read() {
return inner.read();
}

@Override
public int read(byte[] b, int off, int len) {
return inner.read(b, off, len);
}
});
zis.setContinuous(true);
new ZstdInputStreamNoFinalizer(
new InputStream() {
@Override
public int read() {
return inner.read();
}

@Override
public int read(byte[] b, int off, int len) {
return inner.read(b, off, len);
}
})
.setContinuous(true);
}

@Override
Expand All @@ -58,6 +58,19 @@ public void write(byte[] b) throws IOException {
public void write(byte[] b, int off, int len) throws IOException {
inner = new ByteArrayInputStream(b, off, len);
byte[] data = ByteString.readFrom(zis).toByteArray();
super.write(data, 0, data.length);
out.write(data, 0, data.length);
}

@Override
public void close() throws IOException {
closeShallow();
out.close();
}

/**
* Free resources related to decompression without closing the underlying {@link OutputStream}.
*/
public void closeShallow() throws IOException {
zis.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@

import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.mockito.ArgumentMatchers.any;

import build.bazel.remote.execution.v2.Digest;
import com.github.luben.zstd.Zstd;
import com.google.bytestream.ByteStreamGrpc.ByteStreamImplBase;
import com.google.bytestream.ByteStreamProto.ReadRequest;
import com.google.bytestream.ByteStreamProto.ReadResponse;
import com.google.devtools.build.lib.remote.Retrier.Backoff;
import com.google.devtools.build.lib.remote.options.RemoteOptions;
import com.google.devtools.common.options.Options;
import com.google.protobuf.ByteString;
Expand All @@ -31,38 +29,50 @@
import java.io.IOException;
import java.util.Arrays;
import org.junit.Test;
import org.mockito.Mockito;

/** Extra tests for {@link GrpcCacheClient} that are not tested internally. */
public class GrpcCacheClientTestExtra extends GrpcCacheClientTest {

@Test
public void compressedDownloadBlobIsRetriedWithProgress()
throws IOException, InterruptedException {
Backoff mockBackoff = Mockito.mock(Backoff.class);
RemoteOptions options = Options.getDefaults(RemoteOptions.class);
options.cacheCompression = true;
final GrpcCacheClient client = newClient(options, () -> mockBackoff);
final GrpcCacheClient client = newClient(options);
final Digest digest = DIGEST_UTIL.computeAsUtf8("abcdefg");
ByteString blob = ByteString.copyFrom(Zstd.compress("abcdefg".getBytes(UTF_8)));
ByteString chunk1 = ByteString.copyFrom(Zstd.compress("abc".getBytes(UTF_8)));
ByteString chunk2 = ByteString.copyFrom(Zstd.compress("def".getBytes(UTF_8)));
ByteString chunk3 = ByteString.copyFrom(Zstd.compress("g".getBytes(UTF_8)));
serviceRegistry.addService(
new ByteStreamImplBase() {
private boolean first = true;

@Override
public void read(ReadRequest request, StreamObserver<ReadResponse> responseObserver) {
assertThat(request.getResourceName().contains(digest.getHash())).isTrue();
int off = (int) request.getReadOffset();
// Zstd header size is 9 bytes
ByteString data = off == 0 ? blob.substring(0, 9 + 1) : blob.substring(9 + off);
responseObserver.onNext(ReadResponse.newBuilder().setData(data).build());
if (off == 0) {
if (first) {
first = false;
responseObserver.onError(Status.DEADLINE_EXCEEDED.asException());
} else {
responseObserver.onCompleted();
return;
}
switch (Math.toIntExact(request.getReadOffset())) {
case 0:
responseObserver.onNext(ReadResponse.newBuilder().setData(chunk1).build());
break;
case 3:
responseObserver.onNext(ReadResponse.newBuilder().setData(chunk2).build());
break;
case 6:
responseObserver.onNext(ReadResponse.newBuilder().setData(chunk3).build());
responseObserver.onCompleted();
return;
default:
throw new IllegalStateException("unexpected offset " + request.getReadOffset());
}
responseObserver.onError(Status.DEADLINE_EXCEEDED.asException());
}
});
assertThat(new String(downloadBlob(context, client, digest), UTF_8)).isEqualTo("abcdefg");
Mockito.verify(mockBackoff, Mockito.never()).nextDelayMillis(any(Exception.class));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public void bytesWrittenMatchesDecompressedBytes() throws IOException {
for (byte b : compressed.toByteArray()) {
zdos.write(b);
zdos.flush();
assertThat(zdos.getBytesWritten()).isEqualTo(decompressed.toByteArray().length);
}
assertThat(decompressed.toByteArray()).isEqualTo(data);
}
Expand Down