GithubHelp home page GithubHelp logo

Comments (2)

laomeng0703 avatar laomeng0703 commented on July 24, 2024

Hi, I probably know how to fix this.
As shown below, in the place of the cls_loss function in the original loss_utils.py file, the definitions of cls_pc_raw and cls_aug_raw are replaced by underscores. I think just change the two underscores to the above two variable names.

def cls_loss(pred, pred_aug, gold, pc_tran, aug_tran, pc_feat, aug_feat, ispn = True):

    mse_fn = torch.nn.MSELoss(reduce=True, size_average=True)
    cls_pc, _ = cal_loss_raw(pred, gold)
    cls_aug, _ = cal_loss_raw(pred_aug, gold)
  
    if ispn:
        cls_pc = cls_pc + 0.001*mat_loss(pc_tran)
        cls_aug = cls_aug + 0.001*mat_loss(aug_tran)
    feat_diff = 10.0*mse_fn(pc_feat,aug_feat)
    parameters = torch.max(torch.tensor(NUM).cuda(), torch.exp(1.0-cls_pc_raw)**2).cuda()
    cls_diff = (torch.abs(cls_pc_raw - cls_aug_raw) * (parameters*2)).mean()
    cls_loss = cls_pc + cls_aug  + feat_diff# + cls_diff
    return cls_loss

from pointaugment.

kkmm54 avatar kkmm54 commented on July 24, 2024

`def cls_loss(pred, pred_aug, gold, pc_tran, aug_tran, pc_feat, aug_feat, ispn = True):

mse_fn = torch.nn.MSELoss(reduce=True, size_average=True)
cls_pc, cls_pc_raw= cal_loss_raw(pred, gold)
cls_aug, cls_aug_raw= cal_loss_raw(pred_aug, gold)

if ispn:
    cls_pc = cls_pc + 0.001*mat_loss(pc_tran)
    cls_aug = cls_aug + 0.001*mat_loss(aug_tran)
feat_diff = 10.0*mse_fn(pc_feat,aug_feat)
parameters = torch.max(torch.tensor(NUM).cuda(), torch.exp(1.0-cls_pc_raw)**2).cuda()
cls_diff = (torch.abs(cls_pc_raw - cls_aug_raw) * (parameters*2)).mean()
cls_loss = cls_pc + cls_aug  + feat_diff# + cls_diff
return cls_loss`

from pointaugment.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.