From 21808367d02b5b7fcf35b3c7520224c819879aec Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 22 Nov 2021 09:28:01 -0600 Subject: [PATCH] fix: `to_gbq` allows strings for DATE and floats for NUMERIC with `api_method="load_parquet"` (#423) deps: require pandas 0.24+ and db-dtypes for TIME/DATE extension dtypes (#423) --- .circleci/config.yml | 2 +- .coveragerc | 2 +- ....2.conda => requirements-3.7-0.24.2.conda} | 1 + ci/requirements-3.9-NIGHTLY.conda | 1 + noxfile.py | 8 +- owlbot.py | 6 +- pandas_gbq/load.py | 52 +++++++ setup.py | 4 +- testing/constraints-3.7.txt | 3 +- tests/system/test_gbq.py | 21 +-- tests/system/test_to_gbq.py | 139 ++++++++++++------ tests/unit/test_load.py | 120 ++++++++++++++- 12 files changed, 279 insertions(+), 80 deletions(-) rename ci/{requirements-3.7-0.23.2.conda => requirements-3.7-0.24.2.conda} (89%) diff --git a/.circleci/config.yml b/.circleci/config.yml index ec4d7448..4c378b3f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -10,7 +10,7 @@ jobs: - image: continuumio/miniconda3 environment: PYTHON: "3.7" - PANDAS: "0.23.2" + PANDAS: "0.24.2" steps: - checkout - run: ci/config_auth.sh diff --git a/.coveragerc b/.coveragerc index 61285af5..ba50bf32 100644 --- a/.coveragerc +++ b/.coveragerc @@ -22,7 +22,7 @@ omit = google/cloud/__init__.py [report] -fail_under = 86 +fail_under = 88 show_missing = True exclude_lines = # Re-enable the standard pragma diff --git a/ci/requirements-3.7-0.23.2.conda b/ci/requirements-3.7-0.24.2.conda similarity index 89% rename from ci/requirements-3.7-0.23.2.conda rename to ci/requirements-3.7-0.24.2.conda index 1da6d226..82f4e7b9 100644 --- a/ci/requirements-3.7-0.23.2.conda +++ b/ci/requirements-3.7-0.24.2.conda @@ -1,5 +1,6 @@ codecov coverage +db-dtypes==0.3.0 fastavro flake8 numpy==1.16.6 diff --git a/ci/requirements-3.9-NIGHTLY.conda b/ci/requirements-3.9-NIGHTLY.conda index ccaa87e5..5a3e9fb7 100644 --- a/ci/requirements-3.9-NIGHTLY.conda +++ b/ci/requirements-3.9-NIGHTLY.conda @@ -1,3 +1,4 @@ +db-dtypes pydata-google-auth google-cloud-bigquery google-cloud-bigquery-storage diff --git a/noxfile.py b/noxfile.py index 825daf18..2feeccdc 100644 --- a/noxfile.py +++ b/noxfile.py @@ -146,11 +146,7 @@ def system(session): # Install all test dependencies, then install this package into the # virtualenv's dist-packages. session.install("mock", "pytest", "google-cloud-testutils", "-c", constraints_path) - if session.python == "3.9": - extras = "[tqdm,db-dtypes]" - else: - extras = "[tqdm]" - session.install("-e", f".{extras}", "-c", constraints_path) + session.install("-e", ".[tqdm]", "-c", constraints_path) # Run py.test against the system tests. if system_test_exists: @@ -179,7 +175,7 @@ def cover(session): test runs (not system test runs), and then erases coverage data. """ session.install("coverage", "pytest-cov") - session.run("coverage", "report", "--show-missing", "--fail-under=86") + session.run("coverage", "report", "--show-missing", "--fail-under=88") session.run("coverage", "erase") diff --git a/owlbot.py b/owlbot.py index 71679dd4..c69d54de 100644 --- a/owlbot.py +++ b/owlbot.py @@ -29,16 +29,12 @@ # ---------------------------------------------------------------------------- extras = ["tqdm"] -extras_by_python = { - "3.9": ["tqdm", "db-dtypes"], -} templated_files = common.py_library( unit_test_python_versions=["3.7", "3.8", "3.9", "3.10"], system_test_python_versions=["3.7", "3.8", "3.9", "3.10"], - cov_level=86, + cov_level=88, unit_test_extras=extras, system_test_extras=extras, - system_test_extras_by_python=extras_by_python, intersphinx_dependencies={ "pandas": "https://pandas.pydata.org/pandas-docs/stable/", "pydata-google-auth": "https://pydata-google-auth.readthedocs.io/en/latest/", diff --git a/pandas_gbq/load.py b/pandas_gbq/load.py index 69210e41..5422402e 100644 --- a/pandas_gbq/load.py +++ b/pandas_gbq/load.py @@ -4,9 +4,11 @@ """Helper methods for loading data into BigQuery""" +import decimal import io from typing import Any, Callable, Dict, List, Optional +import db_dtypes import pandas import pyarrow.lib from google.cloud import bigquery @@ -56,6 +58,55 @@ def split_dataframe(dataframe, chunksize=None): yield remaining_rows, chunk +def cast_dataframe_for_parquet( + dataframe: pandas.DataFrame, schema: Optional[Dict[str, Any]], +) -> pandas.DataFrame: + """Cast columns to needed dtype when writing parquet files. + + See: https://github.com/googleapis/python-bigquery-pandas/issues/421 + """ + + columns = schema.get("fields", []) + + # Protect against an explicit None in the dictionary. + columns = columns if columns is not None else [] + + for column in columns: + # Schema can be a superset of the columns in the dataframe, so ignore + # columns that aren't present. + column_name = column.get("name") + if column_name not in dataframe.columns: + continue + + # Skip array columns for now. Potentially casting the elements of the + # array would be possible, but not worth the effort until there is + # demand for it. + if column.get("mode", "NULLABLE").upper() == "REPEATED": + continue + + column_type = column.get("type", "").upper() + if ( + column_type == "DATE" + # Use extension dtype first so that it uses the correct equality operator. + and db_dtypes.DateDtype() != dataframe[column_name].dtype + ): + # Construct converted column manually, because I can't use + # .astype() with DateDtype. With .astype(), I get the error: + # + # TypeError: Cannot interpret '' as a data type + cast_column = pandas.Series( + dataframe[column_name], dtype=db_dtypes.DateDtype() + ) + elif column_type in {"NUMERIC", "DECIMAL", "BIGNUMERIC", "BIGDECIMAL"}: + cast_column = dataframe[column_name].map(decimal.Decimal) + else: + cast_column = None + + if cast_column is not None: + dataframe = dataframe.assign(**{column_name: cast_column}) + return dataframe + + def load_parquet( client: bigquery.Client, dataframe: pandas.DataFrame, @@ -70,6 +121,7 @@ def load_parquet( if schema is not None: schema = pandas_gbq.schema.remove_policy_tags(schema) job_config.schema = pandas_gbq.schema.to_google_cloud_bigquery(schema) + dataframe = cast_dataframe_for_parquet(dataframe, schema) try: client.load_table_from_dataframe( diff --git a/setup.py b/setup.py index 876bd4c0..28c81eee 100644 --- a/setup.py +++ b/setup.py @@ -23,8 +23,9 @@ release_status = "Development Status :: 4 - Beta" dependencies = [ "setuptools", + "db-dtypes >=0.3.0,<2.0.0", "numpy>=1.16.6", - "pandas>=0.23.2", + "pandas>=0.24.2", "pyarrow >=3.0.0, <7.0dev", "pydata-google-auth", "google-auth", @@ -35,7 +36,6 @@ ] extras = { "tqdm": "tqdm>=4.23.0", - "db-dtypes": "db-dtypes >=0.3.0,<2.0.0", } # Setup boilerplate below this line. diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt index 7c67d275..7920656a 100644 --- a/testing/constraints-3.7.txt +++ b/testing/constraints-3.7.txt @@ -5,12 +5,13 @@ # # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", # Then this file should have foo==1.14.0 +db-dtypes==0.3.0 google-auth==1.4.1 google-auth-oauthlib==0.0.1 google-cloud-bigquery==1.11.1 google-cloud-bigquery-storage==1.1.0 numpy==1.16.6 -pandas==0.23.2 +pandas==0.24.2 pyarrow==3.0.0 pydata-google-auth==0.1.2 tqdm==4.23.0 diff --git a/tests/system/test_gbq.py b/tests/system/test_gbq.py index a8d6bd0d..f268a85d 100644 --- a/tests/system/test_gbq.py +++ b/tests/system/test_gbq.py @@ -26,8 +26,6 @@ TABLE_ID = "new_test" PANDAS_VERSION = pkg_resources.parse_version(pandas.__version__) -NULLABLE_INT_PANDAS_VERSION = pkg_resources.parse_version("0.24.0") -NULLABLE_INT_MESSAGE = "Require pandas 0.24+ in order to use nullable integer type." def test_imports(): @@ -173,9 +171,6 @@ def test_should_properly_handle_valid_integers(self, project_id): tm.assert_frame_equal(df, DataFrame({"valid_integer": [3]})) def test_should_properly_handle_nullable_integers(self, project_id): - if PANDAS_VERSION < NULLABLE_INT_PANDAS_VERSION: - pytest.skip(msg=NULLABLE_INT_MESSAGE) - query = """SELECT * FROM UNNEST([1, NULL]) AS nullable_integer """ @@ -188,9 +183,7 @@ def test_should_properly_handle_nullable_integers(self, project_id): ) tm.assert_frame_equal( df, - DataFrame( - {"nullable_integer": pandas.Series([1, pandas.NA], dtype="Int64")} - ), + DataFrame({"nullable_integer": pandas.Series([1, None], dtype="Int64")}), ) def test_should_properly_handle_valid_longs(self, project_id): @@ -204,9 +197,6 @@ def test_should_properly_handle_valid_longs(self, project_id): tm.assert_frame_equal(df, DataFrame({"valid_long": [1 << 62]})) def test_should_properly_handle_nullable_longs(self, project_id): - if PANDAS_VERSION < NULLABLE_INT_PANDAS_VERSION: - pytest.skip(msg=NULLABLE_INT_MESSAGE) - query = """SELECT * FROM UNNEST([1 << 62, NULL]) AS nullable_long """ @@ -219,15 +209,10 @@ def test_should_properly_handle_nullable_longs(self, project_id): ) tm.assert_frame_equal( df, - DataFrame( - {"nullable_long": pandas.Series([1 << 62, pandas.NA], dtype="Int64")} - ), + DataFrame({"nullable_long": pandas.Series([1 << 62, None], dtype="Int64")}), ) def test_should_properly_handle_null_integers(self, project_id): - if PANDAS_VERSION < NULLABLE_INT_PANDAS_VERSION: - pytest.skip(msg=NULLABLE_INT_MESSAGE) - query = "SELECT CAST(NULL AS INT64) AS null_integer" df = gbq.read_gbq( query, @@ -237,7 +222,7 @@ def test_should_properly_handle_null_integers(self, project_id): dtypes={"null_integer": "Int64"}, ) tm.assert_frame_equal( - df, DataFrame({"null_integer": pandas.Series([pandas.NA], dtype="Int64")}), + df, DataFrame({"null_integer": pandas.Series([None], dtype="Int64")}), ) def test_should_properly_handle_valid_floats(self, project_id): diff --git a/tests/system/test_to_gbq.py b/tests/system/test_to_gbq.py index 4f315a77..4421f3be 100644 --- a/tests/system/test_to_gbq.py +++ b/tests/system/test_to_gbq.py @@ -2,23 +2,22 @@ # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file. +import datetime +import decimal +import collections import functools import random +import db_dtypes import pandas import pandas.testing import pytest -try: - import db_dtypes -except ImportError: - db_dtypes = None - pytest.importorskip("google.cloud.bigquery", minversion="1.24.0") -@pytest.fixture(params=["default", "load_parquet", "load_csv"]) +@pytest.fixture(params=["load_parquet", "load_csv"]) def api_method(request): return request.param @@ -32,13 +31,20 @@ def method_under_test(credentials, project_id): ) +SeriesRoundTripTestCase = collections.namedtuple( + "SeriesRoundTripTestCase", + ["input_series", "api_methods"], + defaults=[None, {"load_csv", "load_parquet"}], +) + + @pytest.mark.parametrize( - ["input_series", "skip_csv"], + ["input_series", "api_methods"], [ # Ensure that 64-bit floating point numbers are unchanged. # See: https://github.com/pydata/pandas-gbq/issues/326 - ( - pandas.Series( + SeriesRoundTripTestCase( + input_series=pandas.Series( [ 0.14285714285714285, 0.4406779661016949, @@ -51,10 +57,9 @@ def method_under_test(credentials, project_id): ], name="test_col", ), - False, ), - ( - pandas.Series( + SeriesRoundTripTestCase( + input_series=pandas.Series( [ "abc", "defg", @@ -66,10 +71,9 @@ def method_under_test(credentials, project_id): ], name="test_col", ), - False, ), - ( - pandas.Series( + SeriesRoundTripTestCase( + input_series=pandas.Series( [ "abc", "defg", @@ -81,7 +85,13 @@ def method_under_test(credentials, project_id): ], name="empty_strings", ), - True, + # BigQuery CSV loader uses empty string as the "null marker" by + # default. Potentially one could choose a rarely used character or + # string as the null marker to disambiguate null from empty string, + # but then that string couldn't be loaded. + # TODO: Revist when custom load job configuration is supported. + # https://github.com/googleapis/python-bigquery-pandas/issues/425 + api_methods={"load_parquet"}, ), ], ) @@ -91,10 +101,10 @@ def test_series_round_trip( bigquery_client, input_series, api_method, - skip_csv, + api_methods, ): - if api_method == "load_csv" and skip_csv: - pytest.skip("Loading with CSV not supported.") + if api_method not in api_methods: + pytest.skip(f"{api_method} not supported.") table_id = f"{random_dataset_id}.round_trip_{random.randrange(1_000_000)}" input_series = input_series.sort_values().reset_index(drop=True) df = pandas.DataFrame( @@ -111,60 +121,99 @@ def test_series_round_trip( ) +DataFrameRoundTripTestCase = collections.namedtuple( + "DataFrameRoundTripTestCase", + ["input_df", "expected_df", "table_schema", "api_methods"], + defaults=[None, None, [], {"load_csv", "load_parquet"}], +) + DATAFRAME_ROUND_TRIPS = [ # Ensure that a DATE column can be written with datetime64[ns] dtype # data. See: # https://github.com/googleapis/python-bigquery-pandas/issues/362 - ( - pandas.DataFrame( + DataFrameRoundTripTestCase( + input_df=pandas.DataFrame( { + "row_num": [0, 1, 2], "date_col": pandas.Series( ["2021-04-17", "1999-12-31", "2038-01-19"], dtype="datetime64[ns]", ), } ), - [{"name": "date_col", "type": "DATE"}], - True, + table_schema=[{"name": "date_col", "type": "DATE"}], + # Skip CSV because the pandas CSV writer includes time when writing + # datetime64 values. + api_methods={"load_parquet"}, + ), + DataFrameRoundTripTestCase( + input_df=pandas.DataFrame( + { + "row_num": [0, 1, 2], + "date_col": pandas.Series( + ["2021-04-17", "1999-12-31", "2038-01-19"], + dtype=db_dtypes.DateDtype(), + ), + } + ), + table_schema=[{"name": "date_col", "type": "DATE"}], + ), + # Loading a DATE column should work for string objects. See: + # https://github.com/googleapis/python-bigquery-pandas/issues/421 + DataFrameRoundTripTestCase( + input_df=pandas.DataFrame( + {"row_num": [123], "date_col": ["2021-12-12"]}, + columns=["row_num", "date_col"], + ), + expected_df=pandas.DataFrame( + {"row_num": [123], "date_col": [datetime.date(2021, 12, 12)]}, + columns=["row_num", "date_col"], + ), + table_schema=[ + {"name": "row_num", "type": "INTEGER"}, + {"name": "date_col", "type": "DATE"}, + ], + ), + # Loading a NUMERIC column should work for floating point objects. See: + # https://github.com/googleapis/python-bigquery-pandas/issues/421 + DataFrameRoundTripTestCase( + input_df=pandas.DataFrame( + {"row_num": [123], "num_col": [1.25]}, columns=["row_num", "num_col"], + ), + expected_df=pandas.DataFrame( + {"row_num": [123], "num_col": [decimal.Decimal("1.25")]}, + columns=["row_num", "num_col"], + ), + table_schema=[ + {"name": "row_num", "type": "INTEGER"}, + {"name": "num_col", "type": "NUMERIC"}, + ], ), ] -if db_dtypes is not None: - DATAFRAME_ROUND_TRIPS.append( - ( - pandas.DataFrame( - { - "date_col": pandas.Series( - ["2021-04-17", "1999-12-31", "2038-01-19"], dtype="dbdate", - ), - } - ), - [{"name": "date_col", "type": "DATE"}], - False, - ) - ) @pytest.mark.parametrize( - ["input_df", "table_schema", "skip_csv"], DATAFRAME_ROUND_TRIPS + ["input_df", "expected_df", "table_schema", "api_methods"], DATAFRAME_ROUND_TRIPS ) def test_dataframe_round_trip_with_table_schema( method_under_test, random_dataset_id, bigquery_client, input_df, + expected_df, table_schema, api_method, - skip_csv, + api_methods, ): - if api_method == "load_csv" and skip_csv: - pytest.skip("Loading with CSV not supported.") + if api_method not in api_methods: + pytest.skip(f"{api_method} not supported.") + if expected_df is None: + expected_df = input_df table_id = f"{random_dataset_id}.round_trip_w_schema_{random.randrange(1_000_000)}" - input_df["row_num"] = input_df.index - input_df.sort_values("row_num", inplace=True) method_under_test( input_df, table_id, table_schema=table_schema, api_method=api_method ) round_trip = bigquery_client.list_rows(table_id).to_dataframe( - dtypes=dict(zip(input_df.columns, input_df.dtypes)) + dtypes=dict(zip(expected_df.columns, expected_df.dtypes)) ) round_trip.sort_values("row_num", inplace=True) - pandas.testing.assert_frame_equal(input_df, round_trip) + pandas.testing.assert_frame_equal(expected_df, round_trip) diff --git a/tests/unit/test_load.py b/tests/unit/test_load.py index a32d2d9e..8e18cfb9 100644 --- a/tests/unit/test_load.py +++ b/tests/unit/test_load.py @@ -4,12 +4,16 @@ # -*- coding: utf-8 -*- -import textwrap +import datetime +import decimal from io import StringIO +import textwrap from unittest import mock +import db_dtypes import numpy import pandas +import pandas.testing import pytest from pandas_gbq.features import FEATURES @@ -137,3 +141,117 @@ def test_load_chunks_omits_policy_tags( def test_load_chunks_with_invalid_api_method(): with pytest.raises(ValueError, match="Got unexpected api_method:"): load.load_chunks(None, None, None, api_method="not_a_thing") + + +@pytest.mark.parametrize( + ("numeric_type",), + ( + ("NUMERIC",), + ("DECIMAL",), + ("BIGNUMERIC",), + ("BIGDECIMAL",), + ("numeric",), + ("decimal",), + ("bignumeric",), + ("bigdecimal",), + ), +) +def test_cast_dataframe_for_parquet_w_float_numeric(numeric_type): + dataframe = pandas.DataFrame( + { + "row_num": [0, 1, 2], + "num_col": pandas.Series( + # Very much not recommend as the whole point of NUMERIC is to + # be more accurate than a floating point number, but tested to + # keep compatibility with CSV-based uploads. See: + # https://github.com/googleapis/python-bigquery-pandas/issues/421 + [1.25, -1.25, 42.5], + dtype="float64", + ), + "row_num_2": [0, 1, 2], + }, + # Use multiple columns to ensure column order is maintained. + columns=["row_num", "num_col", "row_num_2"], + ) + schema = { + "fields": [ + {"name": "num_col", "type": numeric_type}, + {"name": "not_in_df", "type": "IGNORED"}, + ] + } + result = load.cast_dataframe_for_parquet(dataframe, schema) + expected = pandas.DataFrame( + { + "row_num": [0, 1, 2], + "num_col": pandas.Series( + [decimal.Decimal(1.25), decimal.Decimal(-1.25), decimal.Decimal(42.5)], + dtype="object", + ), + "row_num_2": [0, 1, 2], + }, + columns=["row_num", "num_col", "row_num_2"], + ) + pandas.testing.assert_frame_equal(result, expected) + + +def test_cast_dataframe_for_parquet_w_string_date(): + dataframe = pandas.DataFrame( + { + "row_num": [0, 1, 2], + "date_col": pandas.Series( + ["2021-04-17", "1999-12-31", "2038-01-19"], dtype="object", + ), + "row_num_2": [0, 1, 2], + }, + # Use multiple columns to ensure column order is maintained. + columns=["row_num", "date_col", "row_num_2"], + ) + schema = { + "fields": [ + {"name": "date_col", "type": "DATE"}, + {"name": "not_in_df", "type": "IGNORED"}, + ] + } + result = load.cast_dataframe_for_parquet(dataframe, schema) + expected = pandas.DataFrame( + { + "row_num": [0, 1, 2], + "date_col": pandas.Series( + ["2021-04-17", "1999-12-31", "2038-01-19"], dtype=db_dtypes.DateDtype(), + ), + "row_num_2": [0, 1, 2], + }, + columns=["row_num", "date_col", "row_num_2"], + ) + pandas.testing.assert_frame_equal(result, expected) + + +def test_cast_dataframe_for_parquet_ignores_repeated_fields(): + dataframe = pandas.DataFrame( + { + "row_num": [0, 1, 2], + "repeated_col": pandas.Series( + [ + [datetime.date(2021, 4, 17)], + [datetime.date(199, 12, 31)], + [datetime.date(2038, 1, 19)], + ], + dtype="object", + ), + "row_num_2": [0, 1, 2], + }, + # Use multiple columns to ensure column order is maintained. + columns=["row_num", "repeated_col", "row_num_2"], + ) + expected = dataframe.copy() + schema = {"fields": [{"name": "repeated_col", "type": "DATE", "mode": "REPEATED"}]} + result = load.cast_dataframe_for_parquet(dataframe, schema) + pandas.testing.assert_frame_equal(result, expected) + + +def test_cast_dataframe_for_parquet_w_null_fields(): + dataframe = pandas.DataFrame({"int_col": [0, 1, 2], "str_col": ["a", "b", "c"]}) + expected = dataframe.copy() + schema = {"fields": None} + result = load.cast_dataframe_for_parquet(dataframe, schema) + pandas.testing.assert_frame_equal(result, expected)