-
Notifications
You must be signed in to change notification settings - Fork 68
/
Copy pathcolab_utils.py
145 lines (117 loc) · 4.55 KB
/
colab_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Utilities to make life easier when working with Google Colab.
Warning: This module must be imported from Colab, otherwise it will crash.
"""
import collections
import gc
import matplotlib
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import PIL
import torch
from google.colab import files
from torchvision import transforms
# Always use html5 for animations so they can be rendered inline on Colab.
matplotlib.rcParams["animation.html"] = "html5"
_IMAGE_UNLOADER = transforms.Compose(
[transforms.Lambda(lambda x: x.cpu().clone().squeeze(0)), transforms.ToPILImage()]
)
def get_device():
"""Returns the appropriate device depending on what's available."""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def empty_gpu_cache(*names):
"""Tries to empty the GPU cache without needing to restart the Colab session.
Args:
*names: A list of variable names to delete from global scope.
"""
names = names or ["model", "optimizer", "model_trainer"]
for name in names:
if name in globals():
del globals()[name]
if name in locals():
del locals()[name]
gc.collect(0)
gc.collect(1)
gc.collect(2)
torch.cuda.empty_cache()
def upload_files():
"""Creates a widget to upload files from your local machine to Colab.
The files are saved in '/tmp/<file_name>'.
"""
uploaded = files.upload()
for name, data in uploaded.items():
with open(f"/tmp/{name}", "wb") as f:
f.write(data)
def load_image(path, size=None, remove_alpha_channel=True):
"""Loads an image from the given path as a torch.Tensor.
Args:
path: The path to the image to load.
size: Either None, an integer, or a pair of integers. If not None, the image is
resized to the given size before being returned.
remove_alpha_channel: If True, removes the alpha channel from the image.
Returns:
The loaded image as a torch.Tensor.
"""
transform = []
if size is not None:
size = size if isinstance(size, collections.Sequence) else (size, size)
assert len(size) == 2, "'size' must either be a scalar or contain 2 items"
transform.append(transforms.Resize(size))
transform.append(transforms.ToTensor())
image_loader = transforms.Compose(transform)
image = PIL.Image.open(path)
image = image_loader(image)
if remove_alpha_channel:
image = image[:3, :, :]
image = image.to(torch.float)
return image
def imshow(batch_or_tensor, title=None, figsize=None, **kwargs):
"""Renders tensors as an image using Matplotlib.
Args:
batch_or_tensor: A batch or single tensor to render as images. If the batch size
> 1, the tensors are flattened into a horizontal strip before being
rendered.
title: The title for the rendered image. Passed to Matplotlib.
figsize: The size (in inches) for the image. Passed to Matplotlib.
**kwargs: Extra keyword arguments passed as pyplot.imshow(image, **kwargs).
"""
batch = batch_or_tensor
for _ in range(4 - batch.ndim):
batch = batch.unsqueeze(0)
n, c, h, w = batch.shape
tensor = batch.permute(1, 2, 0, 3).reshape(c, h, -1)
image = _IMAGE_UNLOADER(tensor)
plt.figure(figsize=figsize)
plt.title(title)
plt.axis("off")
plt.imshow(image, **kwargs)
def animate(frames, figsize=None, fps=24):
"""Renders the given frames together into an animation.
Args:
frames: Either a list, iterator, or generator of images in torch.Tensor format.
figsize: The display size for the animation; passed to Matplotlib.
fps: The number of frames to render per second (i.e. frames per second).
Returns:
The Matplotlib animation object.
"""
fig = plt.figure(figsize=figsize)
fig.subplots_adjust(left=0, bottom=0, right=1, top=1)
plt.axis("off")
# We pass a fake 2x2 image to 'imshow' since it does not allow None or empty
# lists to be passed in. The fake image data is then updated by animate_fn.
image = plt.imshow([[0, 0], [0, 0]])
def animate_fn(frame):
frame = _IMAGE_UNLOADER(frame)
image.set_data(frame)
return (image,)
anim = animation.FuncAnimation(
fig,
animate_fn,
frames=frames,
interval=1000 / fps,
blit=True,
# Caching frames causes OOMs in Colab when there are a lot of frames or
# the size of individual frames is large.
cache_frame_data=False,
)
plt.close(anim._fig)
return anim