Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Jan 31, 2025
1 parent ea1ea5a commit 61b6e1d
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 23 deletions.
25 changes: 24 additions & 1 deletion crates/polars-io/src/catalog/unity/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use polars_core::prelude::PlHashMap;
use polars_core::schema::Schema;
use polars_error::{polars_bail, to_compute_err, PolarsResult};

use super::models::{CatalogInfo, NamespaceInfo, TableInfo};
use super::models::{CatalogInfo, NamespaceInfo, TableCredentials, TableInfo};
use super::utils::{do_request, PageWalker};
use crate::catalog::schema::schema_to_column_info_list;
use crate::catalog::unity::models::{ColumnInfo, DataSourceFormat, TableType};
Expand Down Expand Up @@ -83,6 +83,29 @@ impl CatalogClient {
Ok(out)
}

pub async fn get_table_credentials(
&self,
table_id: &str,
write: bool,
) -> PolarsResult<TableCredentials> {
let bytes = do_request(
self.http_client
.post(format!(
"{}{}",
&self.workspace_url, "/api/2.1/unity-catalog/temporary-table-credentials"
))
.query(&[
("table_id", table_id),
("operation", if write { "READ_WRITE" } else { "READ" }),
]),
)
.await?;

let out: TableCredentials = decode_json_response(&bytes)?;

Ok(out)
}

pub async fn create_catalog(
&self,
catalog_name: &str,
Expand Down
48 changes: 48 additions & 0 deletions crates/polars-io/src/catalog/unity/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,54 @@ impl ColumnTypeJsonType {
}
}

#[derive(Debug, serde::Deserialize)]
pub struct TableCredentials {
pub aws_temp_credentials: Option<TableCredentialsAws>,
pub azure_user_delegation_sas: Option<TableCredentialsAzure>,
pub gcp_oauth_token: Option<TableCredentialsGcp>,
pub expiration_time: i64,
}

impl TableCredentials {
pub fn into_enum(self) -> Option<TableCredentialsVariants> {
if let v @ Some(_) = self.aws_temp_credentials {
v.map(TableCredentialsVariants::Aws)
} else if let v @ Some(_) = self.azure_user_delegation_sas {
v.map(TableCredentialsVariants::Azure)
} else if let v @ Some(_) = self.gcp_oauth_token {
v.map(TableCredentialsVariants::Gcp)
} else {
None
}
}
}

pub enum TableCredentialsVariants {
Aws(TableCredentialsAws),
Azure(TableCredentialsAzure),
Gcp(TableCredentialsGcp),
}

#[derive(Debug, serde::Deserialize)]
pub struct TableCredentialsAws {
pub access_key_id: String,
pub secret_access_key: String,
pub session_token: Option<String>,

#[serde(default)]
pub access_point: Option<String>,
}

#[derive(Debug, serde::Deserialize)]
pub struct TableCredentialsAzure {
pub sas_token: String,
}

#[derive(Debug, serde::Deserialize)]
pub struct TableCredentialsGcp {
pub oauth_token: String,
}

fn null_to_default<'de, T, D>(d: D) -> Result<T, D::Error>
where
T: Default + serde::de::Deserialize<'de>,
Expand Down
79 changes: 72 additions & 7 deletions crates/polars-python/src/catalog/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use polars_io::cloud::credential_provider::PlCredentialProvider;
use polars_io::pl_async;
use pyo3::exceptions::PyValueError;
use pyo3::sync::GILOnceCell;
use pyo3::types::{PyAnyMethods, PyDict, PyList};
use pyo3::types::{PyAnyMethods, PyDict, PyList, PyNone, PyTuple};
use pyo3::{pyclass, pymethods, Bound, IntoPyObject, Py, PyAny, PyObject, PyResult, Python};

use crate::lazyframe::PyLazyFrame;
Expand All @@ -19,7 +19,7 @@ use crate::utils::{to_py_err, EnterPolarsExt};

