From cf276babd19dd738d6055032592f48bab4594742 Mon Sep 17 00:00:00 2001 From: Anurag <158568080+atomer-nvidia@users.noreply.github.com> Date: Tue, 12 Mar 2024 15:31:31 +0530 Subject: [PATCH 1/2] fix(tts): Added zero shot parameters to talk.py (#69) --- riva/client/tts.py | 4 ++-- scripts/tts/talk.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/riva/client/tts.py b/riva/client/tts.py index 370373a..71e84ae 100644 --- a/riva/client/tts.py +++ b/riva/client/tts.py @@ -74,7 +74,7 @@ def synthesize( if audio_prompt_file is not None: with wave.open(str(audio_prompt_file), 'rb') as wf: rate = wf.getframerate() - req.zero_shot_data.sample_rate = rate + req.zero_shot_data.sample_rate_hz = rate with audio_prompt_file.open('rb') as wav_f: audio_data = wav_f.read() req.zero_shot_data.audio_prompt = audio_data @@ -131,7 +131,7 @@ def synthesize_online( if audio_prompt_file is not None: with wave.open(str(audio_prompt_file), 'rb') as wf: rate = wf.getframerate() - req.zero_shot_data.sample_rate = rate + req.zero_shot_data.sample_rate_hz = rate with audio_prompt_file.open('rb') as wav_f: audio_data = wav_f.read() req.zero_shot_data.audio_prompt = audio_data diff --git a/scripts/tts/talk.py b/scripts/tts/talk.py index 99cedd0..dd76595 100644 --- a/scripts/tts/talk.py +++ b/scripts/tts/talk.py @@ -22,7 +22,12 @@ def parse_args() -> argparse.Namespace: "based on parameter `--language-code`.", ) parser.add_argument("--text", type=str, required=True, help="Text input to synthesize.") + parser.add_argument( + "--audio_prompt_file", + type=Path, + help="An input audio prompt (.wav) file for zero shot model. This is required to do zero shot inferencing.") parser.add_argument("-o", "--output", type=Path, help="Output file .wav file to write synthesized audio.") + parser.add_argument("--quality", type=int, help="Number of times decoder should be run on the output audio. A higher number improves quality of the produced output but introduces latencies.") parser.add_argument( "--play-audio", action="store_true", @@ -81,7 +86,8 @@ def main() -> None: start = time.time() if args.stream: responses = service.synthesize_online( - args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz + args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz, + audio_prompt_file=args.audio_prompt_file, quality=20 if args.quality is None else args.quality ) first = True for resp in responses: @@ -95,7 +101,8 @@ def main() -> None: out_f.writeframesraw(resp.audio) else: resp = service.synthesize( - args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz + args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz, + audio_prompt_file=args.audio_prompt_file, quality=20 if args.quality is None else args.quality ) stop = time.time() print(f"Time spent: {(stop - start):.3f}s") From deea856b47b5e36ceb98ecf5b45da98b35606a40 Mon Sep 17 00:00:00 2001 From: rmittal-github <61574997+rmittal-github@users.noreply.github.com> Date: Mon, 18 Mar 2024 18:07:44 +0530 Subject: [PATCH 2/2] Remove deprecated NLP client examples (#70) * Remove deprecated NLP client examples * update SHA of common repo --- common | 2 +- scripts/nlp/eval_intent_slot.py | 359 ------------------ scripts/nlp/intentslot_client.py | 77 ---- scripts/nlp/ner_client.py | 54 --- scripts/nlp/qa_client.py | 42 -- scripts/nlp/text_classify_client.py | 29 -- .../update_intent_slot_test_data_format.py | 56 --- 7 files changed, 1 insertion(+), 618 deletions(-) delete mode 100644 scripts/nlp/eval_intent_slot.py delete mode 100644 scripts/nlp/intentslot_client.py delete mode 100644 scripts/nlp/ner_client.py delete mode 100644 scripts/nlp/qa_client.py delete mode 100644 scripts/nlp/text_classify_client.py delete mode 100644 scripts/nlp/update_intent_slot_test_data_format.py diff --git a/common b/common index 2d2cc96..9f192f6 160000 --- a/common +++ b/common @@ -1 +1 @@ -Subproject commit 2d2cc96597c8d30d3fd10f8f584e672efe5d2d10 +Subproject commit 9f192f67f56fcc916ae6c17329394a71c780b0fb diff --git a/scripts/nlp/eval_intent_slot.py b/scripts/nlp/eval_intent_slot.py deleted file mode 100644 index a1e5933..0000000 --- a/scripts/nlp/eval_intent_slot.py +++ /dev/null @@ -1,359 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: MIT - -import argparse -import csv -import itertools -import os.path -import warnings -from pathlib import Path -from typing import Dict, List, NewType, Optional, Tuple, Union - -from sklearn.metrics import classification_report -from sklearn.preprocessing import LabelEncoder -from transformers import BertTokenizer, PreTrainedTokenizerBase - -import riva.client -from riva.client.argparse_utils import add_connection_argparse_parameters - - -def combine_subwords(tokens: List[str]) -> List[str]: - """ - This function combines subwords into single word - - Args: - tokens (:obj:`List[str]`): a list tokens generated using BERT tokenizer which may have subwords - separated by "##". - - Returns: - :obj:`List[str]`: a list of tokens which do not contain "##". Instead such tokens are concatenated with - preceding tokens. - """ - combine_tokens = [] - total_tokens = len(tokens) - idx = 0 - - while idx < total_tokens: - ct = tokens[idx] - token = "" - if ct.startswith("##"): - # remove last token as it needs to be combine with current token - token += combine_tokens.pop(-1) - token += ct.strip("##") - idx += 1 - while idx < total_tokens: - ct = tokens[idx] - if ct.startswith("##"): - token += ct.strip("##") - else: - idx = idx - 1 # put back the token - break - idx += 1 - else: - token = ct - combine_tokens.append(token) - idx += 1 - - # print("combine_tokens=", combine_tokens) - return combine_tokens - - -SlotsType = NewType('SlotsType', List[Dict[str, Union[int, str]]]) - - -def read_tsv_file(input_file: Union[str, os.PathLike]) -> List[Dict[str, Union[str, SlotsType]]]: - """ - Reads .tsv file ``input_file`` with test data in format - ``` - TABTAB - := , - := :: - ``` - Args: - input_file (:obj:`Union[str, os.PathLike]`): a path to an input file - - Returns: - :obj:`List[Dict[str, Union[str, List[Dict[str, Union[int, str]]]]]]`: a list of examples for testing. Each - example has format: - ``` - { - "intent": , - "slots": { - "start": , - "end": , - "slot_name": , - }, - "query": , - } - ``` - """ - content = [] - input_file = Path(input_file).expanduser() - with input_file.open() as f: - reader = csv.reader(f, delimiter='\t') - for row_i, row in enumerate(reader): - row_content = {'intent': row[0]} - slots = [] - if row[1]: - for slot_str in row[1].split(','): - start, end, slot_name = slot_str.split(':') - slots.append({'start': int(start), 'end': int(end), 'name': slot_name}) - slots = sorted(slots, key=lambda x: x['start']) - for i in range(len(slots) - 1): - if slots[i]['end'] > slots[i + 1]['start']: - raise ValueError( - f"Slots {slots[i]} and {slots[i + 1]} from row {row_i} (starting from 0) from file " - f"{input_file} overlap." - ) - row_content['slots'] = slots - row_content['query'] = row[2] - content.append(row_content) - return content - - -def tokenize_with_alignment( - query: str, tokenizer: PreTrainedTokenizerBase -) -> Tuple[List[str], List[Optional[int]], List[Optional[int]], List[Tuple[int, int]]]: - """ - Tokenizes a query :param:`query` using tokenizer :param:`tokenizer`, combines subwords, and calculates slices of - tokens in the query. - - Args: - query (:obj:`str`): an input query. - tokenizer (:obj:`PreTrainedTokenizerBase`): a HuggingFace tokenizer used for tokenizing :param:`query`. - - Returns: - :obj:`tuple`: a tuple containing 3 lists of identical length and 4th list which length can differ - from the first 3: - - - tokens (:obj:`List[str]`): a list of tokens acquired from :param:`query`. - - starts (:obj:`List[Optional[int]]`): a list of slice starts (slices used for extracting tokens from - :param:`query`). If a token is UNK, then a corresponding :obj:`starts` element is :obj:`None`. - - ends (:obj:`List[Optional[int]]`): a list of slice ends (slices used for extracting tokens from - :param:`query`). If a token is UNK, then a corresponding :obj:`ends` element is :obj:`None`. - - unk_zones (:obj:`List[Tuple[int, int]]`): a tuple with slices which show position of UNK tokens and - spaces surrounding UNK tokens. - - Raises: - :obj:`RuntimeError`: if a token is not found in a query. - """ - tokenized_query = tokenizer.tokenize(query) - tokens = combine_subwords(tokenized_query) - starts, ends, unk_zones = [], [], [] - pos_in_query = 0 - unk_zone_start = None - for token_i, token in enumerate(tokens): - if token == tokenizer.unk_token: - if unk_zone_start is None: - unk_zone_start = pos_in_query - starts.append(None) - ends.append(None) - else: - while pos_in_query < len(query) and query[pos_in_query: pos_in_query + len(token)] != token: - pos_in_query += 1 - if pos_in_query >= len(query): - raise RuntimeError( - f"Tokenization of a query '{query}' lead to removal of token '{token}'. Tokens: {tokens}." - ) - if unk_zone_start is not None: - unk_zones.append((unk_zone_start, pos_in_query)) - unk_zone_start = None - starts.append(pos_in_query) - pos_in_query += len(token) - ends.append(pos_in_query) - return tokens, starts, ends, unk_zones - - -def slots_to_bio( - queries: List[str], - slots: List[SlotsType], - tokenizer: Optional[PreTrainedTokenizerBase] = None, - require_correct_slots: bool = True -) -> List[List[str]]: - """ - Creates BIO markup for queries in :param:`queries` according slots described in :param:`slots`. - - Args: - queries (:obj:`List[str]`): a list of input queries - slots (:obj:`List[List[Dict[str, Union[int, str]]]]`): a list of slots for all queries. Slots for a query is a - list of dictionaries with keys :obj:`"start"`, :obj:`"end"`, :obj:`"name"`. :obj:`"start"` and :obj:`"end"` - if used as slice start and end for corresponding give a slot text. - tokenizer (:obj:`PreTrainedTokenizerBase`, `optional`): a tokenizer used for queries tokenization. - If :obj:`None`, then `"bert-base-cased"` is used. - require_correct_slots (:obj:`bool`, defaults to :obj:`True`): if :obj:`True`, then matching of tokens and - slot spans is checked and an error is raised if there is no match. Set this to :obj:`True` if you prepare - ground truth and to :obj:`False` if you prepare predictions. - - Returns: - :obj:`List[List[str]]`: a BIO markup for queries. - """ - if tokenizer is None: - tokenizer = BertTokenizer.from_pretrained('bert-base-cased') - bio: List[List[str]] = [] - for query_idx, (query, query_slots) in enumerate(zip(queries, slots)): - tokens, starts, ends, unk_zones = tokenize_with_alignment(query, tokenizer) - query_bio = ['O'] * len(tokens) - for slot in query_slots: - if slot['end'] <= slot['start']: - if require_correct_slots: - raise ValueError( - f"Slot '{slot['name']}' end offset {slot['end']} cannot be less or equal to slot start offset " - f"{slot['start']} in query '{query}' with query index {query_idx}. " - f"The error can occur if test data mark up is wrong." - ) - else: - continue - slot_start_token_idx, slot_end_token_idx = None, None - for token_i, start in enumerate(starts): - if start == slot['start']: - slot_start_token_idx = token_i - query_bio[slot_start_token_idx] = f'B-{slot["name"]}' - break - if slot_start_token_idx is None: - if require_correct_slots: - raise ValueError( - f"Could not find a beginning of slot {slot} in a query '{query}'. Acquired tokens: {tokens}. " - f"Aligned token beginning offsets: {starts}. Aligned token ending offsets: " - f"{ends}. An error occurred during processing of {query_idx}th query. This error " - f"can appear if query mark up is broken." - ) - else: - continue - found_end = False - for token_i, end in enumerate(ends): - if end == slot['end']: - found_end = True - for j in range(slot_start_token_idx + 1, token_i + 1): - query_bio[j] = f'I-{slot["name"]}' - if not found_end and require_correct_slots: - raise ValueError( - f"Could not find an end of slot {slot} in a query '{query}'. Acquired tokens: {tokens}. " - f"Aligned token beginning offsets: {starts}. Aligned token ending offsets: " - f"{ends}. An error occurred during processing of {query_idx}th query. This error " - f"can appear if query mark up is broken." - ) - bio.append(query_bio) - return bio - - -def pack_slots_to_dict_format( - slots: List[List[str]], starts: List[List[int]], ends: List[List[int]] -) -> List[SlotsType]: - output: List[SlotsType] = [] - for query_slots, query_starts, query_ends in zip(slots, starts, ends): - output.append( - [ - {'start': start, 'end': end + 1, 'name': slot} - for start, end, slot in zip(query_starts, query_ends, query_slots) - ] - ) - return output - - -def slots_classification_report( - y_true: List[List[str]], y_pred: List[List[str]], output_dict: bool -) -> Union[str, Dict[str, Dict[str, Union[int, float]]]]: - encoder = LabelEncoder() - all_slots = list({ele for row in y_true for ele in row}.union({ele for row in y_pred for ele in row})) - encoder.fit(all_slots) - y_true, y_pred = list(itertools.chain(*y_true)), list(itertools.chain(*y_pred)) - y_truth = encoder.transform(y_true) - y_pred = encoder.transform(y_pred) - target_names = encoder.classes_ - return classification_report(y_truth, y_pred, target_names=target_names, output_dict=output_dict) - - -def intent_slots_classification_report( - input_file: Path, - nlp_service: riva.client.NLPService, - model: str, - batch_size: int, - language_code: str, - output_dict: bool, - max_async_requests_to_queue: int, -) -> Union[ - Tuple[str, str], - Tuple[Dict[str, Dict[str, Union[int, float]]]], Dict[str, Dict[str, Union[int, float]]] -]: - test_data = read_tsv_file(input_file) - queries = [elem['query'] for elem in test_data] - tokens, slots, _, token_starts, token_ends = riva.client.nlp.classify_tokens_batch( - nlp_service, queries, model, batch_size, language_code, max_async_requests_to_queue - ) - intents, _ = riva.client.nlp.classify_text_batch( - nlp_service, queries, model, batch_size, language_code, max_async_requests_to_queue - ) - intent_report = classification_report([elem['intent'] for elem in test_data], intents, output_dict=output_dict) - ground_truth_bio = slots_to_bio(queries, [elem['slots'] for elem in test_data]) - predicted_bio = slots_to_bio( - queries, pack_slots_to_dict_format(slots, token_starts, token_ends), require_correct_slots=False - ) - per_label_slot_report = slots_classification_report(ground_truth_bio, predicted_bio, output_dict=output_dict) - return intent_report, per_label_slot_report - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Program to print classification reports for intent and slot test data.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("--model", default="riva_intent_weather", type=str, help="Model on TRTIS to execute") - parser.add_argument( - "--input-file", - type=Path, - required=True, - help="A path to an input .tsv file. An input file has to be in a format TABTAB. " - " field contains several comma separated slots, e.g.: ,. If there are no slots, then " - " is an empty string. Each slot has a format :: where and " - "are start and end of a slice applied to a query to get a slot, e.g. in an a sample " - "'0:4:animalcats are nice' `start=0`, `end=4`, `query='cats are nice'` " - "and slot `animal='cats'` is acquired by `query[start:end]`." - "`data/nlp_test_metrics/weather.fixed.eval.tsv` is an example of a correct input file.", - ) - parser.add_argument("--language-code", default='en-US', help="A language of a model.") - parser.add_argument( - "--batch-size", - type=int, - default=1, - help="How many examples are sent to server in one request. Currently only `1` is supported.", - ) - parser.add_argument( - "--max-async-requests-to-queue", - type=int, - default=500, - help="If greater than 0, then data is processed in async manner. Up to`--max-async-requests-to-queue` " - "requests are asynchronous requests are sent and then the program will wait for results. When results are " - "returned, new `--max-async-requests-to-queue` are sent.", - ) - parser = add_connection_argparse_parameters(parser) - args = parser.parse_args() - if args.max_async_requests_to_queue < 0: - parser.error( - f"Parameter `--max-async-requests-to-queue` has not negative, whereas {args.max_async_requests_to_queue} " - f"was given." - ) - if args.batch_size > 1: - warnings.warn("Batch size > 1 is not supported because spans may be calculated incorrectly.") - args.input_file = args.input_file.expanduser() - return args - - -def main() -> None: - args = parse_args() - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) - service = riva.client.NLPService(auth) - intent_report, slot_report = intent_slots_classification_report( - args.input_file, - service, - args.model, - args.batch_size, - args.language_code, - output_dict=False, - max_async_requests_to_queue=args.max_async_requests_to_queue - ) - print(intent_report) - print(slot_report) - - -if __name__ == "__main__": - main() diff --git a/scripts/nlp/intentslot_client.py b/scripts/nlp/intentslot_client.py deleted file mode 100644 index 7580fa9..0000000 --- a/scripts/nlp/intentslot_client.py +++ /dev/null @@ -1,77 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: MIT - -import argparse -import time -from typing import List - -import riva.client -from riva.client.argparse_utils import add_connection_argparse_parameters - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Client app to run intent slot on Riva", formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument( - "--model", default="riva_intent_weather", help="Model on Riva Server to execute." - ) - parser.add_argument("--query", default="What is the weather tomorrow?", help="Input Query") - parser.add_argument( - "--interactive", - action='store_true', - help="If this option is set, then `--query` argument is ignored and the script suggests user to enter " - "queries to standard input.", - ) - parser = add_connection_argparse_parameters(parser) - return parser.parse_args() - - -def pretty_print_result( - intent: str, intent_score: float, slots: List[str], tokens: List[str], slot_scores: List[float], duration: float -) -> None: - print(f"Inference complete in {duration * 1000:.4f} ms") - print("Intent:", intent) - print("Intent Score:", intent_score) - print("Slots:", slots) - print("Slots Scores:", slot_scores) - if len(tokens) > 0: - print("Combined: ", end="") - for token, slot in zip(tokens, slots): - print(f"{token}{f'({slot})' if slot != 'O' else ''}", end=" ") - print("\n") - - -def main() -> None: - args = parse_args() - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) - service = riva.client.NLPService(auth) - if args.interactive: - while True: - query = input("Enter a query: ") - start = time.time() - intents, intent_confidences = riva.client.extract_most_probable_text_class_and_confidence( - service.classify_text(input_strings=query, model_name=args.model) - ) - tokens, slots, slot_confidences, _, _ = riva.client.extract_most_probable_token_classification_predictions( - service.classify_tokens(input_strings=query, model_name=args.model) - ) - end = time.time() - pretty_print_result( - intents[0], intent_confidences[0], slots[0], tokens[0], slot_confidences[0], end - start - ) - else: - intents, intent_confidences = riva.client.extract_most_probable_text_class_and_confidence( - service.classify_text(input_strings=args.query, model_name=args.model) - ) - tokens, slots, slot_confidences, _, _ = riva.client.extract_most_probable_token_classification_predictions( - service.classify_tokens(input_strings=args.query, model_name=args.model) - ) - results = [ - (intents[i], intent_confidences[i], slots[i], tokens[i], slot_confidences[i]) for i in range(len(slots)) - ] - print(results) - - -if __name__ == '__main__': - main() diff --git a/scripts/nlp/ner_client.py b/scripts/nlp/ner_client.py deleted file mode 100644 index 014cd8e..0000000 --- a/scripts/nlp/ner_client.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: MIT - -import argparse - -import riva.client -from riva.client.argparse_utils import add_connection_argparse_parameters - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Client app to run NER on Riva", formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument("--model", default="riva_ner", help="Model on Riva Server to execute.") - parser.add_argument( - "--query", nargs="+", default=["Where is San Francisco?", "Jensen Huang is the CEO of NVIDIA Corporation."] - ) - parser.add_argument( - "--test", - default="label", - choices=['label', 'span_start', 'span_end'], - help="What info will be printed to STDOUT. If 'label', then a class of an entity will be printed. " - "If 'span_start', then indices of first characters of entities are printed. For example, for a query " - "'cats are nice' if an entity is 'cats', then 'span_start' is 0. If 'span_end', then indices of " - "first characters following entities are printed. For example, for the query 'cats are nice' for entity " - "'cats' 'span_end' is 4.", - ) - parser = add_connection_argparse_parameters(parser) - return parser.parse_args() - - -def main() -> None: - args = parse_args() - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) - service = riva.client.NLPService(auth) - tokens, slots, slot_confidences, starts, ends = riva.client.extract_most_probable_token_classification_predictions( - service.classify_tokens(input_strings=args.query, model_name=args.model) - ) - test_mode = args.test - if test_mode == "label": - print(slots) - elif test_mode == "span_start": - print(starts) - elif test_mode == "span_end": - print(ends) - else: - raise ValueError( - f"Testing mode '{test_mode}' is not supported. Supported testing modes are: 'label', 'span_start', " - f"'span_end'" - ) - - -if __name__ == '__main__': - main() diff --git a/scripts/nlp/qa_client.py b/scripts/nlp/qa_client.py deleted file mode 100644 index f14040e..0000000 --- a/scripts/nlp/qa_client.py +++ /dev/null @@ -1,42 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: MIT - -import argparse - -import riva.client -from riva.client.argparse_utils import add_connection_argparse_parameters - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Riva Question Answering client sample.", formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument( - "--query", default="How much carbon dioxide was released in 2005?", help="Query for the QA API." - ) - parser.add_argument( - "--context", - default="In 2010 the Amazon rainforest experienced another severe drought, in some ways more extreme than the " - "2005 drought. The affected region was approximate 1,160,000 square miles (3,000,000 km2) of " - "rainforest, compared to 734,000 square miles (1,900,000 km2) in 2005. The 2010 drought had three " - "epicenters where vegetation died off, whereas in 2005 the drought was focused on the southwestern " - "part. The findings were published in the journal Science. In a typical year the Amazon absorbs 1.5 " - "gigatons of carbon dioxide; during 2005 instead 5 gigatons were released and in 2010 8 gigatons were " - "released.", - help="Context for the QA API", - ) - parser = add_connection_argparse_parameters(parser) - return parser.parse_args() - - -def main() -> None: - args = parse_args() - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) - service = riva.client.NLPService(auth) - resp = service.natural_query(args.query, args.context) - print(resp) - - -if __name__ == "__main__": - main() - diff --git a/scripts/nlp/text_classify_client.py b/scripts/nlp/text_classify_client.py deleted file mode 100644 index c554d1f..0000000 --- a/scripts/nlp/text_classify_client.py +++ /dev/null @@ -1,29 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: MIT - -import argparse - -import riva.client -from riva.client.argparse_utils import add_connection_argparse_parameters - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Client app to run Text Classification on Riva.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("--model", default="riva_text_classification_domain", help="Model on Riva Server to execute.") - parser.add_argument("--query", default="How much sun does california get?", help="An input query.") - parser = add_connection_argparse_parameters(parser) - return parser.parse_args() - - -def main() -> None: - args = parse_args() - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) - service = riva.client.NLPService(auth) - print(riva.client.nlp.extract_most_probable_text_class_and_confidence(service.classify_text(args.query, args.model))) - - -if __name__ == '__main__': - main() diff --git a/scripts/nlp/update_intent_slot_test_data_format.py b/scripts/nlp/update_intent_slot_test_data_format.py deleted file mode 100644 index ffa1315..0000000 --- a/scripts/nlp/update_intent_slot_test_data_format.py +++ /dev/null @@ -1,56 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: MIT - -import argparse -from pathlib import Path - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - description="Transforms old style file for intent classification and entities classification " - "to new style format. Old style is 'TABTAB' where is in format " - "'BOS EOS'. This script keeps only in and removes auxiliary " - " field and BOS and EOS." - ) - parser.add_argument("--input-file", type=Path, help="A path to an input .tsv file.", required=True) - parser.add_argument("--output-file", type=Path, help="A path to an output .tsv file.", required=True) - args = parser.parse_args() - args.input_file = args.input_file.expanduser() - args.output_file = args.output_file.expanduser() - return args - - -def main() -> None: - args = parse_args() - with args.input_file.open() as in_f, args.output_file.open('w') as out_f: - for line_i, line in enumerate(in_f): - intent, slots, query = line.split('\t') - words = query.split() - new_query = ' '.join(words[2:-1]) - if slots: - slots = slots.split(',') - new_slots = [] - offset = len(words[0]) + len(words[1]) + 2 - for slot in slots: - try: - start, end, name = slot.split(':') - except ValueError: - print(slot) - print(line_i) - raise - start, end = int(start), int(end) - if start < offset or end < offset: - raise ValueError( - f"Slot borders start={start}, end={end} in line {line_i} in file {args.input_file}" - ) - slot = ':'.join([str(start - offset), str(end - offset), name]) - new_slots.append(slot) - new_slots = ','.join(new_slots) - else: - new_slots = '' - out_f.write('\t'.join([intent, new_slots, new_query]) + '\n') - - -if __name__ == "__main__": - main()