-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathwarp.py
32 lines (22 loc) · 924 Bytes
/
warp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
import torch.nn.functional as F
def warp(feat, flow, mode='bilinear', padding_mode='zeros'):
"""
warp an image/tensor (im2) back to im1, according to the optical flow im1 --> im2
input flow must be in format (x, y) at every pixel
feat: [B, C, H, W] (im2)
flow: [B, 2, H, W] flow (x, y)
"""
B, C, H, W = feat.size()
# mesh grid
rowv, colv = torch.meshgrid([torch.arange(0.5, H + 0.5), torch.arange(0.5, W + 0.5)])
grid = torch.stack((colv, rowv), dim=0).unsqueeze(0).float().to(feat.device)
grid = grid + flow
# scale grid to [-1,1]
grid_norm_c = 2.0 * grid[:, 0] / W - 1.0
grid_norm_r = 2.0 * grid[:, 1] / H - 1.0
grid_norm = torch.stack((grid_norm_c, grid_norm_r), dim=1)
grid_norm = grid_norm.permute(0, 2, 3, 1)
output = F.grid_sample(feat, grid_norm, mode=mode, padding_mode=padding_mode)
return output