diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index f98e21dd1dd5..344d8b43b799 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -99,6 +99,7 @@ use url::Url; use crate::catalog::listing_schema::ListingSchemaProvider; use crate::datasource::object_store::ObjectStoreUrl; +use crate::execution::memory_pool::MemoryPool; use uuid::Uuid; use super::options::{ @@ -1961,6 +1962,11 @@ impl TaskContext { self.task_id.clone() } + /// Return the [`MemoryPool`] associated with this [TaskContext] + pub fn memory_pool(&self) -> &Arc { + &self.runtime.memory_pool + } + /// Return the [RuntimeEnv] associated with this [TaskContext] pub fn runtime_env(&self) -> Arc { self.runtime.clone() @@ -2026,6 +2032,7 @@ mod tests { use super::*; use crate::assert_batches_eq; use crate::execution::context::QueryPlanner; + use crate::execution::memory_pool::MemoryConsumer; use crate::execution::runtime_env::RuntimeConfig; use crate::physical_plan::expressions::AvgAccumulator; use crate::test; @@ -2047,24 +2054,27 @@ mod tests { #[tokio::test] async fn shared_memory_and_disk_manager() { // Demonstrate the ability to share DiskManager and - // MemoryManager between two different executions. + // MemoryPool between two different executions. let ctx1 = SessionContext::new(); // configure with same memory / disk manager - let memory_manager = ctx1.runtime_env().memory_manager.clone(); + let memory_pool = ctx1.runtime_env().memory_pool.clone(); + + let mut reservation = MemoryConsumer::new("test").register(&memory_pool); + reservation.grow(100); + let disk_manager = ctx1.runtime_env().disk_manager.clone(); let ctx2 = SessionContext::with_config_rt(SessionConfig::new(), ctx1.runtime_env()); - assert!(std::ptr::eq( - Arc::as_ptr(&memory_manager), - Arc::as_ptr(&ctx1.runtime_env().memory_manager) - )); - assert!(std::ptr::eq( - Arc::as_ptr(&memory_manager), - Arc::as_ptr(&ctx2.runtime_env().memory_manager) - )); + assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 100); + assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 100); + + drop(reservation); + + assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 0); + assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 0); assert!(std::ptr::eq( Arc::as_ptr(&disk_manager), diff --git a/datafusion/core/src/execution/memory_manager/mod.rs b/datafusion/core/src/execution/memory_manager/mod.rs deleted file mode 100644 index c3ff444ebc27..000000000000 --- a/datafusion/core/src/execution/memory_manager/mod.rs +++ /dev/null @@ -1,664 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Manages all available memory during query execution - -use crate::error::{DataFusionError, Result}; -use async_trait::async_trait; -use hashbrown::HashSet; -use log::{debug, warn}; -use parking_lot::{Condvar, Mutex}; -use std::fmt; -use std::fmt::{Debug, Display, Formatter}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::time::{Duration, Instant}; - -pub mod proxy; - -static CONSUMER_ID: AtomicUsize = AtomicUsize::new(0); - -#[derive(Debug, Clone)] -/// Configuration information for memory management -pub enum MemoryManagerConfig { - /// Use the existing [MemoryManager] - Existing(Arc), - - /// Create a new [MemoryManager] that will use up to some - /// fraction of total system memory. - New { - /// Max execution memory allowed for DataFusion. Defaults to - /// `usize::MAX`, which will not attempt to limit the memory - /// used during plan execution. - max_memory: usize, - - /// The fraction of `max_memory` that the memory manager will - /// use for execution. - /// - /// The purpose of this config is to set aside memory for - /// untracked data structures, and imprecise size estimation - /// during memory acquisition. Defaults to 0.7 - memory_fraction: f64, - }, -} - -impl Default for MemoryManagerConfig { - fn default() -> Self { - Self::New { - max_memory: usize::MAX, - memory_fraction: 0.7, - } - } -} - -impl MemoryManagerConfig { - /// Create a new memory [MemoryManager] with no limit on the - /// memory used - pub fn new() -> Self { - Default::default() - } - - /// Create a configuration based on an existing [MemoryManager] - pub fn new_existing(existing: Arc) -> Self { - Self::Existing(existing) - } - - /// Create a new [MemoryManager] with a `max_memory` and `fraction` - pub fn try_new_limit(max_memory: usize, memory_fraction: f64) -> Result { - if max_memory == 0 { - return Err(DataFusionError::Plan(format!( - "invalid max_memory. Expected greater than 0, got {}", - max_memory - ))); - } - if !(memory_fraction > 0f64 && memory_fraction <= 1f64) { - return Err(DataFusionError::Plan(format!( - "invalid fraction. Expected greater than 0 and less than 1.0, got {}", - memory_fraction - ))); - } - - Ok(Self::New { - max_memory, - memory_fraction, - }) - } - - /// return the maximum size of the memory, in bytes, this config will allow - fn pool_size(&self) -> usize { - match self { - MemoryManagerConfig::Existing(existing) => existing.pool_size, - MemoryManagerConfig::New { - max_memory, - memory_fraction, - } => (*max_memory as f64 * *memory_fraction) as usize, - } - } -} - -fn next_id() -> usize { - CONSUMER_ID.fetch_add(1, Ordering::SeqCst) -} - -/// Type of the memory consumer -pub enum ConsumerType { - /// consumers that can grow its memory usage by requesting more from the memory manager or - /// shrinks its memory usage when we can no more assign available memory to it. - /// Examples are spillable sorter, spillable hashmap, etc. - Requesting, - /// consumers that are not spillable, counting in for only tracking purpose. - Tracking, -} - -#[derive(Clone, Debug, Hash, Eq, PartialEq)] -/// Id that uniquely identifies a Memory Consumer -pub struct MemoryConsumerId { - /// partition the consumer belongs to - pub partition_id: usize, - /// unique id - pub id: usize, -} - -impl MemoryConsumerId { - /// Auto incremented new Id - pub fn new(partition_id: usize) -> Self { - let id = next_id(); - Self { partition_id, id } - } -} - -impl Display for MemoryConsumerId { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "{}:{}", self.partition_id, self.id) - } -} - -#[async_trait] -/// A memory consumer that either takes up memory (of type `ConsumerType::Tracking`) -/// or grows/shrinks memory usage based on available memory (of type `ConsumerType::Requesting`). -pub trait MemoryConsumer: Send + Sync { - /// Display name of the consumer - fn name(&self) -> String; - - /// Unique id of the consumer - fn id(&self) -> &MemoryConsumerId; - - /// Ptr to MemoryManager - fn memory_manager(&self) -> Arc; - - /// Partition that the consumer belongs to - fn partition_id(&self) -> usize { - self.id().partition_id - } - - /// Type of the consumer - fn type_(&self) -> &ConsumerType; - - /// Grow memory by `required` to buffer more data in memory, - /// this may trigger spill before grow when the memory threshold is - /// reached for this consumer. - async fn try_grow(&self, required: usize) -> Result<()> { - let current = self.mem_used(); - debug!( - "trying to acquire {} whiling holding {} from consumer {}", - human_readable_size(required), - human_readable_size(current), - self.id(), - ); - - let can_grow_directly = - self.memory_manager().can_grow_directly(required, current); - if !can_grow_directly { - debug!( - "Failed to grow memory of {} directly from consumer {}, spilling first ...", - human_readable_size(required), - self.id() - ); - let freed = self.spill().await?; - self.memory_manager() - .record_free_then_acquire(freed, required); - } - Ok(()) - } - - /// Grow without spilling to the disk. It grows the memory directly - /// so it should be only used when the consumer already allocated the - /// memory and it is safe to grow without spilling. - fn grow(&self, required: usize) { - self.memory_manager().record_free_then_acquire(0, required); - } - - /// Return `freed` memory to the memory manager, - /// may wake up other requesters waiting for their minimum memory quota. - fn shrink(&self, freed: usize) { - self.memory_manager().record_free(freed); - } - - /// Spill in-memory buffers to disk, free memory, return the previous used - async fn spill(&self) -> Result; - - /// Current memory used by this consumer - fn mem_used(&self) -> usize; -} - -impl Debug for dyn MemoryConsumer { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!( - f, - "{}[{}]: {}", - self.name(), - self.id(), - human_readable_size(self.mem_used()) - ) - } -} - -impl Display for dyn MemoryConsumer { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "{}[{}]", self.name(), self.id(),) - } -} - -/* -The memory management architecture is the following: - -1. User designates max execution memory by setting RuntimeConfig.max_memory and RuntimeConfig.memory_fraction (float64 between 0..1). - The actual max memory DataFusion could use `pool_size = max_memory * memory_fraction`. -2. The entities that take up memory during its execution are called 'Memory Consumers'. Operators or others are encouraged to - register themselves to the memory manager and report its usage through `mem_used()`. -3. There are two kinds of consumers: - - 'Requesting' consumers that would acquire memory during its execution and release memory through `spill` if no more memory is available. - - 'Tracking' consumers that exist for reporting purposes to provide a more accurate memory usage estimation for memory consumers. -4. Requesting and tracking consumers share the pool. Each controlling consumer could acquire a maximum of - (pool_size - all_tracking_used) / active_num_controlling_consumers. - - Memory Space for the DataFusion Lib / Process of `pool_size` - ┌──────────────────────────────────────────────z─────────────────────────────┐ - │ z │ - │ z │ - │ Requesting z Tracking │ - │ Memory Consumers z Memory Consumers │ - │ z │ - │ z │ - └──────────────────────────────────────────────z─────────────────────────────┘ -*/ - -/// Manage memory usage during physical plan execution -#[derive(Debug)] -pub struct MemoryManager { - requesters: Arc>>, - pool_size: usize, - requesters_total: Arc>, - trackers_total: AtomicUsize, - cv: Condvar, -} - -impl MemoryManager { - /// Create new memory manager based on the configuration - #[allow(clippy::mutex_atomic)] - pub fn new(config: MemoryManagerConfig) -> Arc { - let pool_size = config.pool_size(); - - match config { - MemoryManagerConfig::Existing(manager) => manager, - MemoryManagerConfig::New { .. } => { - debug!( - "Creating memory manager with initial size {}", - human_readable_size(pool_size) - ); - - Arc::new(Self { - requesters: Arc::new(Mutex::new(HashSet::new())), - pool_size, - requesters_total: Arc::new(Mutex::new(0)), - trackers_total: AtomicUsize::new(0), - cv: Condvar::new(), - }) - } - } - } - - fn get_tracker_total(&self) -> usize { - self.trackers_total.load(Ordering::SeqCst) - } - - pub(crate) fn grow_tracker_usage(&self, delta: usize) { - self.trackers_total.fetch_add(delta, Ordering::SeqCst); - } - - pub(crate) fn shrink_tracker_usage(&self, delta: usize) { - let update = - self.trackers_total - .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| { - if x >= delta { - Some(x - delta) - } else { - None - } - }); - update.unwrap_or_else(|_| { - panic!( - "Tracker total memory shrink by {} underflow, current value is ", - delta - ) - }); - } - - /// Return the total memory usage for all requesters - pub fn get_requester_total(&self) -> usize { - *self.requesters_total.lock() - } - - /// Register a new memory requester - pub(crate) fn register_requester(&self, requester_id: &MemoryConsumerId) { - self.requesters.lock().insert(requester_id.clone()); - } - - fn max_mem_for_requesters(&self) -> usize { - let trk_total = self.get_tracker_total(); - self.pool_size.saturating_sub(trk_total) - } - - /// Grow memory attempt from a consumer, return if we could grant that much to it - fn can_grow_directly(&self, required: usize, current: usize) -> bool { - let num_rqt = self.requesters.lock().len(); - let mut rqt_current_used = self.requesters_total.lock(); - let mut rqt_max = self.max_mem_for_requesters(); - - let granted; - loop { - let max_per_rqt = rqt_max / num_rqt; - let min_per_rqt = max_per_rqt / 2; - - if required + current >= max_per_rqt { - granted = false; - break; - } - - let remaining = rqt_max.checked_sub(*rqt_current_used).unwrap_or_default(); - if remaining >= required { - granted = true; - *rqt_current_used += required; - break; - } else if current < min_per_rqt { - // if we cannot acquire at lease 1/2n memory, just wait for others - // to spill instead spill self frequently with limited total mem - debug!( - "Cannot acquire a minimum amount of {} memory from the manager of total {}, waiting for others to spill ...", - human_readable_size(min_per_rqt), human_readable_size(self.pool_size)); - let now = Instant::now(); - self.cv.wait(&mut rqt_current_used); - let elapsed = now.elapsed(); - if elapsed > Duration::from_secs(10) { - warn!("Elapsed on waiting for spilling: {:.2?}", elapsed); - } - } else { - granted = false; - break; - } - - rqt_max = self.max_mem_for_requesters(); - } - - granted - } - - fn record_free_then_acquire(&self, freed: usize, acquired: usize) { - let mut requesters_total = self.requesters_total.lock(); - debug!( - "free_then_acquire: total {}, freed {}, acquired {}", - human_readable_size(*requesters_total), - human_readable_size(freed), - human_readable_size(acquired) - ); - assert!(*requesters_total >= freed); - *requesters_total -= freed; - *requesters_total += acquired; - self.cv.notify_all(); - } - - fn record_free(&self, freed: usize) { - let mut requesters_total = self.requesters_total.lock(); - debug!( - "free: total {}, freed {}", - human_readable_size(*requesters_total), - human_readable_size(freed) - ); - assert!(*requesters_total >= freed); - *requesters_total -= freed; - self.cv.notify_all(); - } - - /// Drop a memory consumer and reclaim the memory - pub(crate) fn drop_consumer(&self, id: &MemoryConsumerId, mem_used: usize) { - // find in requesters first - { - let mut requesters = self.requesters.lock(); - if requesters.remove(id) { - let mut total = self.requesters_total.lock(); - assert!(*total >= mem_used); - *total -= mem_used; - self.cv.notify_all(); - return; - } - } - self.shrink_tracker_usage(mem_used); - self.cv.notify_all(); - } -} - -impl Display for MemoryManager { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, - "MemoryManager usage statistics: total {}, trackers used {}, total {} requesters used: {}", - human_readable_size(self.pool_size), - human_readable_size(self.get_tracker_total()), - self.requesters.lock().len(), - human_readable_size(self.get_requester_total()), - ) - } -} - -const TB: u64 = 1 << 40; -const GB: u64 = 1 << 30; -const MB: u64 = 1 << 20; -const KB: u64 = 1 << 10; - -/// Present size in human readable form -pub fn human_readable_size(size: usize) -> String { - let size = size as u64; - let (value, unit) = { - if size >= 2 * TB { - (size as f64 / TB as f64, "TB") - } else if size >= 2 * GB { - (size as f64 / GB as f64, "GB") - } else if size >= 2 * MB { - (size as f64 / MB as f64, "MB") - } else if size >= 2 * KB { - (size as f64 / KB as f64, "KB") - } else { - (size as f64, "B") - } - }; - format!("{:.1} {}", value, unit) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::error::Result; - use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; - use crate::execution::MemoryConsumer; - use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MemTrackingMetrics}; - use async_trait::async_trait; - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Arc; - - struct DummyRequester { - id: MemoryConsumerId, - runtime: Arc, - spills: AtomicUsize, - mem_used: AtomicUsize, - } - - impl DummyRequester { - fn new(partition: usize, runtime: Arc) -> Self { - Self { - id: MemoryConsumerId::new(partition), - runtime, - spills: AtomicUsize::new(0), - mem_used: AtomicUsize::new(0), - } - } - - async fn do_with_mem(&self, grow: usize) -> Result<()> { - self.try_grow(grow).await?; - self.mem_used.fetch_add(grow, Ordering::SeqCst); - Ok(()) - } - - fn get_spills(&self) -> usize { - self.spills.load(Ordering::SeqCst) - } - } - - #[async_trait] - impl MemoryConsumer for DummyRequester { - fn name(&self) -> String { - "dummy".to_owned() - } - - fn id(&self) -> &MemoryConsumerId { - &self.id - } - - fn memory_manager(&self) -> Arc { - self.runtime.memory_manager.clone() - } - - fn type_(&self) -> &ConsumerType { - &ConsumerType::Requesting - } - - async fn spill(&self) -> Result { - self.spills.fetch_add(1, Ordering::SeqCst); - let used = self.mem_used.swap(0, Ordering::SeqCst); - Ok(used) - } - - fn mem_used(&self) -> usize { - self.mem_used.load(Ordering::SeqCst) - } - } - - struct DummyTracker { - id: MemoryConsumerId, - runtime: Arc, - mem_used: usize, - } - - impl DummyTracker { - fn new(partition: usize, runtime: Arc, mem_used: usize) -> Self { - runtime.grow_tracker_usage(mem_used); - Self { - id: MemoryConsumerId::new(partition), - runtime, - mem_used, - } - } - } - - #[async_trait] - impl MemoryConsumer for DummyTracker { - fn name(&self) -> String { - "dummy".to_owned() - } - - fn id(&self) -> &MemoryConsumerId { - &self.id - } - - fn memory_manager(&self) -> Arc { - self.runtime.memory_manager.clone() - } - - fn type_(&self) -> &ConsumerType { - &ConsumerType::Tracking - } - - async fn spill(&self) -> Result { - Ok(0) - } - - fn mem_used(&self) -> usize { - self.mem_used - } - } - - #[tokio::test] - async fn basic_functionalities() { - let config = RuntimeConfig::new() - .with_memory_manager(MemoryManagerConfig::try_new_limit(100, 1.0).unwrap()); - let runtime = Arc::new(RuntimeEnv::new(config).unwrap()); - - DummyTracker::new(0, runtime.clone(), 5); - assert_eq!(runtime.memory_manager.get_tracker_total(), 5); - - let tracker1 = DummyTracker::new(0, runtime.clone(), 10); - assert_eq!(runtime.memory_manager.get_tracker_total(), 15); - - DummyTracker::new(0, runtime.clone(), 15); - assert_eq!(runtime.memory_manager.get_tracker_total(), 30); - - runtime.drop_consumer(tracker1.id(), tracker1.mem_used); - assert_eq!(runtime.memory_manager.get_tracker_total(), 20); - - // MemTrackingMetrics as an easy way to track memory - let ms = ExecutionPlanMetricsSet::new(); - let tracking_metric = MemTrackingMetrics::new_with_rt(&ms, 0, runtime.clone()); - tracking_metric.init_mem_used(15); - assert_eq!(runtime.memory_manager.get_tracker_total(), 35); - - drop(tracking_metric); - assert_eq!(runtime.memory_manager.get_tracker_total(), 20); - - let requester1 = DummyRequester::new(0, runtime.clone()); - runtime.register_requester(requester1.id()); - - // first requester entered, should be able to use any of the remaining 80 - requester1.do_with_mem(40).await.unwrap(); - requester1.do_with_mem(10).await.unwrap(); - assert_eq!(requester1.get_spills(), 0); - assert_eq!(requester1.mem_used(), 50); - assert_eq!(*runtime.memory_manager.requesters_total.lock(), 50); - - let requester2 = DummyRequester::new(0, runtime.clone()); - runtime.register_requester(requester2.id()); - - requester2.do_with_mem(20).await.unwrap(); - requester2.do_with_mem(30).await.unwrap(); - assert_eq!(requester2.get_spills(), 1); - assert_eq!(requester2.mem_used(), 30); - - requester1.do_with_mem(10).await.unwrap(); - assert_eq!(requester1.get_spills(), 1); - assert_eq!(requester1.mem_used(), 10); - - assert_eq!(*runtime.memory_manager.requesters_total.lock(), 40); - } - - #[tokio::test] - #[should_panic(expected = "invalid max_memory. Expected greater than 0, got 0")] - async fn test_try_new_with_limit_0() { - MemoryManagerConfig::try_new_limit(0, 1.0).unwrap(); - } - - #[tokio::test] - #[should_panic( - expected = "invalid fraction. Expected greater than 0 and less than 1.0, got -9.6" - )] - async fn test_try_new_with_limit_neg_fraction() { - MemoryManagerConfig::try_new_limit(100, -9.6).unwrap(); - } - - #[tokio::test] - #[should_panic( - expected = "invalid fraction. Expected greater than 0 and less than 1.0, got 9.6" - )] - async fn test_try_new_with_limit_too_large() { - MemoryManagerConfig::try_new_limit(100, 9.6).unwrap(); - } - - #[tokio::test] - async fn test_try_new_with_limit_pool_size() { - let config = MemoryManagerConfig::try_new_limit(100, 0.5).unwrap(); - assert_eq!(config.pool_size(), 50); - - let config = MemoryManagerConfig::try_new_limit(100000, 0.1).unwrap(); - assert_eq!(config.pool_size(), 10000); - } - - #[tokio::test] - async fn test_memory_manager_underflow() { - let config = MemoryManagerConfig::try_new_limit(100, 0.5).unwrap(); - let manager = MemoryManager::new(config); - manager.grow_tracker_usage(100); - - manager.register_requester(&MemoryConsumerId::new(1)); - assert!(!manager.can_grow_directly(20, 0)); - } -} diff --git a/datafusion/core/src/execution/memory_pool/mod.rs b/datafusion/core/src/execution/memory_pool/mod.rs new file mode 100644 index 000000000000..6369cda4d149 --- /dev/null +++ b/datafusion/core/src/execution/memory_pool/mod.rs @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Manages all available memory during query execution + +use crate::error::Result; +use std::sync::Arc; + +mod pool; +pub mod proxy; + +pub use pool::*; + +/// The pool of memory on which [`MemoryReservation`] record their memory reservations +pub trait MemoryPool: Send + Sync + std::fmt::Debug { + /// Registers a new [`MemoryConsumer`] + /// + /// Note: Subsequent calls to [`Self::grow`] must be made to reserve memory + fn register(&self, _consumer: &MemoryConsumer) {} + + /// Records the destruction of a [`MemoryReservation`] with [`MemoryConsumer`] + /// + /// Note: Prior calls to [`Self::shrink`] must be made to free any reserved memory + fn unregister(&self, _consumer: &MemoryConsumer) {} + + /// Infallibly grow the provided `reservation` by `additional` bytes + /// + /// This must always succeed + fn grow(&self, reservation: &MemoryReservation, additional: usize); + + /// Infallibly shrink the provided `reservation` by `shrink` bytes + fn shrink(&self, reservation: &MemoryReservation, shrink: usize); + + /// Attempt to grow the provided `reservation` by `additional` bytes + /// + /// On error the `allocation` will not be increased in size + fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()>; + + /// Return the total amount of memory reserved + fn reserved(&self) -> usize; +} + +/// A memory consumer that can be tracked by [`MemoryReservation`] in a [`MemoryPool`] +#[derive(Debug)] +pub struct MemoryConsumer { + name: String, + can_spill: bool, +} + +impl MemoryConsumer { + /// Create a new empty [`MemoryConsumer`] that can be grown using [`MemoryReservation`] + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + can_spill: false, + } + } + + /// Set whether this allocation can be spilled to disk + pub fn with_can_spill(self, can_spill: bool) -> Self { + Self { can_spill, ..self } + } + + /// Returns true if this allocation can spill to disk + pub fn can_spill(&self) -> bool { + self.can_spill + } + + /// Returns the name associated with this allocation + pub fn name(&self) -> &str { + &self.name + } + + /// Registers this [`MemoryConsumer`] with the provided [`MemoryPool`] returning + /// a [`MemoryReservation`] that can be used to grow or shrink the memory reservation + pub fn register(self, pool: &Arc) -> MemoryReservation { + pool.register(&self); + MemoryReservation { + consumer: self, + size: 0, + policy: Arc::clone(pool), + } + } +} + +/// A [`MemoryReservation`] tracks a reservation of memory in a [`MemoryPool`] +/// that is freed back to the pool on drop +#[derive(Debug)] +pub struct MemoryReservation { + consumer: MemoryConsumer, + size: usize, + policy: Arc, +} + +impl MemoryReservation { + /// Returns the size of this reservation in bytes + pub fn size(&self) -> usize { + self.size + } + + /// Frees all bytes from this reservation returning the number of bytes freed + pub fn free(&mut self) -> usize { + let size = self.size; + if size != 0 { + self.shrink(size) + } + size + } + + /// Frees `capacity` bytes from this reservation + /// + /// # Panics + /// + /// Panics if `capacity` exceeds [`Self::size`] + pub fn shrink(&mut self, capacity: usize) { + let new_size = self.size.checked_sub(capacity).unwrap(); + self.policy.shrink(self, capacity); + self.size = new_size + } + + /// Sets the size of this reservation to `capacity` + pub fn resize(&mut self, capacity: usize) { + use std::cmp::Ordering; + match capacity.cmp(&self.size) { + Ordering::Greater => self.grow(capacity - self.size), + Ordering::Less => self.shrink(self.size - capacity), + _ => {} + } + } + + /// Increase the size of this reservation by `capacity` bytes + pub fn grow(&mut self, capacity: usize) { + self.policy.grow(self, capacity); + self.size += capacity; + } + + /// Try to increase the size of this reservation by `capacity` bytes + pub fn try_grow(&mut self, capacity: usize) -> Result<()> { + self.policy.try_grow(self, capacity)?; + self.size += capacity; + Ok(()) + } +} + +impl Drop for MemoryReservation { + fn drop(&mut self) { + self.free(); + self.policy.unregister(&self.consumer); + } +} + +const TB: u64 = 1 << 40; +const GB: u64 = 1 << 30; +const MB: u64 = 1 << 20; +const KB: u64 = 1 << 10; + +/// Present size in human readable form +pub fn human_readable_size(size: usize) -> String { + let size = size as u64; + let (value, unit) = { + if size >= 2 * TB { + (size as f64 / TB as f64, "TB") + } else if size >= 2 * GB { + (size as f64 / GB as f64, "GB") + } else if size >= 2 * MB { + (size as f64 / MB as f64, "MB") + } else if size >= 2 * KB { + (size as f64 / KB as f64, "KB") + } else { + (size as f64, "B") + } + }; + format!("{:.1} {}", value, unit) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_memory_pool_underflow() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let mut a1 = MemoryConsumer::new("a1").register(&pool); + assert_eq!(pool.reserved(), 0); + + a1.grow(100); + assert_eq!(pool.reserved(), 100); + + assert_eq!(a1.free(), 100); + assert_eq!(pool.reserved(), 0); + + a1.try_grow(100).unwrap_err(); + assert_eq!(pool.reserved(), 0); + + a1.try_grow(30).unwrap(); + assert_eq!(pool.reserved(), 30); + + let mut a2 = MemoryConsumer::new("a2").register(&pool); + a2.try_grow(25).unwrap_err(); + assert_eq!(pool.reserved(), 30); + + drop(a1); + assert_eq!(pool.reserved(), 0); + + a2.try_grow(25).unwrap(); + assert_eq!(pool.reserved(), 25); + } +} diff --git a/datafusion/core/src/execution/memory_pool/pool.rs b/datafusion/core/src/execution/memory_pool/pool.rs new file mode 100644 index 000000000000..5d28629be9c2 --- /dev/null +++ b/datafusion/core/src/execution/memory_pool/pool.rs @@ -0,0 +1,285 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; +use datafusion_common::{DataFusionError, Result}; +use parking_lot::Mutex; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// A [`MemoryPool`] that enforces no limit +#[derive(Debug, Default)] +pub struct UnboundedMemoryPool { + used: AtomicUsize, +} + +impl MemoryPool for UnboundedMemoryPool { + fn grow(&self, _reservation: &MemoryReservation, additional: usize) { + self.used.fetch_add(additional, Ordering::Relaxed); + } + + fn shrink(&self, _reservation: &MemoryReservation, shrink: usize) { + self.used.fetch_sub(shrink, Ordering::Relaxed); + } + + fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()> { + self.grow(reservation, additional); + Ok(()) + } + + fn reserved(&self) -> usize { + self.used.load(Ordering::Relaxed) + } +} + +/// A [`MemoryPool`] that implements a greedy first-come first-serve limit +#[derive(Debug)] +pub struct GreedyMemoryPool { + pool_size: usize, + used: AtomicUsize, +} + +impl GreedyMemoryPool { + /// Allocate up to `limit` bytes + pub fn new(pool_size: usize) -> Self { + Self { + pool_size, + used: AtomicUsize::new(0), + } + } +} + +impl MemoryPool for GreedyMemoryPool { + fn grow(&self, _reservation: &MemoryReservation, additional: usize) { + self.used.fetch_add(additional, Ordering::Relaxed); + } + + fn shrink(&self, _reservation: &MemoryReservation, shrink: usize) { + self.used.fetch_sub(shrink, Ordering::Relaxed); + } + + fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()> { + self.used + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |used| { + let new_used = used + additional; + (new_used <= self.pool_size).then_some(new_used) + }) + .map_err(|used| { + insufficient_capacity_err(reservation, additional, self.pool_size - used) + })?; + Ok(()) + } + + fn reserved(&self) -> usize { + self.used.load(Ordering::Relaxed) + } +} + +/// A [`MemoryPool`] that prevents spillable reservations from using more than +/// an even fraction of the available memory sans any unspillable reservations +/// (i.e. `(pool_size - unspillable_memory) / num_spillable_reservations`) +/// +/// ┌───────────────────────z──────────────────────z───────────────┐ +/// │ z z │ +/// │ z z │ +/// │ Spillable z Unspillable z Free │ +/// │ Memory z Memory z Memory │ +/// │ z z │ +/// │ z z │ +/// └───────────────────────z──────────────────────z───────────────┘ +/// +/// Unspillable memory is allocated in a first-come, first-serve fashion +#[derive(Debug)] +pub struct FairSpillPool { + /// The total memory limit + pool_size: usize, + + state: Mutex, +} + +#[derive(Debug)] +struct FairSpillPoolState { + /// The number of consumers that can spill + num_spill: usize, + + /// The total amount of memory reserved that can be spilled + spillable: usize, + + /// The total amount of memory reserved by consumers that cannot spill + unspillable: usize, +} + +impl FairSpillPool { + /// Allocate up to `limit` bytes + pub fn new(pool_size: usize) -> Self { + Self { + pool_size, + state: Mutex::new(FairSpillPoolState { + num_spill: 0, + spillable: 0, + unspillable: 0, + }), + } + } +} + +impl MemoryPool for FairSpillPool { + fn register(&self, consumer: &MemoryConsumer) { + if consumer.can_spill { + self.state.lock().num_spill += 1; + } + } + + fn unregister(&self, consumer: &MemoryConsumer) { + if consumer.can_spill { + self.state.lock().num_spill -= 1; + } + } + + fn grow(&self, reservation: &MemoryReservation, additional: usize) { + let mut state = self.state.lock(); + match reservation.consumer.can_spill { + true => state.spillable += additional, + false => state.unspillable += additional, + } + } + + fn shrink(&self, reservation: &MemoryReservation, shrink: usize) { + let mut state = self.state.lock(); + match reservation.consumer.can_spill { + true => state.spillable -= shrink, + false => state.unspillable -= shrink, + } + } + + fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()> { + let mut state = self.state.lock(); + + match reservation.consumer.can_spill { + true => { + // The total amount of memory available to spilling consumers + let spill_available = self.pool_size.saturating_sub(state.unspillable); + + // No spiller may use more than their fraction of the memory available + let available = spill_available + .checked_div(state.num_spill) + .unwrap_or(spill_available); + + if reservation.size + additional > available { + return Err(insufficient_capacity_err( + reservation, + additional, + available, + )); + } + state.spillable += additional; + } + false => { + let available = self + .pool_size + .saturating_sub(state.unspillable + state.unspillable); + + if available < additional { + return Err(insufficient_capacity_err( + reservation, + additional, + available, + )); + } + state.unspillable += additional; + } + } + Ok(()) + } + + fn reserved(&self) -> usize { + let state = self.state.lock(); + state.spillable + state.unspillable + } +} + +fn insufficient_capacity_err( + reservation: &MemoryReservation, + additional: usize, + available: usize, +) -> DataFusionError { + DataFusionError::ResourcesExhausted(format!("Failed to allocate additional {} bytes for {} with {} bytes already allocated - maximum available is {}", additional, reservation.consumer.name, reservation.size, available)) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn test_fair() { + let pool = Arc::new(FairSpillPool::new(100)) as _; + + let mut r1 = MemoryConsumer::new("unspillable").register(&pool); + // Can grow beyond capacity of pool + r1.grow(2000); + assert_eq!(pool.reserved(), 2000); + + let mut r2 = MemoryConsumer::new("s1") + .with_can_spill(true) + .register(&pool); + // Can grow beyond capacity of pool + r2.grow(2000); + + assert_eq!(pool.reserved(), 4000); + + let err = r2.try_grow(1).unwrap_err().to_string(); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for s1 with 2000 bytes already allocated - maximum available is 0"); + + let err = r2.try_grow(1).unwrap_err().to_string(); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for s1 with 2000 bytes already allocated - maximum available is 0"); + + r1.shrink(1990); + r2.shrink(2000); + + assert_eq!(pool.reserved(), 10); + + r1.try_grow(10).unwrap(); + assert_eq!(pool.reserved(), 20); + + // Can grow a2 to 80 as only spilling consumer + r2.try_grow(80).unwrap(); + assert_eq!(pool.reserved(), 100); + + r2.shrink(70); + + assert_eq!(r1.size(), 20); + assert_eq!(r2.size(), 10); + assert_eq!(pool.reserved(), 30); + + let mut r3 = MemoryConsumer::new("s2") + .with_can_spill(true) + .register(&pool); + + let err = r3.try_grow(70).unwrap_err().to_string(); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for s2 with 0 bytes already allocated - maximum available is 40"); + + //Shrinking a2 to zero doesn't allow a3 to allocate more than 45 + r2.free(); + let err = r3.try_grow(70).unwrap_err().to_string(); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for s2 with 0 bytes already allocated - maximum available is 40"); + + // But dropping a2 does + drop(r2); + assert_eq!(pool.reserved(), 20); + r3.try_grow(80).unwrap(); + } +} diff --git a/datafusion/core/src/execution/memory_manager/proxy.rs b/datafusion/core/src/execution/memory_pool/proxy.rs similarity index 60% rename from datafusion/core/src/execution/memory_manager/proxy.rs rename to datafusion/core/src/execution/memory_pool/proxy.rs index 2a5bd2507357..43532f9a81f1 100644 --- a/datafusion/core/src/execution/memory_manager/proxy.rs +++ b/datafusion/core/src/execution/memory_pool/proxy.rs @@ -16,96 +16,9 @@ // under the License. //! Utilities that help with tracking of memory allocations. -use std::sync::Arc; -use async_trait::async_trait; -use datafusion_common::DataFusionError; use hashbrown::raw::{Bucket, RawTable}; -use super::{ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager}; - -/// Accounting proxy for memory usage. -/// -/// This is helpful when calculating memory usage on the actual data structure is expensive but it is easy to track -/// allocations while processing data. -/// -/// This consumer will NEVER spill. -pub struct MemoryConsumerProxy { - /// Name - name: String, - - /// Consumer ID. - id: MemoryConsumerId, - - /// Linked memory manager. - memory_manager: Arc, - - /// Currently used size in bytes. - used: usize, -} - -impl MemoryConsumerProxy { - /// Create new proxy consumer and register it with the given memory manager. - pub fn new( - name: impl Into, - id: MemoryConsumerId, - memory_manager: Arc, - ) -> Self { - memory_manager.register_requester(&id); - - Self { - name: name.into(), - id, - memory_manager, - used: 0, - } - } - - /// Try to allocate given amount of memory. - pub async fn alloc(&mut self, bytes: usize) -> Result<(), DataFusionError> { - self.try_grow(bytes).await?; - self.used = self.used.checked_add(bytes).expect("overflow"); - Ok(()) - } -} - -#[async_trait] -impl MemoryConsumer for MemoryConsumerProxy { - fn name(&self) -> String { - self.name.clone() - } - - fn id(&self) -> &crate::execution::MemoryConsumerId { - &self.id - } - - fn memory_manager(&self) -> Arc { - Arc::clone(&self.memory_manager) - } - - fn type_(&self) -> &ConsumerType { - &ConsumerType::Tracking - } - - async fn spill(&self) -> Result { - Err(DataFusionError::ResourcesExhausted(format!( - "Cannot spill {}", - self.name - ))) - } - - fn mem_used(&self) -> usize { - self.used - } -} - -impl Drop for MemoryConsumerProxy { - fn drop(&mut self) { - self.memory_manager - .drop_consumer(self.id(), self.mem_used()); - } -} - /// Extension trait for [`Vec`] to account for allocations. pub trait VecAllocExt { /// Item type. diff --git a/datafusion/core/src/execution/mod.rs b/datafusion/core/src/execution/mod.rs index 024980dee059..5eb859df9304 100644 --- a/datafusion/core/src/execution/mod.rs +++ b/datafusion/core/src/execution/mod.rs @@ -42,13 +42,10 @@ pub mod context; pub mod disk_manager; -pub mod memory_manager; +pub mod memory_pool; pub mod options; pub mod registry; pub mod runtime_env; pub use disk_manager::DiskManager; -pub use memory_manager::{ - human_readable_size, MemoryConsumer, MemoryConsumerId, MemoryManager, -}; pub use registry::FunctionRegistry; diff --git a/datafusion/core/src/execution/runtime_env.rs b/datafusion/core/src/execution/runtime_env.rs index 64da4a103b16..d559e7c7fa35 100644 --- a/datafusion/core/src/execution/runtime_env.rs +++ b/datafusion/core/src/execution/runtime_env.rs @@ -20,16 +20,14 @@ use crate::{ error::Result, - execution::{ - disk_manager::{DiskManager, DiskManagerConfig}, - memory_manager::{MemoryConsumerId, MemoryManager, MemoryManagerConfig}, - }, + execution::disk_manager::{DiskManager, DiskManagerConfig}, }; use std::collections::HashMap; use crate::datasource::datasource::TableProviderFactory; use crate::datasource::listing_table_factory::ListingTableFactory; use crate::datasource::object_store::ObjectStoreRegistry; +use crate::execution::memory_pool::{GreedyMemoryPool, MemoryPool, UnboundedMemoryPool}; use datafusion_common::DataFusionError; use object_store::ObjectStore; use std::fmt::{Debug, Formatter}; @@ -41,7 +39,7 @@ use url::Url; /// Execution runtime environment. pub struct RuntimeEnv { /// Runtime memory management - pub memory_manager: Arc, + pub memory_pool: Arc, /// Manage temporary files during query execution pub disk_manager: Arc, /// Object Store Registry @@ -60,40 +58,23 @@ impl RuntimeEnv { /// Create env based on configuration pub fn new(config: RuntimeConfig) -> Result { let RuntimeConfig { - memory_manager, + memory_pool, disk_manager, object_store_registry, table_factories, } = config; + let memory_pool = + memory_pool.unwrap_or_else(|| Arc::new(UnboundedMemoryPool::default())); + Ok(Self { - memory_manager: MemoryManager::new(memory_manager), + memory_pool, disk_manager: DiskManager::try_new(disk_manager)?, object_store_registry, table_factories, }) } - /// Register the consumer to get it tracked - pub fn register_requester(&self, id: &MemoryConsumerId) { - self.memory_manager.register_requester(id); - } - - /// Drop the consumer from get tracked, reclaim memory - pub fn drop_consumer(&self, id: &MemoryConsumerId, mem_used: usize) { - self.memory_manager.drop_consumer(id, mem_used) - } - - /// Grow tracker memory of `delta` - pub fn grow_tracker_usage(&self, delta: usize) { - self.memory_manager.grow_tracker_usage(delta) - } - - /// Shrink tracker memory of `delta` - pub fn shrink_tracker_usage(&self, delta: usize) { - self.memory_manager.shrink_tracker_usage(delta) - } - /// Registers a custom `ObjectStore` to be used when accessing a /// specific scheme and host. This allows DataFusion to create /// external tables from urls that do not have built in support @@ -142,8 +123,10 @@ impl Default for RuntimeEnv { pub struct RuntimeConfig { /// DiskManager to manage temporary disk file usage pub disk_manager: DiskManagerConfig, - /// MemoryManager to limit access to memory - pub memory_manager: MemoryManagerConfig, + /// [`MemoryPool`] from which to allocate memory + /// + /// Defaults to using an [`UnboundedMemoryPool`] if `None` + pub memory_pool: Option>, /// ObjectStoreRegistry to get object store based on url pub object_store_registry: Arc, /// Custom table factories for things like deltalake that are not part of core datafusion @@ -172,9 +155,9 @@ impl RuntimeConfig { self } - /// Customize memory manager - pub fn with_memory_manager(mut self, memory_manager: MemoryManagerConfig) -> Self { - self.memory_manager = memory_manager; + /// Customize memory policy + pub fn with_memory_pool(mut self, memory_pool: Arc) -> Self { + self.memory_pool = Some(memory_pool); self } @@ -199,11 +182,12 @@ impl RuntimeConfig { /// Specify the total memory to use while running the DataFusion /// plan to `max_memory * memory_fraction` in bytes. /// + /// This defaults to using [`GreedyMemoryPool`] + /// /// Note DataFusion does not yet respect this limit in all cases. pub fn with_memory_limit(self, max_memory: usize, memory_fraction: f64) -> Self { - self.with_memory_manager( - MemoryManagerConfig::try_new_limit(max_memory, memory_fraction).unwrap(), - ) + let pool_size = (max_memory as f64 * memory_fraction) as usize; + self.with_memory_pool(Arc::new(GreedyMemoryPool::new(pool_size))) } /// Use the specified path to create any needed temporary files diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 30bf6dc7eb7d..09b6c6691f72 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -212,6 +212,7 @@ /// DataFusion crate version pub const DATAFUSION_VERSION: &str = env!("CARGO_PKG_VERSION"); +extern crate core; extern crate sqlparser; pub mod avro_to_arrow; diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs b/datafusion/core/src/physical_plan/aggregates/hash.rs index 4d19330808a7..64b21ecf9a6b 100644 --- a/datafusion/core/src/physical_plan/aggregates/hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/hash.rs @@ -29,10 +29,7 @@ use futures::stream::{Stream, StreamExt}; use crate::error::Result; use crate::execution::context::TaskContext; -use crate::execution::memory_manager::proxy::{ - MemoryConsumerProxy, RawTableAllocExt, VecAllocExt, -}; -use crate::execution::MemoryConsumerId; +use crate::execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use crate::physical_plan::aggregates::{ evaluate_group_by, evaluate_many, AccumulatorItem, AggregateMode, PhysicalGroupBy, }; @@ -42,6 +39,7 @@ use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr}; use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; use crate::scalar::ScalarValue; +use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use arrow::{array::ArrayRef, compute, compute::cast}; use arrow::{ array::{Array, UInt32Builder}, @@ -126,6 +124,10 @@ impl GroupedHashAggregateStream { timer.done(); + let reservation = + MemoryConsumer::new(format!("GroupedHashAggregateStream[{}]", partition)) + .register(context.memory_pool()); + let inner = GroupedHashAggregateStreamInner { schema: Arc::clone(&schema), mode, @@ -135,11 +137,7 @@ impl GroupedHashAggregateStream { baseline_metrics, aggregate_expressions, accumulators: Some(Accumulators { - memory_consumer: MemoryConsumerProxy::new( - "GroupBy Hash Accumulators", - MemoryConsumerId::new(partition), - Arc::clone(&context.runtime_env().memory_manager), - ), + reservation, map: RawTable::with_capacity(0), group_states: Vec::with_capacity(0), }), @@ -175,15 +173,10 @@ impl GroupedHashAggregateStream { // allocate memory // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with // overshooting a bit. Also this means we either store the whole record batch or not. - let result = match result { - Ok(allocated) => { - accumulators.memory_consumer.alloc(allocated).await - } - Err(e) => Err(e), - }; - - match result { - Ok(()) => continue, + match result.and_then(|allocated| { + accumulators.reservation.try_grow(allocated) + }) { + Ok(_) => continue, Err(e) => Err(ArrowError::ExternalError(Box::new(e))), } } @@ -445,7 +438,7 @@ struct GroupState { /// The state of all the groups struct Accumulators { - memory_consumer: MemoryConsumerProxy, + reservation: MemoryReservation, /// Logically maps group values to an index in `group_states` /// diff --git a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs index 64cc4f569c8c..8a312abafd9b 100644 --- a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs +++ b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs @@ -18,8 +18,6 @@ //! Aggregate without grouping columns use crate::execution::context::TaskContext; -use crate::execution::memory_manager::proxy::MemoryConsumerProxy; -use crate::execution::MemoryConsumerId; use crate::physical_plan::aggregates::{ aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem, AggregateMode, @@ -35,6 +33,7 @@ use futures::stream::BoxStream; use std::sync::Arc; use std::task::{Context, Poll}; +use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use futures::stream::{Stream, StreamExt}; /// stream struct for aggregation without grouping columns @@ -55,7 +54,7 @@ struct AggregateStreamInner { baseline_metrics: BaselineMetrics, aggregate_expressions: Vec>>, accumulators: Vec, - memory_consumer: MemoryConsumerProxy, + reservation: MemoryReservation, finished: bool, } @@ -69,14 +68,12 @@ impl AggregateStream { baseline_metrics: BaselineMetrics, context: Arc, partition: usize, - ) -> datafusion_common::Result { + ) -> Result { let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode, 0)?; let accumulators = create_accumulators(&aggr_expr)?; - let memory_consumer = MemoryConsumerProxy::new( - "GroupBy None Accumulators", - MemoryConsumerId::new(partition), - Arc::clone(&context.runtime_env().memory_manager), - ); + + let reservation = MemoryConsumer::new(format!("AggregateStream[{}]", partition)) + .register(context.memory_pool()); let inner = AggregateStreamInner { schema: Arc::clone(&schema), @@ -85,7 +82,7 @@ impl AggregateStream { baseline_metrics, aggregate_expressions, accumulators, - memory_consumer, + reservation, finished: false, }; let stream = futures::stream::unfold(inner, |mut this| async move { @@ -111,12 +108,9 @@ impl AggregateStream { // allocate memory // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with // overshooting a bit. Also this means we either store the whole record batch or not. - let result = match result { - Ok(allocated) => this.memory_consumer.alloc(allocated).await, - Err(e) => Err(e), - }; - - match result { + match result + .and_then(|allocated| this.reservation.try_grow(allocated)) + { Ok(_) => continue, Err(e) => Err(ArrowError::ExternalError(Box::new(e))), } diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index c73fa3da0c2e..e769397871ef 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -27,10 +27,7 @@ use futures::stream::{Stream, StreamExt}; use crate::error::Result; use crate::execution::context::TaskContext; -use crate::execution::memory_manager::proxy::{ - MemoryConsumerProxy, RawTableAllocExt, VecAllocExt, -}; -use crate::execution::MemoryConsumerId; +use crate::execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use crate::physical_plan::aggregates::{ evaluate_group_by, evaluate_many, group_schema, AccumulatorItemV2, AggregateMode, PhysicalGroupBy, @@ -40,6 +37,7 @@ use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput}; use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr}; use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; +use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use arrow::compute::cast; use arrow::datatypes::Schema; use arrow::{array::ArrayRef, compute}; @@ -141,13 +139,12 @@ impl GroupedHashAggregateStreamV2 { let aggr_schema = aggr_state_schema(&aggr_expr)?; let aggr_layout = Arc::new(RowLayout::new(&aggr_schema, RowType::WordAligned)); + let reservation = + MemoryConsumer::new(format!("GroupedHashAggregateStreamV2[{}]", partition)) + .register(context.memory_pool()); let aggr_state = AggregationState { - memory_consumer: MemoryConsumerProxy::new( - "GroupBy Hash (Row) AggregationState", - MemoryConsumerId::new(partition), - Arc::clone(&context.runtime_env().memory_manager), - ), + reservation, map: RawTable::with_capacity(0), group_states: Vec::with_capacity(0), }; @@ -196,15 +193,10 @@ impl GroupedHashAggregateStreamV2 { // allocate memory // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with // overshooting a bit. Also this means we either store the whole record batch or not. - let result = match result { - Ok(allocated) => { - this.aggr_state.memory_consumer.alloc(allocated).await - } - Err(e) => Err(e), - }; - - match result { - Ok(()) => continue, + match result.and_then(|allocated| { + this.aggr_state.reservation.try_grow(allocated) + }) { + Ok(_) => continue, Err(e) => Err(ArrowError::ExternalError(Box::new(e))), } } @@ -465,7 +457,7 @@ struct RowGroupState { /// The state of all the groups struct AggregationState { - memory_consumer: MemoryConsumerProxy, + reservation: MemoryReservation, /// Logically maps group values to an index in `group_states` /// diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index b4db3a32b522..b29dc0cb8c11 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -51,7 +51,7 @@ impl SizedRecordBatchStream { pub fn new( schema: SchemaRef, batches: Vec>, - metrics: MemTrackingMetrics, + mut metrics: MemTrackingMetrics, ) -> Self { let size = batches.iter().map(|b| batch_byte_size(b)).sum::(); metrics.init_mem_used(size); diff --git a/datafusion/core/src/physical_plan/explain.rs b/datafusion/core/src/physical_plan/explain.rs index ac350b1837e8..077ed8dcc461 100644 --- a/datafusion/core/src/physical_plan/explain.rs +++ b/datafusion/core/src/physical_plan/explain.rs @@ -152,7 +152,8 @@ impl ExecutionPlan for ExplainExec { )?; let metrics = ExecutionPlanMetricsSet::new(); - let tracking_metrics = MemTrackingMetrics::new(&metrics, partition); + let tracking_metrics = + MemTrackingMetrics::new(&metrics, context.memory_pool(), partition); debug!( "Before returning SizedRecordBatch in ExplainExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); diff --git a/datafusion/core/src/physical_plan/metrics/composite.rs b/datafusion/core/src/physical_plan/metrics/composite.rs index cd4d5c38a9ec..3c257805d2c5 100644 --- a/datafusion/core/src/physical_plan/metrics/composite.rs +++ b/datafusion/core/src/physical_plan/metrics/composite.rs @@ -17,7 +17,7 @@ //! Metrics common for complex operators with multiple steps. -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::memory_pool::MemoryPool; use crate::physical_plan::metrics::tracker::MemTrackingMetrics; use crate::physical_plan::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricValue, MetricsSet, Time, @@ -32,7 +32,7 @@ use std::time::Duration; /// Collects all metrics during a complex operation, which is composed of multiple steps and /// each stage reports its statistics separately. /// Give sort as an example, when the dataset is more significant than available memory, it will report -/// multiple in-mem sort metrics and final merge-sort metrics from `SortPreservingMergeStream`. +/// multiple in-mem sort metrics and final merge-sort metrics from `SortPreservingMergeStream`. /// Therefore, We need a separation of metrics for which are final metrics (for output_rows accumulation), /// and which are intermediate metrics that we only account for elapsed_compute time. pub struct CompositeMetricsSet { @@ -69,18 +69,18 @@ impl CompositeMetricsSet { pub fn new_intermediate_tracking( &self, partition: usize, - runtime: Arc, + pool: &Arc, ) -> MemTrackingMetrics { - MemTrackingMetrics::new_with_rt(&self.mid, partition, runtime) + MemTrackingMetrics::new(&self.mid, pool, partition) } /// create a new final memory tracking metrics pub fn new_final_tracking( &self, partition: usize, - runtime: Arc, + pool: &Arc, ) -> MemTrackingMetrics { - MemTrackingMetrics::new_with_rt(&self.final_, partition, runtime) + MemTrackingMetrics::new(&self.final_, pool, partition) } fn merge_compute_time(&self, dest: &Time) { diff --git a/datafusion/core/src/physical_plan/metrics/tracker.rs b/datafusion/core/src/physical_plan/metrics/tracker.rs index d8017b95ae8d..c61398c65810 100644 --- a/datafusion/core/src/physical_plan/metrics/tracker.rs +++ b/datafusion/core/src/physical_plan/metrics/tracker.rs @@ -17,51 +17,37 @@ //! Metrics with memory usage tracking capability -use crate::execution::runtime_env::RuntimeEnv; -use crate::execution::MemoryConsumerId; use crate::physical_plan::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, Time, }; use std::sync::Arc; use std::task::Poll; +use crate::execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use arrow::{error::ArrowError, record_batch::RecordBatch}; -/// Simplified version of tracking memory consumer, -/// see also: [`Tracking`](crate::execution::memory_manager::ConsumerType::Tracking) -/// -/// You could use this to replace [BaselineMetrics], report the memory, -/// and get the memory usage bookkeeping in the memory manager easily. +/// Wraps a [`BaselineMetrics`] and records memory usage on a [`MemoryReservation`] #[derive(Debug)] pub struct MemTrackingMetrics { - id: MemoryConsumerId, - runtime: Option>, + reservation: MemoryReservation, metrics: BaselineMetrics, } /// Delegates most of the metrics functionalities to the inner BaselineMetrics, /// intercept memory metrics functionalities and do memory manager bookkeeping. impl MemTrackingMetrics { - /// Create metrics similar to [BaselineMetrics] - pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { - let id = MemoryConsumerId::new(partition); - Self { - id, - runtime: None, - metrics: BaselineMetrics::new(metrics, partition), - } - } - - /// Create memory tracking metrics with reference to runtime - pub fn new_with_rt( + /// Create memory tracking metrics with reference to memory manager + pub fn new( metrics: &ExecutionPlanMetricsSet, + pool: &Arc, partition: usize, - runtime: Arc, ) -> Self { - let id = MemoryConsumerId::new(partition); + let reservation = + MemoryConsumer::new(format!("MemTrackingMetrics[{}]", partition)) + .register(pool); + Self { - id, - runtime: Some(runtime), + reservation, metrics: BaselineMetrics::new(metrics, partition), } } @@ -77,11 +63,9 @@ impl MemTrackingMetrics { } /// setup initial memory usage and register it with memory manager - pub fn init_mem_used(&self, size: usize) { + pub fn init_mem_used(&mut self, size: usize) { self.metrics.mem_used().set(size); - if let Some(rt) = self.runtime.as_ref() { - rt.memory_manager.grow_tracker_usage(size); - } + self.reservation.resize(size) } /// return the metric for the total number of output rows produced @@ -118,14 +102,3 @@ impl MemTrackingMetrics { self.metrics.record_poll(poll) } } - -impl Drop for MemTrackingMetrics { - fn drop(&mut self) { - self.metrics.try_done(); - if self.mem_used() != 0 { - if let Some(rt) = self.runtime.as_ref() { - rt.drop_consumer(&self.id, self.mem_used()); - } - } - } -} diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index b6c37d109859..85eca5450849 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -21,8 +21,8 @@ use crate::error::{DataFusionError, Result}; use crate::execution::context::TaskContext; -use crate::execution::memory_manager::{ - human_readable_size, ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager, +use crate::execution::memory_pool::{ + human_readable_size, MemoryConsumer, MemoryReservation, }; use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::common::{batch_byte_size, IPCWriter, SizedRecordBatchStream}; @@ -45,9 +45,7 @@ use arrow::datatypes::SchemaRef; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; -use async_trait::async_trait; use datafusion_physical_expr::EquivalenceProperties; -use futures::lock::Mutex; use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt}; use log::{debug, error}; use std::any::Any; @@ -73,10 +71,9 @@ use tokio::task; /// buffer the batch in memory, go to 1. /// 3. when input is exhausted, merge all in memory batches and spills to get a total order. struct ExternalSorter { - id: MemoryConsumerId, schema: SchemaRef, - in_mem_batches: Mutex>, - spills: Mutex>, + in_mem_batches: Vec, + spills: Vec, /// Sort expressions expr: Vec, session_config: Arc, @@ -84,6 +81,8 @@ struct ExternalSorter { metrics_set: CompositeMetricsSet, metrics: BaselineMetrics, fetch: Option, + reservation: MemoryReservation, + partition_id: usize, } impl ExternalSorter { @@ -97,31 +96,40 @@ impl ExternalSorter { fetch: Option, ) -> Self { let metrics = metrics_set.new_intermediate_baseline(partition_id); + + let reservation = + MemoryConsumer::new(format!("ExternalSorter[{}]", partition_id)) + .with_can_spill(true) + .register(&runtime.memory_pool); + Self { - id: MemoryConsumerId::new(partition_id), schema, - in_mem_batches: Mutex::new(vec![]), - spills: Mutex::new(vec![]), + in_mem_batches: vec![], + spills: vec![], expr, session_config, runtime, metrics_set, metrics, fetch, + reservation, + partition_id, } } async fn insert_batch( - &self, + &mut self, input: RecordBatch, tracking_metrics: &MemTrackingMetrics, ) -> Result<()> { if input.num_rows() > 0 { let size = batch_byte_size(&input); - debug!("Inserting {} rows of {} bytes", input.num_rows(), size); - self.try_grow(size).await?; + if self.reservation.try_grow(size).is_err() { + self.spill().await?; + self.reservation.try_grow(size)? + } + self.metrics.mem_used().add(size); - let mut in_mem_batches = self.in_mem_batches.lock().await; // NB timer records time taken on drop, so there are no // calls to `timer.done()` below. let _timer = tracking_metrics.elapsed_compute().timer(); @@ -136,61 +144,56 @@ impl ExternalSorter { // We don't have to call try_grow here, since we have already used the // memory (so spilling right here wouldn't help at all for the current // operation). But we still have to record it so that other requesters - // would know about this unexpected increase in memory consuption. + // would know about this unexpected increase in memory consumption. let new_size_delta = new_size - size; - self.grow(new_size_delta); + self.reservation.grow(new_size_delta); self.metrics.mem_used().add(new_size_delta); } Ordering::Less => { let size_delta = size - new_size; - self.shrink(size_delta); + self.reservation.shrink(size_delta); self.metrics.mem_used().sub(size_delta); } Ordering::Equal => {} } - in_mem_batches.push(partial); + self.in_mem_batches.push(partial); } Ok(()) } - async fn spilled_before(&self) -> bool { - let spills = self.spills.lock().await; - !spills.is_empty() + fn spilled_before(&self) -> bool { + !self.spills.is_empty() } /// MergeSort in mem batches as well as spills into total order with `SortPreservingMergeStream`. - async fn sort(&self) -> Result { - let partition = self.partition_id(); + fn sort(&mut self) -> Result { let batch_size = self.session_config.batch_size(); - let mut in_mem_batches = self.in_mem_batches.lock().await; - if self.spilled_before().await { + if self.spilled_before() { let tracking_metrics = self .metrics_set - .new_intermediate_tracking(partition, self.runtime.clone()); + .new_intermediate_tracking(self.partition_id, &self.runtime.memory_pool); let mut streams: Vec = vec![]; - if in_mem_batches.len() > 0 { + if !self.in_mem_batches.is_empty() { let in_mem_stream = in_mem_partial_sort( - &mut in_mem_batches, + &mut self.in_mem_batches, self.schema.clone(), &self.expr, batch_size, tracking_metrics, self.fetch, )?; - let prev_used = self.free_all_memory(); + let prev_used = self.reservation.free(); streams.push(SortedStream::new(in_mem_stream, prev_used)); } - let mut spills = self.spills.lock().await; - - for spill in spills.drain(..) { + for spill in self.spills.drain(..) { let stream = read_spill_as_stream(spill, self.schema.clone())?; streams.push(SortedStream::new(stream, 0)); } let tracking_metrics = self .metrics_set - .new_final_tracking(partition, self.runtime.clone()); + .new_final_tracking(self.partition_id, &self.runtime.memory_pool); Ok(Box::pin(SortPreservingMergeStream::new_from_streams( streams, self.schema.clone(), @@ -198,12 +201,12 @@ impl ExternalSorter { tracking_metrics, self.session_config.batch_size(), )?)) - } else if in_mem_batches.len() > 0 { + } else if !self.in_mem_batches.is_empty() { let tracking_metrics = self .metrics_set - .new_final_tracking(partition, self.runtime.clone()); + .new_final_tracking(self.partition_id, &self.runtime.memory_pool); let result = in_mem_partial_sort( - &mut in_mem_batches, + &mut self.in_mem_batches, self.schema.clone(), &self.expr, batch_size, @@ -211,19 +214,13 @@ impl ExternalSorter { self.fetch, ); // Report to the memory manager we are no longer using memory - self.free_all_memory(); + self.reservation.free(); result } else { Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone()))) } } - fn free_all_memory(&self) -> usize { - let used = self.metrics.mem_used().set(0); - self.shrink(used); - used - } - fn used(&self) -> usize { self.metrics.mem_used().value() } @@ -235,66 +232,22 @@ impl ExternalSorter { fn spill_count(&self) -> usize { self.metrics.spill_count().value() } -} - -impl Debug for ExternalSorter { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ExternalSorter") - .field("id", &self.id()) - .field("memory_used", &self.used()) - .field("spilled_bytes", &self.spilled_bytes()) - .field("spill_count", &self.spill_count()) - .finish() - } -} -impl Drop for ExternalSorter { - fn drop(&mut self) { - self.runtime.drop_consumer(self.id(), self.used()); - } -} - -#[async_trait] -impl MemoryConsumer for ExternalSorter { - fn name(&self) -> String { - "ExternalSorter".to_owned() - } - - fn id(&self) -> &MemoryConsumerId { - &self.id - } - - fn memory_manager(&self) -> Arc { - self.runtime.memory_manager.clone() - } - - fn type_(&self) -> &ConsumerType { - &ConsumerType::Requesting - } - - async fn spill(&self) -> Result { - let partition = self.partition_id(); - let mut in_mem_batches = self.in_mem_batches.lock().await; + async fn spill(&mut self) -> Result { // we could always get a chance to free some memory as long as we are holding some - if in_mem_batches.len() == 0 { + if self.in_mem_batches.is_empty() { return Ok(0); } - debug!( - "{}[{}] spilling sort data of {} to disk while inserting ({} time(s) so far)", - self.name(), - self.id(), - self.used(), - self.spill_count() - ); + debug!("Spilling sort data of ExternalSorter to disk whilst inserting"); let tracking_metrics = self .metrics_set - .new_intermediate_tracking(partition, self.runtime.clone()); + .new_intermediate_tracking(self.partition_id, &self.runtime.memory_pool); let spillfile = self.runtime.disk_manager.create_tmp_file("Sorting")?; let stream = in_mem_partial_sort( - &mut in_mem_batches, + &mut self.in_mem_batches, self.schema.clone(), &self.expr, self.session_config.batch_size(), @@ -304,15 +257,21 @@ impl MemoryConsumer for ExternalSorter { spill_partial_sorted_stream(&mut stream?, spillfile.path(), self.schema.clone()) .await?; - let mut spills = self.spills.lock().await; + self.reservation.free(); let used = self.metrics.mem_used().set(0); self.metrics.record_spill(used); - spills.push(spillfile); + self.spills.push(spillfile); Ok(used) } +} - fn mem_used(&self) -> usize { - self.metrics.mem_used().value() +impl Debug for ExternalSorter { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ExternalSorter") + .field("memory_used", &self.used()) + .field("spilled_bytes", &self.spilled_bytes()) + .field("spill_count", &self.spill_count()) + .finish() } } @@ -528,7 +487,7 @@ impl SortedSizedRecordBatchStream { schema: SchemaRef, batches: Vec, sorted_iter: SortedIterator, - metrics: MemTrackingMetrics, + mut metrics: MemTrackingMetrics, ) -> Self { let size = batches.iter().map(batch_byte_size).sum::() + sorted_iter.memory_size(); @@ -911,8 +870,8 @@ async fn do_sort( ); let schema = input.schema(); let tracking_metrics = - metrics_set.new_intermediate_tracking(partition_id, context.runtime_env()); - let sorter = ExternalSorter::new( + metrics_set.new_intermediate_tracking(partition_id, context.memory_pool()); + let mut sorter = ExternalSorter::new( partition_id, schema.clone(), expr, @@ -921,12 +880,11 @@ async fn do_sort( context.runtime_env(), fetch, ); - context.runtime_env().register_requester(sorter.id()); while let Some(batch) = input.next().await { let batch = batch?; sorter.insert_batch(batch, &tracking_metrics).await?; } - let result = sorter.sort().await; + let result = sorter.sort(); debug!( "End do_sort for partition {} of context session_id {} and task_id {:?}", partition_id, @@ -1005,10 +963,7 @@ mod tests { assert_eq!(c7.value(c7.len() - 1), 254,); assert_eq!( - session_ctx - .runtime_env() - .memory_manager - .get_requester_total(), + session_ctx.runtime_env().memory_pool.reserved(), 0, "The sort should have returned all memory used back to the memory manager" ); @@ -1077,10 +1032,7 @@ mod tests { assert_eq!(c7.value(c7.len() - 1), 254,); assert_eq!( - session_ctx - .runtime_env() - .memory_manager - .get_requester_total(), + session_ctx.runtime_env().memory_pool.reserved(), 0, "The sort should have returned all memory used back to the memory manager" ); @@ -1100,7 +1052,7 @@ mod tests { // all the batches we are processing, we expect it to spill. (None, true), // When we have a limit however, the buffered size of batches should fit in memory - // since it is much lover than the total size of the input batch. + // since it is much lower than the total size of the input batch. (Some(1), false), ]; @@ -1331,10 +1283,7 @@ mod tests { assert_strong_count_converges_to_zero(refs).await; assert_eq!( - session_ctx - .runtime_env() - .memory_manager - .get_requester_total(), + session_ctx.runtime_env().memory_pool.reserved(), 0, "The sort should have returned all memory used back to the memory manager" ); diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs index 212c4c955b32..f069cc5b007c 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs @@ -169,7 +169,8 @@ impl ExecutionPlan for SortPreservingMergeExec { ))); } - let tracking_metrics = MemTrackingMetrics::new(&self.metrics, partition); + let tracking_metrics = + MemTrackingMetrics::new(&self.metrics, context.memory_pool(), partition); let input_partitions = self.input.output_partitioning().partition_count(); debug!( @@ -342,7 +343,7 @@ impl SortPreservingMergeStream { streams: Vec, schema: SchemaRef, expressions: &[PhysicalSortExpr], - tracking_metrics: MemTrackingMetrics, + mut tracking_metrics: MemTrackingMetrics, batch_size: usize, ) -> Result { let stream_count = streams.len(); @@ -1258,7 +1259,8 @@ mod tests { } let metrics = ExecutionPlanMetricsSet::new(); - let tracking_metrics = MemTrackingMetrics::new(&metrics, 0); + let tracking_metrics = + MemTrackingMetrics::new(&metrics, task_ctx.memory_pool(), 0); let merge_stream = SortPreservingMergeStream::new_from_streams( streams, diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs index 20ad555d66e1..91d66e884623 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit.rs @@ -19,14 +19,13 @@ use std::sync::Arc; -use arrow::record_batch::RecordBatch; use datafusion::datasource::MemTable; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_common::assert_contains; use datafusion::prelude::{SessionConfig, SessionContext}; -use test_utils::{stagger_batch, AccessLogGenerator}; +use test_utils::AccessLogGenerator; #[cfg(test)] #[ctor::ctor] @@ -39,6 +38,7 @@ async fn oom_sort() { run_limit_test( "select * from t order by host DESC", "Resources exhausted: Memory Exhausted while Sorting (DiskManager is disabled)", + 200_000, ) .await } @@ -47,7 +47,8 @@ async fn oom_sort() { async fn group_by_none() { run_limit_test( "select median(image) from t", - "Resources exhausted: Cannot spill GroupBy None Accumulators", + "Resources exhausted: Failed to allocate additional", + 20_000, ) .await } @@ -56,7 +57,8 @@ async fn group_by_none() { async fn group_by_row_hash() { run_limit_test( "select count(*) from t GROUP BY response_bytes", - "Resources exhausted: Cannot spill GroupBy Hash (Row) AggregationState", + "Resources exhausted: Failed to allocate additional", + 2_000, ) .await } @@ -66,23 +68,21 @@ async fn group_by_hash() { run_limit_test( // group by dict column "select count(*) from t GROUP BY service, host, pod, container", - "Resources exhausted: Cannot spill GroupBy Hash Accumulators", + "Resources exhausted: Failed to allocate additional", + 1_000, ) .await } /// 50 byte memory limit -const MEMORY_LIMIT_BYTES: usize = 50; const MEMORY_FRACTION: f64 = 0.95; /// runs the specified query against 1000 rows with a 50 /// byte memory limit and no disk manager enabled. -async fn run_limit_test(query: &str, expected_error: &str) { - let generator = AccessLogGenerator::new().with_row_limit(Some(1000)); - - let batches: Vec = generator - // split up into more than one batch, as the size limit in sort is not enforced until the second batch - .flat_map(stagger_batch) +async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize) { + let batches: Vec<_> = AccessLogGenerator::new() + .with_row_limit(1000) + .with_max_batch_size(50) .collect(); let table = MemTable::try_new(batches[0].schema(), vec![batches]).unwrap(); @@ -91,7 +91,7 @@ async fn run_limit_test(query: &str, expected_error: &str) { // do not allow spilling .with_disk_manager(DiskManagerConfig::Disabled) // Only allow 50 bytes - .with_memory_limit(MEMORY_LIMIT_BYTES, MEMORY_FRACTION); + .with_memory_limit(memory_limit, MEMORY_FRACTION); let runtime = RuntimeEnv::new(rt_config).unwrap(); diff --git a/datafusion/core/tests/order_spill_fuzz.rs b/datafusion/core/tests/order_spill_fuzz.rs index cc700d5d2cb7..923b44e2681b 100644 --- a/datafusion/core/tests/order_spill_fuzz.rs +++ b/datafusion/core/tests/order_spill_fuzz.rs @@ -22,7 +22,7 @@ use arrow::{ compute::SortOptions, record_batch::RecordBatch, }; -use datafusion::execution::memory_manager::MemoryManagerConfig; +use datafusion::execution::memory_pool::GreedyMemoryPool; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::physical_plan::expressions::{col, PhysicalSortExpr}; use datafusion::physical_plan::memory::MemoryExec; @@ -31,18 +31,18 @@ use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; use rand::Rng; use std::sync::Arc; -use test_utils::{batches_to_vec, partitions_to_sorted_vec, stagger_batch_with_seed}; +use test_utils::{batches_to_vec, partitions_to_sorted_vec}; #[tokio::test] #[cfg_attr(tarpaulin, ignore)] async fn test_sort_1k_mem() { - run_sort(1024, vec![(5, false), (2000, true), (1000000, true)]).await + run_sort(10240, vec![(5, false), (20000, true), (1000000, true)]).await } #[tokio::test] #[cfg_attr(tarpaulin, ignore)] async fn test_sort_100k_mem() { - run_sort(102400, vec![(5, false), (2000, false), (1000000, true)]).await + run_sort(102400, vec![(5, false), (20000, false), (1000000, true)]).await } #[tokio::test] @@ -76,9 +76,8 @@ async fn run_sort(pool_size: usize, size_spill: Vec<(usize, bool)>) { let exec = MemoryExec::try_new(&input, schema, None).unwrap(); let sort = Arc::new(SortExec::try_new(sort, Arc::new(exec), None).unwrap()); - let runtime_config = RuntimeConfig::new().with_memory_manager( - MemoryManagerConfig::try_new_limit(pool_size, 1.0).unwrap(), - ); + let runtime_config = RuntimeConfig::new() + .with_memory_pool(Arc::new(GreedyMemoryPool::new(pool_size))); let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); let session_ctx = SessionContext::with_config_rt(SessionConfig::new(), runtime); @@ -95,12 +94,9 @@ async fn run_sort(pool_size: usize, size_spill: Vec<(usize, bool)>) { } assert_eq!( - session_ctx - .runtime_env() - .memory_manager - .get_requester_total(), + session_ctx.runtime_env().memory_pool.reserved(), 0, - "The sort should have returned all memory used back to the memory manager" + "The sort should have returned all memory used back to the memory pool" ); assert_eq!(expected, actual, "failure in @ pool_size {}", pool_size); } @@ -110,12 +106,23 @@ async fn run_sort(pool_size: usize, size_spill: Vec<(usize, bool)>) { /// with randomized i32 content fn make_staggered_batches(len: usize) -> Vec { let mut rng = rand::thread_rng(); - let mut input: Vec = vec![0; len]; - rng.fill(&mut input[..]); - let input = Int32Array::from_iter_values(input.into_iter()); - - // split into several record batches - let batch = - RecordBatch::try_from_iter(vec![("x", Arc::new(input) as ArrayRef)]).unwrap(); - stagger_batch_with_seed(batch, 42) + let max_batch = 1024; + + let mut batches = vec![]; + let mut remaining = len; + while remaining != 0 { + let to_read = rng.gen_range(0..=remaining.min(max_batch)); + remaining -= to_read; + + batches.push( + RecordBatch::try_from_iter(vec![( + "x", + Arc::new(Int32Array::from_iter_values( + std::iter::from_fn(|| Some(rng.gen())).take(to_read), + )) as ArrayRef, + )]) + .unwrap(), + ) + } + batches } diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index 999becafd0a5..fc74d7ded7e7 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -57,7 +57,7 @@ async fn single_file() { let tempdir = TempDir::new().unwrap(); - let generator = AccessLogGenerator::new().with_row_limit(Some(NUM_ROWS)); + let generator = AccessLogGenerator::new().with_row_limit(NUM_ROWS); // default properties let props = WriterProperties::builder().build(); @@ -236,7 +236,7 @@ async fn single_file() { async fn single_file_small_data_pages() { let tempdir = TempDir::new().unwrap(); - let generator = AccessLogGenerator::new().with_row_limit(Some(NUM_ROWS)); + let generator = AccessLogGenerator::new().with_row_limit(NUM_ROWS); // set the max page rows with arbitrary sizes 8311 to increase // effectiveness of page filtering diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs index 7276820f6f59..13160fd52e1a 100644 --- a/datafusion/core/tests/provider_filter_pushdown.rs +++ b/datafusion/core/tests/provider_filter_pushdown.rs @@ -90,10 +90,11 @@ impl ExecutionPlan for CustomPlan { fn execute( &self, partition: usize, - _context: Arc, + context: Arc, ) -> Result { let metrics = ExecutionPlanMetricsSet::new(); - let tracking_metrics = MemTrackingMetrics::new(&metrics, partition); + let tracking_metrics = + MemTrackingMetrics::new(&metrics, context.memory_pool(), partition); Ok(Box::pin(SizedRecordBatchStream::new( self.schema(), self.batches.clone(), diff --git a/test-utils/src/data_gen.rs b/test-utils/src/data_gen.rs index c82d56ef21f9..19db65400a17 100644 --- a/test-utils/src/data_gen.rs +++ b/test-utils/src/data_gen.rs @@ -78,6 +78,13 @@ impl BatchBuilder { ])) } + fn is_finished(&self) -> bool { + self.row_limit + .as_ref() + .map(|x| *x <= self.row_count) + .unwrap_or_default() + } + fn append(&mut self, rng: &mut StdRng, host: &str, service: &str) { let num_pods = rng.gen_range(1..15); let pods = generate_sorted_strings(rng, num_pods, 30..40); @@ -91,6 +98,10 @@ impl BatchBuilder { let num_entries = rng.gen_range(1024..8192); for i in 0..num_entries { + if self.is_finished() { + return; + } + let time = i as i64 * 1024; self.append_row(rng, host, &pod, service, &container, &image, time); } @@ -109,12 +120,6 @@ impl BatchBuilder { image: &str, time: i64, ) { - // skip if over limit - if let Some(limit) = self.row_limit { - if self.row_count >= limit { - return; - } - } self.row_count += 1; let methods = &["GET", "PUT", "POST", "HEAD", "PATCH", "DELETE"]; @@ -237,7 +242,9 @@ pub struct AccessLogGenerator { rng: StdRng, host_idx: usize, /// optional number of rows produced - row_limit: Option, + row_limit: usize, + /// maximum rows per batch + max_batch_size: usize, /// How many rows have been returned so far row_count: usize, } @@ -259,7 +266,8 @@ impl AccessLogGenerator { schema: BatchBuilder::schema(), host_idx: 0, rng: StdRng::from_seed(seed), - row_limit: None, + row_limit: usize::MAX, + max_batch_size: usize::MAX, row_count: 0, } } @@ -269,8 +277,14 @@ impl AccessLogGenerator { self.schema.clone() } + /// Limit the maximum batch size + pub fn with_max_batch_size(mut self, batch_size: usize) -> Self { + self.max_batch_size = batch_size; + self + } + /// Return up to row_limit rows; - pub fn with_row_limit(mut self, row_limit: Option) -> Self { + pub fn with_row_limit(mut self, row_limit: usize) -> Self { self.row_limit = row_limit; self } @@ -280,15 +294,13 @@ impl Iterator for AccessLogGenerator { type Item = RecordBatch; fn next(&mut self) -> Option { - // if we have a limit and have passed it, stop generating - if let Some(limit) = self.row_limit { - if self.row_count >= limit { - return None; - } + if self.row_count == self.row_limit { + return None; } - let mut builder = BatchBuilder::default() - .with_row_limit(self.row_limit.map(|limit| limit - self.row_count)); + let mut builder = BatchBuilder::default().with_row_limit(Some( + self.max_batch_size.min(self.row_limit - self.row_count), + )); let host = format!( "i-{:016x}.ec2.internal", @@ -300,19 +312,14 @@ impl Iterator for AccessLogGenerator { if self.rng.gen_bool(0.5) { continue; } + if builder.is_finished() { + break; + } builder.append(&mut self.rng, &host, service); } let batch = builder.finish(Arc::clone(&self.schema)); - // limit batch if needed to stay under row limit - let batch = if let Some(limit) = self.row_limit { - let num_rows = limit - self.row_count; - batch.slice(0, num_rows) - } else { - batch - }; - self.row_count += batch.num_rows(); Some(batch) }