GithubHelp home page GithubHelp logo

sauradip / stale Goto Github PK

View Code? Open in Web Editor NEW
96.0 2.0 8.0 26.07 MB

[ECCV 2022] Official Pytorch Implementation of the paper : " Zero-Shot Temporal Action Detection via Vision-Language Prompting "

Home Page: https://sauradip.github.io/project_pages/STALE/

Python 100.00%
action-detection clip prompt-tuning temporal-action-detection temporal-action-localization transformers video-understanding vision-language

stale's Introduction

PWC PWC

Zero-Shot Temporal Action Detection via Vision-Language Prompting

Sauradip Nag1,2,+Xiatian Zhu1,3Yi-Zhe Song1,2Tao Xiang1,2
1CVSSP, University of Surrey, UK  2iFlyTek-Surrey Joint Research Center on Artificial Intelligence, UK 
3Surrey Institute for People-Centred Artificial Intelligence, UK
+corresponding author

Accepted to ECCV 2022

Updates

  • (July, 2022) We released STALE training and inference code for ActivityNetv1.3 dataset.
  • (June, 2022) STALE is accepted by ECCV 2022.

Summary

  • First prompt-guided framework for Zero-Shot Temporal Action Detection (ZS-TAD) task.
  • Adapted classification based CLIP to detection based TAD using Representation Masking.
  • Transformer based Cross-Adaptation module to contextualize classifier using Vision-Language features.
  • Inter-Branch consistency learning to make sure our model can find the accurate boundary.

Abstract

Existing temporal action detection (TAD) methods rely on large training data including segment-level annotations, limited to recognizing previously seen classes alone during inference. Collecting and annotating a large training set for each class of interest is costly and hence unscalable. Zero-shot TAD (ZS-TAD) resolves this obstacle by enabling a pre-trained model to recognize any unseen action classes. Meanwhile, ZS-TAD is also much more challenging with significantly less investigation. Inspired by the success of zero-shot image classification aided by vision-language (ViL) models such as CLIP, we aim to tackle the more complex TAD task. An intuitive method is to integrate an off-the-shelf proposal detector with CLIP style classification. However, due to the sequential localization (e.g., proposal generation) and classification design, it is prone to localization error propagation. To overcome this problem, in this paper we propose a novel zero-Shot Temporal Action detection model via Vision-LanguagE prompting (STALE). Such a novel design effectively eliminates the dependence between localization and classification by breaking the route for error propagation in-between. We further introduce an interaction mechanism between classification and localization for improved optimization. Extensive experiments on standard ZS-TAD video benchmarks show that our STALE significantly outperforms stateof-the-art alternatives. Besides, our model also yields superior results on supervised TAD over recent strong competitors.

Architecture

Getting Started

Requirements

  • Python 3.7
  • PyTorch == 1.9.0 (Please make sure your pytorch version is atleast 1.8)
  • NVIDIA GPU
  • Hugging-Face Transformers
  • Detectron

Environment Setup

It is suggested to create a Conda environment and install the following requirements

pip3 install -r requirements.txt

Extra Dependencies

We have used the implementation of Maskformer for Representation Masking.

git clone https://github.com/sauradip/STALE.git
cd STALE
git clone https://github.com/facebookresearch/MaskFormer

Follow the Installation instructions to install Detectron and other modules within this same environment if possible. After this step, place the files in /STALE/extra_files into /STALE/MaskFormer/mask_former/modeling/transformer/.

Download Features

Download the video features and update the Video paths/output paths in config/anet.yaml file. For now ActivityNetv1.3 dataset config is available. We are planning to release the code for THUMOS14 dataset soon.

Dataset Feature Backbone Pre-Training Link
ActivityNet ViT-B/16-CLIP CLIP Google Drive
THUMOS ViT-B/16-CLIP CLIP Google Drive
ActivityNet I3D Kinetics-400 Google Drive
THUMOS I3D Kinetics-400 Google Drive

Training Splits

Currently we support the training-splits provided by EfficientPrompt paper. Both 50% and 75% labeled data split is available for training. This can be found in STALE/splits

Model Training

To train STALE from scratch run the following command. The training configurations can be adjusted from config/anet.yaml file.

python stale_train.py

Model Inference

We provide the pretrained models containing the checkpoints for both 50% and 75% labeled data split for zero-shot setting

Dataset Split (Seen-Unseen) Feature Link
ActivityNet 50%-50% CLIP ckpt
ActivityNet 75%-25% CLIP ckpt

After downloading the checkpoints, the checkpoints path can be saved in config/anet.yaml file. The model inference can be then performed using the following command

python stale_inference.py

Model Evaluation

To evaluate our STALE model run the following command.

python eval.py

