Comments (8)
Here is an draft implementation
import pytensor.tensor as pt
import pytensor.graph.fg
import pytensor.utils
import functools
from typing import *
@functools.singledispatch
def vectrorize_symbolic(op, inputs: List[pt.Variable], dims: List[Optional[int]]) -> (List[pt.Variable], List[Optional[int]]):
if any(d is not None for d in dims):
raise RuntimeError(f"vectorization was not implemented for {op}")
return no_vectorize(op, inputs)
def no_vectorize(op, inputs: List[pt.Variable]) -> (List[pt.Variable], List[Optional[int]]):
outs = pytensor.utils.flatten(op(*inputs))
return outs, [None] * len(outs)
def batch_axis_at(tens, batch_axis, at=0):
pattern = list(range(tens.ndim))
if batch_axis is None:
pattern.insert(0, "x")
elif batch_axis == 0:
return tens
else:
pattern.insert(at, pattern.pop(batch_axis))
return tens.dimshuffle(*pattern)
@vectrorize_symbolic.register(pt.basic.Elemwise)
def vectorize_elementwise(op, inputs: List[pt.Variable], dims: List[Optional[int]]) -> (List[pt.Variable], List[Optional[int]]):
# use broadcasting
vec_inputs = [batch_axis_at(v, d) for v, d in zip(inputs, dims)]
vec_outputs = pytensor.utils.flatten(op(*vec_inputs))
return vec_outputs, [0]*len(vec_outputs)
@vectrorize_symbolic.register(pt.basic.DimShuffle)
def vectorize_dimshuffle(op, inputs: List[pt.Variable], dims: List[Optional[int]]) -> (List[pt.Variable], List[Optional[int]]):
# move batch axis to first, rest axis are incremented
def move(i, b):
if b is None:
return i
elif i == "x":
return "x"
# [0, 1, +1> 2, 3, 4]
# [0, 1, "b", 2, 3, 4]
elif i<b:
return i
elif i >=b:
return i + 1
batch_dim = dims[0]
order = [0] + [move(i, batch_dim) for i in op.new_order]
broadcastable = list(op.input_broadcastable)
broadcastable.insert(batch_dim, False)
new_op = pt.basic.DimShuffle(broadcastable, order)
vec_outputs = pytensor.utils.flatten(new_op(*inputs))
return vec_outputs, [0]
You can recreate vectorized fgraph in a toposort loop
x = pt.vector()
y = pt.vector()
z = (x + y) ** 2
inputs = [x, y]
outputs = [z.T]
fg = pytensor.graph.fg.FunctionGraph(inputs, outputs)
vec_mapping = dict(zip([pt.matrix(), pt.vector()], [0, None]))
inputs_map = dict(zip(fg.inputs, vec_mapping))
for apply in fg.toposort():
orig_inputs = apply.inputs
# try get a vectorized input for this operation
# otherwise no batching was performed for the input
vec_inputs = [inputs_map.get(o, o) for o in orig_inputs]
vec_patterns = [vec_mapping.get(v, None) for v in vec_inputs]
if all(d is None for d in vec_patterns):
vec_outputs, dims = no_vectorize(apply.op, vec_inputs)
else:
vec_outputs, dims = vectrorize_symbolic(apply.op, vec_inputs, vec_patterns)
vec_mapping.update(zip(vec_outputs, dims))
inputs_map.update(zip(apply.outputs, vec_outputs))
vfg = pytensor.graph.fg.FunctionGraph([inputs_map[o] for o in fg.inputs], [inputs_map[o] for o in fg.outputs])
from pytensor.
With this approach, we can be generic about how we rewrite, for blockwise ops we can just move dims around and this core loop is fixed. The two PRs compliment each other in a sense
from pytensor.
Don't forget to see the work that was already started in aesara-devs/aesara#1215
from pytensor.
Seems like that work is quite different from the proposal
from pytensor.
Seems like that work is quite different from the proposal
It sounds you're talking about a high-level interface whereas the Blockwise tackles one possible implementation at the Op level?
That PR implements Symbolic vectorization for individual Ops. What you're talking about would be something like a fusion of multiple vectorized Ops, the same way the Composites fuses multiple Elemwise?
In the end, vectorization is just an abstraction for looping with repeated operations over some tensors. If Scan was fast that wouldn't be a terrible first way to achieve vectorization.
Regarding NUTS, you can't really vectorize something like that unless you implement the algorithm in the same language.
from pytensor.
I thought about creating a new vectorized fgraph by applying vectorized translations.
an Op would have a dispatch rule
@_vectorize.register(Op):
def vectorize_node(op: Op, inputs: List[Variable], dims: List[Optional[int]]) -> List[List[Variable], List[Optional[int]]]:
...
This can potentially be multiple additional operations to achieve broadcasting patterns or completely different ops implementing the same thing
from pytensor.
Precisely, now if you pair this with Blockwise you expand coverage of Ops that can be vectorized to nearly 100%
For now you would be limited to Dimshuffles and Elemwises.
That's what I meant with keep that in mind.
from pytensor.
I see
from pytensor.
Related Issues (20)
- Implement `numpy.sign` equivalent HOT 3
- ENH: `Join` looses static shape information HOT 1
- CI issues because sphinx dropped py37 support HOT 1
- Release pipeline broke because PyPI doesn't accept platform/version-specific Wheels HOT 7
- Apply scan memory save rewrite to while scans
- Improve Numba compile time
- Cache output of numba_funcify where possible
- Move numba linalg functions into pytensor HOT 1
- Refactor gradient related methods HOT 11
- BUG: `pt.stack` inputs type hint is incomplete HOT 3
- Don't repeat inner graphs with multiple outputs in debugprint
- DOC: Math typo in logistic sigmoid function HOT 2
- BUG: `constant_folding` error *only in interactive mode* pymc5 HOT 5
- Static shapes HOT 5
- Refactor `_print_name` for certain RVs? HOT 1
- BUG: Scan on Windows running into NumPy `ndarray size changed` error HOT 1
- Installation problem HOT 1
- Scan invalid indexing is not always caught
- Build wheels on Ubuntu CI takes a long time
- Provide lower level Numba and Jax functions HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytensor.