交互式在线版本: Open In Colab

Haiku 中的参数共享#

简介#

在 Haiku 中,参数重用由模块实例名称唯一确定,即,如果一个模块实例与另一个模块实例具有相同的名称,则它们共享参数。

除非另有说明,模块名称由 Haiku 根据模块名称自动确定(遵循 TensorFlow 1 和 Sonnet V1 中建立的模式)。更详细地说,模块命名遵循以下规则

  1. 模块名称在模块实例构造时分配。除非模块实例名称作为参数提供给构造函数,否则 Haiku 从当前的模块类名称生成一个名称(基本上:to_snake_case(CurrentClassName))。

  2. 如果模块实例名称不以 _N 结尾(其中 N 是一个数字),并且另一个具有相同名称的模块实例已经存在,则 Haiku 会在新模块实例名称的末尾添加一个递增的数字(例如 module_1)。

  3. 当两个模块嵌套时(即,在一个模块的类定义内部构造一个模块实例),则内部模块名称将以外部模块名称和(可能,见下一点)外部模块当前调用的方法为前缀。构造函数(即 __init__)被波浪号 ~ 符号替换。

  4. 如果调用方法名称是 __call__,这将忽略(方法名称将仅以外部模块名称为前缀)。

  5. 当存在多层嵌套时,先前的规则应用于每个嵌套级别,并且每个内部模块名称都基于模块名称和调用方法名称,该模块紧接在调用层次结构中的当前模块之前。

让我们看看这如何通过一个实际示例来工作。

扁平模块(无嵌套)#

本节介绍模块未嵌套时的参数共享。

[4]:
#@title Imports and accessory functions
import functools
import haiku as hk
import jax
import jax.numpy as jnp


def parameter_shapes(params):
  """Make printing parameters a little more readable."""
  return jax.tree_util.tree_map(lambda p: p.shape, params)


def transform_and_print_shapes(fn, x_shape=(2, 3)):
  """Print name and shape of the parameters."""
  rng = jax.random.PRNGKey(42)
  x = jnp.ones(x_shape)

  transformed_fn = hk.transform(fn)
  params = transformed_fn.init(rng, x)
  print('\nThe name and shape of the parameters are:')
  print(parameter_shapes(params))

def assert_all_equal(params_1, params_2):
  assert all(jax.tree_util.tree_leaves(
      jax.tree_util.tree_map(lambda a, b: (a == b).all(), params_1, params_2)))
[6]:
w_init = hk.initializers.TruncatedNormal(stddev=1)

class SimpleModule(hk.Module):
  """A simple module class with one variable."""

  def __init__(self, output_channels, name=None):
    super().__init__(name)
    assert isinstance(output_channels, int)
    self._output_channels = output_channels

  def __call__(self, x):
    w_shape = (x.shape[-1], self._output_channels)
    w = hk.get_parameter("w", w_shape, x.dtype, init=w_init)
    return jnp.dot(x, w)
[ ]:
def f(x):
  # This instance will be named `a_simple_module`.
  simple = SimpleModule(output_channels=2)
  simple_out = simple(x)  # implicitly calls module_install.__call__()
  print(f'The name assigned to "simple" is: "{simple.module_name}".')
  return simple_out

transform_and_print_shapes(f)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
The name assigned to "simple" is: "simple_module".

The name and shape of the parameters are:
{'simple_module': {'w': (3, 2)}}

太棒了!在这里我们看到,如果我们创建一个 SimpleModule 实例并且不指定名称,Haiku 会为其分配名称 a_simple_module。这也反映在与模块关联的参数中。

如果我们实例化两次 SimpleModule 会发生什么?Haiku 是否为两个实例分配相同的名称?

[ ]:
def f(x):
  # This instance will be named `a_simple_module`.
  simple_one = SimpleModule(output_channels=2)
  # This instance will be named `a_simple_module_1`.
  simple_two = SimpleModule(output_channels=2)
  first_out = simple_one(x)
  second_out = simple_two(x)
  print(f'The name assigned to "simple_one" is: "{simple_one.module_name}".')
  print(f'The name assigned to "simple_two" is: "{simple_two.module_name}".')
  return first_out + second_out

