diff --git a/providers/src/airflow/providers/google/cloud/operators/bigquery_dts.py b/providers/src/airflow/providers/google/cloud/operators/bigquery_dts.py index a9da994ff6fb0..bed14add1911f 100644 --- a/providers/src/airflow/providers/google/cloud/operators/bigquery_dts.py +++ b/providers/src/airflow/providers/google/cloud/operators/bigquery_dts.py @@ -299,6 +299,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.deferrable = deferrable + self._transfer_run: dict = {} @cached_property def hook(self) -> BiqQueryDataTransferServiceHook: @@ -339,12 +340,13 @@ def execute(self, context: Context): self.xcom_push(context, key="run_id", value=run_id) if not self.deferrable: - result = self._wait_for_transfer_to_be_done( + # Save as attribute for further use by OpenLineage + self._transfer_run = self._wait_for_transfer_to_be_done( run_id=run_id, transfer_config_id=transfer_config["config_id"], ) self.log.info("Transfer run %s submitted successfully.", run_id) - return result + return self._transfer_run self.defer( trigger=BigQueryDataTransferRunTrigger( @@ -412,4 +414,117 @@ def execute_completed(self, context: Context, event: dict): event["message"], ) - return TransferRun.to_dict(transfer_run) + # Save as attribute for further use by OpenLineage + self._transfer_run = TransferRun.to_dict(transfer_run) + return self._transfer_run + + def get_openlineage_facets_on_complete(self, _): + """Implement _on_complete as we need a run config to extract information.""" + from urllib.parse import urlsplit + + from airflow.providers.common.compat.openlineage.facet import Dataset, ErrorMessageRunFacet + from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url + from airflow.providers.google.cloud.openlineage.utils import ( + BIGQUERY_NAMESPACE, + extract_ds_name_from_gcs_path, + ) + from airflow.providers.openlineage.extractors import OperatorLineage + from airflow.providers.openlineage.sqlparser import DatabaseInfo, SQLParser + + if not self._transfer_run: + self.log.debug("No BigQuery Data Transfer configuration was found by OpenLineage.") + return OperatorLineage() + + data_source_id = self._transfer_run["data_source_id"] + dest_dataset_id = self._transfer_run["destination_dataset_id"] + params = self._transfer_run["params"] + + input_datasets, output_datasets = [], [] + run_facets, job_facets = {}, {} + if data_source_id in ("google_cloud_storage", "amazon_s3", "azure_blob_storage"): + if data_source_id == "google_cloud_storage": + bucket, path = _parse_gcs_url(params["data_path_template"]) # gs://bucket... + namespace = f"gs://{bucket}" + name = extract_ds_name_from_gcs_path(path) + elif data_source_id == "amazon_s3": + parsed_url = urlsplit(params["data_path"]) # s3://bucket... + namespace = f"s3://{parsed_url.netloc}" + name = extract_ds_name_from_gcs_path(parsed_url.path) + else: # azure_blob_storage + storage_account = params["storage_account"] + container = params["container"] + namespace = f"abfss://{container}@{storage_account}.dfs.core.windows.net" + name = extract_ds_name_from_gcs_path(params["data_path"]) + + input_datasets.append(Dataset(namespace=namespace, name=name)) + dest_table_name = params["destination_table_name_template"] + output_datasets.append( + Dataset( + namespace=BIGQUERY_NAMESPACE, + name=f"{self.project_id}.{dest_dataset_id}.{dest_table_name}", + ) + ) + elif data_source_id in ("postgresql", "oracle", "mysql"): + scheme = data_source_id if data_source_id != "postgresql" else "postgres" + host = params["connector.endpoint.host"] + port = params["connector.endpoint.port"] + + for asset in params["assets"]: + # MySQL: db/table; Other: db/schema/table; + table_name = asset.split("/")[-1] + + input_datasets.append( + Dataset(namespace=f"{scheme}://{host}:{int(port)}", name=asset.replace("/", ".")) + ) + output_datasets.append( + Dataset( + namespace=BIGQUERY_NAMESPACE, name=f"{self.project_id}.{dest_dataset_id}.{table_name}" + ) + ) + elif data_source_id == "scheduled_query": + bq_db_info = DatabaseInfo( + scheme="bigquery", + authority=None, + database=self.project_id, + ) + parser_result = SQLParser("bigquery").generate_openlineage_metadata_from_sql( + sql=params["query"], + database_info=bq_db_info, + database=self.project_id, + use_connection=False, + hook=None, # Hook is not used when use_connection=False + sqlalchemy_engine=None, + ) + if parser_result.inputs: + input_datasets.extend(parser_result.inputs) + if parser_result.outputs: + output_datasets.extend(parser_result.outputs) + if parser_result.job_facets: + job_facets = {**job_facets, **parser_result.job_facets} + if parser_result.run_facets: + run_facets = {**run_facets, **parser_result.run_facets} + dest_table_name = params.get("destination_table_name_template") + if dest_table_name: + output_datasets.append( + Dataset( + namespace=BIGQUERY_NAMESPACE, + name=f"{self.project_id}.{dest_dataset_id}.{dest_table_name}", + ) + ) + else: + self.log.debug( + "BigQuery Data Transfer data_source_id `%s` is not supported by OpenLineage.", data_source_id + ) + return OperatorLineage() + + error_status = self._transfer_run.get("error_status") + if error_status and str(error_status["code"]) != "0": + run_facets["errorMessage"] = ErrorMessageRunFacet( + message=error_status["message"], + programmingLanguage="python", + stackTrace=str(error_status["details"]), + ) + + return OperatorLineage( + inputs=input_datasets, outputs=output_datasets, job_facets=job_facets, run_facets=run_facets + ) diff --git a/providers/tests/google/cloud/operators/test_bigquery_dts.py b/providers/tests/google/cloud/operators/test_bigquery_dts.py index c9379abddb985..f9188dadd7acb 100644 --- a/providers/tests/google/cloud/operators/test_bigquery_dts.py +++ b/providers/tests/google/cloud/operators/test_bigquery_dts.py @@ -19,6 +19,7 @@ from unittest import mock +import pytest from google.api_core.gapic_v1.method import DEFAULT from google.cloud.bigquery_datatransfer_v1 import StartManualTransferRunsResponse, TransferConfig, TransferRun @@ -156,3 +157,258 @@ def test_defer_mode(self, _, defer_method): op.execute({"ti": ti}) defer_method.assert_called_once() + + @pytest.mark.parametrize( + ("data_source_id", "params", "expected_input_namespace", "expected_input_name"), + ( + ( + "google_cloud_storage", + { + "data_path_template": "gs://bucket/path/to/file.txt", + "destination_table_name_template": "bq_table", + }, + "gs://bucket", + "path/to/file.txt", + ), + ( + "amazon_s3", + { + "data_path": "s3://bucket/path/to/file.txt", + "destination_table_name_template": "bq_table", + }, + "s3://bucket", + "path/to/file.txt", + ), + ( + "azure_blob_storage", + { + "storage_account": "account_id", + "container": "container_id", + "data_path": "/path/to/file.txt", + "destination_table_name_template": "bq_table", + }, + "abfss://container_id@account_id.dfs.core.windows.net", + "path/to/file.txt", + ), + ), + ) + @mock.patch( + f"{OPERATOR_MODULE_PATH}.BigQueryDataTransferServiceStartTransferRunsOperator" + f"._wait_for_transfer_to_be_done" + ) + @mock.patch(f"{OPERATOR_MODULE_PATH}.BiqQueryDataTransferServiceHook") + def test_get_openlineage_facets_on_complete_with_blob_storage_sources( + self, mock_hook, mock_wait, data_source_id, params, expected_input_namespace, expected_input_name + ): + mock_hook.return_value.start_manual_transfer_runs.return_value = StartManualTransferRunsResponse( + runs=[TransferRun(name=RUN_NAME)] + ) + mock_wait.return_value = { + "error_status": {"code": 0, "message": "", "details": []}, + "data_source_id": data_source_id, + "destination_dataset_id": "bq_dataset", + "params": params, + } + + op = BigQueryDataTransferServiceStartTransferRunsOperator( + transfer_config_id=TRANSFER_CONFIG_ID, task_id="id", project_id=PROJECT_ID + ) + op.execute({"ti": mock.MagicMock()}) + result = op.get_openlineage_facets_on_complete(None) + assert not result.run_facets + assert not result.job_facets + assert len(result.inputs) == 1 + assert len(result.outputs) == 1 + assert result.inputs[0].namespace == expected_input_namespace + assert result.inputs[0].name == expected_input_name + assert result.outputs[0].namespace == "bigquery" + assert result.outputs[0].name == f"{PROJECT_ID}.bq_dataset.bq_table" + + @pytest.mark.parametrize( + ("data_source_id", "params", "expected_input_namespace", "expected_input_names"), + ( + ( + "postgresql", + { + "connector.endpoint.host": "127.0.0.1", + "connector.endpoint.port": 5432, + "assets": [ + "db1/sch1/tb1", + "db2/sch2/tb2", + ], + }, + "postgres://127.0.0.1:5432", + ["db1.sch1.tb1", "db2.sch2.tb2"], + ), + ( + "oracle", + { + "connector.endpoint.host": "127.0.0.1", + "connector.endpoint.port": 1234, + "assets": [ + "db1/sch1/tb1", + "db2/sch2/tb2", + ], + }, + "oracle://127.0.0.1:1234", + ["db1.sch1.tb1", "db2.sch2.tb2"], + ), + ( + "mysql", + { + "connector.endpoint.host": "127.0.0.1", + "connector.endpoint.port": 3306, + "assets": [ + "db1/tb1", + "db2/tb2", + ], + }, + "mysql://127.0.0.1:3306", + ["db1.tb1", "db2.tb2"], + ), + ), + ) + @mock.patch( + f"{OPERATOR_MODULE_PATH}.BigQueryDataTransferServiceStartTransferRunsOperator" + f"._wait_for_transfer_to_be_done" + ) + @mock.patch(f"{OPERATOR_MODULE_PATH}.BiqQueryDataTransferServiceHook") + def test_get_openlineage_facets_on_complete_with_sql_sources( + self, mock_hook, mock_wait, data_source_id, params, expected_input_namespace, expected_input_names + ): + mock_hook.return_value.start_manual_transfer_runs.return_value = StartManualTransferRunsResponse( + runs=[TransferRun(name=RUN_NAME)] + ) + mock_wait.return_value = { + "error_status": {"code": 0, "message": "", "details": []}, + "data_source_id": data_source_id, + "destination_dataset_id": "bq_dataset", + "params": params, + } + + op = BigQueryDataTransferServiceStartTransferRunsOperator( + transfer_config_id=TRANSFER_CONFIG_ID, task_id="id", project_id=PROJECT_ID + ) + op.execute({"ti": mock.MagicMock()}) + result = op.get_openlineage_facets_on_complete(None) + assert not result.run_facets + assert not result.job_facets + assert len(result.inputs) == 2 + assert len(result.outputs) == 2 + assert result.inputs[0].namespace == expected_input_namespace + assert result.inputs[0].name == expected_input_names[0] + assert result.inputs[1].namespace == expected_input_namespace + assert result.inputs[1].name == expected_input_names[1] + assert result.outputs[0].namespace == "bigquery" + assert result.outputs[0].name == f"{PROJECT_ID}.bq_dataset.{expected_input_names[0].split('.')[-1]}" + assert result.outputs[1].namespace == "bigquery" + assert result.outputs[1].name == f"{PROJECT_ID}.bq_dataset.{expected_input_names[1].split('.')[-1]}" + + @mock.patch( + f"{OPERATOR_MODULE_PATH}.BigQueryDataTransferServiceStartTransferRunsOperator" + f"._wait_for_transfer_to_be_done" + ) + @mock.patch(f"{OPERATOR_MODULE_PATH}.BiqQueryDataTransferServiceHook") + def test_get_openlineage_facets_on_complete_with_scheduled_query(self, mock_hook, mock_wait): + mock_hook.return_value.start_manual_transfer_runs.return_value = StartManualTransferRunsResponse( + runs=[TransferRun(name=RUN_NAME)] + ) + mock_wait.return_value = { + "error_status": {"code": 0, "message": "", "details": []}, + "data_source_id": "scheduled_query", + "destination_dataset_id": "bq_dataset", + "params": {"query": "SELECT a,b,c from x.y.z;", "destination_table_name_template": "bq_table"}, + } + + op = BigQueryDataTransferServiceStartTransferRunsOperator( + transfer_config_id=TRANSFER_CONFIG_ID, task_id="id", project_id=PROJECT_ID + ) + op.execute({"ti": mock.MagicMock()}) + result = op.get_openlineage_facets_on_complete(None) + assert len(result.job_facets) == 1 + assert result.job_facets["sql"].query == "SELECT a,b,c from x.y.z" + assert not result.run_facets + assert len(result.inputs) == 1 + assert len(result.outputs) == 1 + assert result.inputs[0].namespace == "bigquery" + assert result.inputs[0].name == "x.y.z" + assert result.outputs[0].namespace == "bigquery" + assert result.outputs[0].name == f"{PROJECT_ID}.bq_dataset.bq_table" + + @mock.patch( + f"{OPERATOR_MODULE_PATH}.BigQueryDataTransferServiceStartTransferRunsOperator" + f"._wait_for_transfer_to_be_done" + ) + @mock.patch(f"{OPERATOR_MODULE_PATH}.BiqQueryDataTransferServiceHook") + def test_get_openlineage_facets_on_complete_with_error(self, mock_hook, mock_wait): + mock_hook.return_value.start_manual_transfer_runs.return_value = StartManualTransferRunsResponse( + runs=[TransferRun(name=RUN_NAME)] + ) + mock_wait.return_value = { + "error_status": { + "code": 1, + "message": "Sample message error.", + "details": [{"@type": "123", "field1": "test1"}, {"@type": "456", "field2": "test2"}], + }, + "data_source_id": "google_cloud_storage", + "destination_dataset_id": "bq_dataset", + "params": { + "data_path_template": "gs://bucket/path/to/file.txt", + "destination_table_name_template": "bq_table", + }, + } + + op = BigQueryDataTransferServiceStartTransferRunsOperator( + transfer_config_id=TRANSFER_CONFIG_ID, task_id="id", project_id=PROJECT_ID + ) + op.execute({"ti": mock.MagicMock()}) + result = op.get_openlineage_facets_on_complete(None) + assert not result.job_facets + assert len(result.run_facets) == 1 + assert result.run_facets["errorMessage"].message == "Sample message error." + assert ( + result.run_facets["errorMessage"].stackTrace + == "[{'@type': '123', 'field1': 'test1'}, {'@type': '456', 'field2': 'test2'}]" + ) + assert len(result.inputs) == 1 + assert len(result.outputs) == 1 + assert result.inputs[0].namespace == "gs://bucket" + assert result.inputs[0].name == "path/to/file.txt" + assert result.outputs[0].namespace == "bigquery" + assert result.outputs[0].name == f"{PROJECT_ID}.bq_dataset.bq_table" + + @mock.patch(f"{OPERATOR_MODULE_PATH}.BiqQueryDataTransferServiceHook") + @mock.patch(f"{OPERATOR_MODULE_PATH}.BigQueryDataTransferServiceStartTransferRunsOperator.defer") + def test_get_openlineage_facets_on_complete_deferred(self, mock_defer, mock_hook): + mock_hook.return_value.start_manual_transfer_runs.return_value = StartManualTransferRunsResponse( + runs=[TransferRun(name=RUN_NAME)] + ) + mock_hook.return_value.get_transfer_run.return_value = TransferRun( + { + "error_status": {"code": 0, "message": "", "details": []}, + "data_source_id": "google_cloud_storage", + "destination_dataset_id": "bq_dataset", + "params": { + "data_path_template": "gs://bucket/path/to/file.txt", + "destination_table_name_template": "bq_table", + }, + } + ) + + op = BigQueryDataTransferServiceStartTransferRunsOperator( + transfer_config_id=TRANSFER_CONFIG_ID, task_id="id", project_id=PROJECT_ID, deferrable=True + ) + op.execute({"ti": mock.MagicMock()}) + # `defer` is mocked so it will not call the `execute_completed`, so we do it manually. + op.execute_completed( + mock.MagicMock(), {"status": "done", "run_id": 123, "config_id": 321, "message": "msg"} + ) + result = op.get_openlineage_facets_on_complete(None) + assert not result.job_facets + assert not result.run_facets + assert len(result.inputs) == 1 + assert len(result.outputs) == 1 + assert result.inputs[0].namespace == "gs://bucket" + assert result.inputs[0].name == "path/to/file.txt" + assert result.outputs[0].namespace == "bigquery" + assert result.outputs[0].name == f"{PROJECT_ID}.bq_dataset.bq_table"