GithubHelp home page GithubHelp logo

Comments (12)

santhoshkolloju avatar santhoshkolloju commented on May 27, 2024 1

Can you change this in infer single example
src_input_ids:np.array(features.src_input_ids).reshape(1,-1),
src_segment_ids : np.array(features.src_segment_ids).reshape(1,-1)

from abstractive-summarization-with-transfer-learning.

santhoshkolloju avatar santhoshkolloju commented on May 27, 2024

Have u checked u might be using bigger batch size which doesn't fit in your memory

from abstractive-summarization-with-transfer-learning.

Tanmay06 avatar Tanmay06 commented on May 27, 2024

No I haven't changed any batch sizes, by default it should be 1 for inference too, right ? And I'm using Azure Nvidia Tesla M60 GPU with 8Gib of memory.

from abstractive-summarization-with-transfer-learning.

santhoshkolloju avatar santhoshkolloju commented on May 27, 2024

I think model is around 250 million parameters I doubt 8gb can handle this along with the data. Please try with 16gb ram.

from abstractive-summarization-with-transfer-learning.

Tanmay06 avatar Tanmay06 commented on May 27, 2024

But I was able to train the model on same gpu without any issues. I'm facing this problem only when i try to do inference on the trained model / last checkpoint.

from abstractive-summarization-with-transfer-learning.

santhoshkolloju avatar santhoshkolloju commented on May 27, 2024

Please post the link to inference code you are running. 512,10,50,512 this tensor size seems to be wrong. Might be the problem with way you are passing the data.
Check the size of input tensor it should be 1×512

from abstractive-summarization-with-transfer-learning.

Tanmay06 avatar Tanmay06 commented on May 27, 2024

This is the inference code that I'm running.

