GithubHelp home page GithubHelp logo

How to make a test? about adaptive HOT 4 OPEN

inkzk avatar inkzk commented on August 21, 2024
How to make a test?

from adaptive.

Comments (4)

yufengm avatar yufengm commented on August 21, 2024

Please refer to

cider = coco_eval( adaptive, args, epoch )

for test on the validation dataset.

from adaptive.

17000432 avatar 17000432 commented on August 21, 2024

Can you share the test code?I tried to write the test section, but there were some problems.

#--coding:utf-8 --
import torch
import matplotlib.pyplot as plt
import numpy as np
import argparse
import pickle
import os
from torchvision import transforms
from build_vocab import Vocabulary
from adaptive import AttentiveCNN ,Decoder ,Encoder2Decoder
from PIL import Image

from torch.autograd import Variable

Device configuration

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.cuda.set_device(0)
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1" #使用第二块GPU

def load_image(image_path, transform=None):
image = Image.open(image_path)
image = image.resize([224, 224], Image.LANCZOS)
#image = Variable(image)
if transform is not None:
image = transform(image).unsqueeze(0)
return image

def main(args):
# Image preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])

# Load vocabulary wrapper
with open(args.vocab_path, 'rb') as f:
    vocab = pickle.load(f)

# Build models
encoder = AttentiveCNN(args.embed_size, args.hidden_size).eval()  # eval mode (batchnorm uses moving mean/variance)

encoder2decoder = Encoder2Decoder(args.embed_size, len(vocab), args.hidden_size)
# Load the trained model parameters
encoder2decoder.load_state_dict(torch.load(args.encoder2decoder_path))

# Prepare an image
image = load_image(args.image, transform)

# Generate an caption from the image
sampled_ids, attention, Beta = encoder2decoder.sampler(image)
sampled_ids=Variable(torch.LongTensor(sampled_ids))
sampled_ids = sampled_ids[0].cpu().numpy()  # (1, max_seq_length) -> (max_seq_length)
# Convert word_ids to words
sampled_caption = []
for word_id in sampled_ids:
    word = vocab.idx2word[word_id]
    sampled_caption.append(word)
    if word == '<end>':
        break
sentence = ' '.join(sampled_caption)

# Print out the image and the generated caption
print (sentence)
image = Image.open(args.image)
plt.imshow(np.asarray(image))

if name == 'main':
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, required=True, help='input image for generating caption')
parser.add_argument('--encoder2decoder_path', type=str, default='models/adaptive-1.pkl',
help='path for trained encoder2decoder')
parser.add_argument('--vocab_path', type=str, default='data/vocab.pkl', help='path for vocabulary wrapper')

# Model parameters (should be same as paramters in train.py)
parser.add_argument('--embed_size', type=int, default=256, help='dimension of word embedding vectors')
parser.add_argument('--hidden_size', type=int, default=512, help='dimension of lstm hidden states')
parser.add_argument('--num_layers', type=int, default=1, help='number of layers in lstm')
args = parser.parse_args()
main(args)

RuntimeError:Expected object of type torch.LongTensor but found type torch.cuda.LongTensor for argument #3'index'

please help!!!

from adaptive.

17000432 avatar 17000432 commented on August 21, 2024

@inkzk

Did you later write the test module to generate image subtitles?Can you share it with me

from adaptive.

Allenxq avatar Allenxq commented on August 21, 2024

@inkzk

Did you later write the test module to generate image subtitles?Can you share it with me

Did you solve it?

from adaptive.

Related Issues (19)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.