Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid importing Pandas until we actually use it. #1825

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 32 additions & 26 deletions tiledb/multirange_indexing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.util
import json
import time
import weakref
Expand All @@ -8,6 +9,7 @@
from dataclasses import dataclass
from numbers import Real
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
Expand All @@ -30,19 +32,14 @@
from .query_condition import QueryCondition
from .subarray import Subarray

current_timer: ContextVar[str] = ContextVar("timer_scope")

try:
if TYPE_CHECKING:
# We don't want to import these eagerly since importing Pandas in particular
# can add around half a second of import time even if we never use it.
import pandas
import pyarrow
from pyarrow import Table
except ImportError:
pyarrow = Table = None

try:
import pandas as pd
from pandas import DataFrame
except ImportError:
DataFrame = None

current_timer: ContextVar[str] = ContextVar("timer_scope")


# sentinel value to denote selecting an empty range
Expand Down Expand Up @@ -373,8 +370,12 @@ def __init__(
# we need to use a Query in order to get coords for a dense array
if not query:
query = QueryProxy(array, coords=True)
if use_arrow is None:
use_arrow = pyarrow is not None
use_arrow = (
bool(importlib.util.find_spec("pyarrow"))
if use_arrow is None
else use_arrow
)

# TODO: currently there is lack of support for Arrow list types. This prevents
# multi-value attributes, asides from strings, from being queried properly.
# Until list attributes are supported in core, error with a clear message.
Expand All @@ -390,12 +391,15 @@ def __init__(
)
super().__init__(array, query, use_arrow, preload_metadata=True)

def _run_query(self) -> Union[DataFrame, Table]:
def _run_query(self) -> Union["pandas.DataFrame", "pyarrow.Table"]:
import pandas
import pyarrow

if self.pyquery is not None:
self.pyquery.submit()

if self.pyquery is None:
df = DataFrame(self._empty_results)
df = pandas.DataFrame(self._empty_results)
elif self.use_arrow:
with timing("buffer_conversion_time"):
table = self.pyquery._buffers_to_pa_table()
Expand All @@ -417,14 +421,14 @@ def _run_query(self) -> Union[DataFrame, Table]:
# converting all integers with NULLs to float64:
# https://arrow.apache.org/docs/python/pandas.html#arrow-pandas-conversion
extended_dtype_mapping = {
pyarrow.int8(): pd.Int8Dtype(),
pyarrow.int16(): pd.Int16Dtype(),
pyarrow.int32(): pd.Int32Dtype(),
pyarrow.int64(): pd.Int64Dtype(),
pyarrow.uint8(): pd.UInt8Dtype(),
pyarrow.uint16(): pd.UInt16Dtype(),
pyarrow.uint32(): pd.UInt32Dtype(),
pyarrow.uint64(): pd.UInt64Dtype(),
pyarrow.int8(): pandas.Int8Dtype(),
pyarrow.int16(): pandas.Int16Dtype(),
pyarrow.int32(): pandas.Int32Dtype(),
pyarrow.int64(): pandas.Int64Dtype(),
pyarrow.uint8(): pandas.UInt8Dtype(),
pyarrow.uint16(): pandas.UInt16Dtype(),
pyarrow.uint32(): pandas.UInt32Dtype(),
pyarrow.uint64(): pandas.UInt64Dtype(),
}
dtype = extended_dtype_mapping[pa_attr.type]
else:
Expand Down Expand Up @@ -463,7 +467,7 @@ def _run_query(self) -> Union[DataFrame, Table]:

df = table.to_pandas()
else:
df = DataFrame(_get_pyquery_results(self.pyquery, self.array.schema))
df = pandas.DataFrame(_get_pyquery_results(self.pyquery, self.array.schema))

with timing("pandas_index_update_time"):
return _update_df_from_meta(df, self.array.meta, self.query.index_col)
Expand Down Expand Up @@ -663,8 +667,10 @@ def _get_empty_results(


def _update_df_from_meta(
df: DataFrame, array_meta: Metadata, index_col: Union[List[str], bool, None] = True
) -> DataFrame:
df: "pandas.DataFrame",
array_meta: Metadata,
index_col: Union[List[str], bool, None] = True,
) -> "pandas.DataFrame":
col_dtypes = {}
if "__pandas_attribute_repr" in array_meta:
attr_dtypes = json.loads(array_meta["__pandas_attribute_repr"])
Expand Down
13 changes: 13 additions & 0 deletions tiledb/tests/test_basic_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import ast
import subprocess
import sys


def test_dont_import_pandas() -> None:
"""Verifies that when we import TileDB, we don't import Pandas eagerly."""
# Get a list of all modules from a completely fresh interpreter.
all_mods_str = subprocess.check_output(
(sys.executable, "-c", "import sys, tiledb; print(list(sys.modules))")
)
all_mods = ast.literal_eval(all_mods_str.decode())
assert "pandas" not in all_mods