GithubHelp home page GithubHelp logo

temporal-adaptive-module's Introduction

TAM: Temporal Adaptive Module for Video Recognition [arXiv]

@inproceedings{liu2021tam,
  title={TAM: Temporal adaptive module for video recognition},
  author={Liu, Zhaoyang and Wang, Limin and Wu, Wayne and Qian, Chen and Lu, Tong},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={13708--13718},
  year={2021}
}

[NEW!] 2021/07/23 - Our paper has been accepted by ICCV2021. More pretrained models will be released soon for research purpose. Welcom to follow our work!

[NEW!] 2021/06/01 - Our temporal adaptive module has been integrated into MMAction2! We are glad to see our TAM achieved higher accuracy with MMaction2 in several datasets.

[NEW!] 2020/10/10 - We have released the code of TAM for research purpose.

Overview

We release the PyTorch code of the Temporal Adaptive Module.

Architecture
The overall architecture of TANet: ResNet-Block vs. TA-Block.

Content

Prerequisites

The code is built with following libraries:

Data Preparation

As following TSN and TSM repos, we provide a series of tools (vidtools) to extracte frames of video.

For convenience, the processing of video data can be summarized as follows:

  • Extract frames from videos.

    1. Firstly, you need clone vidtools:

      git clone https://github.com/liu-zhy/vidtools.git & cd vidtools
    2. Extract frames by running:

      python extract_frames.py VIDEOS_PATH/ \
      -o DATASETS_PATH/frames/ \
      -j 16 --out_ext png
      

      We suggest you use --out_ext jpg with limited disk storage.

  • Generate the annotation.

    The annotation usually includes train.txt, val.txt and test.txt (optional). The format of *.txt file is like:

    frames/video_1 num_frames label_1
    frames/video_2 num_frames label_2
    frames/video_3 num_frames label_3
    ...
    frames/video_N num_frames label_N
    

    The pre-processed dataset is organized with the following structure:

    datasets
      |_ Kinetics400
        |_ frames
        |  |_ [video_0]
        |  |  |_ img_00001.png
        |  |  |_ img_00001.png
        |  |  |_ ...
        |  |_ [video_1]
        |     |_ img_00001.png
        |     |_ img_00002.png
        |     |_ ...
        |_ annotations
           |_ train.txt
           |_ val.txt
           |_ test.txt (optional)
    
  • Configure the dataset in ops/dataset_configs.py.

Model ZOO

Here we provide some off-the-shelf pretrained models. The accuracy might vary a little bit compared to the paper, since the raw video of Kinetics downloaded by users may have some differences.

Models Datasets Resolution Frames * Crops * Clips Top-1 Top-5 Checkpoints
TAM-R50 Kinetics-400 256 * 256 8 * 3 * 10 76.1% 92.3% ckpt
TAM-R50 Kinetics-400 256 * 256 16 * 3 * 10 76.9% 92.9% ckpt
TAM-R50 Sth-Sth v1 224 * 224 8 * 1 * 1 46.5% 75.8% ckpt
TAM-R50 Sth-Sth v1 224 * 224 16 * 1 * 1 47.6% 77.7% ckpt
TAM-R50 Sth-Sth v2 256 * 256 8 * 3 * 2 62.7% 88.0% ckpt
TAM-R50 Sth-Sth v2 256 * 256 16 * 3 * 2 64.6% 89.5% ckpt

After downloading the checkpoints and putting them into the target path, you can test the TAM with these pretrained weights.

Testing

For example, to test the downloaded pretrained models on Kinetics, you can run scripts/test_tam_kinetics_rgb_8f.sh. The scripts will test TAM with 8-frame setting:

# test TAM on Kinetics-400
python -u test_models.py kinetics \
--weights=./checkpoints/kinetics_RGB_resnet50_tam_avg_segment8_e100_dense/ckpt.best.pth.tar \
--test_segments=8 --test_crops=3 \
--full_res --sample dense-10 --batch_size 8

We should notice that --sample can determine the sampling strategy in the testing. Specifically, --sample uniform-N denotes the model takes N clips uniformly sampled from video as inputs, and --sample dense-N denotes the model takes N clips densely sampled from video as inputs.

You also can test TAM on Something-Something V2 by running scripts/test_tam_somethingv2_rgb_8f.sh:

# test TAM on Something-Something V2
python -u test_models.py somethingv2 \
--weights=./checkpoints/something_RGB_resnet50_tam_avg_segment8_e50/ckpt.best.pth.tar \
--test_segments=8 --test_crops=3 \
--full_res --sample uniform-2 --batch_size 32

Training

