Hello, I have tried running the demos for polyrnn++ following the instructions in the readme however, I keep getting an error when trying to restore polyrnn metagraph. I would appreciate any help on this, more details regarding the error is the following:
Code:
#Initializing and restoring PolyRNN++
model = PolygonModel(PolyRNN_metagraph, polyGraph)
model.register_eval_fn(lambda input_: evaluator.do_test(evalSess, input_))
polySess = tf.Session(config=tf.ConfigProto(
allow_soft_placement=True
), graph=polyGraph)
model.saver.restore(polySess, PolyRNN_checkpoint)
Error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
in ()
1 #Initializing and restoring PolyRNN++
----> 2 model = PolygonModel(PolyRNN_metagraph, polyGraph)
3 model.register_eval_fn(lambda input_: evaluator.do_test(evalSess, input_))
4 polySess = tf.Session(config=tf.ConfigProto(
5 allow_soft_placement=True
/media/nelson/Workspace1/Projects/building_reconstruction/polyrnn/src/PolygonModel.py in init(self, meta_graph_path, graph)
30 self.saver = None
31 self.eval_pred_fn = None
---> 32 self._restore_graph(meta_graph_path)
33
34 def _restore_graph(self, meta_graph_path):
/media/nelson/Workspace1/Projects/building_reconstruction/polyrnn/src/PolygonModel.py in _restore_graph(self, meta_graph_path)
34 def _restore_graph(self, meta_graph_path):
35 with self.graph.as_default():
---> 36 self.saver = tf.train.import_meta_graph(meta_graph_path, clear_devices=False)
37
38 def _prediction(self):
/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in import_meta_graph(meta_graph_or_file, clear_devices, import_scope, **kwargs)
1925 clear_devices=clear_devices,
1926 import_scope=import_scope,
-> 1927 **kwargs)
1928 if meta_graph_def.HasField("saver_def"):
1929 return Saver(saver_def=meta_graph_def.saver_def, name=import_scope)
/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/meta_graph.pyc in import_scoped_meta_graph(meta_graph_or_file, clear_devices, graph, import_scope, input_map, unbound_inputs_col_name, restore_collections_predicate)
739 importer.import_graph_def(
740 input_graph_def, name=(import_scope or ""), input_map=input_map,
--> 741 producer_op_list=producer_op_list)
742
743 # Restores all the other collections.
/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/deprecation.pyc in new_func(*args, **kwargs)
430 'in a future version' if date is None else ('after %s' % date),
431 instructions)
--> 432 return func(*args, **kwargs)
433 return tf_decorator.make_decorator(func, new_func, 'deprecated',
434 _add_deprecated_arg_notice_to_docstring(
/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/importer.pyc in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
678 'Input types mismatch (expected %r but got %r)'
679 % (', '.join(dtypes.as_dtype(x).name for x in input_types),
--> 680 ', '.join(x.name for x in op._input_types))))
681 # pylint: enable=protected-access
682
ValueError: graph_def is invalid at node u'GatherTree': Input types mismatch (expected 'int32, int32, int32, int32' but got 'int32, int32, int32').