GithubHelp home page GithubHelp logo

Comments (9)

lucidrains avatar lucidrains commented on July 17, 2024 1

pip install pytorch-custom-utils

from pytorch_custom_utils import get_adam_optimizer

from x-transformers.

lucidrains avatar lucidrains commented on July 17, 2024 1

@pfeatherstone and yeah, typically you just filter out any parameters with ndims <= 1, however, i've also heard from some researchers that it doesn't matter, ymmv

this is out of the scope for this repository though, recommend you just read some papers and decide for yourself

from x-transformers.

pfeatherstone avatar pfeatherstone commented on July 17, 2024 1

@lucidrains Thank you. It looks like you are doing what nanogpt is doing. That does mean you are decaying normalization weights. I'll have a fiddle. Sorry if this is out of scope.

from x-transformers.

pfeatherstone avatar pfeatherstone commented on July 17, 2024

Currently i'm using:

def createOptimizer(model: torch.nn.Module, betas=(0.9,0.95), lr=0.001, decay=0.1):
    blacklistModules = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) + (nn.Embedding, ScaleNorm, RMSNorm)
    blacklistNames   = ["bias", "memory_tokens", 'mem_k', 'mem_v']
    decay_params   = []
    nodecay_params = []
    for module_name, module in self.named_modules():
        for param_name, param in module.named_parameters(recurse=False):
            fullname = f"{module_name}.{param_name}" if module_name else param_name
            if any(substr in fullname for substr in blacklistNames) or isinstance(module, blacklistModules):
                nodecay_params.append(param)
            else:
                decay_params.append(param)

    ndecayed            = len(decay_params)
    nnodecayed          = len(nodecay_params)
    ntotal              = len(list(filter(lambda p: p.requires_grad, self.parameters())))
    assert ndecayed + nnodecayed == ntotal, f"bad split: {ndecayed} + {nnodecayed} != {ntotal}"
    optim_groups = [
        {'params': decay_params,   'weight_decay': decay},
        {'params': nodecay_params, 'weight_decay': 0.0}
    ]
    optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=True)
    return optimizer

I've put memory tokens in the blacklist, i.e. parameters that don't decay. Not sure if that's correct. Layers like ScaleNorm and RMSNorm I'm treating like other pytorch normalization layers, and therefore also don't decay

from x-transformers.

pfeatherstone avatar pfeatherstone commented on July 17, 2024

Basically, i've only just started playing with optimizers and found that they have a massive influence on convergence rate and stability. Duh.

from x-transformers.

pfeatherstone avatar pfeatherstone commented on July 17, 2024

Can anybody think of any other layers/parameters that shouldn't decay ?

from x-transformers.

lucidrains avatar lucidrains commented on July 17, 2024

@pfeatherstone just use https://github.com/lucidrains/pytorch-custom-utils/blob/main/pytorch_custom_utils/get_adam_optimizer.py#L15 will suit 95% of your optimizer needs

from x-transformers.

lucidrains avatar lucidrains commented on July 17, 2024

@pfeatherstone or hop on eleutherai and consult the crowd intelligence there. everyone has their own opinions about optimizers

from x-transformers.

lucidrains avatar lucidrains commented on July 17, 2024

@pfeatherstone well, it isn't i'm doing what Karpathy is doing; we are both following an early practice for the original transformer training from Brain. however, whether it really matters, or is just passed down superstition, is still up for a future research paper to decide

from x-transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.