Comments (9)
pip install pytorch-custom-utils
from pytorch_custom_utils import get_adam_optimizer
from x-transformers.
@pfeatherstone and yeah, typically you just filter out any parameters with ndims <= 1
, however, i've also heard from some researchers that it doesn't matter, ymmv
this is out of the scope for this repository though, recommend you just read some papers and decide for yourself
from x-transformers.
@lucidrains Thank you. It looks like you are doing what nanogpt is doing. That does mean you are decaying normalization weights. I'll have a fiddle. Sorry if this is out of scope.
from x-transformers.
Currently i'm using:
def createOptimizer(model: torch.nn.Module, betas=(0.9,0.95), lr=0.001, decay=0.1):
blacklistModules = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) + (nn.Embedding, ScaleNorm, RMSNorm)
blacklistNames = ["bias", "memory_tokens", 'mem_k', 'mem_v']
decay_params = []
nodecay_params = []
for module_name, module in self.named_modules():
for param_name, param in module.named_parameters(recurse=False):
fullname = f"{module_name}.{param_name}" if module_name else param_name
if any(substr in fullname for substr in blacklistNames) or isinstance(module, blacklistModules):
nodecay_params.append(param)
else:
decay_params.append(param)
ndecayed = len(decay_params)
nnodecayed = len(nodecay_params)
ntotal = len(list(filter(lambda p: p.requires_grad, self.parameters())))
assert ndecayed + nnodecayed == ntotal, f"bad split: {ndecayed} + {nnodecayed} != {ntotal}"
optim_groups = [
{'params': decay_params, 'weight_decay': decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=True)
return optimizer
I've put memory tokens in the blacklist, i.e. parameters that don't decay. Not sure if that's correct. Layers like ScaleNorm
and RMSNorm
I'm treating like other pytorch normalization layers, and therefore also don't decay
from x-transformers.
Basically, i've only just started playing with optimizers and found that they have a massive influence on convergence rate and stability. Duh.
from x-transformers.
Can anybody think of any other layers/parameters that shouldn't decay ?
from x-transformers.
@pfeatherstone just use https://github.com/lucidrains/pytorch-custom-utils/blob/main/pytorch_custom_utils/get_adam_optimizer.py#L15 will suit 95% of your optimizer needs
from x-transformers.
@pfeatherstone or hop on eleutherai and consult the crowd intelligence there. everyone has their own opinions about optimizers
from x-transformers.
@pfeatherstone well, it isn't i'm doing what Karpathy is doing; we are both following an early practice for the original transformer training from Brain. however, whether it really matters, or is just passed down superstition, is still up for a future research paper to decide
from x-transformers.
Related Issues (20)
- Enhancement: Multi Input/Output transformers HOT 1
- XL-recurrence with RotaryEmbedding and mems not working correctly. HOT 34
- Removed biases breaks pre-trained models HOT 5
- Seq len missing in rotary embedding HOT 3
- Adding memmask to ContinuousTransformerWrapper HOT 3
- attn_num_mem_kv > 0 and attn_one_kv_head = True error HOT 8
- Question: How to implement rel_pos_bias in cross_attention? HOT 13
- [Minor; noob question] Uniform distribution instead of normal
- RotaryEmbedding XPOS doesn't work with mems HOT 5
- `layer_mem` is unbound (when called from `ContinuousTransformerWrapper`) HOT 6
- Generation for PaLI?
- Confusion about image->caption example HOT 1
- How can I add custom attention masks to a Decoder? HOT 3
- Question: rotary embeddings and bad length extrapolation HOT 1
- [Bug] XL-recurrence with AlibiPositionalBias and mems not working correctly HOT 17
- [Question] very small attention scores HOT 7
- Was it a clerical error ? ScaleNorm.g init form dim ** -0.5. I think it should be dim ** 0.5 HOT 1
- [Bug] Error when `rotary_pos_emb` set to True in cross attention HOT 3
- Question: problem with xval implementation HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from x-transformers.