Comments (3)
这是由于代码中使用ACT2CLS导致我们的tracer无法捕获到gelu
- 如果你可以直接修改模型定义的话,可以简单的将你提供的代码里面的GELUActivation修改为:
class GELUActivation(nn.Module):
def __init__(self, use_gelu_python: bool = False):
super().__init__()
# if use_gelu_python:
# self.act = self._gelu_python
# else:
# self.act = nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
def forward(self, input: Tensor) -> Tensor:
# return self.act(input)
return torch.nn.functional.gelu(input)
- 如果你不方便修改模型定义的话,可以在模型初始化时添加import_patcher
def main_worker(args):
with model_tracer():
# 添加import_patcher
with import_patcher():
model = WhisperModelForSequenceClassification()
# Provide a viable input for the model
dummy_input = torch.rand((1, 80, 3000))
print(model(dummy_input))
from tinyneuralnetwork.
抱歉上面的报错信息粘贴错了一部分,下面是正确的报错信息
tensor([[ 0.0254, -0.0090, 0.1155, -0.0195, 0.0668, 0.0332, -0.0074, -0.0186,
0.0280, -0.0085]], grad_fn=)
ERROR (tinynn.graph.tracer) Connection is lost when generating code for audio_encoder_model_layers_0_fc2 of type torch.nn.modules.linear.Linear
Traceback (most recent call last):
File "tinynn/graph/tracer.py", line 3343, in trace
new_graph.init()
File "tinynn/graph/tracer.py", line 2041, in init
self.module(*actual_input)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "onnx_qat_wt.py", line 249, in forward
output_encoder = self.audio_encoder(input_features)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "onnx_qat_wt.py", line 231, in forward
model_output = self.model(input_features)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "onnx_qat_wt.py", line 199, in forward
layer_outputs = encoder_layer(hidden_states)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "onnx_qat_wt.py", line 158, in forward
hidden_states = self.fc2(hidden_states)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/conda/envs/qt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1581, in call_impl
hook_result = hook(self, args, result)
File "tinynn/graph/tracer.py", line 1662, in submodule_tracer
add_forward_node(node, inputs, outputs)
File "tinynn/graph/tracer.py", line 1593, in add_forward_node
pre_node_name = current_graph().tensor_pre_node_dict[id(t)]
KeyError: 140042784477152
ERROR (tinynn.graph.tracer) inputs: ['input_0_f']
ERROR (tinynn.graph.tracer) forwards: ['audio_encoder_model_conv1', 'gelu_0_f', 'audio_encoder_model_conv2', 'gelu_1_f', 'permute_0_f', 'audio_encoder
model_embed_positions', 'weight_0_f', 'add_0_f', 'audio_encoder_model_layers_0_self_attn_layer_norm', 'size_0_f', 'audio_encoder_model_layers_0_self_at
tn_q_proj', 'mul_0_f', 'audio_encoder_model_layers_0_self_attn_k_proj', 'view_0_f', 'transpose_0_f', 'contiguous_0_f', 'audio_encoder_model_layers_0_se
lf_attn_v_proj', 'view_1_f', 'transpose_1_f', 'contiguous_1_f', 'mul_1_f', 'view_2_f', 'transpose_2_f', 'contiguous_2_f', 'view_3_f', 'reshape_0_f', 'r
eshape_1_f', 'transpose_3_f', 'bmm_0_f', 'softmax_0_f', 'bmm_1_f', 'view_4_f', 'transpose_4_f', 'reshape_2_f', 'audio_encoder_model_layers_0_self_attn
out_proj', 'add_1_f', 'audio_encoder_model_layers_0_final_layer_norm', 'audio_encoder_model_layers_0_fc1']
ERROR (tinynn.graph.tracer) outputs: []
ERROR (tinynn.graph.tracer) constants: []
from tinyneuralnetwork.
这是由于代码中使用ACT2CLS导致我们的tracer无法捕获到gelu
- 如果你可以直接修改模型定义的话,可以简单的将你提供的代码里面的GELUActivation修改为:
class GELUActivation(nn.Module): def __init__(self, use_gelu_python: bool = False): super().__init__() # if use_gelu_python: # self.act = self._gelu_python # else: # self.act = nn.functional.gelu def _gelu_python(self, input: Tensor) -> Tensor: return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) def forward(self, input: Tensor) -> Tensor: # return self.act(input) return torch.nn.functional.gelu(input)
- 如果你不方便修改模型定义的话,可以在模型初始化时添加import_patcher
def main_worker(args): with model_tracer(): # 添加import_patcher with import_patcher(): model = WhisperModelForSequenceClassification() # Provide a viable input for the model dummy_input = torch.rand((1, 80, 3000)) print(model(dummy_input))
问题解决了,多谢多谢
from tinyneuralnetwork.
Related Issues (20)
- PyTorch 转 TFLite 使用 int8 量化 HOT 4
- Does tinynn support following int16 quantization? HOT 1
- jit.trace succeed but tinynn tracer failed HOT 1
- It became larger after converting to tflite model HOT 4
- how to do Post-training integer quantization with int16 activation HOT 4
- unnecessary float() variables cause quantization to fail. HOT 7
- aten::index nodes take multiple indices in PyTorch model but cause an error when trying to convert to TFLite HOT 1
- aten::repeat_interleave is considered an unsupported Tensor and causing an error when trying to convert to TFLite HOT 2
- convert model error HOT 5
- 请问下 转tflite 模型能支持签名吗? HOT 9
- [PTQ Converter] 'Linear+relu' module conversion failed HOT 9
- [quantizer] activation nodes that was used multiple times will not work with OP fusion
- [converter] aten::pad with mode="circular"
- [converter] batchnorm + conv fusion
- [Converter] TFLite size is larger than expected HOT 7
- OneShot Pruning Tensor Size Mismatch Error HOT 4
- 关于模型中使用了nn.Parameter()会导致转TFlite会失败的问题 HOT 5
- Huge loss when applying QAT on YOLOv8n HOT 40
- ViT PTQ Error HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tinyneuralnetwork.