Haiku 文档
目录
Haiku 文档#
Haiku 是一个构建于 JAX 之上的库,旨在为机器学习研究提供简单、可组合的抽象。
import haiku as hk
import jax
import jax.numpy as jnp
def forward(x):
mlp = hk.nets.MLP([300, 100, 10])
return mlp(x)
forward = hk.transform(forward)
rng = hk.PRNGSequence(jax.random.PRNGKey(42))
x = jnp.ones([8, 28 * 28])
params = forward.init(next(rng), x)
logits = forward.apply(params, next(rng), x)
安装#
请参阅 https://github.com/google/jax#pip-installation 以获取有关安装 JAX 的说明。
我们建议运行以下命令安装最新版本的 Haiku
$ pip install git+https://github.com/deepmind/dm-haiku
或者,你可以通过 PyPI 安装
$ pip install -U dm-haiku
已知问题#
警告
在 Haiku 网络内部使用 JAX 转换(如 jax.jit()
和 jax.remat()
)可能会导致难以解释的跟踪错误和潜在的静默错误结果。请阅读 嵌套 JAX 函数和 Haiku 模块的限制 以了解如何解决这些问题。
贡献#
支持#
如果您遇到问题,请在我们的 issue 跟踪器 上提交 issue 告知我们。
许可#
Haiku 在 Apache 2.0 许可下获得许可。