From 4adbedfbdc48c68c55bb5873fbe175cfbe9d5831 Mon Sep 17 00:00:00 2001 From: Qishuai Zhong Date: Fri, 19 Nov 2021 21:29:36 +0800 Subject: [PATCH 1/6] reformat the structure and tests --- .gitattributes | 1 - Makefile | 4 +- dbpunctuator/__init__.py | 7 +++ .../data_process}/__init__.py | 0 .../data_process}/additional_data_process.py | 0 .../data_process}/data_cleanning.py | 2 - .../data_process}/data_process.py | 39 ++++++++------ .../inference}/__init__.py | 0 .../inference}/inference_interface.py | 25 ++++++--- .../inference}/inference_pipeline.py | 29 ++++++---- .../training}/__init__.py | 0 .../training}/dataset.py | 0 {training => dbpunctuator/training}/train.py | 4 +- {utils => dbpunctuator/utils}/__init__.py | 0 {utils => dbpunctuator/utils}/constant.py | 2 + {utils => dbpunctuator/utils}/upload_model.py | 0 {utils => dbpunctuator/utils}/utils.py | 0 examples/data_sample.py | 2 +- examples/inference_sample.py | 31 +++++++++-- examples/train_sample.py | 4 +- models/tag2id.json | 3 -- notes.md | 14 +++++ setup.py | 5 +- tests/common.py | 53 +++++++++++++++++++ tests/test_dataprocess.py | 29 ++++++++++ tests/test_inference.py | 40 +++++++------- 26 files changed, 224 insertions(+), 70 deletions(-) create mode 100644 dbpunctuator/__init__.py rename {data_process => dbpunctuator/data_process}/__init__.py (100%) rename {data_process => dbpunctuator/data_process}/additional_data_process.py (100%) rename {data_process => dbpunctuator/data_process}/data_cleanning.py (94%) rename {data_process => dbpunctuator/data_process}/data_process.py (74%) rename {inference => dbpunctuator/inference}/__init__.py (100%) rename {inference => dbpunctuator/inference}/inference_interface.py (86%) rename {inference => dbpunctuator/inference}/inference_pipeline.py (87%) rename {training => dbpunctuator/training}/__init__.py (100%) rename {training => dbpunctuator/training}/dataset.py (100%) rename {training => dbpunctuator/training}/train.py (99%) rename {utils => dbpunctuator/utils}/__init__.py (100%) rename {utils => dbpunctuator/utils}/constant.py (85%) rename {utils => dbpunctuator/utils}/upload_model.py (100%) rename {utils => dbpunctuator/utils}/utils.py (100%) delete mode 100644 models/tag2id.json create mode 100644 notes.md create mode 100644 tests/common.py create mode 100644 tests/test_dataprocess.py diff --git a/.gitattributes b/.gitattributes index 56abdca..3b70f7a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,2 @@ -models/tag2id.json filter=lfs diff=lfs merge=lfs -text models/punctuator/config.json filter=lfs diff=lfs merge=lfs -text models/punctuator/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text diff --git a/Makefile b/Makefile index 497bbf1..4d585f9 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,10 @@ -PY_SOURCE_FILES=data_process/ training/ inference/ utils/ examples/ #this can be modified to include more files +PY_SOURCE_FILES=dbpunctuator/ examples/ tests/ #this can be modified to include more files install: package pip install -e .[dev] test: - pytest tests -vv + pytest tests -vv -s clean: rm -rf build/ dist/ *.egg-info .pytest_cache diff --git a/dbpunctuator/__init__.py b/dbpunctuator/__init__.py new file mode 100644 index 0000000..fa728cd --- /dev/null +++ b/dbpunctuator/__init__.py @@ -0,0 +1,7 @@ +import logging + +from .utils.utils import register_logger + +# setup library logging +logger = logging.getLogger(__name__) +register_logger(logger) diff --git a/data_process/__init__.py b/dbpunctuator/data_process/__init__.py similarity index 100% rename from data_process/__init__.py rename to dbpunctuator/data_process/__init__.py diff --git a/data_process/additional_data_process.py b/dbpunctuator/data_process/additional_data_process.py similarity index 100% rename from data_process/additional_data_process.py rename to dbpunctuator/data_process/additional_data_process.py diff --git a/data_process/data_cleanning.py b/dbpunctuator/data_process/data_cleanning.py similarity index 94% rename from data_process/data_cleanning.py rename to dbpunctuator/data_process/data_cleanning.py index 2be54b7..b7d8b28 100644 --- a/data_process/data_cleanning.py +++ b/dbpunctuator/data_process/data_cleanning.py @@ -4,8 +4,6 @@ tqdm.pandas() -default_kept_punctuations = {",", ".", "?", "!"} - def dataframe_data_cleaning( df, target_col, kept_punctuations, additional_to_remove, *special_cleaning_funcs diff --git a/data_process/data_process.py b/dbpunctuator/data_process/data_process.py similarity index 74% rename from data_process/data_process.py rename to dbpunctuator/data_process/data_process.py index 968b451..40af753 100644 --- a/data_process/data_process.py +++ b/dbpunctuator/data_process/data_process.py @@ -1,8 +1,13 @@ +import logging + import pandas as pd from tqdm import tqdm -from data_process.data_cleanning import dataframe_data_cleaning -from utils.constant import DEFAULT_NER_MAPPING, DIGIT_MASK +from dbpunctuator.utils.constant import DEFAULT_NER_MAPPING, DIGIT_MASK + +from .data_cleanning import dataframe_data_cleaning + +logger = logging.getLogger(__name__) def cleanup_data_from_csv( @@ -11,7 +16,7 @@ def cleanup_data_from_csv( output_file_path, ner_mapping=DEFAULT_NER_MAPPING, additional_to_remove=[], - special_cleaning_funcs=[], + *special_cleaning_funcs, ): """clean up training data from csv file @@ -21,13 +26,14 @@ def cleanup_data_from_csv( output_file_path (string): path of cleaned data ner_mapping (dict, optional): NER mapping of punctuation marks. Defaults to utils.constant.DEFAULT_NER_MAPPING additional_to_remove (list, optional): additional special characters to remove, default [] - special_cleaning_funcs (List[func], optional): additional cleaning funcs to apply to csv data, default [] + *special_cleaning_funcs (funcs, optional): additional cleaning funcs to apply to csv data """ dataframe = pd.read_csv(csv_path) additional_to_remove = ["—"] kept_punctuations = set(ner_mapping.keys()) + logger.info("clean up original data") result_df = dataframe_data_cleaning( - dataframe[1500:], + dataframe, target_col, kept_punctuations, additional_to_remove, @@ -44,8 +50,8 @@ def cleanup_data_from_csv( def process_line(line, ner_mapping=DEFAULT_NER_MAPPING): text_list = line.split() - word_list = [] token_list = [] + tag_list = [] # clean up puncs in the beginning of the text latest_word = text_list.pop(0) while latest_word in ner_mapping: @@ -59,23 +65,23 @@ def process_line(line, ner_mapping=DEFAULT_NER_MAPPING): if not latest_is_punc: latest_token = ner_mapping[word] latest_is_punc = True - word_list.append(latest_word) - token_list.append(latest_token) + token_list.append(latest_word) + tag_list.append(latest_token) else: pass else: if not latest_is_punc: - word_list.append(latest_word) - token_list.append(latest_token) + token_list.append(latest_word) + tag_list.append(latest_token) latest_is_punc = False if word.isdigit(): word = DIGIT_MASK latest_word = word latest_token = "O" if not latest_is_punc: - word_list.append(latest_word) - token_list.append(latest_token) - return word_list, token_list + token_list.append(latest_word) + tag_list.append(latest_token) + return token_list, tag_list def generate_training_data(cleaned_data_path, training_data_path): @@ -85,12 +91,13 @@ def generate_training_data(cleaned_data_path, training_data_path): cleaned_data_path (string): path of cleaned data training_data_path (string): path of generated training data """ + logger.info("generate training data") with open(cleaned_data_path, "r") as data_file: lines = data_file.readlines() with open(training_data_path, "w+") as training_data_file: pbar = tqdm(lines) for line in pbar: - words, tokens = process_line(line) - for word, token in zip(words, tokens): - training_data_file.write("%s\t%s\n" % (word, token)) + tokens, tags = process_line(line) + for token, tag in zip(tokens, tags): + training_data_file.write("%s\t%s\n" % (token, tag)) pbar.close() diff --git a/inference/__init__.py b/dbpunctuator/inference/__init__.py similarity index 100% rename from inference/__init__.py rename to dbpunctuator/inference/__init__.py diff --git a/inference/inference_interface.py b/dbpunctuator/inference/inference_interface.py similarity index 86% rename from inference/inference_interface.py rename to dbpunctuator/inference/inference_interface.py index 5174c67..41a2922 100644 --- a/inference/inference_interface.py +++ b/dbpunctuator/inference/inference_interface.py @@ -5,13 +5,14 @@ import threading from threading import Thread from time import sleep -from typing import List +from typing import List, Tuple -from inference.inference_pipeline import InferenceServer -from utils.utils import register_logger +from .inference_pipeline import InferenceServer + +# from utils.utils import register_logger logger = logging.getLogger(__name__) -register_logger(logger) +# register_logger(logger) class InferenceClient: @@ -112,10 +113,20 @@ def _run(self, check_interval): logger.info("terminate the punctuator") # self.server_process.terminate() - def punctuation(self, inputs: List[str]): + def punctuation(self, inputs: List[str]) -> Tuple[List[str], List[List]]: + """Do punctuation of inputs + + Args: + inputs (List[str]): list of plain text (no punctuated text) + + Returns: + Tuple[List[str], List[List]]: tuple of outputs. + First is the list of punctuated text + Second is the list of labels + """ try: - outputs = self.client.punctuation(inputs) - return outputs + outputs_tuple = self.client.punctuation(inputs) + return outputs_tuple except Exception as err: logger.error(f"error doing punctuation with details {str(err)}") return None diff --git a/inference/inference_pipeline.py b/dbpunctuator/inference/inference_pipeline.py similarity index 87% rename from inference/inference_pipeline.py rename to dbpunctuator/inference/inference_pipeline.py index 80a23eb..3f8bb90 100644 --- a/inference/inference_pipeline.py +++ b/dbpunctuator/inference/inference_pipeline.py @@ -9,11 +9,9 @@ from pydantic import BaseModel from transformers import DistilBertForTokenClassification, DistilBertTokenizerFast -from utils.constant import DIGIT_MASK, TAG_PUNCTUATOR_MAP -from utils.utils import register_logger +from dbpunctuator.utils.constant import DEFAULT_TAG_ID, DIGIT_MASK, TAG_PUNCTUATOR_MAP logger = logging.getLogger(__name__) -register_logger(logger) def verbose(attr_to_log): @@ -38,12 +36,12 @@ class InferenceArguments(BaseModel): Args: model_name_or_path(str): name or path of pre-trained model tokenizer_name(str): name of pretrained tokenizer - tag2id_storage_path(str): tag2id storage path + tag2id_storage_path(str): tag2id storage path, default None. If None, DEFAULT_TAG_ID will be used. """ model_name_or_path: str tokenizer_name: str - tag2id_storage_path: str + tag2id_storage_path: str = None # whole pipeline running in the seperate process, provide a function for user to call, use socket for communication @@ -61,8 +59,12 @@ def __init__(self, inference_arguments, verbose=False): self.classifer = DistilBertForTokenClassification.from_pretrained( inference_arguments.model_name_or_path ) - with open(inference_arguments.tag2id_storage_path, "r") as fp: - tag2id = json.load(fp) + if inference_arguments.tag2id_storage_path: + with open(inference_arguments.tag2id_storage_path, "r") as fp: + tag2id = json.load(fp) + self.id2tag = {id: tag for tag, id in tag2id.items()} + else: + tag2id = DEFAULT_TAG_ID self.id2tag = {id: tag for tag, id in tag2id.items()} self._reset_values() @@ -109,14 +111,18 @@ def classify(self): def post_process(self): reduce_ignored_marks = self.marks >= 0 + self.outputs_labels = [] for pred, reduce_ignored, tokens, digit_index in zip( self.argmax_preds, reduce_ignored_marks, self.all_tokens, self.digit_indexes ): next_upper = True true_pred = pred[reduce_ignored] + result_text = "" + output_labels = [] for id, (index, token) in zip(true_pred, enumerate(tokens)): tag = self.id2tag[id] + output_labels.append(tag) if index in digit_index: token = digit_index[index] if next_upper: @@ -124,12 +130,15 @@ def post_process(self): punctuator, next_upper = TAG_PUNCTUATOR_MAP[tag] result_text += token + punctuator self.outputs.append(result_text.strip()) + self.outputs_labels.append(output_labels) return self def punctuation(self, inputs): self._reset_values() - return self.pre_process(inputs).tokenize().classify().post_process().outputs + self.pre_process(inputs).tokenize().classify().post_process() + + return self.outputs, self.outputs_labels def _mark_ignored_tokens(self, offset_mapping): samples = [] @@ -170,8 +179,8 @@ def __init__( def punctuation(self): try: inputs = self.conn.recv() - outputs = self.inference_pipeline.punctuation(inputs) - self.conn.send(outputs) + outputs_tuple = self.inference_pipeline.punctuation(inputs) + self.conn.send(outputs_tuple) except OSError as err: logger.warning(f"error receiving inputs: {err}") except struct.error as err: diff --git a/training/__init__.py b/dbpunctuator/training/__init__.py similarity index 100% rename from training/__init__.py rename to dbpunctuator/training/__init__.py diff --git a/training/dataset.py b/dbpunctuator/training/dataset.py similarity index 100% rename from training/dataset.py rename to dbpunctuator/training/dataset.py diff --git a/training/train.py b/dbpunctuator/training/train.py similarity index 99% rename from training/train.py rename to dbpunctuator/training/train.py index 0fa8106..8013cf3 100644 --- a/training/train.py +++ b/dbpunctuator/training/train.py @@ -16,7 +16,7 @@ DistilBertTokenizerFast, ) -from training.dataset import generate_tag_ids, read_data, train_test_split +from .dataset import generate_tag_ids, read_data, train_test_split logger = logging.getLogger(__name__) @@ -68,8 +68,8 @@ def load_training_data(self): ) ( self.train_texts, - self.val_texts, self.train_tags, + self.val_texts, self.val_tags, ) = train_test_split(texts, tags, test_size=self.arguments.split_rate) self.tag2id, self.id2tag = generate_tag_ids(tag_docs=tags) diff --git a/utils/__init__.py b/dbpunctuator/utils/__init__.py similarity index 100% rename from utils/__init__.py rename to dbpunctuator/utils/__init__.py diff --git a/utils/constant.py b/dbpunctuator/utils/constant.py similarity index 85% rename from utils/constant.py rename to dbpunctuator/utils/constant.py index a3e5aed..d3eaaf4 100644 --- a/utils/constant.py +++ b/dbpunctuator/utils/constant.py @@ -16,3 +16,5 @@ NUM_BYTE_LENGTH = 2 LENGTH_BYTE_LENGTH = 4 + +DEFAULT_TAG_ID = {"E": 0, "O": 1, "P": 2, "C": 3, "Q": 4} diff --git a/utils/upload_model.py b/dbpunctuator/utils/upload_model.py similarity index 100% rename from utils/upload_model.py rename to dbpunctuator/utils/upload_model.py diff --git a/utils/utils.py b/dbpunctuator/utils/utils.py similarity index 100% rename from utils/utils.py rename to dbpunctuator/utils/utils.py diff --git a/examples/data_sample.py b/examples/data_sample.py index 41a9b85..f762d71 100644 --- a/examples/data_sample.py +++ b/examples/data_sample.py @@ -1,4 +1,4 @@ -from data_process import ( +from dbpunctuator.data_process import ( cleanup_data_from_csv, generate_training_data, remove_brackets_text, diff --git a/examples/inference_sample.py b/examples/inference_sample.py index fbbfbfd..c56450f 100644 --- a/examples/inference_sample.py +++ b/examples/inference_sample.py @@ -1,7 +1,7 @@ import logging -from inference import Inference, InferenceArguments -from utils.utils import register_logger +from dbpunctuator.inference import Inference, InferenceArguments +from dbpunctuator.utils.utils import register_logger logger = logging.getLogger(__name__) register_logger(logger) @@ -11,7 +11,6 @@ args = InferenceArguments( model_name_or_path="Qishuai/distilbert_punctuator_en", tokenizer_name="distilbert-base-uncased", - tag2id_storage_path="models/tag2id.json", ) inference = Inference(inference_args=args, verbose=True) @@ -28,3 +27,29 @@ "great thank you sir here is an additional promo code 5566", ] logger.info(f"testing result {inference.punctuation(test_texts_2)}") + + long_test_text = [ + """ + the two most likely largest inventions of our generation are the internet and the mobile phone + theyve changed the world however largely to our surprise they also turned out to be the perfect tools for the surveillance state + it turned out that the capability to collect data + information and connections about basically any of us and all of us is exactly what weve been hearing throughout of the summer + through revelations and leaks about western intelligence agencies mostly u s intelligence agencies + watching over the rest of the world weve heard about these starting with the revelations from june 6 + edward snowden started leaking information top secret classified information + from the u s intelligence agencies and we started learning about things like prism and xkeyscore and others + and these are examples of the kinds of programs u s + intelligence agencies are running right now against the whole rest of the world + and if you look back about the forecasts on surveillance by george orwell + well it turns out that george orwell was an optimist + we are right now seeing a much larger scale of tracking of individual citizens than he could have ever imagined + and this here is the infamous nsa data center in utah due to be opened very soon + it will be both a supercomputing center and a data storage center + you could basically imagine it has a large hall filled with hard drives storing data they are collecting + and its a pretty big building how big well i can give you the numbers 140 000 square meters + but that doesnt really tell you very much maybe its better to imagine it as a comparison + you think about the largest ikea store youve ever been in this is five times larger + how many hard drives can you fit in an ikea store right its pretty big + """ # noqa: E501 + ] + logger.info(f"testing result {inference.punctuation(long_test_text)}") diff --git a/examples/train_sample.py b/examples/train_sample.py index b6a001e..9c7dbc9 100644 --- a/examples/train_sample.py +++ b/examples/train_sample.py @@ -1,5 +1,5 @@ -from training import TrainingArguments, TrainingPipeline -from utils.utils import register_logger +from dbpunctuator.training import TrainingArguments, TrainingPipeline +from dbpunctuator.utils.utils import register_logger if __name__ == "__main__": register_logger() diff --git a/models/tag2id.json b/models/tag2id.json deleted file mode 100644 index cbb5344..0000000 --- a/models/tag2id.json +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6ccf266811d1f037c96c1b6d9829d929a5cb913c5a895fcf0763ca7109c180f4 -size 62 diff --git a/notes.md b/notes.md new file mode 100644 index 0000000..69046c2 --- /dev/null +++ b/notes.md @@ -0,0 +1,14 @@ +# Notes + +## TODO v0.2.0 +[ ] add async inference + [ ] set up batching inference server with mosec + [ ] run mosec in subprocess + [ ] provide two interface + [ ] python interface + [ ] within async inference func, await http request to mosec server and get result + [ ] response to user + [ ] http interface, user directly call mosec server + +## TODO v1.0.0 +[ ] build own batching backend to replace mosec to provide a main socket connection to python to receive request --> avoid http transportation of data \ No newline at end of file diff --git a/setup.py b/setup.py index 1238473..7482791 100644 --- a/setup.py +++ b/setup.py @@ -12,13 +12,13 @@ setup( name="distilbert-punctuator", - version="0.1.0", + version="0.1.1", description="A small seq2seq punctuator tool based on DistilBERT", long_description=readme, long_description_content_type="text/markdown", author="Zhong Qishuai", author_email="ferdinandzhong@gmail.com", - url="https://https://github.com/FerdinandZhong/punctuator", + url="https://github.com/FerdinandZhong/punctuator", packages=find_packages(exclude=["tests*", "example*"]), classifiers=[ "Programming Language :: Python :: 3 :: Only", @@ -38,6 +38,7 @@ "black>=20.8b1", "isort>=5.6", "autoflake>=1.4", + "pandas>=1.3.4" ], }, zip_safe=False, diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 0000000..40d19db --- /dev/null +++ b/tests/common.py @@ -0,0 +1,53 @@ +import pandas as pd +import pytest + +from dbpunctuator.data_process import remove_brackets_text +from dbpunctuator.data_process.data_cleanning import dataframe_data_cleaning +from dbpunctuator.data_process.data_process import process_line +from dbpunctuator.utils.constant import DEFAULT_NER_MAPPING + +punctuations = list(DEFAULT_NER_MAPPING.keys()) + + +@pytest.fixture(scope="module") +def cleaned_data(): + test_df = pd.DataFrame( + { + "test": [ + "Good morning. How are you?(Laughter)It's been great, hasn't it?", # noqa: E501 + "This is a test without special character.", + "'Hello', who @ are * you (music)?", + "i don't know what's going on...", + "my number is 54785564.", + """ + We shot the scene without a single rehearsal", beatty said. + As usual the director insisted on a rehearsal, but I convinced him the best opportunity for a realistic battle would be when the two animals first met. + You simply can't rehearsal a scene like that. Hence the cameramen were ready and the fight was a real one, unfaked... + And claw to claw, fang to fang battle between natrual enimies of the cat family proved conclusively that the fighting prowess of the lion is superior to that of the tiger according to beatty the tiger lost the battle after a terrific struggle. + We used a untamed tiger for the battle scene because we figured a good fight was a likely to ensue, the trainer continued. + That tiger never before been in a cage with a lion. + Nearly a score of movie stars watched the battle and the majority of them bet on the tiger. + I had no idea which would win, but I figured sultan had an even chance, though lions are gang fighters and a tiger is supposed to be invinceable in a single-handed battle with almost any animal. + My reasons for giving the lion an even chance was that I knew that when one takes a hold with his teeth it is in a vital spot, while a tiger sinks his teeth and hangs on whereever he grabs first. + Thats exactly why tommy lost the fight. While the tiger is simply hanging on to a shoulder, the lion was manuvering into position to get his enemys throat, all the while using his blade-like claws to great advantage, from now on I'll bet on the lion. + """, # noqa: E501 + ] + } + ) + + cleaned_df = dataframe_data_cleaning( + test_df, "test", set(DEFAULT_NER_MAPPING.keys()), [], remove_brackets_text + ) + + return cleaned_df["test"].tolist() + + +@pytest.fixture(scope="module") +def processed_data(cleaned_data): + all_tokens = [] + all_tags = [] + for line in cleaned_data: + tokens, tags = process_line(line) + all_tokens.append(tokens) + all_tags.append(tags) + return all_tokens, all_tags diff --git a/tests/test_dataprocess.py b/tests/test_dataprocess.py new file mode 100644 index 0000000..8b1d08d --- /dev/null +++ b/tests/test_dataprocess.py @@ -0,0 +1,29 @@ +import re + +import pytest + +from dbpunctuator.utils.constant import DEFAULT_NER_MAPPING +from tests.common import cleaned_data, processed_data # noqa: F401 + +punctuations = list(DEFAULT_NER_MAPPING.keys()) + + +@pytest.mark.usefixtures("cleaned_data") +def test_data_cleaning(cleaned_data): # noqa: F811 + checking_regex = r"\([^()]*\)" + for line in cleaned_data: + bracketed_texts = re.findall(checking_regex, line) + assert len(bracketed_texts) == 0 + + +@pytest.mark.usefixtures("processed_data") +def test_training_data_generation(processed_data): # noqa: F811 + for tokens, tags in zip(*processed_data): + last_token_is_punct = False + for token, tag in zip(tokens, tags): + assert not token.isdigit() + if last_token_is_punct: + assert token not in punctuations + if token in punctuations: + assert tag != "O" + last_token_is_punct = True diff --git a/tests/test_inference.py b/tests/test_inference.py index 4de6bb2..b4a8b72 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,30 +1,32 @@ -from inference import InferenceArguments, Inference +import numpy as np import pytest +from dbpunctuator.inference import Inference, InferenceArguments +from tests.common import cleaned_data, processed_data # noqa: F401 testing_args = InferenceArguments( - model_name_or_path="models/punctuator", - tokenizer_name="distilbert-base-uncased", - tag2id_storage_path="models/tag2id.json" + model_name_or_path="models/punctuator", tokenizer_name="distilbert-base-uncased" ) -batch_test_text = [ - "how are you its been ten years since we met in shanghai im really happy to meet you again whats your current phone number", - "my number is 82732212", -] -long_test_text = [ - "the two most likely largest inventions of our generation are the internet and the mobile phone theyve changed the world however largely to our surprise they also turned out to be the perfect tools for the surveillance state it turned out that the capability to collect data information and connections about basically any of us and all of us is exactly what weve been hearing throughout of the summer through revelations and leaks about western intelligence agencies mostly u s intelligence agencies watching over the rest of the world weve heard about these starting with the revelations from june 6 edward snowden started leaking information top secret classified information from the u s intelligence agencies and we started learning about things like prism and xkeyscore and others and these are examples of the kinds of programs u s intelligence agencies are running right now against the whole rest of the world and if you look back about the forecasts on surveillance by george orwell well it turns out that george orwell was an optimist we are right now seeing a much larger scale of tracking of individual citizens than he could have ever imagined and this here is the infamous nsa data center in utah due to be opened very soon it will be both a supercomputing center and a data storage center you could basically imagine it has a large hall filled with hard drives storing data they are collecting and its a pretty big building how big well i can give you the numbers 140 000 square meters but that doesnt really tell you very much maybe its better to imagine it as a comparison you think about the largest ikea store youve ever been in this is five times larger how many hard drives can you fit in an ikea store right its pretty big" -] +def accuracy(prediction_labels, true_labels): + return round( + np.sum(prediction_labels == true_labels) / prediction_labels.shape[0], 3 + ) -test_text_list = [batch_test_text, long_test_text] -@pytest.mark.parametrize("test_text", test_text_list) -def test_inference(test_text): +@pytest.mark.usefixtures("processed_data") +def test_inference(processed_data): # noqa: F811 + test_texts = [" ".join(token_list) for token_list in processed_data[0]] inference = Inference(testing_args) - results = inference.punctuation(test_text) - assert len(results) == len(test_text) - for result in results: - assert result[0].isupper() - inference.terminate() + results_text, results_labels = inference.punctuation(test_texts) + assert len(results_text) == len(test_texts) + for result_text, result_labels, true_labels in zip( + results_text, results_labels, processed_data[1] + ): + assert result_text[0].isupper() + acc = accuracy(np.array(result_labels), np.array(true_labels)) + print(f"output text: '{result_text}' with accuracy: {acc}") + assert acc >= 0.8 + inference.terminate() From 9dbc7fc76f85721850570ee9435f602544e5e564 Mon Sep 17 00:00:00 2001 From: Qishuai Zhong Date: Fri, 19 Nov 2021 22:27:02 +0800 Subject: [PATCH 2/6] modify readme --- .github/workflows/check.yml | 5 +--- .github/workflows/model.yml | 2 +- README.md | 25 +++++++++++++++++++- dbpunctuator/inference/inference_pipeline.py | 15 ++++++------ 4 files changed, 33 insertions(+), 14 deletions(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 3078811..15fad75 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -7,11 +7,8 @@ on: - main paths: - '.github/workflows/check.yml' - - 'data_process/**' - - 'inference/**' + - 'dbpunctuator/**' - 'models/**' - - 'training/**' - - 'utils/**' - 'tests/**' - 'examples/**' - 'setup.py' diff --git a/.github/workflows/model.yml b/.github/workflows/model.yml index 253977d..417a8c5 100644 --- a/.github/workflows/model.yml +++ b/.github/workflows/model.yml @@ -28,4 +28,4 @@ jobs: env: HUGGINGFACE_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }} run: | - python utils/upload_model.py --fine_tuned_model_path=${{ github.event.inputs.fine_tuned_model_path }} + python dbpunctuator/utils/upload_model.py --fine_tuned_model_path=${{ github.event.inputs.fine_tuned_model_path }} diff --git a/README.md b/README.md index 55a105f..c49703a 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,20 @@ Component for providing a training pipeline for fine-tuning a pretrained `Distil ### Example `examples/train_sample.py` +### Training_arguments: +Arguments required for the training pipeline. + +`data_file_path(str)`: path of training data +`model_name(str)`: name or path of pre-trained model +`tokenizer_name(str)`: name of pretrained tokenizer +`split_rate(float)`: train and validation split rate +`sequence_length(int)`: sequence length of one sample +`epoch(int)`: number of epoch +`batch_size(int)`: batch size +`model_storage_path(str)`: fine-tuned model storage path +`tag2id_storage_path(str)`: tag2id storage path +`addtional_model_config(Optional[Dict])`: additional configuration for model + ## Inference Component for providing an inference interface for user to use punctuator. @@ -39,4 +53,13 @@ Therefore user can initialize an inference object and call its `punctuation` fun There is a `graceful shutdown` methodology for the punctuator, hence user dosen't need to worry about the shutting-down. ### Example -`examples/inference_sample.py` \ No newline at end of file +`examples/inference_sample.py` + +### Inference_arguments +Arguments required for the inference pipeline. + +`model_name_or_path(str)`: name or path of pre-trained model +`tokenizer_name(str)`: name of pretrained tokenizer +`tag2id_storage_path(Optional[str])`: tag2id storage path. If None, DEFAULT_TAG_ID will be used. + +`DEFAULT_TAG_ID`: {"E": 0, "O": 1, "P": 2, "C": 3, "Q": 4} \ No newline at end of file diff --git a/dbpunctuator/inference/inference_pipeline.py b/dbpunctuator/inference/inference_pipeline.py index 3f8bb90..138324b 100644 --- a/dbpunctuator/inference/inference_pipeline.py +++ b/dbpunctuator/inference/inference_pipeline.py @@ -1,8 +1,8 @@ import json import logging -import struct from functools import wraps from itertools import filterfalse +from typing import Optional import numpy as np import torch @@ -36,12 +36,14 @@ class InferenceArguments(BaseModel): Args: model_name_or_path(str): name or path of pre-trained model tokenizer_name(str): name of pretrained tokenizer - tag2id_storage_path(str): tag2id storage path, default None. If None, DEFAULT_TAG_ID will be used. + tag2id_storage_path(Optional[str]): tag2id storage path. If None, DEFAULT_TAG_ID will be used. + + DEFAULT_TAG_ID: {"E": 0, "O": 1, "P": 2, "C": 3, "Q": 4} """ model_name_or_path: str tokenizer_name: str - tag2id_storage_path: str = None + tag2id_storage_path: Optional[str] # whole pipeline running in the seperate process, provide a function for user to call, use socket for communication @@ -175,7 +177,6 @@ def __init__( self.termination = termination self.check_interval = check_interval - # data structure: |num|length|text|length|text... def punctuation(self): try: inputs = self.conn.recv() @@ -183,8 +184,6 @@ def punctuation(self): self.conn.send(outputs_tuple) except OSError as err: logger.warning(f"error receiving inputs: {err}") - except struct.error as err: - logger.warning(f"struct unpack error: {err}") def run(self): assert self.inference_pipeline, "no inference pipeline set up" @@ -199,8 +198,8 @@ def run(self): if self.termination.is_set(): logger.info("termination is set") break - except (struct.error, OSError) as err: - logger.warning(f"struct unpack error: {err}") + except OSError as err: + logger.warning(f"sending output error: {err}") raise err except KeyboardInterrupt: logger.warning("punctuator shut down by keyboard interrupt") From 2404f8359ae1e7350b66304a6088bd37dec912bb Mon Sep 17 00:00:00 2001 From: Qishuai Zhong Date: Fri, 19 Nov 2021 22:37:32 +0800 Subject: [PATCH 3/6] replace lfs in testing --- tests/test_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index b4a8b72..681275d 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -5,7 +5,7 @@ from tests.common import cleaned_data, processed_data # noqa: F401 testing_args = InferenceArguments( - model_name_or_path="models/punctuator", tokenizer_name="distilbert-base-uncased" + model_name_or_path="Qishuai/distilbert_punctuator_en", tokenizer_name="distilbert-base-uncased" ) From 303fe3796be980706f2b91917cd83725085881d7 Mon Sep 17 00:00:00 2001 From: Qishuai Zhong Date: Fri, 19 Nov 2021 22:41:28 +0800 Subject: [PATCH 4/6] remove lfs for checking and package workflows --- .github/workflows/check.yml | 4 ---- .github/workflows/package.yml | 2 -- 2 files changed, 6 deletions(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 15fad75..e7f39eb 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -19,8 +19,6 @@ jobs: timeout-minutes: 5 steps: - uses: actions/checkout@v2 - with: - lfs: true - uses: actions/setup-python@v1 with: python-version: 3.8 @@ -41,8 +39,6 @@ jobs: steps: - uses: actions/checkout@v2 - with: - lfs: true - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v1 with: diff --git a/.github/workflows/package.yml b/.github/workflows/package.yml index f97146c..c1a5d74 100644 --- a/.github/workflows/package.yml +++ b/.github/workflows/package.yml @@ -10,8 +10,6 @@ jobs: timeout-minutes: 5 steps: - uses: actions/checkout@v2 - with: - lfs: true - uses: actions/setup-python@v1 with: python-version: 3.8 From 7cb77fbb1c8d436b8a0a8394edde467b4e5528c0 Mon Sep 17 00:00:00 2001 From: Qishuai Zhong Date: Fri, 19 Nov 2021 22:57:27 +0800 Subject: [PATCH 5/6] modify pandas version --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 7482791..99a2e56 100644 --- a/setup.py +++ b/setup.py @@ -31,14 +31,14 @@ python_requires=">3.6", install_requires=requires, extras_require={ - "data_process": ["pandas>=1.3.4"], + "data_process": ["pandas>=1.1.0"], "dev": [ "pytest>=6", "flake8>=3.8", "black>=20.8b1", "isort>=5.6", "autoflake>=1.4", - "pandas>=1.3.4" + "pandas>=1.1.0" ], }, zip_safe=False, From 165aa2b2ed430c8c3a70c89f5cceb7cf36f05c58 Mon Sep 17 00:00:00 2001 From: Qishuai Zhong Date: Fri, 19 Nov 2021 23:02:47 +0800 Subject: [PATCH 6/6] formmating --- tests/test_inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 681275d..972655b 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -5,7 +5,8 @@ from tests.common import cleaned_data, processed_data # noqa: F401 testing_args = InferenceArguments( - model_name_or_path="Qishuai/distilbert_punctuator_en", tokenizer_name="distilbert-base-uncased" + model_name_or_path="Qishuai/distilbert_punctuator_en", + tokenizer_name="distilbert-base-uncased", )