GithubHelp home page GithubHelp logo

ONNX进行QAT报错 about tinyneuralnetwork HOT 3 CLOSED

zeo233 avatar zeo233 commented on July 24, 2024
ONNX进行QAT报错

from tinyneuralnetwork.

Comments (3)

zk1998 avatar zk1998 commented on July 24, 2024 1

这是由于代码中使用ACT2CLS导致我们的tracer无法捕获到gelu

  1. 如果你可以直接修改模型定义的话,可以简单的将你提供的代码里面的GELUActivation修改为:
class GELUActivation(nn.Module):
    def __init__(self, use_gelu_python: bool = False):
        super().__init__()
        # if use_gelu_python:
        #     self.act = self._gelu_python
        # else:
            # self.act = nn.functional.gelu

    def _gelu_python(self, input: Tensor) -> Tensor:
        return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))

    def forward(self, input: Tensor) -> Tensor:
        # return self.act(input)
        return torch.nn.functional.gelu(input)
  1. 如果你不方便修改模型定义的话,可以在模型初始化时添加import_patcher
def main_worker(args):

    with model_tracer():
        # 添加import_patcher
        with import_patcher():
            model = WhisperModelForSequenceClassification()

        # Provide a viable input for the model
        dummy_input = torch.rand((1, 80, 3000))
        print(model(dummy_input))

from tinyneuralnetwork.

zeo233 avatar zeo233 commented on July 24, 2024

抱歉上面的报错信息粘贴错了一部分,下面是正确的报错信息
tensor([[ 0.0254, -0.0090, 0.1155, -0.0195, 0.0668, 0.0332, -0.0074, -0.0186,
0.0280, -0.0085]], grad_fn=)
ERROR (tinynn.graph.tracer) Connection is lost when generating code for audio_encoder_model_layers_0_fc2 of type torch.nn.modules.linear.Linear
Traceback (most recent call last):
File "tinynn/graph/tracer.py", line 3343, in trace
new_graph.init()
File "tinynn/graph/tracer.py", line 2041, in init
self.module(*actual_input)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "onnx_qat_wt.py", line 249, in forward
output_encoder = self.audio_encoder(input_features)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "onnx_qat_wt.py", line 231, in forward
model_output = self.model(input_features)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "onnx_qat_wt.py", line 199, in forward
layer_outputs = encoder_layer(hidden_states)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "onnx_qat_wt.py", line 158, in forward
hidden_states = self.fc2(hidden_states)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1581, in call_impl
hook_result = hook(self, args, result)
File "tinynn/graph/tracer.py", line 1662, in submodule_tracer
add_forward_node(node, inputs, outputs)
File "tinynn/graph/tracer.py", line 1593, in add_forward_node
pre_node_name = current_graph().tensor_pre_node_dict[id(t)]
KeyError: 140042784477152
ERROR (tinynn.graph.tracer) inputs: ['input_0_f']
ERROR (tinynn.graph.tracer) forwards: ['audio_encoder_model_conv1', 'gelu_0_f', 'audio_encoder_model_conv2', 'gelu_1_f', 'permute_0_f', 'audio_encoder

model_embed_positions', 'weight_0_f', 'add_0_f', 'audio_encoder_model_layers_0_self_attn_layer_norm', 'size_0_f', 'audio_encoder_model_layers_0_self_at
tn_q_proj', 'mul_0_f', 'audio_encoder_model_layers_0_self_attn_k_proj', 'view_0_f', 'transpose_0_f', 'contiguous_0_f', 'audio_encoder_model_layers_0_se
lf_attn_v_proj', 'view_1_f', 'transpose_1_f', 'contiguous_1_f', 'mul_1_f', 'view_2_f', 'transpose_2_f', 'contiguous_2_f', 'view_3_f', 'reshape_0_f', 'r
eshape_1_f', 'transpose_3_f', 'bmm_0_f', 'softmax_0_f', 'bmm_1_f', 'view_4_f', 'transpose_4_f', 'reshape_2_f', 'audio_encoder_model_layers_0_self_attn

out_proj', 'add_1_f', 'audio_encoder_model_layers_0_final_layer_norm', 'audio_encoder_model_layers_0_fc1']
ERROR (tinynn.graph.tracer) outputs: []
ERROR (tinynn.graph.tracer) constants: []

from tinyneuralnetwork.

zeo233 avatar zeo233 commented on July 24, 2024

这是由于代码中使用ACT2CLS导致我们的tracer无法捕获到gelu

  1. 如果你可以直接修改模型定义的话,可以简单的将你提供的代码里面的GELUActivation修改为:
class GELUActivation(nn.Module):
    def __init__(self, use_gelu_python: bool = False):
        super().__init__()
        # if use_gelu_python:
        #     self.act = self._gelu_python
        # else:
            # self.act = nn.functional.gelu

    def _gelu_python(self, input: Tensor) -> Tensor:
        return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))

    def forward(self, input: Tensor) -> Tensor:
        # return self.act(input)
        return torch.nn.functional.gelu(input)
  1. 如果你不方便修改模型定义的话,可以在模型初始化时添加import_patcher
def main_worker(args):

    with model_tracer():
        # 添加import_patcher
        with import_patcher():
            model = WhisperModelForSequenceClassification()

        # Provide a viable input for the model
        dummy_input = torch.rand((1, 80, 3000))
        print(model(dummy_input))

问题解决了,多谢多谢

from tinyneuralnetwork.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.