TO-DO Checklist

  • Fix the learnable-prompt issue in Huggig-Face Transformer
  • Fix the NaN bug during Model-Training
  • Support for THUMOS14 dataset
  • Enable multi-gpu training

Acknowledgement

Our source code is based on implementations of DenseCLIP, MaskFormer and CoOP. We thank the authors for open-sourcing their code.

Citation

If you find this project useful for your research, please use the following BibTeX entry.

@article{nag2022zero,
  title={Zero-shot temporal action detection via vision-language prompting},
  author={Nag, Sauradip and Zhu, Xiatian and Song, Yi-Zhe and Xiang, Tao},
  journal={arXiv e-prints},
  pages={arXiv--2207},
  year={2022}
}

stale's People

Contributors

anirudh257 avatar ed-fish avatar sauradip avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

stale's Issues

Reproducibility Issue

Hello @sauradip For Activity using the CLIP features, I was able to these results. It would be really helpful if you could please let me know if any changes in the threshold or any other parameter is needed.

Average-mAP: 0.23484338756421028
Detection: average-mAP 23.484 [email protected] 40.221 [email protected] 36.406 [email protected] 33.502 [email protected] 30.057 [email protected] 26.769 [email protected] 22.584 [email protected] 18.366 [email protected] 14.428 [email protected] 8.962 [email protected] 3.549

Data splits (zero shot)

In a zero-shot scenario,
Out of the 10 lists in the data splits folder, which list was used for performance evaluation?

After obtaining performance results for all 10 lists, did you calculate the average?

Why use only the last segment of the video

In file STALE/stale_lib/stale_dataloader.py , I noticed that the following code means, for a video, to add all the segments from the corresponding annotation file into the label_list.

for j in range(len(labels)):
    tmp_info = labels[j]
    clip_factor = self.temporal_scale / ( corr_sec * (self.num_frame+1) )
    action_start = tmp_info['segment'][0]*clip_factor
    snip_start = max(min(1, tmp_info['segment'][0] / corr_sec), 0)
    action_end = tmp_info['segment'][1]*clip_factor
    snip_end = max(min(1, tmp_info['segment'][1] / corr_sec), 0)
    gt_label = tmp_info["label"]

if action_end - action_start > 1 and gt_label in lbl_dict:
    label_list.append([snip_start,snip_end,gt_label])

But this code actually adds only the last segment. Is the following code correct?

for j in range(len(labels)):
    tmp_info = labels[j]
    clip_factor = self.temporal_scale / ( corr_sec * (self.num_frame+1) )
    action_start = tmp_info['segment'][0]*clip_factor
    snip_start = max(min(1, tmp_info['segment'][0] / corr_sec), 0)
    action_end = tmp_info['segment'][1]*clip_factor
    snip_end = max(min(1, tmp_info['segment'][1] / corr_sec), 0)
    gt_label = tmp_info["label"]

    if action_end - action_start > 1 and gt_label in lbl_dict:
        label_list.append([snip_start,snip_end,gt_label])

I would greatly appreciate it if you could spare some time to answer my question despite your busy schedule.

How to extract the video features with CLIP visual encoder?

Thanks for your great work! I have a question about how to extract video features by using CLIP visual encoder in your work, since I want to extract video features with CLIP visual encoder in a new dataset. If I3D or C3D is used as a video encoder, a video clip with 16 or 32 frames is represented as a feature vector. While for CLIP visual encoder, whether it is needed to extract the 16 or 32 frames features for a video clip? If that, a long video may have thousands of frame features. So, could you give me some details about CLIP features for a video?

Open-set video recognition

Dear author, thank you for publishing your work!
I want to try open-set video recognition.
How to do? In other words, how do I get the features from a video?

Link for pertained model

Hello, can you please share the pretrained weights link for ActivityNet 75% - 25% zero shot setting?

Test Script for videos

Dear authors,

Thank you for the amazing work.

is there a test script we can use to test on any random video? I've noticed that inference and validation are mostly for a predefined benchmark.

Thank you.

question on architecture

Thank you for your great paper.
I have some questions about network architecture and its code implementation of Representation Masking.

  1. In paper, it's said that a MLP is used to optimize query per location. In class STALE I can spot a denifition of self.mask_MLP but cannot find where it is used. Can you help me?
  2. The inputs to the transformer decoder are the snippet embedddings F_vis and mask queries N_z. Isn't this N_z similar to learnable queries in MaskFormer and should correspond to 5/20/100/etc? Why is it set to 1 in the object construction of self.masktrans = TransformerPredictor=(..., num_queries=1,...)?

feature_frame in THUMOS14

In the anet_anno_action.json file, do duration_frame and feature_frame mean the entire video frame and the frame where the action occurred, respectively?

I want to run it on THUMOS14 dataset, but there is no feature_frame in thumos json.
image

Please tell me how to process feature_frame in THUMOS dataset.
Thank you!

