Comments (11)
There's something to tease apart here: you can put for-loops in @jit
functions and they'll compile just fine if the loop bounds don't depend on the values of the arguments to the function (but instead only depend on their shapes, or some other fixed values). The XLA code that gets compiled will have those loops unrolled.
This lax.while
business is about compiling XLA code that itself has loop constructs in it, rather than unrolled loops. You need that if, for example, the loop exit condition depends on the value of an input argument to the @jit function. Even if the loop bounds were static, it can be preferable to generate compiled loop constructs because they might reduce compile times.
I think @alexbw's answer was about the latter case (generating XLA code with loop constructs in it), but I want to underscore that the former case (unrolling Python loops into XLA code) works without using any special constructs.
It's probably also useful to underscore that:
- If you put
@jit
on a function with Python control flow that depends on the values of the@jit
function arguments, you'll get a loud error. No silent failures! - These constraints only apply to using
jit
andvmap
. Automatic differentiation using functions likegrad
don't have any of these constraints, and so that works just like in Autograd with no need for special control flow constructs.
from jax.
Here's a relevant paragraph from the README:
JAX provides some small, experimental libraries for machine learning. These libraries are in part about providing tools and in part about serving as examples for how to build such libraries using JAX. Each one is only a few hundred lines of code, so take a look inside and adapt them as you need!
To that end, stax is a minimalistic library for building neural networks. It's not meant to be complete, and that's the main way it compares to other libraries: it's pretty limited. But its power-to-weight ratio is pretty high: it's only a couple hundred lines of code!
Many users have found it to be a useful starting point in writing their own libraries. There are some already open-sourced, like trax which has a lot more capabilities, more models, and more features, and I can tell you there are several more being developed by users inside Google that I suspect will be open-sourced soon.
The JAX core team is planning to make a more complete yet still minimalistic library over the next several months. It might not ever include all the bells and whistles that other libraries do, but we think we can do more than just stax and have it be both a useful tool and a jumping-off point that inspires how others write libraries.
WDYT?
from jax.
We're definitely accepting contributions and bugfixes, but we're also fine with people who rely on the current version of stax making their own copies of stax.py
, especially if you want to add many layers or features.
from jax.
We have an FAQ now! It's not exactly what was outlined here, but, well, we can grow it as we find new frequently asked questions.
from jax.
I'd like to add a question, @alexbw!
- Does JIT-compilation compile the for-loops in my code, or does it only compile the array computations to GPU/TPU?
(Forgive me if I have any misconceptions here about what JIT-compilation is all about... if my question "isn't even wrong", please let me know!)
from jax.
from jax.
@alexbw yes, it would! For most Python programmers, it is much more natural to write for-loops. Good to know that the cond
and while
constructs exist, but yes, for
and if
would play nicely with Pythonic conventions!
from jax.
@mattjj thanks for the helpful response, and hello from Cambridge! 😄
from jax.
from jax.
Hi, I would also like to add a question, which is probably related to the first two:
What is stax? How complete is it supposed to be? How does it compare to tensorflow/pytorch/other commonly used DL frameworks?
from jax.
Huuuum, I think I got it, though it is an uncommon goal (compared to most other dl libraries). So, are you still accepting contributions to stax? Or do you see it as something that only the core team should/will touch?
from jax.
Related Issues (20)
- spsolve exits with error when inverting matrix sum HOT 4
- jax.random seems to have unnecessary buffer allocations on stack HOT 6
- buggy interaction: remat, automatic partitioning, and unsafe `rbg`-based RNGs
- Seeking guidance for landing spot of `scipy.stats.levy_stable` in Jax
- dynamic config scope under `jit` doesn't change partitionable threefry behavior
- Unexpected speedup from wrapping function call in trivial jax.lax.cond statement
- Persistent compilation cache does not work HOT 2
- ROCm 6.1, 7900 xtx: bfloat16 support not enabled? HOT 1
- Remaining deprecations for array API compliance
- Crash in `eval_jaxpr` with 0.4.27 HOT 17
- '+ptx84' is not a recognized feature for this target (ignoring feature) HOT 13
- linalg.solve produces NaNs on GPU, but not on CPU
- [pallas] Interpreter mismatch for masked OOB indexing
- Strided indexing turns into gather HOT 3
- Compilation cache does not work with custom partitioning HOT 2
- Cannot pass token to custom primitive when using explicit device placement HOT 2
- Frequent Segfault crashes with v0.4.28 HOT 3
- lax.while_loop with multiple conditions HOT 1
- ROCm 6.1.1: float16 + 7900 XTX, docker: NaN detected in train_loss HOT 1
- ComplexWarning in VJP when using complex matrix multiplication HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.