Skip to content

Commit

Permalink
add pytorch implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
WangYueFt committed Jun 12, 2019
1 parent 0099c73 commit 2355c94
Show file tree
Hide file tree
Showing 6 changed files with 513 additions and 0 deletions.
87 changes: 87 additions & 0 deletions pytorch/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author: Yue Wang
@Contact: [email protected]
@File: data.py
@Time: 2018/10/13 6:21 PM
"""


import os
import sys
import glob
import h5py
import numpy as np
from torch.utils.data import Dataset


def download():
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, 'data')
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
zipfile = os.path.basename(www)
os.system('wget %s; unzip %s' % (www, zipfile))
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
os.system('rm %s' % (zipfile))


def load_data(partition):
download()
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, 'data')
all_data = []
all_label = []
for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition)):
f = h5py.File(h5_name)
data = f['data'][:].astype('float32')
label = f['label'][:].astype('int64')
f.close()
all_data.append(data)
all_label.append(label)
all_data = np.concatenate(all_data, axis=0)
all_label = np.concatenate(all_label, axis=0)
return all_data, all_label


def translate_pointcloud(pointcloud):
xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3])
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])

translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
return translated_pointcloud


def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
N, C = pointcloud.shape
pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip)
return pointcloud


class ModelNet40(Dataset):
def __init__(self, num_points, partition='train'):
self.data, self.label = load_data(partition)
self.num_points = num_points
self.partition = partition

def __getitem__(self, item):
pointcloud = self.data[item][:self.num_points]
label = self.label[item]
if self.partition == 'train':
pointcloud = translate_pointcloud(pointcloud)
np.random.shuffle(pointcloud)
return pointcloud, label

def __len__(self):
return self.data.shape[0]


if __name__ == '__main__':
train = ModelNet40(1024)
test = ModelNet40(1024, 'test')
for data, label in train:
print(data.shape)
print(label.shape)
227 changes: 227 additions & 0 deletions pytorch/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author: Yue Wang
@Contact: [email protected]
@File: main.py
@Time: 2018/10/13 10:39 PM
"""


from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from data import ModelNet40
from model import PointNet, DGCNN
import numpy as np
from torch.utils.data import DataLoader
from util import cal_loss, IOStream
import sklearn.metrics as metrics


def _init_():
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
if not os.path.exists('checkpoints/'+args.exp_name):
os.makedirs('checkpoints/'+args.exp_name)
if not os.path.exists('checkpoints/'+args.exp_name+'/'+'models'):
os.makedirs('checkpoints/'+args.exp_name+'/'+'models')
os.system('cp main.py checkpoints'+'/'+args.exp_name+'/'+'main.py.backup')
os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup')
os.system('cp util.py checkpoints' + '/' + args.exp_name + '/' + 'util.py.backup')
os.system('cp data.py checkpoints' + '/' + args.exp_name + '/' + 'data.py.backup')

def train(args, io):
train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=8,
batch_size=args.batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=8,
batch_size=args.test_batch_size, shuffle=True, drop_last=False)

device = torch.device("cuda" if args.cuda else "cpu")

#Try to load models
if args.model == 'pointnet':
model = PointNet(args).to(device)
elif args.model == 'dgcnn':
model = DGCNN(args).to(device)
else:
raise Exception("Not implemented")
print(str(model))

model = nn.DataParallel(model)
print("Let's use", torch.cuda.device_count(), "GPUs!")

if args.use_sgd:
print("Use SGD")
opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)
else:
print("Use Adam")
opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)

scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr)

criterion = cal_loss

best_test_acc = 0
for epoch in range(args.epochs):
scheduler.step()
####################
# Train
####################
train_loss = 0.0
count = 0.0
model.train()
train_pred = []
train_true = []
for data, label in train_loader:
data, label = data.to(device), label.to(device).squeeze()
data = data.permute(0, 2, 1)
batch_size = data.size()[0]
opt.zero_grad()
logits = model(data)
loss = criterion(logits, label)
loss.backward()
opt.step()
preds = logits.max(dim=1)[1]
count += batch_size
train_loss += loss.item() * batch_size
train_true.append(label.cpu().numpy())
train_pred.append(preds.detach().cpu().numpy())
train_true = np.concatenate(train_true)
train_pred = np.concatenate(train_pred)
outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f' % (epoch,
train_loss*1.0/count,
metrics.accuracy_score(
train_true, train_pred),
metrics.balanced_accuracy_score(
train_true, train_pred))
io.cprint(outstr)

