Skip to content

Commit

Permalink
Merge pull request #3 from FerdinandZhong/develop
Browse files Browse the repository at this point in the history
restructure for v0.1.1
  • Loading branch information
FerdinandZhong authored Nov 19, 2021
2 parents 378e77c + 165aa2b commit 53d0fcc
Show file tree
Hide file tree
Showing 30 changed files with 256 additions and 88 deletions.
1 change: 0 additions & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
models/tag2id.json filter=lfs diff=lfs merge=lfs -text
models/punctuator/config.json filter=lfs diff=lfs merge=lfs -text
models/punctuator/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
9 changes: 1 addition & 8 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@ on:
- main
paths:
- '.github/workflows/check.yml'
- 'data_process/**'
- 'inference/**'
- 'dbpunctuator/**'
- 'models/**'
- 'training/**'
- 'utils/**'
- 'tests/**'
- 'examples/**'
- 'setup.py'
Expand All @@ -22,8 +19,6 @@ jobs:
timeout-minutes: 5
steps:
- uses: actions/checkout@v2
with:
lfs: true
- uses: actions/setup-python@v1
with:
python-version: 3.8
Expand All @@ -44,8 +39,6 @@ jobs:

steps:
- uses: actions/checkout@v2
with:
lfs: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ jobs:
env:
HUGGINGFACE_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
run: |
python utils/upload_model.py --fine_tuned_model_path=${{ github.event.inputs.fine_tuned_model_path }}
python dbpunctuator/utils/upload_model.py --fine_tuned_model_path=${{ github.event.inputs.fine_tuned_model_path }}
2 changes: 0 additions & 2 deletions .github/workflows/package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ jobs:
timeout-minutes: 5
steps:
- uses: actions/checkout@v2
with:
lfs: true
- uses: actions/setup-python@v1
with:
python-version: 3.8
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
PY_SOURCE_FILES=data_process/ training/ inference/ utils/ examples/ #this can be modified to include more files
PY_SOURCE_FILES=dbpunctuator/ examples/ tests/ #this can be modified to include more files

install: package
pip install -e .[dev]

test:
pytest tests -vv
pytest tests -vv -s

