未经 Jax、PyTorch 或 TensorFlow 优化的有用算法

2021-07-24 07:00:16

在之前的一些博客文章中,我们详细描述了如何通过将图转换合并到代码生成中来泛化自动微分以自动增强稳定性和各种其他细节。但是,我们没有过多讨论的一件事是这些类型的算法的局限性。这种限制就是我们所说的“准静态”,即一种算法可以重新解释为某种静态算法的特性。事实证明,出于非常根本的原因,这与某些主要机器学习框架对它们可以完全优化的代码(例如 Jax 或 Tensorflow)强加的限制相同。这让我们产生了一个问题:在这种心态下是否存在不可优化的算法,为什么?答案现已在 ICML 2021 上发布,让我们深入研究这个更高层次的概念。首先,让我们对什么是准静态算法有一个具体的想法。它是算法空间,可以以某种方式重新表示为静态算法。可以将“静态算法”视为具有不需要完整计算机描述的简单数学描述的算法,即没有循环、重写到内存等。作为示例,让我们看一下 Jax 文档中的一个示例。下面是Jax JIT的作用: 注意它是用控制流表示的,即用循环表示的代码,但是循环不是必须的 我们也可以把这个方法理解为2*2*2 *x 或 8*x。默认情况下 JIT 将失败的示例是:@jit def f (x): if x < 3: return 3. * x ** 2 else: return - 4 * x # 这将失败! try: f ( 2 ) except Exception as e: print ( "Exception {}".format (e) ) 在这种情况下,我们可以看到本质上有两个计算图在 x<3 处拆分,因此如上所述这并没有有一个描述计算的数学语句。你可以通过执行 lax.cond(x < 3, 3. * x ** 2, -4 * x) 来解决这个问题,但请注意这是一个根本不同的计算:lax.cond 形式总是计算 if 的两边在选择要结转哪一个之前的语句,而 true if 语句根据条件更改其计算。 lax.cond 形式之所以如此适用于 Jax 的 JIT 编译系统,是因为它是准静态的。将发生的计算是固定的,即使结果不是,而原始 if 语句将更改基于输入值计算的内容。存在此限制是因为 Jax 跟踪程序以尝试在以下情况下构建静态计算图hood,然后尝试在该图上进行实际转换。还有其他类型的框架可以做类似的事情吗?事实证明,可转换为纯符号语言的算法集是准静态算法集,所以像 Symbolics.jl 这样的东西在其算法的行为中也有一种准静态表现形式。也是出于同样的原因:在符号算法中,您可以定义诸如“x”和“y”之类的符号变量,然后通过程序进行交易以构建“2x^2 + 3y”的静态计算图,然后您对其进行符号处理。在常见问题中,有一个问题是当函数到符号的转换失败时会发生什么。如果你看一下例子: function factorial (x ) out = x while x > 1 x -= 1 out *= x end out end @variables xfactorial (x ) 你可以看到这是因为算法不能表示为单个数学表达式:阶乘不能写为固定乘法次数,因为乘法次数取决于您尝试计算 x 的值 x!为了!符号语言抛出的错误是“ERROR: TypeError: non-boolean (Num) used in boolean context”,这是说它不知道如何符号扩展“while x > 1”来表示它静态地。这不是不一定“可修复”的东西,这是该算法无法由固定计算表示并且必须需要根据输入更改计算这一事实的基础。

“解决方案”是通过“@register factorial(x)”为图形定义一个新的原语,这样这个函数本身就是一个固定节点,不会尝试进行符号扩展。这与定义 Jax 原语或 Tensorflow 原语的概念相同,其中算法根本不是准静态的,因此获得准静态计算图的方法是将动态块视为函数“y = f (x)”是注定存在的。在符号语言和机器学习框架的上下文中,要使其充分发挥作用,您还需要定义所述函数的导数。最后一部分是捕获。如果您再看一下其中一些工具的文档的深度,您会注意到许多表示非静态控制流的原语都超出了完全处理的领域。在文档中,它指出您可以用 lax.while_loop 替换 while 循环,但这不适用于反向模式自动微分。原因是因为它的反向模式 AD 实现假设存在这种准静态算法并将其用于两个目的,一是用于生成反向传递,二是生成算法的 XLA(“Tensorflow”)描述,然后进行 JIT 编译优化。 XLA 需要静态计算图,对于这种情况,它也不一定存在,因此存在基本限制。解决这个问题的方法当然是用它自己的快速梯度计算定义你自己的原语,这个问题就消失了......有机器学习框架不假设准静态但也优化,大多数Julia 编程语言中的 Diffractor.jl、Zygote.jl 和 Enzyme.jl(注意 PyTorch 不假设准静态表示,尽管 TorchScript 的 JIT 编译会)。这让我思考:是否存在真正的机器学习算法,这是一个真正的限制?这是一个很好的问题,因为如果你提出像卷积神经网络这样的标准方法,那就是一个固定函数内核调用,定义了一个很好的导数,或者一个循环神经网络,这是一个固定大小的循环。如果你想打破这个假设,你必须进入一个基本上关于算法的空间,在那里你无法知道“计算量”,直到你知道问题中的特定值,而方程求解器就是这种形式。牛顿法收敛需要多少步?自适应 ODE 求解器需要多少步?这不是可以先验回答的问题:它们基本上是需要了解的问题:因此,在 Python 框架中工作的人们一直在寻找处理方程求解的“正确”方法(ODE 求解,求根 f( x)=0 等)作为黑盒表示。如果你再看一下神经常微分方程论文,它提出的一项重要建议是将神经 ODE 处理为一个黑盒,其导数由 ODE 伴随定义。原因当然是因为自适应 ODE 求解器必须迭代到容差,所以必然存在诸如“while t <tend”之类的东西,这取决于当前计算是否计算到容差。作为在他们工作的框架中没有优化的东西,这是使算法工作所必需的。不,将此类算法视为黑盒并不是根本。事实上,几年前我们有一篇相当受欢迎的论文,表明可以通过一些 Julia AD 工具直接使用正向和反向模式自动微分训练神经随机微分方程。原因是因为这些 AD 工具(Zygote、Diffractor、Enzyme 等)由于它们如何进行直接的源到源转换而不一定采用准静态形式,因此它们可以直接区分自适应求解器并吐出正确的梯度。所以你不一定要以“定义一个 Tensorflow op”的风格来做,但哪个更好?事实证明,“更好”真的很难定义,因为这两种算法不一定相同并且可以计算不同的值。您可以将其归结为:您是要对方程的求解器进行微分,还是要对方程进行微分并对其应用求解器?前者相当于算法的自动微分,称为离散灵敏度分析或离散然后优化。后者是连续敏感性分析或优化然后离散化的方法。机器学习并不是第一个遇到这个问题的领域,所以关于通用微分方程和科学机器学习生态系统的论文有一个相当长的描述,我将引用:“”“之前的研究表明离散伴随方法是在某些情况下比连续伴随更稳定 [41, 37, 42, 43, 44, 45] 而连续伴随已被证明在其他情况下更稳定 [46, 43] 并且可以减少虚假振荡 [47, 48, 49] . 离散和连续伴随方法之间的这种权衡已经在一些方程中被证明是稳定性和计算效率之间的权衡 [50, 51, 52, 53, 54, 55, 56, 57, 58]。必须小心被视为伴随方法的稳定性可能取决于所选的离散化方法 [59, 60, 61, 62, 63],我们的软件贡献帮助研究人员在所有这些优化方法之间切换,并结合数百个微分方程求解er 方法只需一行代码更改。 """

或者,tl; dr:有大量先前的研究表明,连续伴随词不如离散伴随词稳定,但它们可以更快。我们最近进行了跟进,表明这些说法在现代软件的现代问题上是正确的。具体来说,这篇关于刚性神经 ODE 的论文说明了为什么在多尺度数据上训练时离散伴随比连续伴随更稳定,但我们最近还表明,连续伴随在梯度计算方面比(某些)当前用于离散伴随的 AD 技术要快得多。好吧,如果您正在处理这些硬微分方程、微分偏微分方程等,那么使用离散伴随技术确实有好处,这在 80 年代在控制理论领域就已为人所知。但除此之外,这是一种洗礼,因此尚不清楚区分此类算法在机器学习中是否更好,对吗?这现在让我们了解最近的 ICML 论文如何适应这种叙述。是否存在真正对标准机器学习有用的非准静态算法?答案是肯定的,但是如何到达那里需要一些巧妙的技巧。首先,设置。神经 ODE 可能是一种有趣的机器学习方法,因为它们使用自适应 ODE 求解器从本质上为您选择层数,因此它就像一个循环神经网络(或者更具体地说,像一个残差神经网络)自动找到“正确”层数,其中层数是 ODE 求解器决定采取的步数。换句话说,用于图像处理的神经 ODE 是一种自动进行超参数优化的算法。整洁的!但是……“正确”的层数是多少?对于超参数优化,您会假设这将是“准确预测的最少层数”。但是,默认情况下,神经 ODE 不会为您提供那么多层:它们会给您任何感觉。事实上,如果你看一下原始的神经 ODE 论文,随着神经 ODE 训练它不断增加它使用的层数:那么有没有办法改变神经 ODE 以使其将“正确的层数”定义为“最少的层数”?在学习易于求解的微分方程的工作中,他们就是这样做的。他们是如何做到的,他们正则化了神经 ODE 的训练过程。他们查看了解决方案并指出,发生更多变化的 ODE 必然更难求解,因此您可以通过添加一个表示“使高阶导数项尽可能小”的正则化项将训练过程转换为超参数优化.本文的其余部分是如何实现这个想法。那是怎么做的?好吧,如果您必须将算法视为黑盒,则需要定义一些黑盒方法来定义高阶导数,从而导致 Jesse 非常酷的泰勒模式自动微分公式。但无论你怎么说,这都将是一个昂贵的计算对象:计算梯度比前向传递更昂贵,二阶导数比梯度更贵,三阶等等,所以一个需要六阶导数的算法训练会很讨厌。通过一些非常英勇的工作,他们得到了这个黑盒操作的公式,它需要两倍的时间来训练,但成功地完成了超参数优化。有没有办法通过神经 ODE 训练更快地进行自动超参数优化?是的,我们的论文不仅使它们的训练速度比其他方法快,而且比普通神经 ODE 训练速度更快。我们可以使层超参数优化不那么免费:我们可以使它比不进行优化更便宜!但是如何?诀窍是打开黑匣子。让我向您展示自适应 ODE 求解器的步骤是什么样的:请注意,自适应 ODE 求解器通过使用误差估计来选择时间步长是否合适。 ODE 算法的实际构造是为了免费计算误差估计值,即“这个 ODE 求解的难度”的估计值。如果我们使用这个自由误差估计作为我们的正则化技术会怎样?事实证明,训练速度比以前快 10 倍,同时类似地自动执行超参数优化。

请注意我们结束的地方:结果算法不一定是准静态的。该误差估计是通过自适应 ODE 求解器的实际步骤计算的:要计算该误差估计,您必须执行与 ODE 求解器相同的计算和相同的 while 循环。在此算法中,您无法避免直接对 ODE 求解器进行微分,因为 ODE 求解器的内部计算部分现在是正则化的一部分。这从根本上没有被需要准静态计算图(Jax、Tensorflow 等)的方法优化,并且这使得超参数优化比不进行超参数优化更便宜,因为正则化器是免费计算的。我只是觉得这个结果太酷了!所以是的,这篇论文是一篇关于如何使用神经 ODE 的技巧免费进行超参数优化的机器学习论文,但我认为它所处的一般软件上下文突出了论文的真实发现。这是我所知道的第一个算法,它有明确的动机用于现代机器学习,而且,Jax 和 Tensorflow 等常见机器学习框架无法处理它们是有根本原因的最佳。即使是 PyTorch 的 TorchScript,由于其编译过程的假设,从根本上也不会在该算法上工作。这些假设是明智地选择的,因为大多数算法都可以满足它们,但这个算法不能。这是否意味着机器学习在算法上停滞不前?可能是因为我完全相信在不优化此算法的工具集中工作的人永远不会找到它,这让我非常发人深省。还有哪些算法比我们当前的方法更好,但仅因为当前的机器学习框架而更差?我迫不及待地等到 Diffractor.jl 的发布才开始深入探讨这个问题。