交互式在线版本: Open In Colab

[ ]:
import haiku as hk
import jax
import jax.numpy as jnp

TL;DR:hk.transform 内部的 JAX 变换很可能会变换一个有副作用的函数,这将导致 UnexpectedTracerError。此页面描述了两种解决此问题的方法。

嵌套 JAX 函数和 Haiku 模块的限制#

一旦使用 hk.transform 将 Haiku 网络转换为一对纯函数,就可以自由地将它们与任何 JAX 变换(如 jax.jitjax.gradjax.lax.scan 等)结合使用。

但是,如果要在 hk.transform 内部 使用 JAX 变换,则需要更加小心。这是可能的,但是 hk.transform 边界内部的大多数函数仍然是有副作用的,并且不能安全地被 JAX 变换。这是使用 Haiku 的代码中出现 UnexpectedTracerError 的常见原因。这些错误是在有副作用的函数上使用 JAX 变换的结果(有关此 JAX 错误的更多信息,请参见 https://jax.net.cn/en/latest/errors.html#jax.errors.UnexpectedTracerError)。

使用 jax.eval_shape 的示例

[ ]:
def net(x): # inside of a hk.transform, this is still side-effecting
  w = hk.get_parameter("w", (2, 2), init=jnp.ones)
  return w @ x

def eval_shape_net(x):
  output_shape = jax.eval_shape(net, x) # eval_shape on side-effecting function
  return net(x)                         # UnexpectedTracerError!

init, _ = hk.transform(eval_shape_net)
try:
  init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError:
  print("UnexpectedTracerError: applied JAX transform to side effecting function")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
UnexpectedTracerError: applied JAX transform to side effecting function

这些示例使用 jax.eval_shape,但也可以使用任何高阶 JAX 函数(例如,jax.vmapjax.lax.scanjax.while_loop 等)。

该错误指向 hk.get_parameter。这是使 net 成为有副作用的函数的操作。在这种情况下,副作用是参数的创建,该参数被存储到 Haiku 状态中。类似地,使用 hk.next_rng_key 也会产生错误,因为它会推进 Haiku RNG 状态并将新的 PRNGKey 存储到 Haiku 状态中。一般来说,变换一个未变换的 Haiku 模块将导致 UnexpectedTracerError

您可以重写上面的代码,在 eval_shape 变换之外创建参数,通过显式地将参数作为参数传递,使 net 成为纯函数

[ ]:
def net(w, x): # no side effects!
  return w @ x

def eval_shape_net(x):
  w = hk.get_parameter("w", (3, 2), init=jnp.ones)
  output_shape = jax.eval_shape(net, w, x) # net is now side-effect free
  return output_shape, net(w, x)

key = jax.random.PRNGKey(777)
x = jnp.ones((2, 3))
init, apply = hk.transform(eval_shape_net)
params = init(key, x)
apply(params, key, x)
(ShapeDtypeStruct(shape=(3, 3), dtype=float32),
 DeviceArray([[2., 2., 2.],
              [2., 2., 2.],
              [2., 2., 2.]], dtype=float32))

但是,这并非总是可能的。考虑以下代码,它调用了我们不拥有的 Haiku 模块 (hk.nets.MLP)。此模块将在内部调用 get_parameter

[ ]:
def eval_shape_net(x):
  net = hk.nets.MLP([300, 100])
  output_shape = jax.eval_shape(net, x)
  return output_shape, net(x)

init, _ = hk.transform(eval_shape_net)
try:
  init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError:
  print("UnexpectedTracerError: applied JAX transform to side effecting function")
UnexpectedTracerError: applied JAX transform to side effecting function

使用 hk.lift#

我们想要一种方法来访问我们隐式的 Haiku 状态,并获得 hk.nets.MLP 的功能纯版本。通常实现此目的的方法是使用 hk.transform,因此我们所需要的只是在外部 hk.transform 内部嵌套一个内部 hk.tranform 的方法!我们将通过 hk.transform 创建另一对 initapply 函数,然后可以将它们与任何高阶 JAX 函数安全地结合使用。

但是,我们需要一种方法将此嵌套的 hk.tranform 状态注册到外部作用域。我们可以为此使用 hk.lift。用 hk.lift 包装我们的内部 init 函数会将我们的内部 params 注册到外部参数作用域。

[ ]:
def eval_shape_net(x):
  net = hk.nets.MLP([300, 100])    # still side-effecting
  init, apply = hk.transform(net)  # nested transform
  params = hk.lift(init, name="inner")(hk.next_rng_key(), x) # register parameters in outer module scope with name "inner"
  output_shape = jax.eval_shape(apply, params, hk.next_rng_key(), x) # apply is a functionaly pure function and can be transformed!
  out = net(x)
  return out, output_shape


init, apply = hk.transform(eval_shape_net)
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))
jax.tree_util.tree_map(lambda x: x.shape, params)
FlatMap({
  'inner/mlp/~/linear_0': FlatMap({'b': (300,), 'w': (100, 300)}),
  'inner/mlp/~/linear_1': FlatMap({'b': (100,), 'w': (300, 100)}),
  'mlp/~/linear_0': FlatMap({'b': (300,), 'w': (100, 300)}),
  'mlp/~/linear_1': FlatMap({'b': (100,), 'w': (300, 100)}),
})

使用 Haiku 版本的 JAX 变换#

Haiku 为了方便起见,还提供了一些 JAX 函数的包装版本。例如:hk.gradhk.vmap 等。有关可用函数的完整列表,请参见 https://haiku.jax.net.cn/en/latest/api.html#jax-fundamentals

这些包装器通过为您执行显式的状态传递,将 JAX 函数应用于 Haiku 函数的功能纯版本。它们不像 lift 那样引入额外的命名空间级别。

[ ]:
def eval_shape_net(x):
  net = hk.nets.MLP([300, 100])         # still side-effecting
  output_shape = hk.eval_shape(net, x)  # hk.eval_shape threads through the Haiku state for you
  out = net(x)
  return out, output_shape


init, apply = hk.transform(eval_shape_net)
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
out = apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))

总结#

总结一下,以下是一些组合 JAX 变换和 Haiku 模块的好的和坏的例子

什么?

有效?

示例

在 hk.transform 外部进行 vmapping

✔ 是的!

jax.vmap(hk.transform(hk.nets.ResNet50))

在 hk.transform 内部进行 vmapping

✖ 意外的 tracer 错误

hk.transform(jax.vmap(hk.nets.ResNet50))

vmapping 嵌套的 hk.transform(不使用 lift)

✖ 内部状态未注册

hk.transform(jax.vmap(hk.transform(hk.nets.ResNet50)))

vmapping 嵌套的 hk.transform(使用 lift)

✔ 是的!

hk.transform(jax.vmap(hk.lift(hk.transform(hk.nets.ResNet50))))

使用 hk.vmap

✔ 是的!

hk.transform(hk.vmap(hk.nets.ResNet50))