From f4d1db0e1bad81c1f0fc565bfbc38b1d5be9bdae Mon Sep 17 00:00:00 2001
From: Peter Lamut <inbox@peterlamut.com>
Date: Wed, 22 Jan 2020 13:49:18 +0000
Subject: [PATCH] test(bigquery): add tests for concatenating categorical
 columns

---
 bigquery/tests/unit/test_table.py | 168 ++++++++++++++++++++++++++++++
 1 file changed, 168 insertions(+)

diff --git a/bigquery/tests/unit/test_table.py b/bigquery/tests/unit/test_table.py
index 6e8958cdc46c..079ec6e000d3 100644
--- a/bigquery/tests/unit/test_table.py
+++ b/bigquery/tests/unit/test_table.py
@@ -3242,6 +3242,174 @@ def test_to_dataframe_w_bqstorage_snapshot(self):
         with pytest.raises(ValueError):
             row_iterator.to_dataframe(bqstorage_client)
 
+    @unittest.skipIf(pandas is None, "Requires `pandas`")
+    @unittest.skipIf(
+        bigquery_storage_v1beta1 is None, "Requires `google-cloud-bigquery-storage`"
+    )
+    @unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
+    def test_to_dataframe_concat_categorical_dtype_w_pyarrow(self):
+        from google.cloud.bigquery import schema
+        from google.cloud.bigquery import table as mut
+        from google.cloud.bigquery_storage_v1beta1 import reader
+
+        arrow_fields = [
+            # Not alphabetical to test column order.
+            pyarrow.field("col_str", pyarrow.utf8()),
+            # The backend returns strings, and without other info, pyarrow contains
+            # string data in categorical columns, too (and not maybe the Dictionary
+            # type that corresponds to pandas.Categorical).
+            pyarrow.field("col_category", pyarrow.utf8()),
+        ]
+        arrow_schema = pyarrow.schema(arrow_fields)
+
+        # create a mock BQ storage client
+        bqstorage_client = mock.create_autospec(
+            bigquery_storage_v1beta1.BigQueryStorageClient
+        )
+        bqstorage_client.transport = mock.create_autospec(
+            big_query_storage_grpc_transport.BigQueryStorageGrpcTransport
+        )
+        session = bigquery_storage_v1beta1.types.ReadSession(
+            streams=[{"name": "/projects/proj/dataset/dset/tables/tbl/streams/1234"}],
+            arrow_schema={"serialized_schema": arrow_schema.serialize().to_pybytes()},
+        )
+        bqstorage_client.create_read_session.return_value = session
+
+        mock_rowstream = mock.create_autospec(reader.ReadRowsStream)
+        bqstorage_client.read_rows.return_value = mock_rowstream
+
+        # prepare the iterator over mocked rows
+        mock_rows = mock.create_autospec(reader.ReadRowsIterable)
+        mock_rowstream.rows.return_value = mock_rows
+        page_items = [
+            [
+                pyarrow.array(["foo", "bar", "baz"]),  # col_str
+                pyarrow.array(["low", "medium", "low"]),  # col_category
+            ],
+            [
+                pyarrow.array(["foo_page2", "bar_page2", "baz_page2"]),  # col_str
+                pyarrow.array(["medium", "high", "low"]),  # col_category
+            ],
+        ]
+
+        mock_pages = []
+
+        for record_list in page_items:
+            page_record_batch = pyarrow.RecordBatch.from_arrays(
+                record_list, schema=arrow_schema
+            )
+            mock_page = mock.create_autospec(reader.ReadRowsPage)
+            mock_page.to_arrow.return_value = page_record_batch
+            mock_pages.append(mock_page)
+
+        type(mock_rows).pages = mock.PropertyMock(return_value=mock_pages)
+
+        schema = [
+            schema.SchemaField("col_str", "IGNORED"),
+            schema.SchemaField("col_category", "IGNORED"),
+        ]
+
+        row_iterator = mut.RowIterator(
+            _mock_client(),
+            None,  # api_request: ignored
+            None,  # path: ignored
+            schema,
+            table=mut.TableReference.from_string("proj.dset.tbl"),
+            selected_fields=schema,
+        )
+
+        # run the method under test
+        got = row_iterator.to_dataframe(
+            bqstorage_client=bqstorage_client,
+            dtypes={
+                "col_category": pandas.core.dtypes.dtypes.CategoricalDtype(
+                    categories=["low", "medium", "high"], ordered=False,
+                ),
+            },
+        )
+
+        # Are the columns in the expected order?
+        column_names = ["col_str", "col_category"]
+        self.assertEqual(list(got), column_names)
+
+        # Have expected number of rows?
+        total_pages = len(mock_pages)  # we have a single stream, thus these two equal
+        total_rows = len(page_items[0][0]) * total_pages
+        self.assertEqual(len(got.index), total_rows)
+
+        # Are column types correct?
+        expected_dtypes = [
+            pandas.core.dtypes.dtypes.np.dtype("O"),  # the default for string data
+            pandas.core.dtypes.dtypes.CategoricalDtype(
+                categories=["low", "medium", "high"], ordered=False,
+            ),
+        ]
+        self.assertEqual(list(got.dtypes), expected_dtypes)
+
+        # And the data in the categorical column?
+        self.assertEqual(
+            list(got["col_category"]),
+            ["low", "medium", "low", "medium", "high", "low"],
+        )
+
+        # Don't close the client if it was passed in.
+        bqstorage_client.transport.channel.close.assert_not_called()
+
+    @unittest.skipIf(pandas is None, "Requires `pandas`")
+    def test_to_dataframe_concat_categorical_dtype_wo_pyarrow(self):
+        from google.cloud.bigquery.schema import SchemaField
+
+        schema = [
+            SchemaField("col_str", "STRING"),
+            SchemaField("col_category", "STRING"),
+        ]
+        row_data = [
+            [u"foo", u"low"],
+            [u"bar", u"medium"],
+            [u"baz", u"low"],
+            [u"foo_page2", u"medium"],
+            [u"bar_page2", u"high"],
+            [u"baz_page2", u"low"],
+        ]
+        path = "/foo"
+
+        rows = [{"f": [{"v": field} for field in row]} for row in row_data[:3]]
+        rows_page2 = [{"f": [{"v": field} for field in row]} for row in row_data[3:]]
+        api_request = mock.Mock(
+            side_effect=[{"rows": rows, "pageToken": "NEXTPAGE"}, {"rows": rows_page2}]
+        )
+
+        row_iterator = self._make_one(_mock_client(), api_request, path, schema)
+
+        with mock.patch("google.cloud.bigquery.table.pyarrow", None):
+            got = row_iterator.to_dataframe(
+                dtypes={
+                    "col_category": pandas.core.dtypes.dtypes.CategoricalDtype(
+                        categories=["low", "medium", "high"], ordered=False,
+                    ),
+                },
+            )
+
+        self.assertIsInstance(got, pandas.DataFrame)
+        self.assertEqual(len(got), 6)  # verify the number of rows
+        expected_columns = [field.name for field in schema]
+        self.assertEqual(list(got), expected_columns)  # verify the column names
+
+        # Are column types correct?
+        expected_dtypes = [
+            pandas.core.dtypes.dtypes.np.dtype("O"),  # the default for string data
+            pandas.core.dtypes.dtypes.CategoricalDtype(
+                categories=["low", "medium", "high"], ordered=False,
+            ),
+        ]
+        self.assertEqual(list(got.dtypes), expected_dtypes)
+
+        # And the data in the categorical column?
+        self.assertEqual(
+            list(got["col_category"]),
+            ["low", "medium", "low", "medium", "high", "low"],
+        )
+
 
 class TestPartitionRange(unittest.TestCase):
     def _get_target_class(self):