Comments (2)
@Neargye I had wrote some code for creating Tensor of strings.
// std::vector<std::vector<std::string> > batch_tokens
// just create flat vector of strings
std::vector<std::string> flat_batch_tokens = FlattenVector(batch_tokens);
// we have matrix batch_size x max_length
const std::vector<std::int64_t> flat_batch_tokens_dims = { batch_size, max_length };
// get length of all tokens (strings)
std::size_t flat_batch_tokens_len = 0;
for (const auto& token : flat_batch_tokens) {
flat_batch_tokens_len += TF_StringEncodedSize(token.size());
}
// add offsets length
std::size_t flat_batch_tokens_offset_len = (batch_size * max_length) * 8;
// prepare buffer for tokens data
auto flat_batch_tokens_data = static_cast<char*>(std::malloc(flat_batch_tokens_offset_len + flat_batch_tokens_len));
memset(flat_batch_tokens_data, 0, flat_batch_tokens_offset_len + flat_batch_tokens_len);
// prepare pointer for tokens offsets
std::uint64_t* offset_uint64_data = reinterpret_cast<std::uint64_t*>(flat_batch_tokens_data);
std::uint64_t tokens_offset = flat_batch_tokens_offset_len;
TF_Status* encode_status = TF_NewStatus();
int it = 0;
// encode every token and add its offset to offset_uint64_data
for (const auto& token : flat_batch_tokens) {
size_t encoded_size = TF_StringEncodedSize(token.size());
TF_StringEncode(token.c_str(), token.size(), flat_batch_tokens_data + tokens_offset, encoded_size, encode_status); // fills the rest of tensor data
//std::cout << token << " (" << token.size() << ") " << tokens_offset << " => " << (flat_batch_tokens_data + tokens_offset) << " (" << encoded_size << ")" << std::endl;
offset_uint64_data[it] = (tokens_offset - flat_batch_tokens_offset_len);
if (TF_GetCode(encode_status) != TF_OK) {
std::cout << "ERROR: something wrong with encoding: " << TF_Message(encode_status) << std::endl;
return false;
}
tokens_offset += encoded_size;
++it;
}
TF_DeleteStatus(encode_status);
TF_Tensor* tokens_tensor = TF_NewTensor(TF_STRING,
flat_batch_tokens_dims.data(), static_cast<int>(flat_batch_tokens_dims.size()),
flat_batch_tokens_data, flat_batch_tokens_offset_len + flat_batch_tokens_len,
DeallocateTensor, nullptr);
// check tensor: can be removed
if (tokens_tensor == nullptr) {
std::cout << "Wrong creat tensor" << std::endl;
return false;
}
if (TF_TensorType(tokens_tensor) != TF_STRING) {
std::cout << "Wrong tensor type" << std::endl;
return false;
}
if (TF_NumDims(tokens_tensor) != static_cast<int>(flat_batch_tokens_dims.size())) {
std::cout << "Wrong number of dimensions" << std::endl;
return false;
}
for (std::size_t i = 0; i < flat_batch_tokens_dims.size(); ++i) {
if (TF_Dim(tokens_tensor, static_cast<int>(i)) != flat_batch_tokens_dims[i]) {
std::cout << "Wrong dimension size for dim: " << i << std::endl;
return false;
}
}
if (TF_TensorByteSize(tokens_tensor) != flat_batch_tokens_offset_len + flat_batch_tokens_len) {
std::cout << "Wrong tensor byte size" << std::endl;
return false;
}
const auto tokens_tensor_data = static_cast<char*>(TF_TensorData(tokens_tensor));
if (tokens_tensor_data == nullptr) {
std::cout << "Wrong data tensor" << std::endl;
return false;
}
// checking tensor strings: can be removed
// TF_STRING tensors require copying since Tensor class expects a sequence of string objects.
const std::int64_t num_elements = batch_size * max_length;
const std::size_t src_size = TF_TensorByteSize(tokens_tensor);
if (static_cast<std::int64_t>(src_size / sizeof(std::uint64_t)) < num_elements) {
std::cout << "Malformed TF_STRING tensor; too short to hold number of elements" << std::endl;
return false;
}
const char* data_start = tokens_tensor_data + sizeof(std::uint64_t) * num_elements;
const char* limit = tokens_tensor_data + src_size;
TF_Status* decode_status = TF_NewStatus();
for (std::int64_t i = 0; i < num_elements; ++i) {
std::uint64_t offset = reinterpret_cast<const std::uint64_t*>(tokens_tensor_data)[i];
if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
std::cout << "Malformed TF_STRING tensor; element " << i << " out of range" << std::endl;
return false;
}
size_t len;
const char* p;
const char* srcp = data_start + offset;
TF_StringDecode(srcp, limit - srcp, &p, &len, decode_status);
if (TF_GetCode(decode_status) != TF_OK) {
std::cout << "ERROR: something wrong with decoding: " << TF_Message(decode_status) << std::endl;
return false;
}
std::string p_string(p, p + len);
std::cout << "p " << i << " == " << p_string << " len = " << len << std::endl;
}
TF_DeleteStatus(decode_status);
std::cout << "Success create tokens_tensor" << std::endl;
This code consists of two parts: create tensor of strings and check it with decoding.
from hello_tf_c_api.
I will fix api soon, to support string tensor.
Could you help me test this?
from hello_tf_c_api.
Related Issues (20)
- Memory leak during inference with frozen graph HOT 9
- session_run hangs on GPU (libtensorflow-gpu) HOT 4
- question about this library HOT 3
- how to turn off verbose and idle threads?
- GPU dll HOT 5
- cuda_driver.cc:175] Check failed HOT 1
- How to create Tensor of TF_BOOL? HOT 2
- TF_INVALID_ARGUMENT
- Inference is running very slow on CPU HOT 1
- Multiple models inference HOT 4
- 3D input to model returns different output than python HOT 1
- What is this actually doing? HOT 2
- TF_SessionRun with multiple outputs gives Segmentation Fault HOT 5
- TF_INVALID_ARGUMENT HOT 1
- Multiple GPU Inferencing HOT 1
- cmake -G "Unix Makefiles" .. stop HOT 1
- Confine TensorFlow C API not to generate more than one threads
- Import LSTM-Layer: Expected input[1] to be control input
- when i load graph the TF_Code is ‘TF_UNKNOWN’ , why?
- when i load graph the TF_Code is ‘TF_INVALID_ARGUMENT ’ , why?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from hello_tf_c_api.