嵌套 JAX 函数和 Haiku 模块的限制
目录
[ ]:
import haiku as hk
import jax
import jax.numpy as jnp
TL;DR: 在 hk.transform
内部的 JAX 变换很可能会变换一个有副作用的函数,这将导致 UnexpectedTracerError
。此页面描述了两种解决此问题的方法。
嵌套 JAX 函数和 Haiku 模块的限制#
一旦使用 hk.transform
将 Haiku 网络转换为一对纯函数,就可以自由地将它们与任何 JAX 变换(如 jax.jit
、jax.grad
、jax.lax.scan
等)结合使用。
但是,如果要在 hk.transform
内部 使用 JAX 变换,则需要更加小心。这是可能的,但是 hk.transform
边界内部的大多数函数仍然是有副作用的,并且不能安全地被 JAX 变换。这是使用 Haiku 的代码中出现 UnexpectedTracerError
的常见原因。这些错误是在有副作用的函数上使用 JAX 变换的结果(有关此 JAX 错误的更多信息,请参见 https://jax.net.cn/en/latest/errors.html#jax.errors.UnexpectedTracerError)。
使用 jax.eval_shape
的示例
[ ]:
def net(x): # inside of a hk.transform, this is still side-effecting
w = hk.get_parameter("w", (2, 2), init=jnp.ones)
return w @ x
def eval_shape_net(x):
output_shape = jax.eval_shape(net, x) # eval_shape on side-effecting function
return net(x) # UnexpectedTracerError!
init, _ = hk.transform(eval_shape_net)
try:
init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError:
print("UnexpectedTracerError: applied JAX transform to side effecting function")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
UnexpectedTracerError: applied JAX transform to side effecting function
这些示例使用 jax.eval_shape
,但也可以使用任何高阶 JAX 函数(例如,jax.vmap
、jax.lax.scan
、jax.while_loop
等)。
该错误指向 hk.get_parameter
。这是使 net
成为有副作用的函数的操作。在这种情况下,副作用是参数的创建,该参数被存储到 Haiku 状态中。类似地,使用 hk.next_rng_key
也会产生错误,因为它会推进 Haiku RNG 状态并将新的 PRNGKey 存储到 Haiku 状态中。一般来说,变换一个未变换的 Haiku 模块将导致 UnexpectedTracerError
。
您可以重写上面的代码,在 eval_shape
变换之外创建参数,通过显式地将参数作为参数传递,使 net
成为纯函数
[ ]:
def net(w, x): # no side effects!
return w @ x
def eval_shape_net(x):
w = hk.get_parameter("w", (3, 2), init=jnp.ones)
output_shape = jax.eval_shape(net, w, x) # net is now side-effect free
return output_shape, net(w, x)
key = jax.random.PRNGKey(777)
x = jnp.ones((2, 3))
init, apply = hk.transform(eval_shape_net)
params = init(key, x)
apply(params, key, x)
(ShapeDtypeStruct(shape=(3, 3), dtype=float32),
DeviceArray([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]], dtype=float32))
但是,这并非总是可能的。考虑以下代码,它调用了我们不拥有的 Haiku 模块 (hk.nets.MLP
)。此模块将在内部调用 get_parameter
。
[ ]:
def eval_shape_net(x):
net = hk.nets.MLP([300, 100])
output_shape = jax.eval_shape(net, x)
return output_shape, net(x)
init, _ = hk.transform(eval_shape_net)
try:
init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError:
print("UnexpectedTracerError: applied JAX transform to side effecting function")
UnexpectedTracerError: applied JAX transform to side effecting function
使用 hk.lift#
我们想要一种方法来访问我们隐式的 Haiku 状态,并获得 hk.nets.MLP
的功能纯版本。通常实现此目的的方法是使用 hk.transform
,因此我们所需要的只是在外部 hk.transform
内部嵌套一个内部 hk.tranform
的方法!我们将通过 hk.transform
创建另一对 init
和 apply
函数,然后可以将它们与任何高阶 JAX 函数安全地结合使用。
但是,我们需要一种方法将此嵌套的 hk.tranform
状态注册到外部作用域。我们可以为此使用 hk.lift
。用 hk.lift
包装我们的内部 init
函数会将我们的内部 params
注册到外部参数作用域。
[ ]:
def eval_shape_net(x):
net = hk.nets.MLP([300, 100]) # still side-effecting
init, apply = hk.transform(net) # nested transform
params = hk.lift(init, name="inner")(hk.next_rng_key(), x) # register parameters in outer module scope with name "inner"
output_shape = jax.eval_shape(apply, params, hk.next_rng_key(), x) # apply is a functionaly pure function and can be transformed!
out = net(x)
return out, output_shape
init, apply = hk.transform(eval_shape_net)
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))
jax.tree_util.tree_map(lambda x: x.shape, params)
FlatMap({
'inner/mlp/~/linear_0': FlatMap({'b': (300,), 'w': (100, 300)}),
'inner/mlp/~/linear_1': FlatMap({'b': (100,), 'w': (300, 100)}),
'mlp/~/linear_0': FlatMap({'b': (300,), 'w': (100, 300)}),
'mlp/~/linear_1': FlatMap({'b': (100,), 'w': (300, 100)}),
})
使用 Haiku 版本的 JAX 变换#
Haiku 为了方便起见,还提供了一些 JAX 函数的包装版本。例如:hk.grad
、hk.vmap
等。有关可用函数的完整列表,请参见 https://haiku.jax.net.cn/en/latest/api.html#jax-fundamentals。
这些包装器通过为您执行显式的状态传递,将 JAX 函数应用于 Haiku 函数的功能纯版本。它们不像 lift
那样引入额外的命名空间级别。
[ ]:
def eval_shape_net(x):
net = hk.nets.MLP([300, 100]) # still side-effecting
output_shape = hk.eval_shape(net, x) # hk.eval_shape threads through the Haiku state for you
out = net(x)
return out, output_shape
init, apply = hk.transform(eval_shape_net)
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
out = apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))
总结#
总结一下,以下是一些组合 JAX 变换和 Haiku 模块的好的和坏的例子
什么? |
有效? |
示例 |
---|---|---|
在 hk.transform 外部进行 vmapping |
✔ 是的! |
jax.vmap(hk.transform(hk.nets.ResNet50)) |
在 hk.transform 内部进行 vmapping |
✖ 意外的 tracer 错误 |
hk.transform(jax.vmap(hk.nets.ResNet50)) |
vmapping 嵌套的 hk.transform(不使用 lift) |
✖ 内部状态未注册 |
hk.transform(jax.vmap(hk.transform(hk.nets.ResNet50))) |
vmapping 嵌套的 hk.transform(使用 lift) |
✔ 是的! |
hk.transform(jax.vmap(hk.lift(hk.transform(hk.nets.ResNet50)))) |
使用 hk.vmap |
✔ 是的! |
hk.transform(hk.vmap(hk.nets.ResNet50)) |