Comments (7)
可以先不要amp,用float32试一下吗?我这边暂时没有环境,明天再帮你看一下哈
from paddle.
main_grad不是只用于AMP吗?
from paddle.
inp = paddle.normal(mean=0, std=0.01, shape=[1, 32, 32]).astype('float32')
输入用 fp32 就可以了,amp会把它变成16,改成fp32可以跑通
from paddle.
Paddle框架中目前存在两种支持main_grad的方式:
- 使用
mix_precision_utils.MixPrecisionOptimizer
封装optimizer
- 使用
paddle.amp.decorate
并设置master_grad=True
两种方式不可同时启用,当前分布式环境下建议使用第一种。后续框架将会统一两种用法。
from paddle.
我按照意见,使用FP32输入,并且只使用mix_precision_utils.MixPrecisionOptimizer
封装optimizer,测试仍然会遇到一样的错误。
from paddle.
本地使用下面的脚本运行,可以跑通呀,咱们是不是环境没对齐?
(这个脚本直接把 nlp 的代码复制过来了,单文件就可以跑)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import _C_ops
from paddle.framework import core
import numpy as np
from paddle.distributed.fleet.utils import mix_precision_utils
def is_fused_matmul_bias_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() or paddle.is_compiled_with_xpu():
return hasattr(core.eager.ops.legacy, "fused_gemm_epilogue")
else:
return False
if is_fused_matmul_bias_supported():
origin_linear = paddle.incubate.nn.functional.fused_linear
else:
origin_linear = paddle.nn.functional.linear
class FusedLinearWithGradAdd(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight, bias=None, name=None):
y = origin_linear(x, weight, bias)
ctx.save_for_backward(x, weight, bias)
return y
@staticmethod
def backward(ctx, y_grad):
x, weight, bias = ctx.saved_tensor()
x_grad = paddle.matmul(y_grad, weight, transpose_y=True)
# _C_ops.fused_linear_param_grad_add(x, y_grad, dw, db, multi precision, has bias)
if bias is None:
if hasattr(weight, "main_grad"):
weight.main_grad, _ = _C_ops.fused_linear_param_grad_add(
x, y_grad, weight.main_grad, None, True, False
)
return x_grad, None
else:
if weight.grad is not None:
weight.grad, _ = _C_ops.fused_linear_param_grad_add(x, y_grad, weight.grad, None, False, False)
return x_grad, None
else:
weight_grad, _ = _C_ops.fused_linear_param_grad_add(x, y_grad, None, None, False, False)
return x_grad, weight_grad
if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"):
weight.main_grad, bias.main_grad = _C_ops.fused_linear_param_grad_add(
x, y_grad, weight.main_grad, bias.main_grad, True, True
)
return x_grad, None, None
else:
if weight.grad is not None:
assert bias.grad is not None
weight.grad, bias.grad = _C_ops.fused_linear_param_grad_add(
x, y_grad, weight.grad, bias.grad, False, True
)
return x_grad, None, None
else:
weight_grad, bias_grad = _C_ops.fused_linear_param_grad_add(x, y_grad, None, None, False, True)
return x_grad, weight_grad, bias_grad
def mock_layers():
paddle.nn.functional.linear = FusedLinearWithGradAdd.apply
if is_fused_matmul_bias_supported():
paddle.incubate.nn.functional.fused_linear = FusedLinearWithGradAdd.apply
mock_layers()
def create_optimizer(model, use_pure_bf16, use_main_grad):
if use_main_grad:
assert use_pure_bf16
model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16")
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=0.0001,
multi_precision=use_pure_bf16,
)
if use_main_grad:
optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer)
return optimizer
class Net(paddle.nn.Layer):
"""Network use for recompute testing"""
def __init__(self):
super().__init__()
self.layer = paddle.nn.Linear(32, 32)
def forward(self, inp):
out = self.layer(inp)
return out
def main():
paddle.seed(10)
model = Net()
optimizer = create_optimizer(model, use_pure_bf16=True, use_main_grad=True)
model = paddle.amp.decorate(models=model, dtype="bfloat16", level='O2', master_grad=True)
model.train()
for _ in range(10):
inp = paddle.normal(mean=0, std=0.01, shape=[1, 32, 32]).astype('float32')
inp.stop_gradient = False
with paddle.amp.auto_cast(True, level="O2", dtype="bfloat16"):
out = model(inp)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
print(loss)
if __name__ == "__main__":
main()
from paddle.
已解决,关掉这个issue
from paddle.
Related Issues (20)
- Segmentation fault and Aborted in `paddle.index_fill_`. HOT 4
- 反馈部分API文档中的bug和docs仓库下的rst文档bug HOT 1
- develop 版本 padlde 安装后 Illegal instruction (core dumped) HOT 10
- paddle.jit.save failed! HOT 3
- paddle 无法使用jit比jax训练慢 HOT 1
- http://paddlepaddle.org.cn/paddlehub页面不存在
- 拉去paddle docker paddle.utils.run_check()的时候出错 HOT 3
- 2.6.1 LRScheduler set_state_dict 出现多余warning HOT 2
- windows 11上编译paddle框架出现找不到UINT64_MAX符号 HOT 2
- The type of data we are trying to retrieve (float32) does not match the type of data (int8) currently contained in the container. HOT 1
- paddle静态模型在mac预测输出正确,转化为paddlelite模型输出为nan,数据类型均为fp32 HOT 2
- ppyoloe_r c++ 部署推理,网络输出检测框为非矩形框,python 推理没问题 HOT 2
- ppyoloe_r 按照官方案例训练脊柱数据集,但是分类loss不收敛 HOT 4
- 海光DCU ImportError: libgalaxyhip.so.5: cannot open shared object file: No such file or directory HOT 3
- 请问下multinomial的kernel,好像不能固定seed?
- paddle报错,FatalError: `Segmentation fault` is detected by the operating system. HOT 2
- paddle/paddle/phi/kernels/gpu里面的文件要怎么打印tensor?
- ParameterList不支持slice HOT 2
- ubuntu22.04 使用 make 编译失败,提示大量 `/usr/bin/ld: ../../../libphi_core.so: undefined reference`
- Paddle张量slice报错 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from paddle.