Comments (5)
I experimented two approaches for
mlx/python/mlx/nn/layers/dropout.py
Line 35 in 1086dc4
return (1 / self._p_1) * mask * x # Return dtype is mx.float32
but when I replace that with
return mask * x * ( 1 / self._p_1) # Return dtype is mx.bfloat16
This problem is related to type promotion and similar usecase exist in Numpy and numerous libraries.
Since python executes from left to right; So in the first case, when weakly typed float
(1 / self._p_1) is multiplied with mx.bool_
, (mask) the resulting dtype is mx.float32
and thus entire equation is calculated in mx.float32
.
When we first multiply mx.bool_
with mx.bfloat16
, the weakly typed float
is promoted to just mx.bfloat16
thus preserving the dtype
from mlx.
class Dropout(nn.Module):
def __init__(self, p: Optional[float] = 0.5):
super().__init__()
self.p = p
def __call__(self, x: mx.array, key: Optional[mx.array] = None):
if not self.training or self.p == 0:
return x
else:
mask = mx.random.uniform(shape=x.shape, dtype=x.dtype, key=key) > self.p
return mask * x / (1.0 - self.p)
from mlx.
Just a heads up, the tree_flaten
method drops the Dropout layers from the network, so when we serialized the optimizer:
mx.save_safetensors(file=optimizer_state_path, arrays=dict(optimizer_state))
and decided to continue training from a previously saved checkpoint, we got a bunch of key errors...
from mlx.
I'm not following. The Dropout
layers shouldn't be saved in tree_flatten
since they don't have parameters. Could you share some code which reproduces the errors you are getting?
from mlx.
@awni created a separate issue for this: #1328
from mlx.
Related Issues (20)
- Memory Leakage Issue in MLX 0.16 HOT 9
- [PERFORMANCE] grads for bitwise ops + indexing HOT 1
- [Performance] Linear Layer Benchmark HOT 2
- [BUG] expm1 handling of overflow / underflow causes wrong results
- [Feature Request] mx.pad supports the "edge" padding mode.
- [BUG] Cannot use mlx.metallib from xcode MacOS project (Swift, C++)
- [BUG] Cannot convert a list with `None` to `mx.array` HOT 1
- [BUG] Docs not building HOT 4
- How can we enable w4a8 GEMM in MLX? HOT 6
- Missing pyi file for mlx.core HOT 1
- Performance Comparison Issue: Matrix Multiplication on MLX vs. PyTorch on Mac HOT 3
- [Feature Request] Cannot create tensor from raw bytes + dtypes HOT 3
- [BUG] mx.std returns NaN HOT 1
- What is the equivalent to a Flatten layer in MLX? HOT 6
- How to mask out padding tokens when calculating the cross-entropy? HOT 2
- 使用exo+mlx多台mac运行llama-3.1-70b,返现量化时报错[BUG] HOT 1
- [Performance] PyTorch (MPS) is faster than MLX in backward of convolution layer HOT 4
- [BUG] Convergence issue in MLX when compared to PyTorch HOT 4
- [BUG] Unable to load from a saved checkpoint, `KeyError` for all dropout modules... HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mlx.