Haiku 文档#

Haiku 是一个构建于 JAX 之上的库,旨在为机器学习研究提供简单、可组合的抽象。

import haiku as hk
import jax
import jax.numpy as jnp

def forward(x):
  mlp = hk.nets.MLP([300, 100, 10])
  return mlp(x)

forward = hk.transform(forward)

rng = hk.PRNGSequence(jax.random.PRNGKey(42))
x = jnp.ones([8, 28 * 28])
params = forward.init(next(rng), x)
logits = forward.apply(params, next(rng), x)

安装#

请参阅 https://github.com/google/jax#pip-installation 以获取有关安装 JAX 的说明。

我们建议运行以下命令安装最新版本的 Haiku

$ pip install git+https://github.com/deepmind/dm-haiku

或者,你可以通过 PyPI 安装

$ pip install -U dm-haiku

已知问题#

警告

在 Haiku 网络内部使用 JAX 转换(如 jax.jit()jax.remat())可能会导致难以解释的跟踪错误和潜在的静默错误结果。请阅读 嵌套 JAX 函数和 Haiku 模块的限制 以了解如何解决这些问题。

贡献#

支持#

如果您遇到问题,请在我们的 issue 跟踪器 上提交 issue 告知我们。

许可#

Haiku 在 Apache 2.0 许可下获得许可。

索引和表格#