transform_and_print_shapes(f)
The name assigned to "simple_one" is: "simple_module".
The name assigned to "simple_two" is: "simple_module_1".

The name and shape of the parameters are:
{'simple_module': {'w': (3, 2)}, 'simple_module_1': {'w': (3, 2)}}

正如预期的那样,Haiku 足够智能,可以区分两个实例并避免意外的参数共享:第二个实例被命名为 a_simple_module_1,并且每个实例都有自己的一组参数。很好!

但是,如果我们想要共享参数呢?在这种情况下,我们必须仅实例化模块一次并调用它多次。让我们看看这是如何工作的

[ ]:
def f(x):
  # This instance will be named `a_simple_module`.
  simple_one = SimpleModule(output_channels=2)
  first_out = simple_one(x)
  second_out = simple_one(x)  # share parameters w/ previous call
  print(f'The name assigned to "simple_one" is: "{simple_one.module_name}".')
  return first_out + second_out

transform_and_print_shapes(f)
The name assigned to "simple_one" is: "simple_module".

The name and shape of the parameters are:
{'simple_module': {'w': (3, 2)}}

嵌套模块#

在本节中,我们将看到当我们将一个 hk.Module 嵌套到另一个中时会发生什么。

[ ]:
class NestedModule(hk.Module):
  """A module class with a nested module created in the constructor."""

  def __init__(self, output_channels, name=None):
    super().__init__(name)
    assert isinstance(output_channels, int)
    self._output_channels = output_channels
    self.inner_simple = SimpleModule(self._output_channels)

  def __call__(self, x):
    w_shape = (x.shape[-1], self._output_channels)
    # Another variable that is also called `w`.
    w = hk.get_parameter("w", w_shape, x.dtype, init=w_init)
    return jnp.dot(x, w) + self.inner_simple(x)
[ ]:
def f(x):
  # This will be named `a_nested_module` and the SimpleModule instance created
  # inside it will be named `a_nested_module/a_simple_module`.
  nested = NestedModule(output_channels=2)
  nested_out = nested(x)
  print('The name assigned to outer module (i.e., "nested") is: '
        f'"{nested.module_name}".')
  print('The name assigned to the inner module (i.e., inside "nested") is: "'
        f'{nested.inner_simple.module_name}".')
  return nested_out

transform_and_print_shapes(f)
The name assigned to outer module (i.e., "nested") is: "nested_module".
The name assigned to the inner module (i.e., inside "nested") is: "nested_module/~/simple_module".

The name and shape of the parameters are:
{'nested_module': {'w': (3, 2)}, 'nested_module/~/simple_module': {'w': (3, 2)}}

正如预期的那样,内部模块名称取决于:(a)外部模块名称;以及(b)正在调用的外部模块的方法。

另请注意,外部模块的构造函数名称 __init__ 如何在参数名称中被 ~ 替换。如果内部模块实例是在外部模块的 __call__ 方法内部创建的,则内部模块实例名称将为 'a_nested_module/a_simple_module'

在本例中,我们从头开始定义了所有模块,但对于 Haiku 中定义的任何模块和网络(例如,hk.Linearhk.nets.MLP 等)也是如此。如果您好奇,请查看如果您将 self.inner_simple 分配给 hk.Linear 的实例而不是 SimpleModule 的实例会发生什么。

现在让我们尝试多层嵌套

[ ]:
class TwiceNestedModule(hk.Module):
  """A module class with a nested module containing a nested module."""

  def __init__(self, output_channels, name=None):
    super().__init__(name)
    assert isinstance(output_channels, int)
    self._output_channels = output_channels
    self.inner_nested = NestedModule(self._output_channels)

  def __call__(self, x):
    w_shape = (x.shape[-1], self._output_channels)
    w = hk.get_parameter("w", w_shape, x.dtype, init=w_init)
    return jnp.dot(x, w) + self.inner_nested(x)
