交互式在线版本: Open In Colab

构建你自己的 Haiku#

在这个 Colab 中,我们将从头开始构建一个高度简化的 Haiku 版本,以便您深入了解 Haiku 的工作原理。

这是一个“高级”教程,面向寻求深入了解 Haiku 内部机制的人们。 实际上,这不是理解如何在实践中使用 Haiku 的必要条件。(“高级”之所以加引号,是因为它实际上并没有那么复杂,所以不要害怕!)

这里的实现基于真实 Haiku 库的设计,但大多数细节都经过简化。 因此,虽然这应该让您对概念下的底层原理有一个相当准确的认识,但不要依赖细节来匹配。

问题#

我们希望能够编写带有参数属性的面向对象类,就像这样,

[1]:
class MyModule:

  def apply(self, x):
    return self.w * x

并自动将它们转换为纯函数,就像这样

[2]:
def my_stateless_apply(params, x):
  return params['w'] * x

(但是,我们没有使用通过 self.* 的属性访问,而是定义了自己的访问器函数,名为 get_param()。 这使得拦截其用法变得更加容易,我们需要稍后收集和注入参数值。)

此外,如果此转换还定义了参数初始化,并自动处理为参数分配唯一名称,那就太好了,因为在大型网络中手动管理这些名称可能会变得笨拙。 例如,如果某个其他模块也将其参数称为 w,我们希望自动解决此类冲突。

我们将分步解决这个问题。

  • 在第一步中,我们将实现一个基本的 transform,它将面向对象风格的函数转换为纯函数。

  • 下一步将是添加初始化。

  • 最后,我们将处理当使用同一模块的多个副本或不同模块为其参数使用相同名称时涉及的管道。

在该阶段,我们将能够像使用真实的 Haiku 一样定义和训练一个简单的神经网络。

基本策略#

我们将定义一个函数,该函数实现从使用 get_param 的有状态样式到无状态函数的转换。 此函数将恰如其分地称为 transform。 它将 MyModule().apply 包装到一个函数中,该函数的工作方式类似于 my_stateless_apply

以下是它的工作方式。 transform(f) 将返回 f 的包装版本,它接受一个额外的 params 参数。 调用时,它将运行 f,并且每次 f 调用 get_param 时,它都会从 params 中提取相应的值并返回它。

[3]:
# Global state which holds the parameters for the transformed function.
# get_param uses this to know where to get params from.
current_params = []

def transform(f):

  def apply_f(params, *args, **kwargs):
    current_params.append(params)
    outs = f(*args, **kwargs)
    current_params.pop()
    return outs

  return apply_f


def get_param(identifier):
  return current_params[-1][identifier]

让我们测试一下

[4]:
params = dict(w=5)
my_stateless_apply(params, 5)
[4]:
25
[5]:
class MyModule:
  def apply(self, x):
    return get_param('w') * x

transform(MyModule().apply)(params, 5)
[5]:
25

“等一下!” 你说。 JAX 不是完全关于没有全局状态吗? 这在 JAX 中可能行不通! 好吧,让我们用 JAX 试试看

[6]:
import jax
import jax.numpy as jnp

def linear(x):
  return x @ get_param('w') + get_param('b')

params = dict(w=jnp.ones((3, 5)), b=jnp.ones((5,)))
apply = transform(linear)

jax.jit(apply)(params, jnp.ones((10, 3)))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[6]:
DeviceArray([[4., 4., 4., 4., 4.],
             [4., 4., 4., 4., 4.],
             [4., 4., 4., 4., 4.],
             [4., 4., 4., 4., 4.],
             [4., 4., 4., 4., 4.],
             [4., 4., 4., 4., 4.],
             [4., 4., 4., 4., 4.],
             [4., 4., 4., 4., 4.],
             [4., 4., 4., 4., 4.],
             [4., 4., 4., 4., 4.]], dtype=float32)

这样做有效的原因是,虽然我们使用了全局状态,但我们对如何使用它很谨慎。 我们使函数调用后的全局状态与调用前相同,并确保包装函数的输出仅取决于其输入。 因此,JAX 并不知情 - 就它而言,转换后的函数是纯函数。

添加初始化#

到目前为止,一切都很好,但是我们无法重用模块,因为我们的转换将在同一模块的所有副本之间共享参数,因为它们的名称都相同。 此外,定义初始状态很麻烦 - 我们可以自动化吗?

让我们首先解决初始化问题。 为简单起见,我们的参数将始终使用正态分布进行初始化,但添加不同初始化器的选项并不难。

在这个新版本中,我们将一个面向对象的状态函数 transform 为两个纯函数:一个初始化参数,另一个应用参数。 这些对应于在两种模式下运行原始函数:初始化和应用。

