GithubHelp home page GithubHelp logo

leondgarse / keras_mlp Goto Github PK

View Code? Open in Web Editor NEW
7.0 1.0 3.0 58 KB

Keras implementation of mlp-mixer, ResMLP, gmlp. imagenet/imagenet21k weights reloaded.

License: MIT License

Python 100.00%
mlp-mixer resmlp keras tf tf2 tensorflow tensorflow2 mlp gmlp

keras_mlp's Introduction

Keras_mlp


Usage

  • This repo can be installed as a pip package.
    pip install -U git+https://github.com/leondgarse/keras_mlp
    or just git clone it.
    git clone https://github.com/leondgarse/keras_mlp.git
    cd keras_mlp && pip install .
  • Basic usage
    import keras_mlp
    # Will download and load `imagenet` pretrained weights.
    # Model weight is loaded with `by_name=True, skip_mismatch=True`.
    mm = keras_mlp.MLPMixerB16(num_classes=1000, pretrained="imagenet")
    
    # Run prediction
    import tensorflow as tf
    from tensorflow import keras
    from skimage.data import chelsea # Chelsea the cat
    imm = keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='tf') # model="tf" or "torch"
    pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy()
    print(keras.applications.imagenet_utils.decode_predictions(pred)[0])
    # [('n02124075', 'Egyptian_cat', 0.9568315), ('n02123045', 'tabby', 0.017994137), ...]
    For "imagenet21k" pre-trained models, actual num_classes is 21843.
  • Exclude model top layers by set num_classes=0.
    import keras_mlp
    mm = keras_mlp.ResMLP_B24(num_classes=0, pretrained="imagenet22k")
    print(mm.output_shape)
    # (None, 784, 768)
    
    mm.save('resmlp_b24_imagenet22k-notop.h5')

MLP mixer

  • PDF 2105.01601 MLP-Mixer: An all-MLP Architecture for Vision.

  • Github google-research/vision_transformer.

  • Models Top1 Acc is Pre-trained on JFT-300M model accuray on ImageNet 1K from paper.

    Model Params Top1 Acc ImageNet Imagenet21k ImageNet SAM
    MLPMixerS32 19.1M 68.70
    MLPMixerS16 18.5M 73.83
    MLPMixerB32 60.3M 75.53 b32_imagenet_sam.h5
    MLPMixerB16 59.9M 80.00 b16_imagenet.h5 b16_imagenet21k.h5 b16_imagenet_sam.h5
    MLPMixerL32 206.9M 80.67
    MLPMixerL16 208.2M 84.82 l16_imagenet.h5 l16_imagenet21k.h5
    - input 448 208.2M 86.78
    MLPMixerH14 432.3M 86.32
    - input 448 432.3M 87.94
    Specification S/32 S/16 B/32 B/16 L/32 L/16 H/14
    Number of layers 8 8 12 12 24 24 32
    Patch resolution P×P 32×32 16×16 32×32 16×16 32×32 16×16 14×14
    Hidden size C 512 512 768 768 1024 1024 1280
    Sequence length S 49 196 49 196 49 196 256
    MLP dimension DC 2048 2048 3072 3072 4096 4096 5120
    MLP dimension DS 256 256 384 384 512 512 640
  • Parameter pretrained is added in value [None, "imagenet", "imagenet21k", "imagenet_sam"]. Default is imagenet.

  • Pre-training details

    • We pre-train all models using Adam with β1 = 0.9, β2 = 0.999, and batch size 4 096, using weight decay, and gradient clipping at global norm 1.
    • We use a linear learning rate warmup of 10k steps and linear decay.
    • We pre-train all models at resolution 224.
    • For JFT-300M, we pre-process images by applying the cropping technique from Szegedy et al. [44] in addition to random horizontal flipping.
    • For ImageNet and ImageNet-21k, we employ additional data augmentation and regularization techniques.
    • In particular, we use RandAugment [12], mixup [56], dropout [42], and stochastic depth [19].
    • This set of techniques was inspired by the timm library [52] and Touvron et al. [46].
    • More details on these hyperparameters are provided in Supplementary B.

ResMLP

GMLP


keras_mlp's People

Contributors

leondgarse avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

keras_mlp's Issues

gelu in mlp_mixer is not called

As said in the title, gelu in mlp_mixer is actually not called.
The 11th line of mlp_mixer.py made activation = None:

def mixer_block(inputs, tokens_mlp_dim, channels_mlp_dim=None, activation=None, name=None):

which should actually be:

def mixer_block(inputs, tokens_mlp_dim, channels_mlp_dim=None, activation="gelu", name=None):

I found it out because the model isn't converging well on CIFAR10, even with pretrained weights from google-research/vision_transformer loaded.

Everything worked fine then after rewriting activation="gelu".

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.