[HN Gopher] Training Deep Networks with Data Parallelism in Jax
       ___________________________________________________________________
        
       Training Deep Networks with Data Parallelism in Jax
        
       Author : sebg
       Score  : 65 points
       Date   : 2023-02-24 17:05 UTC (5 hours ago)
        
 (HTM) web link (www.mishalaskin.com)
 (TXT) w3m dump (www.mishalaskin.com)
        
       | qmatch wrote:
       | A nice simple walkthrough in this post, but would be nice if it
       | was updated to show how to do this with sharding and the new
       | jax.Array type introduced not too long ago
       | 
       | https://github.com/google/jax/pull/11233/files
        
       | amrb wrote:
       | So I've been looking into ONNX to speedup inference, is there
       | some killer feature I should look at JAX for?
        
         | chas wrote:
         | Jax is a great tool, but it's really best for training and
         | experimentation. The transformations outlined in this post
         | (amongst others) make it easy to turn simple and
         | straightforward code into high performance parallel code. While
         | this is changing, inferences hasn't been a historical area of
         | emphasis for the project, so it wouldn't be my first choice if
         | that was your primary goal.
        
         | UncleOxidant wrote:
         | About a year ago I was tasked with comparing ONNX runtime
         | implementations of certain components like convolution with
         | some of our own in-house implementations. There was just no
         | comparison. ONNX runtime has some pretty crazy fast
         | implementations. Lots of magic in there. Concluded that we
         | weren't going to be able to beat those without a lot of effort
         | and expertise that we didn't have in our team.
        
         | kkielhofner wrote:
         | Generally from what I've seen the biggest inference speedup win
         | with ONNX is to get the model to ONNX then to TRT (TensorRT) -
         | assuming Nvidia hardware. Once in "TRT land" you can play with
         | fp32, fp16, int8 (with calibration), etc. That said (again
         | generally) ONNX does tend to perform better when compared to
         | native (TF savedmodel, pytorch torchscript, whatever). With
         | TensorFlow and Pytorch there are also ways to export/compile
         | directly to TRT but from an ops standpoint I don't prefer this.
         | 
         | Certain inference serving solutions like Nvidia Triton
         | Inference Server will even take an ONNX model and then do TRT
         | compilation (with cache!) on the actual inference hardware
         | dynamically at model load time. This is really nice because you
         | can deploy a standard ONNX model across instances and varying
         | GPU hardware and always get TRT optimized and compatible with
         | Compute Capability, etc. Really handy and basically comes down
         | to a few lines of config in the model configuration.
         | 
         | I'm not terribly familiar with JAX but I have to imagine
         | there's ONNX export or straight to TRT export somewhere.
        
       | mccoyb wrote:
       | JAX is such a beautiful system. There's many deep PL ideas in the
       | design of JAX, one could spend years thinking about them. It's
       | wonderfully fun to design new interpreters, implementing some
       | semantics, stage them out -- automatically gain access to JIT
       | compilation and accelerators via JAX's other higher-order
       | primitives/transformations.
       | 
       | I've become a big believer that it would be beneficial for PL
       | research in ML which makes heavy use of program transformations
       | to provide small JAX-based implementations. There's really no
       | other system which allows you to express interpreter-based
       | transformations with the benefits that JAX provides (maybe
       | `functorch` in a few months? I have some doubts of transformation
       | composition with systems like torchdynamo - but I don't know much
       | about it)
       | 
       | Edit: note this is coming from a long time Julia stan, take that
       | for what it is worth :)
        
         | UncleOxidant wrote:
         | Since you're a Julia stan: Don't you think that the program
         | transformations that JAX is doing could be done much more
         | easily in Julia since Julia has macros? Aren't there things
         | that are similar to JAX in the Julia ecosystem? (ie. a few
         | different autodiff packages that do program transformation)
        
           | mccoyb wrote:
           | No -- macros are a non-intrusive transformation -- they don't
           | allow you to transform callees of a function, unless you wrap
           | the macro around the callee. People have tried this in Julia,
           | and it's horribly slow.
           | 
           | There's another mechanism in Julia - generated functions.
           | These allow method body specialization given type knowledge
           | about the signature of the function -- so a user can write
           | code which is generated for the method body when inference
           | determines the signature (and the inferred signature is tight
           | enough) which depends on the inferred types.
           | 
           | All of Julia's program transformation based AD packages are
           | based on the latter transformation -- most of them do
           | terrible things to the compiler, including massively blowing
           | up the size of code before optimization.
           | 
           | The only package which is more promising is Diffractor -- but
           | I'm not convinced it is more than a research prototype at its
           | current level of development. That may change. This was
           | written by one of the compiler devs, and uses lower level
           | hooks into the compiler, developed to support its
           | transformation.
           | 
           | The big issue in general: Julia doesn't let you write
           | transformations on its typed IR from user space, unless you
           | want to ignore Julia's native execution engine. There are
           | hooks that someone can work with -- but they aren't user-
           | facing (for all but the most advanced users) -- and they
           | break pass composability with code generation using the
           | native engine (this may have changed since I last looked at
           | this!) I would know, because I've spent several attempts
           | trying to do stuff like this, and making crappy, unstable
           | packages :)
           | 
           | Separately: macros are one level of reflection -> code
           | generation. JAX supports a different form -- you can't emit
           | data which represents generic expressions -- it's not quite
           | like Lisp in that sense. It's better to think about JAX as a
           | "two-level" language system -- where you have a meta-level
           | which is Python, and there's a statically typed array
           | language which is the object level. JAX supports a stage
           | -like operation which allows transforming compat subset of
           | Python to the statically typed array language. But you can
           | write interpreters in the full dynamism of Python -- as long
           | as the "active paths" (under tracers) are in that compat set,
           | you can then stage out applications of the interpreters on
           | Python functions, etc.
           | 
           | JAX provides one solution to the composable transformation
           | problem, and they've done it in an elegant way - that's
           | ~pretty~ easy to understand (c.f. the Autodidax tutorial).
           | With my current knowledge of things, I can't effectively
           | argue that Julia supports the same right now (caveat: things
           | may have changed since I last had a look). This is an area
           | where a lot of stuff seems to be going on in Julia, so I
           | doubt it will remain this way forever.
        
       | jdeaton wrote:
       | The abstractions provided by JAX for parallelism are beautiful.
       | JAX is an absolute master-class in programming-interface design
       | and a lesson in the power of providing composable primitive
       | operations and FP inspired design. An astounding amount of
       | complexity is hidden from the user behind primitives like pmap,
       | and the power is exposed in such a simple interface.
        
         | alfalfasprout wrote:
         | Agreed. Though keep in mind they built on a lot of failed
         | attempts at doing the same to get here.
        
         | 6gvONxR4sf7o wrote:
         | That's true, and is a massive part of what I love about JAX,
         | but they also form barriers in weird parts of your code,
         | preventing standard introspection tools, which is the single
         | thing I hate about JAX. The errors are amazingly opaque.
        
           | mattjjatgoogle wrote:
           | If you have any particular examples in mind, and time to
           | share them on https://github.com/google/jax/issues, we'd love
           | to try to improve them. Improving error messages is a
           | priority.
           | 
           | About introspection tools, at least for runtime value
           | debugging there is to some extent a fundamental challenge:
           | since jax.jit stages computation out of Python (though
           | jax.grad and jax.vmap don't), it means standard Python
           | runtime value inspection tools, like printing and pdb, can't
           | work under a jax.jit as the values aren't available as the
           | Python code is executing. You can always remove the jax.jit
           | while debugging (or use `with jax.disable_jit(): ...`), but
           | that's not always convenient, and we need jax.jit for good
           | performance.
           | 
           | We recently added some runtime value debugging tools which
           | work even with jax.jit-staged-out code (even in automatically
           | parallelized code!), though they're not the standard
           | introspection tools: see `jax.debug.print` and
           | `jax.debug.breakpoint` on
           | https://jax.readthedocs.io/en/latest/debugging/index.html and
           | https://jax.readthedocs.io/en/latest/debugging/print_breakpo.
           | ...
           | 
           | If you were thinking about other kinds of introspection
           | tooling, I'd love to hear about it!
        
             | 6gvONxR4sf7o wrote:
             | > with jax.disable_jit(): ...
             | 
             | That's handy, and I hadn't seen it before, thanks.
             | 
             | It's been a bit, but I think the most frustrating errors
             | were around mapping pytrees (like this issue
             | https://github.com/google/jax/issues/9928). I'm not sure
             | the exact solution, but the axis juggling and
             | specifications were where I remember a lot of pain, and the
             | docs (though extensive) were unclear. At times it feels
             | like improvements are punted on in the hopes that xmap
             | eventually fixes everything (and xmap has been in
             | experimental for far longer than I expected).
             | 
             | Also the barriers where I couldn't disable jit. IIRC pmap
             | automatically jits, so there was no way to avoid staging
             | that part out. When it came to doing some complex
             | jax.lax.ppermute, it felt more difficult than it needed to
             | be to debug.
             | 
             | Next time I encounter something particularly opaque, I'll
             | share on the github issue tracker.
        
               | mattjjatgoogle wrote:
               | Thanks for taking the time to explain these.
               | 
               | > It's been a bit, but I think the most frustrating
               | errors were around mapping pytrees (like this issue
               | https://github.com/google/jax/issues/9928).
               | 
               | We've improved some of these pytree error messages but it
               | seems that vmap one is still not great. Thanks for the
               | ping on it.
               | 
               | > Also the barriers where I couldn't disable jit. IIRC
               | pmap automatically jits, so there was no way to avoid
               | staging that part out.
               | 
               | That was indeed a longstanding issue in pmap's
               | implementation. And since people came to expect jit to be
               | "built in" to pmap, it wasn't easy to revise.
               | 
               | However, we recently
               | (https://github.com/google/jax/pull/11854) made
               | `jax.disable_jit()` work with pmap, in the sense that it
               | makes pmap execute eagerly, so that you can print/pdb/etc
               | to your heart's content. (The pmap successor, shard_map
               | (https://jax.readthedocs.io/en/latest/jep/14273-shard-
               | map.htm...), is eager by default. Also it has uniformly
               | good error messages from the start!)
               | 
               | > Next time I encounter something particularly opaque,
               | I'll share on the github issue tracker.
               | 
               | Thank you for the constructive feedback!
        
               | 6gvONxR4sf7o wrote:
               | Thanks! One last thing, since I have your ear. The
               | function transformation aspects of jax seem to make their
               | way into downstream libraries like haiku, resulting in a
               | lot of "magic" that can be difficult to examine and
               | debug. Are there any utils you made to make jax's own
               | transformations more transparent, which you think might
               | be helpful to third party transformations?
               | 
               | Higher order functions are difficult in general, and it
               | would be fantastic to have core patterns or tools for
               | breaking them open.
        
         | mattjjatgoogle wrote:
         | Thanks for the kind words! We've been doing a lot more work in
         | this direction too, for both compiler-based automatic
         | parallelization [0] and a work-in-progress pmap successor for
         | 'manual' parallelism (per-device code and explicit collectives)
         | [1] which composes seamlessly with the compiler-based stuff.
         | 
         | [0]
         | https://jax.readthedocs.io/en/latest/notebooks/Distributed_a...
         | 
         | [1] https://jax.readthedocs.io/en/latest/jep/14273-shard-
         | map.htm...
        
       | uptownfunk wrote:
       | It seems like the ecosystem is still dominated by PyTorch, is Jax
       | supposed to be a competitor? Any signs of Jax taking over PyTorch
       | anytime soon? Is it perhaps too early for its time? Or is there a
       | critical flaw in the underlying design?
        
         | mccoyb wrote:
         | I think it's young - and perhaps JAX itself is not so
         | specialized to a specific task (but there's plenty of libraries
         | for deep learning focused tooling, although not as mature as
         | PyTorch). It has often been said in other threads on JAX, but
         | it feels like a very different type of library from other AD
         | systems -- the focus on concisely exposing/allowing users to
         | express composable transformations seems novel! (But I may be
         | mistaken)
         | 
         | But in general, I would suspect youth.
        
         | time_to_smile wrote:
         | I think it's better to think of JAX as a more general framework
         | for differentiable programming and PyTorch more focused
         | specifically on deep learning/neural networks.
         | 
         | The beauty of JAX is that basic usage is basically a single
         | function: `grad`.
         | 
         | You just write whatever Python function you want and can get
         | the derivative/gradient of it trivially. It gets a bit trickier
         | when you need more sophisticated numeric tools like
         | numpy/scipy, but in those cases it's just about swapping out
         | with a JAX version of those.
         | 
         | In this sense JAX is the spiritual success to Autograd. However
         | the really amazing thing about JAX is that not only do you get
         | the autodiff for basically free, you also get very good
         | performance, and basically GPU parallelism without needing to
         | think about it at all.
         | 
         | PyTorch is an awesome library, but largely focus on building
         | Neural Networks specifically. JAX should be thought of a tool
         | that basically any Python programmer can just throw in there
         | whenever they come across a problem that benefits from having
         | differentiable code (which is a lot of cases once you start
         | thinking about differentiation as a first class feature).
        
       ___________________________________________________________________
       (page generated 2023-02-24 23:00 UTC)