GithubHelp home page GithubHelp logo

Comments (12)

Romit-Maulik avatar Romit-Maulik commented on June 12, 2024 2

Yup! So for a pretrained and frozen graph you would use something like

// Inputs and dimensions
float input_vals[num_samples][num_inputs]; // populate this with inputs
const std::vector<std::int64_t> input_dims = {num_samples, num_inputs}; //dimensions
// Set up tensors
TF_Tensor* output_tensor_ = nullptr;
TF_Tensor* input_tensor_ = tf_utils::CreateTensor(TF_FLOAT,
                                          input_dims, input_dims.size(),
                                          &input_vals, num_samples*num_inputs*sizeof(float));
// Arrays of tensors
TF_Tensor* input_tensors_[1] = {input_tensor_}; // Array of all the inputs to the model
TF_Tensor* output_tensors_[1] = {output_tensor_}; // Array of all the outputs from the model
// Arrays of operations
TF_Output inputs[1] = {input_ph_}; // The input placeholder(s) 
TF_Output outputs[1] = {output_}; // The output operation

TF_SessionRun(sess_,
                nullptr, // Run options.
                inputs, input_tensors_, 1, // Input tensor ops, input tensor values, number of inputs.
                outputs, output_tensors_, 1, // Output tensor ops, output tensor values, number of outputs.
                nullptr, 0, // *No* target operations, number of target ops.
                nullptr, // Run metadata.
                status_ // Output status.
);

from hello_tf_c_api.

Romit-Maulik avatar Romit-Maulik commented on June 12, 2024 1

I see - thanks! That's what I started out with and am now toying with the idea of moving more functionality into C++.

PS - I found a stackoverflow answer that details (a bit about how one may save a graph to a checkpoint in C++) - I'll look into it.

from hello_tf_c_api.

Neargye avatar Neargye commented on June 12, 2024

Try this

...
TF_Operation* train_ = TF_GraphOperationByName(graph_, "train_step"); 
std::vector<TF_Operation*> target_opers = {train_ };
...
TF_SessionRun(sess_, nullptr, // Run options.
	                &input_op_, &input_tensor_, 1, // Input tensors, input tensor values, number of inputs.
	                &out_op_, &output_tensor_, 1, // Output tensors, output tensor values, number of outputs.
	                target_opers.data(), target_opers.size(), // Target operations, number of targets.
	                nullptr, // Run metadata.
	                status_ // Output status.
	                );

from hello_tf_c_api.

Neargye avatar Neargye commented on June 12, 2024

TF_SessionRun takes an array of target, so the type const TF_Operation* const*.

from hello_tf_c_api.

Romit-Maulik avatar Romit-Maulik commented on June 12, 2024

Hi Neargye - thanks! I managed to fix the issue by doing the following:

TF_Operation* train_ = TF_GraphOperationByName(graph_, "train_step");
TF_SessionRun(sess_, nullptr, // Run options.
	                &input_op_, &input_tensor_, 1, // Input tensors, input tensor values, number of inputs.
	                &out_op_, &output_tensor_, 1, // Output tensors, output tensor values, number of outputs.
	                &train_, 1, // Target operations, number of targets.
	                nullptr, // Run metadata.
	                status_ // Output status.
	                );

which I suspect allowed a safe cast.

While I have you here - I was trying to use the above code to train an unfrozen graph in C++ (I had defined the graph through the tensorflow python API). However, I noticed that unfrozen graphs cannot be used for any sort of output- (for instance to print the loss). Have you had any experience with this issue and is there a work around for it?

I suspect I could dump a checkpoint graph to disk and assess on the side - but in-situ would be nice. Do you have an example that shows how to save a graph to disk from C++? Thanks!

from hello_tf_c_api.

Neargye avatar Neargye commented on June 12, 2024

Unfortunately, I did not work with ungrozen graph. I usually use python to train, and use on c ++ then to run.

from hello_tf_c_api.

carlos-vaz avatar carlos-vaz commented on June 12, 2024

What is target operations? I thought that the session (which contains the graph) included all operations. Why would TF_SessionRun take the session as the first argument AS WELL AS a list of all the target operations?

from hello_tf_c_api.

Romit-Maulik avatar Romit-Maulik commented on June 12, 2024

The target operations here correspond to those you want to run during the session on the unfrozen graph. In the previous snippet they were those related to optimization of the graph. This gist has all the code you'll need to understand this!

https://gist.github.com/asimshankar/7c9f8a9b04323e93bb217109da8c7ad2

from hello_tf_c_api.

carlos-vaz avatar carlos-vaz commented on June 12, 2024

Thanks Romit! So if I understand correctly, if you want to include all graph operations you leave Target ops as NULL? I am using a frozen graph of Deeplab pre-trained on the Cityscapes dataset, so I believe it is fully optimized and ready to predict.

from hello_tf_c_api.

youngallien avatar youngallien commented on June 12, 2024

Hi ,
i want to do gradient descent in tensorflow-C_api ,but can't find any example tell me how to code,
such as :
how to set and initialize a variable ?
how to set optimizer and train?
@Romit-Maulik @Neargye Thanks for your attention and hope to receive your reply.

from hello_tf_c_api.

Neargye avatar Neargye commented on June 12, 2024

@yangjituan you can also check https://gist.github.com/asimshankar/7c9f8a9b04323e93bb217109da8c7ad2.

from hello_tf_c_api.

youngallien avatar youngallien commented on June 12, 2024

@Neargye Thanks very much, It's very useful,

from hello_tf_c_api.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.