GithubHelp home page GithubHelp logo

nvlabs / al-mdn Goto Github PK

View Code? Open in Web Editor NEW
165.0 6.0 19.0 201 KB

Official pytorch implementation of Active Learning for deep object detection via probabilistic modeling (ICCV 2021)

Home Page: https://openaccess.thecvf.com/content/ICCV2021/html/Choi_Active_Learning_for_Deep_Object_Detection_via_Probabilistic_Modeling_ICCV_2021_paper.html

License: Other

Python 97.93% Shell 2.07%
active-learning object-detection deep-learning

al-mdn's People

Contributors

jwchoi384 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

al-mdn's Issues

Uncertainties for different classes / class imbalance problem

Hello, Thank you for the paper and the repo. I was wondering how can I deal with class imbalance during the active learning loop. Do you think the model will be choosing more samples from a class with a fewer number of images? Or will it be the other way around? Which part of the code should I tweak if I want to prioritize some of the classes during the active learning cycle? I really appreciate any help you can provide.

Stop training after first iteration

$ CUDA_VISIBLE_DEVICES='0,1' python train_ssd_gmm_supervised_learning.py
C:\Users\fi42\Active_learning\AL-MDN\layers\modules\l2norm.py:20: UserWarning: nn.init.constant is now deprecated in favor of nn.init.constant_.
init.constant(self.weight,self.gamma)
Loading base network...
Initializing weights...
train_ssd_gmm_supervised_learning.py:225: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.
init.xavier_uniform(param)
Training SSD on: VOC0712
Using the specified args:
Namespace(basenet='vgg16_reducedfc.pth', batch_size=32, cuda=True, dataset='VOC300', dataset_root='C:\Users\fi42\data/VOCdevkit/', gamma=0.1, id=1, lr=0.001, momentum=0.9, num_workers=8, resume=None, save_folder='weights/', start_iter=0, visdom=False, weight_decay=0.0
005)
C:\Users\fi42\Active_learning\AL-MDN\utils\augmentations.py:240: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this,
you must specify 'dtype=object' when creating the ndarray
mode = random.choice(self.sample_options)
C:\Users\fi42\Active_learning\AL-MDN\utils\augmentations.py:240: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this,
you must specify 'dtype=object' when creating the ndarray
mode = random.choice(self.sample_options)
C:\Users\fi42\Active_learning\AL-MDN\utils\augmentations.py:240: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this,
you must specify 'dtype=object' when creating the ndarray
mode = random.choice(self.sample_options)
C:\Users\fi42\Active_learning\AL-MDN\utils\augmentations.py:240: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this,
you must specify 'dtype=object' when creating the ndarray
mode = random.choice(self.sample_options)
C:\Users\fi42\Active_learning\AL-MDN\utils\augmentations.py:240: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this,
you must specify 'dtype=object' when creating the ndarray
mode = random.choice(self.sample_options)
C:\Users\fi42\Active_learning\AL-MDN\utils\augmentations.py:240: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this,
you must specify 'dtype=object' when creating the ndarray
mode = random.choice(self.sample_options)
C:\Users\fi42\Active_learning\AL-MDN\utils\augmentations.py:240: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this,
you must specify 'dtype=object' when creating the ndarray
mode = random.choice(self.sample_options)
C:\Users\fi42\Active_learning\AL-MDN\utils\augmentations.py:240: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this,
you must specify 'dtype=object' when creating the ndarray
mode = random.choice(self.sample_options)
C:\Users\fi42\Anaconda3\envs\py36\lib\site-packages\torch\cuda\nccl.py:24: UserWarning: PyTorch is not compiled with NCCL support
warnings.warn('PyTorch is not compiled with NCCL support')
timer: 2174.4121 sec.
iter 0 || Loss: 29.8597 || loss: 29.8597 , loss_c: 20.1772 , loss_l: 9.6825 , lr : 0.0000


I am using the VOC2007 dataset to train, but it stopped without throwing any error after iteration 0. I didn't change anything in the code. It took a long time to start. What might be the issue?

Issue in running

