-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
39 lines (28 loc) · 1.02 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import glob
import hydra
import torch
import pytorch_lightning as pl
from learned_gradient_descent.data import create_dataset
from learned_gradient_descent.models.baseline import LGDModel
class CustomDataModule(pl.LightningDataModule):
def __init__(self, testset):
super().__init__()
self.testset = testset
def test_dataloader(self):
return self.testset
@hydra.main(config_path="configs", config_name="eval")
def main(opt):
pl.seed_everything(opt.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
testset = create_dataset(opt.data.test)
datamodule = CustomDataModule(testset)
checkpoints = sorted(glob.glob("ckpts/*.ckpt"), key=os.path.getmtime)
print("checkpoints", checkpoints)
model = LGDModel.load_from_checkpoint(checkpoints[-1], strict=False, opt=opt.model)
model.eval()
trainer = pl.Trainer(gpus=1, accelerator="gpu")
trainer.test(model, datamodule=datamodule)
if __name__ == '__main__':
main()