Skip to content

Commit

Permalink
[Python] Allow lambda function in bigtable handler to build a custom …
Browse files Browse the repository at this point in the history
…row key (#30974)

* allow lambda function in bigtable handler

* enable postcommit

* add exception case, test

* Update sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py

Co-authored-by: Danny McCormick <[email protected]>

---------

Co-authored-by: Danny McCormick <[email protected]>
  • Loading branch information
riteshghorse and damccorm authored Apr 16, 2024
1 parent db585b7 commit 1a26ead
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 4 deletions.
1 change: 1 addition & 0 deletions .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run"
}

1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
## New Features / Improvements

* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Bigtable enrichment handler now accepts a custom function to build a composite row key. (Python) ([#30974](https://github.com/apache/beam/issues/30975)).

## Breaking Changes

Expand Down
26 changes: 22 additions & 4 deletions sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#
import logging
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional

Expand All @@ -33,6 +34,8 @@
'BigTableEnrichmentHandler',
]

RowKeyFn = Callable[[beam.Row], bytes]

_LOGGER = logging.getLogger(__name__)


Expand All @@ -53,6 +56,8 @@ class BigTableEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]):
See https://cloud.google.com/bigtable/docs/app-profiles for more details.
encoding (str): encoding type to convert the string to bytes and vice-versa
from BigTable. Default is `utf-8`.
row_key_fn: a lambda function that returns a string row key from the
input row. It is used to build/extract the row key for Bigtable.
exception_level: a `enum.Enum` value from
``apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel``
to set the level when an empty row is returned from the BigTable query.
Expand All @@ -67,11 +72,12 @@ def __init__(
project_id: str,
instance_id: str,
table_id: str,
row_key: str,
row_key: str = "",
row_filter: Optional[RowFilter] = CellsColumnLimitFilter(1),
*,
app_profile_id: str = None, # type: ignore[assignment]
encoding: str = 'utf-8',
row_key_fn: Optional[RowKeyFn] = None,
exception_level: ExceptionLevel = ExceptionLevel.WARN,
include_timestamp: bool = False,
):
Expand All @@ -82,8 +88,15 @@ def __init__(
self._row_filter = row_filter
self._app_profile_id = app_profile_id
self._encoding = encoding
self._row_key_fn = row_key_fn
self._exception_level = exception_level
self._include_timestamp = include_timestamp
if ((not self._row_key_fn and not self._row_key) or
bool(self._row_key_fn and self._row_key)):
raise ValueError(
"Please specify exactly one of `row_key` or a lambda "
"function with `row_key_fn` to extract the row key "
"from the input row.")

def __enter__(self):
"""connect to the Google BigTable cluster."""
Expand All @@ -105,9 +118,12 @@ def __call__(self, request: beam.Row, *args, **kwargs):
response_dict: Dict[str, Any] = {}
row_key_str: str = ""
try:
request_dict = request._asdict()
row_key_str = str(request_dict[self._row_key])
row_key = row_key_str.encode(self._encoding)
if self._row_key_fn:
row_key = self._row_key_fn(request)
else:
request_dict = request._asdict()
row_key_str = str(request_dict[self._row_key])
row_key = row_key_str.encode(self._encoding)
row = self._table.read_row(row_key, filter_=self._row_filter)
if row:
for cf_id, cf_v in row.cells.items():
Expand Down Expand Up @@ -148,4 +164,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def get_cache_key(self, request: beam.Row) -> str:
"""Returns a string formatted with row key since it is unique to
a request made to `Bigtable`."""
if self._row_key_fn:
return "row_key: %s" % str(self._row_key_fn(request))
return "%s: %s" % (self._row_key, request._asdict()[self._row_key])
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
_LOGGER = logging.getLogger(__name__)


def _row_key_fn(request: beam.Row) -> bytes:
row_key = str(request.product_id) # type: ignore[attr-defined]
return row_key.encode(encoding='utf-8')


class ValidateResponse(beam.DoFn):
"""ValidateResponse validates if a PCollection of `beam.Row`
has the required fields."""
Expand Down Expand Up @@ -426,6 +431,29 @@ def test_bigtable_enrichment_with_redis(self):
expected_enriched_fields)))
BigTableEnrichmentHandler.__call__ = actual

def test_bigtable_enrichment_with_lambda(self):
expected_fields = [
'sale_id', 'customer_id', 'product_id', 'quantity', 'product'
]
expected_enriched_fields = {
'product': ['product_id', 'product_name', 'product_stock'],
}
bigtable = BigTableEnrichmentHandler(
project_id=self.project_id,
instance_id=self.instance_id,
table_id=self.table_id,
row_key_fn=_row_key_fn)
with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (
test_pipeline
| "Create" >> beam.Create(self.req)
| "Enrich W/ BigTable" >> Enrichment(bigtable)
| "Validate Response" >> beam.ParDo(
ValidateResponse(
len(expected_fields),
expected_fields,
expected_enriched_fields)))


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

from parameterized import parameterized

try:
from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler
from apache_beam.transforms.enrichment_handlers.bigtable_it_test import _row_key_fn
except ImportError:
raise unittest.SkipTest('Bigtable test dependencies are not installed.')


class TestBigTableEnrichmentHandler(unittest.TestCase):
@parameterized.expand([('product_id', _row_key_fn), ('', None)])
def test_bigtable_enrichment_invalid_args(self, row_key, row_key_fn):
with self.assertRaises(ValueError):
_ = BigTableEnrichmentHandler(
project_id='apache-beam-testing',
instance_id='beam-test',
table_id='bigtable-enrichment-test',
row_key=row_key,
row_key_fn=row_key_fn)


if __name__ == '__main__':
unittest.main()

0 comments on commit 1a26ead

Please sign in to comment.