####################
# Test
####################
test_loss = 0.0
count = 0.0
model.eval()
test_pred = []
test_true = []
for data, label in test_loader:
data, label = data.to(device), label.to(device).squeeze()
data = data.permute(0, 2, 1)
batch_size = data.size()[0]
logits = model(data)
loss = criterion(logits, label)
preds = logits.max(dim=1)[1]
count += batch_size
test_loss += loss.item() * batch_size
test_true.append(label.cpu().numpy())
test_pred.append(preds.detach().cpu().numpy())
test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred)
test_acc = metrics.accuracy_score(test_true, test_pred)
avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred)
outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f' % (epoch,
test_loss*1.0/count,
test_acc,
avg_per_class_acc)
io.cprint(outstr)
if test_acc >= best_test_acc:
best_test_acc = test_acc
torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % args.exp_name)


def test(args, io):
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points),
batch_size=args.test_batch_size, shuffle=True, drop_last=False)

device = torch.device("cuda" if args.cuda else "cpu")

#Try to load models
model = DGCNN(args).to(device)
model = nn.DataParallel(model)
model.load_state_dict(torch.load(args.model_path))
model = model.eval()
test_acc = 0.0
count = 0.0
test_true = []
test_pred = []
for data, label in test_loader:

data, label = data.to(device), label.to(device).squeeze()
data = data.permute(0, 2, 1)
batch_size = data.size()[0]
logits = model(data)
preds = logits.max(dim=1)[1]
test_true.append(label.cpu().numpy())
test_pred.append(preds.detach().cpu().numpy())
test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred)
test_acc = metrics.accuracy_score(test_true, test_pred)
avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred)
outstr = 'Test :: test acc: %.6f, test avg acc: %.6f'%(test_acc, avg_per_class_acc)
io.cprint(outstr)


if __name__ == "__main__":
# Training settings
parser = argparse.ArgumentParser(description='Point Cloud Recognition')
parser.add_argument('--exp_name', type=str, default='exp', metavar='N',
help='Name of the experiment')
parser.add_argument('--model', type=str, default='dgcnn', metavar='N',
choices=['pointnet', 'dgcnn'],
help='Model to use, [pointnet, dgcnn]')
parser.add_argument('--dataset', type=str, default='modelnet40', metavar='N',
choices=['modelnet40'])
parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size',
help='Size of batch)')
parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size',
help='Size of batch)')
parser.add_argument('--epochs', type=int, default=250, metavar='N',
help='number of episode to train ')
parser.add_argument('--use_sgd', type=bool, default=True,
help='Use SGD')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate (default: 0.001, 0.1 if using sgd)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--no_cuda', type=bool, default=False,
help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--eval', type=bool, default=False,
help='evaluate the model')
parser.add_argument('--num_points', type=int, default=1024,
help='num of points to use')
parser.add_argument('--dropout', type=float, default=0.5,
help='dropout rate')
parser.add_argument('--emb_dims', type=int, default=1024, metavar='N',
help='Dimension of embeddings')
parser.add_argument('--k', type=int, default=20, metavar='N',
help='Num of nearest neighbors to use')
parser.add_argument('--model_path', type=str, default='', metavar='N',
help='Pretrained model path')
args = parser.parse_args()

_init_()

io = IOStream('checkpoints/' + args.exp_name + '/run.log')
io.cprint(str(args))

args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
io.cprint(
'Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices')
torch.cuda.manual_seed(args.seed)
else:
io.cprint('Using CPU')

if not args.eval:
train(args, io)
else:
test(args, io)
Loading

0 comments on commit 2355c94

Please sign in to comment.