diff --git a/pandas_gbq/gbq.py b/pandas_gbq/gbq.py index 0edac95d..1157c37b 100644 --- a/pandas_gbq/gbq.py +++ b/pandas_gbq/gbq.py @@ -549,6 +549,7 @@ def load_data( schema=None, progress_bar=True, api_method: str = "load_parquet", + billing_project: Optional[str] = None, ): from pandas_gbq import load @@ -563,6 +564,7 @@ def load_data( schema=schema, location=self.location, api_method=api_method, + billing_project=billing_project, ) if progress_bar and tqdm: chunks = tqdm.tqdm(chunks) @@ -575,8 +577,8 @@ def load_data( except self.http_error as ex: self.process_http_error(ex) - def delete_and_recreate_table(self, dataset_id, table_id, table_schema): - table = _Table(self.project_id, dataset_id, credentials=self.credentials) + def delete_and_recreate_table(self, project_id, dataset_id, table_id, table_schema): + table = _Table(project_id, dataset_id, credentials=self.credentials) table.delete(table_id) table.create(table_id, table_schema) @@ -1113,7 +1115,9 @@ def to_gbq( "'append' or 'replace' data." ) elif if_exists == "replace": - connector.delete_and_recreate_table(dataset_id, table_id, table_schema) + connector.delete_and_recreate_table( + project_id_table, dataset_id, table_id, table_schema + ) else: if not pandas_gbq.schema.schema_is_subset(original_schema, table_schema): raise InvalidSchema( @@ -1142,6 +1146,7 @@ def to_gbq( schema=table_schema, progress_bar=progress_bar, api_method=api_method, + billing_project=project_id, ) diff --git a/pandas_gbq/load.py b/pandas_gbq/load.py index 588a6719..e52952f2 100644 --- a/pandas_gbq/load.py +++ b/pandas_gbq/load.py @@ -114,6 +114,7 @@ def load_parquet( destination_table_ref: bigquery.TableReference, location: Optional[str], schema: Optional[Dict[str, Any]], + billing_project: Optional[str] = None, ): job_config = bigquery.LoadJobConfig() job_config.write_disposition = "WRITE_APPEND" @@ -126,7 +127,11 @@ def load_parquet( try: client.load_table_from_dataframe( - dataframe, destination_table_ref, job_config=job_config, location=location, + dataframe, + destination_table_ref, + job_config=job_config, + location=location, + project=billing_project, ).result() except pyarrow.lib.ArrowInvalid as exc: raise exceptions.ConversionError( @@ -162,6 +167,7 @@ def load_csv_from_dataframe( location: Optional[str], chunksize: Optional[int], schema: Optional[Dict[str, Any]], + billing_project: Optional[str] = None, ): bq_schema = None @@ -171,7 +177,11 @@ def load_csv_from_dataframe( def load_chunk(chunk, job_config): client.load_table_from_dataframe( - chunk, destination_table_ref, job_config=job_config, location=location, + chunk, + destination_table_ref, + job_config=job_config, + location=location, + project=billing_project, ).result() return load_csv(dataframe, chunksize, bq_schema, load_chunk) @@ -184,6 +194,7 @@ def load_csv_from_file( location: Optional[str], chunksize: Optional[int], schema: Optional[Dict[str, Any]], + billing_project: Optional[str] = None, ): """Manually encode a DataFrame to CSV and use the buffer in a load job. @@ -204,6 +215,7 @@ def load_chunk(chunk, job_config): destination_table_ref, job_config=job_config, location=location, + project=billing_project, ).result() finally: chunk_buffer.close() @@ -219,19 +231,39 @@ def load_chunks( schema=None, location=None, api_method="load_parquet", + billing_project: Optional[str] = None, ): if api_method == "load_parquet": - load_parquet(client, dataframe, destination_table_ref, location, schema) + load_parquet( + client, + dataframe, + destination_table_ref, + location, + schema, + billing_project=billing_project, + ) # TODO: yield progress depending on result() with timeout return [0] elif api_method == "load_csv": if FEATURES.bigquery_has_from_dataframe_with_csv: return load_csv_from_dataframe( - client, dataframe, destination_table_ref, location, chunksize, schema + client, + dataframe, + destination_table_ref, + location, + chunksize, + schema, + billing_project=billing_project, ) else: return load_csv_from_file( - client, dataframe, destination_table_ref, location, chunksize, schema + client, + dataframe, + destination_table_ref, + location, + chunksize, + schema, + billing_project=billing_project, ) else: raise ValueError( diff --git a/tests/unit/test_to_gbq.py b/tests/unit/test_to_gbq.py index 22c542f1..a2fa800c 100644 --- a/tests/unit/test_to_gbq.py +++ b/tests/unit/test_to_gbq.py @@ -131,6 +131,46 @@ def test_to_gbq_with_if_exists_replace(mock_bigquery_client): assert mock_bigquery_client.create_table.called +def test_to_gbq_with_if_exists_replace_cross_project( + mock_bigquery_client, expected_load_method +): + mock_bigquery_client.get_table.side_effect = ( + # Initial check + google.cloud.bigquery.Table("data-project.my_dataset.my_table"), + # Recreate check + google.api_core.exceptions.NotFound("my_table"), + ) + gbq.to_gbq( + DataFrame([[1]]), + "data-project.my_dataset.my_table", + project_id="billing-project", + if_exists="replace", + ) + # TODO: We can avoid these API calls by using write disposition in the load + # job. See: https://github.com/googleapis/python-bigquery-pandas/issues/118 + assert mock_bigquery_client.delete_table.called + args, _ = mock_bigquery_client.delete_table.call_args + table_delete: google.cloud.bigquery.TableReference = args[0] + assert table_delete.project == "data-project" + assert table_delete.dataset_id == "my_dataset" + assert table_delete.table_id == "my_table" + assert mock_bigquery_client.create_table.called + args, _ = mock_bigquery_client.create_table.call_args + table_create: google.cloud.bigquery.TableReference = args[0] + assert table_create.project == "data-project" + assert table_create.dataset_id == "my_dataset" + assert table_create.table_id == "my_table" + + # Check that billing project and destination table is set correctly. + expected_load_method.assert_called_once() + load_args, load_kwargs = expected_load_method.call_args + table_destination = load_args[1] + assert table_destination.project == "data-project" + assert table_destination.dataset_id == "my_dataset" + assert table_destination.table_id == "my_table" + assert load_kwargs["project"] == "billing-project" + + def test_to_gbq_with_if_exists_unknown(): with pytest.raises(ValueError): gbq.to_gbq(