GithubHelp home page GithubHelp logo

toshas / torch-householder Goto Github PK

View Code? Open in Web Editor NEW
58.0 58.0 2.0 629 KB

Efficient Householder Transformation in PyTorch

License: Other

Shell 0.54% Python 88.86% C++ 10.60%
householder-product householder-reflectors householder-transformation lapack linalg orgqr orthogonality parameterization pytorch

torch-householder's People

Contributors

toshas avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

torch-householder's Issues

Slow performance compared to `torch.linalg.householder_product` in forward pass

Problem

I'm using an orthonormal constrained matrix in a learnable filterbank setting. Now I want to optimize the training and run some profiling with torch, but getting strange results. Just want to double-check here whether I'm doing something wrong.

Code

I'm constructing the matrix during forward pass like this:

def __init__(self, ..):
       [..]

        # householder decomposition
        decomp, tau = torch.geqrf(filters)

        # assume that DCT is orthogonal
        filters = decomp.tril(diagonal=-1) + torch.eye(decomp.shape[0], decomp.shape[1])

        # register everything as parameter and set gradient flags
        self.filters = torch.nn.Parameter(filters, requires_grad=True)
        self.register_parameter('filter_q', self.filters)

def filters(self):
        valid_coeffs = self.filters.tril(diagonal=-1)
        tau = 2. / (1 + torch.norm(valid_coeffs, dim=0) ** 2)
        return torch.linalg.householder_product(valid_coeffs, tau)
        #return torch_householder_orgqr(valid_coeffs, tau)

Profiles

All profiles are created with the pytorch profiler with warmup of one and two trial runs:

Profile torch householder_product (matrix 512x512 f32)

  • forward pass: ~823us
  • backward pass: ~790ms

Marked forward pass and backward pass visible in light green:

image

Profile torch-householder (matrix 512x512 f32)

  • forward pass: ~240ms
  • backward pass: ~513ms

image

Questions

I'm not an expert in torch and do not follow the development closely. There is an issue pytorch/pytorch#50104 for integrating CUDA support to orgqr, may this cause the difference in time?

  • why is the torch-householder library much slower in the forward pass
  • is this performance expected from AD of a matrix w.r.t to its householder or am I doing something wrong here?
  • why does the number actually add up again to ~800ms, this makes me suspect that my profiling is doing something wrong but couldn't find a cause

I'm also happy to share the traces with you, please just ping then :)

[POLL] Should the package switch install-time to run-time native code compilation?

Currently, the package compiles native code (C++) upon package installation. This saves a few seconds during code run time, as the compilation does not happen when the user code starts. However, one scenario when it hurts is when the package is installed from a different environment or a machine than the actual code will be run on. This is a use case with most cluster environments, where packages may be installed from a login node, rather than the actual machine with the GPU.

Should compilation be rather performed at run time?

๐Ÿ‘ - Move compilation to run time
๐Ÿ‘Ž - Keep as is

> \prod_{i=1}^k H(v_i)

\prod_{i=1}^k H(v_i), is that right?
Right.
Hi, toshas, your work is great, thanks!
You confirm that if we denote H(v) = I_n - v v^T / \norm{v}^2 a Householder reflection, the torch_householder_orgqr() computes \prod_{i=1}^k H(v_i), and I have two questions, one is I think H(v) = I_n โ€“ 2vv^T/\norm{v}^2, and the other is \prod_{i=1}^k H(v_i) means H_1* H_2*โ€ฆH_(n-1) or H_(n-1)* H_(n-2)โ€ฆH_1 ? If it is H_1 H_2*โ€ฆH_(n-1), it refers to Q in QR decomposition, right? But you said earlier that this is not a QR decomposition, so I feel a little confused. Hope to get your prompt reply, thank you.

The parametrisation does not seem to be surjective.

When having a look at the implementation and looking at the differences with torch.householder_product, I found a weird thing.

At the moment, when called with a tensor of the form hh = param.tril(diagonal=-1) + torch.eye(d, r) (as per the documentation) we are passing nk - k(k+1)/2 parameters. This is correct, as it is the number of parameters necessary to parametrise the orthogonal matrices (i.e. it is the dimension of the Stiefel manifold).

Now, when this matrix is passed to torch_householder_orgqr, it columns are normalised:

param_normalized = param / torch.linalg.norm(param, dim=param.dim()-2, keepdim=True).clamp(min=eps)

Now, this is another constraint. Each column is normalised, which removes another k degrees of freedom (k dimensions to be precise). This means that the current implementation cannot represent all the possible orthogonal matrices.

