diff --git a/speech/google/cloud/speech/_gax.py b/speech/google/cloud/speech/_gax.py index 9437a806b1ce..12b05cc8a9ff 100644 --- a/speech/google/cloud/speech/_gax.py +++ b/speech/google/cloud/speech/_gax.py @@ -14,7 +14,13 @@ """GAX/GAPIC module for managing Speech API requests.""" +from google.longrunning import operations_grpc + from google.cloud.gapic.speech.v1beta1.speech_api import SpeechApi +from google.cloud.grpc.speech.v1beta1.cloud_speech_pb2 import ( + AsyncRecognizeMetadata) +from google.cloud.grpc.speech.v1beta1.cloud_speech_pb2 import ( + AsyncRecognizeResponse) from google.cloud.grpc.speech.v1beta1.cloud_speech_pb2 import SpeechContext from google.cloud.grpc.speech.v1beta1.cloud_speech_pb2 import RecognitionConfig from google.cloud.grpc.speech.v1beta1.cloud_speech_pb2 import RecognitionAudio @@ -23,13 +29,23 @@ from google.cloud.grpc.speech.v1beta1.cloud_speech_pb2 import ( StreamingRecognizeRequest) - from google.cloud.speech.alternative import Alternative +from google.cloud._helpers import make_secure_stub +from google.cloud.connection import DEFAULT_USER_AGENT +from google.cloud.speech.operation import Operation +from google.cloud.operation import register_type + + +OPERATIONS_API_HOST = 'speech.googleapis.com' + +register_type(AsyncRecognizeMetadata) +register_type(AsyncRecognizeResponse) class GAPICSpeechAPI(object): """Manage calls through GAPIC wrappers to the Speech API.""" - def __init__(self): + def __init__(self, client=None): + self._client = client self._gapic_api = SpeechApi() def async_recognize(self, sample, language_code=None, @@ -72,9 +88,26 @@ def async_recognize(self, sample, language_code=None, and phrases. This can also be used to add new words to the vocabulary of the recognizer. - :raises NotImplementedError: Always. + :rtype: :class:`~google.cloud.operation.Opeartion` + :returns: Instance of ``Operation`` to poll for results. """ - raise NotImplementedError + config = RecognitionConfig( + encoding=sample.encoding, sample_rate=sample.sample_rate, + language_code=language_code, max_alternatives=max_alternatives, + profanity_filter=profanity_filter, + speech_context=SpeechContext(phrases=speech_context)) + + audio = RecognitionAudio(content=sample.content, + uri=sample.source_uri) + api = self._gapic_api + response = api.async_recognize(config=config, audio=audio) + + self._client._operations_stub = make_secure_stub( + self._client.connection.credentials, + DEFAULT_USER_AGENT, + operations_grpc.OperationsStub, + OPERATIONS_API_HOST) + return Operation.from_pb(response, self._client) def sync_recognize(self, sample, language_code=None, max_alternatives=None, profanity_filter=None, speech_context=None): diff --git a/speech/google/cloud/speech/client.py b/speech/google/cloud/speech/client.py index 7959a1c05e4c..85d49574e984 100644 --- a/speech/google/cloud/speech/client.py +++ b/speech/google/cloud/speech/client.py @@ -162,7 +162,7 @@ def speech_api(self): """Helper for speech-related API calls.""" if self._speech_api is None: if self._use_gax: - self._speech_api = GAPICSpeechAPI() + self._speech_api = GAPICSpeechAPI(self) else: self._speech_api = _JSONSpeechAPI(self) return self._speech_api diff --git a/speech/setup.py b/speech/setup.py index 536ed0c53782..cd56983412da 100644 --- a/speech/setup.py +++ b/speech/setup.py @@ -52,7 +52,6 @@ REQUIREMENTS = [ 'google-cloud-core >= 0.20.0', 'gapic-google-cloud-speech-v1beta1 >= 0.11.1, < 0.12.0', - 'grpc-google-cloud-speech-v1beta1 >= 0.11.1, < 0.12.0', ] setup( diff --git a/speech/unit_tests/test_client.py b/speech/unit_tests/test_client.py index f26313940fd4..3f03aeca836a 100644 --- a/speech/unit_tests/test_client.py +++ b/speech/unit_tests/test_client.py @@ -199,11 +199,10 @@ def test_sync_recognize_with_empty_results_gax(self): credentials = _Credentials() client = self._makeOne(credentials=credentials, use_gax=True) client.connection = _Connection() + _MockGAPICSpeechAPI._results = [] with self.assertRaises(ValueError): - mock_no_results = _MockGAPICSpeechAPI - mock_no_results._results = [] - with _Monkey(MUT, SpeechApi=mock_no_results): + with _Monkey(MUT, SpeechApi=_MockGAPICSpeechAPI): sample = Sample(source_uri=self.AUDIO_SOURCE_URI, encoding=speech.Encoding.FLAC, sample_rate=self.SAMPLE_RATE) @@ -218,9 +217,7 @@ def test_sync_recognize_with_gax(self): client = self._makeOne(credentials=creds, use_gax=True) client.connection = _Connection() client._speech_api = None - - mock_no_results = _MockGAPICSpeechAPI - mock_no_results._results = [_MockGAPICSyncResult()] + _MockGAPICSpeechAPI._results = [_MockGAPICSyncResult] with _Monkey(MUT, SpeechApi=_MockGAPICSpeechAPI): sample = client.sample(source_uri=self.AUDIO_SOURCE_URI, @@ -277,13 +274,16 @@ def test_async_recognize_with_gax(self): credentials = _Credentials() client = self._makeOne(credentials=credentials) client.connection = _Connection() + client.connection.credentials = credentials sample = client.sample(source_uri=self.AUDIO_SOURCE_URI, encoding=speech.Encoding.LINEAR16, sample_rate=self.SAMPLE_RATE) with _Monkey(MUT, SpeechApi=_MockGAPICSpeechAPI): - with self.assertRaises(NotImplementedError): - client.async_recognize(sample) + operation = client.async_recognize(sample) + + self.assertFalse(operation.complete) + self.assertIsNone(operation.response) def test_speech_api_with_gax(self): from google.cloud.speech import _gax as MUT @@ -321,6 +321,10 @@ class _MockGAPICAlternative(object): confidence = 0.95234356 +class _MockGAPICMetadata(object): + type_url = None + + class _MockGAPICSyncResult(object): alternatives = [_MockGAPICAlternative()] @@ -328,21 +332,29 @@ class _MockGAPICSyncResult(object): class _MockGAPICSpeechResponse(object): error = None endpointer_type = None + name = None + metadata = _MockGAPICMetadata() results = [] result_index = 0 class _MockGAPICSpeechAPI(object): _requests = None - _response = _MockGAPICSpeechResponse() + _response = _MockGAPICSpeechResponse _results = [_MockGAPICSyncResult()] + def async_recognize(self, config, audio): + from google.longrunning.operations_pb2 import Operation + self.config = config + self.audio = audio + operation = Operation() + return operation + def sync_recognize(self, config, audio): self.config = config self.audio = audio - mock_response = self._response - mock_response.results = self._results - return mock_response + self._response.results = self._results + return self._response class _Credentials(object): diff --git a/system_tests/speech.py b/system_tests/speech.py index d6f63f8dc3d1..a3fe45bd1879 100644 --- a/system_tests/speech.py +++ b/system_tests/speech.py @@ -119,7 +119,6 @@ def _make_async_request(self, content=None, source_uri=None, def _check_results(self, results, num_results=1): self.assertEqual(len(results), num_results) - top_result = results[0] self.assertIsInstance(top_result, Alternative) self.assertEqual(top_result.transcript, @@ -153,8 +152,6 @@ def test_sync_recognize_gcs_file(self): self._check_results(result) def test_async_recognize_local_file(self): - if Config.USE_GAX: - self.skipTest('async_recognize gRPC not yet implemented.') with open(AUDIO_FILE, 'rb') as file_obj: content = file_obj.read() @@ -165,8 +162,6 @@ def test_async_recognize_local_file(self): self._check_results(operation.results, 2) def test_async_recognize_gcs_file(self): - if Config.USE_GAX: - self.skipTest('async_recognize gRPC not yet implemented.') bucket_name = Config.TEST_BUCKET.name blob_name = 'hello.wav' blob = Config.TEST_BUCKET.blob(blob_name)