f-dangel / singd Goto Github PK
View Code? Open in Web Editor NEW[ICML 2024] SINGD: KFAC-like Structured Inverse-Free Natural Gradient Descent (http://arxiv.org/abs/2312.05705)
Home Page: https://singd.readthedocs.io/en/latest/
[ICML 2024] SINGD: KFAC-like Structured Inverse-Free Natural Gradient Descent (http://arxiv.org/abs/2312.05705)
Home Page: https://singd.readthedocs.io/en/latest/
At the moment, all methods fall back to .to_dense
and .from_dense
.
We originally incorporated this function to support updates based K.T @ (a @ a.T) @ K
and K.T @ (g @ g.T) @ K
, e.g. in the private ASDL implementation.
But now we are using an update that computes (K.T @ a) (K.T @ a).T
and (K.T @ g) @ (K.T @ g).T
instead via StructuredMatrix.from_inner
. Hence, we can remove from_inner2
and the associated tests.
Once KFAC is part of a curvlinops
release, we can try to remove all KFAC related computations from this repository and use curvlinops
instead. Currently this is not trivially possible, but I think it would be really valuable to move to this solution at some point, since it allows us to only maintain one KFAC implementation that other methods and projects can use.
There is the expand and the reduce setting, which is determined by the number of loss terms. If there are NR
loss terms, we are in the expand setting and if there are N
loss terms we are in the reduce setting (N
is the number of examples and R
is the size of the weight-sharing dimension, e.g. the sequence length).
Then there are the KFAC-expand and KFAC-reduce approximations. Both approximations can be applied in both settings.
Currently, we only have one argument to the SINGD
optimizer called kfac_approx
, allowing the user to specify the KFAC approximation that should be used. There is no way to specify the setting that is used, i.e. the number of loss terms. This is not an issue for almost all cases, but one: When batch_averaged=True
, the preconditioner is rescaled such that the scaling is consistent with the one of the mini-batch gradient. Hence, we need to know over how many terms the loss was averaged to apply the correct scaling. Since the setting cannot be deduced by the KFAC approximation, additional information is necessary here. See these comments for more information on this issue.
In the current code, we assume that the setting corresponds to the KFAC approximation that is used. This does not hold in many cases, e.g. in the vision transformer image classification experiment with KFAC-expand, which is an example for the reduce setting.
Add a second argument allowing the user to explicitly specify the setting, i.e. the number of loss terms. We have to find a way that is not too confusing to people unfamiliar with the expand and reduce terminology.
It might make sense to only use "expand"
and "reduce"
for the KFAC approximation and not the setting. Instead, we could call the second argument something like loss_type
and possible values could be something like "batch"
(reduce setting) and "batch+sequence"
(expand setting). Alternatively, n_loss_terms
with options "batch"
and "batch*sequence"
.
We have started using _tensors_to_sync
in many other places that have nothing to do with data-parallel training.
It would be better if there was a mechanism to register tensors (similar to register_parameter
for nn.Modules
):
register_tensor(self, t: Tensor, name: str)
: Register a tensor as component of a structured matrix.tensors(self) -> Iterable[Tensor]
: Yield all tensors constituents of a structured matrix.named_tensors(self) -> Iterable[Tuple[str, Tensor]]
: Yield names and tensor constituents of a structured matrix.Then, we can remove _tensors_to_sync
and use tensors()
to handle accumulation, but also to implement in-place addition & multiplication in the base class.
This additional scaling trick is crucial to improve stability in bfloat16
, and should make the method as stable as @yorkerlin's hacky version.
For example:
The PyTorch convention is to divide the loss by the number of iterations (number of micro-batches) that are accumulated into a mini-batch gradient, and by the number of processes:
loss /= iters_to_accumulate * num_procs
Instead, our example on how to use gradient accumulation does not do apply this division and is therefore inconsistent. To fix this, we have to rewrite the accumulator update
function.
For the documentation, we should mention that it is crucial to use the PyTorch convention when accumulating over micro-batches.
(follow-up from #34)
The current version of _extract_patches
does not support string-valued padding, but only 2-tuples, although strings are allowed to specify the padding in torch.nn.Conv2d
.
We could use this function from my einconv
library to compute the left and right paddings given the convolution's hyper-parameters.
The current code still has commented code leftovers and should be at least cleaned up a bit before the first release.
Also rename the optimizer from SNGD
into SINGD
.
An unpolished version has already been created in docs/examples/example_04_advanced.py
.
You can look at docs/examples/example_03_param_groups.py
for inspiration.
The structured matrices should work with any dimension.
For neural networks, this dimension is often >10, but there might be rare cases where the dimension is 1. This should be added to the test cases.
(Maybe not necessary: Testing that everything works for 0 x 0
matrices, because this never happens in practise.)
asdl
-related codesparse_ngd.experimental
codeSINGD
relies on full_backward_hook
s, which are incompatible with in-place operations (see pytorch/pytorch#61519).
Minimum example to reproduce:
"""SINGD with a model that uses in-place activations."""
from singd.optim.optimizer import SINGD
from torch import rand
from torch.nn import Conv2d, ReLU, Sequential
TRIGGER_BUG = True
model = Sequential(
Conv2d(1, 1, 3),
ReLU(inplace=TRIGGER_BUG),
)
optim = SINGD(model) # install hooks
batch_size = 2
X = rand(2, 1, 5, 5)
output = model(X)
Output:
File "...", line 16, in <module>
output = model(X)
File "...", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/.../python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/.../python3.9/site-packages/torch/nn/modules/container.py", line 215, in forward
input = module(input)
File "/.../python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/.../python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/.../python3.9/site-packages/torch/nn/modules/activation.py", line 101, in forward
return F.relu(input, inplace=self.inplace)
File "/.../python3.9/site-packages/torch/nn/functional.py", line 1469, in relu
result = torch.relu_(input)
RuntimeError: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.
Potential fix: (Not verified if working) Replace the full_backward_hook
with a tensor backward_hook
on the output of a module. This is the officially recommended solution. The tensor hook can be installed with a forward hook (see this solution).
The Kronecker approximation depends on a NN architecture. We should support important GNN layers such as the GCNConv
layer. A reference Kronecker approximation for the GCNConv can be found at here. @runame proposed and implemented Kronecker approximations for another GNN layer.
During one of our internal discussions, we realized that the code starts to accumulate multiple re-scaling operations which are required to avoid over/under-floating when using float16
(e.g. using the average trace #24, splitting the 1 / batch_size
and grad_scale
when computing H_K, H_C
). The problem with those tricks is that their implementation is often non-local, i.e. they affect multiple functions and are thus hard to understand. Long-term, the accumulation of such tricks will make the code unusably complex.
In my opinion, we should attempt to keep the optimizer's update step as scaling trick-free as possible, and move such complexity inside the StructuredMatrix
interface. Also, we will need a strong motivation to implement this, and examples demonstrating the effectiveness of such tricks, as it could double the code base for the structured matrices. One possible scenario is that float16
is simply too unstable, and we will thus not support it.
Idea 1 (stable structured matrices): We could support a DenseStableMatrix
class, which is the equivalent of DenseMatrix
but has special implementations with higher stability (but also cost) in float16
. This structure could then be easily accessed by the optimizer by adding a 'dense_stable'
option to its supported entries in structures
. A simple idea that might keep the computation stable might be to treat each Tensor
t
of a structured matrix internally as a (float, Tensor)
tuple (scale, t_scaled)
, that is a normalized tensor and its scaling factor. For instance, multiplication by a scalar alpha
will then just correspond to using (alpha * scale, t_scaled)
internally.
However, to me it is currently unclear if such a simple heuristic will consistently make all operations stable.
Alternative ideas go here
Related to #53.
When processing the input of a convolution, we unfold it first, then average over channel groups.
Instead, we can first average the channel groups, then unfold. This reduces the unfold op's memory and FLOPS by a factor of groups
.
A placeholder has already been created in docs/examples/example_01_basic.py
.
You can look at docs/examples/example_03_param_groups.py
for inspiration.
We can thus deprecate the functions supported_conv1d
, supported_matmul
, and supported_einsum
.
The current documentation has many placeholders that need to be filled (e.g. installation instructions, link to documentation/examples, ...).
bibtex
block of READMEAfter #26, the KFAC utilities will have solid tests and we can apply one more round of more aggressive refactoring targeting readability of the code.
As an example, I think we should extensively rely on einops
and get rid of all the ugly reshapes/transposes/sums/means that require mentally keeping track of what is happening with the shape of a quantity.
At the moment, all methods fall back to .to_dense
and .from_dense
.
Most of the structures are easy to explain in equations. But equations are hard to read in source code (docstrings).
Hence, it will be beneficial to add an 'Internal' section to the documentation where all structured matrices and templates to create new ones are contained and the math in their docstrings is nicely rendered.
We only support a specific form of DDP. The easiest way to explain is to a user is an example. However, it will be challenging to execute such an example in the CI as it requires access to a multi-GPU machine.
(follow-up from #34)
The current version of _extract_patches
does not take into account the dilation
of a convolution. Supporting this is not urgent as I am not aware of any torchvision
CNN that uses dilated convolution. But we should fix this long-term because the code might run without errors.
We might be able to implement _extract_patches
more efficiently for some special cases, e.g.
Dividing the diagonal elements before averaging them is numerically more stable in low precision. At the moment, we multiply the entire structured matrix (also the off-diagonal elements) with 1 / dim
, then call .trace()
. This wastes computation and memory.
We should support specifying the type of KFAC approximation (expand or reduce) to be used per param group, like, e.g., we do for the sparsity structure.
I am currently thinking of incorporating the example files via mkdocs-gallery
. The benefit will be that they will remain .py
files that we can still run linters on and execute in the test suite.
I parameterized the test that compares with the Lin 2023 implementation with reductions 'mean'
and 'sum'
. For reduction='sum'
, our implementation differs from Lin 2023. See the bug-reduction-sum
branch to re-produce this bug.
At the moment, I haven't done any debugging to narrow down the cause of this bug.
Related to #53.
For KFAC-reduce, we unfold the input, then average over the output locations. This is not necessary. One way to to the averaging on the fly is via einconv
, which is already a dependency of this package. See https://arxiv.org/abs/2307.02275 for details.
There is already a template in docs/examples/example_02_unsupported_parameters.py
. You can take a look at docs/examples/example_03_param_groups.py
for inspiration.
There should be a section in the documentation that visualizes the available structures. I've already written a visual test which creates .gif
animations. So we only have to figure out how to add them to the documentation.
@yorkerlin pointed out a bug in the update of IKFAC here.
Fix:
Replace
second_term = C_tC
with
second_term = C_tC * damping
After merging #63, the accumulation of H_K, H_C
will be done with a tensor hook that gets installed onto the output of a layer. We also have a second hook which saves the input to a layer into a .inputs
dictionary which acts like a global variable.
Proposal: Instead of saving the input to a layer into SINGD.layers
, directly pass it to the tensor hook that computes H_K, H_C
. This
SINGD.inputs
).step
matches the update frequency) from 2 to 1A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.