diff --git a/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingScheduler.java b/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingScheduler.java index fee09793ca320..62d14c92bec97 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingScheduler.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingScheduler.java @@ -22,6 +22,7 @@ import com.facebook.presto.operator.OperatorContext; import com.facebook.presto.operator.PipelineContext; import com.facebook.presto.operator.TaskContext; +import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.FeaturesConfig.TaskSpillingStrategy; @@ -64,7 +65,7 @@ public class MemoryRevokingScheduler private final double memoryRevokingTarget; private final TaskSpillingStrategy spillingStrategy; - private final MemoryPoolListener memoryPoolListener = MemoryPoolListener.onMemoryReserved(this::onMemoryReserved); + private final MemoryPoolListener memoryPoolListener = this::onMemoryReserved; @Nullable private ScheduledFuture scheduledFuture; @@ -152,7 +153,7 @@ void registerPoolListeners() memoryPools.forEach(memoryPool -> memoryPool.addListener(memoryPoolListener)); } - private void onMemoryReserved(MemoryPool memoryPool) + private void onMemoryReserved(MemoryPool memoryPool, QueryId queryId, long queryMemoryReservation) { try { if (!memoryRevokingNeeded(memoryPool)) { diff --git a/presto-main/src/main/java/com/facebook/presto/memory/MemoryPool.java b/presto-main/src/main/java/com/facebook/presto/memory/MemoryPool.java index d1b2b0f0da10b..8cede15d3e7c3 100644 --- a/presto-main/src/main/java/com/facebook/presto/memory/MemoryPool.java +++ b/presto-main/src/main/java/com/facebook/presto/memory/MemoryPool.java @@ -145,13 +145,14 @@ public ListenableFuture reserve(QueryId queryId, String allocationTag, long b } } - onMemoryReserved(); + onMemoryReserved(queryId); return result; } - private void onMemoryReserved() + private void onMemoryReserved(QueryId queryId) { - listeners.forEach(listener -> listener.onMemoryReserved(this)); + long totalMemoryReservation = queryMemoryReservations.getOrDefault(queryId, 0L) + queryMemoryRevocableReservations.getOrDefault(queryId, 0L); + listeners.forEach(listener -> listener.onMemoryReserved(this, queryId, totalMemoryReservation)); } public void onTaskMemoryReserved(TaskId taskId) @@ -181,7 +182,7 @@ public ListenableFuture reserveRevocable(QueryId queryId, long bytes) } } - onMemoryReserved(); + onMemoryReserved(queryId); return result; } @@ -202,7 +203,7 @@ public boolean tryReserve(QueryId queryId, String allocationTag, long bytes) } } - onMemoryReserved(); + onMemoryReserved(queryId); return true; } diff --git a/presto-main/src/main/java/com/facebook/presto/memory/MemoryPoolListener.java b/presto-main/src/main/java/com/facebook/presto/memory/MemoryPoolListener.java index 5307d21b0ce91..a2c3590842acc 100644 --- a/presto-main/src/main/java/com/facebook/presto/memory/MemoryPoolListener.java +++ b/presto-main/src/main/java/com/facebook/presto/memory/MemoryPoolListener.java @@ -13,22 +13,17 @@ */ package com.facebook.presto.memory; -import java.util.function.Consumer; +import com.facebook.presto.spi.QueryId; +@FunctionalInterface public interface MemoryPoolListener { /** * Invoked when memory reservation completes successfully. * - * @param memoryPool the {@link MemoryPool} where the reservation took place + * @param memoryPool the {@link MemoryPool} where the reservation took place + * @param queryId the {@link QueryId} of the query that reserved the memory + * @param queryMemoryReservation the total amount of memory reserved by the query (revocable and regular) */ - void onMemoryReserved(MemoryPool memoryPool); - - /** - * Creates {@link MemoryPoolListener} implementing {@link #onMemoryReserved(MemoryPool)} only. - */ - static MemoryPoolListener onMemoryReserved(Consumer action) - { - return action::accept; - } + void onMemoryReserved(MemoryPool memoryPool, QueryId queryId, long queryMemoryReservation); } diff --git a/presto-main/src/test/java/com/facebook/presto/memory/TestMemoryPools.java b/presto-main/src/test/java/com/facebook/presto/memory/TestMemoryPools.java index 94f93703814a4..e3a08a02a4d59 100644 --- a/presto-main/src/test/java/com/facebook/presto/memory/TestMemoryPools.java +++ b/presto-main/src/test/java/com/facebook/presto/memory/TestMemoryPools.java @@ -165,10 +165,10 @@ public void testNotifyListenerOnMemoryReserved() setupConsumeRevocableMemory(ONE_BYTE, 10); AtomicReference notifiedPool = new AtomicReference<>(); AtomicLong notifiedBytes = new AtomicLong(); - userPool.addListener(MemoryPoolListener.onMemoryReserved(pool -> { + userPool.addListener((pool, queryId, memoryReservation) -> { notifiedPool.set(pool); notifiedBytes.set(pool.getReservedBytes()); - })); + }); userPool.reserve(fakeQueryId, "test", 3); assertEquals(notifiedPool.get(), userPool);