diff --git a/google/cloud/bigtable/client.py b/google/cloud/bigtable/client.py index 0544bcb78..f75613098 100644 --- a/google/cloud/bigtable/client.py +++ b/google/cloud/bigtable/client.py @@ -55,6 +55,8 @@ from google.cloud.bigtable._helpers import _make_metadata from google.cloud.bigtable._helpers import _convert_retry_deadline +from google.cloud.bigtable.read_modify_write_rules import ReadModifyWriteRule +from google.cloud.bigtable.row_filters import RowFilter from google.cloud.bigtable.row_filters import StripValueTransformerFilter from google.cloud.bigtable.row_filters import CellsRowLimitFilter from google.cloud.bigtable.row_filters import RowFilterChain @@ -62,8 +64,6 @@ if TYPE_CHECKING: from google.cloud.bigtable.mutations_batcher import MutationsBatcher from google.cloud.bigtable import RowKeySamples - from google.cloud.bigtable.row_filters import RowFilter - from google.cloud.bigtable.read_modify_write_rules import ReadModifyWriteRule class BigtableDataClient(ClientWithProject): @@ -770,10 +770,11 @@ async def bulk_mutate_rows( async def check_and_mutate_row( self, row_key: str | bytes, - predicate: RowFilter | None, + predicate: RowFilter | dict[str, Any] | None, + *, true_case_mutations: Mutation | list[Mutation] | None = None, false_case_mutations: Mutation | list[Mutation] | None = None, - operation_timeout: int | float | None = 60, + operation_timeout: int | float | None = 20, ) -> bool: """ Mutates a row atomically based on the output of a predicate filter @@ -807,17 +808,43 @@ async def check_and_mutate_row( Raises: - GoogleAPIError exceptions from grpc call """ - raise NotImplementedError + operation_timeout = operation_timeout or self.default_operation_timeout + if operation_timeout <= 0: + raise ValueError("operation_timeout must be greater than 0") + row_key = row_key.encode("utf-8") if isinstance(row_key, str) else row_key + if true_case_mutations is not None and not isinstance( + true_case_mutations, list + ): + true_case_mutations = [true_case_mutations] + true_case_dict = [m._to_dict() for m in true_case_mutations or []] + if false_case_mutations is not None and not isinstance( + false_case_mutations, list + ): + false_case_mutations = [false_case_mutations] + false_case_dict = [m._to_dict() for m in false_case_mutations or []] + if predicate is not None and not isinstance(predicate, dict): + predicate = predicate.to_dict() + metadata = _make_metadata(self.table_name, self.app_profile_id) + result = await self.client._gapic_client.check_and_mutate_row( + request={ + "predicate_filter": predicate, + "true_mutations": true_case_dict, + "false_mutations": false_case_dict, + "table_name": self.table_name, + "row_key": row_key, + "app_profile_id": self.app_profile_id, + }, + metadata=metadata, + timeout=operation_timeout, + ) + return result.predicate_matched async def read_modify_write_row( self, row_key: str | bytes, - rules: ReadModifyWriteRule - | list[ReadModifyWriteRule] - | dict[str, Any] - | list[dict[str, Any]], + rules: ReadModifyWriteRule | list[ReadModifyWriteRule], *, - operation_timeout: int | float | None = 60, + operation_timeout: int | float | None = 20, ) -> Row: """ Reads and modifies a row atomically according to input ReadModifyWriteRules, @@ -841,7 +868,29 @@ async def read_modify_write_row( Raises: - GoogleAPIError exceptions from grpc call """ - raise NotImplementedError + operation_timeout = operation_timeout or self.default_operation_timeout + row_key = row_key.encode("utf-8") if isinstance(row_key, str) else row_key + if operation_timeout <= 0: + raise ValueError("operation_timeout must be greater than 0") + if rules is not None and not isinstance(rules, list): + rules = [rules] + if not rules: + raise ValueError("rules must contain at least one item") + # concert to dict representation + rules_dict = [rule._to_dict() for rule in rules] + metadata = _make_metadata(self.table_name, self.app_profile_id) + result = await self.client._gapic_client.read_modify_write_row( + request={ + "rules": rules_dict, + "table_name": self.table_name, + "row_key": row_key, + "app_profile_id": self.app_profile_id, + }, + metadata=metadata, + timeout=operation_timeout, + ) + # construct Row from result + return Row._from_pb(result.row) async def close(self): """ diff --git a/google/cloud/bigtable/mutations.py b/google/cloud/bigtable/mutations.py index c72f132c8..fe136f8d9 100644 --- a/google/cloud/bigtable/mutations.py +++ b/google/cloud/bigtable/mutations.py @@ -18,6 +18,8 @@ from dataclasses import dataclass from abc import ABC, abstractmethod +from google.cloud.bigtable.read_modify_write_rules import MAX_INCREMENT_VALUE + # special value for SetCell mutation timestamps. If set, server will assign a timestamp SERVER_SIDE_TIMESTAMP = -1 @@ -99,6 +101,10 @@ def __init__( if isinstance(new_value, str): new_value = new_value.encode() elif isinstance(new_value, int): + if abs(new_value) > MAX_INCREMENT_VALUE: + raise ValueError( + "int values must be between -2**63 and 2**63 (64-bit signed int)" + ) new_value = new_value.to_bytes(8, "big", signed=True) if not isinstance(new_value, bytes): raise TypeError("new_value must be bytes, str, or int") diff --git a/google/cloud/bigtable/read_modify_write_rules.py b/google/cloud/bigtable/read_modify_write_rules.py index cd6b370df..aa282b1a6 100644 --- a/google/cloud/bigtable/read_modify_write_rules.py +++ b/google/cloud/bigtable/read_modify_write_rules.py @@ -14,22 +14,59 @@ # from __future__ import annotations -from dataclasses import dataclass +import abc +# value must fit in 64-bit signed integer +MAX_INCREMENT_VALUE = (1 << 63) - 1 -class ReadModifyWriteRule: - pass + +class ReadModifyWriteRule(abc.ABC): + def __init__(self, family: str, qualifier: bytes | str): + qualifier = ( + qualifier if isinstance(qualifier, bytes) else qualifier.encode("utf-8") + ) + self.family = family + self.qualifier = qualifier + + @abc.abstractmethod + def _to_dict(self): + raise NotImplementedError -@dataclass class IncrementRule(ReadModifyWriteRule): - increment_amount: int - family: str - qualifier: bytes + def __init__(self, family: str, qualifier: bytes | str, increment_amount: int = 1): + if not isinstance(increment_amount, int): + raise TypeError("increment_amount must be an integer") + if abs(increment_amount) > MAX_INCREMENT_VALUE: + raise ValueError( + "increment_amount must be between -2**63 and 2**63 (64-bit signed int)" + ) + super().__init__(family, qualifier) + self.increment_amount = increment_amount + + def _to_dict(self): + return { + "family_name": self.family, + "column_qualifier": self.qualifier, + "increment_amount": self.increment_amount, + } -@dataclass class AppendValueRule(ReadModifyWriteRule): - append_value: bytes - family: str - qualifier: bytes + def __init__(self, family: str, qualifier: bytes | str, append_value: bytes | str): + append_value = ( + append_value.encode("utf-8") + if isinstance(append_value, str) + else append_value + ) + if not isinstance(append_value, bytes): + raise TypeError("append_value must be bytes or str") + super().__init__(family, qualifier) + self.append_value = append_value + + def _to_dict(self): + return { + "family_name": self.family, + "column_qualifier": self.qualifier, + "append_value": self.append_value, + } diff --git a/google/cloud/bigtable/row.py b/google/cloud/bigtable/row.py index a5fb033e6..5fdc1b365 100644 --- a/google/cloud/bigtable/row.py +++ b/google/cloud/bigtable/row.py @@ -18,6 +18,8 @@ from typing import Sequence, Generator, overload, Any from functools import total_ordering +from google.cloud.bigtable_v2.types import Row as RowPB + # Type aliases used internally for readability. _family_type = str _qualifier_type = bytes @@ -72,6 +74,30 @@ def _index( ).append(cell) return self._index_data + @classmethod + def _from_pb(cls, row_pb: RowPB) -> Row: + """ + Creates a row from a protobuf representation + + Row objects are not intended to be created by users. + They are returned by the Bigtable backend. + """ + row_key: bytes = row_pb.key + cell_list: list[Cell] = [] + for family in row_pb.families: + for column in family.columns: + for cell in column.cells: + new_cell = Cell( + value=cell.value, + row_key=row_key, + family=family.name, + qualifier=column.qualifier, + timestamp_micros=cell.timestamp_micros, + labels=list(cell.labels) if cell.labels else None, + ) + cell_list.append(new_cell) + return cls(row_key, cells=cell_list) + def get_cells( self, family: str | None = None, qualifier: str | bytes | None = None ) -> list[Cell]: diff --git a/tests/system/test_system.py b/tests/system/test_system.py index f6730576d..692911b10 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -19,6 +19,8 @@ from google.api_core import retry from google.api_core.exceptions import ClientError +from google.cloud.bigtable.read_modify_write_rules import MAX_INCREMENT_VALUE + TEST_FAMILY = "test-family" TEST_FAMILY_2 = "test-family-2" @@ -245,7 +247,6 @@ async def test_mutation_set_cell(table, temp_rows): mutation = SetCell( family=TEST_FAMILY, qualifier=b"test-qualifier", new_value=expected_value ) - await table.mutate_row(row_key, mutation) # ensure cell is updated @@ -282,6 +283,165 @@ async def test_bulk_mutations_set_cell(client, table, temp_rows): assert (await _retrieve_cell_value(table, row_key)) == expected_value +@pytest.mark.parametrize( + "start,increment,expected", + [ + (0, 0, 0), + (0, 1, 1), + (0, -1, -1), + (1, 0, 1), + (0, -100, -100), + (0, 3000, 3000), + (10, 4, 14), + (MAX_INCREMENT_VALUE, -MAX_INCREMENT_VALUE, 0), + (MAX_INCREMENT_VALUE, 2, -MAX_INCREMENT_VALUE), + (-MAX_INCREMENT_VALUE, -2, MAX_INCREMENT_VALUE), + ], +) +@pytest.mark.asyncio +async def test_read_modify_write_row_increment( + client, table, temp_rows, start, increment, expected +): + """ + test read_modify_write_row + """ + from google.cloud.bigtable.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + + rule = IncrementRule(family, qualifier, increment) + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert int(result[0]) == expected + # ensure that reading from server gives same value + assert (await _retrieve_cell_value(table, row_key)) == result[0].value + + +@pytest.mark.parametrize( + "start,append,expected", + [ + (b"", b"", b""), + ("", "", b""), + (b"abc", b"123", b"abc123"), + (b"abc", "123", b"abc123"), + ("", b"1", b"1"), + (b"abc", "", b"abc"), + (b"hello", b"world", b"helloworld"), + ], +) +@pytest.mark.asyncio +async def test_read_modify_write_row_append( + client, table, temp_rows, start, append, expected +): + """ + test read_modify_write_row + """ + from google.cloud.bigtable.read_modify_write_rules import AppendValueRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + + rule = AppendValueRule(family, qualifier, append) + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert result[0].value == expected + # ensure that reading from server gives same value + assert (await _retrieve_cell_value(table, row_key)) == result[0].value + + +@pytest.mark.asyncio +async def test_read_modify_write_row_chained(client, table, temp_rows): + """ + test read_modify_write_row with multiple rules + """ + from google.cloud.bigtable.read_modify_write_rules import AppendValueRule + from google.cloud.bigtable.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + start_amount = 1 + increment_amount = 10 + await temp_rows.add_row( + row_key, value=start_amount, family=family, qualifier=qualifier + ) + rule = [ + IncrementRule(family, qualifier, increment_amount), + AppendValueRule(family, qualifier, "hello"), + AppendValueRule(family, qualifier, "world"), + AppendValueRule(family, qualifier, "!"), + ] + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert result[0].family == family + assert result[0].qualifier == qualifier + # result should be a bytes number string for the IncrementRules, followed by the AppendValueRule values + assert ( + result[0].value + == (start_amount + increment_amount).to_bytes(8, "big", signed=True) + + b"helloworld!" + ) + # ensure that reading from server gives same value + assert (await _retrieve_cell_value(table, row_key)) == result[0].value + + +@pytest.mark.parametrize( + "start_val,predicate_range,expected_result", + [ + (1, (0, 2), True), + (-1, (0, 2), False), + ], +) +@pytest.mark.asyncio +async def test_check_and_mutate( + client, table, temp_rows, start_val, predicate_range, expected_result +): + """ + test that check_and_mutate_row works applies the right mutations, and returns the right result + """ + from google.cloud.bigtable.mutations import SetCell + from google.cloud.bigtable.row_filters import ValueRangeFilter + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + + await temp_rows.add_row( + row_key, value=start_val, family=family, qualifier=qualifier + ) + + false_mutation_value = b"false-mutation-value" + false_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value + ) + true_mutation_value = b"true-mutation-value" + true_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value + ) + predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) + result = await table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + assert result == expected_result + # ensure cell is updated + expected_value = true_mutation_value if expected_result else false_mutation_value + assert (await _retrieve_cell_value(table, row_key)) == expected_value + + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) @pytest.mark.asyncio async def test_read_rows_stream(table, temp_rows): diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 14da80dae..7009069d1 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -27,6 +27,9 @@ from google.api_core import exceptions as core_exceptions from google.cloud.bigtable.exceptions import InvalidChunk +from google.cloud.bigtable.read_modify_write_rules import IncrementRule +from google.cloud.bigtable.read_modify_write_rules import AppendValueRule + # try/except added for compatibility with python < 3.8 try: from unittest import mock @@ -2023,3 +2026,353 @@ async def test_bulk_mutate_row_metadata(self, include_app_profile): assert "app_profile_id=profile" in goog_metadata else: assert "app_profile_id=" not in goog_metadata + + +class TestCheckAndMutateRow: + def _make_client(self, *args, **kwargs): + from google.cloud.bigtable.client import BigtableDataClient + + return BigtableDataClient(*args, **kwargs) + + @pytest.mark.parametrize("gapic_result", [True, False]) + @pytest.mark.asyncio + async def test_check_and_mutate(self, gapic_result): + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + app_profile = "app_profile_id" + async with self._make_client() as client: + async with client.get_table( + "instance", "table", app_profile_id=app_profile + ) as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=gapic_result + ) + row_key = b"row_key" + predicate = None + true_mutations = [mock.Mock()] + false_mutations = [mock.Mock(), mock.Mock()] + operation_timeout = 0.2 + found = await table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutations, + false_case_mutations=false_mutations, + operation_timeout=operation_timeout, + ) + assert found == gapic_result + kwargs = mock_gapic.call_args[1] + request = kwargs["request"] + assert request["table_name"] == table.table_name + assert request["row_key"] == row_key + assert request["predicate_filter"] == predicate + assert request["true_mutations"] == [ + m._to_dict() for m in true_mutations + ] + assert request["false_mutations"] == [ + m._to_dict() for m in false_mutations + ] + assert request["app_profile_id"] == app_profile + assert kwargs["timeout"] == operation_timeout + + @pytest.mark.asyncio + async def test_check_and_mutate_bad_timeout(self): + """Should raise error if operation_timeout < 0""" + async with self._make_client() as client: + async with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + await table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=[mock.Mock()], + false_case_mutations=[], + operation_timeout=-1, + ) + assert str(e.value) == "operation_timeout must be greater than 0" + + @pytest.mark.asyncio + async def test_check_and_mutate_no_mutations(self): + """Requests require either true_case_mutations or false_case_mutations""" + from google.api_core.exceptions import InvalidArgument + + async with self._make_client() as client: + async with client.get_table("instance", "table") as table: + with pytest.raises(InvalidArgument) as e: + await table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=None, + false_case_mutations=None, + ) + assert "No mutations provided" in str(e.value) + + @pytest.mark.asyncio + async def test_check_and_mutate_single_mutations(self): + """if single mutations are passed, they should be internally wrapped in a list""" + from google.cloud.bigtable.mutations import SetCell + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + async with self._make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + true_mutation = SetCell("family", b"qualifier", b"value") + false_mutation = SetCell("family", b"qualifier", b"value") + await table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + kwargs = mock_gapic.call_args[1] + request = kwargs["request"] + assert request["true_mutations"] == [true_mutation._to_dict()] + assert request["false_mutations"] == [false_mutation._to_dict()] + + @pytest.mark.asyncio + async def test_check_and_mutate_predicate_object(self): + """predicate object should be converted to dict""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + mock_predicate = mock.Mock() + fake_dict = {"fake": "dict"} + mock_predicate.to_dict.return_value = fake_dict + async with self._make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + await table.check_and_mutate_row( + b"row_key", + mock_predicate, + false_case_mutations=[mock.Mock()], + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["request"]["predicate_filter"] == fake_dict + assert mock_predicate.to_dict.call_count == 1 + + @pytest.mark.asyncio + async def test_check_and_mutate_mutations_parsing(self): + """mutations objects should be converted to dicts""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + from google.cloud.bigtable.mutations import DeleteAllFromRow + + mutations = [mock.Mock() for _ in range(5)] + for idx, mutation in enumerate(mutations): + mutation._to_dict.return_value = {"fake": idx} + mutations.append(DeleteAllFromRow()) + async with self._make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + await table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=mutations[0:2], + false_case_mutations=mutations[2:], + ) + kwargs = mock_gapic.call_args[1]["request"] + assert kwargs["true_mutations"] == [{"fake": 0}, {"fake": 1}] + assert kwargs["false_mutations"] == [ + {"fake": 2}, + {"fake": 3}, + {"fake": 4}, + {"delete_from_row": {}}, + ] + assert all( + mutation._to_dict.call_count == 1 for mutation in mutations[:5] + ) + + @pytest.mark.parametrize("include_app_profile", [True, False]) + @pytest.mark.asyncio + async def test_check_and_mutate_metadata(self, include_app_profile): + """request should attach metadata headers""" + profile = "profile" if include_app_profile else None + async with self._make_client() as client: + async with client.get_table("i", "t", app_profile_id=profile) as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row", AsyncMock() + ) as mock_gapic: + await table.check_and_mutate_row(b"key", mock.Mock()) + kwargs = mock_gapic.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata + + +class TestReadModifyWriteRow: + def _make_client(self, *args, **kwargs): + from google.cloud.bigtable.client import BigtableDataClient + + return BigtableDataClient(*args, **kwargs) + + @pytest.mark.parametrize( + "call_rules,expected_rules", + [ + ( + AppendValueRule("f", "c", b"1"), + [AppendValueRule("f", "c", b"1")._to_dict()], + ), + ( + [AppendValueRule("f", "c", b"1")], + [AppendValueRule("f", "c", b"1")._to_dict()], + ), + (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_dict()]), + ( + [AppendValueRule("f", "c", b"1"), IncrementRule("f", "c", 1)], + [ + AppendValueRule("f", "c", b"1")._to_dict(), + IncrementRule("f", "c", 1)._to_dict(), + ], + ), + ], + ) + @pytest.mark.asyncio + async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): + """ + Test that the gapic call is called with given rules + """ + async with self._make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + await table.read_modify_write_row("key", call_rules) + assert mock_gapic.call_count == 1 + found_kwargs = mock_gapic.call_args_list[0][1] + assert found_kwargs["request"]["rules"] == expected_rules + + @pytest.mark.parametrize("rules", [[], None]) + @pytest.mark.asyncio + async def test_read_modify_write_no_rules(self, rules): + async with self._make_client() as client: + async with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + await table.read_modify_write_row("key", rules=rules) + assert e.value.args[0] == "rules must contain at least one item" + + @pytest.mark.asyncio + async def test_read_modify_write_call_defaults(self): + instance = "instance1" + table_id = "table1" + project = "project1" + row_key = "row_key1" + async with self._make_client(project=project) as client: + async with client.get_table(instance, table_id) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + await table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + found_kwargs = mock_gapic.call_args_list[0][1] + request = found_kwargs["request"] + assert ( + request["table_name"] + == f"projects/{project}/instances/{instance}/tables/{table_id}" + ) + assert request["app_profile_id"] is None + assert request["row_key"] == row_key.encode() + assert found_kwargs["timeout"] > 1 + + @pytest.mark.asyncio + async def test_read_modify_write_call_overrides(self): + row_key = b"row_key1" + expected_timeout = 12345 + profile_id = "profile1" + async with self._make_client() as client: + async with client.get_table( + "instance", "table_id", app_profile_id=profile_id + ) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + await table.read_modify_write_row( + row_key, + mock.Mock(), + operation_timeout=expected_timeout, + ) + assert mock_gapic.call_count == 1 + found_kwargs = mock_gapic.call_args_list[0][1] + request = found_kwargs["request"] + assert request["app_profile_id"] is profile_id + assert request["row_key"] == row_key + assert found_kwargs["timeout"] == expected_timeout + + @pytest.mark.asyncio + async def test_read_modify_write_string_key(self): + row_key = "string_row_key1" + async with self._make_client() as client: + async with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + await table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + found_kwargs = mock_gapic.call_args_list[0][1] + assert found_kwargs["request"]["row_key"] == row_key.encode() + + @pytest.mark.asyncio + async def test_read_modify_write_row_building(self): + """ + results from gapic call should be used to construct row + """ + from google.cloud.bigtable.row import Row + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + from google.cloud.bigtable_v2.types import Row as RowPB + + mock_response = ReadModifyWriteRowResponse(row=RowPB()) + async with self._make_client() as client: + async with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + with mock.patch.object(Row, "_from_pb") as constructor_mock: + mock_gapic.return_value = mock_response + await table.read_modify_write_row("key", mock.Mock()) + assert constructor_mock.call_count == 1 + constructor_mock.assert_called_once_with(mock_response.row) + + @pytest.mark.parametrize("include_app_profile", [True, False]) + @pytest.mark.asyncio + async def test_read_modify_write_metadata(self, include_app_profile): + """request should attach metadata headers""" + profile = "profile" if include_app_profile else None + async with self._make_client() as client: + async with client.get_table("i", "t", app_profile_id=profile) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row", AsyncMock() + ) as mock_gapic: + await table.read_modify_write_row("key", mock.Mock()) + kwargs = mock_gapic.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata diff --git a/tests/unit/test_mutations.py b/tests/unit/test_mutations.py index 2a376609e..5730c53c9 100644 --- a/tests/unit/test_mutations.py +++ b/tests/unit/test_mutations.py @@ -170,6 +170,17 @@ def _target_class(self): def _make_one(self, *args, **kwargs): return self._target_class()(*args, **kwargs) + @pytest.mark.parametrize("input_val", [2**64, -(2**64)]) + def test_ctor_large_int(self, input_val): + with pytest.raises(ValueError) as e: + self._make_one(family="f", qualifier=b"b", new_value=input_val) + assert "int values must be between" in str(e.value) + + @pytest.mark.parametrize("input_val", ["", "a", "abc", "hello world!"]) + def test_ctor_str_value(self, input_val): + found = self._make_one(family="f", qualifier=b"b", new_value=input_val) + assert found.new_value == input_val.encode("utf-8") + def test_ctor(self): """Ensure constructor sets expected values""" expected_family = "test-family" @@ -194,6 +205,11 @@ def test_ctor_str_inputs(self): assert instance.qualifier == expected_qualifier assert instance.new_value == expected_value + @pytest.mark.parametrize("input_val", [-20, -1, 0, 1, 100, int(2**60)]) + def test_ctor_int_value(self, input_val): + found = self._make_one(family="f", qualifier=b"b", new_value=input_val) + assert found.new_value == input_val.to_bytes(8, "big", signed=True) + @pytest.mark.parametrize( "int_value,expected_bytes", [ @@ -206,7 +222,7 @@ def test_ctor_str_inputs(self): (100, b"\x00\x00\x00\x00\x00\x00\x00d"), ], ) - def test_ctor_int_value(self, int_value, expected_bytes): + def test_ctor_int_value_bytes(self, int_value, expected_bytes): """Test with int value""" expected_family = "test-family" expected_qualifier = b"test-qualifier" diff --git a/tests/unit/test_read_modify_write_rules.py b/tests/unit/test_read_modify_write_rules.py new file mode 100644 index 000000000..02240df6d --- /dev/null +++ b/tests/unit/test_read_modify_write_rules.py @@ -0,0 +1,142 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock +except ImportError: # pragma: NO COVER + import mock # type: ignore + + +class TestBaseReadModifyWriteRule: + def _target_class(self): + from google.cloud.bigtable.read_modify_write_rules import ReadModifyWriteRule + + return ReadModifyWriteRule + + def test_abstract(self): + """should not be able to instantiate""" + with pytest.raises(TypeError): + self._target_class()(family="foo", qualifier=b"bar") + + def test__to_dict(self): + with pytest.raises(NotImplementedError): + self._target_class()._to_dict(mock.Mock()) + + +class TestIncrementRule: + def _target_class(self): + from google.cloud.bigtable.read_modify_write_rules import IncrementRule + + return IncrementRule + + @pytest.mark.parametrize( + "args,expected", + [ + (("fam", b"qual", 1), ("fam", b"qual", 1)), + (("fam", b"qual", -12), ("fam", b"qual", -12)), + (("fam", "qual", 1), ("fam", b"qual", 1)), + (("fam", "qual", 0), ("fam", b"qual", 0)), + (("", "", 0), ("", b"", 0)), + (("f", b"q"), ("f", b"q", 1)), + ], + ) + def test_ctor(self, args, expected): + instance = self._target_class()(*args) + assert instance.family == expected[0] + assert instance.qualifier == expected[1] + assert instance.increment_amount == expected[2] + + @pytest.mark.parametrize("input_amount", [1.1, None, "1", object(), "", b"", b"1"]) + def test_ctor_bad_input(self, input_amount): + with pytest.raises(TypeError) as e: + self._target_class()("fam", b"qual", input_amount) + assert "increment_amount must be an integer" in str(e.value) + + @pytest.mark.parametrize( + "large_value", [2**64, 2**64 + 1, -(2**64), -(2**64) - 1] + ) + def test_ctor_large_values(self, large_value): + with pytest.raises(ValueError) as e: + self._target_class()("fam", b"qual", large_value) + assert "too large" in str(e.value) + + @pytest.mark.parametrize( + "args,expected", + [ + (("fam", b"qual", 1), ("fam", b"qual", 1)), + (("fam", b"qual", -12), ("fam", b"qual", -12)), + (("fam", "qual", 1), ("fam", b"qual", 1)), + (("fam", "qual", 0), ("fam", b"qual", 0)), + (("", "", 0), ("", b"", 0)), + (("f", b"q"), ("f", b"q", 1)), + ], + ) + def test__to_dict(self, args, expected): + instance = self._target_class()(*args) + expected = { + "family_name": expected[0], + "column_qualifier": expected[1], + "increment_amount": expected[2], + } + assert instance._to_dict() == expected + + +class TestAppendValueRule: + def _target_class(self): + from google.cloud.bigtable.read_modify_write_rules import AppendValueRule + + return AppendValueRule + + @pytest.mark.parametrize( + "args,expected", + [ + (("fam", b"qual", b"val"), ("fam", b"qual", b"val")), + (("fam", "qual", b"val"), ("fam", b"qual", b"val")), + (("", "", b""), ("", b"", b"")), + (("f", "q", "str_val"), ("f", b"q", b"str_val")), + (("f", "q", ""), ("f", b"q", b"")), + ], + ) + def test_ctor(self, args, expected): + instance = self._target_class()(*args) + assert instance.family == expected[0] + assert instance.qualifier == expected[1] + assert instance.append_value == expected[2] + + @pytest.mark.parametrize("input_val", [5, 1.1, None, object()]) + def test_ctor_bad_input(self, input_val): + with pytest.raises(TypeError) as e: + self._target_class()("fam", b"qual", input_val) + assert "append_value must be bytes or str" in str(e.value) + + @pytest.mark.parametrize( + "args,expected", + [ + (("fam", b"qual", b"val"), ("fam", b"qual", b"val")), + (("fam", "qual", b"val"), ("fam", b"qual", b"val")), + (("", "", b""), ("", b"", b"")), + ], + ) + def test__to_dict(self, args, expected): + instance = self._target_class()(*args) + expected = { + "family_name": expected[0], + "column_qualifier": expected[1], + "append_value": expected[2], + } + assert instance._to_dict() == expected diff --git a/tests/unit/test_row.py b/tests/unit/test_row.py index 1af09aad9..0413b2889 100644 --- a/tests/unit/test_row.py +++ b/tests/unit/test_row.py @@ -55,6 +55,50 @@ def test_ctor(self): self.assertEqual(list(row_response), cells) self.assertEqual(row_response.row_key, TEST_ROW_KEY) + def test__from_pb(self): + """ + Construct from protobuf. + """ + from google.cloud.bigtable_v2.types import Row as RowPB + from google.cloud.bigtable_v2.types import Family as FamilyPB + from google.cloud.bigtable_v2.types import Column as ColumnPB + from google.cloud.bigtable_v2.types import Cell as CellPB + + row_key = b"row_key" + cells = [ + CellPB( + value=str(i).encode(), + timestamp_micros=TEST_TIMESTAMP, + labels=TEST_LABELS, + ) + for i in range(2) + ] + column = ColumnPB(qualifier=TEST_QUALIFIER, cells=cells) + families_pb = [FamilyPB(name=TEST_FAMILY_ID, columns=[column])] + row_pb = RowPB(key=row_key, families=families_pb) + output = self._get_target_class()._from_pb(row_pb) + self.assertEqual(output.row_key, row_key) + self.assertEqual(len(output), 2) + self.assertEqual(output[0].value, b"0") + self.assertEqual(output[1].value, b"1") + self.assertEqual(output[0].timestamp_micros, TEST_TIMESTAMP) + self.assertEqual(output[0].labels, TEST_LABELS) + assert output[0].row_key == row_key + assert output[0].family == TEST_FAMILY_ID + assert output[0].qualifier == TEST_QUALIFIER + + def test__from_pb_sparse(self): + """ + Construct from minimal protobuf. + """ + from google.cloud.bigtable_v2.types import Row as RowPB + + row_key = b"row_key" + row_pb = RowPB(key=row_key) + output = self._get_target_class()._from_pb(row_pb) + self.assertEqual(output.row_key, row_key) + self.assertEqual(len(output), 0) + def test_get_cells(self): cell_list = [] for family_id in ["1", "2"]: