-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPreprocess.py
143 lines (114 loc) · 3.93 KB
/
Preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# ----- import module ----- #
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
# ----- Hyper parameter ----- #
CorrThreshold = 0.7 # Threshold for highly correlated pairs
L = 5
startIndex = 0
# ----- Read files ----- #
# change directory
os.chdir(r"C:\....\Predict_COVID") # alter the directory !!!
# Open CSV file
with open('covid_19.csv', 'r') as csvFile:
reader = csv.reader(csvFile, delimiter=',')
dataTemp = np.array(list(reader))
dataTemp = np.delete(dataTemp, [0, 1, 2], 0) # Delete first 3 rows
dataTemp = np.delete(dataTemp, [1, 2], 1) # Delete 1, 2 cols
# Turn ndarray to dictionary
CSVdict = {}
for i in range(len(dataTemp)):
CSVdict.update( {dataTemp[i][0] : list(map(float,dataTemp[i][1:-1])) } )
# Get Sequence 'Seq'
Seqdict = {}
for country in CSVdict:
Seq = np.array(CSVdict[country][1:]) - np.array(CSVdict[country][:-1]) # difference sequence
Seqdict.update( { country : Seq } ) # {"country": Diff_ndarray}
# ----- Find correlated countries ----- #
# Compute correlation coefficient (CC)
Corr = []
C = set() # store high-correlated countries
for country1 in Seqdict:
CorrVec = []
for country2 in Seqdict:
corrValue = np.corrcoef(Seqdict[country1], Seqdict[country2])
CorrVec.append(corrValue[0, 1])
# add this pair of countries if they are highly correlated
if (country1 != country2)and(corrValue[0, 1] >= CorrThreshold):
C.add(country1)
C.add(country2)
Corr.append(CorrVec)
Corr = np.array(Corr)
# Plot these coefficients in all the pairs
fig, ax = plt.subplots(1, 1)
x, y = np.mgrid[0:185, 0:185]
CorrRev = Corr[::-1, :]
z = CorrRev
mesh = ax.pcolormesh(x, y, z)
fig.colorbar(mesh)
plt.show()
# ----- Generate the pair (input,target) for modelling ----- #
# Preprocess Input & label
TrainData = []
TrainLabel = []
for country in C:
# generate sequence L for each label
CountrySeq = Seqdict[country]
for i in range(len(CountrySeq) - L - startIndex):
# add input segment
TrainData.append(CountrySeq[(startIndex + i) : (startIndex + i + L)])
# add label
if CountrySeq[(startIndex + i + L)] > CountrySeq[(startIndex + i + L - 1)]:
TrainLabel.append([0, 1]) # increase
else:
TrainLabel.append([1, 0]) # decrease
# Transfer to ndarray
InputData = np.array(TrainData)
Label = np.array(TrainLabel)
np.save("./InputData.npy", InputData)
np.save("./Label.npy", Label)
# ----- save info for Global Map drawing ----- #
# produce Map label
Map_InputData = []
Map_country = []
for country in Seqdict:
# generate sequence L for each label
Map_country.append(country)
CountrySeq = Seqdict[country]
Map_InputData.append(CountrySeq[-L:])
Map_country = np.array(Map_country)
Map_InputData = np.array(Map_InputData)
np.save("./Map_InputData.npy", Map_InputData)
np.save("./Map_country.npy", Map_country)
"""
# Plot only 6 country for demo
Corr2 = []
PlotCountry = ['Japan', 'Korea, South', 'China', 'France','US','Taiwan*']
for country1 in PlotCountry:
CorrVec = []
for country2 in PlotCountry:
corrValue = np.corrcoef(Seqdict[country1], Seqdict[country2])
CorrVec.append(corrValue[0, 1])
Corr2.append(CorrVec)
Corr2 = np.array(Corr2)
# Plot these coefficients in all the pairs
fig, ax = plt.subplots(1, 1)
x, y = np.mgrid[0:7, 0:7]
CorrRev = Corr2[::-1, :]
z = CorrRev
mesh = ax.pcolormesh(x, y, z)
fig.colorbar(mesh)
plt.show()
fig.canvas.draw()
#for label in ax.xaxis.get_xticklabels():
# label.set_horizontalalignment('right')
for tick in ax.xaxis.get_minor_ticks():
tick.tick1line.set_markersize(0)
tick.tick2line.set_markersize(0)
tick.label1.set_horizontalalignment('right')
labels = PlotCountry
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
plt.show()
"""