GithubHelp home page GithubHelp logo

Comments (4)

mrnikwaws avatar mrnikwaws commented on July 25, 2024

Thanks for raising the issue. We’re going to take a look at the broader issue soon. I'm checking with the team to see if there are any work-arounds in the meantime to allow you to compile large models

from aws-neuron-sdk.

mrnikwaws avatar mrnikwaws commented on July 25, 2024

Hi shepsels,

Thanks for your question. Your problem is related to the 2GB protobuf limit. The team are looking at several improvements, and we'll provide more updates once we've worked through those ideas.

For now here is some code shared by a colleague on how to split compilation of the model where embeddings are used. This won't help in every situation, but it might give you some options with your existing models which exceed 2GB.

The key reusable part is tf.import_graph_def which allows graph surgery by modifying the inputs to a graph. We’ve omitted training parameters and parameter loading code in this example to keep it brief.

We can likely provide some additional advice if you are still stuck and can share the code you are trying to compile.

Problem code:

# show_protobuf_size_limit.py
import tensorflow as tf


dataset_size = 65536
embedding_size = 8192
hidden_size = 1024
num_classes = 2
batch_size = 1

with tf.Session(graph=tf.Graph()) as sess:
    index = tf.placeholder(tf.int32, [batch_size])
    huge_embedding_table = tf.get_variable('huge_embedding_table', [dataset_size, embedding_size])
    embedding = tf.gather(huge_embedding_table, index, axis=0)
    hidden = tf.keras.layers.Dense(hidden_size, activation='relu')(embedding)
    probabilities = tf.keras.layers.Dense(num_classes, activation='softmax')(hidden)
    sess.run(tf.global_variables_initializer())
    feed_dict = {index: [3]}
    result = sess.run(probabilities, feed_dict)
    print('got result', result)
    graph_def = sess.graph.as_graph_def()

    # the following function call breaks due to tf.GraphDef protobuf message size limit
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [probabilities.op.name])

Functional code:

# workaround_protobuf_size_limit.py
import shutil
import numpy as np
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants, signature_constants
import tensorflow.neuron as tfn


# some hyperparameters
dataset_size = 65536
embedding_size = 8192
hidden_size = 1024
num_classes = 2
batch_size = 1

# implement a model that skips the huge embedding table first
model_dir = './saved_model'
neuron_model_dir = './neuron_saved_model'
full_neuron_model_dir = './full_neuron_saved_model'
with tf.Session(graph=tf.Graph()) as sess:
    embedding = tf.placeholder(tf.float32, [batch_size, embedding_size])
    hidden = tf.keras.layers.Dense(hidden_size, activation='relu')(embedding)
    probabilities = tf.keras.layers.Dense(num_classes, activation='softmax')(hidden)
    sess.run(tf.global_variables_initializer())

    # generate a SavedModel
    inputs = {embedding.name: embedding}
    outputs = {probabilities.name: probabilities}
    shutil.rmtree(model_dir, ignore_errors=True)
    tf.saved_model.simple_save(sess, model_dir, inputs, outputs)
    shutil.rmtree(neuron_model_dir, ignore_errors=True)

    # compile SavedModel to Neuron-compatible SavedModel
    tfn.saved_model.compile(model_dir, neuron_model_dir)
    
    # cache some tensor names for later use
    embedding_name = embedding.name
    probabilities_name = probabilities.name

# extract Neuron-compatible GraphDef from Neuron-compatible SavedModel
with tf.Session(graph=tf.Graph()) as sess:
    tags = [tag_constants.SERVING]
    meta_graph = tf.saved_model.loader.load(sess, tags, neuron_model_dir)
    compiled_graph_def = meta_graph.graph_def

# construct the SavedModel
with tf.Session(graph=tf.Graph()) as sess:
    index = tf.placeholder(tf.int32, [batch_size])
    huge_embedding_table = tf.get_variable('huge_embedding_table', [dataset_size, embedding_size])
    embedding = tf.gather(huge_embedding_table, index, axis=0)

    # extend the current tensorflow session/graph with the Neuron-compatible GraphDef
    tf.import_graph_def(compiled_graph_def, name='', input_map={embedding_name: embedding})
    probabilities = sess.graph.get_tensor_by_name(probabilities_name)
    sess.run(tf.global_variables_initializer())

    # run an inference to verify
    feed_dict = {index: [3]}
    result = sess.run(probabilities, feed_dict)

    # save the full SavedModel
    inputs = {index.name: index}
    outputs = {probabilities.name: probabilities}
    shutil.rmtree(model_dir, ignore_errors=True)
    tf.saved_model.simple_save(sess, full_neuron_model_dir, inputs, outputs)

from aws-neuron-sdk.

OpUs-Nebula avatar OpUs-Nebula commented on July 25, 2024

Any progress on this? Saw that ONNX had fixed this issue: onnx/tensorflow-onnx#1090

from aws-neuron-sdk.

BugFreeee avatar BugFreeee commented on July 25, 2024

I'm also having the same issue in 2021. Any updates?

from aws-neuron-sdk.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.