[HN Gopher] Useful algorithms that are not optimized by Jax, PyT... ___________________________________________________________________ Useful algorithms that are not optimized by Jax, PyTorch, or TensorFlow Author : ChrisRackauckas Score : 152 points Date : 2021-07-21 11:10 UTC (1 days ago) (HTM) web link (www.stochasticlifestyle.com) (TXT) w3m dump (www.stochasticlifestyle.com) | cjv wrote: | ...doesn't the JAX example just need the argument set to | static_argnums and then it will work? | ChrisRackauckas wrote: | static_argnums is really just a way to give a bit more | assumptions to attempt to build a quasi-static code even if | it's using dynamic constructs. In this example that will force | it to trace one only one of the two branches (depending on | whichever static_argnums sends it down). That is going to | generate incorrect code for input values which should've traced | the other branch (so the real solution of `lax.cond` is to | always trace and always compute both branches, as mentioned in | the post). If the computation is actually not quasi-static, | there's no good choice for a static argnum. See the factorial | example. | cjv wrote: | Ah, thanks for the explanation. | ssivark wrote: | There are many interesting threads in this post, one of which is | using "non standard interpretations" of programs, and enabling | the compiler to augment the human-written code with the extra | pieces necessary to get gradients, propagate uncertainties, etc. | I wonder whether there's a more unified discussion of the | potential of these methods. I suspect that a lot of "solvers" | (each typically with their own DSL for specifying the problem) | might be nicely formulated in such a framework. (Particularly in | the case of auto diff, I found recent work/talks by Conal Elliot | and Tom Minka quite enlightening.) | | Tangentially, thinking about Julia, while one initially gets awed | by the speed, and then the multiple dispatch, I wonder whether | it's deepest superpower (that we're still discovering) might be | the expressiveness to augment the compiler to do interesting | things with a piece of code. Generic programming then acts as a | lever to use these improvements for a variety of use cases, and | the speed is merely the icing on the cake! | shakow wrote: | > the expressiveness to augment the compiler to do interesting | things with a piece of code. | | Julia has very interesting propositions on the subject, from | language-level autodiff (https://fluxml.ai/Zygote.jl/latest/) | to automated probabilistic programming (https://turing.ml/dev/) | through DEs (https://diffeq.sciml.ai/stable/) and optimization | (https://jump.dev/). | | The whole ecosystem is in ebullition, and I'm very eager to see | if it will be able to transform in the comping years into a | solid foundation able to rival the layers of warts stacked on | top of Python. | mccoyb wrote: | Just a comment: you're right on the money here. This is the | dream that a few people in the Julia community are working | towards. | | The framework of abstract interpretation, when combined with | multiple dispatch as a language design feature, is absolutely | insane. | | I think programming language enthusiasts might meditate on | these points --- and get quite excited with the direction that | the Julia compiler implementation is heading. | marcle wrote: | There is no free lunch:). | | I remember spending a summer using Template Model Builder (TMB), | which is a useful R/C++ automatic differentiation (AD) framework, | for working with accelerated failure time models. For these | models, the survival to time T given covariates X is defined by | S(t|X) = P(T>t|X) = S_0(t exp(-beta^T X)) for baseline survival | S_0(t). I wanted to use splines for the baseline survival and | then use AD for gradients and random effects. Unfortunately, | after implementing the splines in template C++, I found a web | page entitled "Things you should NOT do in TMB" | (https://github.com/kaskr/adcomp/wiki/Things-you-should-NOT-d...) | - which included using if statements that are based on | coefficients. In this case, the splines for S_0 depend on beta, | which is this specific excluded case:(. An older framework (ADMB) | did not have this constraint, but dissemination of code was more | difficult. Finally, PyTorch did not have an implementation of | B-splines or an implementation for Laplace's approximation. | Returning to my opening comment, there is no free lunch. | hyperbovine wrote: | Were you optimizing over the knots as well? Otherwise I can't | see why this would be disallowed using either forward or | reverse-mode AD. An infinitesimal perturbation of beta will not | cause t * exp(-beta^T x) to cross a knot, so the whole thing is | smooth. (And, with B-splines the derivatives are continuous | from piece to piece anyways.) But in general I agree--a good | spline implementation I something I miss the most when moving | from scipy.interpolate to jax.scipy. Given that the SciPy | implementation is mostly F77 code written before I was born, I | do not see this situation resolving itself anytime soon. | svantana wrote: | It's not about smoothness, it's about how to JIT the gradient | function. ML libs don't generally do interpolation, partly | because it's tricky to vectorize (you have to search for | which segment to use for each element) and partly because | most ML practioners don't need it. What I've done in my code | is use all the vertices for all the elements, but with | weights that are mostly zero. It's pretty fast on GPU because | I don't use that many vertices. | ChrisRackauckas wrote: | There is definitely no free lunch, it's good to really | delineate the engineering trade-offs you're making! A lot of | this work actually comes from the fact that some people I work | with were building tools that could efficiently handle dynamic | control flow without requiring tracing (see the description of | Zygote.jl https://arxiv.org/abs/1810.07951). I had to bring up | the question: why? It's much harder to build, needs more | machinery, and in some cases can make less assumptions/less | fusions (a general form of vmap is much harder for example if | you cannot trace, see KernelAbstractions.jl for details). This | line of inquiry led an example of why you might want to support | such dynamic behaviors, so I'll leave it up to someone else to | declare whether the maintenance or complexity cost is worth it | to them. I wouldn't say that this means Jax or Tensorflow are | doomed (far from it: simple ML architectures are quasi-static, | so it's building for the correct audience), but it's good to | know what exactly you're leaving out when you make a | simplifying assumption. | _hl_ wrote: | Tangentialy related: Faster training of Neural ODEs is super | exciting! There are a lot of promising applications (although | personally I believe that the intuition of "magically choosing | the number of layers" is misguided, but I'm not am expert and | might be wrong) but right now it takes forever to train even on | toy problems, but I'm sure that enough work in this direction | will eventually lead to more practical methods. | 6gvONxR4sf7o wrote: | This is a really cool post. | | It seems like you can't solve this kind of thing with a new jax | primitive for the algorithm, but what prevents new function | transformations from doing what the mentioned julia libraries do? | It seems like between new function transformations and new | primitives, you out to be able to do just about anything. Is XLA | the issue, and you could run but not jit the result? | ChrisRackauckas wrote: | XLA is the limiting factor in a lot of these cases, though | maybe saying limiting factor is wrong because it's more of a | "trade-off factor". XLA wants to know the static size of a lot | of arguments so it can build a mathematical description of the | compute graph and fuse linear algebra commands freely. What the | Julia libraries like Zygote do is say "there is no good | mathematical description of this, so I will generate source | code instead" (and some programs like Tapenade are similar). | For example, while loops are translated into for loops where a | stack of the Boolean choices are stored so they can be ran in | reverse during the backpass. The Julia libraries can sometimes | have more trouble automatically fusing linear algebra commands | though, since then they need to say "my IR lets non-static | things occur, so therefore I need to prove it's static before | doing transformation X". It's much easier to know you can do | such transformations if anything written in the IR obeys the | rules required for the transform! So it's a trade-off. In the | search for allowing differentiation of any program in the | language, the Julia AD tools have gone for extreme flexibility | (and can rely on the fact that Julia has a compiler that can | JIT compile any generated valid Julia code) and I find it | really interesting to try and elucidate what you actually gain | from that. | awaythrowact wrote: | If the next machine learning killer-app model requires | autodiff'ed dynamic control flow, do you think | Google/Facebook will build that capability into | XLA/TorchScript? Seems like if NLP SOTA requires dynamic | control flow, Google will build it? Maybe they let you | declare some subgraph as "dynamic" to avoid static | optimizations? But maybe the static graph assumption is so | deeply embedded into the XLA architecture, they'd be better | off just adopting Julia? (I honestly don't know the answer, | asking your opinion!) | ChrisRackauckas wrote: | "Maybe they let you declare some subgraph as 'dynamic' to | avoid static optimizations?" What you just described is | Tensorflow Eager and why it has some performance issues | (but more flexibility!). XLA makes some pretty strong | assumptions and I don't think that should change. | Tensorflow's ability to automatically generate good | automatically parallelized production code stems from the | restrictions it has imposed. So I wouldn't even try for a | "one true AD to rule them all" since making things more | flexible will reduce the amount of compiler optimizations | that can be automatically performed. | | To get the more flexible form, you really would want to do | it in a way that uses a full programming language's IR as | its target. I think trying to use a fully dynamic | programming language IR directly (Python, R, etc.) directly | would be pretty insane because it would be hard to enforce | rules and get performance. So some language that has a | front end over an optimizing compiler (LLVM) would probably | make the most sense. Zygote and Diffractor uses Julia's IR, | but there are other ways to do this as well. Enzyme | (https://github.com/wsmoses/Enzyme.jl) uses the LLVM IR | directly for doing source-to-source translations. Using | some dialect of LLVM (provided by MLIR) might be an | interesting place to write a more ML-focused flexible AD | system. Swift for Tensorflow used the Swift IR. This | mindset starts to show why those tools were chosen. | awaythrowact wrote: | Makes sense. I don't use TF Eager, but I do use Jax, and | Jax lets you arbitrarily compose JITed and non-JITed | code, which made me think that might be a viable pattern. | I guess I wondered if there might be something like | "nonstatic_jit(foo)" that would do "julia style" | compiling on function foo, in addition to "jit(foo)" that | compiles foo to optimized XLA ops. Probably impractical. | Thanks. | 6gvONxR4sf7o wrote: | > and can rely on the fact that Julia has a compiler that can | JIT compile any generated valid Julia code | | This seems to be the key bit. It's a great data point around | the meme of "with a sufficiently advanced compiler..." In | this case we have sufficiently advanced compilers to make | very different JIT trade offs. XLA is differently powerful | compared to Julia. Very cool, thanks for the insight. | ipsum2 wrote: | The example that fails in Jax would work fine in PyTorch. If | you're working on purely training the model, TorchScript doesn't | give many benefits, if any. ___________________________________________________________________ (page generated 2021-07-22 23:00 UTC)