-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial read_gbq implementation (WIP) #1
Changes from 13 commits
94c41f6
48becdb
a934259
04bdd80
ab16a32
455f749
c417d5f
4839bbb
774e79b
7bdd66a
31a1253
db4edb4
be1efbd
35cbdc6
40de1ea
45e0004
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
name: Linting | ||
|
||
on: | ||
push: | ||
branches: main | ||
pull_request: | ||
branches: main | ||
|
||
jobs: | ||
checks: | ||
name: "pre-commit hooks" | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- uses: actions/setup-python@v2 | ||
- uses: pre-commit/[email protected] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
repos: | ||
- repo: https://github.com/psf/black | ||
rev: 20.8b1 | ||
hooks: | ||
- id: black | ||
language_version: python3 | ||
exclude: versioneer.py | ||
- repo: https://gitlab.com/pycqa/flake8 | ||
rev: 3.8.3 | ||
hooks: | ||
- id: flake8 | ||
language_version: python3 | ||
- repo: https://github.com/pycqa/isort | ||
rev: 5.8.0 | ||
hooks: | ||
- id: isort | ||
language_version: python3 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .core import read_gbq |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from collections.abc import Iterable | ||
from contextlib import contextmanager | ||
from functools import partial | ||
|
||
import dask | ||
import dask.dataframe as dd | ||
import pandas as pd | ||
import pyarrow | ||
from google.cloud import bigquery, bigquery_storage | ||
|
||
|
||
@contextmanager | ||
def bigquery_client(project_id=None, with_storage_api=False): | ||
"""This context manager is a temporary solution until there is an | ||
upstream solution to handle this. | ||
See googleapis/google-cloud-python#9457 | ||
and googleapis/gapic-generator-python#575 for reference. | ||
""" | ||
|
||
bq_storage_client = None | ||
bq_client = bigquery.Client(project_id) | ||
try: | ||
if with_storage_api: | ||
bq_storage_client = bigquery_storage.BigQueryReadClient( | ||
credentials=bq_client._credentials | ||
) | ||
yield bq_client, bq_storage_client | ||
else: | ||
yield bq_client | ||
finally: | ||
bq_client.close() | ||
|
||
|
||
def _stream_to_dfs(bqs_client, stream_name, schema, timeout): | ||
"""Given a Storage API client and a stream name, yield all dataframes.""" | ||
return [ | ||
pyarrow.ipc.read_record_batch( | ||
pyarrow.py_buffer(message.arrow_record_batch.serialized_record_batch), | ||
schema, | ||
).to_pandas() | ||
for message in bqs_client.read_rows(name=stream_name, offset=0, timeout=timeout) | ||
] | ||
|
||
|
||
@dask.delayed | ||
def _read_rows_arrow( | ||
*, | ||
make_create_read_session_request: callable, | ||
partition_field: str = None, | ||
project_id: str, | ||
stream_name: str = None, | ||
timeout: int, | ||
) -> pd.DataFrame: | ||
"""Read a single batch of rows via BQ Storage API, in Arrow binary format. | ||
Args: | ||
project_id: BigQuery project | ||
create_read_session_request: kwargs to pass to `bqs_client.create_read_session` | ||
as `request` | ||
partition_field: BigQuery field for partitions, to be used as Dask index col for | ||
divisions | ||
NOTE: Please set if specifying `row_restriction` filters in TableReadOptions. | ||
stream_name: BigQuery Storage API Stream "name". | ||
NOTE: Please set if reading from Storage API without any `row_restriction`. | ||
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream | ||
NOTE: `partition_field` and `stream_name` kwargs are mutually exclusive. | ||
Adapted from | ||
https://github.com/googleapis/python-bigquery-storage/blob/a0fc0af5b4447ce8b50c365d4d081b9443b8490e/google/cloud/bigquery_storage_v1/reader.py. | ||
""" | ||
with bigquery_client(project_id, with_storage_api=True) as (bq_client, bqs_client): | ||
session = bqs_client.create_read_session(make_create_read_session_request()) | ||
schema = pyarrow.ipc.read_schema( | ||
pyarrow.py_buffer(session.arrow_schema.serialized_schema) | ||
) | ||
|
||
if (partition_field is not None) and (stream_name is not None): | ||
raise ValueError( | ||
"The kwargs `partition_field` and `stream_name` are mutually exclusive." | ||
) | ||
|
||
elif partition_field is not None: | ||
shards = [ | ||
df | ||
for stream in session.streams | ||
for df in _stream_to_dfs( | ||
bqs_client, stream.name, schema, timeout=timeout | ||
) | ||
] | ||
# NOTE: if no rows satisfying the row_restriction, then `shards` will be empty list | ||
if len(shards) == 0: | ||
shards = [schema.empty_table().to_pandas()] | ||
shards = [shard.set_index(partition_field, drop=True) for shard in shards] | ||
|
||
elif stream_name is not None: | ||
shards = _stream_to_dfs(bqs_client, stream_name, schema, timeout=timeout) | ||
# NOTE: BQ Storage API can return empty streams | ||
if len(shards) == 0: | ||
shards = [schema.empty_table().to_pandas()] | ||
|
||
else: | ||
raise NotImplementedError( | ||
"Please specify either `partition_field` or `stream_name`." | ||
) | ||
|
||
return pd.concat(shards) | ||
|
||
|
||
def read_gbq( | ||
project_id: str, | ||
dataset_id: str, | ||
table_id: str, | ||
partition_field: str = None, | ||
partitions: Iterable[str] = None, | ||
row_filter="", | ||
fields: list[str] = (), | ||
read_timeout: int = 3600, | ||
): | ||
"""Read table as dask dataframe using BigQuery Storage API via Arrow format. | ||
If `partition_field` and `partitions` are specified, then the resulting dask dataframe | ||
will be partitioned along the same boundaries. Otherwise, partitions will be approximately | ||
balanced according to BigQuery stream allocation logic. | ||
If `partition_field` is specified but not included in `fields` (either implicitly by requesting | ||
all fields, or explicitly by inclusion in the list `fields`), then it will still be included | ||
in the query in order to have it available for dask dataframe indexing. | ||
Args: | ||
project_id: BigQuery project | ||
dataset_id: BigQuery dataset within project | ||
table_id: BigQuery table within dataset | ||
partition_field: to specify filters of form "WHERE {partition_field} = ..." | ||
partitions: all values to select of `partition_field` | ||
fields: names of the fields (columns) to select (default None to "SELECT *") | ||
read_timeout: # of seconds an individual read request has before timing out | ||
Returns: | ||
dask dataframe | ||
See https://github.com/dask/dask/issues/3121 for additional context. | ||
""" | ||
if (partition_field is None) and (partitions is not None): | ||
raise ValueError("Specified `partitions` without `partition_field`.") | ||
|
||
# If `partition_field` is not part of the `fields` filter, fetch it anyway to be able | ||
# to set it as dask dataframe index. We want this to be able to have consistent: | ||
# BQ partitioning + dask divisions + pandas index values | ||
if (partition_field is not None) and fields and (partition_field not in fields): | ||
fields = (partition_field, *fields) | ||
|
||
# These read tasks seems to cause deadlocks (or at least long stuck workers out of touch with | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this annotate is maybe a bad idea, would be nice to have @jrbourbeau or someone weigh in; note that we observed this behavior with now-fairly old dask and bigquery_storage/pyarrow versions so I have no idea if it's still relevant |
||
# the scheduler), particularly when mixed with other tasks that execute C code. Anecdotally | ||
# annotating the tasks with a higher priority seems to help (but not fully solve) the issue at | ||
# the expense of higher cluster memory usage. | ||
with bigquery_client(project_id, with_storage_api=True) as ( | ||
bq_client, | ||
bqs_client, | ||
), dask.annotate(priority=1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We would definitely prefer to not have this annotation if possible. Data generation tasks should be *de-*prioritized if anything |
||
table_ref = bq_client.get_table(".".join((dataset_id, table_id))) | ||
if table_ref.table_type == "VIEW": | ||
raise TypeError("Table type VIEW not supported") | ||
|
||
# The protobuf types can't be pickled (may be able to tweak w/ copyreg), so instead use a | ||
# generator func. | ||
def make_create_read_session_request(row_filter=""): | ||
return bigquery_storage.types.CreateReadSessionRequest( | ||
max_stream_count=100, # 0 -> use as many streams as BQ Storage will provide | ||
parent=f"projects/{project_id}", | ||
read_session=bigquery_storage.types.ReadSession( | ||
data_format=bigquery_storage.types.DataFormat.ARROW, | ||
read_options=bigquery_storage.types.ReadSession.TableReadOptions( | ||
row_restriction=row_filter, | ||
selected_fields=fields, | ||
), | ||
table=table_ref.to_bqstorage(), | ||
), | ||
) | ||
|
||
# Create a read session in order to detect the schema. | ||
# Read sessions are light weight and will be auto-deleted after 24 hours. | ||
session = bqs_client.create_read_session( | ||
make_create_read_session_request(row_filter=row_filter) | ||
) | ||
schema = pyarrow.ipc.read_schema( | ||
pyarrow.py_buffer(session.arrow_schema.serialized_schema) | ||
) | ||
meta = schema.empty_table().to_pandas() | ||
delayed_kwargs = dict(prefix=f"{dataset_id}.{table_id}-") | ||
|
||
if partition_field is not None: | ||
if row_filter: | ||
raise ValueError("Cannot pass both `partition_field` and `row_filter`") | ||
delayed_kwargs["meta"] = meta.set_index(partition_field, drop=True) | ||
|
||
if partitions is None: | ||
logging.info( | ||
"Specified `partition_field` without `partitions`; reading full table." | ||
) | ||
partitions = [ | ||
p | ||
for p in bq_client.list_partitions(f"{dataset_id}.{table_id}") | ||
if p != "__NULL__" | ||
] | ||
# TODO generalize to ranges (as opposed to discrete values) | ||
|
||
partitions = sorted(partitions) | ||
delayed_kwargs["divisions"] = (*partitions, partitions[-1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bnaul I noticed in the example I run, that this line causes to have the last partition to contain only 1 element, but that element could have fit into the previous to last partition. What is the reason you separate the last partition? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure why it's not working correctly for you but the idea is that you need n+1 divisions for n partitions. seems to work OK here
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if it's related to how the data is originally partitioned. For example when I read one of the tables of the covid public data set that I copied on "my_project" I see this from dask_bigquery import read_gbq
ddf= read_gbq(
project_id="my_project",
dataset_id="covid19_public_forecasts",
table_id="county_14d",)
ddf.map_partitions(len).compute() Notice the last two partitions...
|
||
row_filters = [ | ||
f'{partition_field} = "{partition_value}"' | ||
for partition_value in partitions | ||
] | ||
delayed_dfs = [ | ||
_read_rows_arrow( | ||
make_create_read_session_request=partial( | ||
make_create_read_session_request, row_filter=row_filter | ||
), | ||
partition_field=partition_field, | ||
project_id=project_id, | ||
timeout=read_timeout, | ||
) | ||
for row_filter in row_filters | ||
] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is great for now, but at some point we may want to use raw task graphs. They're a bit cleaner in a few ways. Delayed is more designed for user code. If we have the time we prefer to use raw graphs in dev code. For example, in some cases I wouldn't be surprised if each Delayed task produces a single TaskGroup, rather than having all of the tasks in a single TaskGroup. Sure, this will compute just fine, but other features (like the task group visualization, or coiled telemetry) may be sad. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jrbourbeau and I gave it a try to use We might want to move this to a separate PR. |
||
else: | ||
delayed_kwargs["meta"] = meta | ||
delayed_dfs = [ | ||
_read_rows_arrow( | ||
make_create_read_session_request=make_create_read_session_request, | ||
project_id=project_id, | ||
stream_name=stream.name, | ||
timeout=read_timeout, | ||
) | ||
for stream in session.streams | ||
] | ||
|
||
return dd.from_delayed(dfs=delayed_dfs, **delayed_kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import random | ||
|
||
import pandas as pd | ||
import pytest | ||
from dask.dataframe.utils import assert_eq | ||
from distributed.utils_test import cluster_fixture # noqa: F401 | ||
from distributed.utils_test import client, loop # noqa: F401 | ||
from google.cloud import bigquery | ||
|
||
from dask_bigquery import read_gbq | ||
|
||
# These tests are run locally and assume the user is already athenticated. | ||
# It also assumes that the user has created a project called dask-bigquery. | ||
|
||
|
||
@pytest.fixture | ||
def df(): | ||
records = [ | ||
{ | ||
"name": random.choice(["fred", "wilma", "barney", "betty"]), | ||
"number": random.randint(0, 100), | ||
"idx": i, | ||
} | ||
for i in range(10) | ||
] | ||
|
||
yield pd.DataFrame(records) | ||
|
||
|
||
@pytest.fixture | ||
def dataset(df): | ||
"Push some data to BigQuery using pandas gbq" | ||
|
||
with bigquery.Client() as bq_client: | ||
try: | ||
bq_client.delete_dataset( | ||
dataset="dask-bigquery.dataset_test", | ||
delete_contents=True, | ||
) | ||
except: # noqa: E722 | ||
pass | ||
|
||
# push data to gbq | ||
pd.DataFrame.to_gbq( | ||
df, | ||
destination_table="dataset_test.table_test", | ||
project_id="dask-bigquery", | ||
chunksize=5, | ||
if_exists="append", | ||
) | ||
yield "dask-bigquery.dataset_test.table_test" | ||
|
||
|
||
# test simple read | ||
def test_read_gbq(df, dataset, client): | ||
"""Test simple read of data pushed to BigQuery using pandas-gbq""" | ||
project_id, dataset_id, table_id = dataset.split(".") | ||
|
||
ddf = read_gbq(project_id=project_id, dataset_id=dataset_id, table_id=table_id) | ||
|
||
assert ddf.columns.tolist() == ["name", "number", "idx"] | ||
assert len(ddf) == 10 | ||
assert ddf.npartitions == 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we verify that the data is actually the same as the data created in @pytest.fixture
def df():
...
@pytest.fixture
def dataset(df):
...
def test_read_gbq(client, dataset, df):
ddf = read_gbq(...)
assert_eq(ddf, df) Maybe there are sorting things that get in the way (is GBQ ordered?) If so then, as you did before assert_eq(ddf.set_index("idx"), df.set_index("idx")) In general we want to use assert_eq if possible. It runs lots of cleanliness checks on the Dask collection, graph, metadata, and so on. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I get some order issues when reading back from gbq, mainly because when I read back the default index goes from 0 to chunksize-1, where chunksize was chosen when I pushed the pandas dataframe. This was part of the reason I had as an extra column "idx". But thanks for pointing out the assert_eq(ddf.set_index("idx").compute(), df.set_index("idx")) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I get this: ____________________________________________________________________ test_read_gbq _____________________________________________________________________
df = name number idx
0 betty 71 0
1 fred 36 1
2 wilma 75 2
3 betty 13 3
4 ... 4
5 fred 74 5
6 wilma 69 6
7 fred 31 7
8 barney 31 8
9 betty 97 9
dataset = 'dask-bigquery.dataset_test.table_test', client = <Client: 'tcp://127.0.0.1:55212' processes=2 threads=2, memory=32.00 GiB>
def test_read_gbq(df, dataset, client):
"""Test simple read of data pushed to BigQuery using pandas-gbq"""
project_id, dataset_id, table_id = dataset.split(".")
ddf = read_gbq(
project_id=project_id, dataset_id=dataset_id, table_id=table_id
)
assert ddf.columns.tolist() == ["name", "number", "idx"]
assert len(ddf) == 10
assert ddf.npartitions == 2
#breakpoint()
> assert assert_eq(ddf.set_index("idx"), df.set_index("idx"))
test_core.py:67:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../../../mambaforge/envs/test_gbq/lib/python3.8/site-packages/dask/dataframe/utils.py:541: in assert_eq
assert_sane_keynames(a)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ddf = Dask DataFrame Structure:
name number
npartitions=2
0 object int64
4 ... ...
9 ... ...
Dask Name: sort_index, 22 tasks
def assert_sane_keynames(ddf):
if not hasattr(ddf, "dask"):
return
for k in ddf.dask.keys():
while isinstance(k, tuple):
k = k[0]
assert isinstance(k, (str, bytes))
assert len(k) < 100
assert " " not in k
> assert k.split("-")[0].isidentifier()
E AssertionError
../../../../../mambaforge/envs/test_gbq/lib/python3.8/site-packages/dask/dataframe/utils.py:621: AssertionError There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, interesting! Do you have any additional information in the traceback? I'm wondering what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also xref dask/dask#8061 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There is no additional in the traceback but There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is coming from the delayed_kwargs = dict(prefix=f"{dataset_id}.{table_id}-") line earlier in this function. Ultimately we'll want to move away from delayed_kwargs = {} instead. That should allow you to also drop the |
||
|
||
assert assert_eq(ddf.set_index("idx").compute(), df.set_index("idx")) | ||
|
||
|
||
# test partitioned data: this test requires a copy of the public dataset | ||
# bigquery-public-data.covid19_public_forecasts.county_14d into a the | ||
# project dask-bigquery | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"fields", | ||
([], ["county_name"], ["county_name", "county_fips_code"]), | ||
ids=["no_fields", "missing_partition_field", "fields"], | ||
) | ||
def test_read_gbq_partitioning(fields, client): | ||
partitions = ["Teton", "Loudoun"] | ||
ddf = read_gbq( | ||
project_id="dask-bigquery", | ||
dataset_id="covid19_public_forecasts", | ||
table_id="county_14d", | ||
partition_field="county_name", | ||
partitions=partitions, | ||
fields=fields, | ||
) | ||
|
||
assert len(ddf) # check it's not empty | ||
loaded = set(ddf.columns) | {ddf.index.name} | ||
|
||
if fields: | ||
assert loaded == set(fields) | {"county_name"} | ||
else: # all columns loaded | ||
assert loaded >= set(["county_name", "county_fips_code"]) | ||
|
||
assert ddf.npartitions == len(partitions) | ||
assert list(ddf.divisions) == sorted(ddf.divisions) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
dask | ||
distributed | ||
google-cloud-bigquery | ||
google-cloud-bigquery-storage | ||
pandas | ||
pandas-gbq | ||
pyarrow |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[flake8] | ||
exclude = __init__.py | ||
max-line-length = 120 | ||
ignore = F811 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use delayed (see comment below) then this name will be the name that shows up in the task stream, progress bars, etc.. We may want to make it more clearly GBQ related, like
bigquery_read