Comments (12)
Can you change this in infer single example
src_input_ids:np.array(features.src_input_ids).reshape(1,-1),
src_segment_ids : np.array(features.src_segment_ids).reshape(1,-1)
from abstractive-summarization-with-transfer-learning.
Have u checked u might be using bigger batch size which doesn't fit in your memory
from abstractive-summarization-with-transfer-learning.
No I haven't changed any batch sizes, by default it should be 1 for inference too, right ? And I'm using Azure Nvidia Tesla M60 GPU with 8Gib of memory.
from abstractive-summarization-with-transfer-learning.
I think model is around 250 million parameters I doubt 8gb can handle this along with the data. Please try with 16gb ram.
from abstractive-summarization-with-transfer-learning.
But I was able to train the model on same gpu without any issues. I'm facing this problem only when i try to do inference on the trained model / last checkpoint.
from abstractive-summarization-with-transfer-learning.
Please post the link to inference code you are running. 512,10,50,512 this tensor size seems to be wrong. Might be the problem with way you are passing the data.
Check the size of input tensor it should be 1×512
from abstractive-summarization-with-transfer-learning.
This is the inference code that I'm running.
`from flask import Flask,request,render_template
import requests
import json
from collections import OrderedDict
import os
import numpy as np
import tensorflow as tf
app =Flask(name)
import sys
if not 'texar_repo' in sys.path:
sys.path += ['texar_repo']
from config import *
from model import *
from preprocess import *
start_tokens = tf.fill([tx.utils.get_batch_size(src_input_ids)],
bos_token_id)
predictions = decoder(
memory=encoder_output,
memory_sequence_length=src_input_length,
decoding_strategy='infer_greedy',
beam_width=beam_width,
alpha=alpha,
start_tokens=start_tokens,
end_token=eos_token_id,
max_decoding_length=300,
mode=tf.estimator.ModeKeys.PREDICT
)
if beam_width <= 1:
inferred_ids = predictions[0].sample_id
else:
# Uses the best sample by beam search
inferred_ids = predictions['sample_id'][:, :, 0]
tokenizer = tokenization.FullTokenizer(
vocab_file=os.path.join(bert_pretrain_dir, 'vocab.txt'),
do_lower_case=True)
sess = tf.Session()
def infer_single_example(story,actual_summary,tokenizer):
example = {"src_txt":story,
"tgt_txt":actual_summary
}
features = convert_single_example(1,example,max_seq_length_src,max_seq_length_tgt,tokenizer)
feed_dict = {
src_input_ids:np.array(features.src_input_ids).reshape(-1,1),
src_segment_ids : np.array(features.src_segment_ids).reshape(-1,1)
}
references, hypotheses = [], []
fetches = {
'inferred_ids': inferred_ids,
}
fetches_ = sess.run(fetches, feed_dict=feed_dict)
labels = np.array(features.tgt_labels).reshape(-1,1)
hypotheses.extend(h.tolist() for h in fetches_['inferred_ids'])
# references.extend(r.tolist() for r in labels)
hypotheses = utils.list_strip_eos(hypotheses, eos_token_id)
# references = utils.list_strip_eos(references, eos_token_id)
hwords = tokenizer.convert_ids_to_tokens(hypotheses[0])
# rwords = tokenizer.convert_ids_to_tokens(references[0])
hwords = tx.utils.str_join(hwords).replace(" ##","")
# rwords = tx.utils.str_join(rwords).replace(" ##","")
# print("Original",rwords)
print("Generated",hwords)
return hwords
@app.route("/results",methods=["GET","POST"])
def results():
story = request.form['story']
summary = request.form['summary']
hwords = infer_single_example(story,summary,tokenizer)
return hwords
if name=="main":
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
saver.restore(sess, tf.train.latest_checkpoint(model_dir))
# app.run(host="0.0.0.0",port=1118,debug=False)
story = "Story text about 200 tokens"
summary = "Summary text about 150 tokens"
# story = input("Enter article:").strip("/n")
# summary = input("Enter summary:").strip("/n")
hwords = infer_single_example(story.strip("/n"), summary.strip("/n"), tokenizer)
print(hwords)
and this is the config,py
import texar as tx
dcoder_config = {
'dim': 768,
'num_blocks': 6,
'multihead_attention': {
'num_heads': 8,
'output_dim': 768
# See documentation for more optional hyperparameters
},
'position_embedder_hparams': {
'dim': 768
},
'initializer': {
'type': 'variance_scaling_initializer',
'kwargs': {
'scale': 1.0,
'mode': 'fan_avg',
'distribution': 'uniform',
},
},
'poswise_feedforward': tx.modules.default_transformer_poswise_net_hparams(
output_dim=768)
}
loss_label_confidence = 0.9
random_seed = 1234
beam_width = 5
alpha = 0.6
hidden_dim = 768
opt = {
'optimizer': {
'type': 'AdamOptimizer',
'kwargs': {
'beta1': 0.9,
'beta2': 0.997,
'epsilon': 1e-9
}
}
}
#warmup steps must be 0.1% of number of iterations
lr = {
'learning_rate_schedule': 'constant.linear_warmup.rsqrt_decay.rsqrt_depth',
'lr_constant': 2 * (hidden_dim ** -0.5),
'static_lr': 1e-3,
'warmup_steps': 10000,
}
bos_token_id =101
eos_token_id = 102
model_dir= "./models"
run_mode= "train_and_evaluate"
batch_size = 1
eval_batch_size = 1
test_batch_size =1
max_train_steps = 100000
display_steps = 1
checkpoint_steps = 500
eval_steps = 50000
max_decoding_length = 400
max_seq_length_src = 512
max_seq_length_tgt = 400
epochs =10
is_distributed = False
data_dir = r"data/"
train_out_file = r"data/train.tf_record"
eval_out_file = r"data/eval.tf_record"
bert_pretrain_dir=r"./bert_uncased_model"
train_story = r"data/train_story.txt"
train_summ = r"data/train_summ.txt"
eval_story = r"data/eval_story.txt"
eval_summ = r"data/eval_summ.txt"
bert_pretrain_dir = r"../uncased_L-12_H-768_A-12"
`
from abstractive-summarization-with-transfer-learning.
Hello, I met the same problem. How did you solve it?@Tanmay06
from abstractive-summarization-with-transfer-learning.
Hi, actually I was away and was working on a different project. @Simons2017 I think you should try @santhoshkolloju 's reply just before your comment. I think it should work.
from abstractive-summarization-with-transfer-learning.
@santhoshkolloju, hi, How do I change this?
src_input_ids:np.array(features.src_input_ids).reshape(1,-1),
src_segment_ids : np.array(features.src_segment_ids).reshape(1,-1),
from abstractive-summarization-with-transfer-learning.
@yuyanzhoufang change this line in the original code.
Abstractive-Summarization-With-Transfer-Learning/Inference.py
Lines 56 to 57 in 97ff2ae
from abstractive-summarization-with-transfer-learning.
from abstractive-summarization-with-transfer-learning.
Related Issues (20)
- ValueError during the init of pretrained BERT HOT 4
- How can I get a abstract quickly?
- Taking way too long for Training HOT 2
- Is there an error inside the _eval_epoch function? HOT 6
- The generated summary has always been one, without any change? HOT 1
- ImportError: cannot import name 'gfile' from 'tensorflow' HOT 1
- Can you make a demo data of this file ?
- The Result on CNN and Daily Mail HOT 1
- AssertionError: model name:bert/encoder/layer_0/ffn/intermediate/bias not exists! HOT 1
- NameError: name 'bert_pretrain_dir' is not defined
- batch size problem HOT 2
- Getting error module 'texar_repo.examples.bert.utils.model_utils' has no attribute 'transform_bert_to_texar_config'
- Requirements file missing HOT 1
- Hi, Can i use your code for Chinese task? HOT 1
- Can't load save_path when it is None.
- ValueError: Dimensions must be equal, but are 768 and 512 for 'bert/transformer_encoder_1/layer_0/add' HOT 2
- got an unexpected keyword argument 'embedding'
- Setup error
- alueError: Unknown hyperparameter: position_embedder_type. Only hyperparameters named 'kwargs' hyperparameters can contain new entries undefined in default hyperparameters.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from abstractive-summarization-with-transfer-learning.