We have submitted two PRs to introduce a new custom data type for FP8 params, also known as OWG params, in this PR and this PR. The purpose of this custom data type is primarily for custom gradient accumulation using the max operation.
After the merger of the aforementioned PRs, we still require one additional change, likely to the LayerwiseShardablePipelined
, to perform the type conversion outside the scan_fn
. This is necessary because the custom data type needs to be recognized before being broadcast into the iterations within the scan_fn
to ensure that autograd correctly applies the custom gradient accumulation.
I have prepared a self-contained Python code for this potential change, which you can find here.
Essentially, you can disregard the lines before line 199 as if they have already been merged. Line 243 represents the proposed dtype conversion to be added to LayerwiseShardablePipelined
, where we convert all OWG params into the custom data type.
However, there is an issue regarding how to obtain the mask of the OWG params. As per my understanding, OWG params physically reside in the PARAMS category, and we have weight hparams to determine if they are OWG or not. However, such weight hparams seem inaccessible inside the LayerwiseShardablePipelined
. In the provided code, I compute the owg_mask
outside in line 263 and pass it as an input to the model.apply
in line 263. Nevertheless, I feel this is not an ideal design since it modifies the model call signature and is specific only to the FP8 scenario.
Ideally, I believe that if we can compute the owg_mask
inside the layer (similar to line 226) by accessing the weight hparams, that would be preferable. I've observed a similar example with bf16_accum_in_fp32
here, although it doesn't require any weight hparams.
To sum up, what is the best practice to obtain the owg_mask
inside the LayerwiseShardablePipelined
where the weight hparams are not available?
(Note, to run the gist code, you need the latest jax build like 0.4.24.devxxxxx)
cc. @zhangqiaorjc