GithubHelp home page GithubHelp logo

Comments (5)

bhargavyagnik avatar bhargavyagnik commented on August 15, 2024 1

I experimented two approaches for

return (1 / self._p_1) * mask * x

  • return (1 / self._p_1) * mask * x # Return dtype is mx.float32

but when I replace that with

  • return mask * x * ( 1 / self._p_1) # Return dtype is mx.bfloat16

This problem is related to type promotion and similar usecase exist in Numpy and numerous libraries.

Since python executes from left to right; So in the first case, when weakly typed float (1 / self._p_1) is multiplied with mx.bool_, (mask) the resulting dtype is mx.float32 and thus entire equation is calculated in mx.float32.

When we first multiply mx.bool_ with mx.bfloat16, the weakly typed float is promoted to just mx.bfloat16thus preserving the dtype

from mlx.

vgoklani avatar vgoklani commented on August 15, 2024
class Dropout(nn.Module):
    def __init__(self, p: Optional[float] = 0.5):
        super().__init__()
        self.p = p

    def __call__(self, x: mx.array, key: Optional[mx.array] = None):
        if not self.training or self.p == 0:
            return x
        else:
            mask = mx.random.uniform(shape=x.shape, dtype=x.dtype, key=key) > self.p
            return mask * x / (1.0 - self.p)

from mlx.

vgoklani avatar vgoklani commented on August 15, 2024

Just a heads up, the tree_flaten method drops the Dropout layers from the network, so when we serialized the optimizer:

mx.save_safetensors(file=optimizer_state_path, arrays=dict(optimizer_state))

and decided to continue training from a previously saved checkpoint, we got a bunch of key errors...

from mlx.

awni avatar awni commented on August 15, 2024

I'm not following. The Dropout layers shouldn't be saved in tree_flatten since they don't have parameters. Could you share some code which reproduces the errors you are getting?

from mlx.

vgoklani avatar vgoklani commented on August 15, 2024

@awni created a separate issue for this: #1328

from mlx.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.