[HN Gopher] 4000x Speedup in Reinforcement Learning with Jax ___________________________________________________________________ 4000x Speedup in Reinforcement Learning with Jax Author : _hark Score : 18 points Date : 2023-04-06 21:46 UTC (1 hours ago) (HTM) web link (chrislu.page) (TXT) w3m dump (chrislu.page) | _hark wrote: | jax.vmap() is all you need? | schizo89 wrote: | Not only vectorization, but the plethora of environments | written in jax. Hopefully someone will port MuJoCo to jax soon | sillysaurusx wrote: | It's a little disingenuous to say that the 4000x speedup is due | to Jax. I'm a huge Jax fanboy (one of the biggest) but the | speedup here is thanks to running the simulation environment on a | GPU. But as much as I love Jax, it's still extraordinarily | difficult to implement even simple environments purely on a GPU. | | My long-term ambition is to replicate OpenAI's Dota 2 | reinforcement learning work, since it's one of the most impactful | (or at least most entertaining) use of RL. It would be more or | less impossible to translate the game logic into Jax, short of | transpiling C++ to Jax somehow. Which isn't a bad idea - someone | should make that. | | It should also be noted that there's a long history of RL being | done on accelerators. AlphaZero's chess evaluations ran entirely | on TPUs. Pytorch CUDA graphs also make it easier to implement | this kind of thing nowadays, since (again, as much as I love Jax) | some Pytorch constructs are simply easier to use than turning | everything into a functional programming paradigm. | | All that said, you should really try out Jax. The fact that you | can calculate gradients w.r.t. any arbitrary function is just | amazing, and you have complete control over what's JIT'ed into a | GPU graph and what's not. It's a wonderful feeling compared to | using Pytorch's accursed .backwards() accumulation scheme. | | Can't wait for a framework that feels closer to pure arbitrary | Python. Maybe AI can figure out how to do it. | schizo89 wrote: | Neural differential equations are also easier with jax. sim2real | may be easier with simulator where some of hard computations are | replaced with neural approximations ___________________________________________________________________ (page generated 2023-04-06 23:00 UTC)