diff --git a/bigquery_storage/google/cloud/bigquery_storage_v1beta1/reader.py b/bigquery_storage/google/cloud/bigquery_storage_v1beta1/reader.py index b2c4a704afe3..7e07392d0aea 100644 --- a/bigquery_storage/google/cloud/bigquery_storage_v1beta1/reader.py +++ b/bigquery_storage/google/cloud/bigquery_storage_v1beta1/reader.py @@ -43,6 +43,16 @@ _STREAM_RESUMPTION_EXCEPTIONS = (google.api_core.exceptions.ServiceUnavailable,) +# The Google API endpoint can unexpectedly close long-running HTTP/2 streams. +# Unfortunately, this condition is surfaced to the caller as an internal error +# by gRPC. We don't want to resume on all internal errors, so instead we look +# for error message that we know are caused by problems that are safe to +# reconnect. +_STREAM_RESUMPTION_INTERNAL_ERROR_MESSAGES = ( + # See: https://github.com/googleapis/google-cloud-python/pull/9994 + "RST_STREAM", +) + _FASTAVRO_REQUIRED = ( "fastavro is required to parse ReadRowResponse messages with Avro bytes." ) @@ -131,6 +141,13 @@ def __iter__(self): yield message return # Made it through the whole stream. + except google.api_core.exceptions.InternalServerError as exc: + resumable_error = any( + resumable_message in exc.message + for resumable_message in _STREAM_RESUMPTION_INTERNAL_ERROR_MESSAGES + ) + if not resumable_error: + raise except _STREAM_RESUMPTION_EXCEPTIONS: # Transient error, so reconnect to the stream. pass diff --git a/bigquery_storage/tests/unit/test_reader.py b/bigquery_storage/tests/unit/test_reader.py index 09d2a6b69503..3d5127522eea 100644 --- a/bigquery_storage/tests/unit/test_reader.py +++ b/bigquery_storage/tests/unit/test_reader.py @@ -154,11 +154,13 @@ def _bq_to_arrow_batch_objects(bq_blocks, arrow_schema): arrays.append( pyarrow.array( (row[name] for row in block), - type=arrow_schema.field_by_name(name).type, + type=arrow_schema.field(name).type, size=len(block), ) ) - arrow_batches.append(pyarrow.RecordBatch.from_arrays(arrays, arrow_schema)) + arrow_batches.append( + pyarrow.RecordBatch.from_arrays(arrays, schema=arrow_schema) + ) return arrow_batches @@ -173,6 +175,22 @@ def _bq_to_arrow_batches(bq_blocks, arrow_schema): return arrow_batches +def _pages_w_nonresumable_internal_error(avro_blocks): + for block in avro_blocks: + yield block + raise google.api_core.exceptions.InternalServerError( + "INTERNAL: Got a nonresumable error." + ) + + +def _pages_w_resumable_internal_error(avro_blocks): + for block in avro_blocks: + yield block + raise google.api_core.exceptions.InternalServerError( + "INTERNAL: Received RST_STREAM with error code 2." + ) + + def _pages_w_unavailable(pages): for page in pages: yield page @@ -363,6 +381,29 @@ def test_rows_w_timeout(class_under_test, mock_client): mock_client.read_rows.assert_not_called() +def test_rows_w_nonresumable_internal_error(class_under_test, mock_client): + bq_columns = [{"name": "int_col", "type": "int64"}] + avro_schema = _bq_to_avro_schema(bq_columns) + read_session = _generate_avro_read_session(avro_schema) + bq_blocks = [[{"int_col": 1024}, {"int_col": 512}], [{"int_col": 256}]] + avro_blocks = _pages_w_nonresumable_internal_error( + _bq_to_avro_blocks(bq_blocks, avro_schema) + ) + + stream_position = bigquery_storage_v1beta1.types.StreamPosition( + stream={"name": "test"} + ) + + reader = class_under_test(avro_blocks, mock_client, stream_position, {}) + + with pytest.raises( + google.api_core.exceptions.InternalServerError, match="nonresumable error" + ): + list(reader.rows(read_session)) + + mock_client.read_rows.assert_not_called() + + def test_rows_w_reconnect(class_under_test, mock_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) @@ -372,13 +413,18 @@ def test_rows_w_reconnect(class_under_test, mock_client): [{"int_col": 345}, {"int_col": 456}], ] avro_blocks_1 = _pages_w_unavailable(_bq_to_avro_blocks(bq_blocks_1, avro_schema)) - bq_blocks_2 = [[{"int_col": 567}, {"int_col": 789}], [{"int_col": 890}]] + bq_blocks_2 = [[{"int_col": 1024}, {"int_col": 512}], [{"int_col": 256}]] avro_blocks_2 = _bq_to_avro_blocks(bq_blocks_2, avro_schema) + avro_blocks_2 = _pages_w_resumable_internal_error( + _bq_to_avro_blocks(bq_blocks_2, avro_schema) + ) + bq_blocks_3 = [[{"int_col": 567}, {"int_col": 789}], [{"int_col": 890}]] + avro_blocks_3 = _bq_to_avro_blocks(bq_blocks_3, avro_schema) - for block in avro_blocks_2: + for block in avro_blocks_3: block.status.estimated_row_count = 7 - mock_client.read_rows.return_value = avro_blocks_2 + mock_client.read_rows.side_effect = (avro_blocks_2, avro_blocks_3) stream_position = bigquery_storage_v1beta1.types.StreamPosition( stream={"name": "test"} ) @@ -395,17 +441,24 @@ def test_rows_w_reconnect(class_under_test, mock_client): itertools.chain( itertools.chain.from_iterable(bq_blocks_1), itertools.chain.from_iterable(bq_blocks_2), + itertools.chain.from_iterable(bq_blocks_3), ) ) assert tuple(got) == expected assert got.total_rows == 7 - mock_client.read_rows.assert_called_once_with( + mock_client.read_rows.assert_any_call( bigquery_storage_v1beta1.types.StreamPosition( stream={"name": "test"}, offset=4 ), metadata={"test-key": "test-value"}, ) + mock_client.read_rows.assert_called_with( + bigquery_storage_v1beta1.types.StreamPosition( + stream={"name": "test"}, offset=7 + ), + metadata={"test-key": "test-value"}, + ) def test_rows_w_reconnect_by_page(class_under_test, mock_client):