Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

简化了模型代码 #5

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# compilation and distribution
__pycache__
_ext
*.pyc
*.so
*.egg-info/
build/
dist/
# pytorch/python/numpy formats
*.pth
*.pkl
*.npy

# ipython/jupyter notebooks
**/.ipynb_checkpoints/

# Editor temporaries
*.swn
*.swo
*.swp
*~

# Pycharm editor settings
.idea

# vscode editor settings
.vscode

# MacOS
.DS_Store

# project dirs
data
output
/.idea/
log
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ This repository contains the code for our cvpr 2020 paper: [Unsupervised Learnin
Current Code is tested on ubuntu16.04 with cuda9, python3.6, torch 1.1.0 and torchvision 0.3.0.
We use a [pytorch version of pointnet++](https://github.com/erikwijmans/Pointnet2_PyTorch) in our pipeline.
```
conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=10.0 -c pytorch
pip install -r requirements.txt
cd pointnet2
python setup.py build_ext --inplace
Expand Down
46 changes: 12 additions & 34 deletions dataset/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
unicode_literals,
)
import sys

sys.path.append("..")
import torch
import numpy as np
Expand Down Expand Up @@ -34,14 +35,14 @@ def angle_axis_tensor(angle, axis):

# yapf: disable
cross_prod_mat = torch.Tensor([[0.0, -u[2], u[1]],
[u[2], 0.0, -u[0]],
[-u[1], u[0], 0.0]]).type(torch.FloatTensor)
[u[2], 0.0, -u[0]],
[-u[1], u[0], 0.0]]).type(torch.FloatTensor)

R = cosval * torch.eye(3).type(torch.FloatTensor) + sinval * cross_prod_mat + (1.0 - cosval) * torch.ger(u, u)


return R


def angle_axis(angle, axis):
# type: (float, np.ndarray) -> float
r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle
Expand All @@ -63,8 +64,8 @@ def angle_axis(angle, axis):

# yapf: disable
cross_prod_mat = np.array([[0.0, -u[2], u[1]],
[u[2], 0.0, -u[0]],
[-u[1], u[0], 0.0]])
[u[2], 0.0, -u[0]],
[-u[1], u[0], 0.0]])

R = torch.from_numpy(
cosval * np.eye(3)
Expand Down Expand Up @@ -145,20 +146,18 @@ def __init__(self, std=0.01, clip=0.05):
def __call__(self, points):
jittered_data = (
points.new(points.size(0), 3)
.normal_(mean=0.0, std=self.std)
.clamp_(-self.clip, self.clip)
.normal_(mean=0.0, std=self.std)
.clamp_(-self.clip, self.clip)
)
points[:, 0:3] += jittered_data
return points


class PointcloudNormalize(object):
def __init__(self, max_size=1.0):

self.max_size = max_size

def __call__(self, points):

points_max, _ = torch.max(points, dim=0)
points_min, _ = torch.min(points, dim=0)
points_center = (points_max + points_min) / 2
Expand All @@ -167,10 +166,10 @@ def __call__(self, points):
points = points / max_radius * self.max_size / 2.0
return points


class PointcloudRandomPermutation(object):

def __call__(self, points):

num = points.shape[0]
idxs = torch.randperm(num).type(torch.LongTensor)
points = torch.index_select(points, 0, idxs).clone()
Expand Down Expand Up @@ -208,19 +207,13 @@ def __call__(self, points):
return torch.from_numpy(pc).float()







class PointcloudTranslate(object):
def __init__(self, translation=np.array([0.0, 0.1, 0.0])):
'''
:param translation: pytorch tensor, translation vector(x,y,z)
'''
self.translation = torch.from_numpy(translation)


def __call__(self, points):
'''

Expand All @@ -242,7 +235,6 @@ def __init__(self, scaler):
self.scaler = scaler

def __call__(self, points):

respoints = points * self.scaler
return respoints

Expand All @@ -254,7 +246,6 @@ def __init__(self, angle_in_degree=np.pi, axis=np.array([0.0, 1.0, 0.0]), is_cud
self.rotation_matrix_t = angle_axis(self.angle_in_degree, self.axis).t()

def __call__(self, points):

'''
:param points: ... , num_of_points, 3
:return: points after rotate
Expand All @@ -267,8 +258,7 @@ def __call__(self, points):
return tpoints



def GenPointcloudRandomTransformFunction(max_rot_angle=2*np.pi):
def GenPointcloudRandomTransformFunction(max_rot_angle=2 * np.pi):
scale_lo = 0.8
scale_hi = 1.25
scaler = np.random.uniform(scale_lo, scale_hi)
Expand All @@ -284,7 +274,7 @@ def GenPointcloudRandomTransformFunction(max_rot_angle=2*np.pi):
return trans_func


def AddTransformsToBatchPoints(points, num_of_trans, max_rot_angle=2*np.pi):
def AddTransformsToBatchPoints(points, num_of_trans, max_rot_angle=2 * np.pi):
'''

:param points:bn, num_of_points, 3
Expand Down Expand Up @@ -324,13 +314,10 @@ def __call__(self, points):
else:
tmp_rot = self.rot_mats


transed_poitns = torch.transpose(torch.matmul(tmp_rot, torch.transpose(points, 1, 2)), 1, 2)
return transed_poitns




def AddPCATransformsToBatchPoints(points, num_of_trans):
trans_points_all = None
rot_mats_all = None
Expand All @@ -347,7 +334,7 @@ def AddPCATransformsToBatchPoints(points, num_of_trans):
np.random.shuffle(tmp_idx)
pca_axis = pca_axis_raw[tmp_idx, :]
tmp_sign = np.random.randint(2, size=2)
tmp_sign[tmp_sign==0] = -1
tmp_sign[tmp_sign == 0] = -1
pca_axis[0, :] = pca_axis[0, :] * tmp_sign[0]
pca_axis[1, :] = pca_axis[1, :] * tmp_sign[1]
pca_axis[2, :] = np.cross(pca_axis[0, :], pca_axis[1, :])
Expand Down Expand Up @@ -382,13 +369,4 @@ def AddPCATransformsToBatchPoints(points, num_of_trans):
trans_func = PointcloudRotateFuns(rot_mats_all[ti, :, :, :])
transfunc_list.append(trans_func)


return trans_points_all, rot_mats_all, transfunc_list








24 changes: 24 additions & 0 deletions models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch.nn as nn
import numpy as np
from abc import abstractmethod


class BaseModel(nn.Module):
"""
Base class for all models
"""
@abstractmethod
def forward(self, *inputs):
"""
Forward pass logic
:return: Model output
"""
raise NotImplementedError

def __str__(self):
"""
Model prints with number of trainable parameters
"""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return super().__str__() + '\nTrainable parameters: {}'.format(params)
90 changes: 41 additions & 49 deletions models/pointnet2_structure_point_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
)
import torch
import torch.nn as nn
import etw_pytorch_utils as pt_utils
from pointnet2.utils.pointnet2_modules import PointnetSAModuleMSG
from models import chamfer_distance
from models.base import BaseModel


class ComputeLoss3d(nn.Module):
class ComputeLoss3d(BaseModel):
def __init__(self):
super(ComputeLoss3d, self).__init__()

Expand All @@ -22,7 +22,8 @@ def __init__(self):
self.consistent_loss = None
self.cd_loss = None

def forward(self, gt_points, structure_points, transed_gt_points=None, transed_structure_points=None, trans_func_list=None):
def forward(self, gt_points, structure_points, transed_gt_points=None, transed_structure_points=None,
trans_func_list=None):

gt_points = gt_points.cuda()
structure_points = structure_points.cuda()
Expand All @@ -39,8 +40,9 @@ def forward(self, gt_points, structure_points, transed_gt_points=None, transed_s
transed_structure_points = transed_structure_points.cuda()
transed_gt_points = transed_gt_points.cuda()
trans_num = transed_structure_points.shape[0]
self.cd_loss = self.cd_loss + self.cd_loss_fun(transed_structure_points.view(trans_num * batch_size, stpts_num, dim),
transed_gt_points.view(trans_num * batch_size, pts_num, dim))
self.cd_loss = self.cd_loss + self.cd_loss_fun(
transed_structure_points.view(trans_num * batch_size, stpts_num, dim),
transed_gt_points.view(trans_num * batch_size, pts_num, dim))
self.consistent_loss = None
for i in range(0, trans_num):
tmp_structure_points = trans_func_list[i](structure_points)
Expand All @@ -53,7 +55,6 @@ def forward(self, gt_points, structure_points, transed_gt_points=None, transed_s
self.consistent_loss = self.consistent_loss + tmp_consistent_loss
self.consistent_loss = self.consistent_loss / trans_num * 1000


self.cd_loss = self.cd_loss / (trans_num + 1)

self.loss = self.cd_loss
Expand All @@ -69,8 +70,28 @@ def get_consistent_loss(self):
return self.consistent_loss


class Pointnet2StructurePointNet(nn.Module):
class Conv1dProbLayer(BaseModel):
def __init__(self, in_channels, out_channels, out=False, kernel_size=1, dropout=0.2):
super().__init__()
self.out = out
self.dropout_conv_bn_layer = nn.Sequential(
nn.Dropout(dropout),
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size),
nn.BatchNorm1d(num_features=out_channels),
)
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=2)

def forward(self, x):
x = self.dropout_conv_bn_layer(x)
if self.out:
x = self.softmax(x)
else:
x = self.relu(x)
return x


class Pointnet2StructurePointNet(BaseModel):
def __init__(self, num_structure_points, input_channels=3, use_xyz=True):
super(Pointnet2StructurePointNet, self).__init__()
self.point_dim = 3
Expand Down Expand Up @@ -110,49 +131,26 @@ def __init__(self, num_structure_points, input_channels=3, use_xyz=True):

conv1d_stpts_prob_modules = []
if num_structure_points <= 128 + 256 + 256:
conv1d_stpts_prob_modules.append(nn.Dropout(0.2))
conv1d_stpts_prob_modules.append(nn.Conv1d(in_channels=128 + 256 + 256, out_channels=512, kernel_size=1))
conv1d_stpts_prob_modules.append(nn.BatchNorm1d(512))
conv1d_stpts_prob_modules.append(nn.ReLU())
conv1d_stpts_prob_modules.append(Conv1dProbLayer(128 + 256 + 256, 512))
in_channels = 512
while in_channels >= self.num_structure_points * 2:
out_channels = int(in_channels / 2)
conv1d_stpts_prob_modules.append(nn.Dropout(0.2))
conv1d_stpts_prob_modules.append(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1))
conv1d_stpts_prob_modules.append(nn.BatchNorm1d(out_channels))
conv1d_stpts_prob_modules.append(nn.ReLU())
conv1d_stpts_prob_modules.append(Conv1dProbLayer(in_channels, out_channels))
in_channels = out_channels
conv1d_stpts_prob_modules.append(Conv1dProbLayer(in_channels, self.num_structure_points, True))

conv1d_stpts_prob_modules.append(nn.Dropout(0.2))
conv1d_stpts_prob_modules.append(nn.Conv1d(in_channels=in_channels, out_channels=self.num_structure_points, kernel_size=1))

conv1d_stpts_prob_modules.append(nn.BatchNorm1d(self.num_structure_points))
conv1d_stpts_prob_modules.append(nn.Softmax(dim=2))
else:
conv1d_stpts_prob_modules.append(nn.Dropout(0.2))
conv1d_stpts_prob_modules.append(nn.Conv1d(in_channels=128 + 256 + 256, out_channels=1024, kernel_size=1))
conv1d_stpts_prob_modules.append(nn.BatchNorm1d(1024))
conv1d_stpts_prob_modules.append(nn.ReLU())

in_channels = 1024
while in_channels <= self.num_structure_points / 2:
conv1d_stpts_prob_modules.append(Conv1dProbLayer(128 + 256 + 256, 1024))
while in_channels < self.num_structure_points / 2: # change <= to <
out_channels = int(in_channels * 2)
conv1d_stpts_prob_modules.append(nn.Dropout(0.2))
conv1d_stpts_prob_modules.append(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1))
conv1d_stpts_prob_modules.append(nn.BatchNorm1d(out_channels))
conv1d_stpts_prob_modules.append(nn.ReLU())
conv1d_stpts_prob_modules.append(Conv1dProbLayer(in_channels, out_channels))
in_channels = out_channels

conv1d_stpts_prob_modules.append(nn.Dropout(0.2))
conv1d_stpts_prob_modules.append(nn.Conv1d(in_channels=in_channels, out_channels=self.num_structure_points, kernel_size=1))

conv1d_stpts_prob_modules.append(nn.BatchNorm1d(self.num_structure_points))
conv1d_stpts_prob_modules.append(nn.Softmax(dim=2))

conv1d_stpts_prob_modules.append(Conv1dProbLayer(in_channels, self.num_structure_points, True))
self.conv1d_stpts_prob = nn.Sequential(*conv1d_stpts_prob_modules)

def _break_up_pc(self, pc):
xyz = pc[..., 0:3].contiguous()
xyz = pc[..., 0:3].contiguous() # 取点
features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None
return xyz, features

Expand All @@ -169,20 +167,14 @@ def forward(self, pointcloud, return_weighted_feature=False):
xyz, features = module(xyz, features)

self.stpts_prob_map = self.conv1d_stpts_prob(features)

weighted_xyz = torch.sum(self.stpts_prob_map[:, :, :, None] * xyz[:, None, :, :], dim=2)
if return_weighted_feature:
weighted_features = torch.sum(self.stpts_prob_map[:, None, :, :] * features[:, :, None, :], dim=3)
# print("prob:{},xyz:{},weighted_xyz:{},features:{}".format(self.stpts_prob_map.shape,
# xyz.shape,
# weighted_xyz.shape,
# features.shape))

if return_weighted_feature:
weighted_features = torch.sum(self.stpts_prob_map[:, None, :, :] * features[:, :, None, :], dim=3)
return weighted_xyz, weighted_features
else:

return weighted_xyz







Loading