为了支持这一点,我们添加了额外的机制(Frame),以跟踪我们所处的模式,并更改 get_param() 的行为

  • 我们添加一个 shape 参数,它告诉我们如果初始化,参数应该是什么形状。

  • 如果正在初始化,get_param() 将创建正确形状的参数,并将其添加到 Frame 中的当前参数,然后再返回。

因此,get_param() 会及时生成初始值,以便它们在有状态函数的调用中使用。

[7]:
from typing import NamedTuple, Dict, Callable
import numpy as np
[8]:
# Since we're tracking more than just the current params,
# we introduce the concept of a frame as the object that holds
# state during a transformed execution.
frame_stack = []

class Frame(NamedTuple):
  """Tracks what's going on during a call of a transformed function."""
  params: Dict[str, jax.Array]
  is_initialising: bool = False

def current_frame():
  return frame_stack[-1]


class Transformed(NamedTuple):
  init: Callable
  apply: Callable


def transform(f) -> Transformed:

  def init_f(*args, **kwargs):
    frame_stack.append(Frame({}, is_initialising=True))
    f(*args, **kwargs)
    frame = frame_stack.pop()
    return frame.params

  def apply_f(params, *args, **kwargs):
    frame_stack.append(Frame(params))
    outs = f(*args, **kwargs)
    frame_stack.pop()
    return outs

  return Transformed(init_f, apply_f)

def get_param(identifier, shape):
  if current_frame().is_initialising:
    current_frame().params[identifier] = np.random.normal(size=shape)

  return current_frame().params[identifier]

让我们通过实现一个线性模块来测试它

[9]:
# Make printing parameters a little more readable
def parameter_shapes(params):
  return jax.tree_util.tree_map(lambda p: p.shape, params)


class Linear:

  def __init__(self, width):
    self._width = width

  def __call__(self, x):
    w = get_param('w', shape=(x.shape[-1], self._width))
    b = get_param('b', shape=(self._width,))
    return x @ w + b

init, apply = transform(Linear(4))

data = jnp.ones((2, 3))

params = init(data)
parameter_shapes(params)
[9]:
{'b': (4,), 'w': (3, 4)}
[10]:
apply(params, data)
[10]:
DeviceArray([[-1.0345883,  0.3280404, -2.4382973,  0.5717376],
             [-1.0345883,  0.3280404, -2.4382973,  0.5717376]],            dtype=float32)

添加唯一参数名称:我们完成的迷你 Haiku#

好的! 是时候处理嵌套模块了,我们的原型将完成。

为此,我们需要为每个参数指定一个明确的名称。 在这里,我们将使用一种与真实 Haiku 有些不同且不兼容的方案,但该方案更简单。 这个想法是记录被调用函数的名称,并根据每个参数在调用堆栈中的位置为其分配一个唯一标识符。

为此,我们将定义一个 Module 类。 每个模块都将具有一个唯一的标识符,该标识符基于类名和到目前为止创建的模块实例数。 (真正的 Haiku 允许自定义这些名称,但为简单起见,我们忽略了这一点)

我们还将为 Module 方法定义一个装饰器,名为 module_method,它将告诉我们何时调用包装函数,从而使我们能够跟踪当前参数范围。 真正的 haiku 使用元类来自动包装 Module 上的所有方法,但为简单起见,我们手动执行此操作。

[11]:
import dataclasses
import collections

@dataclasses.dataclass
class Frame:
  """Tracks what's going on during a call of a transformed function."""
  params: Dict[str, jax.Array]
  is_initialising: bool = False

  # Keeps track of how many modules of each class have been created so far.
  # Used to assign new modules unique names.
  module_counts: Dict[str, int] = dataclasses.field(
      default_factory=lambda: collections.defaultdict(lambda: 0))

  # Keeps track of the entire path to the current module method call.
  # Module methods, when called, will add themselves to this stack.
  # Used to give each parameter a unique name corresponding to the
  # method scope it is in.
  call_stack: list = dataclasses.field(default_factory=list)

  def create_param_path(self, identifier) -> str:
    """Creates a unique path for this param."""
    return '/'.join(['~'] + self.call_stack + [identifier])

  def create_unique_module_name(self, module_name: str) -> str:
    """Assigns a unique name to the module by appending its number to its name."""
    number = self.module_counts[module_name]
    self.module_counts[module_name] += 1
    return f"{module_name}_{number}"

frame_stack = []

def current_frame():
  return frame_stack[-1]


class Module:
  def __init__(self):
    # Assign a unique (for the current `transform` call)
    # name to this instance of the module.
    self._unique_name = current_frame().create_unique_module_name(
        self.__class__.__name__)


