2022年使用JAX

2022-02-17 02:09:21

自JAX于2018年底问世以来,它的受欢迎程度一直在稳步增长,这是有充分理由的。DeepMind在2020年宣布,它正在使用JAX来加速其研究,来自Google Brain和其他公司的越来越多的出版物和项目正在使用JAX。有了这么多的热议,JAX似乎是下一个大型深度学习框架,对吧?

错误的在本文中,我们将阐明JAX是什么(不是),为什么应该关心(或不应该,但可能应该),以及是否应该(或不应该)使用它。

如果你';如果您已经熟悉JAX,并且想跳过基准测试,那么您可以在这里跳转到我们关于何时使用它的建议

最好从JAX不是的东西开始。JAX不是一个深度学习框架或库,它本身也不是一个深度学习框架或库。总之,JAX是一个数字计算库,它包含了可组合函数转换[1]。正如我们所见,深度学习只是JAX所能做的一小部分:

简而言之,就是速度。这是JAX的通用方面,与任何用例都相关。

让';s用NumPy和JAX求矩阵的前三次幂的和(按元素)。首先是我们的NumPy实现:

我们发现这个计算大约需要478毫秒。接下来,我们用JAX实现这个计算:

JAX只需5.54毫秒即可完成此计算,比NumPy快86倍多。

事情并不像"那么简单;使用JAX,你的程序会快86倍;,但使用JAX仍然有很多理由。由于JAX为科学计算提供了一个通用的基础,对于不同的领域,不同的人有不同的原因。基本上,如果你在任何与科学计算相关的领域,你都应该关心JAX。

1.加速器上的NumPy——NumPy是使用Python进行科学计算的基本软件包之一,但它只与CPU兼容。JAX提供了NumPy的一个实现(具有几乎相同的API),可以非常轻松地在GPU和TPU上工作。对于许多用户来说,仅此一点就足以证明使用JAX的合理性。

2、XLA—XLA,或加速线性代数,是一个专门为线性代数设计的完整的程序优化编译器。Jax是基于XLA构建的,极大地提高了计算速度上限[1 ]。

3、JIT JAX允许您使用XLA(7)将您自己的函数转换成即时(JIT)编译版本。这意味着您可以通过在计算函数中添加一个简单的函数装饰器,将计算速度潜在地提高几个数量级。

4.自动区分——JAX文档将JAX称为";AutoGrad和XLA,汇集了34个;[ 1]. 自动区分的能力在科学计算的许多领域都至关重要,JAX提供了几种强大的自动区分工具。

5、深度学习——虽然不是一个深刻的学习框架本身,但JAX当然为深入学习的目的提供了足够多的基础。有许多基于JAX构建的库寻求构建深度学习能力,包括Flax、俳句和挽歌。在最近的PyTorch vs TensorFlow文章中,我们甚至强调JAX是一个值得关注的“框架”,建议将其用于基于TPU的深度学习研究。JAX';Hessians的高效计算也与深度学习相关,因为它们使高阶优化技术更加可行。

6.通用可微编程范式——虽然使用JAX来构建和训练深度学习模型当然是可能的,但它也为通用可微编程提供了一个框架。这意味着,通过使用基于模型的机器学习方法来解决问题,JAX可以利用经过几十年研究积累起来的特定领域的先验知识。

XLA,或加速线性代数,正是JAX强大的基础。由谷歌开发的XLA是一种基于领域的、基于图形的、即时的线性代数编译器[2 ],它可以通过各种全局程序优化显著地提高计算速度[3 ]。

在一个例子(2)中,XLA单独从计算角度提高了伯特训练速度几乎7.3倍,但是由于使用XLA也使得存储器使用率降低,从而使得梯度累加,导致计算吞吐量的惊人增长12倍。

XLA被烘烤成JAX的DNA,从他们的标志中你可以看到JAX的成功依赖XLA。

正确回答为什么XLA是如此大的交易可以产生一个非常技术性的(和长期的)讨论。对于我们的目的,足够的XLA是重要的,因为它大大提高了执行速度,并通过融合低级操作降低内存使用。

XLA不预先将单个操作编译成计算核,而是将整个图编译成一个专门为该图生成的计算内核序列。