nit.constant(self.weight,self.gamma)
Finished loading model!
Traceback (most recent call last):
File "C:\Users\fi42\Rony\AL-MDN\eval_voc.py", line 439, in
mean_ap = test_net(args.save_folder, net, args.cuda, dataset,
File "C:\Users\fi42\Rony\AL-MDN\eval_voc.py", line 385, in test_net
detections = net(x).data
File "C:\Users\fi42\anaconda3\envs\ACT\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\fi42\anaconda3\envs\ACT\lib\site-packages\torch\nn\parallel\data_parallel.py", line 166, in forward
return self.module(*inputs[0], **kwargs[0])
File "C:\Users\fi42\anaconda3\envs\ACT\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\fi42\Rony\AL-MDN\ssd_gmm.py", line 272, in forward
output = self.detect(
File "C:\Users\fi42\anaconda3\envs\ACT\lib\site-packages\torch\autograd\function.py", line 150, in call
raise RuntimeError(
RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)

Can you let me know what exact version of pytorch, python , cuda, you used? I tried several but got this error?

diverges during training on my own dataset

Hello,

I succeeded in reproducing your results on your dataset.
However, when it comes to my own traffic light dataset, which contains many small objects,
the localization heads starts to diverge.
I tried my best to turn down learning rate to 1e-7 or less,
The classification head converges, and works at test sets;
But the localization head diverges at loss=5 for four localizers (loss =NAN after that), and predict nonsense at test sets.

Any suggestions?
Thanks in advacne

Error in VOC Evaluation script

Traceback (most recent call last): File "eval_voc.py", line 440, in <module> thresh=args.confidence_threshold) File "eval_voc.py", line 384, in test_net detections = net(x).data File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 166, in forward return self.module(*inputs[0], **kwargs[0]) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/MDN/ssd_gmm.py", line 296, in forward conf_pi_4.view(conf_var_4.size(0), -1, 1) File "/opt/conda/lib/python3.7/site-packages/torch/autograd/function.py", line 151, in __call__ "Legacy autograd function with non-static forward method is deprecated. " RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)

Pytorch Version - 1.9.0

About Eq.5

image

Very sorry for disturbing you, I would like to ask why there is an add operation?

Training stop in the last iteration iter:119999

I have this error in the end training phase :
Runtime Error: Legacy autograd function with non static forward method is deprecated. Please use new-style autograd function with static forward method .
error in File /AL-MDN-main/utils/test_voc.py line 348, in test_net
"" detections = net(x).data

For the version of pytorch i have 1.10.1+cu111

Why using reparameterizaiton-trick for classification loss computation

Hi author,
Thanks for your great work. After reading the paper, I have a question regarding the computation of the classification loss.

It's clear that, for the localization loss, you regress the mean of GMM w.r.t the offset of the anchor to GT boxes, and use the variance term to predict the un-certainty of the offset prediction, and optimize this procedure with Log-likelihood Maximizaiton.

However, for classification, as far as I am concerned, you do the optimization in another way. You treat the input data as random variable, and by using the re-parameterization trick, you get the sampled class-specific random variable from the learned GMM, and finally compute the BCE loss between GT and re-parameterized random variable.

My question is that, why do this in this way? Can we just compute the classification loss in a similar way as the localization? Doing something like Maximize the likelihood of pos, and neg samples given the predicted mean and variance of GMM, like N(GT_pos | mu_p, Sigma_p), N(GT_neg | mu_p, Sigma_p), where mu_p and Sigma_p are computed by the network.

I hope my puzzel could be considered,

Best regards

The question of the mixture weight π of GMM

We have trained a model with this method on our dataset. In the test phase, We find that the π value in the one of four classification components of GMM is close to 1, and the rest of π values are too small (close to 0), is this correct? Do you have similar situations? If not, what’s the π value of the four classification components when testing? Looking forward to your reply, thanks!

How to use GMM in Faster RCNN?

Thanks for your great work! As you have mentioned in your paper, GMM works well on Faster RCNN. So can you share your code on Faster RCNN? I want to do some experiments on a two-stage detector.

RuntimeError: Error(s) in loading state_dict for DataParallel:

python eval_coco.py --dataset_root /coco --trained_model weights/vgg16_reducedfc.pth
81
:20: UserWarning: nn.init.constant is now deprecated in favor of nn.init.constant_.
init.constant(self.weight,self.gamma)
True
Loading weight: weights/vgg16_reducedfc.pth
odict_keys(['0.weight', '0.bias', '2.weight', '2.bias', '5.weight', '5.bias', '7.weight', '7.bias', '10.weight', '10.bias', '12.weight', '12.bias', '14.weight', '14.bias', '17.weight', '17.bias', '19.weight', '19.bias', '21.weight', '21.bias', '24.weight', '24.bias', '26.weight', '26.bias', '28.weight', '28.bias', '31.weight', '31.bias', '33.weight', '33.bias'])
Traceback (most recent call last):
File "eval_coco.py", line 188, in
net.load_state_dict(ckp['weight'] if 'weight' in ckp.keys() else ckp)
File "/lib/python3.7/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.vgg.0.weight", "module.vgg.0.bias", "module.vgg.2.weight", "module.vgg.2.bias", "module.vgg.5.weight", "module.vgg.5.bias", "module.vgg.7.weight", "module.vgg.7.bias", "module.vgg.10.weight", "module.vgg.10.bias", "module.vgg.12.weight", "module.vgg.12.bias", "module.vgg.14.weight", "module.vgg.14.bias", "module.vgg.17.weight", "module.vgg.17.bias", "module.vgg.19.weight", "module.vgg.19.bias", "module.vgg.21.weight", "module.vgg.21.bias", "module.vgg.24.weight", "module.vgg.24.bias", "module.vgg.26.weight", "module.vgg.26.bias", "module.vgg.28.weight", "module.vgg.28.bias", "module.vgg.31.weight", "module.vgg.31.bias", "module.vgg.33.weight", "module.vgg.33.bias", "module.L2Norm.weight", "module.extras.0.weight", "module.extras.0.bias", "module.extras.1.weight", "module.extras.1.bias", "module.extras.2.weight", "module.extras.2.bias", "module.extras.3.weight", "module.extras.3.bias", "module.extras.4.weight", "module.extras.4.bias", "module.extras.5.weight", "module.extras.5.bias", "module.extras.6.weight", "module.extras.6.bias", "module.extras.7.weight", "module.extras.7.bias", "module.loc_mu_1.0.weight", "module.loc_mu_1.0.bias", "module.loc_mu_1.1.weight", "module.loc_mu_1.1.bias", "module.loc_mu_1.2.weight", "module.loc_mu_1.2.bias", "module.loc_mu_1.3.weight", "module.loc_mu_1.3.bias", "module.loc_mu_1.4.weight", "module.loc_mu_1.4.bias", "module.loc_mu_1.5.weight", "module.loc_mu_1.5.bias", "module.loc_var_1.0.weight", "module.loc_var_1.0.bias", "module.loc_var_1.1.weight", "module.loc_var_1.1.bias", "module.loc_var_1.2.weight", "module.loc_var_1.2.bias", "module.loc_var_1.3.weight", "module.loc_var_1.3.bias", "module.loc_var_1.4.weight", "module.loc_var_1.4.bias", "module.loc_var_1.5.weight", "module.loc_var_1.5.bias", "module.loc_pi_1.0.weight", "module.loc_pi_1.0.bias", "module.loc_pi_1.1.weight", "module.loc_pi_1.1.bias", "module.loc_pi_1.2.weight", "module.loc_pi_1.2.bias", "module.loc_pi_1.3.weight", "module.loc_pi_1.3.bias", "module.loc_pi_1.4.weight", "module.loc_pi_1.4.bias", "module.loc_pi_1.5.weight", "module.loc_pi_1.5.bias", "module.loc_mu_2.0.weight", "module.loc_mu_2.0.bias", "module.loc_mu_2.1.weight", "module.loc_mu_2.1.bias", "module.loc_mu_2.2.weight", "module.loc_mu_2.2.bias", "module.loc_mu_2.3.weight", "module.loc_mu_2.3.bias", "module.loc_mu_2.4.weight", "module.loc_mu_2.4.bias", "module.loc_mu_2.5.weight", "module.loc_mu_2.5.bias", "module.loc_var_2.0.weight", "module.loc_var_2.0.bias", "module.loc_var_2.1.weight", "module.loc_var_2.1.bias", "module.loc_var_2.2.weight", "module.loc_var_2.2.bias", "module.loc_var_2.3.weight", "module.loc_var_2.3.bias", "module.loc_var_2.4.weight", "module.loc_var_2.4.bias", "module.loc_var_2.5.weight", "module.loc_var_2.5.bias", "module.loc_pi_2.0.weight", "module.loc_pi_2.0.bias", "module.loc_pi_2.1.weight", "module.loc_pi_2.1.bias", "module.loc_pi_2.2.weight", "module.loc_pi_2.2.bias", "module.loc_pi_2.3.weight", "module.loc_pi_2.3.bias", "module.loc_pi_2.4.weight", "module.loc_pi_2.4.bias", "module.loc_pi_2.5.weight", "module.loc_pi_2.5.bias", "module.loc_mu_3.0.weight", "module.loc_mu_3.0.bias", "module.loc_mu_3.1.weight", "module.loc_mu_3.1.bias", "module.loc_mu_3.2.weight", "module.loc_mu_3.2.bias", "module.loc_mu_3.3.weight", "module.loc_mu_3.3.bias", "module.loc_mu_3.4.weight", "module.loc_mu_3.4.bias", "module.loc_mu_3.5.weight", "module.loc_mu_3.5.bias", "module.loc_var_3.0.weight", "module.loc_var_3.0.bias", "module.loc_var_3.1.weight", "module.loc_var_3.1.bias", "module.loc_var_3.2.weight", "module.loc_var_3.2.bias", "module.loc_var_3.3.weight", "module.loc_var_3.3.bias", "module.loc_var_3.4.weight", "module.loc_var_3.4.bias", "module.loc_var_3.5.weight", "module.loc_var_3.5.bias", "module.loc_pi_3.0.weight", "module.loc_pi_3.0.bias", "module.loc_pi_3.1.weight", "module.loc_pi_3.1.bias", "module.loc_pi_3.2.weight", "module.loc_pi_3.2.bias", "module.loc_pi_3.3.weight", "module.loc_pi_3.3.bias", "module.loc_pi_3.4.weight", "module.loc_pi_3.4.bias", "module.loc_pi_3.5.weight", "module.loc_pi_3.5.bias", "module.loc_mu_4.0.weight", "module.loc_mu_4.0.bias", "module.loc_mu_4.1.weight", "module.loc_mu_4.1.bias", "module.loc_mu_4.2.weight", "module.loc_mu_4.2.bias", "module.loc_mu_4.3.weight", "module.loc_mu_4.3.bias", "module.loc_mu_4.4.weight", "module.loc_mu_4.4.bias", "module.loc_mu_4.5.weight", "module.loc_mu_4.5.bias", "module.loc_var_4.0.weight", "module.loc_var_4.0.bias", "module.loc_var_4.1.weight", "module.loc_var_4.1.bias", "module.loc_var_4.2.weight", "module.loc_var_4.2.bias", "module.loc_var_4.3.weight", "module.loc_var_4.3.bias", "module.loc_var_4.4.weight", "module.loc_var_4.4.bias", "module.loc_var_4.5.weight", "module.loc_var_4.5.bias", "module.loc_pi_4.0.weight", "module.loc_pi_4.0.bias", "module.loc_pi_4.1.weight", "module.loc_pi_4.1.bias", "module.loc_pi_4.2.weight", "module.loc_pi_4.2.bias", "module.loc_pi_4.3.weight", "module.loc_pi_4.3.bias", "module.loc_pi_4.4.weight", "module.loc_pi_4.4.bias", "module.loc_pi_4.5.weight", "module.loc_pi_4.5.bias", "module.conf_mu_1.0.weight", "module.conf_mu_1.0.bias", "module.conf_mu_1.1.weight", "module.conf_mu_1.1.bias", "module.conf_mu_1.2.weight", "module.conf_mu_1.2.bias", "module.conf_mu_1.3.weight", "module.conf_mu_1.3.bias", "module.conf_mu_1.4.weight", "module.conf_mu_1.4.bias", "module.conf_mu_1.5.weight", "module.conf_mu_1.5.bias", "module.conf_var_1.0.weight", "module.conf_var_1.0.bias", "module.conf_var_1.1.weight", "module.conf_var_1.1.bias", "module.conf_var_1.2.weight", "module.conf_var_1.2.bias", "module.conf_var_1.3.weight", "module.conf_var_1.3.bias", "module.conf_var_1.4.weight", "module.conf_var_1.4.bias", "module.conf_var_1.5.weight", "module.conf_var_1.5.bias", "module.conf_pi_1.0.weight", "module.conf_pi_1.0.bias", "module.conf_pi_1.1.weight", "module.conf_pi_1.1.bias", "module.conf_pi_1.2.weight", "module.conf_pi_1.2.bias", "module.conf_pi_1.3.weight", "module.conf_pi_1.3.bias", "module.conf_pi_1.4.weight", "module.conf_pi_1.4.bias", "module.conf_pi_1.5.weight", "module.conf_pi_1.5.bias", "module.conf_mu_2.0.weight", "module.conf_mu_2.0.bias", "module.conf_mu_2.1.weight", "module.conf_mu_2.1.bias", "module.conf_mu_2.2.weight", "module.conf_mu_2.2.bias", "module.conf_mu_2.3.weight", "module.conf_mu_2.3.bias", "module.conf_mu_2.4.weight", "module.conf_mu_2.4.bias", "module.conf_mu_2.5.weight", "module.conf_mu_2.5.bias", "module.conf_var_2.0.weight", "module.conf_var_2.0.bias", "module.conf_var_2.1.weight", "module.conf_var_2.1.bias", "module.conf_var_2.2.weight", "module.conf_var_2.2.bias", "module.conf_var_2.3.weight", "module.conf_var_2.3.bias", "module.conf_var_2.4.weight", "module.conf_var_2.4.bias", "module.conf_var_2.5.weight", "module.conf_var_2.5.bias", "module.conf_pi_2.0.weight", "module.conf_pi_2.0.bias", "module.conf_pi_2.1.weight", "module.conf_pi_2.1.bias", "module.conf_pi_2.2.weight", "module.conf_pi_2.2.bias", "module.conf_pi_2.3.weight", "module.conf_pi_2.3.bias", "module.conf_pi_2.4.weight", "module.conf_pi_2.4.bias", "module.conf_pi_2.5.weight", "module.conf_pi_2.5.bias", "module.conf_mu_3.0.weight", "module.conf_mu_3.0.bias", "module.conf_mu_3.1.weight", "module.conf_mu_3.1.bias", "module.conf_mu_3.2.weight", "module.conf_mu_3.2.bias", "module.conf_mu_3.3.weight", "module.conf_mu_3.3.bias", "module.conf_mu_3.4.weight", "module.conf_mu_3.4.bias", "module.conf_mu_3.5.weight", "module.conf_mu_3.5.bias", "module.conf_var_3.0.weight", "module.conf_var_3.0.bias", "module.conf_var_3.1.weight", "module.conf_var_3.1.bias", "module.conf_var_3.2.weight", "module.conf_var_3.2.bias", "module.conf_var_3.3.weight", "module.conf_var_3.3.bias", "module.conf_var_3.4.weight", "module.conf_var_3.4.bias", "module.conf_var_3.5.weight", "module.conf_var_3.5.bias", "module.conf_pi_3.0.weight", "module.conf_pi_3.0.bias", "module.conf_pi_3.1.weight", "module.conf_pi_3.1.bias", "module.conf_pi_3.2.weight", "module.conf_pi_3.2.bias", "module.conf_pi_3.3.weight", "module.conf_pi_3.3.bias", "module.conf_pi_3.4.weight", "module.conf_pi_3.4.bias", "module.conf_pi_3.5.weight", "module.conf_pi_3.5.bias", "module.conf_mu_4.0.weight", "module.conf_mu_4.0.bias", "module.conf_mu_4.1.weight", "module.conf_mu_4.1.bias", "module.conf_mu_4.2.weight", "module.conf_mu_4.2.bias", "module.conf_mu_4.3.weight", "module.conf_mu_4.3.bias", "module.conf_mu_4.4.weight", "module.conf_mu_4.4.bias", "module.conf_mu_4.5.weight", "module.conf_mu_4.5.bias", "module.conf_var_4.0.weight", "module.conf_var_4.0.bias", "module.conf_var_4.1.weight", "module.conf_var_4.1.bias", "module.conf_var_4.2.weight", "module.conf_var_4.2.bias", "module.conf_var_4.3.weight", "module.conf_var_4.3.bias", "module.conf_var_4.4.weight", "module.conf_var_4.4.bias", "module.conf_var_4.5.weight", "module.conf_var_4.5.bias", "module.conf_pi_4.0.weight", "module.conf_pi_4.0.bias", "module.conf_pi_4.1.weight", "module.conf_pi_4.1.bias", "module.conf_pi_4.2.weight", "module.conf_pi_4.2.bias", "module.conf_pi_4.3.weight", "module.conf_pi_4.3.bias", "module.conf_pi_4.4.weight", "module.conf_pi_4.4.bias", "module.conf_pi_4.5.weight", "module.conf_pi_4.5.bias".
Unexpected key(s) in state_dict: "0.weight", "0.bias", "2.weight", "2.bias", "5.weight", "5.bias", "7.weight", "7.bias", "10.weight", "10.bias", "12.weight", "12.bias", "14.weight", "14.bias", "17.weight", "17.bias", "19.weight", "19.bias", "21.weight", "21.bias", "24.weight", "24.bias", "26.weight", "26.bias", "28.weight", "28.bias", "31.weight", "31.bias", "33.weight", "33.bias".

How to train on custom dataset ?

Could someone please enumerate the steps needed to train these models on a custom dataset? I can get my data in the PASCAL VOC format

Thank you very much!

How can I apply this algorithm for detector with focal loss

Hi, I am very interested in your work. So I wanna apply this algorithm for my work. As a commonly used loss function focal loss, the output of clssification is different from the cross entropy loss. The classification output layer num is equal to the class num, not class num add 1, So I wanna know how to change the loss function in this paper for applying focal loss output. Thanks a lot.

Autograde issue while evalution

/home/bridgei2i/.local/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory
warn(f"Failed to load image Python extension: {e}")
/Datadisk/AILabs/ComputerVision/Shoeb/AL-MDN/layers/modules/l2norm.py:20: UserWarning: nn.init.constant is now deprecated in favor of nn.init.constant_.
init.constant(self.weight,self.gamma)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
/home/bridgei2i/.local/lib/python3.10/site-packages/torch/nn/functional.py:780: UserWarning: Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.
warnings.warn("Note that order of the arguments: ceil_mode and return_indices will change"
Traceback (most recent call last):
File "/Datadisk/AILabs/ComputerVision/Shoeb/AL-MDN/demo.py", line 70, in
y = net(xx)
File "/home/bridgei2i/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/bridgei2i/.local/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/bridgei2i/.local/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/bridgei2i/.local/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/home/bridgei2i/.local/lib/python3.10/site-packages/torch/_utils.py", line 457, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/bridgei2i/.local/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/bridgei2i/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/Datadisk/AILabs/ComputerVision/Shoeb/AL-MDN/ssd_gmm.py", line 271, in forward
output = self.detect(
File "/home/bridgei2i/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 315, in call
raise RuntimeError(
RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)

How to output aleatoric and epistemic uncertainties associated to the class

Greetings,

I was trying to print out the uncertainties just the way it is shown in Figure 3 from the paper:

uncertainties

Where should I tweak the code, so that I can output those 4 uncertainty values for each image?

In /layers/functions/detection_gmm.py there's one output variable with should contain uncertainties, but I wasn't able to understand the output. Is there anything I'm missing/misunderstanding?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.