-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
28 lines (24 loc) · 1.37 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import argparse
parser = argparse.ArgumentParser(description='GTM-SM Example')
parser.add_argument('--batch-size', type=int, default=16, metavar='N',
help='input batch size for training (default: 16)')
parser.add_argument('--epochs', type=int, default=25, metavar='N',
help='number of epochs to train (default: 100)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='enables CUDA training')
parser.add_argument('--seed', type=int, default=2018, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status (default: 10)')
parser.add_argument('--save-interval', type=int, default=1, metavar='N',
help='how many epochs to wait before saving model status (default: 1)')
parser.add_argument('--gradient-clip', type=int, default=10, metavar='N',
help='the maximum norm of the gradient will be used (default: 10)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed_all(args.seed)
device = torch.device("cuda" if args.cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}