这种方法通过不执行不必要的内核启动以及利用局部信息进行优化来提高速度[3]。由于XLA不在操作序列中实现中间数组(而是在GPU寄存器中保持值并将它们流到3),使用XLA也减少了内存消耗。

这种降低的内存消耗会产生进一步的速度提升,因为(i)内存通常是用GPU计算的限制因素,并且(ii)XLA不会浪费执行无关数据移动的时间。

虽然操作融合(或核融合)是XLA的旗舰特征,但应该注意到XLA还执行了大量其他的整体程序优化,例如专门针对已知张量形状(允许更积极的恒定传播),分析和调度内存使用以消除中间存储缓冲器[4 ],执行内存布局操作,并仅计算请求值的子集(如果不是全部返回的话)[5]。

由于所有Jax操作都是在XLA的操作中实现的,JAX有一个统一的计算语言,允许它在CPU、TPU和GPU之间无缝运行,而库调用及时编译和执行(1)。

如果上面的术语没有一个对你有意义,不要担心——只知道XLA是一个非常快的编译器,它是JAX在各种各样的硬件上使用的唯一强大和简单的基础。

到目前为止,我们已经谈到了XLA,以及它如何允许JAX在加速器上实现NoMPY;但请记住,这只是我们对JAX定义的一半。JAX不仅为强大的科学计算提供了工具,还为可组合的函数转换提供了工具。

简单地说,函数变换是一个函数上的运算符,其输出是另一个函数。如果我们对标量值函数f(x)使用梯度函数变换,那么我们得到一个向量值函数f';(x) 它给出了函数在f(x)域中任意点的梯度。

JAX为此类功能转换整合了一个可扩展系统,并有四个典型用户感兴趣的主要转换:

让我们依次看看这些转变,并讨论它们的原因';你太激动人心了。

为了能够训练机器学习模型,需要能够执行反向传播。与TensorFlow或Pytork通过在计算图中反向传播来计算损失函数在某一点的梯度不同,JAX grad()函数变换输出梯度函数,然后可以在其域中的任何点对其进行计算。

JAX中的自动区分功能非常强大,这部分源于JAX在“何处”可以计算梯度方面的灵活性。使用grad(),您可以通过本机Python和NumPy函数[6]进行区分,例如循环、分支、递归、闭包和“PyTrees”(例如字典)。

让我们看一个例子——我们将用Python控制流定义一个经过修正的立方体函数f(x)=abs(x3)。这个实现显然不是计算效率最高的方法,但它帮助我们强调grad()如何通过原生Python控制流和嵌套在条件中的循环工作。

def_立方体(x):如果x<;0.:对于范围(3)中的i:r*=xr=-r其他:对于范围(3)中的i:r*=x返回rgradient_函数=grad(矫正的_立方体)打印(f";x=2f(x)={矫正的_立方体(2.)}f';(x) =3*x^2={gradient_函数(2.)}")打印(f";x=-3F(x)={正方体(-3.)}f';(x) =-3*x^2={gradient_函数(-3.)}")

x=2f(x)=8.0f';(x) =3*x^2=12.0x=-3f(x)=27.0f';(x) =-3*x^2=-27.0

我们可以看到,在x=2和x=-3时计算函数及其导数时,我们得到了预期的结果。

通过重复应用grad(),JAX可以轻松区分任何顺序。

#对于x>;=0:f(x)=x^3=>;f';(x) =3*x^2=>;f''(x) =3*2*x=>;f'''(x) =6third_deriv=grad(grad(grad(grad(rectived_cube)))表示范围(5)中的i:打印(third_deriv(float(i)))

我们可以看到,对函数的三阶导数的几个输入求值得到f'的恒定预期输出''(x) =6。

从更一般的角度来看,快速、简单地获取多个导数的能力对于深度学习以外的许多更一般的计算领域都有实际用途,例如动力系统的研究。

正如您所料,grad()采用标量值函数的梯度,这意味着将标量/向量映射到标量的函数。这种函数的梯度对于反向传播非常有用,例如,我们通过从(标量)损失函数反向传播来更新模型权重来训练模型。

虽然grad()对于各种项目来说都足够了,但它并不是JAX可以执行的唯一一种差异化类型。