We provided several scripts to train TAM in this repo:

  • To train on Kinetics from ImageNet pretrained models, you can run scripts/train_tam_kinetics_rgb_8f.sh, which contains:

      python -u main.py kinetics RGB --arch resnet50 \
      --num_segments 8 --gd 20 --lr 0.01 --lr_steps 50 75 90 --epochs 100 --batch-size 8 \
      -j 8 --dropout 0.5 --consensus_type=avg --root_log ./checkpoint/this_ckpt \
      --root_model ./checkpoint/this_ckpt --eval-freq=1 --npb \
      --self_conv  --dense_sample --wd 0.0001

    After training, you should get a new checkpoint as downloaded above.

  • To train on Something-Something dataset (V1 & V2), you can run following commands:

    # train TAM on Something-Something V1
    bash scripts/train_tam_something_rgb_8f.sh
    
    # train TAM on Something-Something V2
    bash scripts/train_tam_somethingv2_rgb_8f.sh

temporal-adaptive-module's People

Contributors

liu-zhy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

temporal-adaptive-module's Issues

n_segment?

TAM中的n_segment表示的输入视频的帧数吗?

TAM module in the ResBlock

Hello, thank you very much for your work. Where is the appropriate location of the TAM module in the ResBlock in SlowFast?

awesome works

We have tested your several related works on our own large real dataset and the result is exciting. Respect bro.

About MobileNetV2 arch

Thank you for your work.
And do you implement MobileNetV2-TAM arch? Could you release those code?

Thanks!

About n_segment

n_segment是视频序列中帧的个数,但如果不确定视频序列帧数应该怎么办呢,这里不能自适应调整嘛?

swapping labels for training data of Something-something v2

Hi Zhaoyang,

thanks for sharing the nice implementations!
I have a question regarding the data processing of Something-something v2.
I notice that for data on Something-something v2, you hard code label_transforms to swap the labels for 3 groups of classes: 86 and 87, 93 and 94, 166 and 167 (line 458 in ops/models.py). However this is only done for training, not for validation or test. I wonder if this means that there are errors in the annotation of training data of Something-something v2.

Looking forward to your reply and thanks for the efforts.

Best,

Wei

test model and get error

Hi, thanks for your awesome work in video recognition and also the release.

I run the test command but get errors.

CUDA_VISIBLE_DEVICES=1 python -u test_models.py kinetics \
--weights=./checkpoints/kinetics_RGB_resnet50_tam_avg_segment16_e100_dense/ckpt.best.pth.tar \
--test_segments=16 --test_crops=3 \
--full_res --sample dense-10 --batch_size 1

My envs: python3.7, torch 1.6.0, cuda version 11.0
error log:

  return self.module(*inputs[0], **kwargs[0])
  File "/home/sean/miniconda3/envs/openmmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sean/workspace/temporal-adaptive-module/ops/models.py", line 327, in forward
    output = self.consensus(base_out)
  File "/home/sean/miniconda3/envs/openmmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sean/workspace/temporal-adaptive-module/ops/basic_ops.py", line 46, in forward
    return SegmentConsensus(self.consensus_type, self.dim)(input)
  File "/home/sean/miniconda3/envs/openmmlab/lib/python3.7/site-packages/torch/autograd/function.py", line 149, in __call__
    "Legacy autograd function with non-static forward method is deprecated. "
RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/home/sean/miniconda3/envs/openmmlab/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/home/sean/miniconda3/envs/openmmlab/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/sean/miniconda3/envs/openmmlab/lib/python3.7/site-packages/torch/utils/data/_utils/pin_memory.py", line 25, in _pin_memory_loop
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/home/sean/miniconda3/envs/openmmlab/lib/python3.7/multiprocessing/queues.py", line 113, in get
    return _ForkingPickler.loads(res)
  File "/home/sean/miniconda3/envs/openmmlab/lib/python3.7/site-packages/torch/multiprocessing/reductions.py", line 282, in rebuild_storage_fd
    fd = df.detach()
  File "/home/sean/miniconda3/envs/openmmlab/lib/python3.7/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/sean/miniconda3/envs/openmmlab/lib/python3.7/multiprocessing/resource_sharer.py", line 87, in get_connection
    c = Client(address, authkey=process.current_process().authkey)
  File "/home/sean/miniconda3/envs/openmmlab/lib/python3.7/multiprocessing/connection.py", line 492, in Client
    c = SocketClient(address)
  File "/home/sean/miniconda3/envs/openmmlab/lib/python3.7/multiprocessing/connection.py", line 620, in SocketClient
    s.connect(address)
ConnectionRefusedError: [Errno 111] Connection refused

So could you please help me to figure it out? thx

Learning rate of nn.Linear

Thanks for open source such great work.
I notice that all the learning rate of linear layers are x5, even in all the temporal adaptive module. I know that normally for the last fully connected layer, larger learning rate would bring better performance. Is this a mistake? Or it can produce better result?

The pretrained models seem don't work

Hi, thanks for your great work. I tried the pretrained somthingv1-8f and somethingv1-16f checkpoints and only got 0.5% test accuracy. Maybe there are some mistakes in these ckpts. Could you please check that? Or is there anything I need to do with the ssthv1 frames? I'm using the original frames without resizing them before being loaded.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.