From b94b3a92af1852a8ba72d4f095269cac356cfa29 Mon Sep 17 00:00:00 2001 From: sarane22 <118975230+sarane22@users.noreply.github.com> Date: Tue, 3 Sep 2024 15:59:16 +0530 Subject: [PATCH] Support custom_configuration param in ASR clients (#94) * Passing AST param through custom_configuration * Added exception handling for TTS talk.py * Exposing custom-configurtion to cli * Updating function name to add_custom_configuration_to_config * Updating help message --------- Co-authored-by: mohnishparmar <109233781+mohnishparmar@users.noreply.github.com> --- riva/client/__init__.py | 1 + riva/client/argparse_utils.py | 6 ++++++ riva/client/asr.py | 19 ++++++++++++++++++- scripts/asr/riva_streaming_asr_client.py | 4 ++++ scripts/asr/transcribe_file.py | 4 ++++ scripts/asr/transcribe_file_offline.py | 6 +++++- scripts/asr/transcribe_mic.py | 4 ++++ scripts/tts/talk.py | 2 ++ 8 files changed, 44 insertions(+), 2 deletions(-) diff --git a/riva/client/__init__.py b/riva/client/__init__.py index ef28d13..7656bd6 100644 --- a/riva/client/__init__.py +++ b/riva/client/__init__.py @@ -12,6 +12,7 @@ print_streaming, sleep_audio_length, add_endpoint_parameters_to_config, + add_custom_configuration_to_config, ) from riva.client.auth import Auth from riva.client.nlp import ( diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index 7f46713..abd1c96 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -85,6 +85,12 @@ def add_asr_config_argparse_parameters( type=float, help="Threshold value for likelihood of blanks before detecting end of utterance", ) + parser.add_argument( + "--custom-configuration", + default="", + type=str, + help="Custom configurations to be sent to the server as key value pairs ", + ) return parser diff --git a/riva/client/asr.py b/riva/client/asr.py index 5095fd6..59a7f59 100644 --- a/riva/client/asr.py +++ b/riva/client/asr.py @@ -123,8 +123,9 @@ def add_speaker_diarization_to_config( diarization_config = rasr.SpeakerDiarizationConfig(enable_speaker_diarization=True) inner_config.diarization_config.CopyFrom(diarization_config) + def add_endpoint_parameters_to_config( - config: Union[rasr.RecognitionConfig, rasr.EndpointingConfig], + config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig], start_history: int, start_threshold: float, stop_history: int, @@ -152,6 +153,22 @@ def add_endpoint_parameters_to_config( inner_config.endpointing_config.CopyFrom(endpointing_config) +def add_custom_configuration_to_config( + config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig], + custom_configuration: str, +) -> None: + custom_configuration = custom_configuration.strip().replace(" ", "") + if not custom_configuration: + return + inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config + for pair in custom_configuration.split(","): + key_value = pair.split(":") + if len(key_value) == 2: + inner_config.custom_configuration[key_value[0]] = key_value[1] + else: + raise ValueError(f"Invalid key:value pair {key_value}") + + PRINT_STREAMING_ADDITIONAL_INFO_MODES = ['no', 'time', 'confidence'] diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index 874b0f1..f701814 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -73,6 +73,10 @@ def streaming_transcription_worker( args.stop_threshold, args.stop_threshold_eou ) + riva.client.add_custom_configuration_to_config( + config, + args.custom_configuration + ) riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score) for _ in range(args.num_iterations): with riva.client.AudioChunkFileIterator( diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index d7e2476..f558aaf 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -109,6 +109,10 @@ def main() -> None: args.stop_threshold, args.stop_threshold_eou ) + riva.client.add_custom_configuration_to_config( + config, + args.custom_configuration + ) sound_callback = None try: if args.play_audio or args.output_device is not None: diff --git a/scripts/asr/transcribe_file_offline.py b/scripts/asr/transcribe_file_offline.py index 5dcda00..22f6c13 100644 --- a/scripts/asr/transcribe_file_offline.py +++ b/scripts/asr/transcribe_file_offline.py @@ -46,7 +46,11 @@ def main() -> None: args.stop_history_eou, args.stop_threshold, args.stop_threshold_eou - ) + ) + riva.client.add_custom_configuration_to_config( + config, + args.custom_configuration + ) with args.input_file.open('rb') as fh: data = fh.read() try: diff --git a/scripts/asr/transcribe_mic.py b/scripts/asr/transcribe_mic.py index 5e21ebe..517f2c9 100644 --- a/scripts/asr/transcribe_mic.py +++ b/scripts/asr/transcribe_mic.py @@ -67,6 +67,10 @@ def main() -> None: args.stop_threshold, args.stop_threshold_eou ) + riva.client.add_custom_configuration_to_config( + config, + args.custom_configuration + ) with riva.client.audio_io.MicrophoneStream( args.sample_rate_hz, args.file_streaming_chunk, diff --git a/scripts/tts/talk.py b/scripts/tts/talk.py index 909467b..80bc78b 100644 --- a/scripts/tts/talk.py +++ b/scripts/tts/talk.py @@ -157,6 +157,8 @@ def main() -> None: sound_stream(resp.audio) if out_f is not None: out_f.writeframesraw(resp.audio) + except Exception as e: + print(e.details()) finally: if out_f is not None: out_f.close()