Comments (5)
Do you mean keras.ops.vectorized_map
?
from keras.
Hi François,
Thank you for your quick response.
Sorry I am not so familiar with Jax. Now I found that vmap is similar to vectorized_map in TF and keras.
I am particularly interested in map_fn because my operation cannot be vectorialized due to the large intermediate variable generated during the computation. I used map_fn excessively in my project to simulate some physical processes.
I understand that (1) tf.map_fn uses while_loop under the hood and (2) both tf and jax will convert the python loop to graph using while_loop.
However, my issue is that when using (2), I cannot set the parallel_iterations, and (1) is currently unavailable in Keras 3.0. I am now trying to make a map_fn function by myself using while_loop. One puzzle for me is that tf.map_fn uses tf.TensorArray to accumulate the result during the iteration, which is also unavailable in Keras 3.0. It would be very useful if there were some examples in Keras on this task.
Here is an example of my code about using map_fn:
import numpy as np
import tensorflow as tf
from keras import ops
data = np.random.randn(3, 1024, 1024)
data = ops.convert_to_tensor(data)
dataFT = tf.map_fn(
lambda elem: ops.fft2(elem),
elems=(ops.real(data), ops.imag(data)),
fn_output_signature=(data.dtype, data.dtype),
)
As you can see here, this code does not work with Keras using the Jax backend. Thank you very much for any hints.
P.S. Another change in Keras that significantly influenced my application is that lays does not support complex variables any more. This is very strange because keras.ops provides complex-conjugate, but I cannot pass complex variables from one layer to another.
Kind regards,
Yifeng Shao
from keras.
Another op you can try is keras.ops.vectorize
, which is equivalent to np.vectorize
and is effectively the same as vmap
but with a nicer syntax.
def myfunc(a, b):
return a + b
vfunc = keras.ops.vectorize(myfunc)
y = vfunc([1, 2, 3, 4], 2) # Returns Tensor([3, 4, 5, 6])
Now, if you want to use tf.map_fn
specifically, you can also use that with the TF backend.
from keras.
Hi François,
Thank you for your quick response.
Sorry I am not so familiar with Jax. Now I found that vmap is similar to vectorized_map in TF and keras.
I am particularly interested in map_fn because my operation cannot be vectorialized due to the large intermediate variable generated during the computation. I used map_fn excessively in my project to simulate some physical processes.
I understand that (1) tf.map_fn uses while_loop under the hood and (2) both tf and jax will convert the python loop to graph using while_loop.
However, my issue is that when using (2), I cannot set the parallel_iterations, and (1) is currently unavailable in Keras 3.0. I am now trying to make a map_fn function by myself using while_loop. One puzzle for me is that tf.map_fn uses tf.TensorArray to accumulate the result during the iteration, which is also unavailable in Keras 3.0. It would be very useful if there were some examples in Keras on this task.
Here is an example of my code about using map_fn:
import numpy as np import tensorflow as tf from keras import ops
data = np.random.randn(3, 1024, 1024) data = ops.convert_to_tensor(data)
dataFT = tf.map_fn( lambda elem: ops.fft2(elem), elems=(ops.real(data), ops.imag(data)), fn_output_signature=(data.dtype, data.dtype), )
As you can see here, this code does not work with Keras using the Jax backend. Thank you very much for any hints.
P.S. Another change in Keras that significantly influenced my application is that lays does not support complex variables any more. This is very strange because keras.ops provides complex-conjugate, but I cannot pass complex variables from one layer to another.
Kind regards, Yifeng Shao
In TensorFlow, the tf.map_fn
is different with tf.vectorized_map
tf.map_fn
tf.vectorized_map
In JAX, the jax.vmap
is similar as tf.vectorized_map
in TensorFlow.
In numpy, the np.vectorize
is similar as tf.map_fn
in TensorFlow.
from keras.
Dear Edward,
Thank you for your further clarification.
It seems that the map_fn function is unique for tensorflow and no similar function can be found in other projects. In physics simulations, I believe such a function is very important.
Could you let me know what will happen when converting a Python loop (e.g. pre-allocate the memory by initiating an empty variable and then fill the element through a loop) to a graph? Is this equivalent to map_fn?
import numpy as np
import keras
data = np.random.randn(3, 1024, 1024)
data_real = np.zeros_like(data)
data_imag = np.zeros_like(data)
for ind in np.arange(data.shape[0]):
data_real[ind], data_imag[ind] = keras.ops.fft2((ops.real(data[ind]), ops.imag(data[ind]))
It seems that such a practice is not common in the machine machine-learning community... Thanks a lot for any help here.
Kind regards,
Yifeng
from keras.
Related Issues (20)
- Unable to make two instances of the MobileNetV3 within the same model HOT 2
- NumPy 2.0 support HOT 3
- Add backend-agnostic worker-process data loading HOT 8
- Keras does not save weights properly HOT 3
- Potential bug in legacy h5 weights loading. HOT 2
- Enable Discussions Tab in Github HOT 1
- FeatureSpace multiple output from one input HOT 3
- `keras.Sequential` sometimes states misleading reason for failing to construct model HOT 2
- Implement tool for saved Keras model file inspection, diff, and patching. HOT 6
- AttributeError raised: 'list' object has no attribute 'dtype' when running the official example of SparseCategoricalAccuracy, TopKCategoricalAccuracy, SparseTopKCategoricalAccuracy HOT 2
- ValueError: (F1Score|FBetaScore) expects 2D inputs with shape (batch_size, output_dim). HOT 1
- `plot_model` does not work for all models in `keras.applications` HOT 3
- averagePooling2D calculates wrongly with torch backend HOT 1
- softmax sliently generate a wrong output when the mask has an incompatible shape HOT 1
- CenterCrop meets a division-by-zero with jax backend HOT 2
- CategoryEncoding generate output with wrong shape when output_mode='one_hot' and the last dim is 1 HOT 1
- Resizing layer meets division by zero crash with jax backend HOT 2
- Multi-process batch size not calculated correctly HOT 6
- Overriding `Layer.forward` unexpectedly changes the signature of `Layer.__call__` under torch backend HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras.