-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
145 lines (131 loc) · 4.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import json
import torch
import random
import numpy as np
from collections import Counter
def seed_everything(seed=None, reproducibility=True):
'''
init random seed for random functions in numpy, torch, cuda and cudnn
Args:
seed (int): random seed
reproducibility (bool): Whether to require reproducibility
'''
if seed is None:
seed = int(_select_seed_randomly())
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if reproducibility:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
LABEL_MEANING_MAP = {
"NHCS": "犯罪嫌疑人",
"NHVI": "受害人",
"NCSM": "被盗货币",
"NCGV": "物品价值",
"NCSP": "盗窃获利",
"NASI": "被盗物品",
"NATS": "作案工具",
"NT": "时间",
"NS": "地点",
"NO": "组织机构",
}
MEANING_LABEL_MAP = {v: k for k, v in LABEL_MEANING_MAP.items()}
def load_raw(filepath):
raw = []
with open(os.path.join(filepath), "r") as f:
for line in f.readlines():
r = dict()
line = json.loads(line)
context = line["context"]
r["id"] = line["id"]
r["text"] = context
r["sent_start"] = 0
r["sent_end"] = len(context)
r["entities"] = []
if "entities" in line:
for entity in line["entities"]:
for span in entity["span"]:
start, end = span.split(";")
start, end = int(start), int(end)
r["entities"].append((
entity["label"],
start, end - 1,
context[start: end]
))
raw.append(r)
return raw
def count_entity_labels(samples):
labels = []
for sample in samples:
for label, *other in sample["entities"]:
labels.append(LABEL_MEANING_MAP[label])
counter = Counter(labels)
return counter
def save_samples(filename, samples):
with open(filename, "w") as f:
for sample in samples:
sample = json.dumps(sample, ensure_ascii=False) + "\n"
f.write(sample)
def save_groundtruths(filename, groundtruths):
with open(filename, "w") as f:
for gt in groundtruths:
f.write(json.dumps(gt, ensure_ascii=False) + "\n")
def add_context(ordered_samples, context_window):
if context_window <= 0:
return ordered_samples
samples = copy.deepcopy(ordered_samples)
for i in range(len(samples)):
if i == 0: continue
text = samples[i]["text"]
add_left = (context_window-len(text)) // 2
add_right = (context_window-len(text)) - add_left
sent_start, sent_end = samples[i]["sent_start"], samples[i]["sent_end"]
# add left context
j = i - 1
while j >= 0 and add_left > 0:
context_to_add = samples[j]["text"][-add_left:]
text = context_to_add + text
add_left -= len(context_to_add)
sent_start += len(context_to_add)
sent_end += len(context_to_add)
j -= 1
# add right context
j = i + 1
while j < len(samples) and add_right > 0:
context_to_add = samples[j]["text"][:add_right]
text = text + context_to_add
add_right -= len(context_to_add)
j += 1
# adjust entities
entities = []
for label, start, end, span_text in samples[i]["entities"]:
start += sent_start; end += sent_start
span_text_new = text[start: end + 1]
assert span_text_new == span_text, "Error"
entities.append((label, start, end, span_text))
samples[i]["text"] = text
samples[i]["sent_start"] = sent_start
samples[i]["sent_end"] = sent_end
samples[i]["entities"] = entities
return samples
def get_ner_tags(entities, seq_len):
ner_tags = ["O"] * seq_len
for entity in entities:
t, s, e = entity[:3]
if s < 0 or s >= seq_len or e < 0 or e >= seq_len \
or s > e or ner_tags[s] != "O" or ner_tags[e] != "O":
continue
ner_tags[s] = f"B-{t}"
for i in range(s + 1, e + 1):
ner_tags[i] = f"I-{t}"
return ner_tags
if __name__ == "__main__":
load_raw("./data/信息抽取_第一阶段/xxcq_small.json")