​Jax 生態再添新庫:DeepMind 開源 Haiku、RLax

AIBigbull2050發表於2020-02-26
2020-02-25 10:51:05

機器之心報導

參與:一鳴

Jax 是一個優秀的程式碼庫,在進行科學計算的同時能夠自動微分,還有 GPU、TPU 的效能加速加持。但是 Jax 的生態還不夠完善,使用者相比 TF、PyTorch 少得多。近日,DeepMind 開源了兩個基於 Jax 的新庫,給這個生態注入了新的活力。

Jax 生態再添新庫:DeepMind 開源 Haiku、RLax


Jax 是谷歌開源的一個科學計算庫,能對 Python 程式與 NumPy 運算執行自動微分,而且能夠在 GPU 和 TPU 上執行,具有很高的效能。基於 Jax 已有很多優秀的開源專案,如 Trax 等。近日,DeepMind 開源了兩個基於 Jax 的新機器學習庫,分別是 Haiku 和 RLax,它們都有著各自的特色,對於豐富深度學習社群框架、提升研究者和開發者的使用體驗有著不小的意義。
Haiku::
Haiku:在 Jax 上進行物件導向開發
首先值得注意的是 Haiku,這是一個面向 Jax 的深度學習程式碼庫,它是由 Sonnet 作者——一個谷歌的神經網路庫團隊開發的。
為什麼要使用 Haiku?這是因為其支援的是 Jax,Jax 在靈活性和效能上具有相當的優勢。但是另一方面,Jax 本身是函式式的,和麵向物件的使用者習慣有差別。因此,透過 Haiku,使用者可以在 Jax 上進行物件導向開發了。
此外,Haiku 的 API 和程式設計模型都是基於 Sonnet,因此使用過 Sonnet 的使用者可以快速上手。專案作者也表示,Sonnet 之於 TensorFlow 的提升就如同 Haiku 之於 Jax。
目前,Haiku 已公開了 Alpha 版本,已完全開源。專案作者歡迎使用者提出建議。
Haiku 怎麼和 Jax 互動
Haiku 主要分為兩個模組:hk.Modules和 hk.transform。下文將會分別介紹。
hk.Modules 是 Python 物件,儲存著到引數、其他模組和方法的參照(references)。
hk.transform 則負責將物件導向的模組轉換為純粹的函式式程式碼,然後讓 jax 中的 jax.jit, jax.grad, jax.pmap 等進行處理,從而實現和 Jax 元件的相容。
Haiku 的功能
Haiku 能夠做到很多機器學習需要完成的任務,相關功能和程式碼如下:


自定義你的模組


在 Haiku 中,類似於 TF2.0 和 PyTorch,你可以自定義模組,作為 hk.Module 的子類。例如,自定義一個線性層:

class MyLinear(hk.Module):  def __init__(self, output_size, name=None):    super(MyLinear, self).__init__(name=name)    self.output_size = output_size  def __call__(self, x):    j, k = x.shape[-1], self.output_size    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)    return jnp.dot(x, w) + b


可以看出,Haiku 的程式碼和 TensorFlow 等非常相似,但是你可以看到包括 numpy 等的方法還可以定義在模組中。Haiku 的優勢就在於,它不是一個封閉的框架,而是程式碼庫,因此可以在定義模組的過程中呼叫其他的庫和方法。
當定義好線性層後,我們想要試試自動微分的方法了:

def forward_fn(x):  model = MyLinear(10)  return model(x)# Turn `forward_fn` into an object with `init` and `apply` methods.forward = hk.transform(forward_fn)x = jnp.ones([1, 1])# When we run `forward.init`, Haiku will run `forward(x)` and collect initial# parameter values. Haiku requires you pass a RNG key to `init`, since parameters# are typically initialized randomly:key = hk.PRNGSequence(42)params = forward.init(next(key), x)# When we run `forward.apply`, Haiku will run `forward(x)` and inject parameter# values from the `params` that are passed as the first argument. We do not require# an RNG key by default since models are deterministic. You can (of course!) change# this using `hk.transform(f, apply_rng=True)` if you prefer:y = forward.apply(params, x)


這裡可以看到,定義好模組和前向傳播的函式後,使用 hk.transform(forward_fn) 可以將這種物件導向的方法轉換成 Jax 底層的函式式程式碼進行處理,因此你不需要擔心底層的計算問題。另外,這裡的程式碼相比 TensorFlow 還要簡潔。
非訓練狀態
有時候,我們想要在訓練的過程中保持某些內部引數的狀態,在 Haiku 上這也是非常容易實現的。

def forward(x, is_training):  net = hk.nets.ResNet50(1000)  return net(x, is_training)forward = hk.transform_with_state(forward)# The `init` function now returns parameters **and** state. State contains# anything that was created using `hk.set_state`. The structure is the same as# params (e.g. it is a per-module mapping of named values).params, state = forward.init(rng, x, is_training=True)# The apply function now takes both params **and** state. Additionally it will# return updated values for state. In the resnet example this will be the# updated values for moving averages used in the batch norm layers.logits, state = forward.apply(params, state, rng, x, is_training=True)


如上所示,只需要兩行程式碼進行設定。
和 jax.pmap 聯合進行分散式訓練
由於所有的程式碼都會被轉換成 Jax 的函式,因此它們和 jax.pmap. 是完全相容的。這說明,我們可以利用 jax.pmap 來進行分散式計算。
如下為進行資料分割的分散式加速程式碼,首先,我們先定義模型和訓練步驟:

