Skip to content

Commit

Permalink
[Data] Include ray user-agent in BigQuery client construction (#49922)
Browse files Browse the repository at this point in the history
  • Loading branch information
tswast authored Jan 18, 2025
1 parent 3cfe8ae commit 3bd3865
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 9 deletions.
6 changes: 3 additions & 3 deletions python/ray/data/_internal/datasource/bigquery_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ray
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.datasource import bigquery_datasource
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import _check_import
from ray.data.block import Block, BlockAccessor
Expand Down Expand Up @@ -39,13 +40,12 @@ def __init__(

def on_write_start(self) -> None:
from google.api_core import exceptions
from google.cloud import bigquery

if self.project_id is None or self.dataset is None:
raise ValueError("project_id and dataset are required args")

# Set up datasets to write
client = bigquery.Client(project=self.project_id)
client = bigquery_datasource._create_client(project_id=self.project_id)
dataset_id = self.dataset.split(".", 1)[0]
try:
client.get_dataset(dataset_id)
Expand Down Expand Up @@ -77,7 +77,7 @@ def _write_single_block(block: Block, project_id: str, dataset: str) -> None:

block = BlockAccessor.for_block(block).to_arrow()

client = bigquery.Client(project=project_id)
client = bigquery_datasource._create_client(project=project_id)
job_config = bigquery.LoadJobConfig(autodetect=True)
job_config.source_format = bigquery.SourceFormat.PARQUET
job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
Expand Down
50 changes: 44 additions & 6 deletions python/ray/data/_internal/datasource/bigquery_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,45 @@
logger = logging.getLogger(__name__)


def _create_user_agent() -> str:
import ray

return f"ray/{ray.__version__}"


def _create_client_info():
from google.api_core.client_info import ClientInfo

return ClientInfo(
user_agent=_create_user_agent(),
)


def _create_client_info_gapic():
from google.api_core.gapic_v1.client_info import ClientInfo

return ClientInfo(
user_agent=_create_user_agent(),
)


def _create_client(project_id: str):
from google.cloud import bigquery

return bigquery.Client(
project=project_id,
client_info=_create_client_info(),
)


def _create_read_client():
from google.cloud import bigquery_storage

return bigquery_storage.BigQueryReadClient(
client_info=_create_client_info_gapic(),
)


class BigQueryDatasource(Datasource):
def __init__(
self,
Expand All @@ -30,15 +69,15 @@ def __init__(
)

def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
from google.cloud import bigquery, bigquery_storage
from google.cloud import bigquery_storage

def _read_single_partition(stream) -> Block:
client = bigquery_storage.BigQueryReadClient()
client = _create_read_client()
reader = client.read_rows(stream.name)
return reader.to_arrow()

if self._query:
query_client = bigquery.Client(project=self._project_id)
query_client = _create_client(project_id=self._project_id)
query_job = query_client.query(self._query)
query_job.result()
destination = str(query_job.destination)
Expand All @@ -49,7 +88,7 @@ def _read_single_partition(stream) -> Block:
dataset_id = self._dataset.split(".")[0]
table_id = self._dataset.split(".")[1]

bqs_client = bigquery_storage.BigQueryReadClient()
bqs_client = _create_read_client()
table = f"projects/{self._project_id}/datasets/{dataset_id}/tables/{table_id}"

if parallelism == -1:
Expand Down Expand Up @@ -97,9 +136,8 @@ def estimate_inmemory_data_size(self) -> Optional[int]:

def _validate_dataset_table_exist(self, project_id: str, dataset: str) -> None:
from google.api_core import exceptions
from google.cloud import bigquery

client = bigquery.Client(project=project_id)
client = _create_client(project_id=project_id)
dataset_id = dataset.split(".")[0]
try:
client.get_dataset(dataset_id)
Expand Down

0 comments on commit 3bd3865

Please sign in to comment.