GithubHelp home page GithubHelp logo

mingyao1120 / tr-detr Goto Github PK

View Code? Open in Web Editor NEW
28.0 3.0 2.0 96.08 MB

Official pytorch repository for "TR-DETR: Task-Reciprocal Transformer for Joint Moment Retrieval and Highlight Detection" (AAAI 2024 Paper)

License: Other

Python 96.59% Shell 3.41%

tr-detr's Introduction

TR-DETR: Task-Reciprocal Transformer for Joint Moment Retrieval and Highlight Detection (AAAI 2024 Paper)

by Hao Sun* 1, Mingyao Zhou* 1, Wenjing Chen†2, Wei Xie†1

1 Central China Normal University, 2 Hubei University of Technology, * Equal Contribution, Corresponding authors.

[Paper]


Prerequisites

0. Clone this repo

1. Prepare datasets

QVHighlights : Download official feature files for QVHighlights dataset from Moment-DETR.

Download moment_detr_features.tar.gz (8GB), extract it under '../features' directory. You can change the data directory by modifying 'feat_root' in shell scripts under 'tr_detr/scripts/' directory.

tar -xf path/to/moment_detr_features.tar.gz

TVSum : Download feature files for TVSum dataset from UMT.

Download TVSum (69.1MB), and either extract it under '../features/tvsum/' directory or change 'feat_root' in TVSum shell files under 'tr_detr/scripts/tvsum/'.

2. Install dependencies. Python version 3.7 is required.

pip install -r requirements.txt

Requirements.txt also include other libraries. Will be cleaned up soon. For anaconda setup, please refer to the official Moment-DETR github.

QVHighlights

Training

Training with (only video) and (video + audio) can be executed by running the shell below:

bash tr_detr/scripts/train.sh 
bash tr_detr/scripts/train_audio.sh 

Best validation accuracy is yielded at the last epoch.

Inference Evaluation and Codalab Submission for QVHighlights

Once the model is trained, hl_val_submission.jsonl and hl_test_submission.jsonl can be yielded by running inference.sh.

bash tr_detr/scripts/inference.sh results/{direc}/model_best.ckpt 'val'
bash tr_detr/scripts/inference.sh results/{direc}/model_best.ckpt 'test'

where direc is the path to the saved checkpoint. For more details for submission, check standalone_eval/README.md.

TVSum

Training with (only video) and (video + audio) can be executed by running the shell below:

bash tr_detr/scripts/tvsum/train_tvsum.sh 
bash tr_detr/scripts/tvsum/train_tvsum_audio.sh 

Best results are stored in 'results_[domain_name]/best_metric.jsonl'.

Cite TR-DETR (TR-DETR: Task-Reciprocal Transformer for Joint Moment Retrieval and Highlight Detection)

If you find this repository useful, please use the following entry for citation.

@inproceedings{sun_zhou2024tr,
  title={Tr-detr: Task-reciprocal transformer for joint moment retrieval and highlight detection},
  author={Sun, Hao and Zhou, Mingyao and Chen, Wenjing and Xie, Wei},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={38},
  number={5},
  pages={4998--5007},
  year={2024}
}

LICENSE

The annotation files and many parts of the implementations are borrowed Moment-DETR and QD-DETR. Following, our codes are also under MIT license.

tr-detr's People

Contributors

mingyao1120 avatar

Stargazers

Chujun Huang avatar  avatar Seojeong Park avatar Yongxin Guo avatar  avatar ChenYuming avatar Yuxuan Li avatar  avatar  avatar yahooo avatar  avatar  avatar Yogesh Kumar avatar  avatar Ling avatar  avatar Zhuo Cao avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar Zmu avatar TimeShadow avatar  avatar Hongyu Qu avatar

Watchers

Kostas Georgiou avatar  avatar  avatar

Forkers

hlchen23 dgymjol

tr-detr's Issues

charades_sta "pos_mask"

Hello!Thank you great work! In charades_sta_train_tvr_format.jsonl dataset, The annotation file does not contain the 'relevant_clip_ids' section. How can I modify the getitem_ function section of the StartEndDataset class below to generate model_inputs ["pos_mask"]?

