GithubHelp home page GithubHelp logo

knight's Issues

请教一下generate()方法中prefix参数。

def caption_generation(image_feature, model: GPT2LMHeadModel, tokenizer, device):
	text = "prefix prefix prefix prefix prefix:"
	inputs = tokenizer(text, return_tensors="pt")
	output = model.generate(inputs["input_ids"].to(device), 40, prefix = image_feature, do_sample = False, num_beams=5)[0]
	output = tokenizer.decode(output)
	return output.split(':')[1].split('.')[0].lower()

如上这段代码model.generate()方法中用到了一个prefix参数,我在查阅Huggingface的文档中并没有找到关于prefix参数的解释。

在modeling_gpt2.py文件中,我找到了如下部分代码:

def forward(
        ...
        prefix: Optional[torch.FloatTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
        ...

以及:

...
if inputs_embeds is None:
    inputs_embeds = self.wte(input_ids)
if prefix != None:
    prefix = prefix.expand(inputs_embeds.shape[0], 5, inputs_embeds.shape[2])
    inputs_embeds = torch.cat((prefix, inputs_embeds[:, 5:, :]), dim = 1)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
...

这段部分的添加应该是作者的修改对吗?期待您的回复。

请教一下数据集中的coco_test.txt是什么内容?我从官网下载的coco2014 image captioning 数据集中没有看到这个。

    json_path = "./data/COCO/captions_val2014.json"
    json_labels = json.load(open(json_path,'r'))
    annotations = json_labels["annotations"]
    images = json_labels["images"]
    images_path = "./data/COCO/image/"

    image_dict = dict()
    for image in images:
        image_dict[image["file_name"]] = image["id"]

    with open("./data/COCO/coco_test.txt") as image_names_data:
        image_names = image_names_data.readlines()

    image_features = []
    for image_info in image_names:
        image_file = image_info.split('\n')[0]
        image_id = image_dict[image_file]
        image_path = images_path + image_file
        ori_image = Image.open(image_path)
        image = preprocess(ori_image).unsqueeze(0).to(device)
        image_feature = clip_model.encode_image(image)
        image_features.append(image_feature)
        
    image_features = torch.cat(image_features)
    torch.save(image_features, "./feature/COCO/image_features.pkl")

因为不懂这个coco_test.txt文件,这段代码没有看明白,如果是读取图片的话,应该只需要拼接file_name与image_folder_name吧。

python run_image_captioning.py --dataset flickr

Hello, thank you for sharing code.

I have the following error when use Flickr30k on colab.

!python run_image_captioning.py --dataset flickr
FileNotFoundError: [Errno 2] No such file or directory: './feature/Flickr/nibers.npy'

about running run_image_captioning.py --dataset coco

Thanks for amazing work.
when I run
python running run_image_captioning.py --dataset coco
some error occurs. the error code is shown as follow:
Traceback (most recent call last): File "run_image_captioning.py", line 148, in <module> main(args) File "run_image_captioning.py", line 97, in main output = GPT_model(**token, labels = token["input_ids"], prefix = batch_caption_feature) File "/home/boyang/anaconda3/envs/knight/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) TypeError: forward() got an unexpected keyword argument 'prefix'
https://github.com/junyangwang0410/Knight/blob/e03dc2e340abcf418aba711acc300946145a0b08/run_image_captioning.py#LL97C25-L97C25

my conda environment is show as follow:

Name Version Build Channel

_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
ca-certificates 2023.01.10 h06a4308_0
certifi 2022.12.7 py38h06a4308_0
charset-normalizer 3.1.0 pypi_0 pypi
click 8.1.3 pypi_0 pypi
clip 1.0 pypi_0 pypi
cmake 3.26.0 pypi_0 pypi
colorlog 6.7.0 pypi_0 pypi
contourpy 1.0.7 pypi_0 pypi
cycler 0.11.0 pypi_0 pypi
filelock 3.10.0 pypi_0 pypi
fonttools 4.39.3 pypi_0 pypi
ftfy 6.1.1 pypi_0 pypi
huggingface-hub 0.13.2 pypi_0 pypi
idna 3.4 pypi_0 pypi
importlib-resources 5.12.0 pypi_0 pypi
jinja2 3.1.2 pypi_0 pypi
joblib 1.2.0 pypi_0 pypi
kiwisolver 1.4.4 pypi_0 pypi
ld_impl_linux-64 2.38 h1181459_1
libffi 3.4.2 h6a678d5_6
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libstdcxx-ng 11.2.0 h1234567_1
lit 15.0.7 pypi_0 pypi
markupsafe 2.1.2 pypi_0 pypi
matplotlib 3.7.1 pypi_0 pypi
mpmath 1.3.0 pypi_0 pypi
ncurses 6.4 h6a678d5_0
networkx 3.0 pypi_0 pypi
nltk 3.8.1 pypi_0 pypi
numpy 1.24.2 pypi_0 pypi
nvidia-cublas-cu11 11.10.3.66 pypi_0 pypi
nvidia-cuda-cupti-cu11 11.7.101 pypi_0 pypi
nvidia-cuda-nvrtc-cu11 11.7.99 pypi_0 pypi
nvidia-cuda-runtime-cu11 11.7.99 pypi_0 pypi
nvidia-cudnn-cu11 8.5.0.96 pypi_0 pypi
nvidia-cufft-cu11 10.9.0.58 pypi_0 pypi
nvidia-curand-cu11 10.2.10.91 pypi_0 pypi
nvidia-cusolver-cu11 11.4.0.1 pypi_0 pypi
nvidia-cusparse-cu11 11.7.4.91 pypi_0 pypi
nvidia-nccl-cu11 2.14.3 pypi_0 pypi
nvidia-nvtx-cu11 11.7.91 pypi_0 pypi
openssl 1.1.1t h7f8727e_0
packaging 23.0 pypi_0 pypi
pillow 9.4.0 pypi_0 pypi
pip 23.0.1 py38h06a4308_0
pycocoevalcap 1.2 pypi_0 pypi
pycocotools 2.0.6 pypi_0 pypi
pyparsing 3.0.9 pypi_0 pypi
python 3.8.16 h7a1cb2a_3
python-dateutil 2.8.2 pypi_0 pypi
pyyaml 6.0 pypi_0 pypi
readline 8.2 h5eee18b_0
regex 2023.5.5 pypi_0 pypi
requests 2.28.2 pypi_0 pypi
rouge 1.0.1 pypi_0 pypi
scikit-learn 1.2.2 pypi_0 pypi
scipy 1.10.1 pypi_0 pypi
setuptools 65.6.3 py38h06a4308_0
six 1.16.0 pypi_0 pypi
sqlite 3.41.1 h5eee18b_0
sympy 1.11.1 pypi_0 pypi
threadpoolctl 3.1.0 pypi_0 pypi
tk 8.6.12 h1ccaba5_0
tokenizers 0.12.1 pypi_0 pypi
torch 2.0.0 pypi_0 pypi
torchvision 0.15.1 pypi_0 pypi
tqdm 4.65.0 pypi_0 pypi
transformers 4.21.2 pypi_0 pypi
triton 2.0.0 pypi_0 pypi
typing-extensions 4.5.0 pypi_0 pypi
urllib3 1.26.15 pypi_0 pypi
wcwidth 0.2.6 pypi_0 pypi
wheel 0.38.4 py38h06a4308_0
xz 5.2.10 h5eee18b_1
zipp 3.15.0 pypi_0 pypi
zlib 1.2.13 h5eee18b_0

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.