Torch中的RNN底層程式碼實現

Snail_Walker發表於2018-01-04

Torch中的RNN【1】這個package包括了RNN,RL,通過這個package可以很容易構建RNN,RL的模型。

安裝:

luarocks install torch
luarocks install nn
luarocks install torchx
luarocks install dataload

如果有CUDA:
luarocks install cutorch
luarocks install cunn

記得安裝:
luarocks install rnn

但是如果要使用nn.Reccurent,需要安裝:【4】

理論篇

這一次主要是講最簡單的RNN,也就是Simple RNN。實現的話是根據這兩篇論文:【6】,【7】

首先介紹一下Simple RNN的整個網路結構,再說一下ρ

\rho
step 的BPTT。

整個網路可以用下圖來表示:(這種網路的輸入一部分是當前的輸入,另外一部分來自於hidden layer的上一個輸出,這種叫做Elman Network。另外一種網路是一部分來自於當前輸入,另外一部分來自於整個網路的上一個輸出)
這裡寫圖片描述

  • 當前輸入wt
    w_t
    與上一個hidden layer的輸出st1
    s_{t-1}
    兩個vector相加,得到真正輸入到網路裡面的東西。
    這裡寫圖片描述
  • 接著是把輸入送進一個logistic regression裡面,得到hidden layer:st
    s_t
    . st
    s_t
    一方面往輸出那條路徑走,另外一方面往快取或者叫做Context裡面存起來,稱為下一個輸入需要的一部分,替換st1
    s_{t-1}

    這裡寫圖片描述
    這裡寫圖片描述
  • st
    s_t
    輸出該時刻的output: yt
    y_t
    一個Linear加上softmax,非常簡單。
    這裡寫圖片描述
    這裡寫圖片描述

這樣呢就把整個網路結構描述完了,接下來就是如何訓練得到引數了。(其實RNN,LSTM還有很多小的trick,同樣的演算法,trick不一樣,結果都會千差萬別)

在另外的論文裡面把這幅圖給完整畫了出來,顯得更加清晰:
這裡寫圖片描述

瞭解了整個網路的以後,需要定義loss,在進行BP的時候,首先定義loss function,一般採用的是SSE:dpk

d_{pk}
就是第p個sample,輸出的feature第k個的label。y
y
是prediction。
這裡寫圖片描述

對於w的更新,都是採用梯度下降:
這裡寫圖片描述

對於輸出output部分進行求導:
這裡寫圖片描述

再進一步輸出output的linear regression部分的w進行求導:
這裡寫圖片描述

接著是hidden layer進行求導:
這裡寫圖片描述

對hidden layer的輸入的input部分引數進行求導:
這裡寫圖片描述

對hidden layer的上一個hidden layer的作為input的部分進行求導:
這裡寫圖片描述

目前的loss為SSE的時候,一般採用logistic function作為輸出的函式:
這裡寫圖片描述
這裡寫圖片描述

當然,也可以有別的loss function,對應的output function g也會做鄉相應改變。

比如對於Gaussian Distribution:
這裡寫圖片描述
這裡寫圖片描述
這裡寫圖片描述

使用cross-entropy作為loss:(g為logistic function)
這裡寫圖片描述
這裡寫圖片描述

對於分類問題,採用multinomial dostribution:
採用的是softmax作為g,然後loss function為:
這裡寫圖片描述

這裡寫圖片描述

這裡寫圖片描述

在RNN中經常聽到BPTT,就是讓RNN在進行後向傳遞的時候不僅僅是當前這個時期,還有的是更多時刻:τ。

比如τ = 3,展開的圖如下:展開了三次,那麼進行BP的時候,就把各個引數往後相乘來更新w,這裡需要注意vanishing gradient effect和explode gradient effect的東西,一個梯度衰減比如為0,一個梯度爆炸。
這裡寫圖片描述

還有一種圖可以表示梯度變化,紅色的表示梯度的方向:
這裡寫圖片描述

如果用公式來表示這個整個過程就是前向:
這裡寫圖片描述

後向更新梯度:(每個時刻的梯度都會進行疊加到最後的w更新)
這裡寫圖片描述

程式碼篇

這次描述的是Simple RNN,函式為nn.Recurrent 。在nn中有兩個抽象類,一個是nn,用來構建網路,一個是Criterion【3】,用來提供比如cross entropy,reward。具體介紹可以看【2】,還有對應的論文。

在【3】中提出了一個簡單的例子:目前下面的nn.Recurrent已經不在Torch的庫中,所以要使用的話,就去安裝這個人寫的【4】

這裡面的實現的RNN是最簡單的,連hidden layer都沒有,直接transfer就是輸出了。

nn.Recurrent(start, input, feedback, [transfer, rho, merge])
-- start:對初始t=0的input進行處理
-- input:對t~=0的時候input進行處理
-- feedback:對s(t)進行處理快取到了s(t-1)
-- transfer:對輸出進行處理的函式
-- rho:進行BPTT的steps的數目
-- merge:對input x(t)和上一個時刻的輸出s(t-1)進行融合
-- generate some dummy inputs and gradOutputs sequences
-- 生成dummy input
inputs, gradOutputs = {}, {} 
for step=1,rho do
    inputs[step] = torch.randn(batchSize,inputSize)
    gradOutputs[step] = torch.randn(batchSize,inputSize) 
end

-- 呼叫RNN
-- an AbstractRecurrent instance
rnn = nn.Recurrent(
    hiddenSize, -- size of the input layer(隱層的size)
    nn.Linear(inputSize,outputSize), -- input layer(輸入層進行linear regression) 
    nn.Linear(outputSize, outputSize), -- recurrent layer 輸出層的linear regression
    nn.Sigmoid(), -- transfer function,把輸入通過linear regression之後的結果送到這個函式得到s(t),這個函式也可以改成ReLU別的啟用函式
    rho -- maximum number of time-steps for BPTT,進行BPTT時候的steps
)

-- feed-forward and backpropagate through time like this :
for step=1,rho 
do
    rnn:forward(inputs[step])
    rnn:backward(inputs[step], gradOutputs[step])
end

rnn:backwardThroughTime() -- call backward on the internal modules 
gradInputs = rnn.gradInputs
rnn:updateParameters(0.1)
rnn:forget() -- resets the time-step counter

對完整的nn.Reccurent的理解:【10】

assert(not nn.Recurrent, "update nnx package : luarocks install nnx")
local Recurrent, parent = torch.class('nn.Recurrent', 'nn.AbstractRecurrent')

-- 把各個module放到RNN的對應位置
-- start是對最開始t=0輸入inut做的處理
-- input是對t~=0的時刻進行input的處理
-- feedback是對s(t)進行處理快取到s(t-1)的函式
-- transfer是對最後的輸出的activation function
-- rho:是進行BPTT的時間
-- merge:對於輸入x(t)和上一個時刻的hidden layer的輸出s(t-1)的融合方法
function Recurrent:__init(start, input, feedback, transfer, rho, merge)
   parent.__init(self, rho)

   local ts = torch.type(start)
   if ts == 'torch.LongStorage' or ts == 'number' then
      start = nn.Add(start)
   elseif ts == 'table' then
      start = nn.Add(torch.LongStorage(start))
   elseif not torch.isTypeOf(start, 'nn.Module') then
      error"Recurrent : expecting arg 1 of type nn.Module, torch.LongStorage, number or table"
   end

   self.startModule = start
   self.inputModule = input
   self.feedbackModule = feedback
   self.transferModule = transfer or nn.Sigmoid()
   self.mergeModule = merge or nn.CAddTable()

   self.modules = {self.startModule, self.inputModule, self.feedbackModule, self.transferModule, self.mergeModule}

   self:buildInitialModule()
   self:buildRecurrentModule()
   self.sharedClones[2] = self.recurrentModule
