GithubHelp home page GithubHelp logo

Comments (2)

LeeDark avatar LeeDark commented on June 12, 2024 1

@Neargye I had wrote some code for creating Tensor of strings.

        // std::vector<std::vector<std::string> > batch_tokens
        // just create flat vector of strings
        std::vector<std::string> flat_batch_tokens = FlattenVector(batch_tokens);
        // we have matrix batch_size x max_length
	const std::vector<std::int64_t> flat_batch_tokens_dims = { batch_size, max_length };

        // get length of all tokens (strings)
	std::size_t flat_batch_tokens_len = 0;
	for (const auto& token : flat_batch_tokens) {
		flat_batch_tokens_len += TF_StringEncodedSize(token.size());
	}

        // add offsets length
	std::size_t flat_batch_tokens_offset_len = (batch_size * max_length) * 8;

        // prepare buffer for tokens data
	auto flat_batch_tokens_data = static_cast<char*>(std::malloc(flat_batch_tokens_offset_len + flat_batch_tokens_len));
	memset(flat_batch_tokens_data, 0, flat_batch_tokens_offset_len + flat_batch_tokens_len);

        // prepare pointer for tokens offsets
	std::uint64_t* offset_uint64_data = reinterpret_cast<std::uint64_t*>(flat_batch_tokens_data);

	std::uint64_t tokens_offset = flat_batch_tokens_offset_len;
	TF_Status* encode_status = TF_NewStatus();
	int it = 0;
        // encode every token and add its offset to offset_uint64_data
	for (const auto& token : flat_batch_tokens) {
		size_t encoded_size = TF_StringEncodedSize(token.size());

		TF_StringEncode(token.c_str(), token.size(), flat_batch_tokens_data + tokens_offset, encoded_size, encode_status); // fills the rest of tensor data

		//std::cout << token << " (" << token.size() << ") " << tokens_offset << " => " << (flat_batch_tokens_data + tokens_offset) << " (" << encoded_size << ")" << std::endl;
		offset_uint64_data[it] = (tokens_offset - flat_batch_tokens_offset_len);

		if (TF_GetCode(encode_status) != TF_OK) {
			std::cout << "ERROR: something wrong with encoding: " << TF_Message(encode_status) << std::endl;
			return false;
		}

		tokens_offset += encoded_size;
		++it;
	}
	
	TF_DeleteStatus(encode_status);

	TF_Tensor* tokens_tensor = TF_NewTensor(TF_STRING,
		flat_batch_tokens_dims.data(), static_cast<int>(flat_batch_tokens_dims.size()),
		flat_batch_tokens_data, flat_batch_tokens_offset_len + flat_batch_tokens_len,
		DeallocateTensor, nullptr);

	// check tensor: can be removed
	if (tokens_tensor == nullptr) {
		std::cout << "Wrong creat tensor" << std::endl;
		return false;
	}
	if (TF_TensorType(tokens_tensor) != TF_STRING) {
		std::cout << "Wrong tensor type" << std::endl;
		return false;
	}
	if (TF_NumDims(tokens_tensor) != static_cast<int>(flat_batch_tokens_dims.size())) {
		std::cout << "Wrong number of dimensions" << std::endl;
		return false;
	}
	for (std::size_t i = 0; i < flat_batch_tokens_dims.size(); ++i) {
		if (TF_Dim(tokens_tensor, static_cast<int>(i)) != flat_batch_tokens_dims[i]) {
			std::cout << "Wrong dimension size for dim: " << i << std::endl;
			return false;
		}
	}
	if (TF_TensorByteSize(tokens_tensor) != flat_batch_tokens_offset_len + flat_batch_tokens_len) {
		std::cout << "Wrong tensor byte size" << std::endl;
		return false;
	}
	const auto tokens_tensor_data = static_cast<char*>(TF_TensorData(tokens_tensor));
	if (tokens_tensor_data == nullptr) {
		std::cout << "Wrong data tensor" << std::endl;
		return false;
	}

	// checking tensor strings: can be removed
	// TF_STRING tensors require copying since Tensor class expects a sequence of string objects.
	const std::int64_t num_elements = batch_size * max_length;
	const std::size_t src_size = TF_TensorByteSize(tokens_tensor);
	if (static_cast<std::int64_t>(src_size / sizeof(std::uint64_t)) < num_elements) {
		std::cout << "Malformed TF_STRING tensor; too short to hold number of elements" << std::endl;
		return false;
	}
	const char* data_start = tokens_tensor_data + sizeof(std::uint64_t) * num_elements;
	const char* limit = tokens_tensor_data + src_size;
	TF_Status* decode_status = TF_NewStatus();
	for (std::int64_t i = 0; i < num_elements; ++i) {
		std::uint64_t offset = reinterpret_cast<const std::uint64_t*>(tokens_tensor_data)[i];
		if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
			std::cout << "Malformed TF_STRING tensor; element " << i << " out of range" << std::endl;
			return false;
		}

		size_t len;
		const char* p;
		const char* srcp = data_start + offset;
		TF_StringDecode(srcp, limit - srcp, &p, &len, decode_status);
		if (TF_GetCode(decode_status) != TF_OK) {
			std::cout << "ERROR: something wrong with decoding: " << TF_Message(decode_status) << std::endl;
			return false;
		}

		std::string p_string(p, p + len);
		std::cout << "p " << i << " == " << p_string << " len = " << len << std::endl;
	}

	TF_DeleteStatus(decode_status);

	std::cout << "Success create tokens_tensor" << std::endl;

This code consists of two parts: create tensor of strings and check it with decoding.

from hello_tf_c_api.

Neargye avatar Neargye commented on June 12, 2024

I will fix api soon, to support string tensor.
Could you help me test this?

from hello_tf_c_api.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.