macro_rules! pydict_insert_keys {
($dict:expr, {$a:expr}) => {
$dict.set_item(stringify!($a), $a).unwrap();
$dict.set_item(stringify!($a), $a)?;
};

($dict:expr, {$a:expr, $($args:expr),+}) => {
Expand Down Expand Up @@ -166,6 +166,73 @@ impl PyCatalogClient {
table_info_to_pyobject(py, table_info).map(|x| x.into())
}

#[pyo3(signature = (table_id, write))]
pub fn get_table_credentials(
&self,
py: Python,
table_id: &str,
write: bool,
) -> PyResult<PyObject> {
let table_credentials = py
.enter_polars(|| {
pl_async::get_runtime()
.block_on_potential_spawn(self.client().get_table_credentials(table_id, write))
})
.map_err(to_py_err)?;

let expiry = table_credentials.expiration_time;

let credentials = PyDict::new(py);
// Keys in here are intended to be injected into `storage_options` from the Python side.
// Note this currently really only exists for `aws_endpoint_url`.
let storage_update_options = PyDict::new(py);

{
use polars_io::catalog::unity::models::{
TableCredentialsAws, TableCredentialsAzure, TableCredentialsGcp,
TableCredentialsVariants,
};
use TableCredentialsVariants::*;

match table_credentials.into_enum() {
Some(Aws(TableCredentialsAws {
access_key_id,
secret_access_key,
session_token,
access_point,
})) => {
credentials.set_item("aws_access_key_id", access_key_id)?;
credentials.set_item("aws_secret_access_key", secret_access_key)?;

if let Some(session_token) = session_token {
credentials.set_item("aws_session_token", session_token)?;
}

if let Some(access_point) = access_point {
storage_update_options.set_item("aws_endpoint_url", access_point)?;
}
},
Some(Azure(TableCredentialsAzure { sas_token })) => {
credentials.set_item("sas_token", sas_token)?;
},
Some(Gcp(TableCredentialsGcp { oauth_token })) => {
credentials.set_item("bearer_token", oauth_token)?;
},
None => {},
}
}

let credentials = if credentials.len()? > 0 {
credentials.into_any()
} else {
PyNone::get(py).as_any().clone()
};
let storage_update_options = storage_update_options.into_any();
let expiry = expiry.into_pyobject(py)?.into_any();

Ok(PyTuple::new(py, [credentials, storage_update_options, expiry])?.into())
}

#[pyo3(signature = (catalog_name, namespace, table_name, cloud_options, credential_provider, retries))]
pub fn scan_table(
&self,
Expand Down Expand Up @@ -456,6 +523,8 @@ fn table_info_to_pyobject(py: Python, table_info: TableInfo) -> PyResult<Bound<'
updated_by,
} = table_info;

let column_info_cls = COLUMN_INFO_CLS.get(py).unwrap().bind(py);

let columns = columns
.map(|columns| {
columns
Expand Down Expand Up @@ -486,11 +555,7 @@ fn table_info_to_pyobject(py: Python, table_info: TableInfo) -> PyResult<Bound<'
partition_index,
});

COLUMN_INFO_CLS
.get(py)
.unwrap()
.bind(py)
.call((), Some(&dict))
column_info_cls.call((), Some(&dict))
},
)
.collect::<PyResult<Vec<_>>>()
Expand Down
55 changes: 40 additions & 15 deletions py-polars/polars/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

from polars._typing import SchemaDict
from polars.datatypes.classes import DataType
from polars.io.cloud import CredentialProviderFunction
from polars.io.cloud import (
CredentialProviderFunction,
CredentialProviderFunctionReturn,
)
from polars.lazyframe import LazyFrame


Expand Down Expand Up @@ -59,8 +62,9 @@ def __init__(
bearer_token == "auto"
# For security, in "auto" mode, only retrieve/use the token if:
# * We are running inside a Databricks environment
# * The `workspace_url` is pointing to Databricks
# * The `workspace_url` is pointing to Databricks and uses HTTPS
and "DATABRICKS_RUNTIME_VERSION" in os.environ
and workspace_url.startswith("https://")
and (
workspace_url.removeprefix("https://")
.split("/", 1)[0]
Expand Down Expand Up @@ -137,6 +141,11 @@ def get_table_info(
"""
return self._client.get_table_info(catalog_name, namespace, table_name)

def _get_table_credentials(
self, table_id: str, *, write: bool
) -> tuple[dict[str, str] | None, dict[str, str], int]:
return self._client.get_table_credentials(table_id=table_id, write=write)

def scan_table(
self,
catalog_name: str,
Expand Down Expand Up @@ -201,18 +210,29 @@ def scan_table(
table_info, "scan table"
)

credential_provider = CatalogCredentialProvider(
self, table_info.table_id, write=False
)

_, storage_update_options, _ = self._get_table_credentials(
table_info.table_id, write=False
)

if storage_options is not None or storage_update_options is not None:
storage_options = {
**(storage_options or {}),
**(storage_update_options or {}),
}

if data_source_format in ["DELTA", "DELTASHARING"]:
from polars.io.delta import scan_delta

if credential_provider is not None and credential_provider != "auto":
msg = "credential_provider when scanning DELTA"
raise NotImplementedError(msg)

return scan_delta(
storage_location,
version=delta_table_version,
delta_table_options=delta_table_options,
storage_options=storage_options,
credential_provider=credential_provider,
)

if delta_table_version is not None:
Expand All @@ -229,15 +249,6 @@ def scan_table(
)
raise ValueError(msg)

from polars.io.cloud.credential_provider import _maybe_init_credential_provider

credential_provider = _maybe_init_credential_provider(
credential_provider,
storage_location,
storage_options,
"Catalog.scan_table",
)

if storage_options:
storage_options = list(storage_options.items()) # type: ignore[assignment]
else:
Expand Down Expand Up @@ -457,6 +468,20 @@ def _get_databricks_token(cls) -> str:
return m["DefaultCredentials"]()(m["Config"]())()["Authorization"][7:]


class CatalogCredentialProvider:
def __init__(self, catalog: Catalog, table_id: str, *, write: bool):
self.catalog = catalog
self.table_id = table_id
self.write = write

def __call__(self) -> CredentialProviderFunctionReturn:
creds, _, expiry = self.catalog._get_table_credentials(
self.table_id, write=self.write
)

return creds, expiry


def _extract_location_and_data_format(
table_info: TableInfo, operation: str
) -> tuple[str, DataSourceFormat]:
Expand Down

0 comments on commit 61b6e1d

Please sign in to comment.