GithubHelp home page GithubHelp logo

Comments (2)

Coderx7 avatar Coderx7 commented on May 6, 2024

oh I got it!
silly me miscalculated img_size_flattened = 728 !

from tensorflow-tutorials.

tmagg avatar tmagg commented on May 6, 2024

I am trying to run the MNIST code without any error from the below code ( segment 1) , however at the time of restoring the model which is in segment 2 below , i am getting the error as : Cannot feed value of shape (784,) for Tensor 'input:0', which has shape '(?, 784) , kindly suggest

segment 1

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
from random import randint
import numpy as np

logs_path = 'log_mnist_softmax'
batch_size = 100
learning_rate = 0.5
training_epochs = 10
mnist = input_data.read_data_sets("data", one_hot=True)

X = tf.placeholder(tf.float32, [None, 784], name = "input")
Y_ = tf.placeholder(tf.float32, [None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
XX = tf.reshape(X, [-1, 784])

Y = tf.matmul(X, W) + b
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=Y_, logits=Y), name = "output")
correct_prediction = tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

train_step = tf.train.GradientDescentOptimizer(0.005).minimize(cross_entropy)

tf.summary.scalar("cost", cross_entropy)
tf.summary.scalar("accuracy", accuracy)
summary_op = tf.summary.merge_all()

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter(logs_path,
graph=tf.get_default_graph())
for epoch in range(training_epochs):
batch_count = int(mnist.train.num_examples / batch_size)
for i in range(batch_count):
batch_x, batch_y = mnist.train.next_batch(batch_size)
, summary = sess.run([train_step, summary_op], feed_dict={X: batch_x, Y: batch_y})
writer.add_summary(summary, epoch * batch_count + i)
print("Epoch: ", epoch)

print("Accuracy: ", accuracy.eval(feed_dict={X: mnist.test.images, Y_: mnist.test.labels}))
print("done")

num = randint(0, mnist.test.images.shape[0])
img = mnist.test.images[num]

classification = sess.run(tf.argmax(Y, 1), feed_dict={X: [img]})
print('Neural Network predicted', classification[0])
print('Real label is:', np.argmax(mnist.test.labels[num]))

saver = tf.train.Saver()
save_path = saver.save(sess, "data/saved_mnist_cnn.ckpt")
print("Model saved to %s" % save_path)

Segment 2 for restoring the model

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data', one_hot=True)
sess = tf.InteractiveSession()
new_saver = tf.train.import_meta_graph('data\saved_mnist_cnn.ckpt.meta')
new_saver.restore(sess, 'data\saved_mnist_cnn.ckpt')
tf.get_default_graph().as_graph_def()

x = sess.graph.get_tensor_by_name("input:0")

y_conv = sess.graph.get_tensor_by_name("output:0")
image_b = mnist.test.images[100]
result = sess.run(y_conv, feed_dict={x:image_b})
print(result)
print(sess.run(tf.argmax(result, 1)))

plt.imshow(image_b.reshape([28, 28]), cmap='Greys')
plt.show()

from tensorflow-tutorials.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.