[HN Gopher] 4000x Speedup in Reinforcement Learning with Jax
       ___________________________________________________________________
        
       4000x Speedup in Reinforcement Learning with Jax
        
       Author : _hark
       Score  : 18 points
       Date   : 2023-04-06 21:46 UTC (1 hours ago)
        
 (HTM) web link (chrislu.page)
 (TXT) w3m dump (chrislu.page)
        
       | _hark wrote:
       | jax.vmap() is all you need?
        
         | schizo89 wrote:
         | Not only vectorization, but the plethora of environments
         | written in jax. Hopefully someone will port MuJoCo to jax soon
        
       | sillysaurusx wrote:
       | It's a little disingenuous to say that the 4000x speedup is due
       | to Jax. I'm a huge Jax fanboy (one of the biggest) but the
       | speedup here is thanks to running the simulation environment on a
       | GPU. But as much as I love Jax, it's still extraordinarily
       | difficult to implement even simple environments purely on a GPU.
       | 
       | My long-term ambition is to replicate OpenAI's Dota 2
       | reinforcement learning work, since it's one of the most impactful
       | (or at least most entertaining) use of RL. It would be more or
       | less impossible to translate the game logic into Jax, short of
       | transpiling C++ to Jax somehow. Which isn't a bad idea - someone
       | should make that.
       | 
       | It should also be noted that there's a long history of RL being
       | done on accelerators. AlphaZero's chess evaluations ran entirely
       | on TPUs. Pytorch CUDA graphs also make it easier to implement
       | this kind of thing nowadays, since (again, as much as I love Jax)
       | some Pytorch constructs are simply easier to use than turning
       | everything into a functional programming paradigm.
       | 
       | All that said, you should really try out Jax. The fact that you
       | can calculate gradients w.r.t. any arbitrary function is just
       | amazing, and you have complete control over what's JIT'ed into a
       | GPU graph and what's not. It's a wonderful feeling compared to
       | using Pytorch's accursed .backwards() accumulation scheme.
       | 
       | Can't wait for a framework that feels closer to pure arbitrary
       | Python. Maybe AI can figure out how to do it.
        
       | schizo89 wrote:
       | Neural differential equations are also easier with jax. sim2real
       | may be easier with simulator where some of hard computations are
       | replaced with neural approximations
        
       ___________________________________________________________________
       (page generated 2023-04-06 23:00 UTC)