diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index 53cb0a9ab..4561c1e35 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -27,6 +27,9 @@ authors = ["Apache Arrow "] edition = "2018" build = "build.rs" +[package.metadata.docs.rs] +rustc-args = ["--cfg", "docsrs"] + [features] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = ["datafusion/force_hash_collisions"] diff --git a/ballista/rust/core/build.rs b/ballista/rust/core/build.rs index c2acde108..ab5d050d0 100644 --- a/ballista/rust/core/build.rs +++ b/ballista/rust/core/build.rs @@ -16,6 +16,10 @@ // under the License. fn main() -> Result<(), String> { + use std::io::Write; + + let out = std::path::PathBuf::from(std::env::var("OUT_DIR").unwrap()); + // for use in docker build where file changes can be wonky println!("cargo:rerun-if-env-changed=FORCE_REBUILD"); @@ -26,5 +30,22 @@ fn main() -> Result<(), String> { tonic_build::configure() .extern_path(".datafusion", "::datafusion_proto::protobuf") .compile(&["proto/ballista.proto"], &["proto"]) - .map_err(|e| format!("protobuf compilation failed: {}", e)) + .map_err(|e| format!("protobuf compilation failed: {}", e))?; + + // TODO: undo when resolved: https://github.com/intellij-rust/intellij-rust/issues/9402 + #[cfg(feature = "docsrs")] + let path = out.join("ballista.rs"); + #[cfg(not(feature = "docsrs"))] + let path = "src/serde/generated/ballista.rs"; + + let code = std::fs::read_to_string(out.join("ballista.protobuf.rs")).unwrap(); + let mut file = std::fs::OpenOptions::new() + .write(true) + .truncate(true) + .create(true) + .open(path) + .unwrap(); + file.write_all(code.as_str().as_ref()).unwrap(); + + Ok(()) } diff --git a/ballista/rust/core/src/client.rs b/ballista/rust/core/src/client.rs index dfe2003fb..61c19c643 100644 --- a/ballista/rust/core/src/client.rs +++ b/ballista/rust/core/src/client.rs @@ -26,7 +26,6 @@ use std::{ }; use crate::error::{ballista_error, BallistaError, Result}; -use crate::serde::protobuf::{self}; use crate::serde::scheduler::Action; use arrow_flight::utils::flight_data_to_arrow_batch; @@ -39,6 +38,7 @@ use datafusion::arrow::{ record_batch::RecordBatch, }; +use crate::serde::protobuf; use crate::utils::create_grpc_client_connection; use datafusion::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; use futures::{Stream, StreamExt}; diff --git a/ballista/rust/core/src/serde/generated/.gitignore b/ballista/rust/core/src/serde/generated/.gitignore new file mode 100644 index 000000000..42eb8bcd5 --- /dev/null +++ b/ballista/rust/core/src/serde/generated/.gitignore @@ -0,0 +1,4 @@ +* + +!.gitignore +!mod.rs diff --git a/ballista/rust/core/src/serde/generated/mod.rs b/ballista/rust/core/src/serde/generated/mod.rs new file mode 100644 index 000000000..d399ad3b7 --- /dev/null +++ b/ballista/rust/core/src/serde/generated/mod.rs @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// include the generated protobuf source as a submodule +#[allow(clippy::all)] +#[rustfmt::skip] +#[cfg(not(docsrs))] +pub mod ballista; + +#[cfg(docsrs)] +#[allow(clippy::all)] +pub mod ballista { + include!(concat!(env!("OUT_DIR"), "/ballista.rs")); +} diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 8e2bc8c9b..1e3be74b3 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -33,12 +33,9 @@ use std::marker::PhantomData; use std::sync::Arc; use std::{convert::TryInto, io::Cursor}; -// include the generated protobuf source as a submodule -#[allow(clippy::all)] -pub mod protobuf { - include!(concat!(env!("OUT_DIR"), "/ballista.protobuf.rs")); -} +pub use generated::ballista as protobuf; +pub mod generated; pub mod physical_plan; pub mod scheduler; @@ -167,7 +164,9 @@ impl BallistaCodec {{ if let Some(field) = $PB.as_ref() { - Ok(field.try_into()?) + Ok(field + .try_into() + .map_err(|_| proto_error("Failed to convert!"))?) } else { Err(proto_error("Missing required field in protobuf")) } diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 4b4bea5ce..af52f3d2e 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -42,7 +42,7 @@ use datafusion::physical_plan::{ColumnStatistics, PhysicalExpr, Statistics}; use object_store::path::Path; use object_store::ObjectMeta; -use protobuf::physical_expr_node::ExprType; +use crate::serde::protobuf::physical_expr_node::ExprType; use crate::convert_required; use crate::error::BallistaError; diff --git a/ballista/rust/scheduler/src/scheduler_server/grpc.rs b/ballista/rust/scheduler/src/scheduler_server/grpc.rs index f57262eb2..3b72b5a53 100644 --- a/ballista/rust/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/rust/scheduler/src/scheduler_server/grpc.rs @@ -17,6 +17,7 @@ use ballista_core::config::{BallistaConfig, TaskSchedulingPolicy}; use ballista_core::serde::protobuf::execute_query_params::{OptionalSessionId, Query}; +use std::convert::TryInto; use ballista_core::serde::protobuf::executor_registration::OptionalHost; use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc; @@ -39,8 +40,6 @@ use datafusion_proto::logical_plan::AsLogicalPlan; use futures::TryStreamExt; use log::{debug, error, info, warn}; -// use http_body::Body; -use std::convert::TryInto; use std::ops::Deref; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH};