GithubHelp home page GithubHelp logo

Comments (4)

jluntamazon avatar jluntamazon commented on August 29, 2024

The reason why this is not able to be traced in torch_neuronx is because this function cannot be traced in torch (See: https://pytorch.org/docs/stable/generated/torch.jit.trace.html). Tracing works by tracking operations applied to tensors throughout the compute graph.

There are 2 separate issues that prevent torchcrf from producing a graph that can be traced:

  1. The graph uses calls to Tensor.item (https://pytorch.org/docs/stable/generated/torch.Tensor.item.html). This converts the tensor type to a primitive python type which does not allow the resulting value to be tracked by the trace.

    The torchcrf calls are made here: https://github.com/kmkurn/pytorch-crf/blob/master/torchcrf/__init__.py#L326-L334

    Here is an example reproduction which shows that torch will error when attempting to trace a graph which returns a .item() value:

import torch

def item_example(tensor):
    return tensor.item()

inputs = (torch.tensor(1),)
trace = torch.jit.trace(item_example, inputs)
  1. The graph returns a List[List[int]] which is not allowed by torch tracing.

    The torchcrf output type is defined here: https://github.com/kmkurn/pytorch-crf/blob/master/torchcrf/__init__.py#L319

    Here is an example reproduction which shows that torch will error when attempting to return a list of lists:

import torch

def nested_list_example(tensor):
    return [[tensor]]

inputs = (torch.tensor(1),)
trace = torch.jit.trace(nested_list_example, inputs)

To resolve these issues, you would have to modify torchcrf to remove the .item() calls and instead return a Tuple[List[torch.Tensor]] or any another compatible type of your choice. I was able to successfully and accurately execute the model after making these changes in a local version of the torchcrf package.

In general, a good rule of thumb is to first ensure that your model can be traced using torch.jit.trace() before trying torch_neuronx.trace().

It is important to note that the compute defined in this module may not be a good candidate for Neuron hardware if it is being executed in isolation. Neuron hardware excels in scenarios with dense numerically intense compute with many matrix multiplications.

from aws-neuron-sdk.

PrateekAg1511 avatar PrateekAg1511 commented on August 29, 2024

@jluntamazon Thanks a lot for the response and great insights!

I implemented it and it works successfully!

I will be using CRF as the last layer for NER model alongwith BERT model, hence was looking to compile the entire model with torch neuronx.

from aws-neuron-sdk.

PrateekAg1511 avatar PrateekAg1511 commented on August 29, 2024

@jluntamazon Now I am facing issue with batch inferencing using this approach. My output is in Tuple[List[torch.Tensor]] which works well for batch_size =1. But when I try to use DataParallel on the traced model, it says inconsistent size between inputs and outputs. I looked into jit trace and found that even converting directly to torch.tensor would not work as torch.tensor is treated as a constant. I tried creating torch.zeroes(batch_size , seq_length) and then replacing the values in this tensor but that also did not work.

Any pointers on how to make DataParallel work on CRF ?

It would be of great help!

from aws-neuron-sdk.

jyang-aws avatar jyang-aws commented on August 29, 2024

Close as the initial issue is resolved. opening a new one following up the support for DP.

from aws-neuron-sdk.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.