GithubHelp home page GithubHelp logo

f-dangel / singd Goto Github PK

View Code? Open in Web Editor NEW
19.0 19.0 0.0 1.71 MB

[ICML 2024] SINGD: KFAC-like Structured Inverse-Free Natural Gradient Descent (http://arxiv.org/abs/2312.05705)

Home Page: https://singd.readthedocs.io/en/latest/

Shell 0.07% Makefile 0.87% Python 99.06%

singd's People

Contributors

f-dangel avatar runame avatar wiseodd avatar yorkerlin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

singd's Issues

Deprecate `.from_inner2` from `StructuredMatrix` interface

We originally incorporated this function to support updates based K.T @ (a @ a.T) @ K and K.T @ (g @ g.T) @ K, e.g. in the private ASDL implementation.

But now we are using an update that computes (K.T @ a) (K.T @ a).T and (K.T @ g) @ (K.T @ g).T instead via StructuredMatrix.from_inner. Hence, we can remove from_inner2 and the associated tests.

Outsource all KFAC computations to `curvlinops`

Once KFAC is part of a curvlinops release, we can try to remove all KFAC related computations from this repository and use curvlinops instead. Currently this is not trivially possible, but I think it would be really valuable to move to this solution at some point, since it allows us to only maintain one KFAC implementation that other methods and projects can use.

Distinguish between expand and reduce setting and KFAC approximation

Issue

There is the expand and the reduce setting, which is determined by the number of loss terms. If there are NR loss terms, we are in the expand setting and if there are N loss terms we are in the reduce setting (N is the number of examples and R is the size of the weight-sharing dimension, e.g. the sequence length).

Then there are the KFAC-expand and KFAC-reduce approximations. Both approximations can be applied in both settings.

Currently, we only have one argument to the SINGD optimizer called kfac_approx, allowing the user to specify the KFAC approximation that should be used. There is no way to specify the setting that is used, i.e. the number of loss terms. This is not an issue for almost all cases, but one: When batch_averaged=True, the preconditioner is rescaled such that the scaling is consistent with the one of the mini-batch gradient. Hence, we need to know over how many terms the loss was averaged to apply the correct scaling. Since the setting cannot be deduced by the KFAC approximation, additional information is necessary here. See these comments for more information on this issue.

In the current code, we assume that the setting corresponds to the KFAC approximation that is used. This does not hold in many cases, e.g. in the vision transformer image classification experiment with KFAC-expand, which is an example for the reduce setting.

Potential solution

Add a second argument allowing the user to explicitly specify the setting, i.e. the number of loss terms. We have to find a way that is not too confusing to people unfamiliar with the expand and reduce terminology.

It might make sense to only use "expand" and "reduce" for the KFAC approximation and not the setting. Instead, we could call the second argument something like loss_type and possible values could be something like "batch" (reduce setting) and "batch+sequence" (expand setting). Alternatively, n_loss_terms with options "batch" and "batch*sequence".

[REF] Register mechanism for tensors of structured matrices

We have started using _tensors_to_sync in many other places that have nothing to do with data-parallel training.

It would be better if there was a mechanism to register tensors (similar to register_parameter for nn.Modules):

  • register_tensor(self, t: Tensor, name: str): Register a tensor as component of a structured matrix.
  • tensors(self) -> Iterable[Tensor]: Yield all tensors constituents of a structured matrix.
  • (Optional) named_tensors(self) -> Iterable[Tuple[str, Tensor]]: Yield names and tensor constituents of a structured matrix.

Then, we can remove _tensors_to_sync and use tensors() to handle accumulation, but also to implement in-place addition & multiplication in the base class.

Accumulators inconsistent with PyTorch convention

The PyTorch convention is to divide the loss by the number of iterations (number of micro-batches) that are accumulated into a mini-batch gradient, and by the number of processes:

loss /= iters_to_accumulate * num_procs

Instead, our example on how to use gradient accumulation does not do apply this division and is therefore inconsistent. To fix this, we have to rewrite the accumulator update function.

For the documentation, we should mention that it is crucial to use the PyTorch convention when accumulating over micro-batches.

Support convolutions with string-valued padding

(follow-up from #34)

The current version of _extract_patches does not support string-valued padding, but only 2-tuples, although strings are allowed to specify the padding in torch.nn.Conv2d.

We could use this function from my einconv library to compute the left and right paddings given the convolution's hyper-parameters.

Test edge cases of structured matrices

The structured matrices should work with any dimension.

For neural networks, this dimension is often >10, but there might be rare cases where the dimension is 1. This should be added to the test cases.

(Maybe not necessary: Testing that everything works for 0 x 0 matrices, because this never happens in practise.)

Clean up post-pruned git repo

  • Remove pruned folders from linter and git config files
  • Remove asdl-related code
  • Remove submodules that are not required
  • Remove untested and undocumented sparse_ngd.experimental code
  • Make CI work

Bug: Full backward hook incompatible with in-place activations

SINGD relies on full_backward_hooks, which are incompatible with in-place operations (see pytorch/pytorch#61519).

Minimum example to reproduce:

"""SINGD with a model that uses in-place activations."""
from singd.optim.optimizer import SINGD
from torch import rand
from torch.nn import Conv2d, ReLU, Sequential

TRIGGER_BUG = True

model = Sequential(
    Conv2d(1, 1, 3),
    ReLU(inplace=TRIGGER_BUG),
)
optim = SINGD(model)  # install hooks

batch_size = 2
X = rand(2, 1, 5, 5)

output = model(X)

Output:

  File "...", line 16, in <module>
    output = model(X)
  File "...", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/.../python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/.../python3.9/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/.../python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/.../python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/.../python3.9/site-packages/torch/nn/modules/activation.py", line 101, in forward
    return F.relu(input, inplace=self.inplace)
  File "/.../python3.9/site-packages/torch/nn/functional.py", line 1469, in relu
    result = torch.relu_(input)
RuntimeError: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

Potential fix: (Not verified if working) Replace the full_backward_hook with a tensor backward_hook on the output of a module. This is the officially recommended solution. The tensor hook can be installed with a forward hook (see this solution).

Kronecker approximations for GNNs

The Kronecker approximation depends on a NN architecture. We should support important GNN layers such as the GCNConv layer. A reference Kronecker approximation for the GCNConv can be found at here. @runame proposed and implemented Kronecker approximations for another GNN layer.

`float16` & where to hide complexity from scaling tricks

During one of our internal discussions, we realized that the code starts to accumulate multiple re-scaling operations which are required to avoid over/under-floating when using float16 (e.g. using the average trace #24, splitting the 1 / batch_size and grad_scale when computing H_K, H_C). The problem with those tricks is that their implementation is often non-local, i.e. they affect multiple functions and are thus hard to understand. Long-term, the accumulation of such tricks will make the code unusably complex.

In my opinion, we should attempt to keep the optimizer's update step as scaling trick-free as possible, and move such complexity inside the StructuredMatrix interface. Also, we will need a strong motivation to implement this, and examples demonstrating the effectiveness of such tricks, as it could double the code base for the structured matrices. One possible scenario is that float16 is simply too unstable, and we will thus not support it.

  • Idea 1 (stable structured matrices): We could support a DenseStableMatrix class, which is the equivalent of DenseMatrix but has special implementations with higher stability (but also cost) in float16. This structure could then be easily accessed by the optimizer by adding a 'dense_stable' option to its supported entries in structures. A simple idea that might keep the computation stable might be to treat each Tensor t of a structured matrix internally as a (float, Tensor) tuple (scale, t_scaled), that is a normalized tensor and its scaling factor. For instance, multiplication by a scalar alpha will then just correspond to using (alpha * scale, t_scaled) internally.

    However, to me it is currently unclear if such a simple heuristic will consistently make all operations stable.

  • Alternative ideas go here

Improve `_extract_patches` for grouped convolutions

Related to #53.

When processing the input of a convolution, we unfold it first, then average over channel groups.
Instead, we can first average the channel groups, then unfold. This reduces the unfold op's memory and FLOPS by a factor of groups.

Polish README and API documentation

The current documentation has many placeholders that need to be filled (e.g. installation instructions, link to documentation/examples, ...).

  • Fix typo in paper title in bibtex block of README

Refactor KFAC utils w.r.t. readability

After #26, the KFAC utilities will have solid tests and we can apply one more round of more aggressive refactoring targeting readability of the code.

As an example, I think we should extensively rely on einops and get rid of all the ugly reshapes/transposes/sums/means that require mentally keeping track of what is happening with the shape of a quantity.

Add internal section in the docs for structured matrices and templates

Most of the structures are easy to explain in equations. But equations are hard to read in source code (docstrings).
Hence, it will be beneficial to add an 'Internal' section to the documentation where all structured matrices and templates to create new ones are contained and the math in their docstrings is nicely rendered.

Add example for DDP training

We only support a specific form of DDP. The easiest way to explain is to a user is an example. However, it will be challenging to execute such an example in the CI as it requires access to a multi-GPU machine.

Support dilated convolutions

(follow-up from #34)

The current version of _extract_patches does not take into account the dilation of a convolution. Supporting this is not urgent as I am not aware of any torchvision CNN that uses dilated convolution. But we should fix this long-term because the code might run without errors.

Replace `StructuredMatrix.trace` with `.average_trace`

Dividing the diagonal elements before averaging them is numerically more stable in low precision. At the moment, we multiply the entire structured matrix (also the off-diagonal elements) with 1 / dim, then call .trace(). This wastes computation and memory.

Support KFAC-expand and KFAC-reduce

We should support specifying the type of KFAC approximation (expand or reduce) to be used per param group, like, e.g., we do for the sparsity structure.

Add examples to documentation

I am currently thinking of incorporating the example files via mkdocs-gallery. The benefit will be that they will remain .py files that we can still run linters on and execute in the test suite.

Add animations of available structures in documentation

There should be a section in the documentation that visualizes the available structures. I've already written a visual test which creates .gif animations. So we only have to figure out how to add them to the documentation.

Merge hooks that save layer inputs and compute `H_K`, `H_C`

After merging #63, the accumulation of H_K, H_C will be done with a tensor hook that gets installed onto the output of a layer. We also have a second hook which saves the input to a layer into a .inputs dictionary which acts like a global variable.

Proposal: Instead of saving the input to a layer into SINGD.layers, directly pass it to the tensor hook that computes H_K, H_C. This

  1. Eliminates a 'global' variable (SINGD.inputs)
  2. Reduces the number of hooks (and related boilerplate code like checking if .step matches the update frequency) from 2 to 1

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.