GithubHelp home page GithubHelp logo

knight's Introduction

README

Paper (Accepted by IJCAI 2023)

1. Installing

$ pip install -r requirements.txt
$ pip install git+https://github.com/openai/CLIP.git

2. Data Preparation

Downloading the images and videos of each dataset from Web.

The data files looks like:

./data/
  ├──./COCO/
  |   ├──./image/					#images of the test split
  |   ├──captions_val2014.json		#annotation of test split
  |   ├──coco_test.txt				#test split of Karpathy
  ├──./Flickr/
  |   ├──./image/					#images in dataset
  |   ├──dataset_flickr30k.json		#annotation
  ├──./MSRVTT/
  |   ├──./video/					#images in dataset
  |   ├──./frames/					#keyframes
  |   ├──train_val_videodatainfo.json	#annotation
  ├──./MSVD/
  |   ├──./video/					#images in dataset
  |   ├──./frames/					#keyframes
  |   ├──caption.txt				#annotation
  |   ├──train_list.txt				#train split
  |   ├──test_list.txt				#train split

After preparing the data, execute the following commands to obtain the data files required to run

python data_prepare_{dataset name}.py

dataset name = {coco, flickr, msrvtt, msvd}

3. Run

Image Captioning

python run_image_captioning.py --dataset {dataset name}

dataset name = {coco, flickr}

Video Captioning

python run_video_captioning.py --dataset {dataset name}

dataset name = {msrvtt, msvd}

The default save path for checkpoints is ./checkpoint/{dataset name}, and the default save path for caption flies is ./output/{dataset name}, where the dataset name = {coco, flickr, msrvtt, msvd}

4. Evaluation

We provide the reference results and the results generated as the paper under the ./output/{dataset_name}/

For example:

python evalution.py 
--ref ./output/COCO/reference_COCO.json
--gts ./output/COCO/result_COCO.json

5. Demo

Getting the checkpoint as above operations and put them in ./checkpoint/COCO/ as:

./checkpoint/
  ├──./COCO/
  |   ├──decoder_coco.pth
  |   ├──map_coco.pth	

Then run the demo.ipynb

knight's People

Contributors

junyangwang0410 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar liujinyuan avatar Zijian avatar leo sher avatar  avatar David Nukrai avatar

Watchers

 avatar

knight's Issues

python run_image_captioning.py --dataset flickr

Hello, thank you for sharing code.

I have the following error when use Flickr30k on colab.

!python run_image_captioning.py --dataset flickr
FileNotFoundError: [Errno 2] No such file or directory: './feature/Flickr/nibers.npy'

about running run_image_captioning.py --dataset coco

Thanks for amazing work.
when I run
python running run_image_captioning.py --dataset coco
some error occurs. the error code is shown as follow:
Traceback (most recent call last): File "run_image_captioning.py", line 148, in <module> main(args) File "run_image_captioning.py", line 97, in main output = GPT_model(**token, labels = token["input_ids"], prefix = batch_caption_feature) File "/home/boyang/anaconda3/envs/knight/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) TypeError: forward() got an unexpected keyword argument 'prefix'
https://github.com/junyangwang0410/Knight/blob/e03dc2e340abcf418aba711acc300946145a0b08/run_image_captioning.py#LL97C25-L97C25

my conda environment is show as follow:

Name Version Build Channel