Question about video num of Anet datasets

Nice work, thanks for your publication!

I noticed that the number of Anet I3D feature is not consistent with the video number (about 20k). Why is there difference?

if action_end - action_start > 1 and gt_label in lbl_dict

How should i understand action_end - action_start > 1?

corr_sec = float(num_frame) / vid_frame * num_sec
label_list= []
if subset in subset_vid:
    for j in range(len(labels)):
        tmp_info = labels[j]
        clip_factor = self.temporal_scale / ( corr_sec * (self.num_frame+1) )
        action_start = tmp_info['segment'][0]*clip_factor
        snip_start = max(min(1, tmp_info['segment'][0] / corr_sec), 0)
        action_end = tmp_info['segment'][1]*clip_factor
        snip_end = max(min(1, tmp_info['segment'][1] / corr_sec), 0)
        gt_label = tmp_info["label"]

    if action_end - action_start > 1 and gt_label in lbl_dict:
        label_list.append([snip_start,snip_end,gt_label])

屏幕截图 2023-12-28 115209

The picture above is the real label of video(3l7quTy4c2s), and the picture below is the label processed by your code. You can see that the missing part is really a label.

After processing action_end-action_start>1, I found that the small segment in Anet were not added to the calculation? Is this considered a data loss? Or is there a precedent for doing this before?

Question regarding the CLIP based features

Hello @sauradip,

I had a question regarding the CLIP based features provided for ActivityNet and Thumos. Are they extracted using the same feature extraction code provided by Efficient Prompt paper?

Thanks,
Akshita

Why can Zero Shot be achieved?

Hi, I'm also part of the research on Zero Shot Tempoarl Localization Action, and I found that if I use Transformer to model CLIP video frame features, it leads to high mAP in the training set and low mAP in the test set. My guess is that the video frame information from CLIP, after Transformer leads to difficulty in matching with text information. What is the core of solving this problem?

Can you help me? 😭

RuntimeError: Function 'DivBackward0' returned nan values in its 0th output.

I downloaded the code and dataset, and modified only anet.yaml, but I still have this problem, can you help me?

My environment and configuration:

torch                    1.10.1
torchfile                0.1.0
torchnet                 0.0.4
torchvision              0.11.2
dataset:
  num_classes: 200
  split: 75
  training:
    video_info_path: "./data/activitynet_annotations/video_info_new.csv"
    video_anno_path: "./data/activitynet_annotations/anet_anno_action.json"
    num_frame: 5
    output_path: './path/to/train/'

  testing:
    video_info_path: "./data/activitynet_annotations/video_info_new.csv"
    video_anno_path: "./data/activitynet_annotations/anet_anno_action.json"
    num_frame: 5
    output_path: './path/to/test/'

model:
  embedding_head: 4
  # feat_dim: 2048
  feat_dim: 512
  temporal_scale: 100
  clip_pretrain: "O" ## K : KInetics , O : openAI

training:
  batch_size: 100
  learning_rate: 0.00004
  weight_decay: 0.02
  max_epoch: 5
  checkpoint_path: './path/to/output/'
  random_seed: 1
  step: 10
  gamma: 0.3
  feature_path: "/disk/sdd/liuyang/ANet_CLIP"
  num_gpu: 1

loss:
  lambda_1: 0.6
  lambda_2: 0.4

fewshot:
  shot: 0 ## > 0 is few-shot ;  = 0 is zero-shot
  mode: 1 # 1 : base-training 2 : meta-training 3 : meta-testing 4 : no meta-training/ vanilla few-shot
  trimmed: 0 # 0 : untrimmed 1 : trimmed
  episode: 1000
  num_base: 180
  num_test: 20
  ismulti : 1 # 0 : single-instance 1 : multi-instance
  num_way : 4
  meta_class : 1 # # 1: meta-learn classifier 0: vanilla few-shot w/o meta-learning
  meta_mask : 0 # # 1: meta-learn mask 0: vanilla few-shot w/o meta-learning
  trim_support : 1
  num_context : 20

