Skip to content

Commit

Permalink
Support custom_configuration param in ASR clients (#94)
Browse files Browse the repository at this point in the history
* Passing AST param through custom_configuration

* Added exception handling for TTS talk.py

* Exposing custom-configurtion to cli

* Updating function name to add_custom_configuration_to_config

* Updating help message

---------

Co-authored-by: mohnishparmar <[email protected]>
  • Loading branch information
sarane22 and mohnishparmar authored Sep 3, 2024
1 parent c789e98 commit b94b3a9
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 2 deletions.
1 change: 1 addition & 0 deletions riva/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
print_streaming,
sleep_audio_length,
add_endpoint_parameters_to_config,
add_custom_configuration_to_config,
)
from riva.client.auth import Auth
from riva.client.nlp import (
Expand Down
6 changes: 6 additions & 0 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def add_asr_config_argparse_parameters(
type=float,
help="Threshold value for likelihood of blanks before detecting end of utterance",
)
parser.add_argument(
"--custom-configuration",
default="",
type=str,
help="Custom configurations to be sent to the server as key value pairs <key:value,key:value,...>",
)
return parser


Expand Down
19 changes: 18 additions & 1 deletion riva/client/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ def add_speaker_diarization_to_config(
diarization_config = rasr.SpeakerDiarizationConfig(enable_speaker_diarization=True)
inner_config.diarization_config.CopyFrom(diarization_config)


def add_endpoint_parameters_to_config(
config: Union[rasr.RecognitionConfig, rasr.EndpointingConfig],
config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig],
start_history: int,
start_threshold: float,
stop_history: int,
Expand Down Expand Up @@ -152,6 +153,22 @@ def add_endpoint_parameters_to_config(
inner_config.endpointing_config.CopyFrom(endpointing_config)


def add_custom_configuration_to_config(
config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig],
custom_configuration: str,
) -> None:
custom_configuration = custom_configuration.strip().replace(" ", "")
if not custom_configuration:
return
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
for pair in custom_configuration.split(","):
key_value = pair.split(":")
if len(key_value) == 2:
inner_config.custom_configuration[key_value[0]] = key_value[1]
else:
raise ValueError(f"Invalid key:value pair {key_value}")


PRINT_STREAMING_ADDITIONAL_INFO_MODES = ['no', 'time', 'confidence']


Expand Down
4 changes: 4 additions & 0 deletions scripts/asr/riva_streaming_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def streaming_transcription_worker(
args.stop_threshold,
args.stop_threshold_eou
)
riva.client.add_custom_configuration_to_config(
config,
args.custom_configuration
)
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
for _ in range(args.num_iterations):
with riva.client.AudioChunkFileIterator(
Expand Down
4 changes: 4 additions & 0 deletions scripts/asr/transcribe_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def main() -> None:
args.stop_threshold,
args.stop_threshold_eou
)
riva.client.add_custom_configuration_to_config(
config,
args.custom_configuration
)
sound_callback = None
try:
if args.play_audio or args.output_device is not None:
Expand Down
6 changes: 5 additions & 1 deletion scripts/asr/transcribe_file_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def main() -> None:
args.stop_history_eou,
args.stop_threshold,
args.stop_threshold_eou
)
)
riva.client.add_custom_configuration_to_config(
config,
args.custom_configuration
)
with args.input_file.open('rb') as fh:
data = fh.read()
try:
Expand Down
4 changes: 4 additions & 0 deletions scripts/asr/transcribe_mic.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def main() -> None:
args.stop_threshold,
args.stop_threshold_eou
)
riva.client.add_custom_configuration_to_config(
config,
args.custom_configuration
)
with riva.client.audio_io.MicrophoneStream(
args.sample_rate_hz,
args.file_streaming_chunk,
Expand Down
2 changes: 2 additions & 0 deletions scripts/tts/talk.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def main() -> None:
sound_stream(resp.audio)
if out_f is not None:
out_f.writeframesraw(resp.audio)
except Exception as e:
print(e.details())
finally:
if out_f is not None:
out_f.close()
Expand Down

0 comments on commit b94b3a9

Please sign in to comment.