Comments (3)
I also can't figure out how to make it work in Numba mode, for some reason it doesn't apply then
from pytensor.
Some progress, it seems that when in NUMBA mode (and I assume JAX as well), the inner fgraph of scan is not optimized according to the Scan mode. In the default backend the optimization is triggered when accessing the dynamic property node.fn
, which seems to usually be done first in make_thunk
. I imagine this path is not triggered by the JIT linkers.
If I access the property manually, the correct graph is obtained:
import aesara
aesara.config.mode = "NUMBA"
import aesara.tensor as at
from aesara.compile.builders import OpFromGraph
x = at.scalar("x")
out = at.log(x)
op = OpFromGraph([x], [out], inline=True)
xs = at.vector("xs")
seq, _ = aesara.scan(
fn=lambda x: op(x),
sequences=[xs],
)
seq.owner.op.fn # Trigger inner fgraph optimization manually!!!
# aesara.config.optimizer_verbose = True
f = aesara.function([xs], seq)
aesara.dprint(f)
from pytensor.
#28 solves this for the NUMBA backend, but not JAX. I did not try to change it since Scan is completely broken for JAX
from pytensor.
Related Issues (20)
- BUG: Differentiating Fourier transform fails when the batch dimension has size 1
- BUG: `log10` gradient introduces `float64` HOT 1
- Implement matrix_transpose and mT method for TensorVariables HOT 2
- Drop support for Python 3.9
- Add `name` keyword argument to `Op.__call__` HOT 3
- Test on numpy 2.0 HOT 6
- Add type hints / documentation for `mode` argument HOT 1
- Add helper function to find all inputs needed to a compile a graph HOT 1
- DOC: Installation instructions in the developer start guide outdated HOT 1
- Move constants inside OpFromGraph and remove unused inputs/outputs
- `pt.specify_broadcastable` does not work with negative axis values HOT 2
- Add kwarg to call to return next rng from RandomVariables
- Get rid of dummy type input to RVs
- BUG: configdefaults invokes config.cxx even when config.cxx is absent
- Don't fail when squeezing specific non-broadcastable axis HOT 2
- Avoid dimshuffle if expand_dims has empty axis
- Add pre-commit hook to avoid print statements. HOT 2
- Include Op in message when raising NotImplementedError from grad HOT 5
- Upload codecov job is failing systematically
- Reconcile environment for development and docs building
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytensor.