GithubHelp home page GithubHelp logo

Comments (4)

tombettany avatar tombettany commented on July 25, 2024

We had a similar problem when trying to run torch_neuronx.trace() and it was because the output of our model had the device set for cpu rather than an XLA device. Solution we found was just to set FloatTensor(a, device=torch_xla.core.xla_model.xla_device()) when creating the tensor to set the device type. Alternatively you can run .to(torch_xla.core.xla_model.xla_device()) on the output:

import torch_xla.core.xla_model as xm

def tags(output , mask):
    return model.crf.decode(output, mask).to(xm.xla_device())

...

Presumably you could set the device type on the input tensor instead but we did have issues with this on our model.

Hopefully this gets you a little bit further towards solving the problem.

from aws-neuron-sdk.

PrateekAg1511 avatar PrateekAg1511 commented on July 25, 2024

@tombettany Thanks!

I tried this but then got the following warning:

/usr/local/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:143: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=0, shape=torch.Size([1, 60, 184]), dtype=torch.float32)
  warnings.warn(
/usr/local/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:143: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=1, shape=torch.Size([1, 60]), dtype=torch.uint8)
  warnings.warn(

Now the traced model is giving the same output for every input that it gets.

from aws-neuron-sdk.

jluntamazon avatar jluntamazon commented on July 25, 2024

Would you be able to give more details about the model you are trying to trace? If there is a minimal open source reproduction of the error you are encountering, we can try to help you solve the problem.

The warning you are running into indicates that the output of the model does not appear to depend on the inputs. This can happen when the output is calculated entirely based on tensors which are newly constructed within the forward function. This likely happens due to the implementation of the model.crf.decode method.

from aws-neuron-sdk.

PrateekAg1511 avatar PrateekAg1511 commented on July 25, 2024

@jluntamazon Here is the minimal open source reproduction of the error:

import torch
from torchcrf import CRF
num_tags = 184
model = CRF(num_tags)

emissions = torch.rand([1,60,184])
mask = torch.ones([1,60], dtype=torch.uint8)

def decode_fn(emissions , mask):
a = model.decode(emissions , mask)
a = torch.Tensor(a)
a = a.to(xm.xla_device())
return (a)

inputs_crf = emissions , mask

trace_crf = torch_neuronx.trace(decode_fn , inputs_crf)

After running trace, I get the waring message for both inputs.

"/aws_neuron_venv_pytorch/lib64/python3.9/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:144: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=0, shape=torch.Size([1, 60, 184]), dtype=torch.float32)
warnings.warn(
//aws_neuron_venv_pytorch/lib64/python3.9/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:144: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=1, shape=torch.Size([1, 60]), dtype=torch.uint8)
warnings.warn( "

from aws-neuron-sdk.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.