[ ]:
def f(x):
  """Create the module instances and inspect their names."""
  # Instantiate a NestedModule instance. This will be named `a_nested_module`.
  # The SimpleModule instance created inside it will be named
  # a_nested_module/a_simple_module`.
  outer = TwiceNestedModule(output_channels=2)
  outer_out = outer(x)
  print(f'The name assigned to the most outer class is: "{outer.module_name}".')
  print('The name assigned to the module inside "double_nested" is: "'
        f'{outer.inner_nested.module_name}".')
  print('The name assigned to the module inside it is "'
        f'{outer.inner_nested.inner_simple.module_name}".')
  return outer_out

transform_and_print_shapes(f)
The name assigned to the most outer class is: "twice_nested_module".
The name assigned to the module inside "double_nested" is: "twice_nested_module/~/nested_module".
The name assigned to the module inside it is "twice_nested_module/~/nested_module/~/simple_module".

The name and shape of the parameters are:
{'twice_nested_module': {'w': (3, 2)}, 'twice_nested_module/~/nested_module': {'w': (3, 2)}, 'twice_nested_module/~/nested_module/~/simple_module': {'w': (3, 2)}}

太棒了,这也按预期工作:模块名称和调用的完整层次结构反映在内部模块名称中。

多重转换:合并参数但不共享它们#

有时,当我们有多个转换后的函数时,将所有参数合并到一个唯一的结构中会很方便,以减少我们必须存储和传递的字典数量。但有时这些函数会实例化相同的模块,我们希望确保它们的参数不会被意外共享。

hk.multi_transform 在这种情况下可以帮助我们,它将参数合并到一个唯一的字典中,并确保重命名重复的参数以避免意外共享。

[ ]:
def f(x):
  """A SimpleModule followed by a Linear layer."""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  linear = hk.Linear(40)
  return linear(out)

def g(x):
  """A SimpleModule followed by an MLP."""
  module_instance = SimpleModule(output_channels=2)
  return module_instance(x) * 2  # twice

