GithubHelp home page GithubHelp logo

ENH: Symbolic Vectorization about pytensor HOT 8 CLOSED

pymc-devs avatar pymc-devs commented on May 27, 2024 3
ENH: Symbolic Vectorization

from pytensor.

Comments (8)

ferrine avatar ferrine commented on May 27, 2024 1

Here is an draft implementation

import pytensor.tensor as pt
import pytensor.graph.fg
import pytensor.utils
import functools
from typing import *

@functools.singledispatch
def vectrorize_symbolic(op, inputs: List[pt.Variable], dims: List[Optional[int]]) -> (List[pt.Variable], List[Optional[int]]):
    if any(d is not None for d in dims):
        raise RuntimeError(f"vectorization was not implemented for {op}")
    return no_vectorize(op, inputs)

def no_vectorize(op, inputs: List[pt.Variable]) -> (List[pt.Variable], List[Optional[int]]):
    outs = pytensor.utils.flatten(op(*inputs))
    return outs, [None] * len(outs)

def batch_axis_at(tens, batch_axis, at=0):
    pattern = list(range(tens.ndim))
    if batch_axis is None:
        pattern.insert(0, "x")
    elif batch_axis == 0:
        return tens
    else:
        pattern.insert(at, pattern.pop(batch_axis))
    return tens.dimshuffle(*pattern)

@vectrorize_symbolic.register(pt.basic.Elemwise)
def vectorize_elementwise(op, inputs: List[pt.Variable], dims: List[Optional[int]]) -> (List[pt.Variable], List[Optional[int]]):
    # use broadcasting
    vec_inputs = [batch_axis_at(v, d) for v, d in zip(inputs, dims)]
    vec_outputs = pytensor.utils.flatten(op(*vec_inputs))
    return vec_outputs, [0]*len(vec_outputs)

@vectrorize_symbolic.register(pt.basic.DimShuffle)
def vectorize_dimshuffle(op, inputs: List[pt.Variable], dims: List[Optional[int]]) -> (List[pt.Variable], List[Optional[int]]):
    # move batch axis to first, rest axis are incremented
    def move(i, b):
        if b is None:
            return i
        elif i == "x":
            return "x"
        # [0, 1,  +1> 2, 3, 4]
        # [0, 1, "b", 2, 3, 4]
        elif i<b:
            return i
        elif i >=b:
            return i + 1
    batch_dim = dims[0]
    order = [0] + [move(i, batch_dim) for i in op.new_order]
    broadcastable = list(op.input_broadcastable)
    broadcastable.insert(batch_dim, False)
    new_op = pt.basic.DimShuffle(broadcastable, order)
    vec_outputs = pytensor.utils.flatten(new_op(*inputs))
    return vec_outputs, [0]

You can recreate vectorized fgraph in a toposort loop

x = pt.vector()
y = pt.vector()

z = (x + y) ** 2

inputs = [x, y]
outputs = [z.T]
fg = pytensor.graph.fg.FunctionGraph(inputs, outputs)
vec_mapping = dict(zip([pt.matrix(), pt.vector()], [0, None]))
inputs_map = dict(zip(fg.inputs, vec_mapping))


for apply in fg.toposort():
    orig_inputs = apply.inputs
    # try get a vectorized input for this operation
    # otherwise no batching was performed for the input
    vec_inputs = [inputs_map.get(o, o) for o in orig_inputs]
    vec_patterns = [vec_mapping.get(v, None) for v in vec_inputs]
    if all(d is None for d in vec_patterns):
        vec_outputs, dims = no_vectorize(apply.op, vec_inputs)
    else:
        vec_outputs, dims = vectrorize_symbolic(apply.op, vec_inputs, vec_patterns)
    vec_mapping.update(zip(vec_outputs, dims))
    inputs_map.update(zip(apply.outputs, vec_outputs))

vfg = pytensor.graph.fg.FunctionGraph([inputs_map[o] for o in fg.inputs], [inputs_map[o] for o in fg.outputs])

from pytensor.

ferrine avatar ferrine commented on May 27, 2024 1

With this approach, we can be generic about how we rewrite, for blockwise ops we can just move dims around and this core loop is fixed. The two PRs compliment each other in a sense

from pytensor.

ricardoV94 avatar ricardoV94 commented on May 27, 2024

Don't forget to see the work that was already started in aesara-devs/aesara#1215

CC @Sayam753 and @purna135

from pytensor.

ferrine avatar ferrine commented on May 27, 2024

Seems like that work is quite different from the proposal

from pytensor.

ricardoV94 avatar ricardoV94 commented on May 27, 2024

Seems like that work is quite different from the proposal

It sounds you're talking about a high-level interface whereas the Blockwise tackles one possible implementation at the Op level?

That PR implements Symbolic vectorization for individual Ops. What you're talking about would be something like a fusion of multiple vectorized Ops, the same way the Composites fuses multiple Elemwise?

In the end, vectorization is just an abstraction for looping with repeated operations over some tensors. If Scan was fast that wouldn't be a terrible first way to achieve vectorization.


Regarding NUTS, you can't really vectorize something like that unless you implement the algorithm in the same language.

from pytensor.

ferrine avatar ferrine commented on May 27, 2024

I thought about creating a new vectorized fgraph by applying vectorized translations.

an Op would have a dispatch rule

@_vectorize.register(Op):
def vectorize_node(op: Op, inputs: List[Variable], dims: List[Optional[int]]) -> List[List[Variable], List[Optional[int]]]:
    ...

This can potentially be multiple additional operations to achieve broadcasting patterns or completely different ops implementing the same thing

from pytensor.

ricardoV94 avatar ricardoV94 commented on May 27, 2024

Precisely, now if you pair this with Blockwise you expand coverage of Ops that can be vectorized to nearly 100%

For now you would be limited to Dimshuffles and Elemwises.

That's what I meant with keep that in mind.

from pytensor.

ferrine avatar ferrine commented on May 27, 2024

I see

from pytensor.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.