史上最詳細ConvLstm的pytorch程式碼解讀分析
# -*- coding:utf-8 -*-
"""
作者:Refrain
日期:2020.10.29
"""
import torch.nn as nn
import torch
class ConvLSTMCell(nn.Module):
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
#input_dim是每個num_layer的第一個時刻的的輸入dim,即channel
#hidden_dim是每一個num_layer的隱藏層單元,如第一層是64,第二層是128,第三層是128
#kernel_size是卷積核
super(ConvLSTMCell, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
#padding的目的是保持卷積之後大小不變
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,#卷積輸入的尺寸
out_channels=4 * self.hidden_dim,#因為lstmcell有四個門,隱藏層單元是rnn的四倍
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias)
def forward(self, input_tensor, cur_state):
#input_tensor的尺寸為(batch_size,channel,weight,width),沒有time_step
#cur_state的尺寸是(batch_size,(hidden_dim)channel,weight,width),是呼叫函式init_hidden返回的細胞狀態
h_cur, c_cur = cur_state
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
#conv層的卷積不需要和linear一樣,可以是多維的,只要channel數目相同即可
combined_conv = self.conv(combined)
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
#使用split函式把輸出4*hidden_dim分割成四個門
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
c_next = f * c_cur + i * g #下一個細胞狀態
h_next = o * torch.tanh(c_next) #下一個hc
return h_next, c_next
def init_hidden(self, batch_size, image_size):
height, width = image_size
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
class ConvLSTM(nn.Module):
"""
Parameters:
input_dim: Number of channels in input
hidden_dim: Number of hidden channels
kernel_size: Size of kernel in convolutions
num_layers: Number of LSTM layers stacked on each other
batch_first: Whether or not dimension 0 is the batch or not
bias: Bias or no bias in Convolution
return_all_layers: Return the list of computations for all layers
Note: Will do same padding.
Input:
A tensor of size B, T, C, H, W or T, B, C, H, W
Output:
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
0 - layer_output_list is the list of lists of length T of each output
1 - last_state_list is the list of last states
each element of the list is a tuple (h, c) for hidden state and memory
Example:
>> x = torch.rand((32, 10, 64, 128, 128))
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
>> _, last_states = convlstm(x)
>> h = last_states[0][0] # 0 for layer index, 0 for h index
"""
def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
batch_first=False, bias=True, return_all_layers=False):
super(ConvLSTM, self).__init__()
self._check_kernel_size_consistency(kernel_size)
#核對尺寸,用的函式是靜態方法
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
#kernel_size==hidden_dim=num_layer的維度,因為要遍歷num_layer次
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
if not len(kernel_size) == len(hidden_dim) == num_layers:
raise ValueError('Inconsistent list length.')
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.num_layers = num_layers
self.batch_first = batch_first
self.bias = bias
self.return_all_layers = return_all_layers
#如果return_all_layers==true,則返回所有得到h,如果為false,則返回最後一層的最後一個h
cell_list = []
for i in range(0, self.num_layers):
#判斷input_dim是否是第一層的第一個輸入,如果是的話則使用input_dim,否則取第i層的最後一個hidden_dim的channel數作為輸入
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
hidden_dim=self.hidden_dim[i],
kernel_size=self.kernel_size[i],
bias=self.bias))
#以num_layer為三層為例,則cell_list列表裡的內容為[convlstmcell0(),convlstmcell1(),convlstmcell2()]
#Module_list把nn.module的方法作為列表存放進去,在forward的時候可以呼叫Module_list的東西,cell_list【0】,cell_list【1】,
#一直到cell_list【num_layer】,
self.cell_list = nn.ModuleList(cell_list)
def forward(self, input_tensor, hidden_state=None):
#第一次傳入hidden_state為none
#input_tensor的size為(batch_size,time_step,channel,height,width)
"""
Parameters
----------
input_tensor: todo
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
hidden_state: todo
None. todo implement stateful
Returns
-------
last_state_list, layer_output
"""
#在forward裡開始構建模型,首先把input_tensor的維度調整,然後初始化隱藏狀態
if not self.batch_first:
# (t, b, c, h, w) -> (b, t, c, h, w)
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
b, _, _, h, w = input_tensor.size()
# Implement stateful ConvLSTM
if hidden_state is not None:
raise NotImplementedError()
else:
# Since the init is done in forward. Can send image size here
#呼叫convlstm的init_hidden方法不是lstmcell的方法
#返回的hidden_state有num_layer個hc,cc
hidden_state = self._init_hidden(batch_size=b,
image_size=(h, w))
layer_output_list = []
last_state_list = []
seq_len = input_tensor.size(1)#取time_step
cur_layer_input = input_tensor
#初始化h之後開始前向傳播
for layer_idx in range(self.num_layers):
#在已經初始化好了的hidden_state中取出第num_layer個狀態給num_layer的h0,c0,其作為第一個輸入
h, c = hidden_state[layer_idx]
output_inner = []
#開始每一層的時間步傳播
for t in range(seq_len):
#用cell_list[i]表示第i層的convlstmcell,計算每個time_step的h和c
h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
cur_state=[h, c])
#將每一次的h存放在output_inner裡
output_inner.append(h)
#layer_output是五維向量,在dim=1的維度堆疊,和input_tensor的維度保持一致
layer_output = torch.stack(output_inner, dim=1)
#吧每一層輸出肚餓五維向量作為下一層的輸入,因為五維向量的輸入沒有num_layer,所以每一層的輸入都要喂入五維向量
cur_layer_input = layer_output
#layer_output_list存放的是第一層,第二層,第三層的每一層的五維向量,這些五維向量作為input_tensor的輸入
layer_output_list.append(layer_output)
#last_state_list裡面存放的是第一層,第二層,第三次最後time_step的h和c
last_state_list.append([h, c])
if not self.return_all_layers:
#如果return_all_layers==false的話,則返回每一層最後的狀態,返回最後一層的五維向量,返回最後一層的h和c
layer_output_list = layer_output_list[-1:]
last_state_list = last_state_list[-1:]
return layer_output_list, last_state_list
def _init_hidden(self, batch_size, image_size):
init_states = []
for i in range(self.num_layers):
#cell_list[i]是celllstm的單元,以呼叫裡面的方法
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
#返回的init_states為num_layer個hc=(batch_size,channel(hidden_dim),height,width),cc=(batch_size,channel(hidden_dim),height,width)
return init_states
@staticmethod
def _check_kernel_size_consistency(kernel_size):
if not (isinstance(kernel_size, tuple) or
(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
raise ValueError('`kernel_size` must be tuple or list of tuples')
@staticmethod
def _extend_for_multilayer(param, num_layers):
if not isinstance(param, list):
param = [param] * num_layers
return param
相關文章
- EventBus 3.0+ 原始碼詳解(史上最詳細圖文講解)原始碼
- 史上最詳細的Hadoop環境搭建Hadoop
- 史上最為詳細的javascript繼承JavaScript繼承
- 史上最詳細的一線大廠Mysql面試題詳解MySql面試題
- 史上最詳細的IDEA優雅整合Maven+SSM框架(詳細思路+附帶原始碼)IdeaMavenSSM框架原始碼
- 史上最全、最詳細的Docker學習資料Docker
- 史上最全、最詳細的 kafka 學習筆記!Kafka筆記
- 【UGUI原始碼分析】Unity遮罩之Mask詳細解讀UGUI原始碼Unity遮罩
- 可能是最詳細的字元編碼詳解字元
- 史上最詳細的 webpack 講解 1 (vue-cli 中 build.js)WebVueUIJS
- 史上最通俗分散式鎖解讀分散式
- Kafka原始碼篇 --- 可能是你看過最詳細的RecordAccumulator解讀Kafka原始碼
- 史上最全最強SpringMVC詳細示例實戰教程SpringMVC
- 超詳細解讀:神經語義解析的結構化表示學習 | 附程式碼分析
- Lift, Splat, Shoot, LSS程式碼詳盡分析與解讀
- 【UGUI原始碼分析】Unity遮罩之RectMask2D詳細解讀UGUI原始碼Unity遮罩
- pytorch lstm原始碼解讀PyTorch原始碼
- SqueezeNet詳細解讀
- 史上最詳細 VUE2.0 全套 demo 講解 基礎4(條件渲染)Vue
- Kubernetes叢集部署史上最詳細(一)Kubernetes叢集安裝
- 史上最詳細域名連結被微信封殺攔截圖蔽解決方案
- apisix 最詳細原始碼分析以及手擼一個 apisixAPI原始碼
- [轉帖]記憶體分析之GCViewer詳細解讀記憶體GCView
- 全網最詳細解讀《GIN-HOW POWERFUL ARE GRAPH NEURAL NETWORKS》!!!
- 生命週期詳細解讀(含部分原始碼)原始碼
- 超詳細的 Bert 文字分類原始碼解讀 | 附原始碼文字分類原始碼
- 史上最詳細的Vue實戰專案之喵喵電影原始碼免費領取Vue原始碼
- 史上最簡單的Spring Security教程(三十六):RememberMeAuthenticationFilter詳解SpringREMFilter
- 《CNN Image Retrieval in PyTorch: Training and evaluati-ng CNNs for Image Retrieval in PyTorch》程式碼思路解讀CNNPyTorchAI
- 程式碼歷史上最昂貴的 7 個錯誤
- mysql 5.7配置項最詳細的解釋MySql
- vue的事件冒泡 最詳細解釋版本Vue事件
- Kotlin系列教程——史上最全面、最詳細的學習教程,持續更新中....Kotlin
- 2022解碼Z世代:史上最分裂的一代
- Pytorch的API詳解PyTorchAPI
- 詳細解讀go語言中的chnanelGoNaN
- Kubernetes叢集部署史上最詳細(二)Prometheus監控Kubernetes叢集Prometheus
- 手寫 Promise 詳細解讀Promise