testing:
  cls_thresh: 0.01
  mask_thresh: [0,0.2,0.4,0.6,0.8]
  class_thresh: [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
  top_k_snip: 10
  top_k: 500
  nms_thresh: 0.6

pretraining:
  video_transformer: "./path/to/ckpt"
  isPretrain : 0 # 0 : Finetune , 1 : Pretrain
  video_path: "/disk/sdd/liuyang/ANet_CLIP222"
  raw_video: "/path/to/raw/video"
  clip_length: 768
  clip_stride: 8
  emb_dim: 512

demo:
  generated_feat_dir: "./path/to/feature"

Detailed error reporting

=========using KL Loss=and has temperature and * bz==========

Total Number of Learnable Paramters (in M) :  170.715992
No of Gpus using to Train :  1 
 Saving all Checkpoints in path : ./path/to/train/
No of videos in train is 6575
Loading train Video Information ...
No of class 150
100% 9649/9649 [00:01<00:00, 6946.91it/s]
No of videos in validation is 1094
Loading validation Video Information ...
No of class 50
100% 4728/4728 [00:00<00:00, 26635.37it/s]
stale_train.py:118: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  with autograd.detect_anomaly():
0 torch.Size([100, 512, 100]) torch.Size([100, 100]) torch.Size([100, 100, 100])
/home/ymy/code/ly/STALE-main/MaskFormer/mask_former/modeling/transformer/position_encoding.py:42: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
/home/ymy/miniconda3/envs/py38/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.
  warnings.warn(warning.format(ret))
[W python_anomaly_mode.cpp:104] Warning: Error detected in DivBackward0. Traceback of forward call that caused the error:
  File "stale_train.py", line 119, in <module>
    train(train_loader, model, optimizer, epoch,scheduler)
  File "stale_train.py", line 61, in train
    loss = stale_loss(top_br_gt,top_br_pred,bottom_br_gt,bottom_br_pred,action_gt, mask_pred,bot_gt,cls_pred,label_gt,features,"train")
  File "/home/ymy/code/ly/STALE-main/stale_lib/loss_stale.py", line 235, in stale_loss
    red_loss = redundancy_loss(gt_action , pred_action, gt_cls, pred_cls, features)
  File "/home/ymy/code/ly/STALE-main/stale_lib/loss_stale.py", line 215, in redundancy_loss
    sim_loss += (1-cos_sim(top_feat,bot_feat))
  File "/home/ymy/miniconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ymy/miniconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/distance.py", line 77, in forward
    return F.cosine_similarity(x1, x2, self.dim, self.eps)
 (function _print_stack)
Traceback (most recent call last):
  File "stale_train.py", line 119, in <module>
    train(train_loader, model, optimizer, epoch,scheduler)
  File "stale_train.py", line 64, in train
    tot_loss.backward()
  File "/home/ymy/miniconda3/envs/py38/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/ymy/miniconda3/envs/py38/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'DivBackward0' returned nan values in its 0th output.

About the B-II model

Thank you for your outstanding work, I would like to know more about the B-II model in the paper, can you talk about the details of this model? From where to get the code for this model? Because I found a big performance gap between I3D features and CLIP features in THUMOS dataset. Looking forward to your reply.
image

The Thumos14 clip features

Thank you for your excellent work!I find that the Google Drive of the Thumos14 clip features is empty. Could you upload it again?

Focal loss with internal softmax

F.cross_entropy in PyTorch applies a LogSoftmax operation on its inputs internally. This means that torch.exp(-ce_loss) may be applying the softmax on the values twice when computing the focal loss. I'm not sure if this will have any significant impact but may be something to note.

stale_best_score

In the inference, if the score is lower than stale_best_score, then the label be replaced with the label from this JSON file. How is this JSON file obtained? I also found out that without using this JSON file, then the mAP will drop to 9 instead of 24.9. Why is that?

Clarifying Details to Reproduce on THUMOS14

While I was reproducing the accuracy on the THUMOS14 dataset, some of your implementations were confusing. I would really appreciate your clarification for me to reproduce the results.

Q1.
In the inference time, segments above the threshold are connected to form one large segment as shown in the below figure. Although this is the effective post-processing method for the ActivityNet dataset, this is not true for the THUMOS14 dataset which has many short action instances rather than one/two long action instances.

https://github.com/sauradip/STALE/blob/main/stale_inference.py#L156

filt_seg_score_int = ndimage.binary_fill_holes(filt_seg_score_int).astype(int).tolist()

image

Q2.
In the dataset builder, why do you add 1 and minus 1 for start and end indices, respectively?

https://github.com/sauradip/STALE/blob/main/stale_lib/stale_dataloader.py#L188

        for idx in range(len(start_id)):
          lbl_id = label_id[idx]
          start_indexes.append(start_id[idx]+1)
          end_indexes.append(end_id[idx]-1)
          tuple_list.append([start_id[idx]+1, end_id[idx]-1,lbl_id])

The annotation of Thumos14

Hi, could you provide the annotation of Thumos14? Or where can we obtain it? Looking forward to your reply!

Questions about the implementation of the IOU

Thanks to the authors for introducing the pre-trained ViL model in the ZSTAD domain, especially making the code publicly available to give more reference to that community. I can't understand the code of the IOU implementation very well, it seems to me that "Aand" shoule be implemented by "min(e1, e2) - max(s1, s2)" and returns "Aand / Aor" directly, however, when I do this, the mAP drops. Is it a variant of IOU, I would appreciate if you could help me with this confusion.
image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.