Haiku 和 Flax 互操作 🥂
目录
Haiku 和 Flax 互操作 🥂#
在 Haiku 和 Flax 之间无缝移动的实用工具。
Flax 在 Haiku 内部#
在 hk.transform
(或 hk.transform_with_state
) 中使用 Flax 模块非常直接。
首先构建你的模块的实例,然后使用 hkflax.lift
将 Flax 模块的参数和状态“提升”(参见 [hk.lift
])到 Haiku 转换中。
示例
[ ]:
import jax
import jax.numpy as jnp
import haiku as hk
import haiku.experimental.flax as hkflax
import flax.linen as flax_nn
def f(x):
mod = hkflax.lift(flax_nn.Dense(10), name='my_flax_module')
x = mod(x)
return x
f = hk.transform(f)
x = jnp.ones([1, 1])
rng = jax.random.PRNGKey(42)
params = f.init(rng, x) # params contains the parameters for MyFlaxModule.
f.apply(params, None, x) # MyFlaxModule will be passed parameters from params.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([[ 0.33030465, -1.3496182 , 0.02847686, -1.6579462 , -0.9166192 ,
0.2883583 , -0.046898 , 0.6414894 , -0.404975 , -2.1162813 ]], dtype=float32)
要使用有状态模块,只需将 hk.transform
替换为 hk.transform_with_state
。
Haiku 在 Flax 内部#
有两种受支持的方法可以将 Haiku
代码转换为 Flax
。两者都生成一个 Flax linen nn.Module
,它封装了 Haiku 代码并提供 init
和 apply
方法来创建和使用参数和状态。
将 ``hk.Module`` 转换为
nn.Module
<#hk-Module>`__。将 ``hk.transform`` 转换为
nn.Module
<#hk-transform>`__。将 ``hk.transform_with_state`` 转换为
nn.Module
<#hk-transform>`__。
转换 hk.Module
#
对于无状态模块,你只需通过 hkflax.Module.create
构建 Flax 模块
[ ]:
mod = hkflax.Module.create(hk.Linear, 1) # hk.Linear(1)
你可以像使用常规 Flax nn.Module
一样使用它(因为它就是!)
[ ]:
rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 1])
variables = mod.init(rng, x)
out = mod.apply(variables, x)
对于像 ResNet 这样的有状态模块,你还需要处理输出状态,这与有状态 Flax 模块相同
[ ]:
mod = hkflax.Module.create(hk.nets.ResNet50, 10)
# Regular Flax code from here on:
rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 224, 224, 3])
variables = mod.init(rng, x, is_training=True)
for _ in range(10):
out, state_out = mod.apply(variables, x, is_training=True,
mutable=['state'])
variables = {**variables, **state_out}
转换 hk.transform
或 hk.transform_with_state
#
如果你愿意,可以从 hk.transform
或 hk.transform_with_state
的结果创建 hkflax.Module
[ ]:
def mlp(x):
x = hk.Linear(300)(x)
x = hk.Linear(100)(x)
x = hk.Linear(10)(x)
return x
mlp = hk.transform(mlp)
mlp = hkflax.Module(mlp)
rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 28 * 28])
variables = mlp.init(rng, x)
out = mlp.apply(variables, x)
注意事项#
初始化不同#
Flax 和 Haiku 采用不同的 RNG 密钥分割方法。因此,目前从 hkflax.Module(f).init
返回的参数将不同于 hk.transform(f).init
。
我们有一条途径来支持使 Haiku 转换匹配 Flax 模块的初始化,但目前没有相反方向的路径。
如果跨 Haiku 和 Flax 对齐初始化对你很重要,我们建议使用其中一个库来创建参数,然后根据需要操作参数/状态字典以匹配另一个库
# Utilities.
import haiku.data_structures as hkds
make_flat = {f'{m}/{n}': w for m, n, w in hkds.traverse(d)}
def make_nested(d):
out = {}
for k, w in d.items():
m, n = k.rsplit('/', 1)
out.setdefault(m, {})
out[m][n] = w
return out
# The two modules here should be equivalent when run with Flax or Haiku.
f = hk.transform_with_state(...)
flax_mod = hkflax.Module(f)
# Option 1: Convert Haiku initialized params/state to Flax.
params, state = f.init(...)
variables = {'params': make_flat(params), 'state': make_flat(state)}
# Option 2: Convert Flax initialized variables to Haiku.
variables = flax_mod.init(...)
params = make_nested(variables.get('params', {}))
state = make_nested(variables.get('state', {}))
# The output of the Haiku transformed function or the Flax function should be
# equivalent with either init.
out, state = f.apply(params, state, ...)
out, variables_out = flax_mod.apply(variables, ..., mutable=['state'])
多个前向方法#
hkflax.Module
目前仅支持 __call__
,如果这阻碍了你,请告知我们。