Comments (4)
Please refer to
Line 152 in 4c0555a
for test on the validation dataset.
from adaptive.
Can you share the test code?I tried to write the test section, but there were some problems.
#--coding:utf-8 --
import torch
import matplotlib.pyplot as plt
import numpy as np
import argparse
import pickle
import os
from torchvision import transforms
from build_vocab import Vocabulary
from adaptive import AttentiveCNN ,Decoder ,Encoder2Decoder
from PIL import Image
from torch.autograd import Variable
Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.cuda.set_device(0)
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1" #使用第二块GPU
def load_image(image_path, transform=None):
image = Image.open(image_path)
image = image.resize([224, 224], Image.LANCZOS)
#image = Variable(image)
if transform is not None:
image = transform(image).unsqueeze(0)
return image
def main(args):
# Image preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
# Load vocabulary wrapper
with open(args.vocab_path, 'rb') as f:
vocab = pickle.load(f)
# Build models
encoder = AttentiveCNN(args.embed_size, args.hidden_size).eval() # eval mode (batchnorm uses moving mean/variance)
encoder2decoder = Encoder2Decoder(args.embed_size, len(vocab), args.hidden_size)
# Load the trained model parameters
encoder2decoder.load_state_dict(torch.load(args.encoder2decoder_path))
# Prepare an image
image = load_image(args.image, transform)
# Generate an caption from the image
sampled_ids, attention, Beta = encoder2decoder.sampler(image)
sampled_ids=Variable(torch.LongTensor(sampled_ids))
sampled_ids = sampled_ids[0].cpu().numpy() # (1, max_seq_length) -> (max_seq_length)
# Convert word_ids to words
sampled_caption = []
for word_id in sampled_ids:
word = vocab.idx2word[word_id]
sampled_caption.append(word)
if word == '<end>':
break
sentence = ' '.join(sampled_caption)
# Print out the image and the generated caption
print (sentence)
image = Image.open(args.image)
plt.imshow(np.asarray(image))
if name == 'main':
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, required=True, help='input image for generating caption')
parser.add_argument('--encoder2decoder_path', type=str, default='models/adaptive-1.pkl',
help='path for trained encoder2decoder')
parser.add_argument('--vocab_path', type=str, default='data/vocab.pkl', help='path for vocabulary wrapper')
# Model parameters (should be same as paramters in train.py)
parser.add_argument('--embed_size', type=int, default=256, help='dimension of word embedding vectors')
parser.add_argument('--hidden_size', type=int, default=512, help='dimension of lstm hidden states')
parser.add_argument('--num_layers', type=int, default=1, help='number of layers in lstm')
args = parser.parse_args()
main(args)
RuntimeError:Expected object of type torch.LongTensor but found type torch.cuda.LongTensor for argument #3'index'
please help!!!
from adaptive.
Did you later write the test module to generate image subtitles?Can you share it with me
from adaptive.
Did you later write the test module to generate image subtitles?Can you share it with me
Did you solve it?
from adaptive.
Related Issues (19)
- is this project completely implement the result of the paper? HOT 3
- bug
- Different tokenizers used on training and validation data HOT 4
- pre-trained model? HOT 7
- How to count len( data_loader )? HOT 1
- Training on MSCOCO and testing on Flickr HOT 4
- Encoder encodes the same image differently HOT 10
- Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm)
- When test has a problem,please help HOT 1
- When I try to run train.py the attribute error occurs HOT 4
- train.py
- No module named 'bleu_scorer' for image captioning
- when testing HOT 1
- is batch_size_t work?
- adaptive.py dimension error HOT 9
- how to find accuracy for train.py ?
- bug,help
- I got CIDer 0.82 only ,could you please help me about how I can imporve the score?thanks
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from adaptive.