交互式在线版本: Open In Colab

Haiku 基础知识#

在本 Colab 中,您将学习 Haiku 的基础知识。

是什么和为什么?

Haiku 是一个用于 JAX 的简单神经网络库,它使用户能够使用熟悉的面相对象编程模型,同时允许完全访问 JAX 的纯函数转换。Haiku 的设计目的是使我们常做的事情(例如管理模型参数和其他模型状态)更简单,并且在精神上类似于 DeepMind 广泛使用的 Sonnet 库。它保留了 Sonnet 基于模块的状态管理编程模型,同时保留了对 JAX 函数转换的访问。可以预期 Haiku 可以与其他库组合使用,并且可以与 JAX 的其余部分良好地协同工作。

[2]:
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

使用 hk.transform 的第一个例子#

作为 Haiku 的初步介绍,让我们构建一个具有自定义初始化的权重和偏差的线性模块。

与 Sonnet 模块类似,Haiku 模块是 Python 对象,它们保存对其自身参数、其他模块以及将函数应用于用户输入的方法的引用。另一方面,由于 JAX 对纯函数转换进行操作,因此 Haiku 模块不能逐字实例化。相反,模块需要被包装到纯函数转换中。

Haiku 提供了一个简单的函数转换 hk.transform,它可以将使用这些面向对象、功能上“不纯”的模块的函数转换为可以与 JAX 一起使用的纯函数。

[3]:
class MyLinear1(hk.Module):

  def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
    return jnp.dot(x, w) + b
[4]:
def _forward_fn_linear1(x):
  module = MyLinear1(output_size=2)
  return module(x)

forward_linear1 = hk.transform(_forward_fn_linear1)

我们看到,前向包装器对象现在包含两个方法,initapply,它们用于初始化变量并对模块执行前向推理。

[5]:
forward_linear1
[5]:
Transformed(init=<function without_state.<locals>.init_fn at 0x7f2e310354c0>, apply=<function without_state.<locals>.apply_fn at 0x7f2e31035550>)

调用 init 方法将初始化网络的参数并返回它们,如下所示。init 方法接受一个 jax.random.PRNGKey 和一个样本输入(通常只是一些虚拟值来告诉网络关于预期的形状)。

[6]:
dummy_x = jnp.array([[1., 2., 3.]])
rng_key = jax.random.PRNGKey(42)

params = forward_linear1.init(rng=rng_key, x=dummy_x)
print(params)
{'my_linear1': {'w': DeviceArray([[-0.30350363,  0.5123802 ],
             [ 0.08009142, -0.3163005 ],
             [ 0.6056666 ,  0.5820702 ]], dtype=float32), 'b': DeviceArray([1., 1.], dtype=float32)}}

我们现在可以使用参数将前向函数应用于一些输入。

[7]:
sample_x = jnp.array([[1., 2., 3.]])
sample_x_2 = jnp.array([[4., 5., 6.], [7., 8., 9.]])

output_1 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)
# Outputs are identical for given inputs since the forward inference is non-stochastic.
output_2 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)

output_3 = forward_linear1.apply(params=params, x=sample_x_2, rng=rng_key)

print(f'Output 1 : {output_1}')
print(f'Output 2 (same as output 1): {output_2}')
print(f'Output 3 : {output_3}')
Output 1 : [[2.6736789 2.6259897]]
Output 2 (same as output 1): [[2.6736789 2.6259897]]
Output 3 : [[3.820442 4.960439]
 [4.967205 7.294889]]

没有随机密钥的推断

我们构建的模块本质上是非随机的。在这种情况下,将随机密钥传递给 apply 方法似乎是多余的。Haiku 提供了另一种转换 hk.without_apply_rng,它可以进一步包装在我们的 hk.transform 方法周围。

[8]:
forward_without_rng = hk.without_apply_rng(hk.transform(_forward_fn_linear1))
params = forward_without_rng.init(rng=rng_key, x=sample_x)
output = forward_without_rng.apply(x=sample_x, params=params)
print(f'Output without random key in forward pass \n {output_1}')
Output without random key in forward pass
 [[2.6736789 2.6259897]]

我们还可以改变参数,然后执行前向推理,为相同的输入生成不同的输出。这就是在学习时将梯度下降应用于我们的参数所做的事情。

[9]:
mutated_params = jax.tree_util.tree_map(lambda x: x+1., params)
print(f'Mutated params \n : {mutated_params}')
mutated_output = forward_without_rng.apply(x=sample_x, params=mutated_params)
print(f'Output with mutated params \n {mutated_output}')
Mutated params
 : {'my_linear1': {'b': DeviceArray([2., 2.], dtype=float32), 'w': DeviceArray([[0.69649637, 1.5123801 ],
             [1.0800915 , 0.6836995 ],
             [1.6056666 , 1.5820701 ]], dtype=float32)}}
Output with mutated params
 [[9.673679 9.62599 ]]

Haiku 中的有状态推断#

对于某些模块,您可能希望在函数调用之间维护和传递内部状态。在这里,我们演示一个简单的例子,我们在 Haiku 转换中声明一个状态变量 counter,它在每次函数调用时都会更新。请注意,我们没有将其显式实例化为 Haiku 模块(相同的模块可以复制为 hk 模块,如前所示)。

[10]:
def stateful_f(x):
  counter = hk.get_state("counter", shape=[], dtype=jnp.int32, init=jnp.ones)
  multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
  hk.set_state("counter", counter + 1)
  output = x + multiplier * counter
  return output

stateful_forward = hk.without_apply_rng(hk.transform_with_state(stateful_f))
sample_x = jnp.array([[5., ]])
params, state = stateful_forward.init(x=sample_x, rng=rng_key)
print(f'Initial params:\n{params}\nInitial state:\n{state}')
print('##########')
for i in range(3):
  output, state = stateful_forward.apply(params, state, x=sample_x)
  print(f'After {i+1} iterations:\nOutput: {output}\nState: {state}')
  print('##########')
