GithubHelp home page GithubHelp logo

liruihui / pointaugment Goto Github PK

View Code? Open in Web Editor NEW
205.0 205.0 29.0 24 KB

Code for PointAugment: an Auto-Augmentation Framework for Point Cloud Classification, CVPR 2020 (Oral)

License: Other

Python 100.00%

pointaugment's People

Contributors

liruihui avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pointaugment's Issues

How to augment custom data?

Hi! Could you please give some directions -
after training the model on ModelNet40 dataset,
how can I use Augmentor to produce augmented point clounds from my custom point cloud?

NameError: name 'cls_pc_raw' is not defined

Hello, when I run this code I get the following error: "NameError: name 'cls_pc_raw' is not defined". I'm not sure if it helps but here is the entire debug message.

checkpoints: log/pointnet_cls/20210827-1421
PARAMETER ...
Namespace(augment=False, batch_size=24, data_dir='ModelNet40', decay_rate=0.0001, epoch=250, epoch_per_save=5, learning_rate=0.001, learning_rate_a=0.001, log_dir='log/pointnet_cls/20210827-1421', lr_decay=0.5, model_name='pointnet', no_decay=False, noise_dim=1024, num_points=1024, optimizer='Adam', pretrain=None, restore=False, use_normal=False, y_rotated=True)
Load dataset ...
The number of training data is: 9840
The number of test data is: 2468
No existing Augment, starting training from scratch...
Epoch 1 (1/250):
0% 0/410 [00:00<?, ?it/s]
Traceback (most recent call last):
File "train_PA.py", line 29, in
model.train()
File "/content/PointAugment/Augment/model.py", line 162, in train
aug_feat, ispn=ispn)
File "/content/PointAugment/Common/loss_utils.py", line 63, in cls_loss
parameters = torch.max(torch.tensor(NUM).cuda(), torch.exp(1.0-cls_pc_raw)**2).cuda()
NameError: name 'cls_pc_raw' is not defined

Bugs

PointAugment/Common/loss_utils.py", line 63, in cls_loss
    parameters = torch.max(torch.tensor(NUM).cuda(), torch.exp(1.0-cls_pc_raw)**2).cuda()
NameError: name 'cls_pc_raw' is not defined

How do you achieve end-to-end?

Hello,liruihui.Thank you for you contribution!
I find in your article, you split into two parts.One part is update the classifier, the other one is update augmentor.Is the reason why you implemented end-to-end because you used the loss function of Equation6 to update the classifier?

Generated samples

Where can I find the generated new samples after training?
I want to use both old and new samples to train pointnet to see how much the accuracy improves.
Thanks.

Classification Results

I downloaded the code and trained the pointnet model from scratch as per the instructions in the README file.

However, I could only see Best Accuracy: 0.890600

Please help us reproduce the results mentioned in the paper.

Thanks in Advance.

NameError: name 'cls_pc_raw' is not defined

To fix the issue please change line 56 and 57 in common/loss_utils.py from

cls_pc, _ = cal_loss_raw(pred, gold)
cls_aug, _ = cal_loss_raw(pred_aug, gold)

to
cls_pc, cls_pc_raw = cal_loss_raw(pred, gold)
cls_aug, cls_aug_raw = cal_loss_raw(pred_aug, gold)

A RuntimeError

Hello, when I run your code, I encountered the following problems when calculating the loss of classifier:
"RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [256, 4]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!"

Question about Augmentor

It's a nice job which really inspired me a lot. But I have a question need to answer.

Well, I can't understand the idea of generating a 3×3 rotation matrix from a quaternion vector in a augmentor. What is the meaning of the function of batch_quat_to_rotmat? The idea of the funtion batch_quat_to_rotmat does not embodies in the paper. Looking forward to your reply. Thank you very much.

def batch_quat_to_rotmat(q, out=None):
B = q.size(0)
if out is None:
out = q.new_empty(B, 3, 3)
# 2 / squared quaternion 2-norm
len = torch.sum(q.pow(2), 1)
s = 2 / len
s_ = torch.clamp(len, 2.0 / 3.0, 3.0 / 2.0)
# coefficients of the Hamilton product of the quaternion with itself
h = torch.bmm(q.unsqueeze(2), q.unsqueeze(1))
out[:, 0, 0] = (1 - (h[:, 2, 2] + h[:, 3, 3]).mul(s)) # .mul(s_)
out[:, 0, 1] = (h[:, 1, 2] - h[:, 3, 0]).mul(s)
out[:, 0, 2] = (h[:, 1, 3] + h[:, 2, 0]).mul(s)
out[:, 1, 0] = (h[:, 1, 2] + h[:, 3, 0]).mul(s)
out[:, 1, 1] = (1 - (h[:, 1, 1] + h[:, 3, 3]).mul(s)) # .mul(s_)
out[:, 1, 2] = (h[:, 2, 3] - h[:, 1, 0]).mul(s)
out[:, 2, 0] = (h[:, 1, 3] - h[:, 2, 0]).mul(s)
out[:, 2, 1] = (h[:, 2, 3] + h[:, 1, 0]).mul(s)
out[:, 2, 2] = (1 - (h[:, 1, 1] + h[:, 2, 2]).mul(s)) # .mul(s_)
return out, s_

Runtime Error

When i run code,
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 6 and 4 in dimension 0 at /opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/THC/generic/THCTensorMath.cu:71

Question about the SR16 dataset.

In Table1, the SR16 has 36148 for training and 5165 for testing, but the SR16 has more than 50000 models, where are the missing models ?

PointAugment short explanation in the blog post

Thank you for your contribution!

I hope you would publish the code soon. I enjoyed the simplicity and effectiveness of your paper and wrote a short blog post with a summary of the main ideas. Maybe some people would be interested in it.

blog_post

Question about the normal feature

In your code

raw_pt = pt[:,:3,:].contiguous()
normal = pt[:,3:,:].transpose(1, 2).contiguous() if C > 3 else None

Therefore, I was wondering do you used the normal feature except xyz coordinate to further boost the performance?

about the code

Hi, @liruihui ,

Congratulations on your paper accepted to coming CVPR conference. When will the code be released?

THX!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.