[HN Gopher] Training Deep Networks with Data Parallelism in Jax ___________________________________________________________________ Training Deep Networks with Data Parallelism in Jax Author : sebg Score : 65 points Date : 2023-02-24 17:05 UTC (5 hours ago) (HTM) web link (www.mishalaskin.com) (TXT) w3m dump (www.mishalaskin.com) | qmatch wrote: | A nice simple walkthrough in this post, but would be nice if it | was updated to show how to do this with sharding and the new | jax.Array type introduced not too long ago | | https://github.com/google/jax/pull/11233/files | amrb wrote: | So I've been looking into ONNX to speedup inference, is there | some killer feature I should look at JAX for? | chas wrote: | Jax is a great tool, but it's really best for training and | experimentation. The transformations outlined in this post | (amongst others) make it easy to turn simple and | straightforward code into high performance parallel code. While | this is changing, inferences hasn't been a historical area of | emphasis for the project, so it wouldn't be my first choice if | that was your primary goal. | UncleOxidant wrote: | About a year ago I was tasked with comparing ONNX runtime | implementations of certain components like convolution with | some of our own in-house implementations. There was just no | comparison. ONNX runtime has some pretty crazy fast | implementations. Lots of magic in there. Concluded that we | weren't going to be able to beat those without a lot of effort | and expertise that we didn't have in our team. | kkielhofner wrote: | Generally from what I've seen the biggest inference speedup win | with ONNX is to get the model to ONNX then to TRT (TensorRT) - | assuming Nvidia hardware. Once in "TRT land" you can play with | fp32, fp16, int8 (with calibration), etc. That said (again | generally) ONNX does tend to perform better when compared to | native (TF savedmodel, pytorch torchscript, whatever). With | TensorFlow and Pytorch there are also ways to export/compile | directly to TRT but from an ops standpoint I don't prefer this. | | Certain inference serving solutions like Nvidia Triton | Inference Server will even take an ONNX model and then do TRT | compilation (with cache!) on the actual inference hardware | dynamically at model load time. This is really nice because you | can deploy a standard ONNX model across instances and varying | GPU hardware and always get TRT optimized and compatible with | Compute Capability, etc. Really handy and basically comes down | to a few lines of config in the model configuration. | | I'm not terribly familiar with JAX but I have to imagine | there's ONNX export or straight to TRT export somewhere. | mccoyb wrote: | JAX is such a beautiful system. There's many deep PL ideas in the | design of JAX, one could spend years thinking about them. It's | wonderfully fun to design new interpreters, implementing some | semantics, stage them out -- automatically gain access to JIT | compilation and accelerators via JAX's other higher-order | primitives/transformations. | | I've become a big believer that it would be beneficial for PL | research in ML which makes heavy use of program transformations | to provide small JAX-based implementations. There's really no | other system which allows you to express interpreter-based | transformations with the benefits that JAX provides (maybe | `functorch` in a few months? I have some doubts of transformation | composition with systems like torchdynamo - but I don't know much | about it) | | Edit: note this is coming from a long time Julia stan, take that | for what it is worth :) | UncleOxidant wrote: | Since you're a Julia stan: Don't you think that the program | transformations that JAX is doing could be done much more | easily in Julia since Julia has macros? Aren't there things | that are similar to JAX in the Julia ecosystem? (ie. a few | different autodiff packages that do program transformation) | mccoyb wrote: | No -- macros are a non-intrusive transformation -- they don't | allow you to transform callees of a function, unless you wrap | the macro around the callee. People have tried this in Julia, | and it's horribly slow. | | There's another mechanism in Julia - generated functions. | These allow method body specialization given type knowledge | about the signature of the function -- so a user can write | code which is generated for the method body when inference | determines the signature (and the inferred signature is tight | enough) which depends on the inferred types. | | All of Julia's program transformation based AD packages are | based on the latter transformation -- most of them do | terrible things to the compiler, including massively blowing | up the size of code before optimization. | | The only package which is more promising is Diffractor -- but | I'm not convinced it is more than a research prototype at its | current level of development. That may change. This was | written by one of the compiler devs, and uses lower level | hooks into the compiler, developed to support its | transformation. | | The big issue in general: Julia doesn't let you write | transformations on its typed IR from user space, unless you | want to ignore Julia's native execution engine. There are | hooks that someone can work with -- but they aren't user- | facing (for all but the most advanced users) -- and they | break pass composability with code generation using the | native engine (this may have changed since I last looked at | this!) I would know, because I've spent several attempts | trying to do stuff like this, and making crappy, unstable | packages :) | | Separately: macros are one level of reflection -> code | generation. JAX supports a different form -- you can't emit | data which represents generic expressions -- it's not quite | like Lisp in that sense. It's better to think about JAX as a | "two-level" language system -- where you have a meta-level | which is Python, and there's a statically typed array | language which is the object level. JAX supports a stage | -like operation which allows transforming compat subset of | Python to the statically typed array language. But you can | write interpreters in the full dynamism of Python -- as long | as the "active paths" (under tracers) are in that compat set, | you can then stage out applications of the interpreters on | Python functions, etc. | | JAX provides one solution to the composable transformation | problem, and they've done it in an elegant way - that's | ~pretty~ easy to understand (c.f. the Autodidax tutorial). | With my current knowledge of things, I can't effectively | argue that Julia supports the same right now (caveat: things | may have changed since I last had a look). This is an area | where a lot of stuff seems to be going on in Julia, so I | doubt it will remain this way forever. | jdeaton wrote: | The abstractions provided by JAX for parallelism are beautiful. | JAX is an absolute master-class in programming-interface design | and a lesson in the power of providing composable primitive | operations and FP inspired design. An astounding amount of | complexity is hidden from the user behind primitives like pmap, | and the power is exposed in such a simple interface. | alfalfasprout wrote: | Agreed. Though keep in mind they built on a lot of failed | attempts at doing the same to get here. | 6gvONxR4sf7o wrote: | That's true, and is a massive part of what I love about JAX, | but they also form barriers in weird parts of your code, | preventing standard introspection tools, which is the single | thing I hate about JAX. The errors are amazingly opaque. | mattjjatgoogle wrote: | If you have any particular examples in mind, and time to | share them on https://github.com/google/jax/issues, we'd love | to try to improve them. Improving error messages is a | priority. | | About introspection tools, at least for runtime value | debugging there is to some extent a fundamental challenge: | since jax.jit stages computation out of Python (though | jax.grad and jax.vmap don't), it means standard Python | runtime value inspection tools, like printing and pdb, can't | work under a jax.jit as the values aren't available as the | Python code is executing. You can always remove the jax.jit | while debugging (or use `with jax.disable_jit(): ...`), but | that's not always convenient, and we need jax.jit for good | performance. | | We recently added some runtime value debugging tools which | work even with jax.jit-staged-out code (even in automatically | parallelized code!), though they're not the standard | introspection tools: see `jax.debug.print` and | `jax.debug.breakpoint` on | https://jax.readthedocs.io/en/latest/debugging/index.html and | https://jax.readthedocs.io/en/latest/debugging/print_breakpo. | ... | | If you were thinking about other kinds of introspection | tooling, I'd love to hear about it! | 6gvONxR4sf7o wrote: | > with jax.disable_jit(): ... | | That's handy, and I hadn't seen it before, thanks. | | It's been a bit, but I think the most frustrating errors | were around mapping pytrees (like this issue | https://github.com/google/jax/issues/9928). I'm not sure | the exact solution, but the axis juggling and | specifications were where I remember a lot of pain, and the | docs (though extensive) were unclear. At times it feels | like improvements are punted on in the hopes that xmap | eventually fixes everything (and xmap has been in | experimental for far longer than I expected). | | Also the barriers where I couldn't disable jit. IIRC pmap | automatically jits, so there was no way to avoid staging | that part out. When it came to doing some complex | jax.lax.ppermute, it felt more difficult than it needed to | be to debug. | | Next time I encounter something particularly opaque, I'll | share on the github issue tracker. | mattjjatgoogle wrote: | Thanks for taking the time to explain these. | | > It's been a bit, but I think the most frustrating | errors were around mapping pytrees (like this issue | https://github.com/google/jax/issues/9928). | | We've improved some of these pytree error messages but it | seems that vmap one is still not great. Thanks for the | ping on it. | | > Also the barriers where I couldn't disable jit. IIRC | pmap automatically jits, so there was no way to avoid | staging that part out. | | That was indeed a longstanding issue in pmap's | implementation. And since people came to expect jit to be | "built in" to pmap, it wasn't easy to revise. | | However, we recently | (https://github.com/google/jax/pull/11854) made | `jax.disable_jit()` work with pmap, in the sense that it | makes pmap execute eagerly, so that you can print/pdb/etc | to your heart's content. (The pmap successor, shard_map | (https://jax.readthedocs.io/en/latest/jep/14273-shard- | map.htm...), is eager by default. Also it has uniformly | good error messages from the start!) | | > Next time I encounter something particularly opaque, | I'll share on the github issue tracker. | | Thank you for the constructive feedback! | 6gvONxR4sf7o wrote: | Thanks! One last thing, since I have your ear. The | function transformation aspects of jax seem to make their | way into downstream libraries like haiku, resulting in a | lot of "magic" that can be difficult to examine and | debug. Are there any utils you made to make jax's own | transformations more transparent, which you think might | be helpful to third party transformations? | | Higher order functions are difficult in general, and it | would be fantastic to have core patterns or tools for | breaking them open. | mattjjatgoogle wrote: | Thanks for the kind words! We've been doing a lot more work in | this direction too, for both compiler-based automatic | parallelization [0] and a work-in-progress pmap successor for | 'manual' parallelism (per-device code and explicit collectives) | [1] which composes seamlessly with the compiler-based stuff. | | [0] | https://jax.readthedocs.io/en/latest/notebooks/Distributed_a... | | [1] https://jax.readthedocs.io/en/latest/jep/14273-shard- | map.htm... | uptownfunk wrote: | It seems like the ecosystem is still dominated by PyTorch, is Jax | supposed to be a competitor? Any signs of Jax taking over PyTorch | anytime soon? Is it perhaps too early for its time? Or is there a | critical flaw in the underlying design? | mccoyb wrote: | I think it's young - and perhaps JAX itself is not so | specialized to a specific task (but there's plenty of libraries | for deep learning focused tooling, although not as mature as | PyTorch). It has often been said in other threads on JAX, but | it feels like a very different type of library from other AD | systems -- the focus on concisely exposing/allowing users to | express composable transformations seems novel! (But I may be | mistaken) | | But in general, I would suspect youth. | time_to_smile wrote: | I think it's better to think of JAX as a more general framework | for differentiable programming and PyTorch more focused | specifically on deep learning/neural networks. | | The beauty of JAX is that basic usage is basically a single | function: `grad`. | | You just write whatever Python function you want and can get | the derivative/gradient of it trivially. It gets a bit trickier | when you need more sophisticated numeric tools like | numpy/scipy, but in those cases it's just about swapping out | with a JAX version of those. | | In this sense JAX is the spiritual success to Autograd. However | the really amazing thing about JAX is that not only do you get | the autodiff for basically free, you also get very good | performance, and basically GPU parallelism without needing to | think about it at all. | | PyTorch is an awesome library, but largely focus on building | Neural Networks specifically. JAX should be thought of a tool | that basically any Python programmer can just throw in there | whenever they come across a problem that benefits from having | differentiable code (which is a lot of cases once you start | thinking about differentiation as a first class feature). ___________________________________________________________________ (page generated 2023-02-24 23:00 UTC)