Haiku API 参考#

Haiku 基础#

Haiku 转换#

transform(f, *[, apply_rng])

使用 Haiku 模块将函数转换为一对纯函数。

transform_with_state(f)

使用 Haiku 模块将函数转换为一对纯函数。

multi_transform(f)

使用 Haiku 将函数集合转换为纯函数。

multi_transform_with_state(f)

使用 Haiku 将函数集合转换为纯函数。

without_apply_rng(f)

从 apply 函数中移除 rng 参数。

without_state(f)

包装转换后的元组并忽略状态输入/输出。

transform#

haiku.transform(f, *, apply_rng=True)[source]#

使用 Haiku 模块将函数转换为一对纯函数。

对于函数 out = f(*a, **k),此函数返回一对调用 f(*a, **k) 的两个纯函数,显式收集和注入参数值

params = init(rng, *a, **k)
out = apply(params, rng, *a, **k)

请注意,rng 参数通常不是 apply 所必需的,并且接受传递 None

首先要做的是定义一个 Module。模块封装了一些参数以及对这些参数的计算

>>> class MyModule(hk.Module):
...   def __call__(self, x):
...     w = hk.get_parameter("w", [], init=jnp.zeros)
...     return x + w

接下来,定义一些创建和应用模块的函数。我们使用 transform() 将该函数转换为一对函数,这些函数允许我们将所有参数从函数中提取出来 (f.init) 并使用给定的参数集应用该函数 (f.apply)

>>> def f(x):
...   a = MyModule()
...   b = MyModule()
...   return a(x) + b(x)
>>> f = hk.transform(f)

要获取模块的初始状态,请使用示例输入调用 init

>>> params = f.init(None, 1)
>>> params
{'my_module': {'w': ...Array(0., dtype=float32)},
 'my_module_1': {'w': ...Array(0., dtype=float32)}}

然后,您可以使用给定的参数通过调用 apply 来应用该函数(请注意,由于我们不使用 Haiku 的随机数 API 来应用我们的网络,因此我们传递 None 作为 RNG 密钥)

>>> print(f.apply(params, None, 1))
2.0

预计您的程序会在某个时候生成更新的参数,并且您将希望重新应用 apply。您可以通过使用不同的参数调用 apply 来执行此操作

>>> new_params = {"my_module": {"w": jnp.array(2.)},
...               "my_module_1": {"w": jnp.array(3.)}}
>>> print(f.apply(new_params, None, 2))
9.0

如果您的转换函数需要维护内部状态(例如,批归一化中的移动平均值),请参阅 transform_with_state()

参数
  • f – 一个闭包 Module 实例的函数。

  • apply_rng – 正在移除的过程中。只能取值 True

返回类型

Transformed

返回

一个 Transformed 元组,其中包含 initapply 纯函数。

transform_with_state#

haiku.transform_with_state(f)[source]#

使用 Haiku 模块将函数转换为一对纯函数。

有关 Haiku 转换的常规详细信息,请参阅 transform()

对于函数 out = f(*a, **k),此函数返回一对调用 f(*a, **k) 的两个纯函数,显式收集和注入参数值和状态

params, state = init(rng, *a, **k)
out, state = apply(params, state, rng, *a, **k)

请注意,rng 参数通常不是 apply 所必需的,并且接受传递 None

