Skip to content

Commit

Permalink
Add ASR endpointing stop_threshold_eou parameter (#83)
Browse files Browse the repository at this point in the history
* Exposing the 'stop_historu_eou_th' parameter

* updating submodule

* Updating param name

* Updating help for VAD param

* Adding check for stop_threshold_eou

* Updating proto branch

* updating the submodule
  • Loading branch information
sarane22 authored Jun 27, 2024
1 parent 80e5f04 commit 0a75015
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[submodule "common"]
path = common
url = https://github.com/nvidia-riva/common.git
branch = main
branch = release/2.16.0
2 changes: 1 addition & 1 deletion common
16 changes: 11 additions & 5 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def add_asr_config_argparse_parameters(
"--start-history",
default=-1,
type=int,
help="Value to detect and initiate start of speech utterance",
help="Value (in milliseconds) to detect and initiate start of speech utterance",
)
parser.add_argument(
"--start-threshold",
Expand All @@ -64,19 +64,25 @@ def add_asr_config_argparse_parameters(
"--stop-history",
default=-1,
type=int,
help="Value to reset the endpoint detection history",
help="Value (in milliseconds) to detect end of utterance and reset decoder",
)
parser.add_argument(
"--stop-threshold",
default=-1.0,
type=float,
help="Threshold value for detecting the end of speech utterance",
)
parser.add_argument(
"--stop-history-eou",
default=-1,
type=int,
help="Value to determine the response history for endpoint detection",
help="Value (in milliseconds) to detect end of utterance for the 1st pass and generate an intermediate final transcript",
)
parser.add_argument(
"--stop-threshold",
"--stop-threshold-eou",
default=-1.0,
type=float,
help="Threshold value for detecting the end of speech utterance",
help="Threshold value for likelihood of blanks before detecting end of utterance",
)
return parser

Expand Down
5 changes: 4 additions & 1 deletion riva/client/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def add_endpoint_parameters_to_config(
stop_history: int,
stop_history_eou: int,
stop_threshold: float,
stop_threshold_eou: float,
) -> None:
if not (start_history > 0 or start_threshold > 0 or stop_history > 0 or stop_history_eou > 0 or stop_threshold > 0):
if not (start_history > 0 or start_threshold > 0 or stop_history > 0 or stop_history_eou > 0 or stop_threshold > 0 or stop_threshold_eou > 0):
return

inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
Expand All @@ -146,6 +147,8 @@ def add_endpoint_parameters_to_config(
endpointing_config.stop_history_eou = stop_history_eou
if stop_threshold > 0:
endpointing_config.stop_threshold = stop_threshold
if stop_threshold_eou > 0:
endpointing_config.stop_threshold_eou = stop_threshold_eou
inner_config.endpointing_config.CopyFrom(endpointing_config)


Expand Down
13 changes: 7 additions & 6 deletions scripts/asr/riva_streaming_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ def streaming_transcription_worker(
interim_results=True,
)
riva.client.add_endpoint_parameters_to_config(
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold,
args.stop_threshold_eou
)
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
for _ in range(args.num_iterations):
Expand Down
13 changes: 7 additions & 6 deletions scripts/asr/transcribe_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,13 @@ def main() -> None:
)
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
riva.client.add_endpoint_parameters_to_config(
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold,
args.stop_threshold_eou
)
sound_callback = None
try:
Expand Down
13 changes: 7 additions & 6 deletions scripts/asr/transcribe_file_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ def main() -> None:
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization)
riva.client.add_endpoint_parameters_to_config(
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold,
args.stop_threshold_eou
)
with args.input_file.open('rb') as fh:
data = fh.read()
Expand Down
13 changes: 7 additions & 6 deletions scripts/asr/transcribe_mic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ def main() -> None:
)
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
riva.client.add_endpoint_parameters_to_config(
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold,
args.stop_threshold_eou
)
with riva.client.audio_io.MicrophoneStream(
args.sample_rate_hz,
Expand Down

0 comments on commit 0a75015

Please sign in to comment.