Skip to content

Commit

Permalink
Add saveimg
Browse files Browse the repository at this point in the history
  • Loading branch information
chizuchizu committed Mar 25, 2021
1 parent 8ea6767 commit 08649ae
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/mnist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, accuracy_score
from omegaconf import OmegaConf
from torch.utils.data import Dataset, DataLoader
Expand All @@ -9,7 +10,6 @@
base:
api_key_path: "token/token.json"
train_path: "../data/train_400.csv"
seed: 67
dataset:
Expand All @@ -27,7 +27,6 @@
each_weight: 1 # 重み係数
length_weight: 3 # 重みの層の数
multiprocessing: true
"""

cfg = OmegaConf.create(conf)
Expand Down Expand Up @@ -66,7 +65,8 @@ def get_ds(n, seed):


dataset = pd.read_csv(cfg.base.train_path) # .sample(100).values
for i in range(10):
for i in range(6, 10):
init_client(cfg)
cfg.dataset.target = i
train = get_ds(
int(
Expand Down Expand Up @@ -107,3 +107,16 @@ def get_ds(n, seed):
print("ACC:", accuracy_score(label, np.round(pred)))
print("=" * 43)
weight = weight.sum(axis=1) * cfg.model.each_weight

plt.imshow(
weight.reshape(
cfg.dataset.img_size,
cfg.dataset.img_size
)
)
plt.axis('tight')
plt.axis('off')
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
plt.savefig(
f"../img/MNIST_{cfg.dataset.target}_weight.png"
)

0 comments on commit 08649ae

Please sign in to comment.