# Transform both functions, and print their respective parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))
transformed_f = hk.transform(f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(g)
params_g = transformed_g.init(rng, x)
print('f parameters:', parameter_shapes(params_f))
print('g parameters:', parameter_shapes(params_g))

# Transform both functions at once with hk.multi_transform , and print the
# resulting merged parameter structure.

def multitransform_f_and_g():
  def template(x):
    return f(x), g(x)
  return template, (f, g)
init, (f_apply, g_apply) = hk.multi_transform(multitransform_f_and_g)
merged_params = init(rng, x)

print('\nThe name and shape of the multi-transform parameters are:\n',
      parameter_shapes(merged_params))
f parameters: {'linear': {'b': (40,), 'w': (2, 40)}, 'simple_module': {'w': (3, 2)}}
g parameters: {'simple_module': {'w': (3, 2)}}

The name and shape of the multi-transform parameters are:
 {'linear': {'b': (40,), 'w': (2, 40)}, 'simple_module': {'w': (3, 2)}, 'simple_module_1': {'w': (3, 2)}}

在本例中,fg 都使用相同的参数实例化 SimpleModule 实例,如果我们分别转换它们,我们会看到两个字典都包含一个 'simple_module' 键。

当我们一起转换它们时,hk.multi_transform 会为我们处理重命名其中一个为 'simple_module_1',从而防止意外的参数共享。

在转换后的函数之间共享参数#

现在我们了解了模块名称是如何分配的,以及这如何影响参数共享,让我们看看如何在转换后的函数之间共享参数。

在本节中,我们将考虑两个函数 fg,并探讨共享参数的不同策略。我们将考虑许多案例,这些案例的不同之处在于每个函数实例化的模块中有多少是相同的,以及它们的参数是否具有相同的形状。

案例 1:所有模块具有相同的名称和相同的形状#

让我们重用我们之前创建的模块之一,并尝试在两个不同的函数内部实例化它两次

[ ]:
def f(x):
  """Apply SimpleModule to x."""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  return out

def g(x):
  """Like f, but double the output"""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  return out * 2

# Transform both functions, and print the parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))

transformed_f = hk.transform(f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(g)
params_g = transformed_g.init(rng, x)

print('f parameters:', parameter_shapes(params_f))
print('g parameters:', parameter_shapes(params_g))
f parameters: {'simple_module': {'w': (3, 2)}}
g parameters: {'simple_module': {'w': (3, 2)}}

太棒了!由于 fg 正在使用完全相同的模块,因此每个生成的初始化变量集都具有相同的名称结构(请注意,实际值可能会有所不同,具体取决于初始化)。

现在,如果我们想在这种情况下共享参数,我们可以仅初始化两个函数中的一个(例如,f),并将生成的参数用于两个函数,即,当我们调用 transformed_f.applytransformed_g.apply 时。

案例 2:公共模块具有相同的名称和相同的形状#

这是一个不错的技巧,但是如果函数不完全相同呢?让我们构建两个这样的函数

[ ]:
def f(x):
  """A SimpleModule followed by a Linear layer."""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  linear = hk.Linear(40)
  return linear(out)

def g(x):
  """A SimpleModule followed by an MLP."""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  linear = hk.nets.MLP((10, 40))
  return linear(out)

# Transform both functions, and print the parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))

transformed_f = hk.transform(f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(g)
params_g = transformed_g.init(rng, x)

print('\nThe name and shape of the f parameters are:\n',
      parameter_shapes(params_f))
print('\nThe name and shape of the g parameters are:\n',
      parameter_shapes(params_g))
The name and shape of the f parameters are:
 {'linear': {'b': (40,), 'w': (2, 40)}, 'simple_module': {'w': (3, 2)}}

The name and shape of the g parameters are:
 {'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (40,), 'w': (10, 40)}, 'simple_module': {'w': (3, 2)}}

现在我们遇到了问题!两组参数都有一个 'simple_module' 组件,但它们也都包含仅特定于该函数的参数,因此我们不能像以前那样简单地仅初始化其中一个函数并将返回的参数用于两者。但是我们仍然希望共享 'simple_module' 的参数。我们该怎么做?

这里的一个选择是使用 `haiku.data_structures.merge <https://haiku.jax.net.cn/en/latest/api.html#haiku.data_structures.merge>`__ 来组合两组参数。这将合并两个结构,当两个结构都具有相同的参数时(即,在我们的示例中为 'simple_module'),仅保留来自最后一个结构的值。让我们尝试一下

[ ]:
merged_params = hk.data_structures.merge(params_f, params_g)
print('\nThe name and shape of the shared parameters are:\n',
      parameter_shapes(merged_params))
The name and shape of the shared parameters are:
 {'linear': {'b': (40,), 'w': (2, 40)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (40,), 'w': (10, 40)}, 'simple_module': {'w': (3, 2)}}

太棒了!现在我们有一个共享的参数集,其中包含所有不相交的参数和一组用于共享 'simple_module' 的参数。让我们验证一下,在调用任一函数时,我们都可以使用这组参数

[ ]:
f_out = transformed_f.apply(merged_params, rng, x)
g_out = transformed_g.apply(merged_params, rng, x)

print('f_out mean:', f_out.mean())
print('g_out mean:', g_out.mean())
f_out mean: 0.037986994
g_out mean: 0.104857825

但这让我们对共享的内容几乎没有控制权:如果两个函数具有我们不想共享的同名参数怎么办?

案例 3:公共模块具有相同的名称,但不同的形状#

让我们修改之前的示例,在两个函数中使用 hk.Linear

[ ]:
def f(x):
  """A SimpleModule followed by two Linear layers."""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  mlp = hk.nets.MLP((10, 5))
  out = mlp(out)
  last_linear = hk.Linear(4)
  return last_linear(out)

def g(x):
  """Same as f, with a bigger final layer."""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  mlp = hk.nets.MLP((10, 5))
  out = mlp(out)
  last_linear = hk.Linear(20)  # another Linear, but bigger
  return last_linear(out)

# Transform both functions, and print the parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))

transformed_f = hk.transform(f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(g)
params_g = transformed_g.init(rng, x)

print('\nThe name and shape of the f parameters are:\n',
      parameter_shapes(params_f))
print('\nThe name and shape of the g parameters are:\n',
      parameter_shapes(params_g))
The name and shape of the f parameters are:
 {'linear': {'b': (4,), 'w': (5, 4)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'simple_module': {'w': (3, 2)}}

The name and shape of the g parameters are:
 {'linear': {'b': (20,), 'w': (5, 20)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'simple_module': {'w': (3, 2)}}

现在我们遇到了问题!两组参数都有一个 'linear' 组件,但它们各自的参数具有不同的形状。如果我们像以前一样合并它们,则来自 f'linear' 的参数将被删除,我们将无法使用合并后的参数来调用它

merged_params = hk.data_structures.merge(params_f, params_g)
print('\nThe name and shape of the merged parameters are:\n',
      parameter_shapes(merged_params))

f_out = transformed_f.apply(merged_params, rng, x)  # fails
# ValueError: 'linear/w' with retrieved shape (5, 20) does not match shape=[5, 4] dtype=dtype('float32')

我们如何共享 'simple_module'mlp 的参数,但保持两个输出 linear 层的参数分开?

一种解决方案是在函数外部实例化 simple_modulemlp,以便它们仅实例化一次,然后在两个函数中使用该实例。但是所有 Haiku 模块都必须在转换中初始化,因此天真地这样做会导致错误

module_instance = SimpleModule(output_channels=2)  # this fails
# ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.
mlp = hk.nets.MLP((10, 5))

def f(x):
  """A SimpleModule followed by a Linear layer."""
  out = module_instance(x)
  out = mlp(out)
  linear = hk.Linear(4)
  return linear(out)

def g(x):
  """A SimpleModule followed by a bigger Linear layer."""
  out = module_instance(x)
  out = mlp(out)
  linear = hk.Linear(20)  # another Linear, but bigger
  return linear(out)

我们可以通过创建另一个函数来解决这个问题

[ ]:
class CachedModule():

  def __call__(self, *inputs):
    # Create the instances if are not in the cache.
    if not hasattr(self, 'cached_simple_module'):
      self.cached_simple_module = SimpleModule(output_channels=2)
    if not hasattr(self, 'cached_mlp'):
      self.cached_mlp = hk.nets.MLP((10, 5))

    # Apply the cached instances.
    out = self.cached_simple_module(*inputs)
    out = self.cached_mlp(out)
    return out


def f(x):
  """A SimpleModule followed by a Linear layer."""
  shared_preprocessing = CachedModule()
  out = shared_preprocessing(x)
  linear = hk.Linear(4)
  return linear(out)

def g(x):
  """A SimpleModule followed by a bigger Linear layer."""
  shared_preprocessing = CachedModule()
  out = shared_preprocessing(x)
  linear = hk.Linear(20)  # another Linear, but bigger
  return linear(out)


# Transform both functions, and print the parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))

transformed_f = hk.transform(f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(g)
params_g = transformed_g.init(rng, x)

print('\nThe name and shape of the f parameters are:\n',
      parameter_shapes(params_f))
print('\nThe name and shape of the g parameters are:\n',
      parameter_shapes(params_g))

# Verify that the simple module parameters are shared.
assert_all_equal(params_f['mlp/~/linear_0'],
                 params_g['mlp/~/linear_0'])
assert_all_equal(params_f['mlp/~/linear_1'],
                 params_g['mlp/~/linear_1'])
print('\nThe MLP parameters are shared!')
The name and shape of the f parameters are:
 {'linear': {'b': (4,), 'w': (5, 4)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'simple_module': {'w': (3, 2)}}

The name and shape of the g parameters are:
 {'linear': {'b': (20,), 'w': (5, 20)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'simple_module': {'w': (3, 2)}}

The MLP parameters are shared!

如果我们想共享大量的模块,手动缓存每个模块在 CachedModule 内部可能会变得乏味。此外,如果我们不必为每个要缓存的函数定义不同的 CachedModule 对象,那就太好了。

我们可以使用 hk.to_module 来创建一个更通用的 CachedModule 对象,该对象接受任意 Haiku 函数并缓存它

[ ]:
class CachedModule():
  """Cache one instance of the function and call it multiple times."""
  def __init__(self, fn):
    self._fn = fn

  def __call__(self, *args, **kwargs):
    if not hasattr(self, "_instance"):
      ModularisedFn = hk.to_module(self._fn)
      self._instance = ModularisedFn()
    return self._instance(*args, **kwargs)

def shared_preprocessing_fn(x):
  simple_module = SimpleModule(output_channels=2)
  out = simple_module(x)
  mlp = hk.nets.MLP((10, 5))
  return mlp(out)

def f(x):
  """A SimpleModule followed by a Linear layer."""
  shared_preprocessing = CachedModule(shared_preprocessing_fn)
  out = shared_preprocessing(x)
  linear = hk.Linear(4)
  return linear(out)

def g(x):
  """A SimpleModule followed by a bigger Linear layer."""
  shared_preprocessing = CachedModule(shared_preprocessing_fn)
  out = shared_preprocessing(x)
  linear = hk.Linear(20)  # another Linear, but bigger
  return linear(out)


# Transform both functions, and print the parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))

transformed_f = hk.transform(f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(g)
params_g = transformed_g.init(rng, x)

print('\nThe name and shape of the f parameters are:\n',
      parameter_shapes(params_f))
print('\nThe name and shape of the g parameters are:\n',
      parameter_shapes(params_g))

# Verify that the simple module parameters are shared.
assert_all_equal(params_f['shared_preprocessing_fn/mlp/~/linear_0'],
                 params_g['shared_preprocessing_fn/mlp/~/linear_0'])
assert_all_equal(params_f['shared_preprocessing_fn/mlp/~/linear_1'],
                 params_g['shared_preprocessing_fn/mlp/~/linear_1'])
print('\nThe MLP parameters are shared!')
The name and shape of the f parameters are:
 {'linear': {'b': (4,), 'w': (5, 4)}, 'shared_preprocessing_fn/mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'shared_preprocessing_fn/mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'shared_preprocessing_fn/simple_module': {'w': (3, 2)}}

The name and shape of the g parameters are:
 {'linear': {'b': (20,), 'w': (5, 20)}, 'shared_preprocessing_fn/mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'shared_preprocessing_fn/mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'shared_preprocessing_fn/simple_module': {'w': (3, 2)}}

The MLP parameters are shared!

当我们使用对象时,定义一个装饰器来做同样的事情也很方便

[7]:
def share_parameters():
  def decorator(fn):
    def wrapper(*args, **kwargs):
      if wrapper.instance is None:
        wrapper.instance = hk.to_module(fn)()
      return wrapper.instance(*args, **kwargs)
    wrapper.instance = None
    return functools.wraps(fn)(wrapper)
  return decorator


class Wrapper():

  @share_parameters()
  def shared_preprocessing(self, x):
    simple_module = SimpleModule(output_channels=2)
    out = simple_module(x)
    mlp = hk.nets.MLP((10, 5))
    return mlp(out)

  def f(self, x):
    """A SimpleModule followed by a Linear layer."""
    out = self.shared_preprocessing(x)
    linear = hk.Linear(4)
    return linear(out)

  def g(self, x):
    """A SimpleModule followed by a bigger Linear layer."""
    out = self.shared_preprocessing(x)
    linear = hk.Linear(20)  # another Linear, but bigger
    return linear(out)

# Transform both functions, and print the parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))

wrapper = Wrapper()
transformed_f = hk.transform(wrapper.f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(wrapper.g)
params_g = transformed_g.init(rng, x)

print('\nThe name and shape of the f parameters are:\n',
      parameter_shapes(params_f))
print('\nThe name and shape of the g parameters are:\n',
      parameter_shapes(params_g))

# Verify that the simple module parameters are shared.
assert_all_equal(params_f['shared_preprocessing/mlp/~/linear_0'],
                 params_g['shared_preprocessing/mlp/~/linear_0'])
assert_all_equal(params_f['shared_preprocessing/mlp/~/linear_1'],
                 params_g['shared_preprocessing/mlp/~/linear_1'])
print('\nThe MLP parameters are shared!')
The name and shape of the f parameters are:
 {'linear': {'b': (4,), 'w': (5, 4)}, 'shared_preprocessing/mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'shared_preprocessing/mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'shared_preprocessing/simple_module': {'w': (3, 2)}}

The name and shape of the g parameters are:
 {'linear': {'b': (20,), 'w': (5, 20)}, 'shared_preprocessing/mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'shared_preprocessing/mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'shared_preprocessing/simple_module': {'w': (3, 2)}}

The MLP parameters are shared!