Pytorch——torch.nn.Sequential()詳解

希望每天漲粉發表於2021-10-20

參考:官方文件    原始碼

官方文件

nn.Sequential

  A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.

  翻譯:一個有序的容器,神經網路模組將按照在傳入構造器的順序依次被新增到計算圖中執行,同時以神經網路模組為元素的有序字典也可以作為傳入引數。

 

使用方法一:作為一個有順序的容器

  作為一個有順序的容器,將特定神經網路模組按照在傳入構造器的順序依次被新增到計算圖中執行。

官方 Example:

model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )
# When `model` is run,input will first be passed to `Conv2d(1,20,5)`.
# The output of `Conv2d(1,20,5)` will be used as the input to the first `ReLU`;
# the output of the first `ReLU` will become the input for `Conv2d(20,64,5)`.
# Finally, the output of `Conv2d(20,64,5)` will be used as input to the second `ReLU`

例子:

net = nn.Sequential(
    nn.Linear(num_inputs, num_hidden)
    # 傳入其他層
    )

 

使用方法二:作為一個有序字典

  將以特定神經網路模組為元素的有序字典(OrderedDict)為引數傳入。

官方 Example :

model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

例子:

net = nn.Sequential()
net.add_module('linear1', nn.Linear(num_inputs, num_hiddens))
net.add_module('linear2', nn.Linear(num_hiddens, num_ouputs))
# net.add_module ......

 

原始碼分析

初始化函式 init 

    def __init__(self, *args):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

  __init__ 首先使用 if 條件進行判斷,若傳入的引數為 1 個,且型別為 OrderedDict,則通過字典索引的方式利用 add_module 函式 將子模組新增到現有模組中。否則,通過 for 迴圈遍歷引數,將所有的子模組新增到現有模組中。 這裡需要注意,Sequential 模組的初始換函式沒有異常處理。

forward 函式

    def forward(self, input):
        for module in self:
            input = module(input)
        return input

  因為每一個 module 都繼承於 nn.Module,都會實現 __call__ 與 forward 函式,所以 forward 函式中通過 for 迴圈依次呼叫新增到 self._module 中的子模組,最後輸出經過所有神經網路層的結果。

 

相關文章