[HN Gopher] Useful algorithms that are not optimized by Jax, PyT...
       ___________________________________________________________________
        
       Useful algorithms that are not optimized by Jax, PyTorch, or
       TensorFlow
        
       Author : ChrisRackauckas
       Score  : 152 points
       Date   : 2021-07-21 11:10 UTC (1 days ago)
        
 (HTM) web link (www.stochasticlifestyle.com)
 (TXT) w3m dump (www.stochasticlifestyle.com)
        
       | cjv wrote:
       | ...doesn't the JAX example just need the argument set to
       | static_argnums and then it will work?
        
         | ChrisRackauckas wrote:
         | static_argnums is really just a way to give a bit more
         | assumptions to attempt to build a quasi-static code even if
         | it's using dynamic constructs. In this example that will force
         | it to trace one only one of the two branches (depending on
         | whichever static_argnums sends it down). That is going to
         | generate incorrect code for input values which should've traced
         | the other branch (so the real solution of `lax.cond` is to
         | always trace and always compute both branches, as mentioned in
         | the post). If the computation is actually not quasi-static,
         | there's no good choice for a static argnum. See the factorial
         | example.
        
           | cjv wrote:
           | Ah, thanks for the explanation.
        
       | ssivark wrote:
       | There are many interesting threads in this post, one of which is
       | using "non standard interpretations" of programs, and enabling
       | the compiler to augment the human-written code with the extra
       | pieces necessary to get gradients, propagate uncertainties, etc.
       | I wonder whether there's a more unified discussion of the
       | potential of these methods. I suspect that a lot of "solvers"
       | (each typically with their own DSL for specifying the problem)
       | might be nicely formulated in such a framework. (Particularly in
       | the case of auto diff, I found recent work/talks by Conal Elliot
       | and Tom Minka quite enlightening.)
       | 
       | Tangentially, thinking about Julia, while one initially gets awed
       | by the speed, and then the multiple dispatch, I wonder whether
       | it's deepest superpower (that we're still discovering) might be
       | the expressiveness to augment the compiler to do interesting
       | things with a piece of code. Generic programming then acts as a
       | lever to use these improvements for a variety of use cases, and
       | the speed is merely the icing on the cake!
        
         | shakow wrote:
         | > the expressiveness to augment the compiler to do interesting
         | things with a piece of code.
         | 
         | Julia has very interesting propositions on the subject, from
         | language-level autodiff (https://fluxml.ai/Zygote.jl/latest/)
         | to automated probabilistic programming (https://turing.ml/dev/)
         | through DEs (https://diffeq.sciml.ai/stable/) and optimization
         | (https://jump.dev/).
         | 
         | The whole ecosystem is in ebullition, and I'm very eager to see
         | if it will be able to transform in the comping years into a
         | solid foundation able to rival the layers of warts stacked on
         | top of Python.
        
         | mccoyb wrote:
         | Just a comment: you're right on the money here. This is the
         | dream that a few people in the Julia community are working
         | towards.
         | 
         | The framework of abstract interpretation, when combined with
         | multiple dispatch as a language design feature, is absolutely
         | insane.
         | 
         | I think programming language enthusiasts might meditate on
         | these points --- and get quite excited with the direction that
         | the Julia compiler implementation is heading.
        
       | marcle wrote:
       | There is no free lunch:).
       | 
       | I remember spending a summer using Template Model Builder (TMB),
       | which is a useful R/C++ automatic differentiation (AD) framework,
       | for working with accelerated failure time models. For these
       | models, the survival to time T given covariates X is defined by
       | S(t|X) = P(T>t|X) = S_0(t exp(-beta^T X)) for baseline survival
       | S_0(t). I wanted to use splines for the baseline survival and
       | then use AD for gradients and random effects. Unfortunately,
       | after implementing the splines in template C++, I found a web
       | page entitled "Things you should NOT do in TMB"
       | (https://github.com/kaskr/adcomp/wiki/Things-you-should-NOT-d...)
       | - which included using if statements that are based on
       | coefficients. In this case, the splines for S_0 depend on beta,
       | which is this specific excluded case:(. An older framework (ADMB)
       | did not have this constraint, but dissemination of code was more
       | difficult. Finally, PyTorch did not have an implementation of
       | B-splines or an implementation for Laplace's approximation.
       | Returning to my opening comment, there is no free lunch.
        
         | hyperbovine wrote:
         | Were you optimizing over the knots as well? Otherwise I can't
         | see why this would be disallowed using either forward or
         | reverse-mode AD. An infinitesimal perturbation of beta will not
         | cause t * exp(-beta^T x) to cross a knot, so the whole thing is
         | smooth. (And, with B-splines the derivatives are continuous
         | from piece to piece anyways.) But in general I agree--a good
         | spline implementation I something I miss the most when moving
         | from scipy.interpolate to jax.scipy. Given that the SciPy
         | implementation is mostly F77 code written before I was born, I
         | do not see this situation resolving itself anytime soon.
        
           | svantana wrote:
           | It's not about smoothness, it's about how to JIT the gradient
           | function. ML libs don't generally do interpolation, partly
           | because it's tricky to vectorize (you have to search for
           | which segment to use for each element) and partly because
           | most ML practioners don't need it. What I've done in my code
           | is use all the vertices for all the elements, but with
           | weights that are mostly zero. It's pretty fast on GPU because
           | I don't use that many vertices.
        
         | ChrisRackauckas wrote:
         | There is definitely no free lunch, it's good to really
         | delineate the engineering trade-offs you're making! A lot of
         | this work actually comes from the fact that some people I work
         | with were building tools that could efficiently handle dynamic
         | control flow without requiring tracing (see the description of
         | Zygote.jl https://arxiv.org/abs/1810.07951). I had to bring up
         | the question: why? It's much harder to build, needs more
         | machinery, and in some cases can make less assumptions/less
         | fusions (a general form of vmap is much harder for example if
         | you cannot trace, see KernelAbstractions.jl for details). This
         | line of inquiry led an example of why you might want to support
         | such dynamic behaviors, so I'll leave it up to someone else to
         | declare whether the maintenance or complexity cost is worth it
         | to them. I wouldn't say that this means Jax or Tensorflow are
         | doomed (far from it: simple ML architectures are quasi-static,
         | so it's building for the correct audience), but it's good to
         | know what exactly you're leaving out when you make a
         | simplifying assumption.
        
       | _hl_ wrote:
       | Tangentialy related: Faster training of Neural ODEs is super
       | exciting! There are a lot of promising applications (although
       | personally I believe that the intuition of "magically choosing
       | the number of layers" is misguided, but I'm not am expert and
       | might be wrong) but right now it takes forever to train even on
       | toy problems, but I'm sure that enough work in this direction
       | will eventually lead to more practical methods.
        
       | 6gvONxR4sf7o wrote:
       | This is a really cool post.
       | 
       | It seems like you can't solve this kind of thing with a new jax
       | primitive for the algorithm, but what prevents new function
       | transformations from doing what the mentioned julia libraries do?
       | It seems like between new function transformations and new
       | primitives, you out to be able to do just about anything. Is XLA
       | the issue, and you could run but not jit the result?
        
         | ChrisRackauckas wrote:
         | XLA is the limiting factor in a lot of these cases, though
         | maybe saying limiting factor is wrong because it's more of a
         | "trade-off factor". XLA wants to know the static size of a lot
         | of arguments so it can build a mathematical description of the
         | compute graph and fuse linear algebra commands freely. What the
         | Julia libraries like Zygote do is say "there is no good
         | mathematical description of this, so I will generate source
         | code instead" (and some programs like Tapenade are similar).
         | For example, while loops are translated into for loops where a
         | stack of the Boolean choices are stored so they can be ran in
         | reverse during the backpass. The Julia libraries can sometimes
         | have more trouble automatically fusing linear algebra commands
         | though, since then they need to say "my IR lets non-static
         | things occur, so therefore I need to prove it's static before
         | doing transformation X". It's much easier to know you can do
         | such transformations if anything written in the IR obeys the
         | rules required for the transform! So it's a trade-off. In the
         | search for allowing differentiation of any program in the
         | language, the Julia AD tools have gone for extreme flexibility
         | (and can rely on the fact that Julia has a compiler that can
         | JIT compile any generated valid Julia code) and I find it
         | really interesting to try and elucidate what you actually gain
         | from that.
        
           | awaythrowact wrote:
           | If the next machine learning killer-app model requires
           | autodiff'ed dynamic control flow, do you think
           | Google/Facebook will build that capability into
           | XLA/TorchScript? Seems like if NLP SOTA requires dynamic
           | control flow, Google will build it? Maybe they let you
           | declare some subgraph as "dynamic" to avoid static
           | optimizations? But maybe the static graph assumption is so
           | deeply embedded into the XLA architecture, they'd be better
           | off just adopting Julia? (I honestly don't know the answer,
           | asking your opinion!)
        
             | ChrisRackauckas wrote:
             | "Maybe they let you declare some subgraph as 'dynamic' to
             | avoid static optimizations?" What you just described is
             | Tensorflow Eager and why it has some performance issues
             | (but more flexibility!). XLA makes some pretty strong
             | assumptions and I don't think that should change.
             | Tensorflow's ability to automatically generate good
             | automatically parallelized production code stems from the
             | restrictions it has imposed. So I wouldn't even try for a
             | "one true AD to rule them all" since making things more
             | flexible will reduce the amount of compiler optimizations
             | that can be automatically performed.
             | 
             | To get the more flexible form, you really would want to do
             | it in a way that uses a full programming language's IR as
             | its target. I think trying to use a fully dynamic
             | programming language IR directly (Python, R, etc.) directly
             | would be pretty insane because it would be hard to enforce
             | rules and get performance. So some language that has a
             | front end over an optimizing compiler (LLVM) would probably
             | make the most sense. Zygote and Diffractor uses Julia's IR,
             | but there are other ways to do this as well. Enzyme
             | (https://github.com/wsmoses/Enzyme.jl) uses the LLVM IR
             | directly for doing source-to-source translations. Using
             | some dialect of LLVM (provided by MLIR) might be an
             | interesting place to write a more ML-focused flexible AD
             | system. Swift for Tensorflow used the Swift IR. This
             | mindset starts to show why those tools were chosen.
        
               | awaythrowact wrote:
               | Makes sense. I don't use TF Eager, but I do use Jax, and
               | Jax lets you arbitrarily compose JITed and non-JITed
               | code, which made me think that might be a viable pattern.
               | I guess I wondered if there might be something like
               | "nonstatic_jit(foo)" that would do "julia style"
               | compiling on function foo, in addition to "jit(foo)" that
               | compiles foo to optimized XLA ops. Probably impractical.
               | Thanks.
        
           | 6gvONxR4sf7o wrote:
           | > and can rely on the fact that Julia has a compiler that can
           | JIT compile any generated valid Julia code
           | 
           | This seems to be the key bit. It's a great data point around
           | the meme of "with a sufficiently advanced compiler..." In
           | this case we have sufficiently advanced compilers to make
           | very different JIT trade offs. XLA is differently powerful
           | compared to Julia. Very cool, thanks for the insight.
        
       | ipsum2 wrote:
       | The example that fails in Jax would work fine in PyTorch. If
       | you're working on purely training the model, TorchScript doesn't
       | give many benefits, if any.
        
       ___________________________________________________________________
       (page generated 2021-07-22 23:00 UTC)