GithubHelp home page GithubHelp logo

subclass_zoo's Introduction

subclass zoo

This repository contains a number of examples of Tensor subclasses in PyTorch, specifically using __torch_dispatch__ to integrate deeply into PyTorch's existing subsystems (there's also some use of modes as well). We're still working out good APIs for working with Tensor subclasses, and this repository is here to tell you about what we've figured out so far! To run these examples, you will want a recent nightly of PyTorch.

Here's what's in the repo so far:

  • inner_autograd_tensor.py shows how to override autograd from __torch_dispatch__, by deferring autograd to the inner tensor on a subclass.
  • negative_tensor.py is a reimplementation of negative tensor views as implemented in PyTorch core (pytorch/pytorch#56058)
  • python_meta_tensor.py is a demonstration of how to extend an existing tensor (meta tensor) with some extra behavior (in this case, implementations of meta functions for operations that don't support it natively)
  • sparse_output.py
  • tracer_tensor.py
  • trivial_tensors.py is a comparison for two ways how to "wrap" tensors, one using inheritance (is-a) and one using composition (has-a) (so called wrapper tensors)
  • verifier_tensor.py

There are also some utility files:

  • base_tensor.py contains a common superclass that most of our tensors inherit from, that fixes up some problems with directly inheriting from torch.Tensor. We intend to upstream these changes so that this superclass is not necessary.
  • utils.py contains some handy utility functions that we found ourselves repeatedly using in our implementations.

We're still working on the APIs in questions, so sometimes there will be bugs. bug_zoo.py contains repros for known bugs we're tracking in PyTorch proper.

TODO

  • CUDA sanitizer in Python (hard cuz no event hooks)
  • Sparse gradients / outputs per Christian (using modes; gradients hard cuz need torch function mode)
  • SSD tensor
  • Reimplement functionalization tensor
  • Nested tensor
  • Custom allocator mode (albanD)
  • Lazy tensor
  • Immutable tensor
  • Various ways of writing FX passes https://gist.github.com/1c640ea30fd7451b08e90e34461459c1

Work plan

  • TODO: merge BaseTensor into Tensor

  • Get rid of fill_defaults

  • Compositionality

    • TODO: suppress elem in init

Developer notes

  • This repo is formatted with ufmt and autoflakes. Use ./format.sh to reformat all the files in this repository.

subclass_zoo's People

Contributors

alband avatar awgu avatar chillee avatar ezyang avatar msaroufim avatar pierreguilmin avatar sanketpurandare avatar yiliu30 avatar zou3519 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

subclass_zoo's Issues

Simple functorch: vmap should handle case where in_dims is None

I was talking with @samdow this morning and it would be cool for us to update simple functorch (hopefully in a simple way) to get its implementation of vmap to handle cases where only some of the tensors are being vmapped over. This would let us experiment with what exactly the batching rule of custom_vjp should look like.

Simple functorch: cannot nest vmaps

In simple functorch, because the batch dimension is always expected to be the 0th dimension, we can't do

vmap(vmap(foo))(torch.randn(3, 4)) (removing label + FuncTensor wrappers here in the hopes of understandability. Full example linked at the end)

where the intention is to have one vmap with a batch size of 3 with a nested vmap with batch size 4. We won't even be able to get the nested vmap to have a batch size of 4 because it will see the shape of the tensor as (3, 4). So its batch size is also 3. This means that in the first layer of vmap when it calls self.inner.size(input), we'll get (4,) back because it eliminates the batch dimension (3), causing an error because of the misunderstood batch dimension

In general, from talking with @zou3519, if we wanted to support this case, we would need to have use a wrapper (see autodidax) or use a weak map. However, we might find that introducing these ruins some of the simplicity of simple_functorch (producing not_so_simple_functorch??)

Full repro:

def simpler(a):
    return a.unsqueeze(0)

a = FuncTensor(label(torch.randn(3, 4)), DISPATCHER)
vmap(vmap(simpler))(a)  # fails on: assert self.inner.size(input)[0] == self.length

Save and restore TLS has to also reset torch_function disabled state

pytorch/pytorch#73942 is blocked because torch function gets disabled in the default implementation and then if you fall through to Python key it is still disabled, which breaks FX tracing. So we need to somehow reenable torch function by the time we get to python key (it can't be saved via the snapshot TLS key because by the time we get to dispatcher it's already disabled). Alban suggests just unconditionally turning it back on when we go Python key.

Fix composite compliance

Aka, pytorch/pytorch#69991

Motivation

We should make sure all PyTorch operations are "Composite Compliant". This condition is necessary for operators that look out-of-place to preserve the Tensor Subclass when passed in a Tensor Subclass.

Consider the following hypothetical out-of-place operation:

def my_add(x, y):
  result = x.clone()
  result.add_(y)
  return result

You may expect this to work the same as torch.add. However, if x is not a Tensor Subclass, but y is a Tensor subclass, then this returns us a regular Tensor, NOT a Tensor subclass!

`__torch_dispatch__` vs `__torch_function__`

Hi, thanks for this amazing repo!

I have a question regarding the usage of __torch_dispatch__ and __torch_function__ in PyTorch. While this repository appears to focus on __torch_dispatch__, I haven't been able to find much information on it other than blog posts and discussions on GitHub. On the other hand, PyTorch's documentation only mentions __torch_function__.

To my understanding [0][1], __torch_dispatch__ is used for low-level adjustments, whereas __torch_function__ serves as syntax sugar. However, the function signatures of both methods are nearly identical, making it unclear which one to use for practitioners who aren't familiar with PyTorch's inner workings.

Could someone please explain the practical differences between these methods? Specifically, when should one use __torch_dispatch__ versus __torch_function__, and what benefits and limitations do each approach offer?

Thanks in advance for your time!

[0] https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557
[1] pytorch/pytorch#85826 (comment)

Question: apply trainable scale for qdq `linear` and `matmul`

In a quantization scenario where fake quantization is utilized to assess the accuracy of a new algorithm with trainable scale, we can implement it for an eager model by replacing the Linear module with QDQLinear, as demonstrated below:

class QDQLinear(torch.nn.Module):
    def __init__(self, orign_linear: torch.nn.Module) -> None:
        super().__init__()
        self._orign_linear = original_tensor
        self.trainable_scale =  torch.nn.Parameter(torch.tensor(1), requires_grad=True)
    
    def qdq_tensor(self, input: torch.Tensor):
        # ... new qdq method that use `self.trainable_scale` to update q-dq tensor.
        # int_input = q(input)
        # qdq_input = dq(int_input)
        # return qdq_input
        pass
    
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        input = self.qdq_tensor(input)
        return torch.nn.functional.linear(input, self._orign_linear.weight, self._orign_linear.bias)


### replace all `Linear` with `QDQLinear`

However, some models utilize torch.matmul to perform similar thing as torch.nn.Linear. We also want to apply the aforementioned QDQ method to torch.matmul, but this cannot be achieved through module swapping.

We may probably customize a new TorchDispatchMode to replace all aten.mm with qdq - aten.mm to apply qdq to all input tensors of torch.matmul or torch.nn.Linear. However, I'm currently unsure how to handle the trainable_scale. Do you happen to have any suggestions?

Thank you very much!

AttributeError: 'super' object has no attribute '__torch_dispatch__'

Hi, thanks for this great repo to subclass a torch.Tensor.
I have the following error when I try to run many of the scripts in the repo.

  • PyTorch version: 1.11.0+cu113

  • functorch version: 0.1.0

  • On Ubuntu 20.04, with Python 3.8.10

     Traceback (most recent call last):
        File "/home/jeff/workspace/subclass_zoo/python_meta_tensor.py", line 396, in test_embedding_via_mode
          embedding = torch.nn.Embedding(10, 3, device="meta")
        File "/home/jeff/miniconda3/envs/torch-nightly/lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 139, in __init__
          self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
        File "/home/jeff/workspace/subclass_zoo/python_meta_tensor.py", line 91, in __torch_dispatch__
          return super().__torch_dispatch__(func, types, args, kwargs)
      AttributeError: 'super' object has no attribute '__torch_dispatch__'
    

Compositional modes

Currently only one mode is allowed for enable_python_mode. Should be able to specify multiple modes.

`torch.compile`'ed function inside `__torch_dispatch__`

Thanks again for the help last time!

Lately, we've been experimenting with calling some torch.compile'ed function side __torch_dispatch__ for performance reasons. However, I noticed that compiled function does not work well with __torch_dispatch__. Here is the code snippet:

import torch
from torch._ops import OpOverload
from torch.utils import _pytree as pytree
from torch.testing._internal import common_subclass
from typing import Tuple, List, Dict, Any, Callable, Type
class ContainerTensor(common_subclass.WrapperTensor):

    @classmethod
    def get_wrapper_properties(
        cls,
        tensor: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, Any]]:
        return tensor, {}

    def __init__(self, tensor: torch.Tensor) -> None:
        self._tensor = tensor

    # We disable torch function here to avoid any unwanted wrapping of the output
    __torch_function__ = torch._C._disabled_torch_function_impl

    @classmethod
    def __torch_dispatch__(
        cls,
        func: OpOverload,
        types: List[Type],
        args: List[Any],
        kwargs: Dict[str, Any],
    ) -> torch.Tensor:

        @torch.compile
        def f(t):
            return t + 1

        # Unwrap and apply the function
        args_materialized = pytree.tree_map_only(
            ContainerTensor,
            lambda tensor: f(tensor._tensor),
            args)
        return func(*args_materialized, **kwargs)

Then the following codes will fail:

A = torch.randn(3, device="cuda")
A_ = ContainerTensor(A)
A_ + 1

I'm wondering if there are better ways to call compiled function in such scenario? Thanks in advance for your time again!

How to retain the grad of via __torch_dispatch__ for torch.Tensor method

I have a question, which might be very simple, but I have no idea how to fix it.

I am trying to subclass a torch.Tensor, and want to retain the grad of the original torch.Tensor method.

Here is my code:

import torch
from torch.utils._pytree import tree_map

class MyTensor(torch.Tensor):

    @staticmethod
    def __new__(cls, tensor):
        return torch.Tensor.as_subclass(tensor, cls)

    def __init__(self, tensor):
        self.tensor = tensor

    __torch_function__ = torch._C._disabled_torch_function_impl

    def __repr__(self):
        return self.__class__.__name__ +':\n'+ self.tensor.__repr__()

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(t):
            return t.tensor if isinstance(t, cls) else t

        def wrap(t):
            return cls(t) if isinstance(t, torch.Tensor) and not isinstance(t, cls) else t
        
        return tree_map(wrap, (super().__torch_dispatch__(func, types, args, kwargs)))
    
    def my_method(self):
        return self.tensor.exp()
  • Here is the result:
>>>x = MyTensor(torch.randn(3, requires_grad=True))
>>>x
MyTensor:
tensor([1.4196, 2.0849, 1.2102], requires_grad=True)
  • The original method doesn't retain grad.
>>>x.exp()
MyTensor:
tensor([4.1355, 8.0442, 3.3543])
  • Newly defined method retains grad:
>>>x.my_method()
tensor([4.1355, 8.0442, 3.3543], grad_fn=<ExpBackward0>)

if I use __torch_function__, it can retain the grad. How can I retain the grad by using __torch_dispatch__?
Thank you so much!

Current "inheritance trivial tensor" doesn't seem to really work with in-place ops.

import contextlib

import torch
from torch.utils._pytree import tree_map


@contextlib.contextmanager
def no_dispatch():
    guard = torch._C._DisableTorchDispatch()
    try:
        yield
    finally:
        del guard


class BaseTensor(torch.Tensor):
    # See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary
    # to ensure that super().__new__ can cooperate with each other
    @staticmethod
    def __new__(cls, elem, *, requires_grad=None):
        if requires_grad is None:
            return super().__new__(cls, elem)
        else:
            return cls._make_subclass(cls, elem, requires_grad)

    def __init__(self, elem):
        super().__init__()
    __torch_function__ = torch._C._disabled_torch_function_impl


class ISATensor(BaseTensor):
    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def wrap(e):
            if isinstance(e, torch.Tensor) and not isinstance(e, cls):
                return cls(e)
            else:
                return e

        with no_dispatch():
            print("\n-----------{}-----------".format(func))
            #print(args)
            result = tree_map(wrap, super().__torch_dispatch__(func, types, args, kwargs))
        return result

inp = torch.randn(3, requires_grad=True)

ISATensor(inp).sin().relu_()

This usage of the inheritance-based wrapper tensor fails with RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation..

However, modifying it to this (which more closely mimics how the PythonTensor works) works.

class BaseTensor(torch.Tensor):
    # See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary
    # to ensure that super().__new__ can cooperate with each other
    @staticmethod
    def __new__(cls, elem, *, requires_grad=None):
        if requires_grad is None:
            result = cls._make_subclass(cls, elem)
        else:
            result = cls._make_subclass(cls, elem, requires_grad)
        return result

    def __init__(self, elem):
        super().__init__()
    __torch_function__ = torch._C._disabled_torch_function_impl

Any thoughts on what's the right way to do things?

cc: @ezyang @albanD

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.