clean:
rm -rf build/ dist/ *.egg-info .pytest_cache
Expand Down
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ Component for providing a training pipeline for fine-tuning a pretrained `Distil
### Example
`examples/train_sample.py`

### Training_arguments:
Arguments required for the training pipeline.

`data_file_path(str)`: path of training data
`model_name(str)`: name or path of pre-trained model
`tokenizer_name(str)`: name of pretrained tokenizer
`split_rate(float)`: train and validation split rate
`sequence_length(int)`: sequence length of one sample
`epoch(int)`: number of epoch
`batch_size(int)`: batch size
`model_storage_path(str)`: fine-tuned model storage path
`tag2id_storage_path(str)`: tag2id storage path
`addtional_model_config(Optional[Dict])`: additional configuration for model

## Inference
Component for providing an inference interface for user to use punctuator.

Expand All @@ -39,4 +53,13 @@ Therefore user can initialize an inference object and call its `punctuation` fun
There is a `graceful shutdown` methodology for the punctuator, hence user dosen't need to worry about the shutting-down.

### Example
`examples/inference_sample.py`
`examples/inference_sample.py`

### Inference_arguments
Arguments required for the inference pipeline.

`model_name_or_path(str)`: name or path of pre-trained model
`tokenizer_name(str)`: name of pretrained tokenizer
`tag2id_storage_path(Optional[str])`: tag2id storage path. If None, DEFAULT_TAG_ID will be used.

`DEFAULT_TAG_ID`: {"E": 0, "O": 1, "P": 2, "C": 3, "Q": 4}
7 changes: 7 additions & 0 deletions dbpunctuator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import logging

from .utils.utils import register_logger

# setup library logging
logger = logging.getLogger(__name__)
register_logger(logger)
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

tqdm.pandas()

default_kept_punctuations = {",", ".", "?", "!"}


def dataframe_data_cleaning(
df, target_col, kept_punctuations, additional_to_remove, *special_cleaning_funcs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import logging

import pandas as pd
from tqdm import tqdm

from data_process.data_cleanning import dataframe_data_cleaning
from utils.constant import DEFAULT_NER_MAPPING, DIGIT_MASK
from dbpunctuator.utils.constant import DEFAULT_NER_MAPPING, DIGIT_MASK

from .data_cleanning import dataframe_data_cleaning

logger = logging.getLogger(__name__)


def cleanup_data_from_csv(
Expand All @@ -11,7 +16,7 @@ def cleanup_data_from_csv(
output_file_path,
ner_mapping=DEFAULT_NER_MAPPING,
additional_to_remove=[],
special_cleaning_funcs=[],
*special_cleaning_funcs,
):
"""clean up training data from csv file
Expand All @@ -21,13 +26,14 @@ def cleanup_data_from_csv(
output_file_path (string): path of cleaned data
ner_mapping (dict, optional): NER mapping of punctuation marks. Defaults to utils.constant.DEFAULT_NER_MAPPING
additional_to_remove (list, optional): additional special characters to remove, default []
special_cleaning_funcs (List[func], optional): additional cleaning funcs to apply to csv data, default []
*special_cleaning_funcs (funcs, optional): additional cleaning funcs to apply to csv data
"""
dataframe = pd.read_csv(csv_path)
additional_to_remove = ["—"]
kept_punctuations = set(ner_mapping.keys())
logger.info("clean up original data")
result_df = dataframe_data_cleaning(
dataframe[1500:],
dataframe,
target_col,
kept_punctuations,
additional_to_remove,
Expand All @@ -44,8 +50,8 @@ def cleanup_data_from_csv(

def process_line(line, ner_mapping=DEFAULT_NER_MAPPING):
text_list = line.split()
word_list = []
token_list = []
tag_list = []
# clean up puncs in the beginning of the text
latest_word = text_list.pop(0)
while latest_word in ner_mapping:
Expand All @@ -59,23 +65,23 @@ def process_line(line, ner_mapping=DEFAULT_NER_MAPPING):
if not latest_is_punc:
latest_token = ner_mapping[word]
latest_is_punc = True
word_list.append(latest_word)
token_list.append(latest_token)
token_list.append(latest_word)
tag_list.append(latest_token)
else:
pass
else:
if not latest_is_punc:
word_list.append(latest_word)
token_list.append(latest_token)
token_list.append(latest_word)
tag_list.append(latest_token)
latest_is_punc = False
if word.isdigit():
word = DIGIT_MASK
latest_word = word
latest_token = "O"
if not latest_is_punc:
word_list.append(latest_word)
token_list.append(latest_token)
return word_list, token_list
token_list.append(latest_word)
tag_list.append(latest_token)
return token_list, tag_list


def generate_training_data(cleaned_data_path, training_data_path):
Expand All @@ -85,12 +91,13 @@ def generate_training_data(cleaned_data_path, training_data_path):
cleaned_data_path (string): path of cleaned data
training_data_path (string): path of generated training data
"""
logger.info("generate training data")
with open(cleaned_data_path, "r") as data_file:
lines = data_file.readlines()
with open(training_data_path, "w+") as training_data_file:
pbar = tqdm(lines)
for line in pbar:
words, tokens = process_line(line)
for word, token in zip(words, tokens):
training_data_file.write("%s\t%s\n" % (word, token))
tokens, tags = process_line(line)
for token, tag in zip(tokens, tags):
training_data_file.write("%s\t%s\n" % (token, tag))
pbar.close()
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import threading
from threading import Thread
from time import sleep
from typing import List
from typing import List, Tuple

from inference.inference_pipeline import InferenceServer
from utils.utils import register_logger
from .inference_pipeline import InferenceServer

# from utils.utils import register_logger

logger = logging.getLogger(__name__)
register_logger(logger)
# register_logger(logger)


class InferenceClient:
Expand Down Expand Up @@ -112,10 +113,20 @@ def _run(self, check_interval):
logger.info("terminate the punctuator")
# self.server_process.terminate()

def punctuation(self, inputs: List[str]):
def punctuation(self, inputs: List[str]) -> Tuple[List[str], List[List]]:
"""Do punctuation of inputs
Args:
inputs (List[str]): list of plain text (no punctuated text)
Returns:
Tuple[List[str], List[List]]: tuple of outputs.
First is the list of punctuated text
Second is the list of labels
"""
try:
outputs = self.client.punctuation(inputs)
return outputs
outputs_tuple = self.client.punctuation(inputs)
return outputs_tuple
except Exception as err:
logger.error(f"error doing punctuation with details {str(err)}")
return None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import json
import logging
import struct
from functools import wraps
from itertools import filterfalse
from typing import Optional

import numpy as np
import torch
from pydantic import BaseModel
from transformers import DistilBertForTokenClassification, DistilBertTokenizerFast

from utils.constant import DIGIT_MASK, TAG_PUNCTUATOR_MAP
from utils.utils import register_logger
from dbpunctuator.utils.constant import DEFAULT_TAG_ID, DIGIT_MASK, TAG_PUNCTUATOR_MAP

logger = logging.getLogger(__name__)
register_logger(logger)


def verbose(attr_to_log):
Expand All @@ -38,12 +36,14 @@ class InferenceArguments(BaseModel):
Args:
model_name_or_path(str): name or path of pre-trained model
tokenizer_name(str): name of pretrained tokenizer
tag2id_storage_path(str): tag2id storage path
tag2id_storage_path(Optional[str]): tag2id storage path. If None, DEFAULT_TAG_ID will be used.
DEFAULT_TAG_ID: {"E": 0, "O": 1, "P": 2, "C": 3, "Q": 4}
"""

model_name_or_path: str
tokenizer_name: str
tag2id_storage_path: str
tag2id_storage_path: Optional[str]


# whole pipeline running in the seperate process, provide a function for user to call, use socket for communication
Expand All @@ -61,8 +61,12 @@ def __init__(self, inference_arguments, verbose=False):
self.classifer = DistilBertForTokenClassification.from_pretrained(
inference_arguments.model_name_or_path
)
with open(inference_arguments.tag2id_storage_path, "r") as fp:
tag2id = json.load(fp)
if inference_arguments.tag2id_storage_path:
with open(inference_arguments.tag2id_storage_path, "r") as fp:
tag2id = json.load(fp)
self.id2tag = {id: tag for tag, id in tag2id.items()}
else:
tag2id = DEFAULT_TAG_ID
self.id2tag = {id: tag for tag, id in tag2id.items()}

self._reset_values()
Expand Down Expand Up @@ -109,27 +113,34 @@ def classify(self):
def post_process(self):
reduce_ignored_marks = self.marks >= 0

self.outputs_labels = []
for pred, reduce_ignored, tokens, digit_index in zip(
self.argmax_preds, reduce_ignored_marks, self.all_tokens, self.digit_indexes
):
next_upper = True
true_pred = pred[reduce_ignored]

result_text = ""
output_labels = []
for id, (index, token) in zip(true_pred, enumerate(tokens)):
tag = self.id2tag[id]
output_labels.append(tag)
if index in digit_index:
token = digit_index[index]
if next_upper:
token = token.capitalize()
punctuator, next_upper = TAG_PUNCTUATOR_MAP[tag]
result_text += token + punctuator
self.outputs.append(result_text.strip())
self.outputs_labels.append(output_labels)

return self

def punctuation(self, inputs):
self._reset_values()
return self.pre_process(inputs).tokenize().classify().post_process().outputs
self.pre_process(inputs).tokenize().classify().post_process()

return self.outputs, self.outputs_labels

def _mark_ignored_tokens(self, offset_mapping):
samples = []
Expand Down Expand Up @@ -166,16 +177,13 @@ def __init__(
self.termination = termination
self.check_interval = check_interval

# data structure: |num|length|text|length|text...
def punctuation(self):
try:
inputs = self.conn.recv()
outputs = self.inference_pipeline.punctuation(inputs)
self.conn.send(outputs)
outputs_tuple = self.inference_pipeline.punctuation(inputs)
self.conn.send(outputs_tuple)
except OSError as err:
logger.warning(f"error receiving inputs: {err}")
except struct.error as err:
logger.warning(f"struct unpack error: {err}")

def run(self):
assert self.inference_pipeline, "no inference pipeline set up"
Expand All @@ -190,8 +198,8 @@ def run(self):
if self.termination.is_set():
logger.info("termination is set")
break
except (struct.error, OSError) as err:
logger.warning(f"struct unpack error: {err}")
except OSError as err:
logger.warning(f"sending output error: {err}")
raise err
except KeyboardInterrupt:
logger.warning("punctuator shut down by keyboard interrupt")
Expand Down
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions training/train.py → dbpunctuator/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
DistilBertTokenizerFast,
)

from training.dataset import generate_tag_ids, read_data, train_test_split
from .dataset import generate_tag_ids, read_data, train_test_split

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,8 +68,8 @@ def load_training_data(self):
)
(
self.train_texts,
self.val_texts,
self.train_tags,
self.val_texts,
self.val_tags,
) = train_test_split(texts, tags, test_size=self.arguments.split_rate)
self.tag2id, self.id2tag = generate_tag_ids(tag_docs=tags)
Expand Down
File renamed without changes.
Loading

0 comments on commit 53d0fcc

Please sign in to comment.