-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeature_plotter.py
45 lines (35 loc) · 1.32 KB
/
feature_plotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import util
def plot_object_columns(X: pd.DataFrame):
for i, name in enumerate(X.select_dtypes('object').columns):
# if name not in ['LandContour', 'MiscVal', 'SaleType', 'SaleCondition']:
# continue
# g.map(sns.catplot, name, 'SalePrice', data=X)
X[name] = X[name].fillna('None')
sns.boxplot(name, 'SalePrice', data=X)
print(f'finished plot {i}')
os.makedirs('./house_prices/assets/plots/by_raw/', exist_ok=True)
plt.savefig(f'./house_prices/assets/plots/by_raw/{name}_box.png')
plt.clf()
# plt.show()
def plot_numeric_columns(X: pd.DataFrame):
for i, name in enumerate(X.select_dtypes(exclude='object').columns):
sns.regplot(X[name], X['SalePrice'])
plt.xlabel(name)
print(f'finished plot {i}')
os.makedirs('./house_prices/assets/plots/by_raw/', exist_ok=True)
plt.savefig(f'./house_prices/assets/plots/by_raw/{name}_reg.png')
plt.clf()
# plt.show()
def main():
train, _ = util.read_datasets('./house_prices/assets/input/')
X: pd.DataFrame = train
cols = len(X.columns)
print(f'amount of columns: {cols}')
plot_numeric_columns(X)
plot_object_columns(X)
if __name__ == '__main__':
main()