Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding AST support to python clients #94

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions riva/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
print_streaming,
sleep_audio_length,
add_endpoint_parameters_to_config,
add_ast_parameters_to_config,
)
from riva.client.auth import Auth
from riva.client.nlp import (
Expand Down
18 changes: 18 additions & 0 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,24 @@ def add_asr_config_argparse_parameters(
type=float,
help="Threshold value for likelihood of blanks before detecting end of utterance",
)
parser.add_argument(
"--source-language",
default="",
type=str,
help="Language of the audio file",
)
parser.add_argument(
"--task",
default="transcribe",
type=str,
help="Task for the model (transcribe/translate)",
)
parser.add_argument(
"--target-language",
default="",
type=str,
help="Target language for translation",
)
rmittal-github marked this conversation as resolved.
Show resolved Hide resolved
return parser


Expand Down
15 changes: 15 additions & 0 deletions riva/client/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def add_speaker_diarization_to_config(
diarization_config = rasr.SpeakerDiarizationConfig(enable_speaker_diarization=True)
inner_config.diarization_config.CopyFrom(diarization_config)


def add_endpoint_parameters_to_config(
config: Union[rasr.RecognitionConfig, rasr.EndpointingConfig],
start_history: int,
Expand Down Expand Up @@ -152,6 +153,20 @@ def add_endpoint_parameters_to_config(
inner_config.endpointing_config.CopyFrom(endpointing_config)


def add_ast_parameters_to_config(
config: Union[rasr.RecognitionConfig, rasr.EndpointingConfig],
rmittal-github marked this conversation as resolved.
Show resolved Hide resolved
source_language: str,
target_language: str,
task: str,
) -> None:
if not source_language:
return
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
inner_config.custom_configuration["source_language"] = source_language
inner_config.custom_configuration["target_language"] = target_language
inner_config.custom_configuration["task"] = task


PRINT_STREAMING_ADDITIONAL_INFO_MODES = ['no', 'time', 'confidence']


Expand Down
18 changes: 12 additions & 6 deletions scripts/asr/transcribe_file_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,20 @@ def main() -> None:
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization)
riva.client.add_endpoint_parameters_to_config(
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
rmittal-github marked this conversation as resolved.
Show resolved Hide resolved
args.stop_threshold,
args.stop_threshold_eou
)
)
riva.client.add_ast_parameters_to_config(
config,
args.source_language,
args.target_language,
args.task
)
with args.input_file.open('rb') as fh:
data = fh.read()
try:
Expand Down
2 changes: 2 additions & 0 deletions scripts/tts/talk.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def main() -> None:
sound_stream(resp.audio)
if out_f is not None:
out_f.writeframesraw(resp.audio)
except Exception as e:
print(e.details())
finally:
if out_f is not None:
out_f.close()
Expand Down