forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhlb_cifar10.py
139 lines (124 loc) · 5.76 KB
/
hlb_cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3
# tinygrad implementation of https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
# https://myrtle.ai/learn/how-to-train-your-resnet-8-bag-of-tricks/
# https://siboehm.com/articles/22/CUDA-MMM
# TODO: gelu is causing nans!
import time
import numpy as np
from datasets import fetch_cifar
from tinygrad import nn
from tinygrad.state import get_parameters, get_state_dict
from tinygrad.nn import optim
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from tinygrad.ops import GlobalCounters
num_classes = 10
# TODO: eval won't work with track_running_stats=False
class ConvGroup:
def __init__(self, channels_in, channels_out, short, se=True):
self.short, self.se = short, se and not short
self.conv = [nn.Conv2d(channels_in if i == 0 else channels_out, channels_out, kernel_size=3, padding=1, bias=False) for i in range(1 if short else 3)]
self.norm = [nn.BatchNorm2d(channels_out, track_running_stats=False, eps=1e-12, momentum=0.8) for _ in range(1 if short else 3)]
if self.se: self.se1, self.se2 = nn.Linear(channels_out, channels_out//16), nn.Linear(channels_out//16, channels_out)
def __call__(self, x):
x = self.conv[0](x).max_pool2d(2)
x = self.norm[0](x).relu()
if self.short: return x
residual = x
mult = self.se2(self.se1(residual.mean((2,3))).relu()).sigmoid().reshape(x.shape[0], x.shape[1], 1, 1) if self.se else 1.0
x = self.norm[1](self.conv[1](x)).relu()
x = self.norm[2](self.conv[2](x) * mult).relu()
return x + residual
class SpeedyResNet:
def __init__(self):
# TODO: add whitening
self.net = [
nn.Conv2d(3, 64, kernel_size=1),
nn.BatchNorm2d(64, track_running_stats=False, eps=1e-12, momentum=0.8),
lambda x: x.relu(),
ConvGroup(64, 128, short=False),
ConvGroup(128, 256, short=True),
ConvGroup(256, 512, short=False),
lambda x: x.max((2,3)),
nn.Linear(512, num_classes, bias=False)
]
# note, pytorch just uses https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html instead of log_softmax
def __call__(self, x): return x.sequential(self.net).log_softmax()
from tinygrad.jit import TinyJit
@TinyJit
def train_step_jitted(model, optimizer, X, Y):
out = model(X)
loss = out.mul(Y).mean()
if not getenv("DISABLE_BACKWARD"):
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.realize()
def fetch_batch(X_train, Y_train, BS):
# fetch a batch
samp = np.random.randint(0, X_train.shape[0], size=(BS))
Y = np.zeros((BS, num_classes), np.float32)
Y[range(BS),Y_train[samp]] = -1.0*num_classes
X = Tensor(X_train[samp])
Y = Tensor(Y.reshape(BS, num_classes))
return X.realize(), Y.realize()
def train_cifar():
Tensor.training = True
BS, STEPS = getenv("BS", 512), getenv("STEPS", 10)
if getenv("FAKEDATA"):
N = 2048
X_train = np.random.default_rng().standard_normal(size=(N, 3, 32, 32), dtype=np.float32)
Y_train = np.random.randint(0,10,size=(N), dtype=np.int32)
X_test, Y_test = X_train, Y_train
else:
X_train, Y_train = fetch_cifar(train=True)
X_test, Y_test = fetch_cifar(train=False)
print(X_train.shape, Y_train.shape)
Xt, Yt = fetch_batch(X_test, Y_test, BS=BS)
model = SpeedyResNet()
# init weights with torch
# TODO: it doesn't learn with the tinygrad weights, likely since kaiming init
if getenv("TORCHWEIGHTS"):
from examples.hlb_cifar10_torch import SpeedyResNet as SpeedyResNetTorch
torch_model = SpeedyResNetTorch()
model_state_dict = get_state_dict(model)
torch_state_dict = torch_model.state_dict()
for k,v in torch_state_dict.items():
old_mean_std = model_state_dict[k].mean().numpy(), model_state_dict[k].std().numpy()
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
new_mean_std = model_state_dict[k].mean().numpy(), model_state_dict[k].std().numpy()
print(f"initted {k:40s} {str(model_state_dict[k].shape):20s} from torch mean:{old_mean_std[0]:8.5f} -> {new_mean_std[0]:8.5f} std:{old_mean_std[1]:8.5f} -> {new_mean_std[1]:8.5f}")
exit(0)
if getenv("ADAM"):
optimizer = optim.Adam(get_parameters(model), lr=Tensor([0.001]).realize())
else:
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.85, nesterov=True)
# 97 steps in 2 seconds = 20ms / step
# step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136
# 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68
# 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1
# 4.7 seconds for float32 w/o channels last. 24 TFLOPS. we get 50ms then i'll be happy. only 64x off
# https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june
# 136 TFLOPS is the theoretical max w float16 on 3080 Ti
X, Y = fetch_batch(X_train, Y_train, BS=BS)
for i in range(max(1, STEPS)):
if i%10 == 0 and STEPS != 1:
# use training batchnorm (and no_grad would change the kernels)
out = model(Xt)
outs = out.numpy().argmax(axis=1)
loss = (out * Yt).mean().numpy()
correct = outs == Yt.numpy().argmin(axis=1)
print(f"eval {sum(correct)}/{len(correct)} {sum(correct)/len(correct)*100.0:.2f}%, {loss:7.2f} val_loss")
if STEPS == 0: break
GlobalCounters.reset()
st = time.monotonic()
loss = train_step_jitted(model, optimizer, X, Y)
et = time.monotonic()
X, Y = fetch_batch(X_train, Y_train, BS=BS) # do this here
loss_cpu = loss.numpy()
cl = time.monotonic()
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
#train(model, X, Y, optimizer, steps=X.shape[0]//BS, BS=BS)
#evaluate(model, X_test, Y_test)
if __name__ == "__main__":
train_cifar()