diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index fb6759a195de..ee4e262fa449 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -19,16 +19,17 @@ use std::any::Any; +use std::collections::HashSet; use std::sync::Arc; -use arrow::datatypes::Schema; +use arrow::datatypes::{DataType, Field, Schema}; use arrow::{self, datatypes::SchemaRef}; use async_trait::async_trait; use bytes::Buf; use datafusion_common::DataFusionError; -use futures::TryFutureExt; +use futures::{pin_mut, StreamExt, TryStreamExt}; use object_store::{ObjectMeta, ObjectStore}; use super::FileFormat; @@ -37,7 +38,9 @@ use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; use crate::error::Result; use crate::execution::context::SessionState; use crate::logical_expr::Expr; -use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; +use crate::physical_plan::file_format::{ + newline_delimited_stream, CsvExec, FileScanConfig, +}; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; @@ -122,27 +125,75 @@ impl FileFormat for CsvFormat { let mut records_to_read = self.schema_infer_max_rec.unwrap_or(usize::MAX); - for object in objects { - let data = store + 'iterating_objects: for object in objects { + // stream to only read as many rows as needed into memory + let stream = store .get(&object.location) - .and_then(|r| r.bytes()) - .await - .map_err(|e| DataFusionError::External(Box::new(e)))?; - - let decoder = self.file_compression_type.convert_read(data.reader())?; - let (schema, records_read) = arrow::csv::reader::infer_reader_schema( - decoder, - self.delimiter, - Some(records_to_read), - self.has_header, - )?; - schemas.push(schema.clone()); - if records_read == 0 { - continue; + .await? + .into_stream() + .map_err(|e| DataFusionError::External(Box::new(e))); + let stream = newline_delimited_stream(stream); + pin_mut!(stream); + + let mut column_names = vec![]; + let mut column_type_possibilities = vec![]; + let mut first_chunk = true; + + 'reading_object: while let Some(data) = stream.next().await.transpose()? { + let (Schema { fields, .. }, records_read) = + arrow::csv::reader::infer_reader_schema( + self.file_compression_type.convert_read(data.reader())?, + self.delimiter, + Some(records_to_read), + // only consider header for first chunk + self.has_header && first_chunk, + )?; + records_to_read -= records_read; + + if first_chunk { + // set up initial structures for recording inferred schema across chunks + (column_names, column_type_possibilities) = fields + .into_iter() + .map(|field| { + let mut possibilities = HashSet::new(); + if records_read > 0 { + // at least 1 data row read, record the inferred datatype + possibilities.insert(field.data_type().clone()); + } + (field.name().clone(), possibilities) + }) + .unzip(); + first_chunk = false; + } else { + if fields.len() != column_type_possibilities.len() { + return Err(DataFusionError::Execution( + format!( + "Encountered unequal lengths between records on CSV file whilst inferring schema. \ + Expected {} records, found {} records", + column_type_possibilities.len(), + fields.len() + ) + )); + } + + column_type_possibilities.iter_mut().zip(fields).for_each( + |(possibilities, field)| { + possibilities.insert(field.data_type().clone()); + }, + ); + } + + if records_to_read == 0 { + break 'reading_object; + } } - records_to_read -= records_read; + + schemas.push(build_schema_helper( + column_names, + &column_type_possibilities, + )); if records_to_read == 0 { - break; + break 'iterating_objects; } } @@ -176,14 +227,50 @@ impl FileFormat for CsvFormat { } } +fn build_schema_helper(names: Vec, types: &[HashSet]) -> Schema { + let fields = names + .into_iter() + .zip(types) + .map(|(field_name, data_type_possibilities)| { + // ripped from arrow::csv::reader::infer_reader_schema_with_csv_options + // determine data type based on possible types + // if there are incompatible types, use DataType::Utf8 + match data_type_possibilities.len() { + 1 => Field::new( + field_name, + data_type_possibilities.iter().next().unwrap().clone(), + true, + ), + 2 => { + if data_type_possibilities.contains(&DataType::Int64) + && data_type_possibilities.contains(&DataType::Float64) + { + // we have an integer and double, fall down to double + Field::new(field_name, DataType::Float64, true) + } else { + // default to Utf8 for conflicting datatypes (e.g bool and int) + Field::new(field_name, DataType::Utf8, true) + } + } + _ => Field::new(field_name, DataType::Utf8, true), + } + }) + .collect(); + Schema::new(fields) +} + #[cfg(test)] mod tests { use super::super::test_util::scan_format; use super::*; + use crate::datasource::file_format::test_util::VariableStream; use crate::physical_plan::collect; use crate::prelude::{SessionConfig, SessionContext}; + use bytes::Bytes; + use chrono::DateTime; use datafusion_common::cast::as_string_array; use futures::StreamExt; + use object_store::path::Path; #[tokio::test] async fn read_small_batches() -> Result<()> { @@ -291,6 +378,57 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_infer_schema_stream() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let variable_object_store = + Arc::new(VariableStream::new(Bytes::from("1,2,3,4,5\n"), 200)); + let object_meta = ObjectMeta { + location: Path::parse("/")?, + last_modified: DateTime::default(), + size: usize::MAX, + }; + + let num_rows_to_read = 100; + let csv_format = CsvFormat { + has_header: false, + schema_infer_max_rec: Some(num_rows_to_read), + ..Default::default() + }; + let inferred_schema = csv_format + .infer_schema( + &state, + &(variable_object_store.clone() as Arc), + &[object_meta], + ) + .await?; + + let actual_fields: Vec<_> = inferred_schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect(); + assert_eq!( + vec![ + "column_1: Int64", + "column_2: Int64", + "column_3: Int64", + "column_4: Int64", + "column_5: Int64" + ], + actual_fields + ); + // ensuring on csv infer that it won't try to read entire file + // should only read as many rows as was configured in the CsvFormat + assert_eq!( + num_rows_to_read, + variable_object_store.get_iterations_detected() + ); + + Ok(()) + } + async fn get_exec( state: &SessionState, file_name: &str, diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index ee9795a77563..6b377a0fc9c0 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -87,11 +87,20 @@ pub trait FileFormat: Send + Sync + fmt::Debug { #[cfg(test)] pub(crate) mod test_util { + use std::ops::Range; + use std::sync::Mutex; + use super::*; use crate::datasource::listing::PartitionedFile; use crate::datasource::object_store::ObjectStoreUrl; use crate::test::object_store::local_unpartitioned_file; + use bytes::Bytes; + use futures::stream::BoxStream; + use futures::StreamExt; use object_store::local::LocalFileSystem; + use object_store::path::Path; + use object_store::{GetResult, ListResult, MultipartId}; + use tokio::io::AsyncWrite; pub async fn scan_format( state: &SessionState, @@ -136,4 +145,121 @@ pub(crate) mod test_util { .await?; Ok(exec) } + + /// Mock ObjectStore to provide an variable stream of bytes on get + /// Able to keep track of how many iterations of the provided bytes were repeated + #[derive(Debug)] + pub struct VariableStream { + bytes_to_repeat: Bytes, + max_iterations: usize, + iterations_detected: Arc>, + } + + impl std::fmt::Display for VariableStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "VariableStream") + } + } + + #[async_trait] + impl ObjectStore for VariableStream { + async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + unimplemented!() + } + + async fn put_multipart( + &self, + _location: &Path, + ) -> object_store::Result<(MultipartId, Box)> + { + unimplemented!() + } + + async fn abort_multipart( + &self, + _location: &Path, + _multipart_id: &MultipartId, + ) -> object_store::Result<()> { + unimplemented!() + } + + async fn get(&self, _location: &Path) -> object_store::Result { + let bytes = self.bytes_to_repeat.clone(); + let arc = self.iterations_detected.clone(); + Ok(GetResult::Stream( + futures::stream::repeat_with(move || { + let arc_inner = arc.clone(); + *arc_inner.lock().unwrap() += 1; + Ok(bytes.clone()) + }) + .take(self.max_iterations) + .boxed(), + )) + } + + async fn get_range( + &self, + _location: &Path, + _range: Range, + ) -> object_store::Result { + unimplemented!() + } + + async fn get_ranges( + &self, + _location: &Path, + _ranges: &[Range], + ) -> object_store::Result> { + unimplemented!() + } + + async fn head(&self, _location: &Path) -> object_store::Result { + unimplemented!() + } + + async fn delete(&self, _location: &Path) -> object_store::Result<()> { + unimplemented!() + } + + async fn list( + &self, + _prefix: Option<&Path>, + ) -> object_store::Result>> + { + unimplemented!() + } + + async fn list_with_delimiter( + &self, + _prefix: Option<&Path>, + ) -> object_store::Result { + unimplemented!() + } + + async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { + unimplemented!() + } + + async fn copy_if_not_exists( + &self, + _from: &Path, + _to: &Path, + ) -> object_store::Result<()> { + unimplemented!() + } + } + + impl VariableStream { + pub fn new(bytes_to_repeat: Bytes, max_iterations: usize) -> Self { + Self { + bytes_to_repeat, + max_iterations, + iterations_detected: Arc::new(Mutex::new(0)), + } + } + + pub fn get_iterations_detected(&self) -> usize { + *self.iterations_detected.lock().unwrap() + } + } } diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs b/datafusion/core/src/physical_plan/file_format/mod.rs index 7eb9730c92ac..058ef1394d5b 100644 --- a/datafusion/core/src/physical_plan/file_format/mod.rs +++ b/datafusion/core/src/physical_plan/file_format/mod.rs @@ -28,6 +28,7 @@ mod parquet; pub(crate) use self::csv::plan_to_csv; pub use self::csv::CsvExec; +pub(crate) use self::delimited_stream::newline_delimited_stream; pub(crate) use self::parquet::plan_to_parquet; pub use self::parquet::{ParquetExec, ParquetFileMetrics, ParquetFileReaderFactory}; use arrow::{