forked from dsgiitr/AIKavach
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsample.py
24 lines (20 loc) · 824 Bytes
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn
from packaging_class import FinishedModel
import argparse
class DenoisedModel(nn.Module):
def __init__(self, denoiser_path, classifier_path):
super().__init__()
self.denoiser = torch.load(denoiser_path)
self.classifier = torch.load(classifier_path)
def forward(self, x):
return self.classifier(self.denoiser(x))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--denoiser_path', type=str)
parser.add_argument('--classifier_path', type=str)
parser.add_argument('--final_path', type=str)
args = parser.parse_args()
denoised_model = DenoisedModel(args.denoiser_path, args.classifier_path)
final_model = FinishedModel(denoised_model)
torch.save(final_model, args.final_path)