Comments (6)
There is no Flatten
layer yet. You would have to redo the computation like so:
class MLP(nn.Module):
def __init__(self, out_dims):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 20, 5), # input channels, output channels, kernel size
nn.ReLU(),
nn.MaxPool2d(2, 2), # kernel size, stride length
nn.Conv2d(20, 50, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
)
self.mlp = nn.Sequential(
nn.Linear(800, 500),
nn.ReLU(),
nn.Linear(500, 10),
)
def __call__(self, x):
x = self.conv(x):
x = x.flatten(-3, -1)
x = self.mlp(x)
return(x)
At some point we considered adding Flatten
but decided we prefer not to mirror every op with an NN equivalent and the equivalent above is not so onerous.
from mlx.
@awni Thanks, I have tried this. But I've ended up across another bug.
Here is an minimal example:
i = 0
for X, y in batch_iterate(64, train_images, train_labels):
i += 1
loss, grads = loss_and_grad_fn(model, X, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
if(i == 100):
break
print(model.parameters())
The parameters become nan
at some point between the 100th and 150th batch. It varies on every evaluation but is usually within this range. Is this related to #1277 or #319 ?
from mlx.
That I have no idea about. You’d need to share more code to fully reproduce this so we can help debug.
also make sure you are using the latest MLX.
from mlx.
@awni I'm attaching my code here. I am using version 0.16.1.
from mlx.
You are converting train_labels
to one hot and then using it's size
to determine the size of the dataset. That is a bug because the size will be a factor of 10 too large. So when in your batch_iterate
function you will be reading lot's of unitialized memory.
My recommendation would be to not convert the labels to one hot, just use them as is which works with cross_entropy
and is more efficient.
Alternatively you could change your batch iteration to get the right dataset size:
def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.shape[0]))
for s in range(0, y.shape[0], batch_size):
ids = perm[s: s+batch_size]
yield X[ids], y[ids]
from mlx.
Great catch. I overlooked y.size
. Thank you for the help.
from mlx.
Related Issues (20)
- Memory Leakage Issue in MLX 0.16 HOT 9
- [PERFORMANCE] grads for bitwise ops + indexing HOT 1
- [Performance] Linear Layer Benchmark HOT 2
- [BUG] expm1 handling of overflow / underflow causes wrong results
- [Feature Request] mx.pad supports the "edge" padding mode.
- [BUG] Cannot use mlx.metallib from xcode MacOS project (Swift, C++)
- [BUG] Cannot convert a list with `None` to `mx.array` HOT 1
- [BUG] Docs not building HOT 4
- How can we enable w4a8 GEMM in MLX? HOT 6
- Missing pyi file for mlx.core HOT 1
- Performance Comparison Issue: Matrix Multiplication on MLX vs. PyTorch on Mac HOT 3
- [Feature Request] Cannot create tensor from raw bytes + dtypes HOT 3
- [BUG] mx.std returns NaN HOT 1
- How to mask out padding tokens when calculating the cross-entropy? HOT 2
- 使用exo+mlx多台mac运行llama-3.1-70b,返现量化时报错[BUG] HOT 1
- [Performance] PyTorch (MPS) is faster than MLX in backward of convolution layer HOT 4
- [BUG] Dropout not preserving `dtype` HOT 5
- [BUG] Convergence issue in MLX when compared to PyTorch HOT 4
- [BUG] Unable to load from a saved checkpoint, `KeyError` for all dropout modules... HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mlx.