`from flask import Flask,request,render_template
import requests
import json
from collections import OrderedDict
import os
import numpy as np
import tensorflow as tf

app =Flask(name)

import sys

if not 'texar_repo' in sys.path:
sys.path += ['texar_repo']

from config import *
from model import *
from preprocess import *

start_tokens = tf.fill([tx.utils.get_batch_size(src_input_ids)],
bos_token_id)
predictions = decoder(
memory=encoder_output,
memory_sequence_length=src_input_length,
decoding_strategy='infer_greedy',
beam_width=beam_width,
alpha=alpha,
start_tokens=start_tokens,
end_token=eos_token_id,
max_decoding_length=300,
mode=tf.estimator.ModeKeys.PREDICT
)
if beam_width <= 1:
inferred_ids = predictions[0].sample_id
else:
# Uses the best sample by beam search
inferred_ids = predictions['sample_id'][:, :, 0]

tokenizer = tokenization.FullTokenizer(
vocab_file=os.path.join(bert_pretrain_dir, 'vocab.txt'),
do_lower_case=True)

sess = tf.Session()
def infer_single_example(story,actual_summary,tokenizer):
example = {"src_txt":story,
"tgt_txt":actual_summary
}
features = convert_single_example(1,example,max_seq_length_src,max_seq_length_tgt,tokenizer)
feed_dict = {
src_input_ids:np.array(features.src_input_ids).reshape(-1,1),
src_segment_ids : np.array(features.src_segment_ids).reshape(-1,1)

  }

  references, hypotheses = [], []
  fetches = {
  'inferred_ids': inferred_ids,
  }
  fetches_ = sess.run(fetches, feed_dict=feed_dict)
  labels = np.array(features.tgt_labels).reshape(-1,1)
  hypotheses.extend(h.tolist() for h in fetches_['inferred_ids'])
  # references.extend(r.tolist() for r in labels)
  hypotheses = utils.list_strip_eos(hypotheses, eos_token_id)
  # references = utils.list_strip_eos(references, eos_token_id)
  hwords = tokenizer.convert_ids_to_tokens(hypotheses[0])
  # rwords = tokenizer.convert_ids_to_tokens(references[0])

  hwords = tx.utils.str_join(hwords).replace(" ##","")
  # rwords = tx.utils.str_join(rwords).replace(" ##","")
  # print("Original",rwords)
  print("Generated",hwords)
  return hwords

@app.route("/results",methods=["GET","POST"])
def results():
story = request.form['story']
summary = request.form['summary']
hwords = infer_single_example(story,summary,tokenizer)
return hwords

if name=="main":
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
saver.restore(sess, tf.train.latest_checkpoint(model_dir))
# app.run(host="0.0.0.0",port=1118,debug=False)
story = "Story text about 200 tokens"
summary = "Summary text about 150 tokens"
# story = input("Enter article:").strip("/n")
# summary = input("Enter summary:").strip("/n")
hwords = infer_single_example(story.strip("/n"), summary.strip("/n"), tokenizer)
print(hwords)
and this is the config,py
import texar as tx
dcoder_config = {
'dim': 768,
'num_blocks': 6,
'multihead_attention': {
'num_heads': 8,
'output_dim': 768
# See documentation for more optional hyperparameters
},
'position_embedder_hparams': {
'dim': 768
},
'initializer': {
'type': 'variance_scaling_initializer',
'kwargs': {
'scale': 1.0,
'mode': 'fan_avg',
'distribution': 'uniform',
},
},
'poswise_feedforward': tx.modules.default_transformer_poswise_net_hparams(
output_dim=768)
}

loss_label_confidence = 0.9

random_seed = 1234
beam_width = 5
alpha = 0.6
hidden_dim = 768

opt = {
'optimizer': {
'type': 'AdamOptimizer',
'kwargs': {
'beta1': 0.9,
'beta2': 0.997,
'epsilon': 1e-9
}
}
}

#warmup steps must be 0.1% of number of iterations
lr = {
'learning_rate_schedule': 'constant.linear_warmup.rsqrt_decay.rsqrt_depth',
'lr_constant': 2 * (hidden_dim ** -0.5),
'static_lr': 1e-3,
'warmup_steps': 10000,
}

bos_token_id =101
eos_token_id = 102

model_dir= "./models"
run_mode= "train_and_evaluate"
batch_size = 1
eval_batch_size = 1
test_batch_size =1

max_train_steps = 100000

display_steps = 1
checkpoint_steps = 500
eval_steps = 50000

max_decoding_length = 400

max_seq_length_src = 512
max_seq_length_tgt = 400

epochs =10

is_distributed = False

data_dir = r"data/"

train_out_file = r"data/train.tf_record"
eval_out_file = r"data/eval.tf_record"

bert_pretrain_dir=r"./bert_uncased_model"

train_story = r"data/train_story.txt"
train_summ = r"data/train_summ.txt"

eval_story = r"data/eval_story.txt"
eval_summ = r"data/eval_summ.txt"

bert_pretrain_dir = r"../uncased_L-12_H-768_A-12"

`

from abstractive-summarization-with-transfer-learning.

Simons2017 avatar Simons2017 commented on May 27, 2024

Hello, I met the same problem. How did you solve it?@Tanmay06

from abstractive-summarization-with-transfer-learning.

Tanmay06 avatar Tanmay06 commented on May 27, 2024

Hi, actually I was away and was working on a different project. @Simons2017 I think you should try @santhoshkolloju 's reply just before your comment. I think it should work.

from abstractive-summarization-with-transfer-learning.

yuyanzhoufang avatar yuyanzhoufang commented on May 27, 2024

@santhoshkolloju, hi, How do I change this?
src_input_ids:np.array(features.src_input_ids).reshape(1,-1),
src_segment_ids : np.array(features.src_segment_ids).reshape(1,-1),

from abstractive-summarization-with-transfer-learning.

Tanmay06 avatar Tanmay06 commented on May 27, 2024

@yuyanzhoufang change this line in the original code.

src_input_ids:np.array(features.src_input_ids).reshape(-1,1),
src_segment_ids : np.array(features.src_segment_ids).reshape(-1,1)

from abstractive-summarization-with-transfer-learning.

yuyanzhoufang avatar yuyanzhoufang commented on May 27, 2024

from abstractive-summarization-with-transfer-learning.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.