class StartEndDataset(Dataset):

    def __getitem__(self, index):
        meta = self.data[index]

        model_inputs = dict()
        model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"])  # (Dq, ) or (Lq, Dq)
        if self.use_video:
            model_inputs["video_feat"] = self._get_video_feat_by_vid(meta["vid"])  # (Lv, Dv)
            ctx_l = len(model_inputs["video_feat"])
        else:
            ctx_l = self.max_v_l

        if self.use_tef:
            tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
            tef_ed = tef_st + 1.0 / ctx_l
            tef = torch.stack([tef_st, tef_ed], dim=1)  # (Lv, 2)
            if self.use_video:
                model_inputs["video_feat"] = torch.cat(
                    [model_inputs["video_feat"], tef], dim=1)  # (Lv, Dv+2)
            else:
                model_inputs["video_feat"] = tef

        if self.load_labels:
            if self.dset_name == 'tvsum': 

                max_l = ctx_l//2 

                meta_label = meta['label']
                agg_scores = np.sum(meta_label - np.ones_like(meta_label), axis=-1)[:ctx_l] # start from 1, so minus 1
                sort_indices = np.argsort(agg_scores)  # increasing
                pos_idx = torch.tensor(sort_indices[max_l:])
                
                mask = torch.zeros_like(torch.ones(ctx_l))

                if pos_idx.max() >= len(mask):
                    new_mask = torch.zeros_like(torch.ones(pos_idx.max()+1 ))
                    new_mask[pos_idx] = 1
                    new_mask[:len(mask)] = mask
                    mask = new_mask
                else:
                    mask[pos_idx] = 1

                model_inputs["pos_mask"] = mask 
                
                
                neg_idx = torch.tensor(list(set(range(ctx_l)) - set(pos_idx)))
                

                pad_tensor = torch.ones(ctx_l) * -2
                pad_tensor[:len(pos_idx)] = pos_idx
                model_inputs["pos_idx"] = pad_tensor

                pad_tensor = torch.ones(ctx_l) * -2
                pad_tensor[:len(neg_idx)] = neg_idx
                model_inputs["neg_idx"] = pad_tensor

                model_inputs["span_labels"] = torch.tensor([[0., 0.]])
                meta_label = meta['label']
                model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
                            self.get_saliency_labels_all_tvsum(meta_label, ctx_l)
            else:

                pos_idx = torch.tensor(meta['relevant_clip_ids'])
                mask = torch.zeros_like(torch.ones(ctx_l))

                if pos_idx.max() >= len(mask):
                    new_mask = torch.zeros_like(torch.ones(pos_idx.max()+1 ))
                    new_mask[pos_idx] = 1
                    new_mask[:len(mask)] = mask
                    mask = new_mask
                else:
                    mask[pos_idx] = 1

                model_inputs["pos_mask"] = mask 


                model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l)  # (#windows, 2)
                if "subs_train" not in self.data_path:
                    model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
                        self.get_saliency_labels_all(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l)
                else:
                    model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
                        self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l)  # only one gt
                    
        return dict(meta=meta, model_inputs=model_inputs)

我在加载数据集的时候一直报错,找不到文件,我的qvhighlight数据集就放在根目录下,运行inference.sh一直跑不通,可以帮我看一下是哪里需要调整嘛

Traceback (most recent call last):
File "tr_detr/inference.py", line 423, in
start_inference(split=split, splitfile=splitfile)
File "tr_detr/inference.py", line 413, in start_inference
eval_epoch(model, eval_dataset, opt, save_submission_filename, criterion=criterion)
File "tr_detr/inference.py", line 299, in eval_epoch
submission, eval_loss_meters = get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer)
File "tr_detr/inference.py", line 254, in get_eval_res
eval_res, eval_loss_meters = compute_mr_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) # list(dict)
File "/root/miniconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "tr_detr/inference.py", line 185, in compute_mr_results
for batch in tqdm(eval_loader, desc="compute st ed scores"):
File "/root/miniconda3/lib/python3.8/site-packages/tqdm/std.py", line 1185, in iter
for obj in iterable:
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in next
data = self._next_data()
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
return self._process_data(data)
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
data.reraise()
File "/root/miniconda3/lib/python3.8/site-packages/torch/_utils.py", line 425, in reraise
raise self.exc_type(msg)
FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/root/autodl-tmp/TR-DETR/tr_detr/start_end_dataset.py", line 90, in getitem
model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"]) # (Dq, ) or (Lq, Dq)
File "/root/autodl-tmp/TR-DETR/tr_detr/start_end_dataset.py", line 324, in _get_query_feat_by_qid
q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32)
File "/root/miniconda3/lib/python3.8/site-packages/numpy/lib/npyio.py", line 417, in load
fid = stack.enter_context(open(os_fspath(file), "rb"))
FileNotFoundError: [Errno 2] No such file or directory: '.../qid2579.npz'

The dataset

Hello, the download links provided for the two datasets in the paper indicate that the files do not exist. I was wondering if you could provide updated links or an alternative method to access these datasets? Thank you very much for your assistance and understanding.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.