_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
ca-certificates 2023.01.10 h06a4308_0
certifi 2022.12.7 py38h06a4308_0
charset-normalizer 3.1.0 pypi_0 pypi
click 8.1.3 pypi_0 pypi
clip 1.0 pypi_0 pypi
cmake 3.26.0 pypi_0 pypi
colorlog 6.7.0 pypi_0 pypi
contourpy 1.0.7 pypi_0 pypi
cycler 0.11.0 pypi_0 pypi
filelock 3.10.0 pypi_0 pypi
fonttools 4.39.3 pypi_0 pypi
ftfy 6.1.1 pypi_0 pypi
huggingface-hub 0.13.2 pypi_0 pypi
idna 3.4 pypi_0 pypi
importlib-resources 5.12.0 pypi_0 pypi
jinja2 3.1.2 pypi_0 pypi
joblib 1.2.0 pypi_0 pypi
kiwisolver 1.4.4 pypi_0 pypi
ld_impl_linux-64 2.38 h1181459_1
libffi 3.4.2 h6a678d5_6
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libstdcxx-ng 11.2.0 h1234567_1
lit 15.0.7 pypi_0 pypi
markupsafe 2.1.2 pypi_0 pypi
matplotlib 3.7.1 pypi_0 pypi
mpmath 1.3.0 pypi_0 pypi
ncurses 6.4 h6a678d5_0
networkx 3.0 pypi_0 pypi
nltk 3.8.1 pypi_0 pypi
numpy 1.24.2 pypi_0 pypi
nvidia-cublas-cu11 11.10.3.66 pypi_0 pypi
nvidia-cuda-cupti-cu11 11.7.101 pypi_0 pypi
nvidia-cuda-nvrtc-cu11 11.7.99 pypi_0 pypi
nvidia-cuda-runtime-cu11 11.7.99 pypi_0 pypi
nvidia-cudnn-cu11 8.5.0.96 pypi_0 pypi
nvidia-cufft-cu11 10.9.0.58 pypi_0 pypi
nvidia-curand-cu11 10.2.10.91 pypi_0 pypi
nvidia-cusolver-cu11 11.4.0.1 pypi_0 pypi
nvidia-cusparse-cu11 11.7.4.91 pypi_0 pypi
nvidia-nccl-cu11 2.14.3 pypi_0 pypi
nvidia-nvtx-cu11 11.7.91 pypi_0 pypi
openssl 1.1.1t h7f8727e_0
packaging 23.0 pypi_0 pypi
pillow 9.4.0 pypi_0 pypi
pip 23.0.1 py38h06a4308_0
pycocoevalcap 1.2 pypi_0 pypi
pycocotools 2.0.6 pypi_0 pypi
pyparsing 3.0.9 pypi_0 pypi
python 3.8.16 h7a1cb2a_3
python-dateutil 2.8.2 pypi_0 pypi
pyyaml 6.0 pypi_0 pypi
readline 8.2 h5eee18b_0
regex 2023.5.5 pypi_0 pypi
requests 2.28.2 pypi_0 pypi
rouge 1.0.1 pypi_0 pypi
scikit-learn 1.2.2 pypi_0 pypi
scipy 1.10.1 pypi_0 pypi
setuptools 65.6.3 py38h06a4308_0
six 1.16.0 pypi_0 pypi
sqlite 3.41.1 h5eee18b_0
sympy 1.11.1 pypi_0 pypi
threadpoolctl 3.1.0 pypi_0 pypi
tk 8.6.12 h1ccaba5_0
tokenizers 0.12.1 pypi_0 pypi
torch 2.0.0 pypi_0 pypi
torchvision 0.15.1 pypi_0 pypi
tqdm 4.65.0 pypi_0 pypi
transformers 4.21.2 pypi_0 pypi
triton 2.0.0 pypi_0 pypi
typing-extensions 4.5.0 pypi_0 pypi
urllib3 1.26.15 pypi_0 pypi
wcwidth 0.2.6 pypi_0 pypi
wheel 0.38.4 py38h06a4308_0
xz 5.2.10 h5eee18b_1
zipp 3.15.0 pypi_0 pypi
zlib 1.2.13 h5eee18b_0

请教一下数据集中的coco_test.txt是什么内容?我从官网下载的coco2014 image captioning 数据集中没有看到这个。

    json_path = "./data/COCO/captions_val2014.json"
    json_labels = json.load(open(json_path,'r'))
    annotations = json_labels["annotations"]
    images = json_labels["images"]
    images_path = "./data/COCO/image/"

    image_dict = dict()
    for image in images:
        image_dict[image["file_name"]] = image["id"]

    with open("./data/COCO/coco_test.txt") as image_names_data:
        image_names = image_names_data.readlines()

    image_features = []
    for image_info in image_names:
        image_file = image_info.split('\n')[0]
        image_id = image_dict[image_file]
        image_path = images_path + image_file
        ori_image = Image.open(image_path)
        image = preprocess(ori_image).unsqueeze(0).to(device)
        image_feature = clip_model.encode_image(image)
        image_features.append(image_feature)
        
    image_features = torch.cat(image_features)
    torch.save(image_features, "./feature/COCO/image_features.pkl")

因为不懂这个coco_test.txt文件,这段代码没有看明白,如果是读取图片的话,应该只需要拼接file_name与image_folder_name吧。

请教一下generate()方法中prefix参数。

def caption_generation(image_feature, model: GPT2LMHeadModel, tokenizer, device):
	text = "prefix prefix prefix prefix prefix:"
	inputs = tokenizer(text, return_tensors="pt")
	output = model.generate(inputs["input_ids"].to(device), 40, prefix = image_feature, do_sample = False, num_beams=5)[0]
	output = tokenizer.decode(output)
	return output.split(':')[1].split('.')[0].lower()

如上这段代码model.generate()方法中用到了一个prefix参数,我在查阅Huggingface的文档中并没有找到关于prefix参数的解释。

在modeling_gpt2.py文件中,我找到了如下部分代码:

def forward(
        ...
        prefix: Optional[torch.FloatTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
        ...

以及:

...
if inputs_embeds is None:
    inputs_embeds = self.wte(input_ids)
if prefix != None:
    prefix = prefix.expand(inputs_embeds.shape[0], 5, inputs_embeds.shape[2])
    inputs_embeds = torch.cat((prefix, inputs_embeds[:, 5:, :]), dim = 1)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
...

这段部分的添加应该是作者的修改对吗?期待您的回复。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.