对于将向量映射到向量的向量值函数,与梯度类似的是雅可比矩阵。通过函数转换jacfwd()和jacrev(),JAX返回一个函数,当在域中的某个点求值时,该函数将生成雅可比矩阵。

def映射(v):x=v[0]y=v[1]z=v[2]返回jnp。数组([x*x,y*z])#3个输入,2个输出#[d/dx x^2,d/dy x^2,d/dz x^2]#[d/dx y*z,d/dy*z,d/dz y*z]#[2*x,0,0]#[0,z,y]f=jax。jacfwd(映射)v=jnp。数组([4,5,9.]))印刷品(f(v))

例如,您也可以使用雅可比矩阵,以便更有效地计算函数相对于数据矩阵中每个基准的权重矩阵的梯度。

从深度学习的角度来看,JAX最令人兴奋的一个方面可能是,它使计算黑森人变得极其简单和高效。由于XLA,JAX可以比PyTrac计算Hessian的速度快得多,这使得实现高阶优化技术如AdHessian更加实用。这一事实本身就足以为一些从业者提供使用JAX的理由。

最慢的跑步比最快的跑长8.14倍。这可能意味着正在缓存中间结果。10次循环,最佳5次:每次循环16.3毫秒

如我们所见,计算大约需要16.3毫秒。Let';让我们在JAX中尝试同样的计算:

最慢的跑步比最快的跑长47.27倍。这可能意味着正在缓存中间结果。1000圈,最佳5圈:每圈1.55毫秒

JAX甚至可以计算雅可比矢量积和雅可比矢量积。考虑光滑流形之间的光滑映射。JAX可以计算这个映射的推进,将一个流形上的点的切向量映射到另一个流形上的切向量。

如果这部分令人困惑或不熟悉,不要担心!这是一个高级主题,可能与典型用户无关。我们指出这种能力的存在只是为了强调JAX为各种各样的计算任务提供了非常强大的基础。例如,向前推在微分几何领域很重要,我们可以使用JAX来研究。

通过数学转换到更实际的/计算转换,我们得到了vmap()。考虑一下我们想在一组对象上重复应用一个函数的情况。让我们考虑,例如,添加两个数字列表的任务。实现这种操作的简单方法是简单地使用for循环,即对于第一个列表中的每个数字,将其添加到第二个列表中的相应值,并将结果写入一个新列表。

通过vmap()转换,JAX执行相同的计算,但将循环向下推到基本操作以获得更好的性能[6],从而生成计算的自动矢量化版本。

当然,我们可以简单地将列表定义为JAX数组,并使用JAX';s数组添加,但由于许多原因,vmap()仍然很有用。

一个基本原因是,我们可以用更多的本地Python代码编写操作,然后使用vmap()编写操作,从而生成高度Pythonic的、可能更可读的代码。另一个原因当然是推广到没有简单的矢量化替代方案来实现的情况。

分布式计算一年比一年变得越来越重要,这在深度学习中尤其如此,正如下图所示,SOTA模型已经发展到绝对天文数字。例如,GPT-4将有超过100万亿个参数。

我们';以上讨论了如何利用XLA,JAX可以轻松地计算加速器,但JAX也可以容易地用多个加速器计算,用单个命令PMAP-()来执行SPMD程序的分布式训练。

考虑向量矩阵乘法的例子。假设我们通过顺序计算向量与矩阵每一行的点积来执行这个计算。我们需要一次一个地通过硬件完成这些计算。

使用JAX,只需将操作包装在pmap()中,我们就可以轻松地将这些计算分布到4个TPU中。这允许我们在每个TPU上同时执行一个点积,显著提高了计算速度(对于大型计算)。

这里非常值得注意的是,对我们的代码所做的更改是如此之小。由于JAX是建立在XLA上的,我们可以轻松地将计算映射到硬件。

即时编译(Just-in-time,简称JIT编译)是一种执行介于解释和提前编译(AoT)之间的代码的方法。重要的事实是,JIT编译器会在运行时将代码编译成快速的可执行文件,代价是第一次运行的速度较慢。