def loss_fn(inputs, labels):  logits = hk.nets.MLP([8, 4, 2])(x)  return jnp.mean(softmax_cross_entropy(logits, labels))loss_obj = hk.transform(loss_fn)# Initialize the model on a single device.rng = jax.random.PRNGKey(428)sample_image, sample_label = next(input_dataset)params = loss_obj.init(rng, sample_image, sample_label)


然後設定將引數複製到所有的裝置上:

# Replicate params onto all devices.num_devices = jax.local_device_count()params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)


定義資料分批的方法,以及引數更新的方法:

def make_superbatch():  """Constructs a superbatch, i.e. one batch of data per device."""  # Get N batches, then split into list-of-images and list-of-labels.  superbatch = [next(input_dataset) for _ in range(num_devices)]  superbatch_images, superbatch_labels = zip(*superbatch)  # Stack the superbatches to be one array with a leading dimension, rather than  # a python list. This is what `jax.pmap` expects as input.  superbatch_images = np.stack(superbatch_images)  superbatch_labels = np.stack(superbatch_labels)  return superbatch_images, superbatch_labelsdef update(params, inputs, labels, axis_name='i'):  """Updates params based on performance on inputs and labels."""  grads = jax.grad(loss_obj.apply)(params, inputs, labels)  # Take the mean of the gradients across all data-parallel replicas.  grads = jax.lax.pmean(grads, axis_name)  # Update parameters using SGD or Adam or ...  new_params = my_update_rule(params, grads)  return new_params


最後開始分散式計算即可:

# Run several training updates.for _ in range(10):  superbatch_images, superbatch_labels = make_superbatch()  params = jax.pmap(update, axis_name='i')(params, superbatch_images,                                           superbatch_labels)


RLax:Jax 上也有強化學習庫了
除了令人印象深刻的 Haiku 外,DeepMind 還開源了 RLax——這是一個基於 Jax 的強化學習庫。
相比 Haiku,RLax 專門針對強化學習。專案作者認為,儘管強化學習中的運算元和函式並不是完全的演算法,但是,如果需要構建完全基於函式式的智慧體,就需要特定的數學運算元。
因此,函式式的 Jax 就成為了一個不錯的選擇。在 Jax 上進行一定的開發後,就可以有專用的強化學習庫了。RLax 目前的資料還較少,但專案已提供了一個示例程式碼:使用 RLax 進行 Q-learning 模型的搭建和訓練。
程式碼如下,首先,使用 Haiku 構建基本的強化學習模型:

def build_network(num_actions: int) -> hk.Transformed:  def q(obs):    flatten = lambda x: jnp.reshape(x, (-1,))    network = hk.Sequential(        [flatten, nets.MLP([FLAGS.hidden_units, num_actions])])    return network(obs)  return hk.transform(q)


設定訓練的方法:

def main_loop(unused_arg):  env = catch.Catch(seed=FLAGS.seed)  rng = hk.PRNGSequence(jax.random.PRNGKey(FLAGS.seed))  # Build and initialize Q-network.  num_actions = env.action_spec().num_values  network = build_network(num_actions)  sample_input = env.observation_spec().generate_value()  net_params = network.init(next(rng), sample_input)  # Build and initialize optimizer.  optimizer = optix.adam(FLAGS.learning_rate)  opt_state = optimizer.init(net_params)


以下和 Jax 結合,定義策略、獎勵等:

@jax.jitdef policy(net_params, key, obs):    """Sample action from epsilon-greedy policy."""    q = network.apply(net_params, obs)    a = rlax.epsilon_greedy(epsilon=FLAGS.epsilon).sample(key, q)    return q, a@jax.jitdef eval_policy(net_params, key, obs):    """Sample action from greedy policy."""    q = network.apply(net_params, obs)    return rlax.greedy().sample(key, q)@jax.jitdef update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t):    """Update network weights wrt Q-learning loss."""def q_learning_loss(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t):    q_tm1 = network.apply(net_params, obs_tm1)    td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t)    return rlax.l2_loss(td_error)    dloss_dtheta = jax.grad(q_learning_loss)(net_params, obs_tm1, a_tm1, r_t,                                             discount_t, q_t)    updates, opt_state = optimizer.update(dloss_dtheta, opt_state)    net_params = optix.apply_updates(net_params, updates)    return net_params, opt_stateprint(f"Training agent for {FLAGS.train_episodes} episodes...")


可以看到,RLax 基於 jax.jit 的方法,在效能方面有不錯的提升。更有趣的是,構建模型的過程中使用了前文提到的 Haiku,可見基於 Jax 生態的程式碼庫之間都是可以相容的。
從 DeepMind 近日開源的兩個程式碼庫可以看到,雖然現在深度學習框架依然在穩步發展,但是針對高效能的科學計算也漸漸變得更為重要了。而 Jax 這樣的優秀開源專案,無疑也需要更多的生態支援。這次開源的 Haiku 和 RLax,無疑能夠鞏固 Jax 的地位,使其優秀的特性進一步得到發揮。







來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69946223/viewspace-2677328/,如需轉載,請註明出處,否則將追究法律責任。

相關文章