GithubHelp home page GithubHelp logo

minyoungpark1 / swin_transformer_v2_jax Goto Github PK

View Code? Open in Web Editor NEW
9.0 1.0 0.0 1.93 MB

This project compares the performance of Swin-Transformer v2 implemented in JAX and PyTorch.

Python 100.00%
jax pytorch swin-transformer imagenette

swin_transformer_v2_jax's Introduction

JAX implementation of Swin-Transformer v2

Introduction

This project compared the performance (training/validation speed and accuracy for sanity checking) of Swin-Transformer v2 implemented in JAX and PyTorch. All of these works had been done in Colab environment with Tesla V100-SMX2 GPU. Some of the features in Swin-Transformer v2 has not been implemented, or omitted, yet in JAX setting, such as absolute positional embedding or using pretrained window.

Getting Started

Installation

Since this project was done in the Colab environment, which pre-installed all the DL related packages (PyTorch, JAX, Tensorflow), instructions for installing those packages are omitted. If you are not using the Colab, please visit the links above to install those packages.

All the remaining packages can be installed with the following command:

pip install -r requirements.txt

Download Imagenette dataset

This project used Imagenette dataset. Imagenette dataset is a subset of 10 classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute). If you want to train/test yourself, please click "Imagenette" above to download the dataset.

Its file size is 1.45 GB, and contains 9,469 training images and 3,925 validation images.

Train Swin-Transformer v2 (PyTorch/JAX)

Experiment & results

With image size=(256,256) and batch size 32 and 64 settings, both the JAX and the PyTorch took significantly longer during the first epoch, especially during the first iteration (batch). I assume this was caused by GPU memory allocation during the first run. However PyTorch ran slowly even after the first iteration, and the speed became faster when the second epoch started.

Including the first epoch, with batch size = 32 and 64, JAX was 27.3% and 28.0% faster than PyTorch respectively during training. Also, JAX was 140.5% and 147.4% faster than PyTorch respectively during validation.

Without the first epoch, the speed differences were reduced to 24.9% during training for the both batch sizes, and 132.1% and 137.7% respectively during validation with batch size 32 and 64.

Discussion

  • Factors caused the bottleneck
  • Checkpoints remain in Trash folder

TODO

  • requirements.txt update
  • Different task
  • Analyze time consumption step-by-step

Acknowledgements

This project was inspired by Swin-Transformer and vision_transformer. Some of the codes were borrowed from them.

swin_transformer_v2_jax's People

Contributors

minyoungpark1 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

swin_transformer_v2_jax's Issues

log_relative_position_index

Hello Thank You for publishing this repository!
I have question to the jax implementation https://github.com/minyoungpark1/swin_transformer_v2_jax

in first line of setup function in Window attention module there is a line

        self.log_relative_position_index = log_space_continuous_position_bias(self.window_size)#krowa - not used

As far as I understand this is a novelty from swin transformer v2 relative to v1 - Hovewer I can not find then any way that this log space relative position encoding is used, Why is it so?

Thank You!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.