GithubHelp home page GithubHelp logo

Using `jaxtyping` with torch about jaxtyping HOT 6 CLOSED

google avatar google commented on July 28, 2024
Using `jaxtyping` with torch

from jaxtyping.

Comments (6)

patrick-kidger avatar patrick-kidger commented on July 28, 2024

Hey there! Yep, I'd love to update torchtyping as well.

I'm planning to essentially:

  • copy-paste the jaxtyping codebase
  • remove the JAX parts (e.g. PyTrees)
  • Add a backward-compatible new version of torchtyping.TensorType that internally lowers to produce the new thing. (And for patch_typeguard that just prints a depreciation warning and then does nothing.)
  • Update all the documentation, tests, etc.

In practice that's a fair amount of work so it's not a priority for me right now.

If this is important to you then I'd welcome a PR against torchtyping that does the above.

from jaxtyping.

brentyi avatar brentyi commented on July 28, 2024

Makes sense, thanks for the fast reponse!

Have you toyed with the possibility of merging the projects?

I'm happy to carve out time for a PR, but the duplication is mildly uncomfortable. torch will apparently also have PyTrees soon so fully removing that may not even be desired: pytorch/pytorch#65761

I could also see myself using this kind of thing for projects with just vanilla Numpy. 🙂

from jaxtyping.

patrick-kidger avatar patrick-kidger commented on July 28, 2024

Yep, I did consider merging things. This ends up being infeasible due to edge cases, e.g. JAX and PyTorch (eventually) having different PyTree manipulation routines.

Anyway, great - go ahead and open a PR on torchtyping whenever you like. (And for now I'm happy to cut out PyTrees since PyTorch doesn't support these yet; easy enough to add back later.)

from jaxtyping.

brentyi avatar brentyi commented on July 28, 2024

Ok! Added torchtyping updates to my to-do list.

Two question if you have a moment:

(1)

This ends up being infeasible due to edge cases, e.g. JAX and PyTorch (eventually) having different PyTree manipulation routines.

Are there other edge cases you'd be able to elaborate on? Maybe that I should be aware of when making torchtyping updates? I'm having some trouble understanding your note on the feasibility of unification, and generally why the benefits of it don't (greatly) outweigh the tradeoffs that would be required, particularly when it seems like you've already done the hard parts of making jaxtyping torch-friendly.

Did you have thoughts, for example, on an API where:

  • The core utilities are moved into arraytyping.Float32, arraytyping.Shaped, etc.
  • JAX-specific things are moved into arraytyping.jax.PyTree, etc. (is there anything else?)
  • Torch-specific things are moved into arraytyping.torch.PyTree, etc. (is there anything else?)

(2)

Was syntax that looks like Array[Float32, "c h w"] considered as an alternative to Float32[Array, "c h w"]? Just curious; it seems like either could be written in a way that's compatible with both static and dynamic type checks.

Thanks again for your library!!

from jaxtyping.

patrick-kidger avatar patrick-kidger commented on July 28, 2024
  1. I suspect at a technical level this might be feasible. (PyTrees and dtypes are the main two discrepancies.) The dominant issues here are:

    • Time: supporting more things takes up dev time and maintenance time. In particular thinking hard about how you achieve compatibility. I'm not a current PyTorch user, and jaxtyping isn't my day job -- just a side-project (one of many I support).
    • Marketability: "jaxtyping" is a much easier thing to advertise than a generic arraytyping library!
    • Simplicity: it's important both to me as a dev, and to new users looking to understand the API, that we don't have too many special cases squirrelled away.
    • Feature creep: once we have .jax and .torch namespaces then before long someone will ask me to support another library I don't currently use ;)

    (Besides which I think the name arraytyping would need to be revisited given the presence of PyTrees in the library!)

  2. Yup, this was considered. At the moment so far as static type checking is concerned, our Float32 is actually a typing.Annotated, so that Float32[Array, "foo bar"] is treated as Annotated[Array, "foo bar"] and thus treated as just Array. If we were to go for an Array[Float32, "foo bar"] syntax then we'd be setting Array = Annotated, and then that would break whenever someone just writes def foo(x: Array).

from jaxtyping.

brentyi avatar brentyi commented on July 28, 2024

Thanks for the detailed explanation! That's convincing on both fronts; on (2) it seems like removing Array = Annotated and just using Shaped = Annotated with Float32 = Annotated[jnp.ndarray, "some_float32_marker"] would result in a sensible API (eg func(shaped_array: Shaped[Float32, "c h w"], unshaped_array: Float32)), but the distinction is minor and there are cons to that approach too.

Will do my best to make a PR (probably at the end of) this week! I started an effort here was but was derailed by life, if anybody else has time feel free 🙃

from jaxtyping.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.