交互式在线版本: Open In Colab

Haiku 和 Flax 互操作 🥂#

在 Haiku 和 Flax 之间无缝移动的实用工具。

Flax 在 Haiku 内部#

hk.transform (或 hk.transform_with_state) 中使用 Flax 模块非常直接。

首先构建你的模块的实例,然后使用 hkflax.lift 将 Flax 模块的参数和状态“提升”(参见 [hk.lift])到 Haiku 转换中。

示例

[ ]:
import jax
import jax.numpy as jnp
import haiku as hk
import haiku.experimental.flax as hkflax
import flax.linen as flax_nn

def f(x):
  mod = hkflax.lift(flax_nn.Dense(10), name='my_flax_module')
  x = mod(x)
  return x

f = hk.transform(f)
x = jnp.ones([1, 1])
rng = jax.random.PRNGKey(42)
params = f.init(rng, x)   # params contains the parameters for MyFlaxModule.
f.apply(params, None, x)  # MyFlaxModule will be passed parameters from params.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([[ 0.33030465, -1.3496182 ,  0.02847686, -1.6579462 , -0.9166192 ,
         0.2883583 , -0.046898  ,  0.6414894 , -0.404975  , -2.1162813 ]],      dtype=float32)

要使用有状态模块,只需将 hk.transform 替换为 hk.transform_with_state

Haiku 在 Flax 内部#

有两种受支持的方法可以将 Haiku 代码转换为 Flax。两者都生成一个 Flax linen nn.Module,它封装了 Haiku 代码并提供 initapply 方法来创建和使用参数和状态。

  • 将 ``hk.Module`` 转换为 nn.Module <#hk-Module>`__。

  • 将 ``hk.transform`` 转换为 nn.Module <#hk-transform>`__。

  • 将 ``hk.transform_with_state`` 转换为 nn.Module <#hk-transform>`__。

转换 hk.Module#

对于无状态模块,你只需通过 hkflax.Module.create 构建 Flax 模块

[ ]:
mod = hkflax.Module.create(hk.Linear, 1)  # hk.Linear(1)

你可以像使用常规 Flax nn.Module 一样使用它(因为它就是!)

[ ]:
rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 1])
variables = mod.init(rng, x)
out = mod.apply(variables, x)

对于像 ResNet 这样的有状态模块,你还需要处理输出状态,这与有状态 Flax 模块相同

[ ]:
mod = hkflax.Module.create(hk.nets.ResNet50, 10)

# Regular Flax code from here on:
rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 224, 224, 3])
variables = mod.init(rng, x, is_training=True)
for _ in range(10):
  out, state_out = mod.apply(variables, x, is_training=True,
                             mutable=['state'])
  variables = {**variables, **state_out}

转换 hk.transformhk.transform_with_state#

如果你愿意,可以从 hk.transformhk.transform_with_state 的结果创建 hkflax.Module

[ ]:
def mlp(x):
  x = hk.Linear(300)(x)
  x = hk.Linear(100)(x)
  x = hk.Linear(10)(x)
  return x

mlp = hk.transform(mlp)
mlp = hkflax.Module(mlp)

rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 28 * 28])
variables = mlp.init(rng, x)
out = mlp.apply(variables, x)

注意事项#

初始化不同#

Flax 和 Haiku 采用不同的 RNG 密钥分割方法。因此,目前从 hkflax.Module(f).init 返回的参数将不同于 hk.transform(f).init

我们有一条途径来支持使 Haiku 转换匹配 Flax 模块的初始化,但目前没有相反方向的路径。

如果跨 Haiku 和 Flax 对齐初始化对你很重要,我们建议使用其中一个库来创建参数,然后根据需要操作参数/状态字典以匹配另一个库

# Utilities.
import haiku.data_structures as hkds

make_flat = {f'{m}/{n}': w for m, n, w in hkds.traverse(d)}

def make_nested(d):
  out = {}
  for k, w in d.items():
    m, n = k.rsplit('/', 1)
    out.setdefault(m, {})
    out[m][n] = w
  return out

# The two modules here should be equivalent when run with Flax or Haiku.
f = hk.transform_with_state(...)
flax_mod = hkflax.Module(f)

# Option 1: Convert Haiku initialized params/state to Flax.
params, state = f.init(...)
variables = {'params': make_flat(params), 'state': make_flat(state)}

# Option 2: Convert Flax initialized variables to Haiku.
variables = flax_mod.init(...)
params = make_nested(variables.get('params', {}))
state = make_nested(variables.get('state', {}))

# The output of the Haiku transformed function or the Flax function should be
# equivalent with either init.
out, state = f.apply(params, state, ...)
out, variables_out = flax_mod.apply(variables, ..., mutable=['state'])

多个前向方法#

hkflax.Module 目前仅支持 __call__,如果这阻碍了你,请告知我们。