此函数等效于 transform(),但是它允许您通过 get_state()set_state() 维护和更新内部状态(例如,BatchNorm 中的 ExponentialMovingAverage

>>> def f():
...   counter = hk.get_state("counter", shape=[], dtype=jnp.int32,
...                          init=jnp.zeros)
...   hk.set_state("counter", counter + 1)
...   return counter
>>> f = hk.transform_with_state(f)
>>> params, state = f.init(None)
>>> for _ in range(10):
...   counter, state = f.apply(params, state, None)
>>> print(counter)
9
参数

f – 一个闭包 Module 实例的函数。

返回类型

TransformedWithState

返回

一个 TransformedWithState 元组,其中包含 initapply 纯函数。

multi_transform#

haiku.multi_transform(f)[source]#

使用 Haiku 将函数集合转换为纯函数。

在许多场景中,我们有多个模块,这些模块既可以用作多个 Haiku 模块/函数的原语,也可以将其纯版本在下游代码中重用。此实用程序通过将 transform() 应用于 Haiku 函数的任意树来实现此目的,这些函数共享模块并具有通用的 init 函数。

f 预计返回两个元素的元组。第一个是 template Haiku 函数,它提供了一个关于所有内部 Haiku 模块如何连接的示例。此函数用于创建通用的 init 函数(带有您的参数)。

第二个对象是 Haiku 函数的任意树,所有这些函数都重用了 template 函数中连接的模块。这些函数被转换为纯 apply 函数。

示例

>>> def f():
...   encoder = hk.Linear(1, name="encoder")
...   decoder = hk.Linear(1, name="decoder")
...
...   def init(x):
...     z = encoder(x)
...     return decoder(z)
...
...   return init, (encoder, decoder)
>>> f = hk.multi_transform(f)
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([1, 1])
>>> params = f.init(rng, x)
>>> jax.tree_util.tree_map(jnp.shape, params)
{'decoder': {'b': (1,), 'w': (1, 1)},
 'encoder': {'b': (1,), 'w': (1, 1)}}
>>> encode, decode = f.apply
>>> z = encode(params, None, x)
>>> y = decode(params, None, z)
参数

f (Callable[[], tuple[TemplateFn, TreeOfApplyFns]]) – 一个工厂函数,它返回两个函数,首先是一个创建所有模块的通用 init 函数,其次是一个使用这些模块的 apply 函数的 pytree。

返回类型

MultiTransformed

返回

一个 MultiTransformed 实例,其中包含一个纯 init 函数

,它创建所有参数,以及一个纯 apply 函数的 pytree,给定参数,应用给定函数。

另请参阅

multi_transform_with_state():等效于使用状态的模块。

multi_transform_with_state#

haiku.multi_transform_with_state(f)[source]#

使用 Haiku 将函数集合转换为纯函数。

有关更多详细信息,请参阅 multi_transform()

示例

>>> def f():
...   encoder = hk.Linear(1, name="encoder")
...   decoder = hk.Linear(1, name="decoder")
...
...   def init(x):
...     z = encoder(x)
...     return decoder(z)
...
...   return init, (encoder, decoder)
>>> f = hk.multi_transform_with_state(f)
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([1, 1])
>>> params, state = f.init(rng, x)
>>> jax.tree_util.tree_map(jnp.shape, params)
{'decoder': {'b': (1,), 'w': (1, 1)},
 'encoder': {'b': (1,), 'w': (1, 1)}}
>>> encode, decode = f.apply
>>> z, state = encode(params, state, None, x)
>>> y, state = decode(params, state, None, z)
参数

f (Callable[[], tuple[TemplateFn, TreeOfApplyFns]]) – 函数返回一个“模板”函数和一个使用模板函数中连接的模块的任意函数树。

返回类型

MultiTransformedWithState

返回

一个 init 函数和一个纯 apply 函数树。

另请参阅

without_apply_rng#

haiku.without_apply_rng(f)[source]#

从 apply 函数中移除 rng 参数。

这是一个便利包装器,使 f.applyrng 参数默认为 None。当 f 实际上不使用随机数作为其计算的一部分时,这很有用,这样就不会使用 rng 参数。请注意,如果 f 确实 使用随机数,这将导致抛出错误,抱怨 f 需要一个非 None PRNGKey。

参数

f (TransformedT) – 一个转换后的函数。

返回类型

TransformedT

返回

相同的转换函数,带有一个修改后的 apply

without_state#

haiku.without_state(f)[source]#

包装转换后的元组并忽略状态输入/输出。

以下示例等效于 f = hk.transform(f)

>>> def f(x):
...   mod = hk.Linear(10)
...   return mod(x)
>>> f = hk.without_state(hk.transform_with_state(f))
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.zeros([1, 1])
>>> params = f.init(rng, x)
>>> print(f.apply(params, rng, x))
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
参数

f (TransformedWithState) – 一个转换后的函数。

返回类型

Transformed

返回

一个不接受或返回状态的转换函数。

with_empty_state#

haiku.with_empty_state(f)[source]#

包装转换后的元组并在输入/输出中传递空状态。

以下示例等效于 f = hk.transform_with_state(f)

>>> def f(x):
...   mod = hk.Linear(10)
...   return mod(x)
>>> f = hk.with_empty_state(hk.transform(f))
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.zeros([1, 1])
>>> params, state = f.init(rng, x)
>>> state
{}
>>> out, state = f.apply(params, state, rng, x)
>>> print(out)
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
>>> state
{}
参数

f (Transformed) – 一个转换后的函数。

返回类型

TransformedWithState

返回

一个接受和返回状态的转换函数。

模块、参数和状态#

Module([name])

Haiku 模块的基类。

to_module(f)

将函数转换为可调用模块类。

get_parameter(name, shape[, dtype, init])

为给定的转换函数创建或重用参数。

get_state(name[, shape, dtype, init])

获取状态的当前值,带有可选的初始化器。

set_state(name, value)

设置某些状态的当前值。

Module#

class haiku.Module(name=None)[source]#

Haiku 模块的基类。

Haiku 模块是变量和其他模块的轻量级容器。模块通常定义一个或多个“前向”方法(例如 __call__),这些方法应用组合用户输入和模块参数的操作。

模块必须在 transform() 调用内部初始化。

例如

>>> class AddModule(hk.Module):
...   def __call__(self, x):
...     w = hk.get_parameter("w", [], init=jnp.ones)
...     return x + w
>>> def forward_fn(x):
...   mod = AddModule()
...   return mod(x)
>>> forward = hk.transform(forward_fn)
>>> x = 1.
>>> rng = None
>>> params = forward.init(rng, x)
>>> print(forward.apply(params, None, x))
2.0
__init__(name=None)[source]#

使用给定名称初始化当前模块。

子类应在创建其他模块或变量之前调用此构造函数,以便正确命名这些模块。

参数

name (Optional[str]) – 类的可选字符串名称。必须是有效的 Python 标识符。如果未提供 name,则当前实例的类名将转换为 lower_snake_case 并改为使用。

__post_init__(name=None)[source]#

使用给定名称初始化当前模块。

子类应在创建其他模块或变量之前调用此构造函数,以便正确命名这些模块。

参数

name (Optional[str]) – 类的可选字符串名称。必须是有效的 Python 标识符。如果未提供 name,则当前实例的类名将转换为 lower_snake_case 并改为使用。

params_dict()[source]#

返回此模块和子模块的按名称键控的参数。

返回类型

Mapping[str, jnp.ndarray]

state_dict()[source]#

返回此模块和子模块的按名称键控的状态。

返回类型

Mapping[str, jnp.ndarray]

to_module#

haiku.to_module(f)[source]#

将函数转换为可调用模块类。

用法示例

>>> def bias_fn(x):
...   b = hk.get_parameter("b", [], init=hk.initializers.RandomNormal())
...   return x + b
>>> Bias = hk.to_module(bias_fn)
>>> def net(x, y):
...   b = Bias(name="my_bias")
...   # Bias x and y by the same amount.
...   return b(x) * b(y)
参数

f (Callable[..., Any]) – 要转换的函数。

返回类型

type[CallableModule]

返回

一个模块类,在调用时运行 f

get_parameter#

haiku.get_parameter(name, shape, dtype=<class 'jax.numpy.float32'>, init=None)[source]#

为给定的转换函数创建或重用参数。

>>> print(hk.get_parameter("w", [], init=jnp.ones))
1.0

在同一 transform() 和/或 Module 中具有相同名称的参数具有相同的值

>>> w1 = hk.get_parameter("w", [], init=jnp.zeros)
>>> w2 = hk.get_parameter("w", [], init=jnp.zeros)
>>> assert w1 is w2
参数
  • name (str) – 参数的名称。

  • shape (Sequence[int]) – 参数的形状。

  • dtype (Any) – 参数的 dtype。

  • init (Optional[Initializer]) – 形状、dtype 的可调用对象,用于生成参数的初始值。

返回类型

jax.Array

返回

具有给定形状参数的 jax.Array。

get_state#

haiku.get_state(name, shape=None, dtype=<class 'jax.numpy.float32'>, init=None)[source]#

获取状态的当前值,带有可选的初始化器。

“状态”可用于表示网络中的可变状态。状态最常见的用法是表示批归一化中使用的移动平均值(请参阅 ExponentialMovingAverage)。如果您的网络使用“状态”,则您需要使用 transform_with_state() 并将状态传入和传出 apply 函数。

>>> print(hk.get_state("counter", [], init=jnp.zeros))
0.0

如果给定状态的值已定义(例如,使用 set_state()),则您可以仅使用名称调用

>>> print(hk.get_state("counter"))
0.0

注意:在同一 transform() 和/或 Module 中具有相同名称的状态具有相同的值

>>> c1 = hk.get_state("counter")
>>> c2 = hk.get_state("counter")
>>> assert c1 is c2
参数
  • name (str) – 状态的名称。

  • shape (Optional[Sequence[int]]) – 状态的形状。

  • dtype (Any) – 状态的 dtype。

  • init (Optional[Initializer]) – 可调用对象 f(shape, dtype),它返回状态的初始值。

返回类型

jax.Array

返回

具有给定形状状态的 jax.Array。

set_state#

haiku.set_state(name, value)[source]#

设置某些状态的当前值。

请参阅 get_state()

“状态”可用于表示网络中的可变状态。状态最常见的用法是表示批归一化中使用的移动平均值(请参阅 ExponentialMovingAverage)。如果您的网络使用“状态”,则您需要使用 transform_with_state() 并将状态传入和传出 apply 函数。

>>> hk.set_state("counter", jnp.zeros([]))
>>> print(hk.get_state("counter"))
0.0

注意:在同一 transform() 和/或 Module 中具有相同名称的状态具有相同的值

>>> w1 = hk.get_state("counter")
>>> w2 = hk.get_state("counter")
>>> assert w1 is w2
参数
  • name (str) – 状态的名称。

  • value – 要设置的值。

Getter 和 Interceptor#

custom_creator(creator, *[, params, state])

注册自定义参数和/或状态创建器。

custom_getter(getter, *[, params, state])

注册自定义参数或状态 getter。

custom_setter(setter)

注册自定义状态 setter。

GetterContext(full_name, module, ...)

有关参数创建位置的上下文。

SetterContext(full_name, module, ...)

有关状态设置位置的上下文。

intercept_methods(interceptor)

注册新的方法 interceptor。

MethodContext(module, method_name, ...)

只读状态,显示方法的调用上下文。

custom_creator#

haiku.custom_creator(creator, *, params=True, state=False)[source]#

注册自定义参数和/或状态创建器。

当通过 get_parameter() 创建新参数时,我们首先运行自定义创建器,传递用户定义的值。例如

>>> def zeros_creator(next_creator, shape, dtype, init, context):
...   init = jnp.zeros
...   return next_creator(shape, dtype, init)
>>> with hk.custom_creator(zeros_creator):
...   z = hk.get_parameter("z", [], jnp.float32, jnp.ones)
>>> print(z)
0.0

如果 state=True,则您的创建器还将运行在对 get_state() 的调用上

>>> with hk.custom_creator(zeros_creator, state=True):
...   z = hk.get_state("z", [], jnp.float32, jnp.ones)
>>> print(z)
0.0
参数
  • creator (Creator) – 参数创建器。

  • params (bool) – 是否拦截参数创建,默认为 True

  • state (bool) – 是否拦截状态创建,默认为 False

返回类型

contextlib.AbstractContextManager

返回

创建器处于活动状态的上下文管理器。

custom_getter#

haiku.custom_getter(getter, *, params=True, state=False)[source]#

注册自定义参数或状态 getter。

当使用 get_parameter() 检索参数时,我们总是在向用户返回值之前运行所有自定义 getter。

>>> def bf16_getter(next_getter, value, context):
...   value = value.astype(jnp.bfloat16)
...   return next_getter(value)
>>> with hk.custom_getter(bf16_getter):
...   w = hk.get_parameter("w", [], jnp.float32, jnp.ones)
>>> w.dtype
dtype(bfloat16)

如果 state=True,则 getter 还将为 get_state() 的调用运行。

>>> with hk.custom_getter(bf16_getter, state=True):
...   c = hk.get_state("c", [], jnp.float32, jnp.ones)
>>> c.dtype
dtype(bfloat16)
参数
  • getter (Getter) – 参数 getter。

  • params (bool) – getter 是否应在 get_parameter() 上运行

  • state (bool) – getter 是否应在 get_state() 上运行。

返回类型

contextlib.AbstractContextManager

返回

getter 在其下处于活动状态的上下文管理器。

custom_setter#

haiku.custom_setter(setter)[source]#

注册自定义状态 setter。

当使用 set_state() 设置状态时,我们总是在保存值之前运行所有自定义 setter。

>>> def zero_during_init(next_setter, value, context):
...   if hk.running_init():
...     value = jnp.zeros_like(value)
...   return next_setter(value)
>>> with hk.custom_setter(zero_during_init):
...   hk.set_state("x", jnp.ones([2]))
...   x = hk.get_state("x")
>>> print(x)
[0. 0.]
参数

setter (Setter) – 状态 setter。

返回类型

contextlib.AbstractContextManager

返回

setter 在其下处于活动状态的上下文管理器。

GetterContext#

class haiku.GetterContext(full_name: str, module: Optional[Module], original_dtype: Any, original_shape: Sequence[int], original_init: Optional[Initializer], lifted_prefix_name: Optional[str])[source]#

有关参数创建位置的上下文。

full_name#

给定参数的完整名称 (例如 mlp/~/linear_0/w)。

类型

str

module#

拥有当前参数的模块,如果此参数存在于任何模块之外,则为 None

类型

Optional[Module]

original_dtype#

get_parameter()get_state() 最初调用时的数据类型。

类型

Any

original_shape#

get_parameter()get_state() 最初调用时的形状。

类型

Sequence[int]

original_init#

get_parameter()get_state() 最初调用时的初始化器。

类型

Optional[Initializer]

lifted_prefix_name#

所有封闭 lifted 模块的模块名称(有关更多上下文,请参见 lift())。 将此字符串作为前缀添加到 full_name 将等于外部转换的参数字典中的最终参数名称。 注意:当在 apply 上下文中调用 get_parameter()get_state() 时,此名称将始终为 None,因为仅提升 init 函数。

类型

Optional[str]

module_name#

封闭模块的完整名称。

name#

此参数的名称。

SetterContext#

class haiku.SetterContext(full_name: str, module: Optional[Module], original_dtype: Any, original_shape: Sequence[int], lifted_prefix_name: Optional[str])[source]#

有关状态设置位置的上下文。

full_name#

给定状态的完整名称 (例如 mlp/~/linear_0/w)。

类型

str

module#

拥有当前状态的模块,如果此状态存在于任何模块之外,则为 None

类型

Optional[Module]

original_dtype#

set_state() 最初调用时的数据类型。

类型

Any

original_shape#

set_state()get_state() 最初调用时的形状。

类型

Sequence[int]

lifted_prefix_name#

所有封闭 lifted 模块的模块名称(有关更多上下文,请参见 lift())。 将此字符串作为前缀添加到 full_name 将等于外部转换的参数字典中的最终参数名称。 注意:当在 apply 上下文中调用 get_parameter()get_state() 时,此名称将始终为 None,因为仅提升 init 函数。

类型

Optional[str]

module_name#

封闭模块的完整名称。

name#

此状态的名称。

intercept_methods#

haiku.intercept_methods(interceptor)[source]#

注册新的方法 interceptor。

方法拦截器允许您(远程地)拦截对模块的方法调用,并在调用底层方法之前修改 args/kwargs。在调用底层方法后,您可以修改其结果,然后再将其传递回用户。

例如,您可以拦截对 BatchNorm 的方法调用,并确保它始终以全精度计算

>>> def my_interceptor(next_f, args, kwargs, context):
...   if (type(context.module) is not hk.BatchNorm
...       or context.method_name != "__call__"):
...     # We ignore methods other than BatchNorm.__call__.
...     return next_f(*args, **kwargs)
...
...   def cast_if_array(x):
...     if isinstance(x, jax.Array):
...       x = x.astype(jnp.float32)
...     return x
...
...   args, kwargs = jax.tree_util.tree_map(cast_if_array, (args, kwargs))
...   out = next_f(*args, **kwargs)
...   return out

我们可以像往常一样创建和使用我们的模块,我们只需要将我们想要拦截的任何方法调用包装在上下文管理器中

>>> mod = hk.BatchNorm(decay_rate=0.9, create_scale=True, create_offset=True)
>>> x = jnp.ones([], jnp.bfloat16)
>>> with hk.intercept_methods(my_interceptor):
...   out = mod(x, is_training=True)
>>> assert out.dtype == jnp.float32

如果没有拦截器,BatchNorm 将以 bf16 计算,但是由于我们在调用底层方法之前转换了 x,因此我们在 f32 中计算。

参数

interceptor (MethodGetter) – 方法拦截器。

返回

拦截器在其下处于活动状态的上下文管理器。

MethodContext#

class haiku.MethodContext(module: 'Module', method_name: str, orig_method: Callable[..., Any], orig_class: type['Module'])[source]#

只读状态,显示方法的调用上下文。

例如,让我们定义两个拦截器并在上下文中打印值。此外,我们将使第一个拦截器有条件地短路,因为拦截器堆叠并按顺序运行,所以较早的拦截器可以决定调用下一个拦截器,或者短路并直接调用底层方法

>>> module = hk.Linear(1, name="method_context_example")
>>> short_circuit = False
>>> def my_interceptor_1(next_fun, args, kwargs, context):
...   print('running my_interceptor_1')
...   print('- module.name: ', context.module.name)
...   print('- method_name: ', context.method_name)
...   if short_circuit:
...     return context.orig_method(*args, **kwargs)
...   else:
...     return next_fun(*args, **kwargs)
>>> def my_interceptor_2(next_fun, args, kwargs, context):
...   print('running my_interceptor_2')
...   print('- module.name: ', context.module.name)
...   print('- method_name: ', context.method_name)
...   return next_fun(*args, **kwargs)

short_circuit=False 时,两个拦截器将按顺序运行

>>> with hk.intercept_methods(my_interceptor_1), \
...      hk.intercept_methods(my_interceptor_2):
...   _ = module(jnp.ones([1, 1]))
running my_interceptor_1
- module.name:  method_context_example
- method_name:  __call__
running my_interceptor_2
- module.name:  method_context_example
- method_name:  __call__

设置 short_circuit=True 将导致第一个拦截器调用原始方法(而不是 next_fun,这将触发下一个拦截器)

>>> short_circuit = True
>>> with hk.intercept_methods(my_interceptor_1), \
...      hk.intercept_methods(my_interceptor_2):
...   _ = module(jnp.ones([1, 1]))
running my_interceptor_1
- module.name:  method_context_example
- method_name:  __call__
module#

正在调用其方法的 Module 实例。

类型

‘Module’

method_name#

正在模块上调用的方法的名称。

类型

str

orig_method#

模块上的底层方法,调用该方法不会触发拦截器。 仅当您要短路所有其他拦截器时才应调用此方法,通常您应该首选调用传递给您的拦截器的 next_fun,它将在运行所有其他拦截器后运行 orig_method

类型

Callable[…, Any]

orig_class#

定义 orig_method 的类。 请注意,使用继承时,这不一定与 type(module) 相同。

类型

type[‘Module’]

Random Numbers#

PRNGSequence(key_or_seed)

JAX 随机密钥的迭代器。

next_rng_key()

返回从当前全局密钥拆分的唯一 JAX 随机密钥。

next_rng_keys(num)

返回从当前全局密钥拆分的一个或多个 JAX 随机密钥。

maybe_next_rng_key()

next_rng_key() 如果随机数可用,否则为 None

reserve_rng_keys(num)

预分配一些 JAX RNG 密钥。

with_rng(key)

next_rng_key() 提供新的序列以从中提取。

maybe_get_rng_sequence_state()

返回 PRNG 序列的内部状态。

replace_rng_sequence_state(state)

用给定状态替换 PRNG 序列的内部状态。

PRNGSequence#

class haiku.PRNGSequence(key_or_seed)[source]#

JAX 随机密钥的迭代器。

>>> seq = hk.PRNGSequence(42)  # OR pass a jax.random.PRNGKey
>>> key1 = next(seq)
>>> key2 = next(seq)
>>> assert key1 is not key2

如果您知道您需要多少个密钥,则可以使用 reserve() 来更有效地拆分您需要的密钥

>>> seq.reserve(4)
>>> keys = [next(seq) for _ in range(4)]
__init__(key_or_seed)[source]#

创建一个新的 PRNGSequence

reserve(num)[source]#

拆分额外的 num 个密钥以供以后使用。

__next__()[source]#

从迭代器返回下一个项目。 当耗尽时,引发 StopIteration

返回类型

PRNGKey

next()[source]#

从迭代器返回下一个项目。 当耗尽时,引发 StopIteration

返回类型

PRNGKey

next_rng_key#

haiku.next_rng_key()[source]#

返回从当前全局密钥拆分的唯一 JAX 随机密钥。

>>> key = hk.next_rng_key()
>>> _ = jax.random.uniform(key, [])
返回类型

PRNGKey

返回

一个唯一的(在调用 initapply 中)JAX rng 密钥,可与 jax.random.uniform() 等 API 一起使用。

next_rng_keys#

haiku.next_rng_keys(num)[source]#

返回从当前全局密钥拆分的一个或多个 JAX 随机密钥。

>>> k1, k2 = hk.next_rng_keys(2)
>>> assert (k1 != k2).all()
>>> a = jax.random.uniform(k1, [])
>>> b = jax.random.uniform(k2, [])
>>> assert a != b
参数

num (int) – 要拆分的密钥数量。

返回类型

jax.Array

返回

形状为 [num, 2] 的数组,其中包含唯一的(在转换后的函数中)JAX rng 密钥,可与 jax.random.uniform() 等 API 一起使用。

maybe_next_rng_key#

haiku.maybe_next_rng_key()[source]#

next_rng_key() 如果随机数可用,否则为 None

返回类型

Optional[PRNGKey]

reserve_rng_keys#

haiku.reserve_rng_keys(num)[source]#

预分配一些 JAX RNG 密钥。

参见 next_rng_key()

此 API 提供了一种微优化在使用 Haiku 时如何拆分 RNG 密钥的方法。 除非您发现 init 函数的编译时间有问题,或者您在 apply 中采样大量随机数,否则您不太可能需要它。

>>> hk.reserve_rng_keys(2)  # Pre-allocate 2 keys for us to consume.
>>> _ = hk.next_rng_key()   # Takes the first pre-allocated key.
>>> _ = hk.next_rng_key()   # Takes the second pre-allocated key.
>>> _ = hk.next_rng_key()   # Splits a new key.
参数

num (int) – 要分配的 JAX rng 密钥的数量。

with_rng#

haiku.with_rng(key)[source]#

next_rng_key() 提供新的序列以从中提取。

当调用 next_rng_key() 时,它从由转换函数的输入密钥定义的 PRNGSequence 中提取新密钥。 此上下文管理器在作用域的持续时间内覆盖序列。

>>> with hk.with_rng(jax.random.PRNGKey(428)):
...   s = jax.random.uniform(hk.next_rng_key(), ())
>>> print("{:.1f}".format(s))
0.5
参数

key (PRNGKey) – 用于为序列播种的密钥。

返回

给定序列在其下处于活动状态的上下文管理器。

maybe_get_rng_sequence_state#

haiku.maybe_get_rng_sequence_state()[source]#

返回 PRNG 序列的内部状态。

返回类型

Optional[PRNGSequenceState]

返回

如果随机数可用,则为内部状态,否则为 None

replace_rng_sequence_state#

haiku.replace_rng_sequence_state(state)[source]#

用给定状态替换 PRNG 序列的内部状态。

参数

state (PRNGSequenceState) – 新的内部状态或 None

Raises

MissingRNGError – 如果随机数不可用。

Type Hints#

LSTMState(hidden, cell)

LSTM 核心状态由隐藏向量和单元向量组成。

Params

Mapping 是用于关联键/值对的通用容器。

MutableParams

MutableMapping 是用于关联键/值对的通用容器。

State

Mapping 是用于关联键/值对的通用容器。

MutableState

MutableMapping 是用于关联键/值对的通用容器。

Transformed(init, apply)

保存一对纯函数。

TransformedWithState(init, apply)

保存一对纯函数。

MultiTransformed(init, apply)

保存纯函数的集合。

MultiTransformedWithState(init, apply)

保存纯函数的集合。

ModuleProtocol(*args, **kwargs)

Module 类似类型的协议。

SupportsCall(*args, **kwargs)

可调用的 Module 类似类型的协议。

LSTMState#

class haiku.LSTMState(hidden: jax.Array, cell: jax.Array)[source]#

LSTM 核心状态由隐藏向量和单元向量组成。

hidden#

隐藏状态。

类型

jax.Array

cell#

单元状态。

类型

jax.Array

Params#

haiku.Params#

alias of collections.abc.Mapping[str, collections.abc.Mapping[str, jax.Array]]

MutableParams#

haiku.MutableParams#

alias of collections.abc.MutableMapping[str, collections.abc.MutableMapping[str, jax.Array]]

State#

haiku.State#

alias of collections.abc.Mapping[str, collections.abc.Mapping[str, jax.Array]]

MutableState#

haiku.MutableState#

alias of collections.abc.MutableMapping[str, collections.abc.MutableMapping[str, jax.Array]]

Transformed#

class haiku.Transformed(init: Callable[..., hk.MutableParams], apply: Callable[..., Any])[source]#

保存一对纯函数。

init#

一个纯函数:params = init(rng, *a, **k)

类型

Callable[…, hk.MutableParams]

apply#

一个纯函数:out = apply(params, rng, *a, **k)

类型

Callable[…, Any]

TransformedWithState#

class haiku.TransformedWithState(init: Callable[..., tuple[hk.MutableParams, hk.MutableState]], apply: Callable[..., tuple[Any, hk.MutableState]])[source]#

保存一对纯函数。

init#

一个纯函数:params, state = init(rng, *a, **k)

类型

Callable[…, tuple[hk.MutableParams, hk.MutableState]]

apply#

一个纯函数:out, state = apply(params, state, rng, *a, **k)

类型

Callable[…, tuple[Any, hk.MutableState]]

MultiTransformed#

class haiku.MultiTransformed(init: Callable[..., hk.MutableParams], apply: Any)[source]#

保存纯函数的集合。

init#

一个纯函数:params = init(rng, *a, **k)

类型

Callable[…, hk.MutableParams]

apply#

一个 JAX 纯函数树,每个函数都具有以下签名:out = apply(params, rng, *a, **k)

类型

Any

另请参阅

MultiTransformedWithState#

class haiku.MultiTransformedWithState(init: Callable[..., tuple[hk.MutableParams, hk.MutableState]], apply: Any)[source]#

保存纯函数的集合。

init#

一个纯函数:params, state = init(rng, *a, **k)

类型

Callable[…, tuple[hk.MutableParams, hk.MutableState]]

apply#

一个 JAX 纯函数树,每个函数都具有以下签名:out, state = apply(params, state, rng, *a, **k)

类型

Any

另请参阅

ModuleProtocol#

class haiku.ModuleProtocol(*args, **kwargs)[source]#

Module 类似类型的协议。

SupportsCall#

class haiku.SupportsCall(*args, **kwargs)[source]#

可调用的 Module 类似类型的协议。

作为一个协议意味着你不需要显式地扩展此类型,以便支持使用它的实例检查。例如,Linear 仅扩展了 Module,但由于它符合(例如,实现了 __call__)此协议,因此你可以使用它进行实例检查

>>> assert isinstance(hk.Linear(1), hk.SupportsCall)

Flax Interop#

Haiku 在 Flax 内部#

Module#

class haiku.experimental.flax.Module(transformed, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

一个 Flax nn.Module,它运行 Haiku 转换函数。

此类型旨在使获取 Haiku 转换函数和/或 Haiku 模块并在其他使用 Flax 的程序中使用它变得容易。

给定一个 Haiku 转换函数

>>> def f(x):
...   return hk.Linear(1)(x)
>>> f = hk.transform(f)

你可以使用以下方法将其转换为 Flax 模块

>>> mod = hk.experimental.flax.Module(f)

调用此模块与调用任何常规 Flax 模块相同

>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([1, 1])
>>> variables = mod.init(rng, x)
>>> out = mod.apply(variables, x)

如果你只想转换 Haiku 模块类,使其可以与 Flax 一起使用,则可以使用 create 类方法

>>> mod = hk.experimental.flax.Module.create(hk.Linear, 1)
>>> variables = mod.init(rng, x)
>>> out = mod.apply(variables, x)

flatten_flax_to_haiku#

haiku.experimental.flax.flatten_flax_to_haiku(collection)[source]#

将 Flax 变量集合(例如 params)展平为 Haiku 字典。

返回类型

HaikuParamsOrState

Flax 在 Haiku 内部#

lift#

haiku.experimental.flax.lift(mod, *, name)[source]#

将 flax nn.Module 提升为 Haiku 转换函数。

对于 Flax 模块(例如 mod = nn.Dense(10)),mod = lift(mod) 允许你运行模块的 call 方法,就像该模块是常规 Haiku 模块一样。

Flax 模块的参数和状态已在 Haiku 中注册,并成为 params/state 字典的一部分(从 init/apply 返回)。

>>> def f(x):
...   # Create and "lift" a Flax module.
...   mod = hk.experimental.flax.lift(nn.Dense(300), name='dense')
...   x = mod(x)                  # Any params/state will be registered
...                               # with Haiku when applying the module.
...   x = jax.nn.relu(x)
...   x = hk.nets.MLP([100, 10])  # You can of course mix Haiku modules in.
...   return x
>>> f = hk.transform(f)
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([1, 1])
>>> params = f.init(rng, x)
>>> out = f.apply(params, None, x)
参数
  • mod (nn.Module) – 任何 Flax nn.Module 实例。

  • name (str) – 用于为外部 params/state 字典中的条目添加前缀的名称作用域。

返回类型

Callable[…, Any]

返回

一个函数,在应用时调用给定 Flax 模块的 call 方法并返回其输出。 作为调用模块的副作用,任何模块参数和状态变量都会在 Haiku 中注册。

高级状态管理#

提升 (Lifting)#

lift(init_fn, *[, allow_reuse, name])

在外部转换中注册来自内部 init 函数的参数。

lift_with_state(init_fn, *[, allow_reuse, name])

在外部转换中注册来自 init 函数的参数和状态。

transparent_lift(init_fn, *[, allow_reuse])

在外部转换中注册参数,而不添加名称作用域。

transparent_lift_with_state(init_fn, *[, ...])

在外部转换中注册参数和状态,而不添加作用域。

LiftWithStateUpdater(name)

处理更新 lift_with_state 计算的状态。

lift#

haiku.lift(init_fn, *, allow_reuse=False, name='lifted')[source]#

在外部转换中注册来自内部 init 函数的参数。

提示:当你想在 transform()transform_with_state()内部对 JAX 转换(例如 jax.vmap)进行非平凡的使用时,可以使用 lift()。 我们通常建议尝试对 transform() 返回的纯函数使用 JAX 转换,在这种情况下,你不需要 lift()

在嵌套 Haiku 转换以在任何外部转换中注册内部转换的参数时,请使用 lift()。 当在 Haiku 模块内部使用 JAX 函数(例如,在层上使用 jax.vmap)时,这非常有用。 有关何时使用 lift() 的更多说明,请参阅 https://haiku.jax.net.cn/en/latest/notebooks/transforms.html#Using-hk.lift。 (如果你在模块内部没有使用 JAX 函数,或者在转换内部不需要访问你的参数,那么你可能不需要使用 lift()

必须在 transform() 内部调用,并传递 Transformedinit 成员。

在 init 期间,返回的可调用对象将运行给定的 init_fn,并将生成的参数包含在外部转换的字典中。 在 apply 期间,返回的可调用对象将改为从外部转换的字典中拉取相关参数。

默认情况下,用户必须确保给定的 init 不会通过函数闭包意外捕获来自外部 transform() 的模块。 如果需要此行为,请将 allow_reuse 设置为 True

示例

lift() 的常见用法是在 transform() 内部以非平凡的方式使用 JAX 转换,例如 vmap。 例如,我们可以使用 lift()jax.vmap 来创建集成。

首先,我们将创建一个辅助函数,该函数使用 lift()vmap 应用于我们的模型。 正如你从注释中看到的那样,我们使用 vmap 来更改应如何创建参数(在这种情况下,我们为集成的每个成员创建一组唯一的参数),并且我们更改了 apply 的工作方式(我们“映射”参数,这意味着 JAX 将为集成的每个成员分别并行计算前向传递)

>>> def create_ensemble(model, size: int):
...   init_rng = hk.next_rng_keys(size) if hk.running_init() else None
...   model = hk.transform(model)
...   # in_axes: rng is mapped, data is not.
...   init_model = jax.vmap(model.init, in_axes=(0, None))
...   # Use hk.lift to "lift" parameters created by `init_model` into the
...   # outer transform.
...   init_model = hk.lift(init_model, name="ensemble")
...   def ensemble(x):
...     params = init_model(init_rng, x)
...     # in_axes: params are mapped, rng/data are not.
...     return jax.vmap(model.apply, in_axes=(0, None, None))(params, None, x)
...   return ensemble

我们现在可以使用此函数来集成任何 Haiku 模块,在转换内部。 首先,我们为集成的每个成员定义一个函数

>>> def member_fn(x):
...   return hk.nets.MLP([300, 100, 10])(x)

其次,我们可以将我们的两个函数组合在 transform() 内部以创建集成

>>> def f(x):
...   ensemble = create_ensemble(member_fn, size=4)
...   x = ensemble(x)
...   # You could create other modules here which were not ensembled.
...   return x
>>> f = hk.transform(f)

当我们初始化网络时,我们的集成成员的参数具有集成大小的前导维度

>>> rng = jax.random.PRNGKey(777)
>>> x = jnp.ones([32, 128])
>>> params = f.init(rng, x)
>>> jax.tree_util.tree_map(lambda x: x.shape, params)
{'ensemble/mlp/~/linear_0': {'b': (4, 300), 'w': (4, 128, 300)},
 'ensemble/mlp/~/linear_1': {'b': (4, 100), 'w': (4, 300, 100)},
 'ensemble/mlp/~/linear_2': {'b': (4, 10), 'w': (4, 100, 10)}}

当我们应用网络时,我们为整个批次的集成的每个成员获得一个输出

>>> y = f.apply(params, None, x)
>>> y.shape
(4, 32, 10)
参数
  • init_fn (Callable[..., hk.Params]) – 来自 Transformedinit 函数。

  • allow_reuse (bool) – 允许从外部 transform() 重用提升的参数和状态。 当在控制流(例如 hk.scan)中使用 lift 时,这可能是理想的。

  • name (str) – 用于为参数添加前缀的字符串名称。

返回类型

Callable[…, hk.Params]

返回

一个可调用对象,在 init 期间将参数值注入到外部上下文中,并在 apply 期间从外部上下文中检索参数。 在这两种情况下,都返回要与 apply 函数一起使用的参数值。

另请参阅

lift_with_state#

haiku.lift_with_state(init_fn, *, allow_reuse=False, name='lifted')[source]#

在外部转换中注册来自 init 函数的参数和状态。

有关何时使用 lift 的更多上下文,请参阅 lift()

此函数返回两个对象。 第一个是可调用对象,它在 init 与 apply 时间运行你的 init 函数时,行为略有不同。 第二个是更新器,可用于传递运行你的 apply 函数后产生的更新状态值。 有关工作示例,请参阅文档后面的内容。

在 init 期间,返回的可调用对象将运行给定的 init_fn,并将生成的参数/状态包含在外部转换的字典中。 在 apply 期间,返回的可调用对象将改为从外部转换的字典中拉取相关参数/状态。

必须在 transform_with_state() 内部调用,并传递 TransformedWithStateinit 成员。

默认情况下,用户必须确保给定的 init 不会通过函数闭包意外捕获来自外部 transform_with_state() 的模块。 如果需要此行为,请将 allow_reuse 设置为 True

示例

>>> def g(x):
...   return hk.nets.ResNet50(1)(x, True)
>>> g = hk.transform_with_state(g)
>>> params_and_state_fn, updater = (
...   hk.lift_with_state(g.init, name='f_lift'))
>>> init_rng = hk.next_rng_key() if hk.running_init() else None
>>> x = jnp.ones([1, 224, 224, 3])
>>> params, state = params_and_state_fn(init_rng, x)
>>> out, state = g.apply(params, state, None, x)
>>> updater.update(state)
参数
  • init_fn (Callable[..., tuple[hk.Params, hk.State]]) – 来自 TransformedWithStateinit 函数。

  • allow_reuse (bool) – 允许从外部 transform_with_state() 重用提升的参数和状态。 当在控制流(例如 hk.scan)中使用 lift_with_state 时,这可能是理想的。

  • name (str) – 用于为参数添加前缀的字符串名称。

返回类型

tuple[Callable[…, tuple[hk.Params, hk.State]], LiftWithStateUpdater]

返回

一个可调用对象,在 init 期间将参数值注入到外部上下文中,并在 apply 期间重用外部上下文中的参数。 在这两种情况下,都返回要与 apply 函数一起使用的参数值。 init 函数还会返回一个对象,该对象用于在使用 apply 后更新外部上下文中的新状态。

另请参阅

transparent_lift#

haiku.transparent_lift(init_fn, *, allow_reuse=False)[source]#

在外部转换中注册参数,而不添加名称作用域。

在功能上,这等效于 lift(),但没有自动添加额外的变量作用域。 请注意,不允许从外部作用域关闭模块。

有关何时使用 lift 的更多上下文,请参阅 lift()

参数
  • init_fn (Callable[..., hk.Params]) – 来自 Transformedinit 函数。

  • allow_reuse (bool) – 允许从外部 transform_with_state() 重用提升的参数。 例如,当在控制流(例如 hk.scan)中时,这可能是理想的。

返回类型

Callable[…, hk.Params]

返回

一个可调用对象,在 init 期间将参数值注入到外部上下文中,并在 apply 期间重用外部上下文中的参数。 在这两种情况下,都返回要与 apply 函数一起使用的参数值。

另请参阅

transparent_lift_with_state#

haiku.transparent_lift_with_state(init_fn, *, allow_reuse=False)[source]#

在外部转换中注册参数和状态,而不添加作用域。

从功能上讲,这等同于 lift_with_state(),但不会自动添加额外的变量作用域。

有关何时使用 lift_with_state 的更多上下文,请参阅 lift_with_state()

参数
  • init_fn (Callable[..., tuple[hk.Params, hk.State]]) – 来自 TransformedWithStateinit 函数。

  • allow_reuse (bool) – 允许从外部 transform_with_state() 重用提升的参数和状态。例如,在控制流(例如 hk.scan)中,这可能是理想的。

返回类型

tuple[Callable[…, tuple[hk.Params, hk.State]], LiftWithStateUpdater]

返回

一个可调用对象,在 init 期间将参数值注入到外部上下文中,并在 apply 期间重用外部上下文中的参数。 在这两种情况下,都返回要与 apply 函数一起使用的参数值。 init 函数还会返回一个对象,该对象用于在使用 apply 后更新外部上下文中的新状态。

另请参阅

LiftWithStateUpdater#

class haiku.LiftWithStateUpdater(name)[source]#

处理更新 lift_with_state 计算的状态。

层叠 (Layer Stack)#

layer_stack(num_layers[, ...])

用于包装 Haiku 函数并将其递归应用于输入的实用工具。

LayerStackTransparencyMapping(*args, **kwargs)

用于透明 layer_stack 的模块名称映射。

layer_stack#

class haiku.layer_stack(num_layers, with_per_layer_inputs=False, unroll=1, pass_reverse_to_layer_fn=False, transparent=False, transparency_map=None, name=None)[source]#

用于包装 Haiku 函数并将其递归应用于输入的实用工具。

这可以用于改善模型编译时间。

如果一个函数仅使用显式位置参数,并且其返回类型与其输入类型匹配,则该函数是有效的。位置参数可以是任意嵌套的结构,叶节点上带有 jax.Array。请注意,不支持 kwargs,也不支持具有可变数量参数的函数(由 *args 指定)。

请注意,目前 layer_stack 不能与构建带有状态的 Haiku 模块的函数一起使用。

如果 with_per_layer_inputs=False,那么新的、包装后的函数可以理解为执行以下操作

>>> f = lambda x: x+1
>>> num_layers = 4
>>> x = 0
>>> for i in range(num_layers):
...   x = f(x)
>>> x
4

如果 with_per_layer_inputs=True,假设 f 除了 x 之外还需要两个参数

>>> f = lambda x, y0, y1: (x+1, y0+y1)
>>> num_layers = 4
>>> x = 0
>>> ys_0 = [1, 2, 3, 4]
>>> ys_1 = [5, 6, 7, 8]
>>> zs = []
>>> for i in range(num_layers):
...   x, z = f(x, ys_0[i], ys_1[i])
...   zs.append(z)
>>> x, zs
(4, [6, 8, 10, 12])

使用 layer_stack 用于上述函数的代码将是

>>> f = lambda x, y0, y1: (x+1, y0+y1)
>>> num_layers = 4
>>> x = 0
>>> ys_0 = jnp.array([1, 2, 3, 4])
>>> ys_1 = jnp.array([5, 6, 7, 8])
>>> stack = hk.layer_stack(num_layers, with_per_layer_inputs=True)
>>> x, zs = stack(f)(x, ys_0, ys_1)
>>> print(x, zs)
4 [ 6  8 10 12]

有关更多示例,请查看 layer_stack_test.py 中的测试。

至关重要的是,在 f 内部创建的任何参数都不会在迭代之间共享。

参数
  • num_layers (int) – 包装函数迭代的次数。

  • with_per_layer_inputs – 是否将每层输入传递给包装函数。

  • unroll (int) – scan 使用的展开 (unroll)。

  • pass_reverse_to_layer_fn (bool) – 是否将 reverse 关键字传递给函数 f,以便它知道层叠是否正在正向或反向运行(以及底层的 scan)。 要反向运行层叠,您需要将 reverse=True 传递给层叠的调用。

  • transparent (bool) – 是否透明地应用 layer_stack。 当此项为 True 且提供了正确的 transparency_map 时,将以这样一种方式生成参数:layer_stack 可以被常规 for 循环替换,而不会更改参数树。

  • transparency_map (Optional[LayerStackTransparencyMapping]) – 如何将堆叠的模块名称映射到扁平名称和反向映射。 有关示例,请参阅 LayerStackTransparencyMappinglayer_stack_test.py

  • name (Optional[str]) – Haiku 上下文的名称。

返回

当使用有效函数调用时,将生成层叠的可调用对象。

LayerStackTransparencyMapping#

class haiku.LayerStackTransparencyMapping(*args, **kwargs)[source]#

用于透明 layer_stack 的模块名称映射。

命名 (Naming)#

name_scope(name, *[, method_name])

上下文管理器,为所有新模块、参数或状态添加前缀。

current_name()

返回当前活动的模块名称。

DO_NOT_STORE

导致参数或状态值不被存储。

get_params()

返回当前 transform() 的参数。

get_current_state()

返回当前 transform_with_state() 的当前状态。

get_initial_state()

返回当前 transform_with_state() 的初始状态。

force_name(name)

强制 Haiku 使用此名称,忽略所有上下文信息。

name_like(method_name)

允许方法命名得像其他方法。

transparent(method)

装饰器,用于包装方法,防止自动变量作用域包装。

name_scope#

haiku.name_scope(name, *, method_name='__call__')[source]#

上下文管理器,为所有新模块、参数或状态添加前缀。

>>> with hk.name_scope("my_name_scope"):
...   net = hk.Linear(1, name="my_linear")
>>> net.module_name
'my_name_scope/my_linear'

在模块内部使用时,在名称作用域内创建的任何子模块、参数或状态都将在其名称中添加前缀

>>> class MyModule(hk.Module):
...   def __call__(self, x):
...     with hk.name_scope("my_name_scope"):
...       submodule = hk.Linear(1, name="submodule")
...       w = hk.get_parameter("w", [], init=jnp.ones)
...     return submodule(x) + w
>>> f = hk.transform(lambda x: MyModule()(x))
>>> params = f.init(jax.random.PRNGKey(42), jnp.ones([1, 1]))
>>> jax.tree_util.tree_map(jnp.shape, params)
{'my_module/my_name_scope': {'w': ()},
 'my_module/my_name_scope/submodule': {'b': (1,), 'w': (1, 1)}}

名称作用域非常类似于将上下文管理器内的所有代码放在 Module 上的一个方法中,并使用您提供的名称。在幕后,这正是名称作用域的实现方式。

如果您熟悉 TensorFlow,那么 Haiku 的 name_scope() 类似于 TensorFlow 1 中的 tf.variable_scope(..) 和 TensorFlow 1 和 2 中的 tf.name_scope(..),因为它可以更改与模块、参数和状态关联的名称。

参数
  • name (str) – 要使用的名称作用域(例如 "foo""foo/bar")。

  • method_name (str) – (仅限高级用法)。 由于名称作用域等效于调用模块上的方法,因此方法名称属性允许您指定要模拟的方法名称。 大多数用户应将其保留为默认值 (“__call__”)。

返回类型

ContextManager[None]

返回

单次使用的上下文管理器,当激活时,会使用给定名称为新的模块、参数或状态添加前缀。

current_name#

haiku.current_name()[source]#

返回当前活动的模块名称。

在 Haiku 模块外部(但在 Haiku 转换内部),这将返回 ~,它与 params/state 字典中存储顶层值的键匹配。

>>> hk.current_name()
'~'

在模块内部,这将返回当前模块名称

>>> class ExampleModule(hk.Module):
...   def __call__(self):
...     return hk.current_name()
>>> ExampleModule()()
'example_module'

在名称作用域内部,这将返回当前名称作用域

>>> with hk.name_scope('example_name_scope'):
...   print(hk.current_name())
example_name_scope
返回类型

str

返回

当前活动的模块或名称作用域名称。 如果正在使用模块或名称作用域,则返回 ~

DO_NOT_STORE#

haiku.DO_NOT_STORE = <haiku._src.base.DoNotStore object>#

导致参数或状态值不被存储。

默认情况下,Haiku 会将从 get_parameter()get_state()set_state() 返回的值放入 init 返回的字典中。 这并非总是理想的。

例如,用户可能希望其网络的一部分来自预训练的检查点,并且他们可能希望冻结这些值(即,让它们不出现在稍后传递给 grad 的 params 字典中)。 您可以通过操作 params 字典来实现此目的,但有时使用自定义创建器/getter/setter 会更方便。

考虑以下函数

>>> def f(x):
...   x = hk.Linear(300, name='torso')(x)
...   x = hk.Linear(10, name='tail')(x)
...   return x

假设您有一组躯干的预训练权重

>>> pretrained = {'torso': {'w': jnp.ones([28 * 28, 300]),
...                         'b': jnp.ones([300])}}

首先,我们定义一个创建器,它告诉 Haiku 不要存储任何属于预训练字典一部分的参数

>>> def my_creator(next_creator, shape, dtype, init, context):
...   if context.module_name in pretrained:
...     return hk.DO_NOT_STORE
...   return next_creator(shape, dtype, init)

然后我们需要一个 getter,它从预训练字典中提供参数值

>>> def my_getter(next_getter, value, context):
...   if context.module_name in pretrained:
...     assert value is hk.DO_NOT_STORE
...     value = pretrained[context.module_name][context.name]
...   return next_getter(value)

最后,我们将把我们的函数包装在上下文管理器中,激活我们的创建器和 getter

>>> def f_with_pretrained_torso(x):
...   with hk.custom_creator(my_creator), \
...        hk.custom_getter(my_getter):
...     return f(x)

您可以看到,当我们运行我们的函数时,我们只从不在预训练字典中的模块中获取参数

>>> f_with_pretrained_torso = hk.transform(f_with_pretrained_torso)
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([1, 28 * 28])
>>> params = f_with_pretrained_torso.init(rng, x)
>>> assert list(params) == ['tail']

此值可用于初始化器、custom_creator()custom_setter()

get_params#

haiku.get_params()[source]#

返回当前 transform() 的参数。

>>> def report(when):
...   shapes = jax.tree_util.tree_map(jnp.shape, hk.get_params())
...   print(f'{when}: {shapes}')
>>> def f(x):
...   report('Before call')
...   x = hk.Linear(1)(x)
...   report('After call')
...   return x
>>> f = hk.transform(f)
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([1, 1])

init 期间,参数字典将在调用模块时被填充

>>> params = f.init(rng, x)
Before call: {}
After call: {'linear': {'b': (1,), 'w': (1, 1)}}

apply 期间,参数字典将保持不变

>>> _ = f.apply(params, None, x)
Before call: {'linear': {'b': (1,), 'w': (1, 1)}}
After call: {'linear': {'b': (1,), 'w': (1, 1)}}

注意:不运行 custom_getters() 或参数初始化器。

返回类型

Params

返回

参数字典的副本。 在 init 期间,此字典将填充到目前为止已创建的任何参数。 在 apply 期间,这将包含所有模块的所有参数(params 字典在 apply 期间不会更改)。

另请参阅

get_current_state#

haiku.get_current_state()[source]#

返回当前 transform_with_state() 的当前状态。

示例

>>> def report(when):
...   state = jax.tree_util.tree_map(int, hk.get_current_state())
...   print(f'{when}: {state}')
>>> def f():
...   report('Before get_state')
...   x = hk.get_state('x', [], init=jnp.zeros)
...   report('After get_state')
...   hk.set_state('x', x + 1)
...   report('After set_state')
>>> f = hk.transform_with_state(f)

init 期间,将返回最近设置的值(直接通过 set_state() 设置,或通过 get_state()init 参数设置)

>>> _, state = f.init(None)
Before get_state: {}
After get_state: {'~': {'x': 0}}
After set_state: {'~': {'x': 1}}

apply 期间,将使用最近设置的值;如果未设置任何值,则将使用传递到 apply 的值

>>> state = {'~': {'x': 10}}
>>> _ = f.apply({}, state, None)
Before get_state: {'~': {'x': 10}}
After get_state: {'~': {'x': 10}}
After set_state: {'~': {'x': 11}}

注意:不运行 custom_getters() 或状态初始化器。

返回类型

State

返回

将从 initapply 返回的状态字典的副本。

另请参阅

get_initial_state#

haiku.get_initial_state()[source]#

返回当前 transform_with_state() 的初始状态。

示例

>>> def report(when):
...   state = jax.tree_util.tree_map(int, hk.get_initial_state())
...   print(f'{when}: {state}')
>>> def f():
...   report('Before get_state')
...   x = hk.get_state('x', [], init=jnp.zeros)
...   report('After get_state')
...   hk.set_state('x', x + 1)
...   report('After set_state')
>>> f = hk.transform_with_state(f)

init 期间,将返回第一个设置的值(直接通过 set_state() 设置,或通过 get_state()init 参数设置)

>>> _, state = f.init(None)
Before get_state: {}
After get_state: {'~': {'x': 0}}
After set_state: {'~': {'x': 0}}

apply 期间,将使用传递到 apply 函数的值

>>> state = {'~': {'x': 10}}
>>> _ = f.apply({}, state, None)
Before get_state: {'~': {'x': 10}}
After get_state: {'~': {'x': 10}}
After set_state: {'~': {'x': 10}}

注意:不运行 custom_getters() 或状态初始化器。

返回类型

State

返回

将从 init 返回或传递到 apply 的状态字典的副本。

另请参阅

force_name#

haiku.force_name(name)[source]#

强制 Haiku 使用此名称,忽略所有上下文信息。

注意:此方法仅适用于高级用例,应尽可能避免使用,因为它在设置绝对名称时有效地强制执行单例模式。

Haiku 根据模块的创建位置(例如,创建它们的模块堆栈或当前的 name_scope())命名模块。 此函数允许您创建忽略所有这些内容并具有您提供的精确名称的模块。

如果您有两个模块并且想要强制它们共享参数,这可能会很有用

>>> mod0 = hk.Linear(1)
>>> some_hyperparameter = True
>>> if some_hyperparameter:
...   # Force mod1 and mod0 to have shared weights.
...   mod1 = hk.Linear(1, name=hk.force_name(mod0.module_name))
... else:
...   # mod0 and mod1 are independent.
...   mod1 = hk.Linear(1)

(此代码段的更简单版本是执行 mod1 = mod0 而不是使用 force_name,但是在实际示例中,使用 force_name 可能更简单,尤其是在您可能无法访问模块实例而需要大量管道的情况下,但获取模块名称很容易 [例如,它是一个超参数])。

参数

name (str) – 模块的字符串名称。 例如 "foo""foo/bar"

返回类型

str

返回

适合传递到任何 Haiku 模块构造函数的 name 参数中的值。

name_like#

haiku.name_like(method_name)[source]#

允许方法命名得像其他方法。

在 Haiku 中,子模块根据其父模块的名称以及创建它们的方法进行命名。 当重构代码时,可能需要保留以前的名称以保持检查点兼容性,这可以使用 name_like() 来实现。

例如,考虑以下玩具自动编码器

>>> class Autoencoder(hk.Module):
...   def __call__(self, x):
...     z = hk.Linear(10, name="enc")(x)  # name: autoencoder/enc
...     y = hk.Linear(10, name="dec")(z)  # name: autoencoder/dec
...     return y

如果我们想重构它,以便用户可以编码或解码,我们将创建两个方法(encode,decode),它们将创建和应用我们的模块。 为了保持与原始模块的检查点兼容性,我们可以使用 name_like() 来命名这些子模块,就像它们是在 __call__ 内部创建的一样

>>> class Autoencoder(hk.Module):
...   @hk.name_like("__call__")
...   def encode(self, x):
...     return hk.Linear(10, name="enc")(x)  # name: autoencoder/enc
...
...   @hk.name_like("__call__")
...   def decode(self, z):
...     return hk.Linear(10, name="dec")(z)  # name: autoencoder/dec
...
...   def __call__(self, x):
...     return self.decode(self.encode(x))

一个明显的缺点是,如果用户依赖 Haiku 的编号来处理给出唯一名称并使用 name_like() 进行重构。 例如,当重构以下内容时

>>> class Autoencoder(hk.Module):
...   def __call__(self, x):
...     y = hk.Linear(10)(z)  # name: autoencoder/linear_1
...     z = hk.Linear(10)(x)  # name: autoencoder/linear
...     return y

要使用 name_like(),encode/decode 中未命名的线性模块最终将具有相同的名称(均为:autoencoder/linear),因为模块编号仅在方法内应用

>>> class Autoencoder(hk.Module):
...   @hk.name_like("__call__")
...   def encode(self, x):
...     return hk.Linear(10)(x)  # name: autoencoder/linear
...
...   @hk.name_like("__call__")
...   def decode(self, z):
...     return hk.Linear(10)(z)  # name: autoencoder/linear  <-- NOT INTENDED

要解决这种情况,您需要在方法中使用以前的名称显式命名模块

>>> class Autoencoder(hk.Module):
...   @hk.name_like("__call__")
...   def encode(self, x):
...     return hk.Linear(10, name="linear")(x)    # name: autoencoder/linear
...
...   @hk.name_like("__call__")
...   def decode(self, z):
...     return hk.Linear(10, name="linear_1")(z)  # name: autoencoder/linear_1
参数

method_name (str) – 我们应采用其名称的方法的名称。 此方法实际上不必在类上定义。

返回类型

Callable[[T], T]

返回

一个装饰器,当应用于方法时,将其标记为具有不同的名称。

transparent#

haiku.transparent(method)[source]#

装饰器,用于包装方法,防止自动变量作用域包装。

默认情况下,在一个方法中创建的所有变量和模块都由模块和方法名称限定作用域。 在某些情况下,这是不希望的。 任何使用 transparent() 修饰的方法将在调用它的作用域中创建变量和模块。

参数

method (T) – 要包装的方法。

返回类型

T

返回

该方法,带有一个标志,指示不应发生名称作用域包装。

可视化 (Visualisation)#

to_dot(fun)

将使用 Haiku 模块的函数转换为 dot 图。

to_dot#

haiku.to_dot(fun)[source]#

将使用 Haiku 模块的函数转换为 dot 图。

要在 Google Colab 或 iPython 笔记本中查看结果图,请使用 graphviz

dot = hk.to_dot(f)(x)
import graphviz
graphviz.Source(dot)
参数

fun (Callable[..., Any]) – 使用 Haiku 模块的函数。

返回类型

Callable[…, str]

返回

一个函数,它返回 graphviz 图的源代码字符串,该图描述给定函数执行的操作,并按 Haiku 模块进行聚类。

另请参阅

abstract_to_dot():使用抽象输入生成 graphviz 图。

常用模块 (Common Modules)#

Linear#

Linear(output_size[, with_bias, w_init, ...])

线性模块 (Linear module)。

Bias([output_size, bias_dims, b_init, name])

向输入添加偏置 (bias)。

Linear#

class haiku.Linear(output_size, with_bias=True, w_init=None, b_init=None, name=None)[source]#

线性模块 (Linear module)。

__init__(output_size, with_bias=True, w_init=None, b_init=None, name=None)[source]#

构造 Linear 模块。

参数
  • output_size (int) – 输出维度。

  • with_bias (bool) – 是否向输出添加偏置 (bias)。

  • w_init (Optional[hk.initializers.Initializer]) – 权重的可选初始化器。 默认情况下,使用截断正态分布的随机值,标准差为 1 / sqrt(fan_in)。 请参阅 https://arxiv.org/abs/1502.03167v3

  • b_init (Optional[hk.initializers.Initializer]) – 偏置 (bias) 的可选初始化器。 默认情况下为零。

  • name (Optional[str]) – 模块的名称。

__call__(inputs, *, precision=None)[source]#

计算输入的线性变换。

返回类型

jax.Array

Bias#

class haiku.Bias(output_size=None, bias_dims=None, b_init=None, name=None)[source]#

向输入添加偏置 (bias)。

用法示例

>>> N, H, W, C = 1, 2, 3, 4
>>> x = jnp.ones([N, H, W, C])
>>> scalar_bias = hk.Bias(bias_dims=[])
>>> scalar_bias_output = scalar_bias(x)
>>> assert scalar_bias.bias_shape == ()

在所有非小批量维度上创建偏置 (bias)

>>> all_bias = hk.Bias()
>>> all_bias_output = all_bias(x)
>>> assert all_bias.bias_shape == (H, W, C)

创建跨越最后一个非小批量维度的偏置

>>> last_bias = hk.Bias(bias_dims=[-1])
>>> last_bias_output = last_bias(x)
>>> assert last_bias.bias_shape == (C,)

创建跨越第一个非小批量维度的偏置

>>> first_bias = hk.Bias(bias_dims=[1])
>>> first_bias_output = first_bias(x)
>>> assert first_bias.bias_shape == (H, 1, 1)

减去并在之后加上相同的学习偏置

>>> bias = hk.Bias()
>>> h1 = bias(x, multiplier=-1)
>>> h2 = bias(x)
>>> h3 = bias(x, multiplier=-1)
>>> reconstructed_x = bias(h3)
>>> assert (x == reconstructed_x).all()
__init__(output_size=None, bias_dims=None, b_init=None, name=None)[source]#

构造一个支持广播的 Bias 模块。

参数
  • output_size (Optional[Sequence[int]]) – 输出大小(不包含批次维度的输出形状)。如果 output_size 保留为 None,则大小将由输入直接推断。

  • bias_dims (Optional[Sequence[int]]) – 构建偏置时,从输入形状中保留的维度序列。剩余维度将被广播(给定大小为 1),并且前导维度将被完全移除。有关示例,请参阅类文档。

  • b_init (Optional[hk.initializers.Initializer]) – 偏置的可选初始化器。默认为零。

  • name (Optional[str]) – 模块的名称。

__call__(inputs, multiplier=None)[source]#

将偏置添加到 inputs,并可选择乘以 multiplier

参数
  • inputs (jax.Array) – 大小为 [batch_size, input_size1, ...] 的张量。

  • multiplier (Optional[Union[float, jax.Array]]) – 标量或张量,偏置项在添加到 inputs 之前与之相乘。任何在表达式 bias * multiplier 中起作用的东西在这里都是可以接受的。如果您想在一个地方添加偏置,并在另一个地方通过 multiplier=-1 减去相同的偏置,这可能很有用。

返回类型

jax.Array

返回

大小为 [batch_size, input_size1, ...] 的张量。

池化#

avg_pool(value, window_shape, strides, padding)

平均池化。

AvgPool(window_shape, strides, padding[, ...])

平均池化。

max_pool(value, window_shape, strides, padding)

最大池化。

MaxPool(window_shape, strides, padding[, ...])

最大池化。

平均池化#

haiku.avg_pool(value, window_shape, strides, padding, channel_axis=- 1)[source]#

平均池化。

参数
  • value (jax.Array) – 要池化的值。

  • window_shape (Union[int, Sequence[int]]) – 池化窗口的形状,与值具有相同的秩。

  • strides (Union[int, Sequence[int]]) – 池化窗口的步幅,与值具有相同的秩。

  • padding (str) – 填充算法。可以是 VALIDSAME

  • channel_axis (Optional[int]) – 跳过池化的空间通道的轴。

返回类型

jax.Array

返回

池化结果。与值具有相同的秩。

Raises

ValueError – 如果填充无效。

class haiku.AvgPool(window_shape, strides, padding, channel_axis=- 1, name=None)[source]#

平均池化。

等效于 avg_pool() 的部分应用。

__init__(window_shape, strides, padding, channel_axis=- 1, name=None)[source]#

平均池化。

参数
  • window_shape (Union[int, Sequence[int]]) – 池化窗口的形状,与值具有相同的秩。

  • strides (Union[int, Sequence[int]]) – 池化窗口的步幅,与值具有相同的秩。

  • padding (str) – 填充算法。可以是 VALIDSAME

  • channel_axis (Optional[int]) – 跳过池化的空间通道的轴。

  • name (Optional[str]) – 模块的字符串名称。

__call__(value)[source]#

将自身作为函数调用。

返回类型

jax.Array

最大池化#

haiku.max_pool(value, window_shape, strides, padding, channel_axis=- 1)[source]#

最大池化。

参数
  • value (jax.Array) – 要池化的值。

  • window_shape (Union[int, Sequence[int]]) – 池化窗口的形状,与值具有相同的秩。

  • strides (Union[int, Sequence[int]]) – 池化窗口的步幅,与值具有相同的秩。

  • padding (str) – 填充算法。可以是 VALIDSAME

  • channel_axis (Optional[int]) – 跳过池化的空间通道的轴。

返回类型

jax.Array

返回

池化结果。与值具有相同的秩。

class haiku.MaxPool(window_shape, strides, padding, channel_axis=- 1, name=None)[source]#

最大池化。

等效于 max_pool() 的部分应用。

__init__(window_shape, strides, padding, channel_axis=- 1, name=None)[source]#

最大池化。

参数
  • window_shape (Union[int, Sequence[int]]) – 池化窗口的形状,与值具有相同的秩。

  • strides (Union[int, Sequence[int]]) – 池化窗口的步幅,与值具有相同的秩。

  • padding (str) – 填充算法。可以是 VALIDSAME

  • channel_axis (Optional[int]) – 跳过池化的空间通道的轴。

  • name (Optional[str]) – 模块的字符串名称。

__call__(value)[source]#

将自身作为函数调用。

返回类型

jax.Array

Dropout#

dropout(rng, rate, x[, broadcast_dims])

以给定的速率随机丢弃输入中的单元。

dropout#

haiku.dropout(rng, rate, x, broadcast_dims=())[source]#

以给定的速率随机丢弃输入中的单元。

参见:http://www.cs.toronto.edu/~hinton/absps/dropout.pdf

参数
  • rng (PRNGKey) – JAX 随机密钥。

  • rate (float) – x 的每个元素被丢弃的概率。必须是范围 [0, 1) 内的标量。

  • x (jax.Array) – 要进行 dropout 的值。

  • broadcast_dims (Sequence[int]) – 指定将共享相同 dropout mask 的维度。

返回类型

jax.Array

返回

x,但已 dropout 并按 1 / (1 - rate) 缩放。

注意

这涉及到生成 x.size 个来自 U([0, 1)) 的伪随机样本,这些样本以与 rate 比较所需的完整精度计算。当 rate 是 Python 浮点数时,这通常是 32 位,这通常超过了应用程序的需求。一种解决方法是以较低的精度传递 rate,例如使用 np.float16(rate)

组合器#

Sequential(layers[, name])

顺序调用给定的层列表。

Sequential#

class haiku.Sequential(layers, name=None)[source]#

顺序调用给定的层列表。

请注意,Sequential 在它可以处理的可能架构范围方面是有限的。这是一个故意的设计决策;Sequential 仅用于融合模块/操作的简单情况,其中特定模块/操作的输入是前一个模块/操作的输出。

另一个限制是,在 __call__() 方法中不可能有传递给模块组成部分的其他参数 - 例如,如果 Sequential 中有一个 BatchNorm 模块,并且用户希望切换 is_training 标志。如果这是期望的用例,建议的解决方案是子类化 Module 并实现 __call__

>>> class CustomModule(hk.Module):
...   def __call__(self, x, is_training):
...     x = hk.Conv2D(32, 4, 2)(x)
...     x = hk.BatchNorm(True, True, 0.9)(x, is_training)
...     x = jax.nn.relu(x)
...     return x
__init__(layers, name=None)[source]#

使用给定名称初始化当前模块。

子类应在创建其他模块或变量之前调用此构造函数,以便正确命名这些模块。

参数

name (Optional[str]) – 类的可选字符串名称。必须是有效的 Python 标识符。如果未提供 name,则当前实例的类名将转换为 lower_snake_case 并改为使用。

__call__(inputs, *args, **kwargs)[source]#

顺序调用所有层。

卷积#

ConvND(num_spatial_dims, output_channels, ...)

通用 N 维卷积。

Conv1D(output_channels, kernel_shape[, ...])

一维卷积。

Conv2D(output_channels, kernel_shape[, ...])

二维卷积。

Conv3D(output_channels, kernel_shape[, ...])

三维卷积。

ConvNDTranspose(num_spatial_dims, ...[, ...])

通用 n 维转置卷积(又名反卷积)。

Conv1DTranspose(output_channels, kernel_shape)

一维转置卷积(又名反卷积)。

Conv2DTranspose(output_channels, kernel_shape)

二维转置卷积(又名反卷积)。

Conv3DTranspose(output_channels, kernel_shape)

三维转置卷积(又名反卷积)。

DepthwiseConv1D(channel_multiplier, kernel_shape)

一维卷积。

DepthwiseConv2D(channel_multiplier, kernel_shape)

二维卷积。

DepthwiseConv3D(channel_multiplier, kernel_shape)

三维卷积。

get_channel_index(data_format)

给定有效的数据格式时,返回通道索引。

ConvND#

class haiku.ConvND(num_spatial_dims, output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, feature_group_count=1, name=None)[source]#

通用 N 维卷积。

__init__(num_spatial_dims, output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, feature_group_count=1, name=None)[source]#

初始化模块。

参数
  • num_spatial_dims (int) – 输入的空间维度数。

  • output_channels (int) – 输出通道数。

  • kernel_shape (Union[int, Sequence[int]]) – 卷积核的形状。可以是整数或长度为 num_spatial_dims 的序列。

  • stride (Union[int, Sequence[int]]) – 卷积核的可选步幅。可以是整数或长度为 num_spatial_dims 的序列。默认为 1。

  • rate (Union[int, Sequence[int]]) – 可选的卷积核膨胀率。可以是整数或长度为 num_spatial_dims 的序列。1 对应于标准 ND 卷积,rate > 1 对应于膨胀卷积。默认为 1。

  • padding (Union[str, Sequence[tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – 可选的填充算法。可以是 VALIDSAME,或者是 n 个 (low, high) 整数对的序列,这些整数对给出了在每个空间维度之前和之后应用的填充。或者是一个可调用对象或大小为 num_spatial_dims 的可调用对象序列。任何可调用对象都必须接受一个等于有效卷积核大小的整数参数,并返回一个包含两个整数的序列,表示之前和之后的填充。有关更多详细信息和示例函数,请参阅 haiku.pad.*。默认为 SAME。参见:https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • with_bias (bool) – 是否添加偏置。默认情况下为 true。

  • w_init (Optional[hk.initializers.Initializer]) – 可选的权重初始化。默认情况下为截断正态分布。

  • b_init (Optional[hk.initializers.Initializer]) – 可选的偏置初始化。默认情况下为零。

  • data_format (str) – 输入的数据格式。可以是 channels_firstchannels_lastN...CNC...。默认情况下为 channels_last。参见 get_channel_index()

  • mask (Optional[jax.Array]) – 可选的权重掩码。

  • feature_group_count (int) – 分组卷积中的可选组数。默认值 1 对应于正常的密集卷积。如果使用更高的值,则卷积将分别应用于那么多组,然后堆叠在一起。这减少了给定 output_channels 的参数数量,并可能减少计算量。参见:https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • name (Optional[str]) – 模块的名称。

__call__(inputs, *, precision=None)[source]#

连接 ConvND 层。

参数
  • inputs (jax.Array) – 如果未批处理,则形状为 [spatial_dims, C] 且秩为 N+1 的数组;如果已批处理,则形状为 [N, spatial_dims, C] 且秩为 N+2 的数组。

  • precision (Optional[lax.Precision]) – 可选的 jax.lax.Precision,用于传递给 jax.lax.conv_general_dilated()

返回类型

jax.Array

返回

如果未批处理,则形状为 [spatial_dims, output_channels] 且秩为 N+1 的数组

如果已批处理,则形状为 [N, spatial_dims, output_channels] 且秩为 N+2 的数组。

Conv1D#

class haiku.Conv1D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, feature_group_count=1, name=None)[source]#

一维卷积。

__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, feature_group_count=1, name=None)[source]#

初始化模块。

参数
  • output_channels (int) – 输出通道数。

  • kernel_shape (Union[int, Sequence[int]]) – 卷积核的形状。可以是整数或长度为 1 的序列。

  • stride (Union[int, Sequence[int]]) – 卷积核的可选步幅。可以是整数或长度为 1 的序列。默认为 1。

  • rate (Union[int, Sequence[int]]) – 可选的卷积核膨胀率。可以是整数或长度为 1 的序列。1 对应于标准 ND 卷积,rate > 1 对应于膨胀卷积。默认为 1。

  • padding (Union[str, Sequence[tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – 可选的填充算法。可以是 VALIDSAME,或者长度为 1 的可调用对象或可调用对象序列。任何可调用对象都必须接受一个等于有效卷积核大小的整数参数,并返回一个包含两个整数的列表,表示之前和之后的填充。有关更多详细信息和示例函数,请参阅 haiku.pad.*。默认为 SAME。参见:https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • with_bias (bool) – 是否添加偏置。默认情况下为 true。

  • w_init (Optional[hk.initializers.Initializer]) – 可选的权重初始化。默认情况下为截断正态分布。

  • b_init (Optional[hk.initializers.Initializer]) – 可选的偏置初始化。默认情况下为零。

  • data_format (str) – 输入的数据格式。可以是 NWCNCW。默认情况下为 NWC

  • mask (Optional[jax.Array]) – 可选的权重掩码。

  • feature_group_count (int) – 分组卷积中的可选组数。默认值 1 对应于正常的密集卷积。如果使用更高的值,则卷积将分别应用于那么多组,然后堆叠在一起。这减少了给定 output_channels 的参数数量,并可能减少计算量。参见:https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • name (Optional[str]) – 模块的名称。

Conv2D#

class haiku.Conv2D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, feature_group_count=1, name=None)[source]#

二维卷积。

__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, feature_group_count=1, name=None)[source]#

初始化模块。

参数
  • output_channels (int) – 输出通道数。

  • kernel_shape (Union[int, Sequence[int]]) – 卷积核的形状。可以是整数或长度为 2 的序列。

  • stride (Union[int, Sequence[int]]) – 卷积核的可选步幅。可以是整数或长度为 2 的序列。默认为 1。

  • rate (Union[int, Sequence[int]]) – 可选的卷积核空洞率。可以是整数或长度为 2 的序列。1 对应于标准 ND 卷积,rate > 1 对应于空洞卷积。默认为 1。

  • padding (Union[str, Sequence[tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – 可选的填充算法。可以是 VALIDSAME,也可以是可调用对象或长度为 2 的可调用对象序列。任何可调用对象都必须接受一个等于有效卷积核大小的整数参数,并返回一个包含两个整数的列表,分别表示之前和之后的填充。有关更多详细信息和示例函数,请参阅 haiku.pad.*。默认为 SAME。参见: https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • with_bias (bool) – 是否添加偏置。默认情况下为 true。

  • w_init (Optional[hk.initializers.Initializer]) – 可选的权重初始化。默认情况下为截断正态分布。

  • b_init (Optional[hk.initializers.Initializer]) – 可选的偏置初始化。默认情况下为零。

  • data_format (str) – 输入数据格式。可以是 NHWCNCHW。默认情况下为 NHWC

  • mask (Optional[jax.Array]) – 可选的权重掩码。

  • feature_group_count (int) – 分组卷积中的可选组数。默认值 1 对应于正常的密集卷积。如果使用更高的值,则卷积将分别应用于那么多组,然后堆叠在一起。这减少了给定 output_channels 的参数数量,并可能减少计算量。参见:https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • name (Optional[str]) – 模块的名称。

Conv3D#

class haiku.Conv3D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, feature_group_count=1, name=None)[source]#

三维卷积。

__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, feature_group_count=1, name=None)[source]#

初始化模块。

参数
  • output_channels (int) – 输出通道数。

  • kernel_shape (Union[int, Sequence[int]]) – 卷积核的形状。可以是整数或长度为 3 的序列。

  • stride (Union[int, Sequence[int]]) – 卷积核的可选步幅。可以是整数或长度为 3 的序列。默认为 1。

  • rate (Union[int, Sequence[int]]) – 可选的卷积核空洞率。可以是整数或长度为 3 的序列。1 对应于标准 ND 卷积,rate > 1 对应于空洞卷积。默认为 1。

  • padding (Union[str, Sequence[tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – 可选的填充算法。可以是 VALIDSAME,也可以是可调用对象或长度为 3 的可调用对象序列。任何可调用对象都必须接受一个等于有效卷积核大小的整数参数,并返回一个包含两个整数的列表,分别表示之前和之后的填充。有关更多详细信息和示例函数,请参阅 haiku.pad.*。默认为 SAME。参见: https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • with_bias (bool) – 是否添加偏置。默认情况下为 true。

  • w_init (Optional[hk.initializers.Initializer]) – 可选的权重初始化。默认情况下为截断正态分布。

  • b_init (Optional[hk.initializers.Initializer]) – 可选的偏置初始化。默认情况下为零。

  • data_format (str) – 输入数据格式。可以是 NDHWCNCDHW。默认情况下为 NDHWC

  • mask (Optional[jax.Array]) – 可选的权重掩码。

  • feature_group_count (int) – 分组卷积中的可选组数。默认值 1 对应于正常的密集卷积。如果使用更高的值,则卷积将分别应用于那么多组,然后堆叠在一起。这减少了给定 output_channels 的参数数量,并可能减少计算量。参见:https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • name (Optional[str]) – 模块的名称。

ConvNDTranspose#

class haiku.ConvNDTranspose(num_spatial_dims, output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, name=None)[source]#

通用的 n 维转置卷积(又称反卷积)。

__init__(num_spatial_dims, output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, name=None)[source]#

初始化模块。

参数
  • num_spatial_dims (int) – 输入的空间维度数。

  • output_channels (int) – 输出通道数。

  • kernel_shape (Union[int, Sequence[int]]) – 卷积核的形状。可以是整数或长度为 num_spatial_dims 的序列。

  • stride (Union[int, Sequence[int]]) – 卷积核的可选步幅。可以是整数或长度为 num_spatial_dims 的序列。默认为 1。

  • output_shape (Optional[Union[int, Sequence[int]]]) – 转置卷积的空间维度的输出形状。可以是整数或整数的可迭代对象。如果给定 None 值,则会自动计算默认形状。

  • padding (Union[str, Sequence[tuple[int, int]]]) – 可选的填充算法。可以是 “VALID” 或 “SAME”。默认为 “SAME”。参见: https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • with_bias (bool) – 是否添加偏置。默认情况下为 true。

  • w_init (Optional[hk.initializers.Initializer]) – 可选的权重初始化。默认情况下为截断正态分布。

  • b_init (Optional[hk.initializers.Initializer]) – 可选的偏置初始化。默认情况下为零。

  • data_format (str) – 输入数据格式。可以是 channels_first, channels_last, N...CNC...。默认情况下为 channels_last

  • mask (Optional[jax.Array]) – 可选的权重掩码。

  • name (Optional[str]) – 模块的名称。

__call__(inputs, *, precision=None)[source]#

计算输入的转置卷积。

参数
  • inputs (jax.Array) – 如果未批处理,则形状为 [spatial_dims, C] 且秩为 N+1 的数组;如果已批处理,则形状为 [N, spatial_dims, C] 且秩为 N+2 的数组。

  • precision (Optional[lax.Precision]) – 可选的 jax.lax.Precision 传递给 jax.lax.conv_transpose()

返回类型

jax.Array

返回

如果未批处理,则形状为 [spatial_dims, output_channels] 且秩为 N+1 的数组

如果已批处理,则形状为 [N, spatial_dims, output_channels] 且秩为 N+2 的数组。

Conv1DTranspose#

class haiku.Conv1DTranspose(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, name=None)[source]#

一维转置卷积(又称反卷积)。

__init__(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, name=None)[source]#

初始化模块。

参数
  • output_channels (int) – 输出通道数。

  • kernel_shape (Union[int, Sequence[int]]) – 卷积核的形状。可以是整数或长度为 1 的序列。

  • stride (Union[int, Sequence[int]]) – 卷积核的可选步幅。可以是整数或长度为 1 的序列。默认为 1。

  • output_shape (Optional[Union[int, Sequence[int]]]) – 转置卷积的空间维度的输出形状。可以是整数或整数的可迭代对象。如果给定 None 值,则会自动计算默认形状。

  • padding (Union[str, Sequence[tuple[int, int]]]) – 可选的填充算法。可以是 VALIDSAME。默认为 SAME。参见: https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • with_bias (bool) – 是否添加偏置。默认情况下为 true。

  • w_init (Optional[hk.initializers.Initializer]) – 可选的权重初始化。默认情况下为截断正态分布。

  • b_init (Optional[hk.initializers.Initializer]) – 可选的偏置初始化。默认情况下为零。

  • data_format (str) – 输入的数据格式。可以是 NWCNCW。默认情况下为 NWC

  • mask (Optional[jax.Array]) – 可选的权重掩码。

  • name (Optional[str]) – 模块的名称。

Conv2DTranspose#

class haiku.Conv2DTranspose(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, name=None)[source]#

二维转置卷积(又称反卷积)。

__init__(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, name=None)[source]#

初始化模块。

参数
  • output_channels (int) – 输出通道数。

  • kernel_shape (Union[int, Sequence[int]]) – 卷积核的形状。可以是整数或长度为 2 的序列。

  • stride (Union[int, Sequence[int]]) – 卷积核的可选步幅。可以是整数或长度为 2 的序列。默认为 1。

  • output_shape (Optional[Union[int, Sequence[int]]]) – 转置卷积的空间维度的输出形状。可以是整数或整数的可迭代对象。如果给定 None 值,则会自动计算默认形状。

  • padding (Union[str, Sequence[tuple[int, int]]]) – 可选的填充算法。可以是 VALIDSAME。默认为 SAME。参见: https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • with_bias (bool) – 是否添加偏置。默认情况下为 true。

  • w_init (Optional[hk.initializers.Initializer]) – 可选的权重初始化。默认情况下为截断正态分布。

  • b_init (Optional[hk.initializers.Initializer]) – 可选的偏置初始化。默认情况下为零。

  • data_format (str) – 输入数据格式。可以是 NHWCNCHW。默认情况下为 NHWC

  • mask (Optional[jax.Array]) – 可选的权重掩码。

  • name (Optional[str]) – 模块的名称。

Conv3DTranspose#

class haiku.Conv3DTranspose(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, name=None)[source]#

三维转置卷积(又称反卷积)。

__init__(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, name=None)[source]#

初始化模块。

参数
  • output_channels (int) – 输出通道数。

  • kernel_shape (Union[int, Sequence[int]]) – 卷积核的形状。可以是整数或长度为 3 的序列。

  • stride (Union[int, Sequence[int]]) – 卷积核的可选步幅。可以是整数或长度为 3 的序列。默认为 1。

  • output_shape (Optional[Union[int, Sequence[int]]]) – 转置卷积的空间维度的输出形状。可以是整数或整数的可迭代对象。如果给定 None 值,则会自动计算默认形状。

  • padding (Union[str, Sequence[tuple[int, int]]]) – 可选的填充算法。可以是 VALIDSAME。默认为 SAME。参见: https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • with_bias (bool) – 是否添加偏置。默认情况下为 true。

  • w_init (Optional[hk.initializers.Initializer]) – 可选的权重初始化。默认情况下为截断正态分布。

  • b_init (Optional[hk.initializers.Initializer]) – 可选的偏置初始化。默认情况下为零。

  • data_format (str) – 输入数据格式。可以是 NDHWCNCDHW。默认情况下为 NDHWC

  • mask (Optional[jax.Array]) – 可选的权重掩码。

  • name (Optional[str]) – 模块的名称。

DepthwiseConv1D#

class haiku.DepthwiseConv1D(channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', name=None)[source]#

一维卷积。

__init__(channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', name=None)[source]#

构建一个 1D 深度可分离卷积。

参数
  • channel_multiplier (int) – 输出通道的倍数。要保持输出通道数与输入通道数相同,请设置为 1。

  • kernel_shape (Union[int, Sequence[int]]) – 卷积核的形状。可以是整数或长度为 1 的序列。

  • stride (Union[int, Sequence[int]]) – 卷积核的可选步幅。可以是整数或长度为 1 的序列。默认为 1。

  • rate (Union[int, Sequence[int]]) – 可选的卷积核膨胀率。可以是整数或长度为 1 的序列。1 对应于标准 ND 卷积,rate > 1 对应于膨胀卷积。默认为 1。

  • padding (Union[str, Sequence[tuple[int, int]]]) – 可选的填充算法。可以是 VALID, SAMEbefore, after 对的序列。默认为 SAME。参见: https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • with_bias (bool) – 是否添加偏置。默认情况下为 true。

  • w_init (Optional[hk.initializers.Initializer]) – 可选的权重初始化。默认情况下为截断正态分布。

  • b_init (Optional[hk.initializers.Initializer]) – 可选的偏置初始化。默认情况下为零。

  • data_format (str) – 输入的数据格式。可以是 channels_firstchannels_lastN...CNC...。默认情况下为 channels_last。参见 get_channel_index()

  • name (Optional[str]) – 模块的名称。

DepthwiseConv2D#

class haiku.DepthwiseConv2D(channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]#

二维卷积。

__init__(channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]#

构建一个 2D 深度可分离卷积。

参数
  • channel_multiplier (int) – 输出通道的倍数。要保持输出通道数与输入通道数相同,请设置为 1。

  • kernel_shape (Union[int, Sequence[int]]) – 卷积核的形状。可以是整数或长度为 2 的序列。

  • stride (Union[int, Sequence[int]]) – 卷积核的可选步幅。可以是整数或长度为 2 的序列。默认为 1。

  • rate (Union[int, Sequence[int]]) – 可选的卷积核膨胀率。可以是整数或长度为 1 的序列。1 对应于标准 ND 卷积,rate > 1 对应于膨胀卷积。默认为 1。

  • padding (Union[str, Sequence[tuple[int, int]]]) – 可选的填充算法。可以是 VALID, SAMEbefore, after 对的序列。默认为 SAME。参见: https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • with_bias (bool) – 是否添加偏置。默认情况下为 true。

  • w_init (Optional[hk.initializers.Initializer]) – 可选的权重初始化。默认情况下为截断正态分布。

  • b_init (Optional[hk.initializers.Initializer]) – 可选的偏置初始化。默认情况下为零。

  • data_format (str) – 输入的数据格式。可以是 channels_firstchannels_lastN...CNC...。默认情况下为 channels_last。参见 get_channel_index()

  • name (Optional[str]) – 模块的名称。

DepthwiseConv3D#

class haiku.DepthwiseConv3D(channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', name=None)[source]#

三维卷积。

__init__(channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', name=None)[source]#

构建一个 3D 深度可分离卷积。

参数
  • channel_multiplier (int) – 输出通道的倍数。要保持输出通道数与输入通道数相同,请设置为 1。

  • kernel_shape (Union[int, Sequence[int]]) – 卷积核的形状。可以是整数或长度为 3 的序列。

  • stride (Union[int, Sequence[int]]) – 卷积核的可选步幅。可以是整数或长度为 3 的序列。默认为 1。

  • rate (Union[int, Sequence[int]]) – 可选的卷积核膨胀率。可以是整数或长度为 1 的序列。1 对应于标准 ND 卷积,rate > 1 对应于膨胀卷积。默认为 1。

  • padding (Union[str, Sequence[tuple[int, int]]]) – 可选的填充算法。可以是 VALID, SAMEbefore, after 对的序列。默认为 SAME。参见: https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • with_bias (bool) – 是否添加偏置。默认情况下为 true。

  • w_init (Optional[hk.initializers.Initializer]) – 可选的权重初始化。默认情况下为截断正态分布。

  • b_init (Optional[hk.initializers.Initializer]) – 可选的偏置初始化。默认情况下为零。

  • data_format (str) – 输入的数据格式。可以是 channels_firstchannels_lastN...CNC...。默认情况下为 channels_last。参见 get_channel_index()

  • name (Optional[str]) – 模块的名称。

SeparableDepthwiseConv2D#

class haiku.SeparableDepthwiseConv2D(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]#

可分离的 2-D 深度可分离卷积模块。

__init__(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]#

构建一个可分离的 2D 深度可分离卷积模块。

参数
  • channel_multiplier (int) – 输出通道的倍数。要保持输出通道数与输入通道数相同,请设置为 1。

  • kernel_shape (Union[int, Sequence[int]]) – 卷积核的形状。可以是整数或长度为 num_spatial_dims 的序列。

  • stride (Union[int, Sequence[int]]) – 卷积核的可选步幅。可以是整数或长度为 num_spatial_dims 的序列。默认为 1。

  • padding (Union[str, Sequence[tuple[int, int]]]) – 可选的填充算法。可以是 VALID, SAMEbefore, after 对的序列。默认为 SAME。参见: https://tensorflowcn.cn/xla/operation_semantics#conv_convolution

  • with_bias (bool) – 是否添加偏置。默认情况下为 true。

  • w_init (Optional[hk.initializers.Initializer]) – 可选的权重初始化。默认情况下为截断正态分布。

  • b_init (Optional[hk.initializers.Initializer]) – 可选的偏置初始化。默认情况下为零。

  • data_format (str) – 输入数据格式。可以是 channels_first, channels_last, N...CNC...。默认情况下为 channels_last

  • name (Optional[str]) – 模块的名称。

__call__(inputs)[source]#

将自身作为函数调用。

返回类型

jax.Array

get_channel_index#

haiku.get_channel_index(data_format)[source]#

给定有效的数据格式时,返回通道索引。

>>> hk.get_channel_index('channels_last')
-1
>>> hk.get_channel_index('channels_first')
1
>>> hk.get_channel_index('N...C')
-1
>>> hk.get_channel_index('NCHW')
1
参数

data_format (str) – 字符串,用于从中获取通道索引的数据格式。有效的数据格式为空间 (例如,``NCHW``)、序列 (例如, BTHWD)、 channels_firstchannels_last)。

返回类型

int

返回

通道索引,为整数,可以是 1-1

Raises

ValueError – 如果数据格式无法识别。

Normalization#

BatchNorm(create_scale, create_offset, ...)

对输入进行归一化,以保持均值约为 ~0,标准差约为 ~1。

GroupNorm(groups[, axis, create_scale, ...])

组归一化模块。

InstanceNorm(create_scale, create_offset[, ...])

沿着空间维度对输入进行归一化。

LayerNorm(axis, create_scale, create_offset)

层归一化模块。

RMSNorm(axis[, eps, scale_init, name, ...])

RMSNorm 模块。

SpectralNorm([eps, n_steps, name])

通过其第一个奇异值对输入进行归一化。

ExponentialMovingAverage(decay[, ...])

维护指数移动平均值。

SNParamsTree([eps, n_steps, ignore_regex, name])

将谱归一化应用于树中的所有参数。

EMAParamsTree(decay[, zero_debias, ...])

为树中的所有参数维护指数移动平均值。

BatchNorm#

class haiku.BatchNorm(create_scale, create_offset, decay_rate, eps=1e-05, scale_init=None, offset_init=None, axis=None, cross_replica_axis=None, cross_replica_axis_index_groups=None, data_format='channels_last', name=None)[source]#

对输入进行归一化,以保持均值约为 ~0,标准差约为 ~1。

参见: https://arxiv.org/abs/1502.03167

用户在需要时,管理 scale 和 offset 的方式有很多不同的变体。这些变体包括:

  • 没有 scale/offset,在这种情况下,create_* 应设置为 False,并且在调用模块时不会传递 scale/offset

  • 可训练的 scale/offset,在这种情况下,create_* 应设置为 True,并且在调用模块时同样不会传递 scale/offset。在这种情况下,此模块创建并拥有 scale/offset 变量。

  • 外部生成的 scale/offset,例如用于条件归一化,在这种情况下,create_* 应设置为 False,然后在调用时传入这些值。

注意:jax.vmap(hk.transform(BatchNorm)) 将更新摘要统计信息并按批次归一化值;我们目前支持跨 vmap 引入的批次轴进行归一化。

__init__(create_scale, create_offset, decay_rate, eps=1e-05, scale_init=None, offset_init=None, axis=None, cross_replica_axis=None, cross_replica_axis_index_groups=None, data_format='channels_last', name=None)[source]#

构建一个 BatchNorm 模块。

参数
  • create_scale (bool) – 是否包含可训练的缩放因子。

  • create_offset (bool) – 是否包含可训练的偏移。

  • decay_rate (float) – EMA 的衰减率。

  • eps (float) – 用于避免除以零方差的小 epsilon 值。默认值为 1e-5,与论文和 Sonnet 中的值相同。

  • scale_init (Optional[hk.initializers.Initializer]) – gain(又名 scale)的可选初始化器。仅当 create_scale=True 时才能设置。默认值为 1

  • offset_init (Optional[hk.initializers.Initializer]) – bias(又名 offset)的可选初始化器。仅当 create_offset=True 时才能设置。默认值为 0

  • axis (Optional[Sequence[int]]) – 要在其上进行缩减的轴。默认值 (None) 表示应归一化除通道轴之外的所有轴。否则,这是将计算归一化统计信息的轴索引列表。

  • cross_replica_axis (Optional[Union[str, Sequence[str]]]) – 如果不为 None,则应为字符串(或字符串序列),表示在此模块中在 jax map(例如 jax.pmapjax.vmap)中运行的轴名称。提供此参数意味着跨命名轴上的所有副本计算批次统计信息。

  • cross_replica_axis_index_groups (Optional[Sequence[Sequence[int]]]) – 指定如何对设备进行分组。仅在 jax.pmap 集体通信中有效。

  • data_format (str) – 输入的数据格式。可以是 channels_firstchannels_lastN...CNC...。默认值为 channels_last。请参阅 get_channel_index()

  • name (Optional[str]) – 模块名称。

__call__(inputs, is_training, test_local_stats=False, scale=None, offset=None)[source]#

计算输入的归一化版本。

参数
  • inputs (jax.Array) – 数组,其中数据格式为 [..., C]

  • is_training (bool) – 是否处于训练期间。

  • test_local_stats (bool) – 当 is_training=False 时是否使用局部统计信息。

  • scale (Optional[jax.Array]) – 最多 n-D 的数组。此张量的形状必须可广播到 inputs 的形状。这是应用于归一化输入的 scale。如果模块构造时 create_scale=True,则无法传入此参数。

  • offset (Optional[jax.Array]) – 最多 n-D 的数组。此张量的形状必须可广播到 inputs 的形状。这是应用于归一化输入的 offset。如果模块构造时 create_offset=True,则无法传入此参数。

返回类型

jax.Array

返回

数组,在除最后一个维度之外的所有维度上进行归一化。

GroupNorm#

class haiku.GroupNorm(groups, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]#

组归一化模块。

这会将组归一化应用于 x。这涉及在计算均值和方差之前将通道拆分为组。默认行为是计算空间维度和分组通道上的均值和方差。永远不会在创建的组轴上计算均值和方差。

它将输入 x 转换为

\[\d{outputs} = \d{scale} \dfrac{x - \mu}{\sigma + \epsilon} + \d{offset}\]

其中 \(\mu\)\(\sigma\) 分别是 x 的均值和标准差。

用户在需要时,管理 scale 和 offset 的方式有很多不同的变体。这些变体包括:

  • 没有 scale/offset,在这种情况下,create_* 应设置为 False,并且在调用模块时不会传递 scale/offset

  • 可训练的 scale/offset,在这种情况下,create_* 应设置为 True,并且在调用模块时同样不会传递 scale/offset。在这种情况下,此模块创建并拥有 scale/offset 参数。

  • 外部生成的 scale/offset,例如用于条件归一化,在这种情况下,create_* 应设置为 False,然后在调用时传入这些值。

__init__(groups, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]#

构建一个 GroupNorm 模块。

参数
  • groups (int) – 用于划分通道的组数。通道数必须可被此数整除。

  • axis (Union[int, slice, Sequence[int]]) – intslice 或表示应在其上进行归一化的轴的整数序列。默认情况下,这是除第一个维度之外的所有维度。对于时间序列数据,请使用 slice(2, None) 对无批次和时间数据求平均值。

  • create_scale (bool) – 是否创建在归一化后应用的每个通道的可训练 scale。

  • create_offset (bool) – 是否创建在归一化和缩放后应用的每个通道的可训练 offset。

  • eps (float) – 添加到方差以避免除以零的小 epsilon 值。默认为 1e-5

  • scale_init (Optional[hk.initializers.Initializer]) – scale 参数的可选初始化器。仅当 create_scale=True 时才能设置。默认情况下,scale 初始化为 1

  • offset_init (Optional[hk.initializers.Initializer]) – offset 参数的可选初始化器。仅当 create_offset=True 时才能设置。默认情况下,offset 初始化为 0

  • data_format (str) – 输入的数据格式。可以是 channels_firstchannels_lastN...CNC...。默认值为 channels_last。请参阅 get_channel_index()

  • name (Optional[str]) – 模块的名称。

__call__(x, scale=None, offset=None)[source]#

返回归一化的输入。

参数
  • x (jax.Array) – 构造函数中指定的 data_format 的 n-D 张量,将在其上执行变换。

  • scale (Optional[jax.Array]) – 最多 n-D 的张量。此张量的形状必须可广播到 x 的形状。这是应用于归一化 x 的 scale。如果模块构造时 create_scale=True,则无法传入此参数。

  • offset (Optional[jax.Array]) – 最多 n-D 的张量。此张量的形状必须可广播到 x 的形状。这是应用于归一化 x 的 offset。如果模块构造时 create_offset=True,则无法传入此参数。

返回类型

jax.Array

返回

与 x 形状相同的 n 维张量,已归一化。

InstanceNorm#

class haiku.InstanceNorm(create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]#

沿着空间维度对输入进行归一化。

有关更多详细信息,请参见 LayerNorm

__init__(create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]#

构建一个 InstanceNorm 模块。

此方法创建一个模块,该模块在空间维度上进行归一化。

参数
  • create_scale (bool) – bool,表示是否创建在归一化后应用的每个通道的可训练 scale。

  • create_offset (bool) – bool,表示是否创建在归一化和缩放后应用的每个通道的可训练 offset。

  • eps (float) – 用于避免除以零方差的小 epsilon 值。默认为 1e-5

  • scale_init (Optional[hk.initializers.Initializer]) – scale 变量的可选初始化器。仅当 create_scale=True 时才能设置。默认情况下,scale 初始化为 1

  • offset_init (Optional[hk.initializers.Initializer]) – offset 变量的可选初始化器。仅当 create_offset=True 时才能设置。默认情况下,offset 初始化为 0

  • data_format (str) – 输入的数据格式。可以是 channels_firstchannels_lastN...CNC...。默认值为 channels_last。请参阅 get_channel_index()

  • name (Optional[str]) – 模块的名称。

LayerNorm#

class haiku.LayerNorm(axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, use_fast_variance=False, name=None, *, param_axis=None)[source]#

层归一化模块。

参见: https://arxiv.org/abs/1607.06450

使用示例

>>> ln = hk.LayerNorm(axis=-1, param_axis=-1,
...                   create_scale=True, create_offset=True)
>>> x = ln(jnp.ones([8, 224, 224, 3]))
__init__(axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, use_fast_variance=False, name=None, *, param_axis=None)[source]#

构建一个 LayerNorm 模块。

参数
  • axis (AxisOrAxes) – 整数,整数列表或切片,指示要标准化的轴。 请注意,scale/offset 参数的形状由 param_axis 参数控制。

  • create_scale (bool) – 布尔值,定义是否创建在标准化后应用的每个通道的可训练 scale。

  • create_offset (bool) – 布尔值,定义是否创建在标准化和缩放后应用的每个通道的可训练 offset。

  • eps (float) – 用于避免除以零方差的小 epsilon 值。默认值为 1e-5,与论文和 Sonnet 中的值相同。

  • scale_init (Optional[hk.initializers.Initializer]) – gain(也称为 scale)的可选初始化器。 默认值为 1。

  • offset_init (Optional[hk.initializers.Initializer]) – bias(也称为 offset)的可选初始化器。 默认值为零。

  • use_fast_variance (bool) – 如果为 true,则使用更快但数值稳定性较差的公式来计算方差。

  • name (Optional[str]) – 模块名称。

  • param_axis (Optional[AxisOrAxes]) – 用于确定可学习的 scale/offset 参数形状的轴。 Sonnet 将其设置为通道/特征轴(例如,对于 NHWC 设置为 -1)。 其他库将其设置为与 reduction 轴相同(例如 axis=param_axis)。

__call__(inputs, scale=None, offset=None)[source]#

连接 layer norm。

参数
  • inputs (jax.Array) – 一个数组,其中数据格式为 [N, ..., C]

  • scale (Optional[jax.Array]) – 最多 n-D 的数组。此张量的形状必须可广播到 inputs 的形状。这是应用于归一化输入的 scale。如果模块构造时 create_scale=True,则无法传入此参数。

  • offset (Optional[jax.Array]) – 最多 n-D 的数组。此张量的形状必须可广播到 inputs 的形状。这是应用于归一化输入的 offset。如果模块构造时 create_offset=True,则无法传入此参数。

返回类型

jax.Array

返回

标准化的数组。

RMSNorm#

class haiku.RMSNorm(axis, eps=1e-05, scale_init=None, name=None, create_scale=True, *, param_axis=None)[source]#

RMSNorm 模块。

RMSNorm 提供了一种替代方案,它比 LayerNorm 更快且更稳定。 输入通过均方根 (RMS) 进行标准化,并通过学习的参数进行缩放,但它们不会围绕均值重新居中。

参见 https://arxiv.org/pdf/1910.07467.pdf

__init__(axis, eps=1e-05, scale_init=None, name=None, create_scale=True, *, param_axis=None)[source]#

构建一个 RMSNorm 模块。

参数
  • axis (AxisOrAxes) – 整数,整数列表或切片,指示要标准化轴。

  • eps (float) – 小 epsilon 值,以避免除以零方差。 默认为 1e-5。

  • scale_init (Optional[hk.initializers.Initializer]) – gain(也称为 scale)的可选初始化器。 默认值为 1。

  • name (Optional[str]) – 模块名称。

  • create_scale (bool) – 布尔值,定义是否创建在标准化后应用的每个通道的可训练 scale。

  • param_axis (Optional[AxisOrAxes]) – 用于确定可学习的 scale/offset 参数形状的轴。 Sonnet 将其设置为通道/特征轴(例如,对于 NHWC 设置为 -1)。 其他库将其设置为与 reduction 轴相同(例如 axis=param_axis)。 None 默认为 (-1,)。

__call__(inputs)[source]#

连接 layer norm。

参数

inputs (jax.Array) – 一个数组,其中数据格式为 [N, ..., C]

返回

标准化的数组,形状与输入相同。

SpectralNorm#

class haiku.SpectralNorm(eps=0.0001, n_steps=1, name=None)[source]#

通过其第一个奇异值对输入进行归一化。

此模块使用幂迭代法根据输入和内部隐藏状态计算此值。

__init__(eps=0.0001, n_steps=1, name=None)[source]#

初始化 SpectralNorm 模块。

参数
  • eps (float) – 用于数值稳定性的常数。

  • n_steps (int) – 执行多少步幂迭代来近似输入的奇异值。

  • name (Optional[str]) – 模块的名称。

__call__(value, update_stats=True, error_on_non_matrix=False)[source]#

执行谱归一化并返回新值。

参数
  • value – 您想要对其执行谱归一化的类数组对象。

  • update_stats (bool) – 一个布尔值,默认为 True。 无论此参数如何,此函数都将返回归一化的输入。 当 update_stats 为 True 时,此对象的内部状态也将更新以反映输入值。 当 update_stats 为 False 时,内部统计信息将保持不变。

  • error_on_non_matrix (bool) – 谱归一化仅在矩阵上定义。 默认情况下,此模块将返回未更改的标量,并在其前导维度中展平高阶张量。 将此标志设置为 True 将在这些情况下抛出错误。

返回类型

jax.Array

返回

输入值通过其第一个奇异值进行归一化。

Raises

ValueError – 如果 error_on_non_matrix 为 True 且 value 的 ndims > 2。

ExponentialMovingAverage#

class haiku.ExponentialMovingAverage(decay, zero_debias=True, warmup_length=0, name=None)[source]#

维护指数移动平均值。

这使用了 Adam 去偏差程序。 有关详细信息,请参见 https://arxiv.org/pdf/1412.6980.pdf

__init__(decay, zero_debias=True, warmup_length=0, name=None)[source]#

初始化 ExponentialMovingAverage 模块。

参数
  • decay – 选择的衰减率。 必须在 [0, 1) 范围内。 接近 1 的值会导致缓慢衰减; 接近 0 的值会导致快速衰减。

  • zero_debias (bool) – 是否在零偏差下运行。

  • warmup_length (int) – 一个正整数,在内部计数器达到 warmup_length 之前,EMA 不起作用,此时衰减平均值的初始值初始化为 warmup_length 迭代后的输入值。

  • name (Optional[str]) – 模块的名称。

initialize(shape, dtype=<class 'jax.numpy.float32'>)[source]#

如果未初始化,则将平均值设置为给定 shape/dtype 的 zeros

__call__(value, update_stats=True)[source]#

更新 EMA 并返回新值。

参数
  • value (Union[float, jax.Array]) – 您想要对其执行指数衰减的类数组对象。

  • update_stats (bool) – 一个布尔值,指示是否更新此对象的内部状态以反映输入值。 当 update_stats 为 False 时,内部统计信息将保持不变。

返回类型

jax.Array

返回

输入值的指数加权平均值。

SNParamsTree#

class haiku.SNParamsTree(eps=0.0001, n_steps=1, ignore_regex='', name=None)[source]#

将谱归一化应用于树中的所有参数。

这与 moving_averages.py 中的 EMAParamsTree 同构。

__init__(eps=0.0001, n_steps=1, ignore_regex='', name=None)[source]#

初始化 SNParamsTree 模块。

参数
  • eps (float) – 用于数值稳定性的常数。

  • n_steps (int) – 执行多少步幂迭代来近似输入的奇异值。

  • ignore_regex (str) – 字符串。 树中名称与此正则表达式匹配的任何参数都不会应用谱归一化。 空字符串表示此模块适用于所有参数。

  • name (Optional[str]) – 模块的名称。

__call__(tree, update_stats=True)[source]#

将自身作为函数调用。

EMAParamsTree#

class haiku.EMAParamsTree(decay, zero_debias=True, warmup_length=0, ignore_regex='', name=None)[source]#

为树中的所有参数维护指数移动平均值。

虽然 ExponentialMovingAverage 旨在应用于函数中的单个参数,但此类旨在应用于函数的整个参数树。

给定某个网络的参数集

>>> network_fn = lambda x: hk.Linear(10)(x)
>>> x = jnp.ones([1, 1])
>>> params = hk.transform(network_fn).init(jax.random.PRNGKey(428), x)

您可以像下面这样使用 EMAParamsTree

>>> ema_fn = hk.transform_with_state(lambda x: hk.EMAParamsTree(0.2)(x))
>>> _, ema_state = ema_fn.init(None, params)
>>> ema_params, ema_state = ema_fn.apply(None, ema_state, None, params)

在这里,我们正在转换一个 Haiku 函数,并通过 init_fn 正常构建其参数,但正在创建第二个转换后的函数,该函数期望参数树作为输入。 然后使用当前参数作为输入调用此函数,然后返回一个相同的树,其中每个参数都替换为其指数衰减的平均值。 然后可以将此 ema_params 对象像往常一样传递到 network_fn 中,这将导致它使用 EMA 权重运行。

__init__(decay, zero_debias=True, warmup_length=0, ignore_regex='', name=None)[source]#

初始化 EMAParamsTree 模块。

参数
  • decay – 选择的衰减率。 必须在 [0, 1) 范围内。 接近 1 的值会导致缓慢衰减; 接近 0 的值会导致快速衰减。

  • zero_debias (bool) – 是否在零偏差下运行。

  • warmup_length (int) – 一个正整数,在内部计数器达到 warmup_length 之前,EMA 不起作用,此时衰减平均值的初始值初始化为 warmup_length 迭代后的输入值。

  • ignore_regex (str) – 字符串。 树中名称与此正则表达式匹配的任何参数都不会应用任何移动平均。 空字符串表示此模块将 EMA 所有参数。

  • name (Optional[str]) – 模块的名称。

__call__(tree, update_stats=True)[source]#

将自身作为函数调用。

Recurrent#

RNNCore([name])

RNN 核的基本类。

dynamic_unroll(core, input_sequence, ...[, ...])

执行 RNN 的动态展开。

static_unroll(core, input_sequence, ...[, ...])

执行 RNN 的静态展开。

expand_apply(f[, axis])

包装 f 以临时向其输入添加大小为 1 的轴。

VanillaRNN(hidden_size[, double_bias, name])

基本全连接 RNN 核。

LSTM(hidden_size[, name])

长短期记忆 (LSTM) RNN 核。

GRU(hidden_size[, w_i_init, w_h_init, ...])

门控循环单元。

DeepRNN(layers[, name])

将一系列核和可调用对象包装为单个核。

deep_rnn_with_skip_connections(layers[, name])

构建具有跳跃连接的 DeepRNN

ResetCore(core[, name])

用于管理展开期间状态重置的包装器。

IdentityCore([name])

一个循环核,它转发输入和一个空状态。

Conv1DLSTM(input_shape, output_channels, ...)

1-D 卷积 LSTM。

Conv2DLSTM(input_shape, output_channels, ...)

2-D 卷积 LSTM。

Conv3DLSTM(input_shape, output_channels, ...)

3-D 卷积 LSTM。

RNNCore#

class haiku.RNNCore(name=None)[source]#

RNN 核的基本类。

此类定义了每个核应实现的基本功能:initial_state(),用于构造核状态的示例;以及 __call__(),它将由先前状态参数化的核应用于输入。

核可以与 dynamic_unroll()static_unroll() 一起使用,以从给定的输入序列迭代构造输出序列。

abstract __call__(inputs, prev_state)[source]#

运行 RNN 的一个步骤。

参数
  • inputs – 任意嵌套的结构。

  • prev_state – 先前的核状态。

返回类型

tuple[Any, Any]

返回

具有两个元素的元组 output, next_stateoutput 是任意嵌套的结构。 next_state 是下一个核状态,它必须与 prev_state 具有相同的形状。

abstract initial_state(batch_size)[source]#

为此核构造一个初始状态。

参数

batch_size (可选[int]) – 可选的整数或表示批大小的整数标量张量。如果为 None,核心可能会失败,或者(实验性地)返回一个没有批次维度的初始状态。

返回

此核心的任意嵌套初始状态。

dynamic_unroll#

haiku.dynamic_unroll(core, input_sequence, initial_state, time_major=True, reverse=False, return_all_states=False, unroll=1)[source]#

执行 RNN 的动态展开。

展开对应于在一个循环中对输入序列的每个元素调用核心,并传递状态。

state = initial_state
for t in range(len(input_sequence)):
   outputs, state = core(input_sequence[t], state)

当在 jax.jit() 内部执行时,动态展开会保留循环结构。有关将循环替换为其主体重复多次的展开函数,请参阅 static_unroll()

参数
  • core – 要展开的 RNNCore

  • input_sequence – 如果 time-major=True,则为形状为 [T, ...] 的张量的任意嵌套结构;如果 time_major=False,则为形状为 [B, T, ...] 的张量的任意嵌套结构,其中 T 是时间步数。

  • initial_state – 给定核心的初始状态。

  • time_major – 如果为 True,则输入预计为时间优先 (time-major),否则预计为批次优先 (batch-major)。

  • reverse – 如果为 True,则按相反的顺序扫描输入。等效于反转输入和输出中的时间维度。有关更多详细信息,请参阅 https://jax.net.cn/en/latest/_autosummary/jax.lax.scan.html

  • return_all_states – 如果为 True,则返回所有中间状态,而不仅仅是时间上的最后一个状态。

  • unroll – 在循环的单次迭代中展开多少次扫描迭代。

返回

  • output_sequence - 如果为 time-major,则为形状为 [T, ...] 的张量的任意嵌套结构,否则为 [B, T, ...]

  • state_sequence - 如果 return_all_states 为 True,则返回核心状态序列。否则,返回时间步 T 的核心状态。

返回类型

包含两个元素的元组

static_unroll#

haiku.static_unroll(core, input_sequence, initial_state, time_major=True)[source]#

执行 RNN 的静态展开。

展开对应于在一个循环中对输入序列的每个元素调用核心,并传递状态。

state = initial_state
for t in range(len(input_sequence)):
   outputs, state = core(input_sequence[t], state)

当在 jax.jit() 内部执行时,静态展开会将循环替换为其主体重复多次的形式。

state = initial_state
outputs0, state = core(input_sequence[0], state)
outputs1, state = core(input_sequence[1], state)
outputs2, state = core(input_sequence[2], state)
...

有关保留循环的展开函数,请参阅 dynamic_unroll()

参数
  • core – 要展开的 RNNCore

  • input_sequence – 如果 time-major=True,则为形状为 [T, ...] 的张量的任意嵌套结构;如果 time_major=False,则为形状为 [B, T, ...] 的张量的任意嵌套结构,其中 T 是时间步数。

  • initial_state – 给定核心的初始状态。

  • time_major – 如果为 True,则输入预计为时间优先 (time-major),否则预计为批次优先 (batch-major)。

返回

  • output_sequence - 如果为 time-major,则为形状为 [T, ...] 的张量的任意嵌套结构,否则为 [B, T, ...]

  • final_state - 时间步 T 的核心状态。

返回类型

包含两个元素的元组

expand_apply#

haiku.expand_apply(f, axis=0)[source]#

包装 f 以临时向其输入添加大小为 1 的轴。

以下语法的简写:

ins = jax.tree_util.tree_map(lambda t: np.expand_dims(t, axis=axis), ins)
out = f(ins)
out = jax.tree_util.tree_map(lambda t: np.squeeze(t, axis=axis), out)

这可能对于将为 [Time, Batch, ...] 数组构建的函数应用于单个时间步非常有用。

参数
  • f – 要应用于扩展输入的被调用对象。

  • axis – 在哪里添加额外的轴。

返回

f,如上所述包装。

VanillaRNN#

class haiku.VanillaRNN(hidden_size, double_bias=True, name=None)[source]#

基本全连接 RNN 核。

给定 \(x_t\) 和先前的隐藏状态 \(h_{t-1}\),核心计算:

\[h_t = \operatorname{ReLU}(w_i x_t + b_i + w_h h_{t-1} + b_h)\]

输出等于新状态 \(h_t\)

__init__(hidden_size, double_bias=True, name=None)[source]#

构建一个 vanilla RNN 核心。

参数
  • hidden_size (int) – 隐藏层大小。

  • double_bias (bool) – 是否在两个线性层中使用偏置。这不会改变单元的学习性能。但是,加倍将创建两组偏置参数,而不是一组。

  • name (Optional[str]) – 模块的名称。

__call__(inputs, prev_state)[source]#

运行 RNN 的一个步骤。

参数
  • inputs – 任意嵌套的结构。

  • prev_state – 先前的核状态。

返回

具有两个元素的元组 output, next_stateoutput 是任意嵌套的结构。 next_state 是下一个核状态,它必须与 prev_state 具有相同的形状。

initial_state(batch_size)[source]#

为此核构造一个初始状态。

参数

batch_size (可选[int]) – 可选的整数或表示批大小的整数标量张量。如果为 None,核心可能会失败,或者(实验性地)返回一个没有批次维度的初始状态。

返回

此核心的任意嵌套初始状态。

LSTM#

class haiku.LSTM(hidden_size, name=None)[source]#

长短期记忆 (LSTM) RNN 核。

该实现基于 [1]。给定 \(x_t\) 和先前的状态 \((h_{t-1}, c_{t-1})\),核心计算:

\[\begin{array}{ll} i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

其中 \(i_t\)\(f_t\)\(o_t\) 是输入门、遗忘门和输出门激活,\(g_t\) 是单元更新向量。

输出等于新的隐藏状态 \(h_t\)

注意

遗忘门初始化

根据 [2],我们在初始化后将 1.0 添加到 \(b_f\),以减少训练开始时遗忘的规模。

__init__(hidden_size, name=None)[source]#

构建一个 LSTM。

参数
  • hidden_size (int) – 隐藏层大小。

  • name (Optional[str]) – 模块的名称。

__call__(inputs, prev_state)[source]#

运行 RNN 的一个步骤。

参数
  • inputs (jax.Array) – 任意嵌套结构。

  • prev_state (LSTMState) – 先前的核心状态。

返回类型

tuple[jax.Array, LSTMState]

返回

具有两个元素的元组 output, next_stateoutput 是任意嵌套的结构。 next_state 是下一个核状态,它必须与 prev_state 具有相同的形状。

initial_state(batch_size)[source]#

为此核构造一个初始状态。

参数

batch_size (可选[int]) – 可选的整数或表示批大小的整数标量张量。如果为 None,核心可能会失败,或者(实验性地)返回一个没有批次维度的初始状态。

返回类型

LSTMState

返回

此核心的任意嵌套初始状态。

GRU#

class haiku.GRU(hidden_size, w_i_init=None, w_h_init=None, b_init=None, name=None)[source]#

门控循环单元。

该实现基于:https://arxiv.org/pdf/1412.3555v1.pdf,带有偏置。

给定 \(x_t\) 和先前的状态 \(h_{t-1}\),核心计算:

\[\begin{array}{ll} z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\ r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\ a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t-1}) + b_a) \\ h_t &= (1 - z_t) \bigodot h_{t-1} + z_t \bigodot a_t \end{array}\]

其中 \(z_t\)\(r_t\) 是重置门和更新门。

输出等于新的隐藏状态 \(h_t\)

__init__(hidden_size, w_i_init=None, w_h_init=None, b_init=None, name=None)[source]#

使用给定名称初始化当前模块。

子类应在创建其他模块或变量之前调用此构造函数,以便正确命名这些模块。

参数

name (Optional[str]) – 类的可选字符串名称。必须是有效的 Python 标识符。如果未提供 name,则当前实例的类名将转换为 lower_snake_case 并改为使用。

__call__(inputs, state)[source]#

运行 RNN 的一个步骤。

参数
  • inputs – 任意嵌套的结构。

  • prev_state – 先前的核状态。

返回

具有两个元素的元组 output, next_stateoutput 是任意嵌套的结构。 next_state 是下一个核状态,它必须与 prev_state 具有相同的形状。

initial_state(batch_size)[source]#

为此核构造一个初始状态。

参数

batch_size (可选[int]) – 可选的整数或表示批大小的整数标量张量。如果为 None,核心可能会失败,或者(实验性地)返回一个没有批次维度的初始状态。

返回

此核心的任意嵌套初始状态。

DeepRNN#

class haiku.DeepRNN(layers, name=None)[source]#

将一系列核和可调用对象包装为单个核。

>>> deep_rnn = hk.DeepRNN([
...     hk.LSTM(hidden_size=4),
...     jax.nn.relu,
...     hk.LSTM(hidden_size=2),
... ])

DeepRNN 的状态是一个元组,每个 RNNCore 元素对应一个元素。如果没有任何层是 RNNCore,则状态为空元组。

__init__(layers, name=None)[source]#

使用给定名称初始化当前模块。

子类应在创建其他模块或变量之前调用此构造函数,以便正确命名这些模块。

参数

name (Optional[str]) – 类的可选字符串名称。必须是有效的 Python 标识符。如果未提供 name,则当前实例的类名将转换为 lower_snake_case 并改为使用。

haiku.deep_rnn_with_skip_connections(layers, name=None)[source]#

构建具有跳跃连接的 DeepRNN

跳跃连接改变了 DeepRNN 内的依赖结构。具体来说,第 i 层(i > 0)的输入由核心的输入和第 (i-1) 层的输出的串联给出。

DeepRNN 的输出是所有核心的输出的串联。

outputs0, ... = layers[0](inputs, ...)
outputs1, ... = layers[1](tf.concat([inputs, outputs0], axis=-1], ...)
outputs2, ... = layers[2](tf.concat([inputs, outputs1], axis=-1], ...)
...
参数
  • layers (Sequence[RNNCore]) – RNNCore 的列表。

  • name (Optional[str]) – 模块的名称。

返回类型

RNNCore

返回

具有跳跃连接的 _DeepRNN

Raises

ValueError – 如果任何层不是 RNNCore

ResetCore#

class haiku.ResetCore(core, name=None)[source]#

用于管理展开期间状态重置的包装器。

当在一批输入序列上展开 RNNCore 时,可能需要在不同时间步为批次的不同元素重置核心状态。ResetCore 类通过接收一批 should_reset 布尔值以及一批输入来实现这一点,并有条件地为批次的各个元素重置核心状态。您还可以通过传递与状态结构兼容的 should_reset 嵌套来重置状态的各个条目。

__init__(core, name=None)[source]#

使用给定名称初始化当前模块。

子类应在创建其他模块或变量之前调用此构造函数,以便正确命名这些模块。

参数

name (Optional[str]) – 类的可选字符串名称。必须是有效的 Python 标识符。如果未提供 name,则当前实例的类名将转换为 lower_snake_case 并改为使用。

__call__(inputs, state)[source]#

运行包装核心的一个步骤,处理状态重置。

参数
  • inputs – 包含两个元素的元组,inputs, should_reset,其中 should_reset 是用于重置包装核心状态的信号。should_reset 可以是张量或嵌套。如果为嵌套,则 should_reset 必须与状态结构匹配,并且其组件的形状必须是状态嵌套中相应条目张量形状的前缀。如果为张量,则支持的形状是状态组件张量的所有公共形状前缀,例如 [batch_size]

  • state – 先前的包装核心状态。

返回

包装核心的 output, next_state 元组。

initial_state(batch_size)[source]#

为此核构造一个初始状态。

参数

batch_size (可选[int]) – 可选的整数或表示批大小的整数标量张量。如果为 None,核心可能会失败,或者(实验性地)返回一个没有批次维度的初始状态。

返回

此核心的任意嵌套初始状态。

IdentityCore#

class haiku.IdentityCore(name=None)[source]#

一个循环核,它转发输入和一个空状态。

当在模型的循环版本和前馈版本之间切换时,这通常用于保持相同的接口。

__call__(inputs, state)[source]#

运行 RNN 的一个步骤。

参数
  • inputs – 任意嵌套的结构。

  • prev_state – 先前的核状态。

返回

具有两个元素的元组 output, next_stateoutput 是任意嵌套的结构。 next_state 是下一个核状态,它必须与 prev_state 具有相同的形状。

initial_state(batch_size)[source]#

为此核构造一个初始状态。

参数

batch_size (可选[int]) – 可选的整数或表示批大小的整数标量张量。如果为 None,核心可能会失败,或者(实验性地)返回一个没有批次维度的初始状态。

返回

此核心的任意嵌套初始状态。

Conv1DLSTM#

class haiku.Conv1DLSTM(input_shape, output_channels, kernel_shape, name=None)[source]#

1-D 卷积 LSTM。

该实现基于 [3]。给定 \(x_t\) 和先前的状态 \((h_{t-1}, c_{t-1})\),核心计算:

\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

其中 \(*\) 表示卷积运算符;\(i_t\)\(f_t\)\(o_t\) 是输入门、遗忘门和输出门激活,\(g_t\) 是单元更新向量。

输出等于新的隐藏状态 \(h_t\)

注意

遗忘门初始化

根据 [2],我们在初始化后将 1.0 添加到 \(b_f\),以减少训练开始时遗忘的规模。

__init__(input_shape, output_channels, kernel_shape, name=None)[source]#

构建一个 1-D 卷积 LSTM。

参数
  • input_shape (Sequence[int]) – 输入的形状,不包括批次大小。

  • output_channels (int) – 输出通道数。

  • kernel_shape (Union[int, Sequence[int]]) – 核大小序列(长度为 1)或整数。kernel_shape 将被扩展以定义所有维度上的核大小。

  • name (Optional[str]) – 模块的名称。

Conv2DLSTM#

class haiku.Conv2DLSTM(input_shape, output_channels, kernel_shape, name=None)[source]#

2-D 卷积 LSTM。

该实现基于 [3]。给定 \(x_t\) 和先前的状态 \((h_{t-1}, c_{t-1})\),核心计算:

\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

其中 \(*\) 表示卷积运算符;\(i_t\)\(f_t\)\(o_t\) 是输入门、遗忘门和输出门激活,\(g_t\) 是单元更新向量。

输出等于新的隐藏状态 \(h_t\)

注意

遗忘门初始化

根据 [2],我们在初始化后将 1.0 添加到 \(b_f\),以减少训练开始时遗忘的规模。

__init__(input_shape, output_channels, kernel_shape, name=None)[source]#

构建一个 2-D 卷积 LSTM。

参数
  • input_shape (Sequence[int]) – 输入的形状,不包括批次大小。

  • output_channels (int) – 输出通道数。

  • kernel_shape (Union[int, Sequence[int]]) – 核大小序列(长度为 2)或整数。kernel_shape 将被扩展以定义所有维度上的核大小。

  • name (Optional[str]) – 模块的名称。

Conv3DLSTM#

class haiku.Conv3DLSTM(input_shape, output_channels, kernel_shape, name=None)[source]#

3-D 卷积 LSTM。

该实现基于 [3]。给定 \(x_t\) 和先前的状态 \((h_{t-1}, c_{t-1})\),核心计算:

\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

其中 \(*\) 表示卷积运算符;\(i_t\)\(f_t\)\(o_t\) 是输入门、遗忘门和输出门激活,\(g_t\) 是单元更新向量。

输出等于新的隐藏状态 \(h_t\)

注意

遗忘门初始化

根据 [2],我们在初始化后将 1.0 添加到 \(b_f\),以减少训练开始时遗忘的规模。

__init__(input_shape, output_channels, kernel_shape, name=None)[source]#

构建一个 3-D 卷积 LSTM。

参数
  • input_shape (Sequence[int]) – 输入的形状,不包括批次大小。

  • output_channels (int) – 输出通道数。

  • kernel_shape (Union[int, Sequence[int]]) – 核大小序列(长度为 3)或整数。kernel_shape 将被扩展以定义所有维度上的核大小。

  • name (Optional[str]) – 模块的名称。

Attention#

MultiHeadAttention#

class haiku.MultiHeadAttention(num_heads, key_size, w_init_scale=None, *, w_init=None, with_bias=True, b_init=None, value_size=None, model_size=None, name=None)[source]#

多头注意力 (MHA) 模块。

此模块旨在关注向量序列。

粗略草图:- 将键 (K)、查询 (Q) 和值 (V) 计算为输入的投影。- 注意力权重计算为 W = softmax(QK^T / sqrt(key_size))。- 输出是 WV^T 的另一个投影。

有关更多详细信息,请参阅原始 Transformer 论文

“Attention is all you need” https://arxiv.org/abs/1706.03762

形状词汇表:- T:序列长度。- D:向量(嵌入)大小。- H:注意力头的数量。

__init__(num_heads, key_size, w_init_scale=None, *, w_init=None, with_bias=True, b_init=None, value_size=None, model_size=None, name=None)[source]#

初始化模块。

参数
  • num_heads (int) – 独立注意力头的数量 (H)。

  • key_size (int) – 用于注意力的键 (K) 和查询的大小。

  • w_init_scale (Optional[float]) – 已弃用。请改用 w_init。

  • w_init (Optional[hk.initializers.Initializer]) – 线性映射中权重的初始化器。一旦 w_init_scale 完全弃用,w_init 将变为强制性。在此之前,它具有默认值 None 以实现向后兼容性。

  • with_bias (bool) – 计算各种线性投影时是否添加偏置。

  • b_init (Optional[hk.initializers.Initializer]) – 偏置 (bias) 的可选初始化器。 默认情况下为零。

  • value_size (Optional[int]) – 值投影 (V) 的可选大小。如果为 None,则默认为键大小 (K)。

  • model_size (Optional[int]) – 输出嵌入 (D') 的可选大小。如果为 None,则默认为键大小乘以注意力头的数量 (K * H)。

  • name (Optional[str]) – 此模块的可选名称。

__call__(query, key, value, mask=None)[source]#

计算(可选的带掩码)MHA,使用查询、键和值。

此模块在零个或多个“批次式”前导维度上广播。

参数
  • query (jax.Array) – 用于计算查询的嵌入序列;形状为 […, T’, D_q]。

  • key (jax.Array) – 用于计算键的嵌入序列;形状为 […, T, D_k]。

  • value (jax.Array) – 用于计算值的嵌入序列;形状为 […, T, D_v]。

  • mask (Optional[jax.Array]) – 应用于注意力权重的可选掩码;形状为 […, H=1, T’, T]。

返回类型

jax.Array

返回

由注意力加权的值投影组成的新嵌入序列的投影;

形状为 […, T’, D’]。

Batch#

Reshape(output_shape[, preserve_dims, name])

重塑输入 Tensor,保留批次维度。

Flatten([preserve_dims, name])

展平输入,保留批次维度。

BatchApply(f[, num_dims])

临时合并输入张量的前导维度。

Reshape#

class haiku.Reshape(output_shape, preserve_dims=1, name=None)[source]#

重塑输入 Tensor,保留批次维度。

例如,给定形状为 [B, H, W, C, D] 的输入张量

>>> B, H, W, C, D = range(1, 6)
>>> x = jnp.ones([B, H, W, C, D])

output_shape(-1, D) 时的默认行为是展平 BD 之间的所有维度

>>> mod = hk.Reshape(output_shape=(-1, D))
>>> assert mod(x).shape == (B, H*W*C, D)

您可以通过 preserve_dims 更改保留的前导维度的数量

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=2)
>>> assert mod(x).shape == (B, H, W*C, D)

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=3)
>>> assert mod(x).shape == (B, H, W, C, D)

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=4)
>>> assert mod(x).shape == (B, H, W, C, 1, D)

或者,preserve_dims 的负值指定要用 output_shape 替换的尾随维度的数量

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=-3)
>>> assert mod(x).shape == (B, H, W*C, D)

这在将同一模块应用于批量和非批量输出的情况下很有用

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=-3)
>>> assert mod(x[0]).shape == (H, W*C, D)
__init__(output_shape, preserve_dims=1, name=None)[source]#

构建 Reshape 模块。

参数
  • output_shape (Sequence[int]) – 将输入张量重塑为的形状,同时保留其前 preserve_dims 维度。当特殊值 -1 出现在 output_shape 中时,将自动推断相应的大小。请注意,-1 只能在 output_shape 中出现一次。要展平所有非批次维度,请使用 Flatten

  • preserve_dims (int) – 不会被重塑的前导维度的数量。如果为负数,则将其解释为要用新形状替换的尾随维度的数量。

  • name (Optional[str]) – 模块的名称。

Raises

ValueError – 如果 preserve_dims 为零。

__call__(inputs)[source]#

将自身作为函数调用。

Flatten#

class haiku.Flatten(preserve_dims=1, name=None)[source]#

展平输入,保留批次维度。

默认情况下,Flatten 合并除第一个维度之外的所有维度。可以通过设置 preserve_dims 来保留其他前导维度。

>>> x = jnp.ones([3, 2, 4])
>>> flat = hk.Flatten()
>>> flat(x).shape
(3, 8)

当要展平的输入的维度少于 preserve_dims 维度时,它将保持不变

>>> x = jnp.ones([3])
>>> flat(x).shape
(3,)

或者,preserve_dims 的负值指定要展平的尾随维度的数量

>>> x = jnp.ones([3, 2, 4])
>>> negative_flat = hk.Flatten(preserve_dims=-2)
>>> negative_flat(x).shape
(3, 8)

这允许将同一模块无缝应用于单个元素或具有相同元素形状的批次元素

>> negative_flat(x[0]).shape (8,)

__init__(preserve_dims=1, name=None)[source]#

构建 Reshape 模块。

参数
  • output_shape – 将输入张量重塑为的形状,同时保留其前 preserve_dims 维度。当特殊值 -1 出现在 output_shape 中时,将自动推断相应的大小。请注意,-1 只能在 output_shape 中出现一次。要展平所有非批次维度,请使用 Flatten

  • preserve_dims (int) – 不会被重塑的前导维度的数量。如果为负数,则将其解释为要用新形状替换的尾随维度的数量。

  • name (Optional[str]) – 模块的名称。

Raises

ValueError – 如果 preserve_dims 为零。

BatchApply#

class haiku.BatchApply(f, num_dims=2)[source]#

临时合并输入张量的前导维度。

将张量的前导维度合并为单个维度,运行给定的可调用对象,然后拆分结果的前导维度以匹配输入。

秩小于要折叠的维度数的输入数组将未经修改地传递。

这可能对于将模块应用于例如 [Time, Batch, ...] 数组的每个时间步很有用。

对于某些 f 和平台,这可能比 jax.vmap() 更有效,尤其是在与其他转换(如 jax.grad())结合使用时。

__init__(f, num_dims=2)[source]#

构建 BatchApply 模块。

参数
  • f – 要应用于重塑数组的可调用对象。

  • num_dims – 要合并的维度数。

__call__(*args, **kwargs)[source]#

将自身作为函数调用。

Embedding#

Embed([vocab_size, embed_dim, ...])

用于将标记嵌入到低维空间的模块。

EmbedLookupStyle(value[, names, module, ...])

如何返回给定 ID 的嵌入矩阵。

Embed#

class haiku.Embed(vocab_size=None, embed_dim=None, embedding_matrix=None, w_init=None, lookup_style='ARRAY_INDEX', name=None, precision=<Precision.HIGHEST: 2>)[source]#

用于将标记嵌入到低维空间的模块。

__init__(vocab_size=None, embed_dim=None, embedding_matrix=None, w_init=None, lookup_style='ARRAY_INDEX', name=None, precision=<Precision.HIGHEST: 2>)[source]#

构建 Embed 模块。

参数
  • vocab_size (Optional[int]) – 要嵌入的唯一标记的数量。如果未提供,则必须提供现有词汇表矩阵,可以从中推断出 vocab_size 作为 embedding_matrix

  • embed_dim (Optional[int]) – 要分配给每个嵌入的维度数。如果现有词汇表矩阵初始化了模块,则不应提供此参数,因为它将被推断出来。

  • embedding_matrix (Optional[Union[np.ndarray, jax.Array]]) – 大小等同于 [vocab_size, embed_dim] 的类矩阵对象。如果给定,它将用作嵌入矩阵的初始值,并且不需要给定 vocab_sizeembed_dim。如果给定了它们,则会检查它们的值是否与 embedding_matrix 的维度一致。

  • w_init (Optional[hk.initializers.Initializer]) – 嵌入矩阵的初始化器。默认情况下,嵌入通过截断的正态分布初始化。

  • lookup_style (Union[str, EmbedLookupStyle]) – EmbedLookupStyle 的枚举值之一,确定如何根据 ID 访问嵌入的值。无论如何,输入都应该是表示 ID 的整数值的密集数组。此设置更改了此模块内部如何将这些 ID 映射到嵌入。结果是相同的,但速度和内存权衡是不同的。它默认为使用 NumPy 风格的数组索引。此值仅是模块的默认值,并且在任何给定调用中都可以在 __call__() 中覆盖。

  • name (Optional[str]) – 此模块的可选名称。

  • precision (jax.lax.Precision) – 仅当 lookup_style 为 ONE_HOT 时使用。用于单热编码输入与嵌入向量之间的点积的精度。使用 jax.lax.Precision.DEFAULT 可能会在 TPU 上实现约 2 倍的加速,但代价是精度略有降低。

Raises

ValueError – 如果未提供 embed_dimembedding_matrixvocab_size 中的任何一个,或者如果提供了 embedding_matrixembed_dimvocab_size 与提供的矩阵不一致。

__call__(ids, lookup_style=None, precision=None)[source]#

查找嵌入。

ids 中的每个值查找嵌入向量。所有 ID 必须在 [0, vocab_size) 范围内,以防止 NaN 传播。

参数
  • ids (Union[jax.Array, Sequence[int]]) – 整数数组。

  • lookup_style (Optional[Union[str, hk.EmbedLookupStyle]]) – 覆盖构造函数中给出的 lookup_style

  • precision (Optional[jax.lax.Precision]) – 覆盖构造函数中给出的 precision

返回类型

jax.Array

返回

形状为 ids.shape + [embedding_dim] 的张量。

Raises
  • AttributeError – 如果 lookup_style 无效。

  • ValueError – 如果 ids 不是整数数组。

EmbedLookupStyle#

class haiku.EmbedLookupStyle(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

如何返回给定 ID 的嵌入矩阵。

ARRAY_INDEX = 1#
ONE_HOT = 2#

Utilities#

Deferred(factory[, call_methods])

延迟另一个模块的构建,直到首次调用。

Deferred#

class haiku.Deferred(factory, call_methods=('__call__',))[source]#

延迟另一个模块的构建,直到首次调用。

Deferred 可用于声明依赖于其他模块的计算属性的模块,在这些模块被定义之前。这允许用户分离模块的声明和使用。例如,在程序开始时,您可以声明两个耦合的模块

>>> encoder = hk.Linear(64)
>>> decoder = hk.Deferred(lambda: hk.Linear(encoder.input_size))

稍后您可以自然地使用这些模块(注意:首先使用 decoder 会导致错误,因为 encoder.input_size 仅在调用 encoder 后定义)

>>> x = jnp.ones([8, 32])
>>> y = encoder(x)
>>> z = decoder(y)  # Constructs the Linear encoder by calling the lambda.

结果将满足以下条件

>>> assert x.shape == z.shape
>>> assert y.shape == (8, 64)
>>> assert decoder.input_size == encoder.output_size
>>> assert decoder.output_size == encoder.input_size
__init__(factory, call_methods=('__call__',))[source]#

初始化 Deferred 模块。

参数
  • factory (Callable[[], T]) – 一个无参数可调用对象,用于构建要延迟到的模块。首次调用 call_methods 之一时,将运行 factory,然后将使用与延迟模块相同的方法和参数调用构建的模块。

  • call_methods (Sequence[str]) – 应触发目标模块构建的方法。默认值将此模块配置为在首次运行 __call__ 时构建。如果您想添加除 call 之外的方法,则应显式传递它们(可选),例如 call_methods=(“__call__”, “encode”, “decode”)

property target: T#

返回目标模块。

如果 factory 尚未运行,这将触发构建。后续对 target 的调用将返回相同的实例。

返回类型

T

返回

由传递到构造函数的 factory 函数创建的 Module 实例。

__call__(*args, **kwargs)[source]#

将自身作为函数调用。

__setattr__(name, value)[source]#

实现 setattr(self, name, value)。

__delattr__(name)[source]#

实现 delattr(self, name)。

Initializers#

Initializer

alias of Callable[[collections.abc.Sequence[int], Any], jax.Array]

Constant(constant)

使用常量初始化。

Identity([gain])

生成单位矩阵的初始化器。

Orthogonal([scale, axis])

均匀缩放初始化器。

RandomNormal([stddev, mean])

通过从正态分布中采样进行初始化。

RandomUniform([minval, maxval])

通过从均匀分布中采样进行初始化。

TruncatedNormal([stddev, mean, lower, upper])

通过从截断的正态分布中采样进行初始化。

VarianceScaling([scale, mode, distribution, ...])

初始化器,其根据初始化的数组的形状调整其比例。

UniformScaling([scale])

均匀缩放初始化器。

Initializer#

haiku.initializers.Initializer#

alias of Callable[[collections.abc.Sequence[int], Any], jax.Array]

Constant#

class haiku.initializers.Constant(constant)[source]#

使用常量初始化。

__init__(constant)[source]#

构造 Constant 初始化器。

参数

constant (Union[float, int, complex, np.ndarray, jax.Array]) – 用于初始化的常数值。

__call__(shape, dtype)[source]#

将自身作为函数调用。

返回类型

jax.Array

Identity#

class haiku.initializers.Identity(gain=1.0)[source]#

生成单位矩阵的初始化器。

构造 2D 单位矩阵或这些矩阵的批次。

__init__(gain=1.0)[source]#

构造 Identity 初始化器。

参数

gain (Union[float, np.ndarray, jax.Array]) – 应用于单位矩阵的乘法因子。

__call__(shape, dtype)[source]#

将自身作为函数调用。

返回类型

jax.Array

Orthogonal#

class haiku.initializers.Orthogonal(scale=1.0, axis=- 1)[source]#

均匀缩放初始化器。

__init__(scale=1.0, axis=- 1)[source]#

构造一个均匀分布的正交矩阵的初始化器。

这些矩阵将沿着 axis 指定的轴是行正交的。如果权重的秩大于 2,形状将在所有其他维度中展平,然后沿着最后一个维度行正交。请注意,这仅在 axis 维度较大时才有效,否则矩阵将被转置(等效地,它将是列正交而不是行正交)。

如果形状不是正方形,则矩阵将具有正交的行或列,具体取决于哪一侧较小。

参数
  • scale – 比例因子。

  • axis – 哪个轴对应于张量的“输出维度”。

返回

一个正交初始化的参数。

__call__(shape, dtype)[source]#

将自身作为函数调用。

返回类型

jax.Array

RandomNormal#

class haiku.initializers.RandomNormal(stddev=1.0, mean=0.0)[source]#

通过从正态分布中采样进行初始化。

__init__(stddev=1.0, mean=0.0)[source]#

构造一个 RandomNormal 初始化器。

参数
  • stddev – 从中采样的正态分布的标准差。

  • mean – 从中采样的正态分布的均值。

__call__(shape, dtype)[source]#

将自身作为函数调用。

返回类型

jax.Array

RandomUniform#

class haiku.initializers.RandomUniform(minval=0.0, maxval=1.0)[source]#

通过从均匀分布中采样进行初始化。

__init__(minval=0.0, maxval=1.0)[source]#

构造一个 RandomUniform 初始化器。

参数
  • minval – 均匀分布的下限。

  • maxval – 均匀分布的上限。

__call__(shape, dtype)[source]#

将自身作为函数调用。

返回类型

jax.Array

TruncatedNormal#

class haiku.initializers.TruncatedNormal(stddev=1.0, mean=0.0, lower=- 2.0, upper=2.0)[source]#

通过从截断的正态分布中采样进行初始化。

__init__(stddev=1.0, mean=0.0, lower=- 2.0, upper=2.0)[source]#

构造一个 TruncatedNormal 初始化器。

参数
  • stddev (Union[float, jax.Array]) – 截断正态分布的标准差参数。

  • mean (Union[float, complex, jax.Array]) – 截断正态分布的均值。

  • lower (Union[float, jax.Array]) – 表示截断下界的浮点数或数组。

  • upper (Union[float, jax.Array]) – 表示截断上界的浮点数或数组。

__call__(shape, dtype)[source]#

将自身作为函数调用。

返回类型

jax.Array

VarianceScaling#

class haiku.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal', fan_in_axes=None)[source]#

初始化器,其根据初始化的数组的形状调整其比例。

此初始化器首先计算缩放因子 s = scale / n,其中 n 是

  • 权重张量中的输入单元数,如果 mode = fan_in

  • 输出单元数,如果 mode = fan_out

  • 输入和输出单元数的平均值,如果 mode = fan_avg

然后,使用 distribution="truncated_normal""normal",从均值为零且标准差(如果使用截断,则在截断后)为 stddev = sqrt(s) 的分布中抽取样本。

使用 distribution=uniform,从 [-limit, limit] 范围内的均匀分布中抽取样本,其中 limit = sqrt(3 * s)

方差缩放初始化器可以配置为使用 scale、mode 和 distribution 参数生成其他标准初始化器。以下是一些示例配置

名称

参数

glorot_uniform

VarianceScaling(1.0, “fan_avg”, “uniform”)

glorot_normal

VarianceScaling(1.0, “fan_avg”, “truncated_normal”)

lecun_uniform

VarianceScaling(1.0, “fan_in”, “uniform”)

lecun_normal

VarianceScaling(1.0, “fan_in”, “truncated_normal”)

he_uniform

VarianceScaling(2.0, “fan_in”, “uniform”)

he_normal

VarianceScaling(2.0, “fan_in”, “truncated_normal”)

__init__(scale=1.0, mode='fan_in', distribution='truncated_normal', fan_in_axes=None)[source]#

构造 VarianceScaling 初始化器。

参数
  • scale – 用于乘以方差的比例。

  • modefan_infan_outfan_avg 之一

  • distribution – 要使用的随机分布。truncated_normalnormaluniform 之一。

  • fan_in_axes – 可选的整数序列,指定形状的哪些轴是 fan-in 的一部分。如果未提供,则假定权重类似于卷积核,其中所有前导维度都是 fan-in 的一部分,只有尾部维度是 fan-out 的一部分。在实例化多头注意力权重时很有用。

__call__(shape, dtype)[source]#

将自身作为函数调用。

返回类型

jax.Array

UniformScaling#

class haiku.initializers.UniformScaling(scale=1.0)[source]#

均匀缩放初始化器。

通过从均匀分布中采样进行初始化,但方差按输入单元数的平方根倒数缩放,并乘以比例。

__init__(scale=1.0)[source]#

构造 UniformScaling 初始化器。

参数

scale – 用于乘以均匀分布上限的比例。

__call__(shape, dtype)[source]#

将自身作为函数调用。

返回类型

jax.Array

Paddings#

PadFn

alias of Callable[[int], tuple[int, int]]

is_padfn(padding)

测试给定参数是否为单个或 PadFn 序列。

create(padding, kernel, rate, n)

为给定的填充算法生成所需的填充。

create_from_padfn(padding, kernel, rate, n)

为给定的填充算法生成所需的填充。

create_from_tuple(padding, n)

使用部分指定的填充元组创建填充元组。

causal(effective_kernel_size)

预填充,使输出不依赖于未来。

full(effective_kernel_size)

最大填充,同时不只在填充元素上进行卷积。

reverse_causal(effective_kernel_size)

后填充,使输出不依赖于过去。

same(effective_kernel_size)

填充,使得步长为 1 时输出大小与输入大小匹配。

valid(effective_kernel_size)

无填充。

PadFn#

haiku.pad.PadFn#

alias of Callable[[int], tuple[int, int]]

is_padfn#

haiku.pad.is_padfn(padding)[source]#

测试给定参数是否为单个或 PadFn 序列。

返回类型

bool

create#

haiku.pad.create(padding, kernel, rate, n)[source]#

为给定的填充算法生成所需的填充。

参数
  • padding (Union[hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – callable/tuple 或 callable/tuple 的序列。callable 接受一个整数,表示有效内核大小(当 rate 为 1 时的内核大小),并返回一个包含两个整数的序列,表示该维度的前填充和后填充。元组定义了两个元素:前填充和后填充。如果 padding 是一个序列,则其长度必须为 1 或 n

  • kernel (Union[int, Sequence[int]]) – int 或长度为 n 的 int 序列。每个维度的内核大小。如果它是一个 int,它将在非通道和批次维度上复制。

  • rate (Union[int, Sequence[int]]) – int 或长度为 n 的 int 序列。每个维度的扩张率。如果它是一个 int,它将在非通道和批次维度上复制。

  • n (int) – 空间维度的数量。

返回类型

Sequence[tuple[int, int]]

返回

长度为 n 的序列,包含每个元素的填充。这些形式为 [pad_before, pad_after]

create_from_padfn#

haiku.pad.create_from_padfn(padding, kernel, rate, n)[source]#

为给定的填充算法生成所需的填充。

参数
  • padding (Union[hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – callable/tuple 或 callable/tuple 的序列。callable 接受一个整数,表示有效内核大小(当 rate 为 1 时的内核大小),并返回一个包含两个整数的序列,表示该维度的前填充和后填充。元组定义了两个元素:前填充和后填充。如果 padding 是一个序列,则其长度必须为 1 或 n

  • kernel (Union[int, Sequence[int]]) – int 或长度为 n 的 int 序列。每个维度的内核大小。如果它是一个 int,它将在非通道和批次维度上复制。

  • rate (Union[int, Sequence[int]]) – int 或长度为 n 的 int 序列。每个维度的扩张率。如果它是一个 int,它将在非通道和批次维度上复制。

  • n (int) – 空间维度的数量。

返回类型

Sequence[tuple[int, int]]

返回

长度为 n 的序列,包含每个元素的填充。这些形式为 [pad_before, pad_after]

create_from_tuple#

haiku.pad.create_from_tuple(padding, n)[source]#

使用部分指定的填充元组创建填充元组。

返回类型

Sequence[tuple[int, int]]

causal#

haiku.pad.causal(effective_kernel_size)[source]#

预填充,使输出不依赖于未来。

返回类型

tuple[int, int]

full#

haiku.pad.full(effective_kernel_size)[source]#

最大填充,同时不只在填充元素上进行卷积。

返回类型

tuple[int, int]

reverse_causal#

haiku.pad.reverse_causal(effective_kernel_size)[source]#

后填充,使输出不依赖于过去。

返回类型

tuple[int, int]

same#

haiku.pad.same(effective_kernel_size)[source]#

填充,使得步长为 1 时输出大小与输入大小匹配。

返回类型

tuple[int, int]

valid#

haiku.pad.valid(effective_kernel_size)[source]#

无填充。

返回类型

tuple[int, int]

Full Networks#

MLP#

class haiku.nets.MLP(output_sizes, w_init=None, b_init=None, with_bias=True, activation=<jax._src.custom_derivatives.custom_jvp object>, activate_final=False, name=None)[source]#

一个多层感知器模块。

__init__(output_sizes, w_init=None, b_init=None, with_bias=True, activation=<jax._src.custom_derivatives.custom_jvp object>, activate_final=False, name=None)[source]#

构造一个 MLP。

参数
  • output_sizes (Iterable[int]) – 层大小的序列。

  • w_init (Optional[hk.initializers.Initializer]) – Linear 权重的初始化器。

  • b_init (Optional[hk.initializers.Initializer]) – Linear 偏置的初始化器。如果 with_bias=False,则必须为 None

  • with_bias (bool) – 是否在每层中应用偏置。

  • activation (Callable[[jax.Array], jax.Array]) – 应用于 Linear 层之间的激活函数。默认为 ReLU。

  • activate_final (bool) – 是否激活 MLP 的最后一层。

  • name (Optional[str]) – 此模块的可选名称。

Raises

ValueError – 如果 with_biasFalseb_init 不为 None

__call__(inputs, dropout_rate=None, rng=None)[source]#

将模块连接到一些输入。

参数
  • inputs (jax.Array) – 形状为 [batch_size, input_size] 的张量。

  • dropout_rate (Optional[float]) – 可选的 dropout 率。

  • rng – 可选的 RNG 键。当使用 dropout 时需要。

返回类型

jax.Array

返回

大小为 [batch_size, output_size] 的模型输出。

reverse(activate_final=None, name=None)[source]#

返回一个新的 MLP,它是此 MLP 的逐层反向。

注意:由于计算 MLP 的反向需要知道每个线性层的输入大小,因此如果模块尚未至少调用一次,此方法将失败。

反向模块的约定是,它将接收父模块的输出作为输入,并产生一个输出,该输出是父模块的输入大小。

>>> mlp = hk.nets.MLP([1, 2, 3])
>>> mlp_in = jnp.ones([1, 2])
>>> y = mlp(mlp_in)
>>> rev = mlp.reverse()
>>> rev_mlp_out = rev(y)
>>> mlp_in.shape == rev_mlp_out.shape
True
参数
  • activate_final (可选[bool]) – 是否应激活 MLP 的最后一层。

  • name (可选[str]) – 新模块的可选名称。默认名称将是当前模块的名称,并以 "reversed_" 为前缀。

返回类型

MLP

返回

作为当前实例反向的 MLP 实例。请注意,这些实例不共享权重,并且除了彼此对称外,没有任何耦合关系。

MobileNet#

MobileNetV1#

class haiku.nets.MobileNetV1(strides=(1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1), channels=(64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024), num_classes=1000, use_bn=True, name=None)[source]#

MobileNetV1 模型。

__init__(strides=(1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1), channels=(64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024), num_classes=1000, use_bn=True, name=None)[source]#

构建 MobileNetV1 模型。

参数
  • strides (序列[int]) – 在每个 MobileNet 模块的深度卷积中使用的步幅。

  • channels (序列[int]) – 从逐点卷积输出的通道数,用于每个模块中。

  • num_classes (int) – 类别数量。

  • use_bn (bool) – 是否使用批归一化。默认为 True。如果为 true,则不使用偏置。如果为 false,则使用偏置。

  • name (Optional[str]) – 模块的名称。

__call__(inputs, is_training)[source]#

将自身作为函数调用。

返回类型

jax.Array

ResNet#

ResNet(blocks_per_group, num_classes[, ...])

ResNet 模型。

ResNet.BlockGroup(channels, num_blocks, ...)

用于 ResNet 实现的更高级别模块。

ResNet.BlockV1(channels, stride, ...[, name])

带有可选瓶颈的 ResNet V1 模块。

ResNet.BlockV2(channels, stride, ...[, name])

带有可选瓶颈的 ResNet V2 模块。

ResNet18(num_classes[, bn_config, ...])

ResNet18。

ResNet34(num_classes[, bn_config, ...])

ResNet34。

ResNet50(num_classes[, bn_config, ...])

ResNet50。

ResNet101(num_classes[, bn_config, ...])

ResNet101。

ResNet152(num_classes[, bn_config, ...])

ResNet152。

ResNet200(num_classes[, bn_config, ...])

ResNet200。

ResNet#

class haiku.nets.ResNet(blocks_per_group, num_classes, bn_config=None, resnet_v2=False, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), use_projection=(True, True, True, True), logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

ResNet 模型。

class BlockGroup(channels, num_blocks, stride, bn_config, resnet_v2, bottleneck, use_projection, name=None)[source]#

用于 ResNet 实现的更高级别模块。

__call__(inputs, is_training, test_local_stats)[source]#

将自身作为函数调用。

__init__(channels, num_blocks, stride, bn_config, resnet_v2, bottleneck, use_projection, name=None)[source]#

使用给定名称初始化当前模块。

子类应在创建其他模块或变量之前调用此构造函数,以便正确命名这些模块。

参数

name (Optional[str]) – 类的可选字符串名称。必须是有效的 Python 标识符。如果未提供 name,则当前实例的类名将转换为 lower_snake_case 并改为使用。

class BlockV1(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]#

带有可选瓶颈的 ResNet V1 模块。

__call__(inputs, is_training, test_local_stats)[source]#

将自身作为函数调用。

__init__(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]#

使用给定名称初始化当前模块。

子类应在创建其他模块或变量之前调用此构造函数,以便正确命名这些模块。

参数

name (Optional[str]) – 类的可选字符串名称。必须是有效的 Python 标识符。如果未提供 name,则当前实例的类名将转换为 lower_snake_case 并改为使用。

class BlockV2(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]#

带有可选瓶颈的 ResNet V2 模块。

__call__(inputs, is_training, test_local_stats)[source]#

将自身作为函数调用。

__init__(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]#

使用给定名称初始化当前模块。

子类应在创建其他模块或变量之前调用此构造函数,以便正确命名这些模块。

参数

name (Optional[str]) – 类的可选字符串名称。必须是有效的 Python 标识符。如果未提供 name,则当前实例的类名将转换为 lower_snake_case 并改为使用。

__init__(blocks_per_group, num_classes, bn_config=None, resnet_v2=False, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), use_projection=(True, True, True, True), logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

构建 ResNet 模型。

参数
  • blocks_per_group (序列[int]) – 长度为 4 的序列,指示每个组中创建的块的数量。

  • num_classes (int) – 将输入分类到的类别数量。

  • bn_config (可选[Mapping[str, FloatStrOrBool]]) – 包含两个元素的字典,decay_rateeps,将传递给 BatchNorm 层。 默认情况下,decay_rate0.9eps1e-5

  • resnet_v2 (bool) – 是否使用 v1 或 v2 ResNet 实现。默认为 False

  • bottleneck (bool) – 模块是否应使用瓶颈结构。默认为 True

  • channels_per_group (序列[int]) – 长度为 4 的序列,指示每个组中每个块使用的通道数。

  • use_projection (序列[bool]) – 长度为 4 的序列,指示每个残差块是否应使用投影。

  • logits_config (可选[Mapping[str, Any]]) – logits 层的关键字参数字典。

  • name (Optional[str]) – 模块的名称。

  • initial_conv_config (可选[Mapping[str, FloatStrOrBool]]) – 传递给初始 Conv2D 模块构造函数的关键字参数。

  • strides (序列[int]) – 长度为 4 的序列,指示每个组中每个块的卷积步幅大小。

__call__(inputs, is_training, test_local_stats=False)[source]#

将自身作为函数调用。

ResNet18#

class haiku.nets.ResNet18(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

ResNet18。

__init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

构建 ResNet 模型。

参数
  • num_classes (int) – 将输入分类到的类别数量。

  • bn_config (可选[Mapping[str, FloatStrOrBool]]) – 包含两个元素的字典,decay_rateeps,将传递给 BatchNorm 层。

  • resnet_v2 (bool) – 是否使用 v1 或 v2 ResNet 实现。默认为 False

  • logits_config (可选[Mapping[str, Any]]) – logits 层的关键字参数字典。

  • name (Optional[str]) – 模块的名称。

  • initial_conv_config (可选[Mapping[str, FloatStrOrBool]]) – 传递给初始 Conv2D 模块构造函数的关键字参数。

  • strides (序列[int]) – 长度为 4 的序列,指示每个组中每个块的卷积步幅大小。

ResNet34#

class haiku.nets.ResNet34(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

ResNet34。

__init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

构建 ResNet 模型。

参数
  • num_classes (int) – 将输入分类到的类别数量。

  • bn_config (可选[Mapping[str, FloatStrOrBool]]) – 包含两个元素的字典,decay_rateeps,将传递给 BatchNorm 层。

  • resnet_v2 (bool) – 是否使用 v1 或 v2 ResNet 实现。默认为 False

  • logits_config (可选[Mapping[str, Any]]) – logits 层的关键字参数字典。

  • name (Optional[str]) – 模块的名称。

  • initial_conv_config (可选[Mapping[str, FloatStrOrBool]]) – 传递给初始 Conv2D 模块构造函数的关键字参数。

  • strides (序列[int]) – 长度为 4 的序列,指示每个组中每个块的卷积步幅大小。

ResNet50#

class haiku.nets.ResNet50(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

ResNet50。

__init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

构建 ResNet 模型。

参数
  • num_classes (int) – 将输入分类到的类别数量。

  • bn_config (可选[Mapping[str, FloatStrOrBool]]) – 包含两个元素的字典,decay_rateeps,将传递给 BatchNorm 层。

  • resnet_v2 (bool) – 是否使用 v1 或 v2 ResNet 实现。默认为 False

  • logits_config (可选[Mapping[str, Any]]) – logits 层的关键字参数字典。

  • name (Optional[str]) – 模块的名称。

  • initial_conv_config (可选[Mapping[str, FloatStrOrBool]]) – 传递给初始 Conv2D 模块构造函数的关键字参数。

  • strides (序列[int]) – 长度为 4 的序列,指示每个组中每个块的卷积步幅大小。

ResNet101#

class haiku.nets.ResNet101(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

ResNet101。

__init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

构建 ResNet 模型。

参数
  • num_classes (int) – 将输入分类到的类别数量。

  • bn_config (可选[Mapping[str, FloatStrOrBool]]) – 包含两个元素的字典,decay_rateeps,将传递给 BatchNorm 层。

  • resnet_v2 (bool) – 是否使用 v1 或 v2 ResNet 实现。默认为 False

  • logits_config (可选[Mapping[str, Any]]) – logits 层的关键字参数字典。

  • name (Optional[str]) – 模块的名称。

  • initial_conv_config (可选[Mapping[str, FloatStrOrBool]]) – 传递给初始 Conv2D 模块构造函数的关键字参数。

  • strides (序列[int]) – 长度为 4 的序列,指示每个组中每个块的卷积步幅大小。

ResNet152#

class haiku.nets.ResNet152(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

ResNet152。

__init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

构建 ResNet 模型。

参数
  • num_classes (int) – 将输入分类到的类别数量。

  • bn_config (可选[Mapping[str, FloatStrOrBool]]) – 包含两个元素的字典,decay_rateeps,将传递给 BatchNorm 层。

  • resnet_v2 (bool) – 是否使用 v1 或 v2 ResNet 实现。默认为 False

  • logits_config (可选[Mapping[str, Any]]) – logits 层的关键字参数字典。

  • name (Optional[str]) – 模块的名称。

  • initial_conv_config (可选[Mapping[str, FloatStrOrBool]]) – 传递给初始 Conv2D 模块构造函数的关键字参数。

  • strides (序列[int]) – 长度为 4 的序列,指示每个组中每个块的卷积步幅大小。

ResNet200#

class haiku.nets.ResNet200(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

ResNet200。

__init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#

构建 ResNet 模型。

参数
  • num_classes (int) – 将输入分类到的类别数量。

  • bn_config (可选[Mapping[str, FloatStrOrBool]]) – 包含两个元素的字典,decay_rateeps,将传递给 BatchNorm 层。

  • resnet_v2 (bool) – 是否使用 v1 或 v2 ResNet 实现。默认为 False

  • logits_config (可选[Mapping[str, Any]]) – logits 层的关键字参数字典。

  • name (Optional[str]) – 模块的名称。

  • initial_conv_config (可选[Mapping[str, FloatStrOrBool]]) – 传递给初始 Conv2D 模块构造函数的关键字参数。

  • strides (序列[int]) – 长度为 4 的序列,指示每个组中每个块的卷积步幅大小。

VectorQuantizer#

VectorQuantizer(embedding_dim, ...[, dtype, ...])

Haiku 模块,表示 VQ-VAE 层。

VectorQuantizerEMA(embedding_dim, ...[, ...])

Haiku 模块,表示 VQ-VAE 层。

VectorQuantizer#

class haiku.nets.VectorQuantizer(embedding_dim, num_embeddings, commitment_cost, dtype=<class 'jax.numpy.float32'>, name=None, cross_replica_axis=None)[source]#

Haiku 模块,表示 VQ-VAE 层。

实现了 van den Oord 等人在“神经离散表示学习”中提出的算法。https://arxiv.org/abs/1711.00937

输入任何需要量化的张量。最后一维将用作量化的空间。所有其他维度将被展平,并被视为不同的量化示例。

输出张量将具有与输入相同的形状。

例如,形状为 [16, 32, 32, 64] 的张量将被重塑为 [16384, 64],并且所有 16384 个向量(每个向量为 64 维)将被独立量化。

embedding_dim#

整数,表示量化空间中张量的维度。模块的输入也必须是这种格式。

num_embeddings#

整数,量化空间中向量的数量。

commitment_cost#

标量,用于控制损失项的权重(请参阅论文中的公式 4 - 此变量为 Beta)。

__init__(embedding_dim, num_embeddings, commitment_cost, dtype=<class 'jax.numpy.float32'>, name=None, cross_replica_axis=None)[source]#

初始化 VQ-VAE 模块。

参数
  • embedding_dim (int) – 量化空间中张量的维度。模块的输入也必须是这种格式。

  • num_embeddings (int) – 量化空间中向量的数量。

  • commitment_cost (float) – 标量,用于控制损失项的权重(请参阅论文中的公式 4 - 此变量为 Beta)。

  • dtype (Any) – 嵌入变量的数据类型,默认为 float32

  • name (Optional[str]) – 模块的名称。

  • cross_replica_axis (Optional[str]) – 如果不是 None,则它应该是一个字符串,表示此模块在 jax.pmap() 中运行的轴名称。提供此参数意味着跨该轴上的所有副本计算困惑度。

__call__(inputs, is_training)[source]#

将模块连接到一些输入。

参数
  • inputs – 张量,最后一维必须等于 embedding_dim。所有其他前导维度将被展平并视为一个大的批次。

  • is_training – 布尔值,指示此连接是否用于训练数据。

返回

包含以下键和值的字典
  • quantize: 包含输入量化版本的张量。

  • loss: 包含要优化的损失的张量。

  • perplexity: 包含编码困惑度的张量。

  • encodings: 包含离散编码的张量,即每个输入元素映射到量化空间中的哪个元素。

  • encoding_indices: 包含离散编码索引的张量,即每个输入元素映射到量化空间中的哪个元素。

返回类型

dict

quantize(encoding_indices)[source]#

返回一批索引的嵌入张量。

VectorQuantizerEMA#

class haiku.nets.VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay, epsilon=1e-05, dtype=<class 'jax.numpy.float32'>, cross_replica_axis=None, name=None)[source]#

Haiku 模块,表示 VQ-VAE 层。

实现了 van den Oord 等人在“神经离散表示学习”中提出的算法。https://arxiv.org/abs/1711.00937

VectorQuantizerEMAVectorQuantizer 之间的区别在于,此模块使用 ExponentialMovingAverage 来更新嵌入向量,而不是辅助损失。这样做的好处是,嵌入更新独立于用于编码器、解码器和架构其他部分的优化器选择(SGD、RMSProp、Adam、K-Fac 等)。对于大多数实验,EMA 版本比非 EMA 版本训练速度更快。

输入任何需要量化的张量。最后一维将用作量化的空间。所有其他维度将被展平,并被视为不同的量化示例。

输出张量将具有与输入相同的形状。

例如,形状为 [16, 32, 32, 64] 的张量将被重塑为 [16384, 64],并且所有 16384 个向量(每个向量为 64 维)将被独立量化。

embedding_dim#

整数,表示量化空间中张量的维度。模块的输入也必须是这种格式。

num_embeddings#

整数,量化空间中向量的数量。

commitment_cost#

标量,用于控制损失项的权重(请参阅论文中的公式 4)。

decay#

浮点数,移动平均的衰减率。

epsilon#

小的浮点常数,用于避免数值不稳定。

__init__(embedding_dim, num_embeddings, commitment_cost, decay, epsilon=1e-05, dtype=<class 'jax.numpy.float32'>, cross_replica_axis=None, name=None)[source]#

初始化 VQ-VAE EMA 模块。

参数
  • embedding_dim – 整数,表示量化空间中张量的维度。模块的输入也必须是这种格式。

  • num_embeddings – 整数,量化空间中向量的数量。

  • commitment_cost – 标量,用于控制损失项的权重(请参阅论文中的公式 4 - 此变量为 Beta)。

  • decay – 介于 0 和 1 之间的浮点数,控制指数移动平均的速度。

  • epsilon (float) – 小常数,用于辅助数值稳定性,默认为 1e-5

  • dtype (Any) – 嵌入变量的数据类型,默认为 float32

  • cross_replica_axis (Optional[str]) – 如果不是 None,则它应该是一个字符串,表示此模块在 jax.pmap() 中运行的轴名称。提供此参数意味着跨该轴上的所有副本计算集群统计信息和困惑度。

  • name (Optional[str]) – 模块的名称。

__call__(inputs, is_training)[source]#

将模块连接到一些输入。

参数
  • inputs – 张量,最后一维必须等于 embedding_dim。所有其他前导维度将被展平并视为一个大的批次。

  • is_training – 布尔值,指示此连接是否用于训练数据。当设置为 False 时,内部移动平均统计信息将不会更新。

返回

包含以下键和值的字典
  • quantize: 包含输入量化版本的张量。

  • loss: 包含要优化的损失的张量。

  • perplexity: 包含编码困惑度的张量。

  • encodings: 包含离散编码的张量,即每个输入元素映射到量化空间中的哪个元素。

  • encoding_indices: 包含离散编码索引的张量,即每个输入元素映射到量化空间中的哪个元素。

返回类型

dict

quantize(encoding_indices)[source]#

返回一批索引的嵌入张量。

JAX 基础知识#

控制流#

cond(pred, true_fun, false_fun, *operands[, ...])

有条件地应用 `true_fun` 或 `false_fun`。

fori_loop(lower, upper, body_fun, init_val)

等效于 `jax.lax.fori_loop()`,但 Haiku 状态会传入/传出。

map(f, xs)

等效于 `jax.lax.map()`,但 Haiku 状态会传入/传出。

scan(f, init, xs[, length, reverse, unroll])

等效于 `jax.lax.scan()`,但 Haiku 状态会传入/传出。

switch(index, branches, *operands)

等效于 `jax.lax.switch()`,但 Haiku 状态会传入/传出。

while_loop(cond_fun, body_fun, init_val)

等效于 jax.lax.while_loop,但 Haiku 状态会在其中传递。

cond#

haiku.cond(pred, true_fun, false_fun, *operands, operand=<object object>, linear=None)[source]#

有条件地应用 `true_fun` 或 `false_fun`。

包装了 XLA 的 Conditional 运算符。

如果提供的参数类型正确,则 `cond()` 具有与此 Python 实现等效的语义,其中 `pred` 必须是标量类型

def cond(pred, true_fun, false_fun, *operands):
  if pred:
    return true_fun(*operands)
  else:
    return false_fun(*operands)

与 `jax.lax.select()` 相比,使用 `cond` 表示仅执行两个分支之一(取决于编译器重写和优化)。但是,当使用 `vmap()` 转换以对一批谓词进行操作时,`cond` 将转换为 `select()`。

参数
  • pred – 布尔标量类型,指示要应用哪个分支函数。

  • true_fun (Callable) – 函数 (A -> B),如果 `pred` 为 True,则应用此函数。

  • false_fun (Callable) – 函数 (A -> B),如果 `pred` 为 False,则应用此函数。

  • operands – 操作数 (A),根据 `pred` 输入到任一分支。类型可以是标量、数组或任何 pytree(嵌套的 Python 元组/列表/字典)。

返回

值 (B) 为 `true_fun(*operands)` 或 `false_fun(*operands)`,具体取决于 `pred` 的值。类型可以是标量、数组或任何 pytree(嵌套的 Python 元组/列表/字典)。

fori_loop#

haiku.fori_loop(lower, upper, body_fun, init_val)[source]#

等效于 `jax.lax.fori_loop()`,但 Haiku 状态会传入/传出。

map#

haiku.map(f, xs)[source]#

等效于 `jax.lax.map()`,但 Haiku 状态会传入/传出。

scan#

haiku.scan(f, init, xs, length=None, reverse=False, unroll=1)[source]#

等效于 `jax.lax.scan()`,但 Haiku 状态会传入/传出。

switch#

haiku.switch(index, branches, *operands)[source]#

等效于 `jax.lax.switch()`,但 Haiku 状态会传入/传出。

请注意,不支持在 switch 分支内部创建参数,因此在初始化时,我们建议您无条件地评估 switch 的所有分支,并且仅在应用时使用 switch。例如

>>> experts = [hk.nets.MLP([300, 100, 10]) for _ in range(5)]
>>> x = jnp.ones([1, 28 * 28])
>>> if hk.running_init():
...   # During init unconditionally create params/state for all experts.
...   for expert in experts:
...     out = expert(x)
... else:
...   # During apply conditionally apply (and update) only one expert.
...   index = jax.random.randint(hk.next_rng_key(), [], 0, len(experts) - 1)
...   out = hk.switch(index, experts, x)
参数
  • index – 整数标量类型,指示要应用哪个分支函数。

  • branches – 函数序列 (A -> B),根据索引应用。

  • operands – 操作数 (A),输入到应用的任何分支。

返回

值 (B) 为 branch(*operands),用于基于索引选择的分支。

while_loop#

haiku.while_loop(cond_fun, body_fun, init_val)[source]#

等效于 jax.lax.while_loop,但 Haiku 状态会在其中传递。

JAX 转换#

eval_shape(fun, *args, **kwargs)

等效于 jax.eval_shape,但会丢弃任何更改的 Haiku 状态。

grad(fun[, argnums, has_aux, holomorphic])

创建一个函数,用于评估 `fun` 的梯度。

remat(fun, *[, prevent_cse, policy, ...])

等效于 jax.checkpoint,但会传递 Haiku 状态。

value_and_grad(fun[, argnums, has_aux, ...])

创建一个函数,用于评估 `fun` 和 `fun` 的梯度。

vmap(fun[, in_axes, out_axes, axis_name, ...])

等效于 `jax.vmap()`,但模块参数/状态未映射。

eval_shape#

haiku.eval_shape(fun, *args, **kwargs)[source]#

等效于 jax.eval_shape,但会丢弃任何更改的 Haiku 状态。

grad#

haiku.grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]#

创建一个函数,用于评估 `fun` 的梯度。

注意:只有在非常特殊的情况下才需要此项,即您想在 `transform()` 转换的函数 内部 获取梯度,并且您要区分的函数使用 `set_state()`。例如

>>> class MyModule(hk.Module):
...   def __call__(self, x):
...     hk.set_state("last", x ** 2)
...     return x ** 2
>>> def f(x):
...   m = MyModule()
...   g = hk.grad(m)(x)
...   return g
>>> f = hk.transform_with_state(f)
>>> x = jnp.array(2.)
>>> params, state = jax.jit(f.init)(None, x)
>>> print(state["my_module"]["last"])
4.0
参数
  • fun – 要微分的函数。其在由 `argnums` 指定的位置的参数应为数组、标量或标准 Python 容器。它应返回一个标量(包括形状为 `()` 但不包括形状为 `(1,)` 等的数组)

  • argnums – 可选,整数或整数元组。指定要相对于哪个位置参数进行微分(默认为 0)。

  • has_aux – 可选,布尔值。指示 `fun` 是否返回一对,其中第一个元素被视为要微分的数学函数的输出,第二个元素是辅助数据。默认为 False。

  • holomorphic – 可选,布尔值。指示 `fun` 是否保证是全纯的。默认为 False。

返回

一个与 fun 具有相同参数的函数,它评估 fun 的梯度。如果 argnums 是一个整数,则梯度具有与该整数指示的位置参数相同的形状和类型。如果 argnums 是一个整数元组,则梯度是一个值元组,这些值具有与相应参数相同的形状和类型。如果 has_aux 为 True,则返回一对 gradient, auxiliary_data

例如

>>> grad_tanh = jax.grad(jax.numpy.tanh)
>>> print(grad_tanh(0.2))
0.96...

remat#

haiku.remat(fun, *, prevent_cse=True, policy=None, static_argnums=())[source]#

等效于 jax.checkpoint,但会传递 Haiku 状态。

返回类型

可调用对象 (Callable)

value_and_grad#

haiku.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]#

创建一个函数,用于评估 `fun` 和 `fun` 的梯度。

注意:只有在非常特殊的情况下才需要此项,即您想在 `transform()` 转换的函数 内部 获取梯度,并且您要区分的函数使用 `set_state()`。例如

>>> class MyModule(hk.Module):
...   def __call__(self, x):
...     hk.set_state("last", jnp.sum(x))
...     return x ** 2
>>> def f(x):
...   m = MyModule()
...   y, g = hk.value_and_grad(m)(x)
...   return y, g
>>> f = hk.transform_with_state(f)
>>> x = jnp.array(2.)
>>> _ = jax.jit(f.init)(None, x)
参数
  • fun – 要微分的函数。其在由 `argnums` 指定的位置的参数应为数组、标量或标准 Python 容器。它应返回一个标量(包括形状为 `()` 但不包括形状为 `(1,)` 等的数组)

  • argnums – 可选,整数或整数元组。指定要相对于哪个位置参数进行微分(默认为 0)。

  • has_aux – 可选,布尔值。指示 `fun` 是否返回一对,其中第一个元素被视为要微分的数学函数的输出,第二个元素是辅助数据。默认为 False。

  • holomorphic – 可选,布尔值。指示 `fun` 是否保证是全纯的。默认为 False。

返回

一个与 fun 具有相同参数的函数,它同时评估 funfun 的梯度,并将它们作为一对(一个双元素元组)返回。如果 argnums 是一个整数,则梯度具有与该整数指示的位置参数相同的形状和类型。如果 argnums 是一个整数元组,则梯度是一个值元组,这些值具有与相应参数相同的形状和类型。

vmap#

haiku.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, *, split_rng)[source]#

等效于 `jax.vmap()`,但模块参数/状态未映射。

Haiku 随机密钥 API 在 vmap() 下的行为由 split_rng 参数控制

>>> x = jnp.arange(2)
>>> f = hk.vmap(lambda _: hk.next_rng_key(), split_rng=False)
>>> key1, key2 = f(x)
>>> assert (key1 == key2).all()
>>> f = hk.vmap(lambda _: hk.next_rng_key(), split_rng=True)
>>> key1, key2 = f(x)
>>> assert not (key1 == key2).all()

Haiku 中的随机数通常用于两件事,首先是初始化模型参数,其次是在神经网络的前向传播中创建随机样本(例如,用于 dropout)。如果您将 vmap() 与一个同时使用 Haiku 随机密钥的模块一起使用(例如,您没有将密钥显式传递到网络中),那么您很可能希望根据我们是初始化(例如,创建模型参数)还是应用模型来改变 split_rng 的值。一种简单的方法是将 split_rng 设置为 (not hk.running_init())

对于更高级的用例,例如映射模块参数,我们建议用户改用 lift()transparent_lift() 结合 jax.vmap()

参数
  • fun (Callable[..., Any]) – 请参阅 jax.vmap()

  • in_axes – 请参阅 jax.vmap()

  • out_axes – 请参阅 jax.vmap()

  • axis_name (Optional[str]) – 请参阅 jax.vmap()

  • axis_size (Optional[int]) – 请参阅 jax.vmap()

  • split_rng (bool) – 控制 Haiku 中的随机密钥 API(例如 next_rng_key())是返回不同的密钥(即,在调用映射函数之前拆分内部密钥)还是相同的密钥(即,在调用映射函数之前广播内部密钥)。请参阅文档字符串中的示例。

返回类型

Callable[…, Any]

返回

请参阅 jax.vmap()

混合精度#

自动混合精度#

set_policy(cls, policy)

对模块类的所有实例使用给定的策略。

current_policy()

检索当前上下文中当前活动的策略。

get_policy(cls)

检索给定类当前活动的策略。

clear_policy(cls)

清除与给定类关联的任何策略。

push_policy(cls, policy)

在上下文处于活动状态时,为给定类设置给定的策略。

set_policy#

haiku.mixed_precision.set_policy(cls, policy)[source]#

对模块类的所有实例使用给定的策略。

注意:策略仅应用于在当前线程中创建的模块。

混合精度策略描述了在运行时应如何转换输入、模块参数和模块输出。通过将策略应用于给定类型的模块,您可以控制该模块的所有实例在程序中的行为方式。

例如,您可能希望尝试在 GPU 上以 float16float32 的混合精度运行 ResNet50 模型,以获得更高的吞吐量。为此,您可以将混合精度策略应用于 ResNet50 类型,该策略将以 float32 创建参数,但在使用之前将其转换为 float16,以及所有模块输入

>>> policy = jmp.get_policy('params=float32,compute=float16,output=float32')
>>> hk.mixed_precision.set_policy(hk.nets.ResNet50, policy)
>>> net = hk.nets.ResNet50(4)
>>> x = jnp.ones([4, 224, 224, 3])
>>> print(net(x, is_training=True))
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]

有关完整的工作混合精度示例,请参阅 Haiku 示例目录中的 imagenet 示例。此示例展示了 GPU 上的混合精度,在训练时间上提供了 2 倍的加速,而对最终 top-1 准确率的影响很小。

>>> hk.mixed_precision.clear_policy(hk.nets.ResNet50)
参数
  • cls (type[hk.Module]) – Haiku 模块类。

  • policy (jmp.Policy) – 要应用于模块的 JMP 策略。

另请参阅

current_policy#

haiku.mixed_precision.current_policy()[source]#

检索当前上下文中当前活动的策略。

返回类型

Optional[jmp.Policy]

返回

当前活动的混合精度策略,或 None

另请参阅

get_policy#

haiku.mixed_precision.get_policy(cls)[source]#

检索给定类当前活动的策略。

请注意,显式应用于顶层类(例如,ResNet)的策略将隐式应用于从父级调用的所有子模块(例如,ConvND)。此函数仅返回已显式应用的策略(例如,通过 set_policy())。

参数

cls (type[hk.Module]) – Haiku 模块类。

返回类型

Optional[jmp.Policy]

返回

用于给定类的 JMP 策略,如果没有活动的策略,则为 None

另请参阅

clear_policy#

haiku.mixed_precision.clear_policy(cls)[source]#

清除与给定类关联的任何策略。

参数

cls (type[hk.Module]) – Haiku 模块类。

另请参阅

push_policy#

haiku.mixed_precision.push_policy(cls, policy)[source]#

在上下文处于活动状态时,为给定类设置给定的策略。

参数
  • cls (type[hk.Module]) – Haiku 模块类。

  • policy (jmp.Policy) – 要应用于模块的 JMP 策略。

Yields

None.

另请参阅

🚧 实验性 (Experimental)#

Graphviz 可视化#

abstract_to_dot(fun)

将使用 Haiku 模块的函数转换为 dot 图。

abstract_to_dot#

haiku.experimental.abstract_to_dot(fun)[source]#

将使用 Haiku 模块的函数转换为 dot 图。

to_dot() 相同,但使用 JAX 的抽象解释机制来评估函数,而无需具体的输入。包装函数的有效输入包括 jax.ShapeDtypeStruct

abstract_to_dot() 不支持数据相关的控制流,因为没有为函数提供具体的值。

参数

fun (Callable[..., Any]) – 使用 Haiku 模块的函数。

返回类型

Callable[…, str]

返回

一个函数,它返回 graphviz 图的源代码字符串,该图描述给定函数执行的操作,并按 Haiku 模块进行聚类。

另请参阅

to_dot():使用具体输入生成 graphviz 图。

摘要#

tabulate(f, *[, columns, filters, ...])

生成 f 执行的摘要视图。

eval_summary(f)

记录 f 执行的模块方法调用。

ArraySpec(shape, dtype)

数组的形状和大小规范。

MethodInvocation(module_details, args_spec, ...)

在给定模块上调用方法的记录。

ModuleDetails(module, method_name, params, state)

模块和方法相关信息。

tabulate#

haiku.experimental.tabulate(f, *, columns=('module', 'config', 'owned_params', 'input', 'output', 'params_size', 'params_bytes'), filters=('has_output',), tabulate_kwargs={'tablefmt': 'grid'})[source]#

生成 f 执行的摘要视图。

>>> def f(x):
...   return hk.nets.MLP([300, 100, 10])(x)
>>> x = jnp.ones([8, 28 * 28])
>>> f = hk.transform(f)
>>> print(hk.experimental.tabulate(f)(x))
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| Module                  | Config                                   | Module params   | Input      | Output     |   Param count |   Param bytes |
+=========================+==========================================+=================+============+============+===============+===============+
| mlp (MLP)               | MLP(output_sizes=[300, 100, 10])         |                 | f32[8,784] | f32[8,10]  |       266,610 |       1.07 MB |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_0 (Linear) | Linear(output_size=300, name='linear_0') | w: f32[784,300] | f32[8,784] | f32[8,300] |       235,500 |     942.00 KB |
|  └ mlp (MLP)            |                                          | b: f32[300]     |            |            |               |               |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_1 (Linear) | Linear(output_size=100, name='linear_1') | w: f32[300,100] | f32[8,300] | f32[8,100] |        30,100 |     120.40 KB |
|  └ mlp (MLP)            |                                          | b: f32[100]     |            |            |               |               |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_2 (Linear) | Linear(output_size=10, name='linear_2')  | w: f32[100,10]  | f32[8,100] | f32[8,10]  |         1,010 |       4.04 KB |
|  └ mlp (MLP)            |                                          | b: f32[10]      |            |            |               |               |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+

columns 的可能值

  • module:显示模块和方法名称。

  • config:显示用于模块的构造函数参数。

  • owned_params:显示此模块直接拥有的参数。

  • input:显示模块输入。

  • output:显示模块输出。

  • params_size:显示参数的数量

  • params_bytes:以字节为单位显示参数大小。

filters 的可能值

  • has_output:仅包括返回除 None 之外的值的方法。

  • has_params:从不具有参数的模块中删除方法。

参数
  • f (Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState]) – 要转换的函数,或来自 Haiku 的 init/apply 函数之一,或 transform()transform_with_state() 的结果。

  • columns (Optional[Sequence[str]]) – 要启用的列名称列表。

  • filters (Optional[Sequence[str]]) – 要应用于删除某些模块方法的过滤器列表。

  • tabulate_kwargs – 要传递给 tabulate.tabulate(..) 的关键字参数。

返回类型

Callable[…, str]

返回

一个可调用对象,它接受与 f 相同的参数,但返回一个字符串,概述了在 f 执行期间使用的模块。

另请参阅

eval_summary():用于生成此表的原始数据。

eval_summary#

haiku.experimental.eval_summary(f)[source]#

记录 f 执行的模块方法调用。

>>> f = lambda x: hk.nets.MLP([300, 100, 10])(x)
>>> x = jnp.ones([8, 28 * 28])
>>> for i in hk.experimental.eval_summary(f)(x):
...   print("mod := {:14} | in := {} out := {}".format(
...       i.module_details.module.module_name, i.args_spec[0], i.output_spec))
mod := mlp            | in := f32[8,784] out := f32[8,10]
mod := mlp/~/linear_0 | in := f32[8,784] out := f32[8,300]
mod := mlp/~/linear_1 | in := f32[8,300] out := f32[8,100]
mod := mlp/~/linear_2 | in := f32[8,100] out := f32[8,10]
参数

f (Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState]) – 要追踪的函数或转换后的函数。

返回类型

Callable[…, Sequence[MethodInvocation]]

返回

一个可调用对象,它接受与提供的函数相同的参数,但返回一个 MethodInvocation 实例序列,揭示在应用 f 时在每个模块上调用的方法。

另请参阅

tabulate():漂亮地打印函数执行的摘要。

ArraySpec#

class haiku.experimental.ArraySpec(shape, dtype)[source]#

数组的形状和大小规范。

shape#

数组的形状。

类型

Sequence[int]

dtype#

数组的 DType。

类型

jnp.dtype

__delattr__(name)#

实现 delattr(self, name)。

__eq__(other)#

Return self==value.

__hash__()#

Return hash(self).

__init__(shape, dtype)#
__setattr__(name, value)#

实现 setattr(self, name, value)。

MethodInvocation#

class haiku.experimental.MethodInvocation(module_details, args_spec, kwargs_spec, output_spec, context, call_stack)[source]#

在给定模块上调用方法的记录。

module_details#

有关调用了哪个模块和方法的详细信息。

类型

ModuleDetails

args_spec#

方法调用的位置参数,其中数组被替换为 ArraySpec

类型

tuple[Any, …]

kwargs_spec#

方法调用的关键字参数,其中数组被替换为 ArraySpec

类型

dict[str, Any]

output_spec#

方法调用的输出,其中数组被替换为 ArraySpec

类型

Any

context#

intercept_methods() 提供的用于方法调用的其他上下文信息。

类型

hk.MethodContext

call_stack#

在调用此模块方法时当前活动的模块堆栈。例如,如果 A 调用 B,而 B 调用 C,则 C 的调用堆栈将为 [B_DETAILS, A_DETAILS]

类型

Sequence[ModuleDetails]

__delattr__(name)#

实现 delattr(self, name)。

__eq__(other)#

Return self==value.

__hash__()#

Return hash(self).

__init__(module_details, args_spec, kwargs_spec, output_spec, context, call_stack)#
__setattr__(name, value)#

实现 setattr(self, name, value)。

ModuleDetails#

class haiku.experimental.ModuleDetails(module, method_name, params, state)[source]#

模块和方法相关信息。

module#

一个 Module 实例。

类型

hk.Module

method_name#

在模块上调用的方法名称。

类型

str

params#

模块的 params 字典,其中数组已转换为 ArraySpec

类型

Mapping[str, ArraySpec]

state#

模块的 state 字典,其中数组已转换为 ArraySpec

类型

Mapping[str, ArraySpec]

__delattr__(name)#

实现 delattr(self, name)。

__eq__(other)#

Return self==value.

__hash__()#

Return hash(self).

__init__(module, method_name, params, state)#
__setattr__(name, value)#

实现 setattr(self, name, value)。

管理状态#

check_jax_usage([enabled])

确保 JAX API(例如

check_jax_usage#

haiku.experimental.check_jax_usage(enabled=True)[source]#

确保 JAX API(例如 jax.vmap())与 Haiku 正确使用。

JAX 变换 (例如 jax.vmap()) 和控制流 (例如 jax.lax.cond()) 期望传入纯函数。Haiku 中的一些函数(例如 get_parameter())具有副作用,因此使用它们的函数只有在使用 transform() (等等) 后才是纯函数。

有时在转换函数之前使用 JAX 变换或控制流会很方便(例如,对模块的应用进行 vmap() 操作),但这样做时,您需要小心使用底层 JAX 函数的 Haiku 重载版本,该版本会谨慎地将您传入的函数变成纯函数,然后再调用底层 JAX 函数。

check_jax_usage() 允许检查原始 JAX 变换是否在 Haiku 转换函数内部被正确使用。不正确地使用 JAX 变换将导致错误。

考虑下面的函数,它不是一个纯函数(一个仅依赖于其输入且没有副作用的函数),因为我们调用了 Haiku API (get_parameter()),它在初始化期间将创建一个参数并将其注册到 Haiku。

>>> def f():
...   return hk.get_parameter("some_param", [], init=jnp.zeros)

我们不应该将此函数与 JAX API(如 jax.vmap())一起使用(因为它不是一个纯函数)。check_jax_usage() 允许您告知 Haiku 将 JAX API 的不正确使用视为错误。

>>> previous_value = hk.experimental.check_jax_usage(True)
>>> jax.vmap(f, axis_size=2)()
Traceback (most recent call last):
  ...
haiku.JaxUsageError: ...

使用 Haiku 包装的版本可以正确工作。

>>> print(hk.vmap(f, axis_size=2, split_rng=False)())
[0. 0.]
参数

enabled (bool) – 指示是否应检查使用情况的布尔值。

返回类型

bool

返回

此设置先前值的布尔值。

优化#

optimize_rng_use(fun)

优化 fun 中的 RNG 密钥拆分。

module_auto_repr(enabled)

禁用自动生成 Module.__repr__ 的实现。

fast_eval_shape(fun, *args, **kwargs)

等效于 JAX 中的 eval_shape

rng_reserve_size(size)

更改调用 next_rng_key 时保留的 RNG 密钥数量。

optimize_rng_use#

haiku.experimental.optimize_rng_use(fun)[source]#

优化 fun 中的 RNG 密钥拆分。

我们的策略是使用抽象解释来运行您的函数两次,第一次我们使用 jax.eval_shape() 以避免花费任何浮点运算,并简单地观察您调用 next_rng_key() 的次数。然后我们再次运行您的函数,但这次我们提前保留了足够的 RNG 密钥,这样我们只需要调用一次 jax.random.split()

在以下示例中,我们的 3 层 MLP 中的权重矩阵需要三个随机样本。为了绘制这些样本,我们使用 next_rng_key(),它将为每个样本拆分一个新的密钥。通过使用 optimize_rng_use(),Haiku 将预先分配足够的 RNG,以便通过一次且仅一次拆分输入密钥来评估 f。对于大型模型(与此示例不同),这可以减少 initapply 的编译时间,其中 init 预计会看到更大的加速,因为它通常会执行更多的 RNG 密钥拆分。

>>> def f(x):
...   net = hk.nets.MLP([300, 100, 10])
...   return net(x)
>>> f = hk.experimental.optimize_rng_use(f)
>>> f = hk.transform(f)
>>> params = f.init(jax.random.PRNGKey(42), jnp.ones([1, 1]))
参数

fun – 要包装的函数。

返回

一个应用 fun 的函数,但 Haiku 仅需要调用一次 jax.random.split()

module_auto_repr#

haiku.experimental.module_auto_repr(enabled)[source]#

禁用自动生成 Module.__repr__ 的实现。

默认情况下,Haiku 将自动生成模块的有用字符串表示形式以进行打印。例如:

>>> print(hk.Linear(1))
Linear(output_size=1)

在某些情况下,传递到模块构造函数中的对象可能打印速度很慢,例如非常嵌套的数据结构,或者您可能正在快速创建和丢弃模块(例如在测试中),并且不想支付转换为字符串的开销。

此配置选项允许用户在 Haiku 中全局禁用自动 repr 功能。

>>> previous_value = hk.experimental.module_auto_repr(False)
>>> print(hk.Linear(1))
<...Linear object at ...>
>>> previous_value = hk.experimental.module_auto_repr(True)
>>> print(hk.Linear(1))
Linear(output_size=1)

要在每个子类的基础上禁用该功能,请将 AUTO_REPR = False 作为属性分配给您的类,例如:

>>> class NoAutoRepr(hk.Module):
...   AUTO_REPR = False
>>> print(NoAutoRepr())
<...NoAutoRepr object at ...>
参数

enabled (bool) – 指示是否应启用模块的布尔值。

返回类型

bool

返回

此配置设置的先前值。

fast_eval_shape#

haiku.experimental.fast_eval_shape(fun, *args, **kwargs)[source]#

等效于 JAX 中的 eval_shape

此实用程序等效于 JAX 中的 eval_shape,不同之处在于它避免运行形状是平凡已知的 Haiku 函数。这可以避免 JAX 中一些 Python 开销,这些开销可能会在非常大的模型中累积。

优化

  • 所有参数/状态初始化器都替换为零。

  • hk.dropout 替换为 identity。

  • jax.random.fold_in 替换为 identity。

参数
  • fun – 要追踪的函数。

  • *args – 传递给 fun 的位置参数。

  • **kwargs – 传递给 fun 的关键字参数。

返回

fun 对于给定 args/kwargs 生成的形状。

rng_reserve_size#

haiku.experimental.rng_reserve_size(size)[source]#

更改调用 next_rng_key 时保留的 RNG 密钥数量。

参数

size (int) – 通过 next_rng_key() 拆分密钥时要保留的密钥数量,默认为 1。保留较大的密钥块可以提高模型的编译和运行时性能。更改保留大小将更改 next_rng_key 返回的 RNG 密钥,并将更改生成的随机数。

返回类型

int

返回

rng_reserve_size 设置的先前值。

jaxpr_info#

make_model_info(f[, name, ...])

创建一个计算浮点运算次数、参数和状态信息的函数。

as_html(module[, min_flop, outvars, last])

Module 格式化为交互式 HTML 元素树。

as_html_page(module[, min_flop])

make_model_info 的输出格式化为交互式 HTML 页面。

css()

Module 的 HTML 可视化的 CSS。

format_module(module[, depth])

以递归方式将模块信息格式化为人类可读的字符串。

js()

Module 的 HTML 可视化的 JavaScript。

Expression(primitive, invars, outvars[, ...])

有关单个 JAX 表达式的信息。

Module(name[, flops, expressions, ...])

有关 Haiku 模块的信息。

make_model_info#

haiku.experimental.jaxpr_info.make_model_info(f, name=None, include_module_info=True, compute_flops=None, axis_env=None)[source]#

创建一个计算浮点运算次数、参数和状态信息的函数。

参数
  • f (Callable[..., Any]) – 用于计算信息的函数。Haiku 模块和 jax.named_call 表达式将在结果中表示为嵌套的 Module。

  • name (Optional[str]) – 可选,根表达式的名称。

  • include_module_info (bool) – 是否包含 haiku 模块的参数和状态计数信息。对于非常大的计算,这可能会很慢。

  • compute_flops (Optional[ComputeFlopsFn]) – 可选,一个函数,返回执行方程所需的浮点运算次数的估计值。

  • axis_env (Optional[Sequence[tuple[Any, int]]]) – pmapped 轴的大小。有关详细信息,请参阅 jax.make_jaxpr 的文档。

返回类型

Callable[…, Module]

返回

f 的包装版本,当应用于示例参数时,返回这些参数的 fModule 表示形式。

ModuleExpression 包含有关 JAX 操作 (jaxpr) 的高级信息,并且可以以简洁和交互式格式可视化;请参阅 format_moduleas_html_pageas_html

as_html#

haiku.experimental.jaxpr_info.as_html(module, min_flop=1000, outvars='', last=False)[source]#

Module 格式化为交互式 HTML 元素树。

当将其嵌入到页面中时,还必须嵌入 cssjs 的输出,以便可视化工作。要仅直接可视化单个模块,请参阅 as_html_page

参数
  • module (Module) – 要可视化的模块,作为交互式 HTML 树。

  • min_flop (int) – 要显示的操作的最小浮点运算次数。

  • outvars (str) – 供内部使用,此模块的输出。

  • last (bool) – 供内部使用,此模块是否是其同级模块中的最后一个。

返回类型

str

返回

module 的 HTML 表示形式。

as_html_page#

haiku.experimental.jaxpr_info.as_html_page(module, min_flop=1000)[source]#

make_model_info 的输出格式化为交互式 HTML 页面。

返回类型

str

css#

haiku.experimental.jaxpr_info.css()[source]#

Module 的 HTML 可视化的 CSS。

返回类型

str

format_module#

haiku.experimental.jaxpr_info.format_module(module, depth=0)[source]#

以递归方式将模块信息格式化为人类可读的字符串。

返回类型

str

js#

haiku.experimental.jaxpr_info.js()[source]#

Module 的 HTML 可视化的 JavaScript。

返回类型

str

Expression#

class haiku.experimental.jaxpr_info.Expression(primitive, invars, outvars, flops=None, details='', params=<factory>, submodule=None, first_outvar='', name_stack=<factory>)[source]#

有关单个 JAX 表达式的信息。

Module#

class haiku.experimental.jaxpr_info.Module(name, flops=None, expressions=<factory>, total_param_size=0, param_info=<factory>, total_state_size=0, state_info=<factory>)[source]#

有关 Haiku 模块的信息。

配置#

context(*[, check_jax_usage, ...])

用于设置配置选项的上下文管理器。

set(*[, check_jax_usage, module_auto_repr, ...])

设置给定的配置选项。

context#

haiku.config.context(*, check_jax_usage=None, module_auto_repr=None, restore_flatmap=None, rng_reserve_size=None)[source]#

用于设置配置选项的上下文管理器。

此上下文管理器可用于在给定上下文中覆盖配置设置,未显式作为关键字参数传递的值将保留其当前值。

>>> with hk.config.context(check_jax_usage=True):
...   pass
参数
  • check_jax_usage (Optional[bool]) – 检查 JAX 变换和控制流是否在 Haiku 转换函数中被正确使用。

  • module_auto_repr (Optional[bool]) – 可用于禁用作为 Haiku 基本构造函数一部分的“转换为字符串”功能。

  • restore_flatmap (Optional[bool]) – 传统检查点是否应在旧的 FlatMap 数据类型(由 to_immtable_dict 返回)中恢复,默认是将这些恢复为纯字典。

  • rng_reserve_size (Optional[int]) – 通过 next_rng_key() 拆分密钥时要保留的密钥数量,默认为 1。保留较大的密钥块可以提高模型的编译和运行时性能。更改保留大小将更改 next_rng_key 返回的 RNG 密钥,并将更改生成的随机数。

返回

在活动期间应用给定配置的上下文管理器。

set#

haiku.config.set(*, check_jax_usage=None, module_auto_repr=None, restore_flatmap=None, rng_reserve_size=None)[source]#

设置给定的配置选项。

>>> hk.config.set(module_auto_repr=False)
>>> hk.Linear(1)
<...Linear object at ...>
>>> hk.config.set(module_auto_repr=True)
>>> hk.Linear(1)
Linear(output_size=1)
参数
  • check_jax_usage (Optional[bool]) – 检查 JAX 变换和控制流是否在 Haiku 转换函数中被正确使用。

  • module_auto_repr (Optional[bool]) – 可用于禁用作为 Haiku 基本构造函数一部分的“转换为字符串”功能。

  • restore_flatmap (Optional[bool]) – 传统检查点是否应在旧的 FlatMap 数据类型(由 to_immtable_dict 返回)中恢复,默认是将这些恢复为纯字典。

  • rng_reserve_size (Optional[int]) – 通过 next_rng_key() 拆分密钥时要保留的密钥数量,默认为 1。保留较大的密钥块可以提高模型的编译和运行时性能。更改保留大小将更改 next_rng_key 返回的 RNG 密钥,并将更改生成的随机数。

实用程序#

数据结构#

filter(predicate, structure)

根据用户指定的谓词过滤输入结构。

is_subset(*, subset, superset)

检查子集的叶节点是否出现在超集中。

map(fn, structure)

相应地将函数映射到输入结构。

merge(*structures[, check_duplicates])

合并多个输入结构。

partition(predicate, structure)

根据给定的谓词将输入结构划分为两个部分。

partition_n(fn, structure, n)

将一个结构划分为 n 个结构。

to_haiku_dict(structure)

返回给定二级结构的副本。

to_immutable_dict(mapping)

返回给定映射的不可变副本。

to_mutable_dict(mapping)

将不可变的 FlatMapping 转换为可变的 dict。

traverse(structure)

迭代一个结构,产生模块名称、名称和值。

tree_bytes(tree)

对 pytree 中所有数组的大小(以字节为单位)求和。

tree_size(tree)

对 pytree 中所有数组的大小求和。

filter#

haiku.data_structures.filter(predicate, structure)[source]#

根据用户指定的谓词过滤输入结构。

>>> params = {'linear': {'w': None, 'b': None}}
>>> predicate = lambda module_name, name, value: name == 'w'
>>> hk.data_structures.filter(predicate, params)
{'linear': {'w': None}}

注意:返回新结构,而不是视图。

参数
  • predicate (Callable[[str, str, T], bool]) – 用于划分输入数据的标准。predicate 参数应为一个布尔函数,它接受模块名称、模块数据包中给定条目的名称(例如,参数名称)和相应的数据作为输入。

  • structure (Mapping[str, Mapping[str, T]]) – 要过滤的 Haiku 参数或状态数据结构。

返回类型

Mapping[str, Mapping[str, T]]

返回

由输入谓词选择的所有输入参数或状态。

is_subset#

haiku.data_structures.is_subset(*, subset, superset)[source]#

检查子集的叶节点是否出现在超集中。

请注意,如果两个结构都没有叶节点,则这在空洞上为真。

>>> hk.data_structures.is_subset(subset={'a': {}}, superset={})
True
参数
  • subset (Mapping[str, Mapping[str, Any]]) – 要检查的子集。

  • superset (Mapping[str, Mapping[str, Any]]) – 要检查的超集。

返回类型

bool

返回

指示子集中的所有元素是否包含在超集中的布尔值。

map#

haiku.data_structures.map(fn, structure)[source]#

相应地将函数映射到输入结构。

>>> params = {'linear': {'w': 1.0, 'b': 2.0}}
>>> fn = lambda module_name, name, value: 2 * value if name == 'w' else value
>>> hk.data_structures.map(fn, params)
{'linear': {'b': 2.0, 'w': 2.0}}

注意:返回新结构,而不是视图。

参数
  • fn (Callable[[str, str, InT], OutT]) – 用于映射输入数据的标准。fn 参数应为一个函数,它接受模块名称、模块数据包中给定条目的名称(例如,参数名称)和相应的数据作为输入,并返回一个新值。

  • structure (Mapping[str, Mapping[str, InT]]) – 要映射的 Haiku 参数或状态数据结构。

返回类型

Mapping[str, Mapping[str, OutT]]

返回

由输入 fn 映射的所有输入参数或状态。

merge#

haiku.data_structures.merge(*structures, check_duplicates=False)[source]#

合并多个输入结构。

>>> weights = {'linear': {'w': None}}
>>> biases = {'linear': {'b': None}}
>>> hk.data_structures.merge(weights, biases)
{'linear': {'w': None, 'b': None}}

当结构不 disjoint 时,输出将包含每个路径上最后一个结构的值

>>> weights1 = {'linear': {'w': 1}}
>>> weights2 = {'linear': {'w': 2}}
>>> hk.data_structures.merge(weights1, weights2)
{'linear': {'w': 2}}

注意:返回新结构,而不是视图。

参数
  • *structures – 要合并的一个或多个结构。

  • check_duplicates (bool) – 如果为 True,则当在多个结构中找到数组但形状和 dtype 不同时,将抛出 ValueError。

返回类型

MutableMapping[str, MutableMapping[str, Any]]

返回

一个包含输入结构中每个路径的条目的单一结构。

partition#

haiku.data_structures.partition(predicate, structure)[source]#

根据给定的谓词将输入结构划分为两个部分。

对于给定的参数集,您可以使用 partition() 来拆分它们

>>> params = {'linear': {'w': None, 'b': None}}
>>> predicate = lambda module_name, name, value: name == 'w'
>>> weights, biases = hk.data_structures.partition(predicate, params)
>>> weights
{'linear': {'w': None}}
>>> biases
{'linear': {'b': None}}

注意:返回新的结构,而不是视图。

参数
  • predicate (Callable[[str, str, jax.Array], bool]) – 用于划分输入数据的标准。 predicate 参数应为一个布尔函数,它接受模块名称、模块数据包中给定条目的名称(例如,参数名称)和相应的数据作为输入。

  • structure (Mapping[str, Mapping[str, T]]) – 要划分的 Haiku 参数或状态数据结构。

返回类型

tuple[Mapping[str, Mapping[str, T]], Mapping[str, Mapping[str, T]]]

返回

一个元组,包含所有按输入划分的参数或状态

predicate。匹配 predicate 的条目将在第一个结构中,其余的将在第二个结构中。

partition_n#

haiku.data_structures.partition_n(fn, structure, n)[source]#

将一个结构划分为 n 个结构。

对于给定的参数集,您可以使用 partition_n() 将它们拆分为 n 组。例如,按模块名称拆分您的参数/梯度

>>> def partition_by_module(structure):
...   cnt = itertools.count()
...   d = collections.defaultdict(lambda: next(cnt))
...   fn = lambda m, n, v: d[m]
...   return hk.data_structures.partition_n(fn, structure, len(structure))
>>> structure = {f'layer_{i}': {'w': None, 'b': None} for i in range(3)}
>>> for substructure in partition_by_module(structure):
...   print(substructure)
{'layer_0': {'b': None, 'w': None}}
{'layer_1': {'b': None, 'w': None}}
{'layer_2': {'b': None, 'w': None}}
参数
  • fn (Callable[[str, str, T], int]) – Callable,返回给定元素应输出到 [0, n) 中的哪个 bucket。

  • structure (Mapping[str, Mapping[str, T]]) – 要划分的 Haiku 参数或状态数据结构。

  • n (int) – bucket 的总数。

返回类型

tuple[Mapping[str, Mapping[str, T]], …]

返回

大小为 n 的元组,其中每个元素将包含函数返回当前索引的值。

to_haiku_dict#

haiku.data_structures.to_haiku_dict(structure)[source]#

返回给定二级结构的副本。

使用与 Haiku 从 initapply 函数返回的相同的映射类型。

参数

structure (Mapping[K, V]) – 要复制的两级映射。

返回类型

MutableMapping[K, V]

返回

一个新的两级映射,其内容与输入相同。

to_immutable_dict#

haiku.data_structures.to_immutable_dict(mapping)[source]#

返回给定映射的不可变副本。

返回类型

Mapping[K, V]

to_mutable_dict#

haiku.data_structures.to_mutable_dict(mapping)[source]#

将不可变的 FlatMapping 转换为可变的 dict。

traverse#

haiku.data_structures.traverse(structure)[source]#

迭代一个结构,产生模块名称、名称和值。

注意:条目按键排序顺序迭代。

参数

structure (Mapping[str, Mapping[str, T]]) – 要遍历的结构。

Yields

来自给定结构的模块名称、名称和值的元组。

返回类型

Generator[tuple[str, str, T], None, None]

tree_bytes#

haiku.data_structures.tree_bytes(tree)[source]#

对 pytree 中所有数组的大小(以字节为单位)求和。

请注意,这是数组的最小大小(例如,对于 float32,我们至少需要 4 个字节),但在某些加速器上,由于填充/对齐约束,缓冲区可能会占用更多内存。

例如,给定一个 ResNet50 模型

>>> f = hk.transform_with_state(lambda x: hk.nets.ResNet50(1000)(x, True))
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([128, 224, 224, 3])
>>> params, state = f.init(rng, x)

我们可以计算 f32 处参数的数量及其大小

>>> num_params = hk.data_structures.tree_size(params)
>>> byte_size = hk.data_structures.tree_bytes(params)
>>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB')
25557032 params, size: 102.23MB

并将它与将参数转换为 bf16 进行比较

>>> params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
>>> num_params = hk.data_structures.tree_size(params)
>>> byte_size = hk.data_structures.tree_bytes(params)
>>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB')
25557032 params, size: 51.11MB
参数

tree – jax.Arrays 的树。

返回类型

int

返回

输入中数组的总字节大小。

tree_size#

haiku.data_structures.tree_size(tree)[source]#

对 pytree 中所有数组的大小求和。

例如,给定一个 ResNet50 模型

>>> f = hk.transform_with_state(lambda x: hk.nets.ResNet50(1000)(x, True))
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([128, 224, 224, 3])
>>> params, state = f.init(rng, x)

我们可以计算 f32 处参数的数量及其大小

>>> num_params = hk.data_structures.tree_size(params)
>>> byte_size = hk.data_structures.tree_bytes(params)
>>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB')
25557032 params, size: 102.23MB

并将它与将参数转换为 bf16 进行比较

>>> params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
>>> num_params = hk.data_structures.tree_size(params)
>>> byte_size = hk.data_structures.tree_bytes(params)
>>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB')
25557032 params, size: 51.11MB
参数

tree – jax.Arrays 的树。

返回类型

int

返回

输入中数组的总大小(元素数量)。

Testing#

transform_and_run([f, seed, run_apply, ...])

转换给定的函数并运行 init,然后(可选)运行 apply。

transform_and_run#

haiku.testing.transform_and_run(f=None, seed=42, run_apply=True, jax_transform=None, *, map_rng=None)[source]#

转换给定的函数并运行 init,然后(可选)运行 apply。

等效于

>>> def f(x):
...   return x
>>> x = jnp.ones([])
>>> rng = jax.random.PRNGKey(42)
>>> f = hk.transform_with_state(f)
>>> params, state = f.init(rng, x)
>>> out = f.apply(params, state, rng, x)

此函数使单元测试 Haiku 非常方便

>>> class MyTest(unittest.TestCase):
...   @hk.testing.transform_and_run
...   def test_linear_output(self):
...     mod = hk.Linear(1)
...     out = mod(jnp.ones([1, 1]))
...     self.assertEqual(out.ndim, 2)

它还可以与 chex 结合使用,以测试函数的所有 pure/jit/pmap 版本

>>> class MyTest(unittest.TestCase):
...   @chex.all_variants
...   def test_linear_output(self):
...     @hk.testing.transform_and_run(jax_transform=self.variant)
...     def f(inputs):
...       mod = hk.Linear(1)
...       return mod(inputs)
...     out = f(jnp.ones([1, 1]))
...     self.assertEqual(out.ndim, 2)

并且在交互式环境(如 ipython、Jupyter 或 Google Colaboratory)中也很有用

>>> f = lambda x: hk.Bias()(x)
>>> print(hk.testing.transform_and_run(f)(jnp.ones([1, 1])))
[[1.]]

有关更多详细信息,请参见 transform()

要将此与 pmap 一起使用(不使用 chex),您还需要传入一个函数来映射 init/apply rng 键。例如,如果您希望 pmap 的每个实例都具有相同的键

>>> def same_key_on_all_devices(key):
...   return jnp.broadcast_to(key, (jax.local_device_count(), *key.shape))
>>> @hk.testing.transform_and_run(jax_transform=jax.pmap,
...                               map_rng=same_key_on_all_devices)
... def test_something():
...   ...

或者您可以使用不同的键

>>> def different_key_on_all_devices(key):
...   return jax.random.split(key, jax.local_device_count())
>>> @hk.testing.transform_and_run(jax_transform=jax.pmap,
...                               map_rng=different_key_on_all_devices)
... def test_something_else():
...   ...
参数
  • f (Optional[Fn]) – 要转换的函数方法。

  • seed (Optional[int]) – 要传递给 init 和 apply 的种子。

  • run_apply (bool) – 是否运行 apply 以及 init。默认为 true。

  • jax_transform (Optional[Callable[[Fn], Fn]]) – 一个可选的 jax 转换,应用于 init 和 apply 函数。

  • map_rng (Optional[Callable[[Key], Key]]) – 如果设置为非 None 值,则广播 init/apply rngs broadcast_rng 次。

返回类型

T

返回

一个函数,它 transform() f 并运行 init 和可选的 apply

Conditional Computation#

running_init()

如果运行 Haiku 转换的 init 函数,则返回 True。

running_init#

haiku.running_init()[source]#

如果运行 Haiku 转换的 init 函数,则返回 True。

通常,您不应基于您是在运行 init 还是 apply 来门控模块的行为,但有时(例如,当使用 JAX 控制流时)这是必需的。

例如,如果您想使用 switch() 在专家之间进行选择,当我们运行您的 init 函数时,我们需要确保为所有专家创建参数/状态(无条件地),但在 apply 期间,我们希望有条件地应用(并且可能更新内部状态)我们的专家之一

>>> experts = [hk.nets.ResNet50(10) for _ in range(5)]
>>> x = jnp.ones([1, 224, 224, 3])
>>> if hk.running_init():
...   # During init unconditionally create params/state for all experts.
...   for expert in experts:
...     out = expert(x, is_training=True)
... else:
...   # During apply conditionally apply (and update) only one expert.
...   index = jax.random.randint(hk.next_rng_key(), [], 0, len(experts) - 1)
...   out = hk.switch(index, experts, x)
返回类型

bool

返回

如果正在运行 init 则为 True,否则为 False。

Functions#

multinomial(rng, logits, num_samples)

从多项分布中抽取样本。

one_hot(x, num_classes[, dtype])

返回索引的 one-hot 版本。

multinomial#

haiku.multinomial(rng, logits, num_samples)[source]#

从多项分布中抽取样本。

已弃用:请改用 jax.random.categorical

参数
  • rng – 一个 JAX PRNGKey。

  • logits – 未归一化的对数概率,其中最后一个维度是类别。

  • num_samples – 要抽取的样本数。

返回

选择的类别,形状为 logits.shape[:-1] + (num_samples,)

one_hot#

haiku.one_hot(x, num_classes, dtype=<class 'jax.numpy.float32'>)[source]#

返回索引的 one-hot 版本。

已弃用:请改用 jax.nn.one_hot(x, num_classes).astype(dtype)

参数
  • x – 索引张量。

  • num_classes – one-hot 维度中的类别数。

  • dtype – dtype。

返回

one-hot 张量。如果索引的形状为 [A, B, …],则形状为

[A, B, … num_classes]。

References#

1

Wojciech Zaremba, Ilya Sutskever, 和 Oriol Vinyals. Recurrent neural network regularization. arXiv preprint arXiv:1409.2329, 2014. URL: https://arxiv.org/abs/1409.2329.

2(1,2,3,4)

Rafal Jozefowicz, Wojciech Zaremba, 和 Ilya Sutskever. An empirical exploration of recurrent network architectures. In International Conference on Machine Learning, 2342–2350. 2015.

3(1,2,3)

SHI Xingjian, Zhourong Chen, Hao Wang, Dit-Yan Yeung, Wai-Kin Wong, 和 Wang-chun Woo. Convolutional lstm network: a machine learning approach for precipitation nowcasting. In Advances in neural information processing systems, 802–810. 2015.