Skip to content

混合编程

一、两种编程方式:命令式和符号式

1. 命令式编程(我们平时常用的方式)

命令式编程就是像我们平时写 Python 代码那样,用print+if这些语句一步步改变程序状态,代码写了就直接执行。比如下面这段代码:

python
def add(a, b):
    return a + b

def fancy_func(a, b, c, d):
    e = add(a, b)
    f = add(c, d)
    g = add(e, f)
    return g

print(fancy_func(1, 2, 3, 4))

Python 是解释型语言,执行这段代码的时候,会按顺序一步步来:先算e=add(a,b),把结果存在e里;再算f=add(c,d),存在f里;最后算g=add(e,f)。这种方式的好处是容易上手、容易调试,我们可以随时打印中间变量,用 Python 的调试工具找问题。

但它的缺点是效率不高:Python 解释器会一个个执行函数调用,就算同一个函数被重复调用(比如上面的add被调用了 3 次),也不会做优化;而且要一直保存ef这些变量,直到整个函数执行完,因为程序不知道后面会不会用到这些变量。如果在 GPU 上运行,Python 解释器的开销会很大,拖慢速度。

2. 符号式编程(为了效率的方式)

符号式编程是先定义好整个计算流程,然后把流程编译成可执行的程序,最后再传入数据执行。步骤是:

  1. 先写好整个计算的流程(比如先定义addfancy_func的逻辑)

  2. 把这个流程编译成机器能高效执行的程序

  3. 传入数据,运行编译好的程序

这种方式的好处是效率高、容易移植:编译器可以看到整个代码,能做很多优化,比如把上面的代码直接优化成print((1+2)+(3+4))甚至直接print(10);而且可以跳过 Python 解释器,避免解释器的性能瓶颈,还能把程序转换成和 Python 无关的格式,在非 Python 环境里运行。

但它的缺点是不够灵活,调试难:必须先把整个流程都定义好才能编译,没法像命令式那样边写边改;而且调试的时候很难看到中间变量的结果。

二、混合编程:把两种方式的优点结合起来

混合编程就是既可以用命令式编程来开发和调试(方便我们写代码、找 bug),又能把代码转换成符号式程序来运行(提高性能,方便部署)。不同框架的实现方式不一样:

  • MXNet 用HybridSequentialHybridBlock,调用hybridize函数

  • PyTorch 用 TorchScript

  • TensorFlow 用tf.function

  • 飞桨用paddle.jit.to_static

这里重点讲 PyTorch 的实现方式,因为我们平时用 PyTorch 比较多。

三、PyTorch 中的混合编程:TorchScript

1. 基本用法:把普通模型转换成 TorchScript

我们先定义一个简单的多层感知机:

python
import torch
from torch import nn
from d2l import torch as d2l

def get_net():
    net = nn.Sequential(
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 2)
    )
    return net

# 测试一下
x = torch.randn(size=(1, 512))
net = get_net()
net(x)

这个时候模型是命令式的,用 Python 解释器执行。如果想转换成符号式的,只需要用torch.jit.script把模型转换一下:

python
net = torch.jit.script(net)
net(x)

转换之后,模型的计算结果和之前一样,但是运行效率会提高。

2. 性能对比:转换前后的速度差异

我们可以用一个简单的工具类来测试速度:

python
class Benchmark:
    """用来测量运行时间"""
    def __init__(self, description='Done'):
        self.description = description
    def __enter__(self):
        self.timer = d2l.Timer()
        return self
    def __exit__(self, *args):
        print(f'{self.description}: {self.timer.stop():.4f} sec')

然后测试转换前后的速度:

python
# 测试普通的命令式模型
net = get_net()
with Benchmark('无torchscript'):
    for i in range(1000):
        net(x)

# 测试转换后的TorchScript模型
net = torch.jit.script(net)
with Benchmark('有torchscript'):
    for i in range(1000):
        net(x)

运行结果大概是:

Plain
无torchscript: 0.1361 sec
有torchscript: 0.1204 sec

可以看到,转换之后速度变快了,因为 TorchScript 把模型编译成了符号式的程序,跳过了 Python 解释器的开销。

3. 序列化:把编译好的模型保存下来

编译好的模型可以保存到磁盘,这样就可以部署到其他设备上,或者用其他编程语言调用。保存的方法很简单:

python
net.save('my_mlp')

保存之后会生成一个文件,这个文件里包含了模型的结构和参数,不需要 Python 环境也能运行,非常方便部署。

四、混合编程的注意事项

  1. 不是所有代码都能转换:如果模型里有一些自定义的 Python 代码(比如print语句、用了 Python 特有的函数),转换之后这些代码可能会被忽略或者报错。比如在 TorchScript 里,不能用x.asnumpy()这种只能在 Python 里用的函数,也不能用一些复杂的 Python 控制流。

  2. 变量赋值要注意:在混合编程里,像a += b或者a[:] = a + b这种操作,要改成a = a + b,因为符号式编程里的变量和命令式里的不一样,不能直接修改原变量。

  3. 调试的时候用命令式,运行的时候用符号式:开发和调试的时候用普通的命令式代码,方便找 bug;确定代码没问题了,再转换成符号式的来提高性能。

五、小结

  1. 命令式编程容易上手、调试方便,但效率低;符号式编程效率高、容易移植,但不够灵活。

  2. 混合编程把两者的优点结合起来,开发的时候用命令式,运行的时候用符号式。

  3. PyTorch 里用 TorchScript 实现混合编程,只需要用torch.jit.script转换模型,就能提高运行效率,还能把模型保存下来方便部署。

(注:文档部分内容可能由 AI 生成)

京ICP备2024093538号-1