-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
88 lines (66 loc) · 2.67 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# import jieba.analyse
import jieba
from torchtext import data
from torchtext import vocab
import pandas as pd
import numpy as np
import torch
import codecs
import torch.nn.functional as F
# load stop words
def load_stopwords():
stopwords = []
with codecs.open('tmp/stopwords.txt', 'r', encoding='utf-8') as f:
for line in f.readlines():
# print(line.strip())
stopwords.append(line.strip())
return stopwords
def extract_keywords(s):
# jieba.analyse.textrank
tags = jieba.analyse.extract_tags(s, topK=35, withWeight=True, allowPOS=(['n', 'v', 'nt', 'vn']))
res = ''
for item in tags:
res += ' ' + item[0]
print(res)
return res
class Data():
def __init__(self):
self.data = []
self.target = []
self.target_names = []
self.class_name = ['社会', '时政', '健康', '科技', '教育']
def load_data(self, path):
df = pd.read_csv(path) # , nrows=30
for i in range(len(df)):
tmp = extract_keywords(df['title'][i] + ' ' + df['content'][i])
self.data.append(tmp)
self.target.append(self.class_name.index(df['category'][i]))
self.target_names.append(df['category'][i])
self.target = np.asarray(self.target)
import dill
embedding = './data/sgns.sogou.word'
def tokenizer(s):
return jieba.lcut(s)
# load data
def load_news(config, text_field, band_field):
fields = {
'text': ('text', text_field),
'label': ('label', band_field)
}
word_vectors = vocab.Vectors(config.embedding_file)
train, val, test = data.TabularDataset.splits(
path=config.data_path, train='train.csv', validation='val.csv',
test='test.csv', format='csv', fields=fields)
print("the size of train: {}, dev:{}, test:{}".format(
len(train.examples), len(val.examples), len(test.examples)))
text_field.build_vocab(train, val, test, max_size=config.n_vocab, vectors=word_vectors,
unk_init=torch.Tensor.normal_)
band_field.build_vocab(train, val, test)
train_iter, val_iter, test_iter = data.BucketIterator.splits(
(train, val, test), batch_sizes=(config.batch_size, config.batch_size, config.batch_size), sort=False,
device=config.device, sort_within_batch=False, shuffle=False)
# data loader and split Chinese
text_field = data.Field(tokenize=tokenizer, include_lengths=True, fix_length=512)
band_field = data.Field(sequential=False, use_vocab=False, batch_first=True,
dtype=torch.int64, preprocessing=data.Pipeline(lambda x: int(x)))
train_iterator, val_iterator, test_iterator = load_news(config, text_field, band_field)