Initial params:
{'~': {'multiplier': DeviceArray([1.], dtype=float32)}}
Initial state:
{'~': {'counter': DeviceArray(1, dtype=int32)}}
##########
After 1 iterations:
Output: [[6.]]
State: {'~': {'counter': DeviceArray(2, dtype=int32)}}
##########
After 2 iterations:
Output: [[7.]]
State: {'~': {'counter': DeviceArray(3, dtype=int32)}}
##########
After 3 iterations:
Output: [[8.]]
State: {'~': {'counter': DeviceArray(4, dtype=int32)}}
##########

内置 Haiku 网络和嵌套模块#

我们常用的网络(例如 MLP、Convnets 等)已经在 Haiku 中定义,我们可以组合这些模块来构建我们自定义的 Haiku 模块。

查看 params 字典,了解参数是如何以与模块在我们自定义 Haiku 模块中嵌套的相同方式嵌套的。

[11]:
# See: https://haiku.jax.net.cn/en/latest/api.html#common-modules

class MyModuleCustom(hk.Module):
  def __init__(self, output_size=2, name='custom_linear'):
    super().__init__(name=name)
    self._internal_linear_1 = hk.nets.MLP(output_sizes=[2, 3], name='hk_internal_linear')
    self._internal_linear_2 = MyLinear1(output_size=output_size, name='old_linear')

  def __call__(self, x):
    return self._internal_linear_2(self._internal_linear_1(x))

def _custom_forward_fn(x):
  module = MyModuleCustom()
  return module(x)

custom_forward_without_rng = hk.without_apply_rng(hk.transform(_custom_forward_fn))
params = custom_forward_without_rng.init(rng=rng_key, x=sample_x)
params
[11]:
{'custom_linear/~/hk_internal_linear/~/linear_0': {'b': DeviceArray([0., 0.], dtype=float32),
  'w': DeviceArray([[ 1.51595   , -0.23353337]], dtype=float32)},
 'custom_linear/~/hk_internal_linear/~/linear_1': {'b': DeviceArray([0., 0., 0.], dtype=float32),
  'w': DeviceArray([[-0.22075887, -0.27375957,  0.5931483 ],
               [ 0.7818068 ,  0.72626334, -0.6860752 ]], dtype=float32)},
 'custom_linear/~/old_linear': {'b': DeviceArray([1., 1.], dtype=float32),
  'w': DeviceArray([[ 0.28584382,  0.31626168],
               [ 0.2335775 , -0.4827032 ],
               [-0.14647584, -0.7185701 ]], dtype=float32)}}

带有 hk.next_rng_key() 的 RNG 密钥#

我们之前看到的模块都是非随机的。下面我们展示如何采样随机数来进行随机推断。

Haiku 提供了一个用于处理随机数的简单模型。在转换后的函数中,hk.next_rng_key() 返回一个唯一的 rng 密钥。这些唯一的密钥是从传递到顶层转换函数的初始随机密钥确定性地派生的,因此可以安全地与 JAX 程序转换一起使用。

让我们定义一个简单的 haiku 函数,我们在其中生成两个随机样本。请注意,next_rng_keys 是从传递给顶层转换函数的 apply 方法的初始随机密钥确定的。

[15]:
class HkRandom2(hk.Module):
  def __init__(self, rate=0.5):
    super().__init__()
    self.rate = rate

  def __call__(self, x):
    key1 = hk.next_rng_key()
    return jax.random.bernoulli(key1, 1.0 - self.rate, shape=x.shape)


class HkRandomNest(hk.Module):
  def __init__(self, rate=0.5):
    super().__init__()
    self.rate = rate
    self._another_random_module = HkRandom2()

  def __call__(self, x):
    key2 = hk.next_rng_key()
    p1 = self._another_random_module(x)
    p2 = jax.random.bernoulli(key2, 1.0 - self.rate, shape=x.shape)
    print(f'Bernoullis are  : {p1, p2}')

# Note that the modules that are stochastic cannot be wrapped with hk.without_apply_rng()
forward = hk.transform(lambda x: HkRandomNest()(x))

x = jnp.array(1.)
print("INIT:")
params = forward.init(rng_key, x=x)
print("APPLY:")
prediction = forward.apply(params, x=x, rng=rng_key)
INIT:
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))
APPLY:
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

请注意,这意味着将相同的随机密钥传递给 apply 函数的多次调用将生成相同的随机结果!

[16]:
for _ in range(3):
  forward.apply(params, x=x, rng=rng_key)
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

确保分离出新的 RNG 密钥,以便在 apply 调用中获得不同的随机行为,并且永远不要重复使用 RNG 密钥。(有关如何在 JAX 中处理随机状态的更全面的解释,请查看此 RNG 教程:https://jax.net.cn/en/latest/jax-101/05-random-numbers.html。)

[19]:
for _ in range(3):
  rng_key, apply_rng_key = jax.random.split(rng_key)
  forward.apply(params, x=x, rng=apply_rng_key)
Bernoullis are  : (DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool))
Bernoullis are  : (DeviceArray(False, dtype=bool), DeviceArray(True, dtype=bool))
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

Haiku 还提供了 hk.PRNGSequence,它返回随机密钥的迭代器。

[20]:
rng_sequence = hk.PRNGSequence(rng_key)
for _ in range(3):
  forward.apply(params, x=x, rng=next(rng_sequence))
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(True, dtype=bool))
Bernoullis are  : (DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool))
Bernoullis are  : (DeviceArray(False, dtype=bool), DeviceArray(True, dtype=bool))