Do you know what is going on here?

Support of complex tensors

I need support for complex datatype to learn a unitary matrix. How much effort would it be to add that?

error when installing torch-householder via pip

Hello,
I am having trouble installing the package. I don't know if there is a dependency problem. Could there be an installation problem due to the version of CUDA? My specs are using Ubuntu 18.04 and CUDA 11.3. There is no difference from installing in CUDA 10.2 ?

Collecting torch-householder
  Using cached torch_householder-1.0.1.tar.gz (457 kB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... error
  ERROR: Command errored out with exit status 1:
   command: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmp_p_1kd61
       cwd: /tmp/pip-install-ot451iqp/torch-householder_bee491246ef740cbb618d42e92070ff3
  Complete output (21 lines):
  Traceback (most recent call last):
    File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 363, in <module>
      main()
    File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 345, in main
      json_out['return_val'] = hook(**hook_input['kwargs'])
    File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 130, in get_requires_for_build_wheel
      return hook(config_settings)
    File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 338, in get_requires_for_build_wheel
      return self._get_build_requires(config_settings, requirements=['wheel'])
    File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 320, in _get_build_requires
      self.run_setup()
    File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 335, in run_setup
      exec(code, locals())
    File "<string>", line 4, in <module>
    File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/__init__.py", line 191, in <module>
      _load_global_deps()
    File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/__init__.py", line 153, in _load_global_deps
      ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
    File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/ctypes/__init__.py", line 374, in __init__
      self._handle = _dlopen(self._name, mode)
  OSError: /tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/lib/../../nvidia/cublas/lib/libcublas.so.11: symbol cublasLtGetStatusString version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference
  ----------------------------------------
WARNING: Discarding https://files.pythonhosted.org/packages/f3/7d/a87d4ea6c11f23d237fc81c094a6c18909486fdb9914599479cbeb5d089f/torch_householder-1.0.1.tar.gz#sha256=9a4b240c68947491c4e96a78771497562650f9a555001e062a0969fce206f786 (from https://pypi.org/simple/torch-householder/) (requires-python:>=3.6). Command errored out with exit status 1: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmp_p_1kd61 Check the logs for full command output.
  Using cached torch_householder-1.0.0.tar.gz (177 kB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... error
  ERROR: Command errored out with exit status 1:
   command: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmpiqcotimp
       cwd: /tmp/pip-install-ot451iqp/torch-householder_bb93639781d44e6799b67c9f4f83fae9
  Complete output (21 lines):
  Traceback (most recent call last):
    File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 363, in <module>
      main()
    File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 345, in main
      json_out['return_val'] = hook(**hook_input['kwargs'])
    File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 130, in get_requires_for_build_wheel
      return hook(config_settings)
    File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 338, in get_requires_for_build_wheel
      return self._get_build_requires(config_settings, requirements=['wheel'])
    File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 320, in _get_build_requires
      self.run_setup()
    File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 335, in run_setup
      exec(code, locals())
    File "<string>", line 4, in <module>
    File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/__init__.py", line 191, in <module>
      _load_global_deps()
    File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/__init__.py", line 153, in _load_global_deps
      ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
    File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/ctypes/__init__.py", line 374, in __init__
      self._handle = _dlopen(self._name, mode)
  OSError: /tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/lib/../../nvidia/cublas/lib/libcublas.so.11: symbol cublasLtGetStatusString version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference
  ----------------------------------------
WARNING: Discarding https://files.pythonhosted.org/packages/c2/9c/7af0f1414e24c09ddc67439364c70c7e894c65ed83f94aa82a0ff3308673/torch_householder-1.0.0.tar.gz#sha256=e9f06c29685a6bcbc360af5c8cbae9227d9990e3e9acc744b0ea15654ab782c0 (from https://pypi.org/simple/torch-householder/) (requires-python:>=3.6). Command errored out with exit status 1: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmpiqcotimp Check the logs for full command output.
ERROR: Could not find a version that satisfies the requirement torch-householder (from versions: 1.0.0, 1.0.1)
ERROR: No matching distribution found for torch-householder

How are the gradients implemented for non-full rank matrices?

It is not clear how to implement the "gradient" (adjoint) of the QR decomposition for a matrix that it is not full rank (e.g. the zero matrix). How does this package handle this?

Also, if this is a drop-in replacement for the QR decomposition implemented in PyTorch and it works better, why not making a PR to core PyTorch with this? Where does the speed-up come from vs LAPACK / MAGMA?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.