end

-- 對最開始t=0的時候構建模型
-- build module used for the first step (steps == 1)
function Recurrent:buildInitialModule()
   self.initialModule = nn.Sequential()
   self.initialModule:add(self.inputModule:sharedClone())
   self.initialModule:add(self.startModule)
   self.initialModule:add(self.transferModule:sharedClone())
end

-- build module used for the other steps (steps > 1)
-- 構建整個模型
function Recurrent:buildRecurrentModule()
   local parallelModule = nn.ParallelTable()
   parallelModule:add(self.inputModule)
   parallelModule:add(self.feedbackModule)
   self.recurrentModule = nn.Sequential()
   self.recurrentModule:add(parallelModule)
   self.recurrentModule:add(self.mergeModule)
   self.recurrentModule:add(self.transferModule)
end

-- 更新輸出
function Recurrent:updateOutput(input)
   -- output(t) = transfer(feedback(output_(t-1)) + input(input_(t)))
   local output
   if self.step == 1 then
      output = self.initialModule:updateOutput(input)
   else
      if self.train ~= false then
         -- set/save the output states
         self:recycle()
         local recurrentModule = self:getStepModule(self.step)
          -- self.output is the previous output of this module
         output = recurrentModule:updateOutput{input, self.outputs[self.step-1]}
      else
         -- self.output is the previous output of this module
         output = self.recurrentModule:updateOutput{input, self.outputs[self.step-1]}
      end
   end

   self.outputs[self.step] = output
   self.output = output
   self.step = self.step + 1
   self.gradPrevOutput = nil
   self.updateGradInputStep = nil
   self.accGradParametersStep = nil
   return self.output
end

-- 求解梯度,沒有累加
function Recurrent:_updateGradInput(input, gradOutput)
   assert(self.step > 1, "expecting at least one updateOutput")
   local step = self.updateGradInputStep - 1

   local gradInput

   if self.gradPrevOutput then
      self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], self.gradPrevOutput)
      nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
      gradOutput = self._gradOutputs[step]
   end

   local output = self.outputs[step-1]
   if step > 1 then
      local recurrentModule = self:getStepModule(step)
      gradInput, self.gradPrevOutput = unpack(recurrentModule:updateGradInput({input, output}, gradOutput))
   elseif step == 1 then
      gradInput = self.initialModule:updateGradInput(input, gradOutput)
   else
      error"non-positive time-step"
   end

   return gradInput
end

-- 求解梯度,但是會把t steps的梯度相加
function Recurrent:_accGradParameters(input, gradOutput, scale)
   local step = self.accGradParametersStep - 1

   local gradOutput = (step == self.step-1) and gradOutput or self._gradOutputs[step]
   local output = self.outputs[step-1]

   if step > 1 then
      local recurrentModule = self:getStepModule(step)
      recurrentModule:accGradParameters({input, output}, gradOutput, scale)
   elseif step == 1 then
      self.initialModule:accGradParameters(input, gradOutput, scale)
   else
      error"non-positive time-step"
   end
end

function Recurrent:recycle()
   return parent.recycle(self, 1)
end

function Recurrent:forget()
   return parent.forget(self, 1)
end

function Recurrent:includingSharedClones(f)
   local modules = self.modules
   self.modules = {}
   local sharedClones = self.sharedClones
   self.sharedClones = nil
   local initModule = self.initialModule
   self.initialModule = nil
   for i,modules in ipairs{modules, sharedClones, {initModule}} do
      for j, module in pairs(modules) do
         table.insert(self.modules, module)
      end
   end
   local r = f()
   self.modules = modules
   self.sharedClones = sharedClones
   self.initialModule = initModule
   return r
end