def module_method(f):
  """A decorator for Module methods."""
  # In the real Haiku, this doesn't face the user but is applied by a metaclass.

  def wrapped(self, *args, **kwargs):
    """A version of f that lets the frame know it's being called."""
    # Self is the instance to which this method is attached.
    module_name = self._unique_name
    call_stack = current_frame().call_stack
    call_stack.append(module_name)
    call_stack.append(f.__name__)
    outs = f(self, *args, **kwargs)
    assert call_stack.pop() == f.__name__
    assert call_stack.pop() == module_name
    return outs

  return wrapped


def get_param(identifier, shape):
  frame = current_frame()
  param_path = frame.create_param_path(identifier)

  if frame.is_initialising:
    frame.params[param_path] = np.random.normal(size=shape)

  return frame.params[param_path]


class Linear(Module):

  def __init__(self, width):
    super().__init__()
    self._width = width

  @module_method  # Again, this decorator is behind-the-scenes in real Haiku.
  def __call__(self, x):
    w = get_param('w', shape=(x.shape[-1], self._width))
    b = get_param('b', shape=(self._width,))
    return x @ w + b

在这个阶段,我们已经复制了一些核心 Haiku 功能,但我们仍然没有:* 对初始化的控制 * rng 处理 * 状态处理(尽管从概念上讲,这类似于参数处理)* 任何类型的验证和错误处理 * 一旦创建就冻结参数 * 线程安全 * transform 内的 JAX 转换(例如 hk.remat)* transforms 内的 JAX 控制流(例如 hk.scan)* 最后但并非最不重要的,文档 :)

还有很多。 但是,基本功能有效,因此我们可以试用一下我们的迷你 Haiku

[12]:
init, apply = transform(lambda x: Linear(4)(x))

params = init(data)
parameter_shapes(params)
[12]:
{'~/Linear_0/__call__/b': (4,), '~/Linear_0/__call__/w': (3, 4)}
[13]:
apply(params, data)
[13]:
DeviceArray([[-1.1969297,  1.3215988,  5.175427 , -1.9018829],
             [-1.1969297,  1.3215988,  5.175427 , -1.9018829]],            dtype=float32)

函数调用中的不同模块都具有单独的参数

[14]:
class MLP(Module):

  def __init__(self, widths):
    super().__init__()
    self._widths = widths

  @module_method
  def __call__(self, x):
    for w in self._widths:
      out = Linear(w)(x)
      x = jax.nn.sigmoid(out)
    return out
[15]:
init, apply = transform(lambda x: MLP([3, 5])(x))
parameter_shapes(init(data))
[15]:
{'~/MLP_0/__call__/Linear_0/__call__/b': (3,),
 '~/MLP_0/__call__/Linear_0/__call__/w': (3, 3),
 '~/MLP_0/__call__/Linear_1/__call__/b': (5,),
 '~/MLP_0/__call__/Linear_1/__call__/w': (3, 5)}

而在不同位置调用的同一模块会重用参数

[16]:
class ParameterReuseTest(Module):

  @module_method
  def __call__(self, x):
    f = Linear(x.shape[-1])

    x = f(x)
    x = jax.nn.relu(x)
    return f(x)

init, forward = transform(lambda x: ParameterReuseTest()(x))
parameter_shapes(init(data))
[16]:
{'~/ParameterReuseTest_0/__call__/Linear_0/__call__/b': (3,),
 '~/ParameterReuseTest_0/__call__/Linear_0/__call__/w': (3, 3)}

示例训练循环#

[17]:
import matplotlib.pyplot as plt
[18]:
# Data: a quadratic curve.
xs = np.linspace(-2., 2., num=128)[:, None]  # Generate array of shape (128, 1).
ys = xs ** 2

# Model
def mlp(x):
  return MLP([128, 128, 1])(x)

init, forward = transform(mlp)
params = init(xs)
parameter_shapes(params)
[18]:
{'~/MLP_0/__call__/Linear_0/__call__/b': (128,),
 '~/MLP_0/__call__/Linear_0/__call__/w': (1, 128),
 '~/MLP_0/__call__/Linear_1/__call__/b': (128,),
 '~/MLP_0/__call__/Linear_1/__call__/w': (128, 128),
 '~/MLP_0/__call__/Linear_2/__call__/b': (1,),
 '~/MLP_0/__call__/Linear_2/__call__/w': (128, 1)}
[19]:
# Loss function and update function
def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

LEARNING_RATE = 0.003

@jax.jit
def update(params, x, y):
  grads = jax.grad(loss_fn)(params, x, y)
  return jax.tree_util.tree_map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )
[20]:
for _ in range(5000):
  params = update(params, xs, ys)
[21]:
plt.scatter(xs, ys, label='Data')
plt.scatter(xs, forward(params, xs), label='Model prediction')
plt.legend()
plt.show()
../_images/notebooks_build_your_own_haiku_36_0.png
[ ]: