GithubHelp home page GithubHelp logo

inflated_convnets_pytorch's Introduction

Inflated I3D models with ImageNet weight transfer in PyTorch

This repo contains several scripts that allow to inflate 2D networks according to the technique described in the paper Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset by Joao Carreira and Andrew Zisserman to PyTorch.

It provides the inflated versions for :

  • ResNet 50, ResNet101, ResNet152
  • DenseNet 121, DenseNet161, DenseNet169, DenseNet201

The original (and official!) tensorflow code inflates the inception-v1 network and can be found here.

So far this code allows for the inflation of DenseNet and ResNet where the basis block is a Bottleneck block (Resnet >50), and the transfer of 2D ImageNet weights.

The 3D network is obtained by going through the layers of the 2D network and inflating them one by one. The utilities for the inflation (which both inflate the layers and transfer the weights) are located in src/inflate.py.

Note that for the ResNet inflation, I use a centered initialization scheme as presented in Detect-and-Track: Efficient Pose Estimation in Videos, where instead of replicating the kernel and scaling the weights by the time dimension (as described in the original I3D paper), I initialize the time-centered slice of the kernel to the 2D weights and the rest to 0. This allows to obtain (up to numerical differences) the same outputs for the 2D network with the image input and the matching 3D network with 3D inputs (obtained by replicating the 2D image input in the time dimension).

Use it

To inflate the network and run it on a dummy-dataset with comparison between the final predictions between the original and inflated networks run:

  • For ResNet 101 for instance, run python inflate_resnet.py --resnet_nb 101 (available for ResNet [50|101|152])

  • For DenseNet 121 python inflate_densenet.py --densenet_nb 121 (available for DenseNet [121|161|169|201])

Profiling

Forward pass on GeForce GTX TITAN Black (6Giga) GPU with batch-size 2:

Network time (s)
ResNet 50 0.6 s
ResNet 101 0.8 s
ResNet 152 1.1 s
DenseNet 121 2.6 s

Forward pass on GeForce GTX TITAN Black (6Giga) GPU with batch-size 1:

Network time (s)
ResNet 50 0.1s
ResNet 101 0.3s
ResNet 152 0.5s
DenseNet 121 1.3 s
DenseNet 161 1.8 s
DenseNet 169 1.5 s
DenseNet 201 1.7 s

Note

Another repo with networks pretrained on kinetics is available here 3D-Resnets-Pytorch. However, it does not transfer the ImageNet weights, which in my experience with inception-v1 did improve the final results.

inflated_convnets_pytorch's People

Contributors

bomri avatar hassony2 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

inflated_convnets_pytorch's Issues

benchmark results

Hi

Thank you for providing the code for this.
Did you try generating any benchmark results and how was the performance
as compared to the original i3D model?

Thanks in advance!
Devraj

Parameter not defined in inflate.py

Hi Yana,
Thank you for sharing this code.
In inflate.py you call two times Parameter(weights_3d), but Parameter() is not imported or defined in the code. Could you please share this function?
Thank you,
Robert

training with i3d

hello, have you ever training with this model?
i find the conv3d may cause some CUDA_INTERNAL_ERROR with cudnn, from v6.0 to v7.0
if torch.backends.cudnn.enabled is set to False, everything will be ok

Densenet 169 doesn't load the example.

When I try to load the example with densenet 169 I get the following traceback:

`
python inflate_densenet.py --densenet_nb 169

Traceback (most recent call last):
File "/Users/andimarafioti/Documents/code/inflated_convnets_pytorch/inflate_densenet.py", line 113, in
run_inflater(args)
File "/Users/andimarafioti/Documents/code/inflated_convnets_pytorch/inflate_densenet.py", line 43, in run_inflater
i3densenet = I3DenseNet(
File "/Users/andimarafioti/Documents/code/inflated_convnets_pytorch/src/i3dense.py", line 15, in init
self.features, transition_nb = inflate_features(
File "/Users/andimarafioti/Documents/code/inflated_convnets_pytorch/src/i3dense.py", line 121, in inflate_features
_DenseLayer3d(
File "/Users/andimarafioti/Documents/code/inflated_convnets_pytorch/src/i3dense.py", line 53, in init
self.add_module('padding.1', pad_time)
File "/Users/andimarafioti/miniforge3/envs/torchless/lib/python3.9/site-packages/torch/nn/modules/module.py", line 380, in add_module
raise KeyError("module name can't contain ".", got: {}".format(name))
KeyError: 'module name can't contain ".", got: padding.1'
`

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.