PointCloud 데이터를 이용한 대표적인 모델인 PointNet의 구조와 PyTorch로 구현한 코드이다.
PointNet은 Feature extraction 후 classification / segmentation을 수행할 수 있지만,
본 글에서는 classification을 위한 네트워크만 소개한다.
코드는 가장 star 수가 많은 PyTorch implementation인 아래 Github repo를 참고하여 일부 수정했다.
https://github.com/fxia22/pointnet.pytorch/tree/f0c2430b0b1529e3f76fb5d6cd6ca14be763d975
PointNet
PointNet은 PointCloud 데이터를 Voxel grid 형태로 만들지 않고, PointCloud 형태에서 바로 feature extraction을 하는 모델이다.
PointCloud 데이터에서 feature extraction을 하기 위해서는 다음 세 가지를 고려해야 한다.
1. Input의 순서에 invariant해야 한다.
- [pointA, pointB, pointC] 순서의 input이나, [pointB, pointA, pointC] 순서의 input이나, 모두 같은 object를 가리키는 데이터이기 때문에 항상 결과를 내야 한다는 것이다.
- 다시 말해, 모델이 input permutation에 invariant해야 하며, 이를 위해 symmetric function을 이용해 각 point에서의 정보를 결합해야 한다.
- Symmetric function이란, n개의 vector을 input으로 받아, input order에 invariant한 새로운 vector을 출력하는 함수를 말한다. 더하기, 곱하기 등이 여기에 속한다.
- PointNet에서는 다음과 같은 symmetric function을 정의하여 사용한다.
- $f(\{x_1,..., x_n\})=\gamma \circ g(h(x_1),...,h(x_n))$
- 위 식은 $g$가 symmetric일 시 symmetric function이다.
- $h$: MLP
- $g$: max pooling
- $\gamma$: MLP
- $f(\{x_1,..., x_n\})=\gamma \circ g(h(x_1),...,h(x_n))$
2. Local과 Global한 정보의 결합
- 이 부분은 segmentation network에만 해당되는 항목이라 pass
3. Geometric transformation에 invariant해야 한다.
- Input의 순서 뿐만 아니라, linear transformation에도 invariant해야 한다.
- Geometric transformation에 invariant한 representation을 만들기 위해, canonical space로의 mapping을 위한 affine transformation parameter을 학습한다.
- 여기서 canonical space란, linear transformation을 가해도 변형되지 않는 기저공간이며, 이 공간으로의 매핑을 해주는 transformation을 학습하는 것으로 이해했다.
- transformation을 위한 parameter을 학습하는 네트워크를 T-Net이라 하고, 여기서 학습한 transformation matrix를 input feature에 곱해주는 것으로 mapping을 수행한다 (matrix multiply).
- T-Net은 shared MLP와 maxpooling, fc layer들로 구성되어, NxC 크기의 input을 받아 CxC 크기의 transformation matrix를 출력한다.
- 이를 input에 곱해 transformation을 수행해준다.
- 이를 input image와 중간 feature에 대해 두 번 적용한다.
- 이때 중간 feature의 경우 64x64 size의 transformation matrix를 예측해야 하므로, 차원이 너무 커서 최적화하기 힘들다. 따라서 이 경우에는 regularization term을 추가한다.
- $L_{reg}=||I-AA^T||^2_F$
PyTorch Implementation
main network (PointNetCls)
class PointNetCls(nn.Module):
def __init__(self, num_classes=2):
super(PointNetCls, self).__init__()
self.tnet = TNet(dim=3)
self.mlp1 = mlpblock(3, 64)
self.tnet_feature = TNet(dim=64)
self.mlp2 = nn.Sequential(
mlpblock(64, 128),
mlpblock(128, 1024, act_f=False)
)
self.mlp3 = nn.Sequential(
fcblock(1024, 512),
fcblock(512, 256, dropout_rate=0.3),
nn.Linear(256, num_classes)
)
def forward(self, x):
"""
:input size: (N, n_points, 3)
:output size: (N, num_classes)
"""
x = x.transpose(2, 1) #N, 3, n_points
trans = self.tnet(x) #N, 3, 3
x = torch.bmm(x.transpose(2, 1), trans).transpose(2, 1) #N, 3, n_points
x = self.mlp1(x) #N, 64, n_points
trans_feat = self.tnet_feature(x) #N, 64, 64
x = torch.bmm(x.transpose(2, 1), trans_feat).transpose(2, 1) #N, 64, n_points
x = self.mlp2(x) #N, 1024, n_points
x = torch.max(x, 2, keepdim=False)[0] #N, 1024 (global feature)
x = self.mlp3(x) #N, num_classes
return x, trans_feat
- input feature에 대해 우선 T-Net을 통해 transformation matrix를 계산하고, matrix multiplication을 통해 transformation을 수행한다.
- Shared MLP를 이용해 feature 차원을 3->64로 늘려준다.
- 64차원의 shared mlp에 마찬가지로 T-Net과 matrix multiplication을 이용한 transformation을 수행한다.
- Shared MLP를 이용해 feature 차원을 64->128->1024로 늘려준다.
- Max pooling으로 1024차원의 vector를 추출한다.
- 마지막 MLP로 classification을 수행한다.
mlpblock, fcblock
def mlpblock(in_channels, out_channels, act_f=True):
layers = [
nn.Conv1d(in_channels, out_channels, 1),
nn.BatchNorm1d(out_channels),
]
if act_f:
layers.append(nn.ReLU())
return nn.Sequential(*layers)
def fcblock(in_channels, out_channels, dropout_rate=None):
layers = [
nn.Linear(in_channels, out_channels),
]
if dropout_rate is not None:
layers.append(nn.Dropout(p=dropout_rate))
layers += [
nn.BatchNorm1d(out_channels),
nn.ReLU()
]
return nn.Sequential(*layers)
- PointNet에 사용된 mlpblock과 fcblock이다. Shared MLP는 kernel size=1의 1D Conv layer로 구현되었다.
T-Net
class TNet(nn.Module):
def __init__(self, dim=64):
super(TNet, self).__init__()
self.dim = dim
self.mlp = nn.Sequential(
mlpblock(dim, 64),
mlpblock(64, 128),
mlpblock(128, 1024)
)
self.fc = nn.Sequential(
fcblock(1024, 512),
fcblock(512, 256),
nn.Linear(256, dim*dim)
)
def forward(self, x):
x = self.mlp(x)
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
x = self.fc(x)
idt = torch.eye(self.dim, dtype=torch.float32).flatten().unsqueeze(0).repeat(x.size()[0], 1)
idt = idt.to(x.device)
x = x + idt
x = x.view(-1, self.dim, self.dim)
return x
- canonical space로의 mapping을 위한 transformation matrix를 계산하는 T-Net이다.
Train
import torch
import torch.nn as nn
def feature_transform_regularizer(trans):
D = trans.size()[1]
I = torch.eye(D)[None, :, :]
I = I.to(trans.device)
loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
return loss
#sample data
points = torch.rand(5, 1024, 3)
target = torch.empty(5, dtype=torch.long).random_(10)
model = PointNetCls(num_classes=10)
loss_f = nn.CrossEntropyLoss()
pred, trans_feat = model(points)
loss = loss_f(pred, target)
loss += feature_transform_regularizer(trans_feat) * 0.001
- feature transform의 regularization을 위한 함수를 처음에 정의하고 사용했다.
- Loss로는 Cross entropy loss를 이용했다.
Dataloader
PointCloud 데이터는 각 sample마다 point의 수가 다 다르지만, batch 단위로 학습을 진행하기 위해서는 각 sample의 point 수를 맞춰 주어야 한다. 이를 위해 n_points를 설정해 두고, 각 sample마다 random하게 sampling을 진행한다.
추출한 point들은 unit sphere로의 normalization을 적용한다.
또한 data augmentation으로 y축을 기준으로 한 random rotation과, Gaussian noise를 이용한 point jittering을 사용했다.
from torch.utils.data import Dataset
import numpy as np
class PointCloudDataset(Dataset):
def __init__(self, npoints=1024):
self.npoints = npoints
...
def __getitem__(self, index):
points = self.point_list[index]
#randomly sample points
choice = np.random.choice(points.shape[0], self.npoints, replace=True)
points = points[choice, :]
#normalize to unit sphere
points = points - np.expand_dims(np.mean(points, axis=0), 0) #center
dist = np.max(np.sqrt(np.sum(points**2, axis=1)), 0)
points = points / dist #scale
points = self.data_augmentation(points)
label = self.label_list[index]
return torch.from_numpy(points).float(), torch.tensor(label)
def data_augmentation(self, points):
theta = np.random.uniform(0, np.pi*2) #0~360
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
points[:,[0,2]] = points[:,[0,2]].dot(rotation_matrix) # random rotation
points += np.random.normal(0, 0.02, size=points.shape) # random jitter
return points