在JIT编译中,代码是在运行时编译的,因此在程序第一次运行时,由于需要编译和执行代码,因此会有一些初始开销。因此,AoT编译在第一次通过时可能会优于JIT;然而,对于重复执行,JIT编译的程序将使用之前编译的缓存代码来快速执行。JIT编译的程序在理论上可以比AoT编译的同一程序运行得更快,因为JIT编译器可以利用代码在将在其上执行的同一台机器上编译的事实,使用本地信息进行优化。

线条会变得模糊。例如,当Python运行时,它被编译成字节码,然后由Python的虚拟机(例如CPython)解释字节码,或者编译成机器码(PyPy)。如果这些细节令人困惑,请不要';别担心。重要的一点是,JIT编译JAX程序允许它们以极快的速度执行。

XLA原语是JIT编译的,但是JAX也允许JIT将自己的Python函数编译成XLA优化内核,既可以作为函数装饰器JIT,也可以作为函数本身JITE()1。

JIT将一次一次的操作调度到GPU中,而不是使用XLA将操作序列编译成一个内核,给出了函数的端到端编译的、高效的XLA实现[6 ] [7 ]。

为了提供一个例子,让我们定义一个函数来计算一个值矩阵的前三次幂之和。我们在一个5000 x 5000的矩阵上计算这个函数三次——一次使用NumPy,一次使用JAX,一次使用JIT编译版本的JAX。首先,我们在CPU上进行实验:

def fn(x):返回x+x*x+x*x*xx_np=np。随机的兰登(5000,5000)。aType(dtype=';float32';)x_jnp=jnp。数组(x_np)%timeit-n5-r5 fn(x_np)%timeit fn(x_jnp)。阻塞_,直到_ready()jitted=jit(fn)jitted(x_jnp)%timeit jitted(x_jnp)。阻塞_直到_就绪()

警告:absl:未找到GPU/TPU,正在返回CPU。(将TF_CPP_MIN_LOG_LEVEL设置为0,然后重新运行以获取更多信息。)5圈,最佳5:151毫秒/圈10圈,最佳5:109毫秒/圈100圈,最佳5:17.7毫秒/圈

我们看到JAX比NumPy快近40%,当我们对函数进行JIT时,我们发现JAX比NumPy快8.5倍。这些结果已经令人印象深刻,但让我们';提高赌注,让JAX在TPU上计算:

在本例中,我们看到JAX比NumPy快9.3倍,如果我们都在TPU上JIT函数和计算,我们会发现JAX比NumPy快57倍。

当然,速度的大幅提高并非没有代价。JAX对JIT允许使用哪些函数进行了限制,尽管通常允许使用只涉及上述NumPy操作的函数。此外,通过Python控制流进行JITting也有一些限制,因此';在编写函数时,我们必须记住这一点。

在使用jit之前,您应该确保了解它是如何工作的,以及在什么情况下允许使用它。如果你不了解这一点,但无论如何都要尝试使用jit,你要么会收到让你困惑的错误消息(如果你幸运的话),要么会收到未经跟踪且不受欢迎的副作用,这些副作用会悄悄地影响结果的准确性(如果你不幸运的话)。

JAX有4种主要的函数转换——grad()用于自动区分函数,vmap()用于自动矢量化操作,pmap()用于SPMD程序的并行计算,jit()用于将函数转换为jit编译版本。这些转换(大部分)是可组合的,非常强大,并且有可能使您的程序加速几倍。

我们看到了XLA和基本JAX转换如何有可能显著地提高程序的性能。虽然JAX非常强大,有可能在许多领域显著提高生产率,但它的使用需要谨慎。特别是如果您正在考虑从PyTorch或TensorFlow转移到JAX,您应该了解JAX的基本理念与两个深度学习框架截然不同。我们';我现在来谈谈主要的区别。

JAX的主要区别在于,它的转换和编译只适用于功能纯粹的程序。虽然如果你只是想使用JAX在GPU或TPU上进行NumPy计算,这一事实可能并不重要,但它与大量潜在的JAX应用程序有关,因此你应该确保在开始之前理解采用这种范式的含义。

纯函数的中心特征是引用透明性——纯函数可以随时用其求值结果替换,程序无法分辨两者之间的差异。在给定相同输入的情况下,函数应始终对程序具有相同的效果,而不管它是在什么时间或上下文中执行的。

这在原则上听起来很简单,但确实存在

......