function Recurrent:reinforce(reward)
   if torch.type(reward) == 'table' then
      -- multiple rewards, one per time-step
      local rewards = reward
      for step, reward in ipairs(rewards) do
         if step == 1 then
            self.initialModule:reinforce(reward)
         else
            local sm = self:getStepModule(step)
            sm:reinforce(reward)
         end
      end
   else
      -- one reward broadcast to all time-steps
      return self:includingSharedClones(function()
         return parent.reinforce(self, reward)
      end)
   end
end

function Recurrent:maskZero()
   error("Recurrent doesn't support maskZero as it uses a different "..
      "module for the first time-step. Use nn.Recurrence instead.")
end

function Recurrent:trimZero()
   error("Recurrent doesn't support trimZero as it uses a different "..
      "module for the first time-step. Use nn.Recurrence instead.")
end

-- 把模型列印出來
-- 比如我呼叫的是:
-- nn.Recurrent(256, nn.Identity(), nn.Linear(256, 256), nn['ReLU'](), 99999)
-- [[[
{input(t), output(t-1)} -> (1) -> (2) -> (3) -> output(t)]
    (1):  {
           input(t)
                |`-> (t==0): nn.Add
                |`-> (t~=0): nn.Identity
           output(t-1)
                |`-> nn.Linear(256 -> 256)
          }
    (2): nn.CAddTable
    (3): nn.ReLU
    }
---]]
function Recurrent:__tostring__()
   local tab = '  '
   local line = '\n'
   local next = ' -> '
   local str = torch.type(self)
   str = str .. ' {' .. line .. tab .. '[{input(t), output(t-1)}'
   for i=1,3 do
      str = str .. next .. '(' .. i .. ')'
   end
   str = str .. next .. 'output(t)]'

   local tab = '  '
   local line = '\n  '
   local next = '  |`-> '
   local ext = '  |    '
   local last = '   ... -> '
   str = str .. line ..  '(1): ' .. ' {' .. line .. tab .. 'input(t)'
   str = str .. line .. tab .. next .. '(t==0): ' .. tostring(self.startModule):gsub('\n', '\n' .. tab .. ext)
   str = str .. line .. tab .. next .. '(t~=0): ' .. tostring(self.inputModule):gsub('\n', '\n' .. tab .. ext)
   str = str .. line .. tab .. 'output(t-1)'
   str = str .. line .. tab .. next .. tostring(self.feedbackModule):gsub('\n', line .. tab .. ext)
   str = str .. line .. "}"
   local tab = '  '
   local line = '\n'
   local next = ' -> '
   str = str .. line .. tab .. '(' .. 2 .. '): ' .. tostring(self.mergeModule):gsub(line, line .. tab)
   str = str .. line .. tab .. '(' .. 3 .. '): ' .. tostring(self.transferModule):gsub(line, line .. tab)
   str = str .. line .. '}'
   return str
end

轉載請註明出處: http://blog.csdn.net/c602273091/article/details/78975636

參考連結:
【1】RNN地址: https://github.com/torch/rnn
【2】nn Package: https://arxiv.org/pdf/1511.07889.pdf
【3】RNN Code: https://github.com/torch/rnn/blob/master/doc/recurrent.md#rnn.Recurrence
【4】nn.Reccurent: https://github.com/Element-Research/rnn/blob/master/Recurrent.lua
【5】nn RNN: https://github.com/Element-Research/rnn
【6】Recurrent neural network based language model: http://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf
【7】 A guide to recurrent neural networks and backpropagation: http://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=CDD081815C5FAC4835EF27B81EEA5F8C?doi=10.1.1.3.9311&rep=rep1&type=pdf
【8】STATISTICAL LANGUAGE MODELS BASED ON NEURAL NETWORKS: (3.2~3.3)http://www.fit.vutbr.cz/%7Eimikolov/rnnlm/thesis.pdf
【9】TRAINING RECURRENT NEURAL NETWORKS:(2.5~2.8) http://www.cs.utoronto.ca/%7Eilya/pubs/ilya_sutskever_phd_thesis.pdf
【10】nn.Reccurent: https://github.com/Element-Research/rnn/blob/master/Recurrent.lua

相關文章