lpf-sgd's People
lpf-sgd's Issues
# Some Questions on the Noise Generation
Hi!
Thanks for releasing the codes for your AISTATS2022 paper, and the performance is awe-inspiring. However, I am confused about the generation of noise, Code. It seems to generate the noise tensor-wise --- a bit different from the claim in the original paper. Does there a lack of a normalization operation on the standard variance of the Gaussian noise? In your implementation, the tensor with larger dimensions/kernel size will receive stronger element-wise noise, since it has a larger norm. Could you explain the mechanism behind that?
# Some issue met in machine translation codes
Dear Devansh,
Thanks so much for releasing the code. I have tried to reproduce the codes in machine translation. But I met some issues. Could you provide additional detailed procedures for preprocessing the data? I thought there existed such a procedure. Besides, I have met the following error when running https://github.com/devansh20la/LPF-SGD/tree/master/codes#training-2:
Traceback (most recent call last): File "/mnt/sda/litao/LPF-SGD/codes/machine_translation/_adamtrain.py", line 174, in <module> inputs = translate(model, inputs, bpe_model, device) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/mnt/sda/litao/LPF-SGD/codes/machine_translation/utils/train_utils.py", line 117, in translate outputs = greedy_decode(model, inputs, inputs_mask, max_len=num_tokens + 10, device=device).flatten() File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/mnt/sda/litao/LPF-SGD/codes/machine_translation/utils/train_utils.py", line 91, in greedy_decode out = model.decode(ys, memory, tgt_mask) File "/mnt/sda/litao/LPF-SGD/codes/machine_translation/models/seq2seq.py", line 84, in decode return self.transformer.decoder(self.positional_encoding( File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 369, in forward output = mod(output, memory, tgt_mask=tgt_mask, File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 717, in forward x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal)) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 735, in _mha_block x = self.multihead_attn(x, mem, mem, File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1205, in forward attn_output, attn_output_weights = F.multi_head_attention_forward( File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 5279, in multi_head_attention_forward k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) RuntimeError: shape '[21, 8, 64]' is invalid for input of size 225792
Could you kindly help me with that?
Many thanks,
Tao
About LPF-SGD implementation
Thanks so much for releasing the code. I have several questions about the implementation about LPF-SGD.
-
noise.append(- init_mp - temp)
the reason that using init_mp to obtain the noise -
mp.grad.add_((-(n**2 + 1) / mp.view(-1).norm().item())*batch_loss.item())
why we still need this value to the gradient.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.