diffusers 原始碼解析(十四)
.\diffusers\models\unets\unet_2d_blocks_flax.py
# 版權宣告,說明該檔案的版權資訊及相關許可協議
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 許可資訊,使用 Apache License 2.0 許可
# Licensed under the Apache License, Version 2.0 (the "License");
# 該檔案只能在符合許可證的情況下使用
# you may not use this file except in compliance with the License.
# 提供許可證獲取連結
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 免責宣告,表明不提供任何形式的保證或條件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 許可證的相關許可權及限制
# See the License for the specific language governing permissions and
# limitations under the License.
# 匯入 flax.linen 模組,用於構建神經網路
import flax.linen as nn
# 匯入 jax.numpy,用於數值計算
import jax.numpy as jnp
# 從其他模組匯入特定的類,用於構建模型的各個元件
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
# 定義 FlaxCrossAttnDownBlock2D 類,表示一個 2D 跨注意力下采樣模組
class FlaxCrossAttnDownBlock2D(nn.Module):
r"""
跨注意力 2D 下采樣塊 - 原始架構來自 Unet transformers:
https://arxiv.org/abs/2103.06104
引數說明:
in_channels (:obj:`int`):
輸入通道數
out_channels (:obj:`int`):
輸出通道數
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout 率
num_layers (:obj:`int`, *optional*, defaults to 1):
注意力塊層數
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
每個空間變換塊的注意力頭數
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
是否在每個最終輸出之前新增下采樣層
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
啟用記憶體高效的注意力 https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
是否將頭維度拆分為一個新的軸進行自注意力計算。在大多數情況下,
啟用此標誌應加快 Stable Diffusion 2.x 和 Stable Diffusion XL 的計算速度。
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
引數的資料型別
"""
# 定義輸入通道數
in_channels: int
# 定義輸出通道數
out_channels: int
# 定義 Dropout 率,預設為 0.0
dropout: float = 0.0
# 定義注意力塊的層數,預設為 1
num_layers: int = 1
# 定義注意力頭數,預設為 1
num_attention_heads: int = 1
# 定義是否新增下采樣層,預設為 True
add_downsample: bool = True
# 定義是否使用線性投影,預設為 False
use_linear_projection: bool = False
# 定義是否僅使用跨注意力,預設為 False
only_cross_attention: bool = False
# 定義是否啟用記憶體高效注意力,預設為 False
use_memory_efficient_attention: bool = False
# 定義是否拆分頭維度,預設為 False
split_head_dim: bool = False
# 定義引數的資料型別,預設為 jnp.float32
dtype: jnp.dtype = jnp.float32
# 定義每個塊的變換器層數,預設為 1
transformer_layers_per_block: int = 1
# 設定模型的各個組成部分,包括殘差塊和注意力塊
def setup(self):
# 初始化殘差塊列表
resnets = []
# 初始化注意力塊列表
attentions = []
# 遍歷每一層,構建殘差塊和注意力塊
for i in range(self.num_layers):
# 第一層的輸入通道為 in_channels,其他層為 out_channels
in_channels = self.in_channels if i == 0 else self.out_channels
# 建立一個 FlaxResnetBlock2D 例項
res_block = FlaxResnetBlock2D(
in_channels=in_channels, # 輸入通道
out_channels=self.out_channels, # 輸出通道
dropout_prob=self.dropout, # 丟棄率
dtype=self.dtype, # 資料型別
)
# 將殘差塊新增到列表中
resnets.append(res_block)
# 建立一個 FlaxTransformer2DModel 例項
attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels, # 輸入通道
n_heads=self.num_attention_heads, # 注意力頭數
d_head=self.out_channels // self.num_attention_heads, # 每個頭的維度
depth=self.transformer_layers_per_block, # 每個塊的層數
use_linear_projection=self.use_linear_projection, # 是否使用線性投影
only_cross_attention=self.only_cross_attention, # 是否只使用交叉注意力
use_memory_efficient_attention=self.use_memory_efficient_attention, # 是否使用記憶體高效的注意力
split_head_dim=self.split_head_dim, # 是否拆分頭的維度
dtype=self.dtype, # 資料型別
)
# 將注意力塊新增到列表中
attentions.append(attn_block)
# 將殘差塊列表賦值給例項變數
self.resnets = resnets
# 將注意力塊列表賦值給例項變數
self.attentions = attentions
# 如果需要下采樣,則建立下采樣層
if self.add_downsample:
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
# 定義前向呼叫方法,處理隱藏狀態和編碼器隱藏狀態
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
# 初始化輸出狀態元組
output_states = ()
# 遍歷殘差塊和注意力塊並進行處理
for resnet, attn in zip(self.resnets, self.attentions):
# 透過殘差塊處理隱藏狀態
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
# 透過注意力塊處理隱藏狀態
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
# 將當前隱藏狀態新增到輸出狀態元組中
output_states += (hidden_states,)
# 如果需要下采樣,則進行下采樣
if self.add_downsample:
hidden_states = self.downsamplers_0(hidden_states)
# 將下采樣後的隱藏狀態新增到輸出狀態元組中
output_states += (hidden_states,)
# 返回最終的隱藏狀態和輸出狀態元組
return hidden_states, output_states
# 定義 Flax 2D 降維塊類,繼承自 nn.Module
class FlaxDownBlock2D(nn.Module):
r"""
Flax 2D downsizing block
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
# 宣告輸入輸出通道和其他引數
in_channels: int
out_channels: int
dropout: float = 0.0
num_layers: int = 1
add_downsample: bool = True
dtype: jnp.dtype = jnp.float32
# 設定方法,用於初始化模型的層
def setup(self):
# 建立空列表以儲存殘差塊
resnets = []
# 根據層數建立殘差塊
for i in range(self.num_layers):
# 第一個塊的輸入通道為 in_channels,其餘為 out_channels
in_channels = self.in_channels if i == 0 else self.out_channels
# 建立殘差塊例項
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
# 將殘差塊新增到列表中
resnets.append(res_block)
# 將列表賦值給例項屬性
self.resnets = resnets
# 如果需要,新增降取樣層
if self.add_downsample:
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
# 呼叫方法,執行前向傳播
def __call__(self, hidden_states, temb, deterministic=True):
# 建立空元組以儲存輸出狀態
output_states = ()
# 遍歷所有殘差塊進行前向傳播
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
# 將當前隱藏狀態新增到輸出狀態中
output_states += (hidden_states,)
# 如果需要,應用降取樣層
if self.add_downsample:
hidden_states = self.downsamplers_0(hidden_states)
# 將降取樣後的隱藏狀態新增到輸出狀態中
output_states += (hidden_states,)
# 返回最終的隱藏狀態和輸出狀態
return hidden_states, output_states
# 定義 Flax 交叉注意力 2D 上取樣塊類,繼承自 nn.Module
class FlaxCrossAttnUpBlock2D(nn.Module):
r"""
Cross Attention 2D Upsampling block - original architecture from Unet transformers:
https://arxiv.org/abs/2103.06104
# 定義引數的文件字串,描述各個引數的用途和型別
Parameters:
in_channels (:obj:`int`): # 輸入通道數
Input channels
out_channels (:obj:`int`): # 輸出通道數
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0): # Dropout 率,預設值為 0.0
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1): # 注意力塊的層數,預設值為 1
Number of attention blocks layers
num_attention_heads (:obj:`int`, *optional*, defaults to 1): # 每個空間變換塊的注意力頭數量,預設值為 1
Number of attention heads of each spatial transformer block
add_upsample (:obj:`bool`, *optional*, defaults to `True`): # 是否在每個最終輸出前新增上取樣層,預設值為 True
Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): # 啟用記憶體高效注意力,預設值為 False
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`): # 是否將頭維度拆分為新軸以進行自注意力計算,預設值為 False
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): # 資料型別引數,預設值為 jnp.float32
Parameters `dtype`
"""
in_channels: int # 輸入通道數的宣告
out_channels: int # 輸出通道數的宣告
prev_output_channel: int # 前一個輸出通道數的宣告
dropout: float = 0.0 # Dropout 率的宣告,預設值為 0.0
num_layers: int = 1 # 注意力層數的宣告,預設值為 1
num_attention_heads: int = 1 # 注意力頭數量的宣告,預設值為 1
add_upsample: bool = True # 是否新增上取樣的宣告,預設值為 True
use_linear_projection: bool = False # 是否使用線性投影的宣告,預設值為 False
only_cross_attention: bool = False # 是否僅使用交叉注意力的宣告,預設值為 False
use_memory_efficient_attention: bool = False # 是否啟用記憶體高效注意力的宣告,預設值為 False
split_head_dim: bool = False # 是否拆分頭維度的宣告,預設值為 False
dtype: jnp.dtype = jnp.float32 # 資料型別的宣告,預設值為 jnp.float32
transformer_layers_per_block: int = 1 # 每個塊的變換層數的宣告,預設值為 1
# 設定方法,初始化網路結構
def setup(self):
# 初始化空列表以儲存 ResNet 塊
resnets = []
# 初始化空列表以儲存注意力塊
attentions = []
# 遍歷每一層以建立相應的 ResNet 和注意力塊
for i in range(self.num_layers):
# 設定跳躍連線的通道數,最後一層使用輸入通道,否則使用輸出通道
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
# 設定當前 ResNet 塊的輸入通道,第一層使用前一層的輸出通道
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
# 建立 FlaxResnetBlock2D 例項
res_block = FlaxResnetBlock2D(
# 設定輸入通道為當前 ResNet 塊輸入通道加跳躍連線通道
in_channels=resnet_in_channels + res_skip_channels,
# 設定輸出通道為指定的輸出通道
out_channels=self.out_channels,
# 設定 dropout 機率
dropout_prob=self.dropout,
# 設定資料型別
dtype=self.dtype,
)
# 將建立的 ResNet 塊新增到列表中
resnets.append(res_block)
# 建立 FlaxTransformer2DModel 例項
attn_block = FlaxTransformer2DModel(
# 設定輸入通道為輸出通道
in_channels=self.out_channels,
# 設定注意力頭數
n_heads=self.num_attention_heads,
# 設定每個注意力頭的維度
d_head=self.out_channels // self.num_attention_heads,
# 設定 transformer 塊的深度
depth=self.transformer_layers_per_block,
# 設定是否使用線性投影
use_linear_projection=self.use_linear_projection,
# 設定是否僅使用交叉注意力
only_cross_attention=self.only_cross_attention,
# 設定是否使用記憶體高效的注意力機制
use_memory_efficient_attention=self.use_memory_efficient_attention,
# 設定是否分割頭部維度
split_head_dim=self.split_head_dim,
# 設定資料型別
dtype=self.dtype,
)
# 將建立的注意力塊新增到列表中
attentions.append(attn_block)
# 將 ResNet 列表儲存到例項屬性
self.resnets = resnets
# 將注意力列表儲存到例項屬性
self.attentions = attentions
# 如果需要新增上取樣層,則建立相應的 FlaxUpsample2D 例項
if self.add_upsample:
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
# 定義呼叫方法,接受隱藏狀態和其他引數
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
# 遍歷 ResNet 和注意力塊
for resnet, attn in zip(self.resnets, self.attentions):
# 從跳躍連線的隱藏狀態元組中取出最後一個狀態
res_hidden_states = res_hidden_states_tuple[-1]
# 更新跳躍連線的隱藏狀態元組,去掉最後一個狀態
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 將隱藏狀態與跳躍連線的隱藏狀態在最後一個軸上拼接
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
# 使用當前的 ResNet 塊處理隱藏狀態
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
# 使用當前的注意力塊處理隱藏狀態
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
# 如果需要新增上取樣,則使用上取樣層處理隱藏狀態
if self.add_upsample:
hidden_states = self.upsamplers_0(hidden_states)
# 返回處理後的隱藏狀態
return hidden_states
# 定義一個 2D 上取樣塊類,繼承自 nn.Module
class FlaxUpBlock2D(nn.Module):
r"""
Flax 2D upsampling block
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
prev_output_channel (:obj:`int`):
Output channels from the previous block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
# 定義輸入輸出通道和其他引數
in_channels: int
out_channels: int
prev_output_channel: int
dropout: float = 0.0
num_layers: int = 1
add_upsample: bool = True
dtype: jnp.dtype = jnp.float32
# 設定方法用於初始化塊的結構
def setup(self):
resnets = [] # 建立一個空列表用於儲存 ResNet 塊
# 遍歷每一層,建立 ResNet 塊
for i in range(self.num_layers):
# 計算跳躍連線通道數
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
# 設定輸入通道數
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
# 建立一個新的 FlaxResnetBlock2D 例項
res_block = FlaxResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block) # 將塊新增到列表中
self.resnets = resnets # 將列表賦值給例項變數
# 如果需要上取樣,初始化上取樣層
if self.add_upsample:
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
# 定義前向傳播方法
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
# 遍歷每個 ResNet 塊進行前向傳播
for resnet in self.resnets:
# 從元組中彈出最後的殘差隱藏狀態
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1] # 更新元組,去掉最後一項
# 連線當前隱藏狀態與殘差隱藏狀態
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
# 透過 ResNet 塊處理隱藏狀態
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
# 如果需要上取樣,呼叫上取樣層
if self.add_upsample:
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states # 返回處理後的隱藏狀態
# 定義一個 2D 中級交叉注意力塊類,繼承自 nn.Module
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
r"""
Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
# 定義引數的文件字串
Parameters:
in_channels (:obj:`int`): # 輸入通道數
Input channels
dropout (:obj:`float`, *optional*, defaults to 0.0): # Dropout比率,預設為0.0
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1): # 注意力層的數量,預設為1
Number of attention blocks layers
num_attention_heads (:obj:`int`, *optional*, defaults to 1): # 每個空間變換塊的注意力頭數量,預設為1
Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): # 是否啟用記憶體高效的注意力機制,預設為False
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`): # 是否將頭維度分割為新的軸以加速計算,預設為False
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): # 資料型別引數,預設為jnp.float32
Parameters `dtype`
"""
in_channels: int # 輸入通道數的型別
dropout: float = 0.0 # Dropout比率的預設值
num_layers: int = 1 # 注意力層數量的預設值
num_attention_heads: int = 1 # 注意力頭數量的預設值
use_linear_projection: bool = False # 是否使用線性投影的預設值
use_memory_efficient_attention: bool = False # 是否使用記憶體高效注意力的預設值
split_head_dim: bool = False # 是否分割頭維度的預設值
dtype: jnp.dtype = jnp.float32 # 資料型別的預設值
transformer_layers_per_block: int = 1 # 每個塊中的變換層數量的預設值
def setup(self): # 設定方法,用於初始化
# 至少會有一個ResNet塊
resnets = [ # 建立ResNet塊列表
FlaxResnetBlock2D( # 建立一個ResNet塊
in_channels=self.in_channels, # 輸入通道數
out_channels=self.in_channels, # 輸出通道數
dropout_prob=self.dropout, # Dropout機率
dtype=self.dtype, # 資料型別
)
]
attentions = [] # 初始化注意力塊列表
for _ in range(self.num_layers): # 遍歷指定的注意力層數
attn_block = FlaxTransformer2DModel( # 建立一個Transformer塊
in_channels=self.in_channels, # 輸入通道數
n_heads=self.num_attention_heads, # 注意力頭數量
d_head=self.in_channels // self.num_attention_heads, # 每個頭的維度
depth=self.transformer_layers_per_block, # 變換層深度
use_linear_projection=self.use_linear_projection, # 是否使用線性投影
use_memory_efficient_attention=self.use_memory_efficient_attention, # 是否使用記憶體高效注意力
split_head_dim=self.split_head_dim, # 是否分割頭維度
dtype=self.dtype, # 資料型別
)
attentions.append(attn_block) # 將注意力塊新增到列表中
res_block = FlaxResnetBlock2D( # 建立一個ResNet塊
in_channels=self.in_channels, # 輸入通道數
out_channels=self.in_channels, # 輸出通道數
dropout_prob=self.dropout, # Dropout機率
dtype=self.dtype, # 資料型別
)
resnets.append(res_block) # 將ResNet塊新增到列表中
self.resnets = resnets # 將ResNet塊列表賦值給例項屬性
self.attentions = attentions # 將注意力塊列表賦值給例項屬性
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): # 呼叫方法
hidden_states = self.resnets[0](hidden_states, temb) # 透過第一個ResNet塊處理隱藏狀態
for attn, resnet in zip(self.attentions, self.resnets[1:]): # 遍歷每個注意力塊和後續ResNet塊
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) # 處理隱藏狀態
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) # 再次處理隱藏狀態
return hidden_states # 返回處理後的隱藏狀態
.\diffusers\models\unets\unet_2d_condition.py
# 版權宣告,標明版權資訊和使用許可
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 按照 Apache License 2.0 版本進行許可
# Licensed under the Apache License, Version 2.0 (the "License");
# 你不得在未遵守許可的情況下使用此檔案
# you may not use this file except in compliance with the License.
# 你可以在以下網址獲得許可證的副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或書面同意,軟體以“按原樣”方式分發,不提供任何形式的擔保或條件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 檢視許可證以瞭解特定語言所適用的許可權和限制
# See the License for the specific language governing permissions and
# limitations under the License.
# 從 dataclasses 模組匯入 dataclass 裝飾器,用於簡化類的定義
from dataclasses import dataclass
# 匯入所需的型別註釋
from typing import Any, Dict, List, Optional, Tuple, Union
# 匯入 PyTorch 庫和相關模組
import torch
import torch.nn as nn
import torch.utils.checkpoint
# 從配置和載入器模組中匯入所需的類和函式
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, # 匯入與注意力機制相關的處理器
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from ..embeddings import (
GaussianFourierProjection, # 匯入多種嵌入方法
GLIGENTextBoundingboxProjection,
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from ..modeling_utils import ModelMixin # 匯入模型混合類
from .unet_2d_blocks import (
get_down_block, # 匯入下采樣塊的建構函式
get_mid_block, # 匯入中間塊的建構函式
get_up_block, # 匯入上取樣塊的建構函式
)
# 建立一個日誌記錄器,用於記錄模型相關資訊
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 定義 UNet2DConditionOutput 資料類,用於儲存 UNet2DConditionModel 的輸出
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
UNet2DConditionModel 的輸出。
引數:
sample (`torch.Tensor`,形狀為 `(batch_size, num_channels, height, width)`):
基於 `encoder_hidden_states` 輸入的隱藏狀態輸出,模型最後一層的輸出。
"""
sample: torch.Tensor = None # 定義一個樣本屬性,預設為 None
# 定義 UNet2DConditionModel 類,表示一個條件 2D UNet 模型
class UNet2DConditionModel(
ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
):
r"""
一個條件 2D UNet 模型,接受一個噪聲樣本、條件狀態和時間步,並返回樣本形狀的輸出。
該模型繼承自 [`ModelMixin`]。檢視超類文件以獲取其為所有模型實現的通用方法
(例如下載或儲存)。
"""
_supports_gradient_checkpointing = True # 表示該模型支援梯度檢查點
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] # 不進行拆分的模組列表
@register_to_config # 將該方法註冊到配置中
# 初始化方法,設定類的基本屬性
def __init__(
# 樣本大小,預設為 None
self,
sample_size: Optional[int] = None,
# 輸入通道數,預設為 4
in_channels: int = 4,
# 輸出通道數,預設為 4
out_channels: int = 4,
# 是否將輸入樣本中心化,預設為 False
center_input_sample: bool = False,
# 是否將正弦函式翻轉為餘弦函式,預設為 True
flip_sin_to_cos: bool = True,
# 頻率偏移量,預設為 0
freq_shift: int = 0,
# 向下取樣的塊型別,包含多種塊型別
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
# 中間塊的型別,預設為 UNet 的中間塊型別
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
# 向上取樣的塊型別,包含多種塊型別
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
# 是否僅使用交叉注意力,預設為 False
only_cross_attention: Union[bool, Tuple[bool]] = False,
# 每個塊的輸出通道數
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
# 每個塊的層數,預設為 2
layers_per_block: Union[int, Tuple[int]] = 2,
# 下采樣時的填充大小,預設為 1
downsample_padding: int = 1,
# 中間塊的縮放因子,預設為 1
mid_block_scale_factor: float = 1,
# dropout 機率,預設為 0.0
dropout: float = 0.0,
# 啟用函式型別,預設為 "silu"
act_fn: str = "silu",
# 歸一化的組數,預設為 32
norm_num_groups: Optional[int] = 32,
# 歸一化的 epsilon 值,預設為 1e-5
norm_eps: float = 1e-5,
# 交叉注意力的維度,預設為 1280
cross_attention_dim: Union[int, Tuple[int]] = 1280,
# 每個塊的變換層數,預設為 1
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
# 反向變換層的塊數,預設為 None
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
# 編碼器隱藏層的維度,預設為 None
encoder_hid_dim: Optional[int] = None,
# 編碼器隱藏層型別,預設為 None
encoder_hid_dim_type: Optional[str] = None,
# 注意力頭的維度,預設為 8
attention_head_dim: Union[int, Tuple[int]] = 8,
# 注意力頭的數量,預設為 None
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
# 是否使用雙交叉注意力,預設為 False
dual_cross_attention: bool = False,
# 是否使用線性投影,預設為 False
use_linear_projection: bool = False,
# 類嵌入型別,預設為 None
class_embed_type: Optional[str] = None,
# 附加嵌入型別,預設為 None
addition_embed_type: Optional[str] = None,
# 附加時間嵌入維度,預設為 None
addition_time_embed_dim: Optional[int] = None,
# 類嵌入數量,預設為 None
num_class_embeds: Optional[int] = None,
# 是否上溯注意力,預設為 False
upcast_attention: bool = False,
# ResNet 時間縮放偏移型別,預設為 "default"
resnet_time_scale_shift: str = "default",
# ResNet 是否跳過時間啟用,預設為 False
resnet_skip_time_act: bool = False,
# ResNet 輸出縮放因子,預設為 1.0
resnet_out_scale_factor: float = 1.0,
# 時間嵌入型別,預設為 "positional"
time_embedding_type: str = "positional",
# 時間嵌入維度,預設為 None
time_embedding_dim: Optional[int] = None,
# 時間嵌入啟用函式,預設為 None
time_embedding_act_fn: Optional[str] = None,
# 時間步後啟用函式,預設為 None
timestep_post_act: Optional[str] = None,
# 時間條件投影維度,預設為 None
time_cond_proj_dim: Optional[int] = None,
# 輸入卷積核大小,預設為 3
conv_in_kernel: int = 3,
# 輸出卷積核大小,預設為 3
conv_out_kernel: int = 3,
# 投影類嵌入輸入維度,預設為 None
projection_class_embeddings_input_dim: Optional[int] = None,
# 注意力型別,預設為 "default"
attention_type: str = "default",
# 類嵌入是否拼接,預設為 False
class_embeddings_concat: bool = False,
# 中間塊是否僅使用交叉注意力,預設為 None
mid_block_only_cross_attention: Optional[bool] = None,
# 交叉注意力歸一化型別,預設為 None
cross_attention_norm: Optional[str] = None,
# 附加嵌入型別的頭數量,預設為 64
addition_embed_type_num_heads: int = 64,
# 定義一個私有方法,用於檢查配置引數
def _check_config(
self,
# 定義下行塊型別的元組,表示模型的結構
down_block_types: Tuple[str],
# 定義上行塊型別的元組,表示模型的結構
up_block_types: Tuple[str],
# 定義僅使用交叉注意力的標誌,可以是布林值或布林值的元組
only_cross_attention: Union[bool, Tuple[bool]],
# 定義每個塊的輸出通道數的元組,表示層的寬度
block_out_channels: Tuple[int],
# 定義每個塊的層數,可以是整數或整數的元組
layers_per_block: Union[int, Tuple[int]],
# 定義交叉注意力維度,可以是整數或整數的元組
cross_attention_dim: Union[int, Tuple[int]],
# 定義每個塊的變換器層數,可以是整數、整數的元組或元組的元組
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
# 定義是否反轉變換器層的布林值
reverse_transformer_layers_per_block: bool,
# 定義注意力頭的維度,表示注意力的解析度
attention_head_dim: int,
# 定義注意力頭的數量,可以是可選的整數或整數的元組
num_attention_heads: Optional[Union[int, Tuple[int]],
):
# 檢查 down_block_types 和 up_block_types 的長度是否相同
if len(down_block_types) != len(up_block_types):
# 如果不同,丟擲值錯誤並提供詳細資訊
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
# 檢查 block_out_channels 和 down_block_types 的長度是否相同
if len(block_out_channels) != len(down_block_types):
# 如果不同,丟擲值錯誤並提供詳細資訊
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
# 檢查 only_cross_attention 是否為布林值且長度與 down_block_types 相同
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
# 如果不滿足條件,丟擲值錯誤並提供詳細資訊
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
# 檢查 num_attention_heads 是否為整數且長度與 down_block_types 相同
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
# 如果不滿足條件,丟擲值錯誤並提供詳細資訊
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# 檢查 attention_head_dim 是否為整數且長度與 down_block_types 相同
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
# 如果不滿足條件,丟擲值錯誤並提供詳細資訊
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
# 檢查 cross_attention_dim 是否為列表且長度與 down_block_types 相同
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
# 如果不滿足條件,丟擲值錯誤並提供詳細資訊
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
# 檢查 layers_per_block 是否為整數且長度與 down_block_types 相同
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
# 如果不滿足條件,丟擲值錯誤並提供詳細資訊
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
# 檢查 transformer_layers_per_block 是否為列表且 reverse_transformer_layers_per_block 為 None
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
# 遍歷 transformer_layers_per_block 中的每個層
for layer_number_per_block in transformer_layers_per_block:
# 檢查每個層是否為列表
if isinstance(layer_number_per_block, list):
# 如果是,則丟擲值錯誤,提示需要提供 reverse_transformer_layers_per_block
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# 定義設定時間投影的私有方法
def _set_time_proj(
self,
# 時間嵌入型別
time_embedding_type: str,
# 塊輸出通道數
block_out_channels: int,
# 是否翻轉正弦和餘弦
flip_sin_to_cos: bool,
# 頻率偏移
freq_shift: float,
# 時間嵌入維度
time_embedding_dim: int,
# 返回時間嵌入維度和時間步輸入維度的元組
) -> Tuple[int, int]:
# 判斷時間嵌入型別是否為傅立葉
if time_embedding_type == "fourier":
# 計算時間嵌入維度,預設為 block_out_channels[0] * 2
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
# 確保時間嵌入維度為偶數
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
# 初始化高斯傅立葉投影,設定相關引數
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
# 設定時間步輸入維度為時間嵌入維度
timestep_input_dim = time_embed_dim
# 判斷時間嵌入型別是否為位置編碼
elif time_embedding_type == "positional":
# 計算時間嵌入維度,預設為 block_out_channels[0] * 4
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
# 初始化時間步物件,設定相關引數
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
# 設定時間步輸入維度為 block_out_channels[0]
timestep_input_dim = block_out_channels[0]
# 如果時間嵌入型別不合法,丟擲錯誤
else:
raise ValueError(
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
)
# 返回時間嵌入維度和時間步輸入維度
return time_embed_dim, timestep_input_dim
# 定義設定編碼器隱藏投影的方法
def _set_encoder_hid_proj(
self,
encoder_hid_dim_type: Optional[str],
cross_attention_dim: Union[int, Tuple[int]],
encoder_hid_dim: Optional[int],
):
# 如果編碼器隱藏維度型別為空且隱藏維度已定義
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
# 預設將編碼器隱藏維度型別設為'text_proj'
encoder_hid_dim_type = "text_proj"
# 註冊編碼器隱藏維度型別到配置中
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
# 記錄資訊日誌
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
# 如果編碼器隱藏維度為空且隱藏維度型別已定義,丟擲錯誤
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
# 判斷編碼器隱藏維度型別是否為'text_proj'
if encoder_hid_dim_type == "text_proj":
# 初始化線性投影層,輸入維度為encoder_hid_dim,輸出維度為cross_attention_dim
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
# 判斷編碼器隱藏維度型別是否為'text_image_proj'
elif encoder_hid_dim_type == "text_image_proj":
# 初始化文字-影像投影物件,設定相關引數
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
# 判斷編碼器隱藏維度型別是否為'image_proj'
elif encoder_hid_dim_type == "image_proj":
# 初始化影像投影物件,設定相關引數
self.encoder_hid_proj = ImageProjection(
image_embed_dim=encoder_hid_dim,
cross_attention_dim=cross_attention_dim,
)
# 如果編碼器隱藏維度型別不合法,丟擲錯誤
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
# 如果都不符合,將編碼器隱藏投影設為None
else:
self.encoder_hid_proj = None
# 設定類嵌入的私有方法
def _set_class_embedding(
self,
class_embed_type: Optional[str], # 嵌入型別,可能為 None 或特定字串
act_fn: str, # 啟用函式的名稱
num_class_embeds: Optional[int], # 類嵌入數量,可能為 None
projection_class_embeddings_input_dim: Optional[int], # 投影類嵌入輸入維度,可能為 None
time_embed_dim: int, # 時間嵌入的維度
timestep_input_dim: int, # 時間步輸入的維度
):
# 如果嵌入型別為 None 且類嵌入數量不為 None
if class_embed_type is None and num_class_embeds is not None:
# 建立嵌入層,大小為類嵌入數量和時間嵌入維度
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
# 如果嵌入型別為 "timestep"
elif class_embed_type == "timestep":
# 建立時間步嵌入物件
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
# 如果嵌入型別為 "identity"
elif class_embed_type == "identity":
# 建立恆等層,輸入和輸出維度相同
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
# 如果嵌入型別為 "projection"
elif class_embed_type == "projection":
# 如果投影類嵌入輸入維度為 None,丟擲錯誤
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# 建立投影時間步嵌入物件
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
# 如果嵌入型別為 "simple_projection"
elif class_embed_type == "simple_projection":
# 如果投影類嵌入輸入維度為 None,丟擲錯誤
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
)
# 建立線性層作為簡單投影
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
# 如果沒有匹配的嵌入型別
else:
# 將類嵌入設定為 None
self.class_embedding = None
# 設定附加嵌入的私有方法
def _set_add_embedding(
self,
addition_embed_type: str, # 附加嵌入型別
addition_embed_type_num_heads: int, # 附加嵌入型別的頭數
addition_time_embed_dim: Optional[int], # 附加時間嵌入維度,可能為 None
flip_sin_to_cos: bool, # 是否翻轉正弦到餘弦
freq_shift: float, # 頻率偏移量
cross_attention_dim: Optional[int], # 交叉注意力維度,可能為 None
encoder_hid_dim: Optional[int], # 編碼器隱藏維度,可能為 None
projection_class_embeddings_input_dim: Optional[int], # 投影類嵌入輸入維度,可能為 None
time_embed_dim: int, # 時間嵌入維度
):
# 檢查附加嵌入型別是否為 "text"
if addition_embed_type == "text":
# 如果編碼器隱藏維度不為 None,則使用該維度
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
# 否則使用交叉注意力維度
else:
text_time_embedding_from_dim = cross_attention_dim
# 建立文字時間嵌入物件
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
# 檢查附加嵌入型別是否為 "text_image"
elif addition_embed_type == "text_image":
# text_embed_dim 和 image_embed_dim 不必是 `cross_attention_dim`,為了避免 __init__ 過於繁雜
# 在這裡設定為 `cross_attention_dim`,因為這是當前唯一使用情況的所需維度 (Kandinsky 2.1)
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
# 檢查附加嵌入型別是否為 "text_time"
elif addition_embed_type == "text_time":
# 建立時間投影物件
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
# 建立時間嵌入物件
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
# 檢查附加嵌入型別是否為 "image"
elif addition_embed_type == "image":
# Kandinsky 2.2
# 建立影像時間嵌入物件
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
# 檢查附加嵌入型別是否為 "image_hint"
elif addition_embed_type == "image_hint":
# Kandinsky 2.2 ControlNet
# 建立影像提示時間嵌入物件
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
# 檢查附加嵌入型別是否為 None 以外的值
elif addition_embed_type is not None:
# 丟擲值錯誤,提示無效的附加嵌入型別
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
# 定義一個屬性方法,用於設定位置網路
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
# 檢查注意力型別是否為 "gated" 或 "gated-text-image"
if attention_type in ["gated", "gated-text-image"]:
positive_len = 768 # 預設的正向長度
# 如果交叉注意力維度是整數,則使用該值
if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim
# 如果交叉注意力維度是列表或元組,則使用第一個值
elif isinstance(cross_attention_dim, (list, tuple)):
positive_len = cross_attention_dim[0]
# 根據注意力型別確定特徵型別
feature_type = "text-only" if attention_type == "gated" else "text-image"
# 建立 GLIGEN 文字邊界框投影物件
self.position_net = GLIGENTextBoundingboxProjection(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)
# 定義一個屬性
@property
# 定義一個方法,返回一個字典,包含模型中所有的注意力處理器
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# 初始化一個空字典,用於儲存注意力處理器
processors = {}
# 定義一個遞迴函式,用於新增處理器到字典
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 檢查模組是否有獲取處理器的方法
if hasattr(module, "get_processor"):
# 將處理器新增到字典中,鍵為名稱,值為處理器
processors[f"{name}.processor"] = module.get_processor()
# 遍歷模組的所有子模組
for sub_name, child in module.named_children():
# 遞迴呼叫,處理子模組
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回更新後的處理器字典
return processors
# 遍歷當前模組的所有子模組
for name, module in self.named_children():
# 呼叫遞迴函式,新增處理器
fn_recursive_add_processors(name, module, processors)
# 返回包含所有處理器的字典
return processors
# 定義一個方法,設定用於計算注意力的處理器
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
# 獲取當前處理器的數量
count = len(self.attn_processors.keys())
# 如果傳入的是字典,且字典長度與注意力層數量不匹配,丟擲錯誤
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
# 定義一個遞迴函式,用於設定處理器
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 檢查模組是否有設定處理器的方法
if hasattr(module, "set_processor"):
# 如果處理器不是字典,直接設定
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 從字典中彈出對應的處理器並設定
module.set_processor(processor.pop(f"{name}.processor"))
# 遍歷模組的所有子模組
for sub_name, child in module.named_children():
# 遞迴呼叫,設定子模組的處理器
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍歷當前模組的所有子模組
for name, module in self.named_children():
# 呼叫遞迴函式,設定處理器
fn_recursive_attn_processor(name, module, processor)
# 定義設定預設注意力處理器的方法
def set_default_attn_processor(self):
"""
禁用自定義注意力處理器並設定預設的注意力實現。
"""
# 檢查所有注意力處理器是否屬於新增的鍵值注意力處理器類
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 建立新增鍵值注意力處理器的例項
processor = AttnAddedKVProcessor()
# 檢查所有注意力處理器是否屬於交叉注意力處理器類
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 建立標準注意力處理器的例項
processor = AttnProcessor()
else:
# 如果注意力處理器型別不匹配,則丟擲錯誤
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
# 設定選定的注意力處理器
self.set_attn_processor(processor)
# 定義設定梯度檢查點的方法
def _set_gradient_checkpointing(self, module, value=False):
# 如果模組具有梯度檢查點屬性,則設定其值
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# 定義啟用 FreeU 機制的方法
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""啟用 FreeU 機制,詳細資訊請見 https://arxiv.org/abs/2309.11497。
在縮放因子後面的字尾表示它們被應用的階段塊。
請參考 [官方倉庫](https://github.com/ChenyangSi/FreeU) 以獲取已知在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中效果良好的值組合。
引數:
s1 (`float`):
階段 1 的縮放因子,用於減弱跳躍特徵的貢獻,以減輕增強去噪過程中的“過平滑效應”。
s2 (`float`):
階段 2 的縮放因子,用於減弱跳躍特徵的貢獻,以減輕增強去噪過程中的“過平滑效應”。
b1 (`float`): 階段 1 的縮放因子,用於增強骨幹特徵的貢獻。
b2 (`float`): 階段 2 的縮放因子,用於增強骨幹特徵的貢獻。
"""
# 遍歷上取樣塊並設定相應的縮放因子
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1) # 設定階段 1 的縮放因子
setattr(upsample_block, "s2", s2) # 設定階段 2 的縮放因子
setattr(upsample_block, "b1", b1) # 設定階段 1 的骨幹縮放因子
setattr(upsample_block, "b2", b2) # 設定階段 2 的骨幹縮放因子
# 定義禁用 FreeU 機制的方法
def disable_freeu(self):
"""禁用 FreeU 機制。"""
freeu_keys = {"s1", "s2", "b1", "b2"} # 定義 FreeU 相關的鍵
# 遍歷上取樣塊
for i, upsample_block in enumerate(self.up_blocks):
# 遍歷每個 FreeU 鍵
for k in freeu_keys:
# 如果上取樣塊具有該鍵的屬性或其值不為 None,則將其值設定為 None
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
# 定義一個方法,用於啟用融合的 QKV 投影
def fuse_qkv_projections(self):
"""
啟用融合的 QKV 投影。對於自注意力模組,所有投影矩陣(即查詢、鍵、值)都被融合。
對於交叉注意力模組,鍵和值的投影矩陣被融合。
<Tip warning={true}>
此 API 是 🧪 實驗性的。
</Tip>
"""
# 初始化原始注意力處理器為 None
self.original_attn_processors = None
# 遍歷注意力處理器,檢查是否包含“Added”字樣
for _, attn_processor in self.attn_processors.items():
# 如果發現新增的 KV 投影,丟擲錯誤
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# 儲存當前的注意力處理器
self.original_attn_processors = self.attn_processors
# 遍歷模組,查詢型別為 Attention 的模組
for module in self.modules():
if isinstance(module, Attention):
# 啟用投影融合
module.fuse_projections(fuse=True)
# 設定注意力處理器為融合的處理器
self.set_attn_processor(FusedAttnProcessor2_0())
# 定義一個方法,用於禁用已啟用的融合 QKV 投影
def unfuse_qkv_projections(self):
"""禁用已啟用的融合 QKV 投影。
<Tip warning={true}>
此 API 是 🧪 實驗性的。
</Tip>
"""
# 如果原始注意力處理器不為 None,則恢復到原始處理器
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
# 定義一個方法,用於獲取時間嵌入
def get_time_embed(
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
) -> Optional[torch.Tensor]:
# 將時間步長賦值給 timesteps
timesteps = timestep
# 如果 timesteps 不是張量
if not torch.is_tensor(timesteps):
# TODO: 這需要在 CPU 和 GPU 之間同步。因此,如果可以的話,儘量將 timesteps 作為張量傳遞
# 這將是使用 `match` 語句的好例子(Python 3.10+)
is_mps = sample.device.type == "mps" # 檢查裝置型別是否為 MPS
# 根據時間步長型別設定資料型別
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 # 浮點數型別
else:
dtype = torch.int32 if is_mps else torch.int64 # 整數型別
# 將 timesteps 轉換為張量
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
# 如果 timesteps 是標量(零維張量),則擴充套件維度
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) # 增加一個維度並轉移到樣本裝置
# 將 timesteps 廣播到與樣本批次維度相容的方式
timesteps = timesteps.expand(sample.shape[0]) # 擴充套件到批次大小
# 透過時間投影獲得時間嵌入
t_emb = self.time_proj(timesteps)
# `Timesteps` 不包含任何權重,總是返回 f32 張量
# 但時間嵌入可能實際在 fp16 中執行,因此需要進行型別轉換。
# 可能有更好的方法來封裝這一點。
t_emb = t_emb.to(dtype=sample.dtype) # 轉換 t_emb 的資料型別
# 返回時間嵌入
return t_emb
# 獲取類嵌入的方法,接受樣本張量和可選的類標籤
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
# 初始化類嵌入為 None
class_emb = None
# 檢查類嵌入是否存在
if self.class_embedding is not None:
# 如果類標籤為 None,丟擲錯誤
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
# 檢查類嵌入型別是否為時間步
if self.config.class_embed_type == "timestep":
# 將類標籤透過時間投影處理
class_labels = self.time_proj(class_labels)
# `Timesteps` 不包含權重,總是返回 f32 張量
# 可能有更好的方式來封裝這一點
class_labels = class_labels.to(dtype=sample.dtype)
# 獲取類嵌入並轉換為與樣本相同的資料型別
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
# 返回類嵌入
return class_emb
# 獲取增強嵌入的方法,接受嵌入張量、編碼器隱藏狀態和額外條件引數
def get_aug_embed(
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
# 處理編碼器隱藏狀態的方法,接受編碼器隱藏狀態和額外條件引數
def process_encoder_hidden_states(
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
# 定義返回型別為 torch.Tensor
) -> torch.Tensor:
# 檢查是否存在隱藏層投影,並且配置為 "text_proj"
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
# 使用文字投影對編碼隱藏狀態進行轉換
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
# 檢查是否存在隱藏層投影,並且配置為 "text_image_proj"
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
# 檢查條件中是否包含 "image_embeds"
if "image_embeds" not in added_cond_kwargs:
# 丟擲錯誤提示缺少必要引數
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
# 獲取傳入的影像嵌入
image_embeds = added_cond_kwargs.get("image_embeds")
# 對編碼隱藏狀態和影像嵌入進行投影轉換
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
# 檢查是否存在隱藏層投影,並且配置為 "image_proj"
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
# 檢查條件中是否包含 "image_embeds"
if "image_embeds" not in added_cond_kwargs:
# 丟擲錯誤提示缺少必要引數
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
# 獲取傳入的影像嵌入
image_embeds = added_cond_kwargs.get("image_embeds")
# 使用影像嵌入對編碼隱藏狀態進行投影轉換
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
# 檢查是否存在隱藏層投影,並且配置為 "ip_image_proj"
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
# 檢查條件中是否包含 "image_embeds"
if "image_embeds" not in added_cond_kwargs:
# 丟擲錯誤提示缺少必要引數
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
# 如果存在文字編碼器的隱藏層投影,則對編碼隱藏狀態進行投影轉換
if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
# 獲取傳入的影像嵌入
image_embeds = added_cond_kwargs.get("image_embeds")
# 對影像嵌入進行投影轉換
image_embeds = self.encoder_hid_proj(image_embeds)
# 將編碼隱藏狀態和影像嵌入打包成元組
encoder_hidden_states = (encoder_hidden_states, image_embeds)
# 返回最終的編碼隱藏狀態
return encoder_hidden_states
# 定義前向傳播函式
def forward(
# 輸入的樣本資料,型別為 PyTorch 張量
sample: torch.Tensor,
# 當前時間步,型別可以是張量、浮點數或整數
timestep: Union[torch.Tensor, float, int],
# 編碼器的隱藏狀態,型別為 PyTorch 張量
encoder_hidden_states: torch.Tensor,
# 可選的類別標籤,型別為 PyTorch 張量
class_labels: Optional[torch.Tensor] = None,
# 可選的時間步條件,型別為 PyTorch 張量
timestep_cond: Optional[torch.Tensor] = None,
# 可選的注意力掩碼,型別為 PyTorch 張量
attention_mask: Optional[torch.Tensor] = None,
# 可選的交叉注意力引數,型別為字典,包含額外的關鍵字引數
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可選的附加條件引數,型別為字典,鍵為字串,值為 PyTorch 張量
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
# 可選的下層塊附加殘差,型別為元組,包含 PyTorch 張量
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
# 可選的中間塊附加殘差,型別為 PyTorch 張量
mid_block_additional_residual: Optional[torch.Tensor] = None,
# 可選的下層內部塊附加殘差,型別為元組,包含 PyTorch 張量
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
# 可選的編碼器注意力掩碼,型別為 PyTorch 張量
encoder_attention_mask: Optional[torch.Tensor] = None,
# 返回結果的標誌,布林值,預設值為 True
return_dict: bool = True,
.\diffusers\models\unets\unet_2d_condition_flax.py
# 版權宣告,表明該檔案的版權所有者及相關資訊
#
# 根據 Apache License 2.0 版本的許可協議
# 除非遵守該許可協議,否則不得使用本檔案
# 可以在以下地址獲取許可證副本
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面協議另有約定,軟體在“按現狀”基礎上分發,
# 不提供任何明示或暗示的保證或條件
# 請參閱許可證瞭解管理許可權和限制的具體條款
from typing import Dict, Optional, Tuple, Union # 從 typing 模組匯入型別註釋工具
import flax # 匯入 flax 庫用於構建神經網路
import flax.linen as nn # 從 flax 中匯入 linen 模組,方便定義神經網路層
import jax # 匯入 jax 庫用於高效數值計算
import jax.numpy as jnp # 匯入 jax 的 numpy 模組,提供張量操作功能
from flax.core.frozen_dict import FrozenDict # 從 flax 匯入 FrozenDict,用於不可變字典
from ...configuration_utils import ConfigMixin, flax_register_to_config # 匯入配置相關工具
from ...utils import BaseOutput # 匯入基礎輸出類
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps # 匯入時間步嵌入相關類
from ..modeling_flax_utils import FlaxModelMixin # 匯入模型混合類
from .unet_2d_blocks_flax import ( # 匯入 UNet 的不同構建塊
FlaxCrossAttnDownBlock2D, # 匯入交叉注意力下采樣塊
FlaxCrossAttnUpBlock2D, # 匯入交叉注意力上取樣塊
FlaxDownBlock2D, # 匯入下采樣塊
FlaxUNetMidBlock2DCrossAttn, # 匯入中間塊,帶有交叉注意力
FlaxUpBlock2D, # 匯入上取樣塊
)
@flax.struct.dataclass # 使用 flax 的資料類裝飾器
class FlaxUNet2DConditionOutput(BaseOutput): # 定義 UNet 條件輸出類,繼承自基礎輸出類
"""
[`FlaxUNet2DConditionModel`] 的輸出。
引數:
sample (`jnp.ndarray` 的形狀為 `(batch_size, num_channels, height, width)`):
基於 `encoder_hidden_states` 輸入的隱藏狀態輸出。模型最後一層的輸出。
"""
sample: jnp.ndarray # 定義輸出樣本,資料型別為 jnp.ndarray
@flax_register_to_config # 使用裝飾器將模型註冊到配置中
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): # 定義條件 UNet 模型類,繼承多個混合類
r"""
一個條件 2D UNet 模型,接收噪聲樣本、條件狀態和時間步,並返回樣本形狀的輸出。
此模型繼承自 [`FlaxModelMixin`]。請檢視超類文件以瞭解其通用方法
(例如下載或儲存)。
此模型也是 Flax Linen 的 [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
子類。將其作為常規 Flax Linen 模組使用,具體使用和行為請參閱 Flax 文件。
支援以下 JAX 特性:
- [即時編譯 (JIT)](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [自動微分](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [向量化](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [並行化](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
# 引數說明部分
Parameters:
# 輸入樣本的大小,型別為整型,選填引數
sample_size (`int`, *optional*):
The size of the input sample.
# 輸入樣本的通道數,型別為整型,預設為4
in_channels (`int`, *optional*, defaults to 4):
The number of channels in the input sample.
# 輸出的通道數,型別為整型,預設為4
out_channels (`int`, *optional*, defaults to 4):
The number of channels in the output.
# 使用的下采樣塊的元組,型別為字串元組,預設為特定的下采樣塊
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
The tuple of downsample blocks to use.
# 使用的上取樣塊的元組,型別為字串元組,預設為特定的上取樣塊
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
# UNet中間塊的型別,型別為字串,預設為"UNetMidBlock2DCrossAttn"
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer
is skipped.
# 每個塊的輸出通道的元組,型別為整型元組,預設為特定的輸出通道
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
# 每個塊的層數,型別為整型,預設為2
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
# 注意力頭的維度,可以是整型或整型元組,預設為8
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads.
# 注意力頭的數量,可以是整型或整型元組,選填引數
num_attention_heads (`int` or `Tuple[int]`, *optional*):
The number of attention heads.
# 交叉注意力特徵的維度,型別為整型,預設為768
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
# dropout的機率,型別為浮點數,預設為0
dropout (`float`, *optional*, defaults to 0):
Dropout probability for down, up and bottleneck blocks.
# 是否在時間嵌入中將正弦轉換為餘弦,型別為布林值,預設為True
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
# 應用於時間嵌入的頻率偏移,型別為整型,預設為0
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
# 是否啟用記憶體高效的注意力機制,型別為布林值,預設為False
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
# 是否將頭維度拆分為新的軸進行自注意力計算,型別為布林值,預設為False
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
# 定義樣本大小,預設為32
sample_size: int = 32
# 定義輸入通道數,預設為4
in_channels: int = 4
# 定義輸出通道數,預設為4
out_channels: int = 4
# 定義下采樣塊的型別元組
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", # 第一個下采樣塊
"CrossAttnDownBlock2D", # 第二個下采樣塊
"CrossAttnDownBlock2D", # 第三個下采樣塊
"DownBlock2D", # 第四個下采樣塊
)
# 定義上取樣塊的型別元組
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
# 定義中間塊型別,預設為"UNetMidBlock2DCrossAttn"
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn"
# 定義是否只使用交叉注意力,預設為False
only_cross_attention: Union[bool, Tuple[bool]] = False
# 定義每個塊的輸出通道元組
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
# 每個塊的層數設為 2
layers_per_block: int = 2
# 注意力頭的維度設為 8
attention_head_dim: Union[int, Tuple[int, ...]] = 8
# 可選的注意力頭數量,預設為 None
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
# 跨注意力的維度設為 1280
cross_attention_dim: int = 1280
# dropout 比率設為 0.0
dropout: float = 0.0
# 是否使用線性投影,預設為 False
use_linear_projection: bool = False
# 資料型別設為 float32
dtype: jnp.dtype = jnp.float32
# flip_sin_to_cos 設為 True
flip_sin_to_cos: bool = True
# 頻移設為 0
freq_shift: int = 0
# 是否使用記憶體高效的注意力,預設為 False
use_memory_efficient_attention: bool = False
# 是否拆分頭維度,預設為 False
split_head_dim: bool = False
# 每個塊的變換層數設為 1
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1
# 可選的附加嵌入型別,預設為 None
addition_embed_type: Optional[str] = None
# 可選的附加時間嵌入維度,預設為 None
addition_time_embed_dim: Optional[int] = None
# 附加嵌入型別的頭數量設為 64
addition_embed_type_num_heads: int = 64
# 可選的投影類嵌入輸入維度,預設為 None
projection_class_embeddings_input_dim: Optional[int] = None
# 初始化權重函式,接受隨機數生成器作為引數
def init_weights(self, rng: jax.Array) -> FrozenDict:
# 初始化輸入張量的形狀
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
# 建立全零的輸入樣本
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
# 建立全一的時間步張量
timesteps = jnp.ones((1,), dtype=jnp.int32)
# 初始化編碼器的隱藏狀態
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
# 分割隨機數生成器,用於引數和 dropout
params_rng, dropout_rng = jax.random.split(rng)
# 建立隨機數字典
rngs = {"params": params_rng, "dropout": dropout_rng}
# 初始化附加條件關鍵字引數
added_cond_kwargs = None
# 判斷嵌入型別是否為 "text_time"
if self.addition_embed_type == "text_time":
# 透過反向計算獲取期望的文字嵌入維度
is_refiner = (
5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
== self.config.projection_class_embeddings_input_dim
)
# 確定微條件的數量
num_micro_conditions = 5 if is_refiner else 6
# 計算文字嵌入維度
text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
num_micro_conditions * self.config.addition_time_embed_dim
)
# 計算時間 ID 的通道數和維度
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
# 建立附加條件關鍵字引數字典
added_cond_kwargs = {
"text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
"time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
}
# 返回初始化後的引數字典
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
# 定義呼叫函式,接收多個輸入引數
def __call__(
self,
sample: jnp.ndarray,
timesteps: Union[jnp.ndarray, float, int],
encoder_hidden_states: jnp.ndarray,
# 可選的附加條件關鍵字引數
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
# 可選的下塊附加殘差
down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None,
# 可選的中塊附加殘差
mid_block_additional_residual: Optional[jnp.ndarray] = None,
# 是否返回字典,預設為 True
return_dict: bool = True,
# 是否為訓練模式,預設為 False
train: bool = False,
.\diffusers\models\unets\unet_3d_blocks.py
# 版權所有 2024 HuggingFace 團隊。所有權利保留。
#
# 根據 Apache 許可證 2.0 版本("許可證")授權;
# 除非遵守許可證,否則您不得使用此檔案。
# 您可以在以下位置獲得許可證副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面同意,否則根據許可證分發的軟體是按“原樣”基礎分發,
# 不提供任何形式的保證或條件,無論是明示或暗示。
# 有關許可證的具體許可權和限制,請參閱許可證。
# 匯入型別提示中的任何型別
from typing import Any, Dict, Optional, Tuple, Union
# 匯入 PyTorch 庫
import torch
# 從 PyTorch 匯入神經網路模組
from torch import nn
# 匯入實用工具函式,包括棄用和日誌記錄
from ...utils import deprecate, is_torch_version, logging
# 匯入 PyTorch 相關的工具函式
from ...utils.torch_utils import apply_freeu
# 匯入注意力機制相關的類
from ..attention import Attention
# 匯入 ResNet 相關的類
from ..resnet import (
Downsample2D, # 匯入 2D 下采樣模組
ResnetBlock2D, # 匯入 2D ResNet 塊
SpatioTemporalResBlock, # 匯入時空 ResNet 塊
TemporalConvLayer, # 匯入時間卷積層
Upsample2D, # 匯入 2D 上取樣模組
)
# 匯入 2D 變換器模型
from ..transformers.transformer_2d import Transformer2DModel
# 匯入時間相關的變換器模型
from ..transformers.transformer_temporal import (
TransformerSpatioTemporalModel, # 匯入時空變換器模型
TransformerTemporalModel, # 匯入時間變換器模型
)
# 匯入運動模型的 UNet 相關類
from .unet_motion_model import (
CrossAttnDownBlockMotion, # 匯入交叉注意力下塊運動類
CrossAttnUpBlockMotion, # 匯入交叉注意力上塊運動類
DownBlockMotion, # 匯入下塊運動類
UNetMidBlockCrossAttnMotion, # 匯入中間塊交叉注意力運動類
UpBlockMotion, # 匯入上塊運動類
)
# 建立一個日誌記錄器,使用當前模組的名稱
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 定義 DownBlockMotion 類,繼承自 DownBlockMotion
class DownBlockMotion(DownBlockMotion):
# 初始化方法,接受任意引數和關鍵字引數
def __init__(self, *args, **kwargs):
# 設定棄用訊息,提醒使用者變更
deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead."
# 呼叫棄用函式,記錄棄用資訊
deprecate("DownBlockMotion", "1.0.0", deprecation_message)
# 呼叫父類的初始化方法
super().__init__(*args, **kwargs)
# 定義 CrossAttnDownBlockMotion 類,繼承自 CrossAttnDownBlockMotion
class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion):
# 初始化方法,接受任意引數和關鍵字引數
def __init__(self, *args, **kwargs):
# 設定棄用訊息,提醒使用者變更
deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead."
# 呼叫棄用函式,記錄棄用資訊
deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message)
# 呼叫父類的初始化方法
super().__init__(*args, **kwargs)
# 定義 UpBlockMotion 類,繼承自 UpBlockMotion
class UpBlockMotion(UpBlockMotion):
# 初始化方法,接受任意引數和關鍵字引數
def __init__(self, *args, **kwargs):
# 設定棄用訊息,提醒使用者變更
deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead."
# 呼叫棄用函式,記錄棄用資訊
deprecate("UpBlockMotion", "1.0.0", deprecation_message)
# 呼叫父類的初始化方法
super().__init__(*args, **kwargs)
# 定義 CrossAttnUpBlockMotion 類,繼承自 CrossAttnUpBlockMotion
class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion):
# 初始化方法,用於建立類的例項
def __init__(self, *args, **kwargs):
# 定義一個關於匯入的棄用警告資訊
deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead."
# 呼叫棄用警告函式,記錄該功能的棄用資訊及版本
deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message)
# 呼叫父類的初始化方法,傳遞引數以初始化父類部分
super().__init__(*args, **kwargs)
# 定義一個名為 UNetMidBlockCrossAttnMotion 的類,繼承自同名父類
class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion):
# 初始化方法,接收可變引數和關鍵字引數
def __init__(self, *args, **kwargs):
# 定義棄用警告資訊,提示使用者更新匯入路徑
deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead."
# 觸發棄用警告
deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message)
# 呼叫父類的初始化方法
super().__init__(*args, **kwargs)
# 定義一個函式,返回不同型別的下采樣塊
def get_down_block(
# 定義引數,型別和含義
down_block_type: str, # 下采樣塊的型別
num_layers: int, # 層數
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
add_downsample: bool, # 是否新增下采樣
resnet_eps: float, # ResNet 的 epsilon 引數
resnet_act_fn: str, # ResNet 的啟用函式
num_attention_heads: int, # 注意力頭數
resnet_groups: Optional[int] = None, # ResNet 的分組數,可選
cross_attention_dim: Optional[int] = None, # 交叉注意力維度,可選
downsample_padding: Optional[int] = None, # 下采樣填充,可選
dual_cross_attention: bool = False, # 是否使用雙重交叉注意力
use_linear_projection: bool = True, # 是否使用線性投影
only_cross_attention: bool = False, # 是否僅使用交叉注意力
upcast_attention: bool = False, # 是否提升注意力精度
resnet_time_scale_shift: str = "default", # ResNet 時間尺度偏移
temporal_num_attention_heads: int = 8, # 時間注意力頭數
temporal_max_seq_length: int = 32, # 時間序列最大長度
transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 每個塊的變換器層數
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 時間變換器每塊層數
dropout: float = 0.0, # dropout 機率
) -> Union[
"DownBlock3D", # 返回的可能型別之一:3D 下采樣塊
"CrossAttnDownBlock3D", # 返回的可能型別之二:交叉注意力下采樣塊
"DownBlockSpatioTemporal", # 返回的可能型別之三:時空下采樣塊
"CrossAttnDownBlockSpatioTemporal", # 返回的可能型別之四:時空交叉注意力下采樣塊
]:
# 檢查下采樣塊型別是否為 DownBlock3D
if down_block_type == "DownBlock3D":
# 建立並返回 DownBlock3D 例項
return DownBlock3D(
num_layers=num_layers, # 傳入層數
in_channels=in_channels, # 傳入輸入通道數
out_channels=out_channels, # 傳入輸出通道數
temb_channels=temb_channels, # 傳入時間嵌入通道數
add_downsample=add_downsample, # 傳入是否新增下采樣
resnet_eps=resnet_eps, # 傳入 ResNet 的 epsilon 引數
resnet_act_fn=resnet_act_fn, # 傳入啟用函式
resnet_groups=resnet_groups, # 傳入分組數
downsample_padding=downsample_padding, # 傳入下采樣填充
resnet_time_scale_shift=resnet_time_scale_shift, # 傳入時間尺度偏移
dropout=dropout, # 傳入 dropout 機率
)
# 檢查下采樣塊型別是否為 CrossAttnDownBlock3D
elif down_block_type == "CrossAttnDownBlock3D":
# 如果交叉注意力維度未指定,丟擲錯誤
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
# 建立並返回 CrossAttnDownBlock3D 例項
return CrossAttnDownBlock3D(
num_layers=num_layers, # 傳入層數
in_channels=in_channels, # 傳入輸入通道數
out_channels=out_channels, # 傳入輸出通道數
temb_channels=temb_channels, # 傳入時間嵌入通道數
add_downsample=add_downsample, # 傳入是否新增下采樣
resnet_eps=resnet_eps, # 傳入 ResNet 的 epsilon 引數
resnet_act_fn=resnet_act_fn, # 傳入啟用函式
resnet_groups=resnet_groups, # 傳入分組數
downsample_padding=downsample_padding, # 傳入下采樣填充
cross_attention_dim=cross_attention_dim, # 傳入交叉注意力維度
num_attention_heads=num_attention_heads, # 傳入注意力頭數
dual_cross_attention=dual_cross_attention, # 傳入是否使用雙重交叉注意力
use_linear_projection=use_linear_projection, # 傳入是否使用線性投影
only_cross_attention=only_cross_attention, # 傳入是否僅使用交叉注意力
upcast_attention=upcast_attention, # 傳入是否提升注意力精度
resnet_time_scale_shift=resnet_time_scale_shift, # 傳入時間尺度偏移
dropout=dropout, # 傳入 dropout 機率
)
# 檢查下一個塊的型別是否為時空下采樣塊
elif down_block_type == "DownBlockSpatioTemporal":
# 為 SDV 進行了新增
# 返回一個時空下采樣塊的例項
return DownBlockSpatioTemporal(
# 設定層數
num_layers=num_layers,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# 是否新增下采樣
add_downsample=add_downsample,
)
# 檢查下一個塊的型別是否為交叉注意力時空下采樣塊
elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
# 為 SDV 進行了新增
# 如果沒有指定交叉注意力維度,丟擲錯誤
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
# 返回一個交叉注意力時空下采樣塊的例項
return CrossAttnDownBlockSpatioTemporal(
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# 設定層數
num_layers=num_layers,
# 每個塊的變換層數
transformer_layers_per_block=transformer_layers_per_block,
# 是否新增下采樣
add_downsample=add_downsample,
# 設定交叉注意力維度
cross_attention_dim=cross_attention_dim,
# 注意力頭數
num_attention_heads=num_attention_heads,
)
# 如果塊型別不匹配,則丟擲錯誤
raise ValueError(f"{down_block_type} does not exist.")
# 定義函式 get_up_block,返回不同型別的上取樣模組
def get_up_block(
# 上取樣塊型別
up_block_type: str,
# 層數
num_layers: int,
# 輸入通道數
in_channels: int,
# 輸出通道數
out_channels: int,
# 上一層的輸出通道數
prev_output_channel: int,
# 時間嵌入通道數
temb_channels: int,
# 是否新增上取樣
add_upsample: bool,
# ResNet 的 epsilon 值
resnet_eps: float,
# ResNet 的啟用函式型別
resnet_act_fn: str,
# 注意力頭的數量
num_attention_heads: int,
# 解析度索引(可選)
resolution_idx: Optional[int] = None,
# ResNet 組數(可選)
resnet_groups: Optional[int] = None,
# 跨注意力維度(可選)
cross_attention_dim: Optional[int] = None,
# 是否使用雙重跨注意力
dual_cross_attention: bool = False,
# 是否使用線性投影
use_linear_projection: bool = True,
# 是否僅使用跨注意力
only_cross_attention: bool = False,
# 是否提升注意力計算
upcast_attention: bool = False,
# ResNet 時間尺度移位的設定
resnet_time_scale_shift: str = "default",
# 時間上的注意力頭數量
temporal_num_attention_heads: int = 8,
# 時間上的跨注意力維度(可選)
temporal_cross_attention_dim: Optional[int] = None,
# 時間序列最大長度
temporal_max_seq_length: int = 32,
# 每個塊的變換器層數
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# 每個塊的時間變換器層數
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# dropout 機率
dropout: float = 0.0,
) -> Union[
# 返回型別為不同的上取樣塊
"UpBlock3D",
"CrossAttnUpBlock3D",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
]:
# 判斷上取樣塊型別是否為 UpBlock3D
if up_block_type == "UpBlock3D":
# 建立並返回 UpBlock3D 例項
return UpBlock3D(
# 傳入引數設定
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
dropout=dropout,
)
# 判斷上取樣塊型別是否為 CrossAttnUpBlock3D
elif up_block_type == "CrossAttnUpBlock3D":
# 檢查是否提供跨注意力維度
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
# 建立並返回 CrossAttnUpBlock3D 例項
return CrossAttnUpBlock3D(
# 傳入引數設定
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
dropout=dropout,
)
# 檢查上升塊型別是否為 "UpBlockSpatioTemporal"
elif up_block_type == "UpBlockSpatioTemporal":
# 為 SDV 新增的內容
# 返回 UpBlockSpatioTemporal 例項,使用指定的引數
return UpBlockSpatioTemporal(
# 層數引數
num_layers=num_layers,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 前一個輸出通道數
prev_output_channel=prev_output_channel,
# 時間嵌入通道數
temb_channels=temb_channels,
# 解析度索引
resolution_idx=resolution_idx,
# 是否新增上取樣
add_upsample=add_upsample,
)
# 檢查上升塊型別是否為 "CrossAttnUpBlockSpatioTemporal"
elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
# 為 SDV 新增的內容
# 如果沒有指定交叉注意力維度,丟擲錯誤
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
# 返回 CrossAttnUpBlockSpatioTemporal 例項,使用指定的引數
return CrossAttnUpBlockSpatioTemporal(
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 前一個輸出通道數
prev_output_channel=prev_output_channel,
# 時間嵌入通道數
temb_channels=temb_channels,
# 層數引數
num_layers=num_layers,
# 每個塊的變換層數
transformer_layers_per_block=transformer_layers_per_block,
# 是否新增上取樣
add_upsample=add_upsample,
# 交叉注意力維度
cross_attention_dim=cross_attention_dim,
# 注意力頭數
num_attention_heads=num_attention_heads,
# 解析度索引
resolution_idx=resolution_idx,
)
# 如果上升塊型別不符合任何已知型別,丟擲錯誤
raise ValueError(f"{up_block_type} does not exist.")
# 定義一個名為 UNetMidBlock3DCrossAttn 的類,繼承自 nn.Module
class UNetMidBlock3DCrossAttn(nn.Module):
# 初始化方法,設定類的引數
def __init__(
self,
in_channels: int, # 輸入通道數
temb_channels: int, # 時間嵌入通道數
dropout: float = 0.0, # dropout 機率
num_layers: int = 1, # 層數
resnet_eps: float = 1e-6, # ResNet 中的小常數,避免除零
resnet_time_scale_shift: str = "default", # ResNet 時間縮放偏移
resnet_act_fn: str = "swish", # ResNet 啟用函式型別
resnet_groups: int = 32, # ResNet 分組數
resnet_pre_norm: bool = True, # 是否在 ResNet 前進行歸一化
num_attention_heads: int = 1, # 注意力頭數
output_scale_factor: float = 1.0, # 輸出縮放因子
cross_attention_dim: int = 1280, # 交叉注意力維度
dual_cross_attention: bool = False, # 是否使用雙交叉注意力
use_linear_projection: bool = True, # 是否使用線性投影
upcast_attention: bool = False, # 是否使用上取樣注意力
):
# 省略具體初始化程式碼,通常會在這裡初始化各個層和引數
# 定義前向傳播方法
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可選的編碼器隱藏狀態
attention_mask: Optional[torch.Tensor] = None, # 可選的注意力掩碼
num_frames: int = 1, # 幀數
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 交叉注意力的額外引數
) -> torch.Tensor: # 返回一個張量
# 透過第一個 ResNet 層處理隱藏狀態
hidden_states = self.resnets[0](hidden_states, temb)
# 透過第一個時間卷積層處理隱藏狀態
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
# 遍歷所有的注意力層、時間注意力層、ResNet 層和時間卷積層
for attn, temp_attn, resnet, temp_conv in zip(
self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
):
# 透過當前注意力層處理隱藏狀態
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0] # 只取返回的第一個元素
# 透過當前時間注意力層處理隱藏狀態
hidden_states = temp_attn(
hidden_states,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0] # 只取返回的第一個元素
# 透過當前 ResNet 層處理隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 透過當前時間卷積層處理隱藏狀態
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
# 返回最終的隱藏狀態
return hidden_states
# 定義一個名為 CrossAttnDownBlock3D 的類,繼承自 nn.Module
class CrossAttnDownBlock3D(nn.Module):
# 初始化方法,設定類的引數
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
dropout: float = 0.0, # dropout 機率
num_layers: int = 1, # 層數
resnet_eps: float = 1e-6, # ResNet 中的小常數,避免除零
resnet_time_scale_shift: str = "default", # ResNet 時間縮放偏移
resnet_act_fn: str = "swish", # ResNet 啟用函式型別
resnet_groups: int = 32, # ResNet 分組數
resnet_pre_norm: bool = True, # 是否在 ResNet 前進行歸一化
num_attention_heads: int = 1, # 注意力頭數
cross_attention_dim: int = 1280, # 交叉注意力維度
output_scale_factor: float = 1.0, # 輸出縮放因子
downsample_padding: int = 1, # 下采樣的填充大小
add_downsample: bool = True, # 是否新增下采樣層
dual_cross_attention: bool = False, # 是否使用雙交叉注意力
use_linear_projection: bool = False, # 是否使用線性投影
only_cross_attention: bool = False, # 是否只使用交叉注意力
upcast_attention: bool = False, # 是否使用上取樣注意力
):
# 呼叫父類的初始化方法
super().__init__()
# 初始化殘差塊列表
resnets = []
# 初始化注意力層列表
attentions = []
# 初始化臨時注意力層列表
temp_attentions = []
# 初始化臨時卷積層列表
temp_convs = []
# 設定是否使用交叉注意力
self.has_cross_attention = True
# 設定注意力頭的數量
self.num_attention_heads = num_attention_heads
# 根據層數建立各層模組
for i in range(num_layers):
# 設定輸入通道數,第一層使用傳入的輸入通道數,後續層使用輸出通道數
in_channels = in_channels if i == 0 else out_channels
# 建立殘差塊並新增到列表中
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # 殘差塊的 epsilon 值
groups=resnet_groups, # 組卷積的組數
dropout=dropout, # Dropout 比例
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入的歸一化方式
non_linearity=resnet_act_fn, # 非線性啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否進行預歸一化
)
)
# 建立時間卷積層並新增到列表中
temp_convs.append(
TemporalConvLayer(
out_channels, # 輸入通道數
out_channels, # 輸出通道數
dropout=0.1, # Dropout 比例
norm_num_groups=resnet_groups, # 歸一化的組數
)
)
# 建立二維變換器模型並新增到列表中
attentions.append(
Transformer2DModel(
out_channels // num_attention_heads, # 每個注意力頭的通道數
num_attention_heads, # 注意力頭的數量
in_channels=out_channels, # 輸入通道數
num_layers=1, # 變換器層數
cross_attention_dim=cross_attention_dim, # 交叉注意力維度
norm_num_groups=resnet_groups, # 歸一化的組數
use_linear_projection=use_linear_projection, # 是否使用線性對映
only_cross_attention=only_cross_attention, # 是否只使用交叉注意力
upcast_attention=upcast_attention, # 是否上溢注意力
)
)
# 建立時間變換器模型並新增到列表中
temp_attentions.append(
TransformerTemporalModel(
out_channels // num_attention_heads, # 每個注意力頭的通道數
num_attention_heads, # 注意力頭的數量
in_channels=out_channels, # 輸入通道數
num_layers=1, # 變換器層數
cross_attention_dim=cross_attention_dim, # 交叉注意力維度
norm_num_groups=resnet_groups, # 歸一化的組數
)
)
# 將殘差塊列表轉換為模組列表
self.resnets = nn.ModuleList(resnets)
# 將臨時卷積層列表轉換為模組列表
self.temp_convs = nn.ModuleList(temp_convs)
# 將注意力層列表轉換為模組列表
self.attentions = nn.ModuleList(attentions)
# 將臨時注意力層列表轉換為模組列表
self.temp_attentions = nn.ModuleList(temp_attentions)
# 如果需要新增下采樣層
if add_downsample:
# 建立下采樣模組列表
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, # 輸出通道數
use_conv=True, # 是否使用卷積進行下采樣
out_channels=out_channels, # 輸出通道數
padding=downsample_padding, # 下采樣時的填充
name="op", # 模組名稱
)
]
)
else:
# 如果不需要下采樣,設定為 None
self.downsamplers = None
# 初始化梯度檢查點設定為 False
self.gradient_checkpointing = False
# 定義前向傳播方法,接受多個輸入引數並返回張量或元組
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
cross_attention_kwargs: Dict[str, Any] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
# TODO(Patrick, William) - 注意力掩碼未使用
output_states = () # 初始化輸出狀態為元組
# 遍歷所有的殘差網路、臨時卷積、注意力和臨時注意力層
for resnet, temp_conv, attn, temp_attn in zip(
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
):
# 使用殘差網路處理隱狀態和時間嵌入
hidden_states = resnet(hidden_states, temb)
# 使用臨時卷積處理隱狀態,考慮幀數
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
# 使用注意力層處理隱狀態,返回字典設為 False
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0] # 取返回的第一個元素
# 使用臨時注意力層處理隱狀態,返回字典設為 False
hidden_states = temp_attn(
hidden_states,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0] # 取返回的第一個元素
# 將當前隱狀態新增到輸出狀態中
output_states += (hidden_states,)
# 如果存在下采樣器,則逐個應用
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states) # 應用下采樣器
# 將下采樣後的隱狀態新增到輸出狀態中
output_states += (hidden_states,)
# 返回最終的隱狀態和所有輸出狀態
return hidden_states, output_states
# 定義一個 3D 下采樣模組,繼承自 nn.Module
class DownBlock3D(nn.Module):
# 初始化方法,設定各個引數及模組
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
dropout: float = 0.0, # dropout 機率
num_layers: int = 1, # 層數
resnet_eps: float = 1e-6, # ResNet 中的 epsilon 值
resnet_time_scale_shift: str = "default", # 時間尺度偏移方式
resnet_act_fn: str = "swish", # ResNet 啟用函式型別
resnet_groups: int = 32, # ResNet 中的組數
resnet_pre_norm: bool = True, # 是否使用預歸一化
output_scale_factor: float = 1.0, # 輸出縮放因子
add_downsample: bool = True, # 是否新增下采樣層
downsample_padding: int = 1, # 下采樣的填充大小
):
# 呼叫父類建構函式
super().__init__()
# 初始化 ResNet 模組列表
resnets = []
# 初始化時間卷積層列表
temp_convs = []
# 遍歷層數,構建 ResNet 模組和時間卷積層
for i in range(num_layers):
# 第一層使用輸入通道數,後續層使用輸出通道數
in_channels = in_channels if i == 0 else out_channels
# 新增 ResNet 模組到列表中
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # ResNet 中的 epsilon 值
groups=resnet_groups, # 組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化方式
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否預歸一化
)
)
# 新增時間卷積層到列表中
temp_convs.append(
TemporalConvLayer(
out_channels, # 輸入通道數
out_channels, # 輸出通道數
dropout=0.1, # dropout 機率
norm_num_groups=resnet_groups, # 組數
)
)
# 將 ResNet 模組列表轉換為 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 將時間卷積層列表轉換為 nn.ModuleList
self.temp_convs = nn.ModuleList(temp_convs)
# 如果需要新增下采樣層
if add_downsample:
# 初始化下采樣模組列表
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, # 輸出通道數
use_conv=True, # 使用卷積
out_channels=out_channels, # 輸出通道數
padding=downsample_padding, # 填充大小
name="op", # 模組名稱
)
]
)
else:
# 不新增下采樣層
self.downsamplers = None
# 初始化梯度檢查點開關為 False
self.gradient_checkpointing = False
# 前向傳播方法
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
num_frames: int = 1, # 幀數
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 初始化輸出狀態元組
output_states = ()
# 遍歷 ResNet 模組和時間卷積層
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
# 透過 ResNet 模組處理隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 透過時間卷積層處理隱藏狀態
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
# 將當前的隱藏狀態新增到輸出狀態元組中
output_states += (hidden_states,)
# 如果存在下采樣層
if self.downsamplers is not None:
# 遍歷下采樣層
for downsampler in self.downsamplers:
# 透過下采樣層處理隱藏狀態
hidden_states = downsampler(hidden_states)
# 將當前的隱藏狀態新增到輸出狀態元組中
output_states += (hidden_states,)
# 返回最終的隱藏狀態和輸出狀態元組
return hidden_states, output_states
# 定義一個 3D 交叉注意力上取樣模組,繼承自 nn.Module
class CrossAttnUpBlock3D(nn.Module):
# 初始化方法,用於設定類的屬性
def __init__(
# 輸入通道數
self,
in_channels: int,
# 輸出通道數
out_channels: int,
# 前一個輸出通道數
prev_output_channel: int,
# 時間嵌入通道數
temb_channels: int,
# dropout比率,預設為0.0
dropout: float = 0.0,
# 網路層數,預設為1
num_layers: int = 1,
# ResNet的epsilon值,預設為1e-6
resnet_eps: float = 1e-6,
# ResNet的時間尺度偏移,預設為"default"
resnet_time_scale_shift: str = "default",
# ResNet啟用函式,預設為"swish"
resnet_act_fn: str = "swish",
# ResNet的組數,預設為32
resnet_groups: int = 32,
# ResNet是否使用預歸一化,預設為True
resnet_pre_norm: bool = True,
# 注意力頭的數量,預設為1
num_attention_heads: int = 1,
# 交叉注意力的維度,預設為1280
cross_attention_dim: int = 1280,
# 輸出縮放因子,預設為1.0
output_scale_factor: float = 1.0,
# 是否新增上取樣,預設為True
add_upsample: bool = True,
# 雙重交叉注意力的開關,預設為False
dual_cross_attention: bool = False,
# 是否使用線性投影,預設為False
use_linear_projection: bool = False,
# 僅使用交叉注意力的開關,預設為False
only_cross_attention: bool = False,
# 是否上取樣注意力,預設為False
upcast_attention: bool = False,
# 解析度索引,預設為None
resolution_idx: Optional[int] = None,
):
# 呼叫父類的建構函式進行初始化
super().__init__()
# 初始化用於儲存 ResNet 塊的列表
resnets = []
# 初始化用於儲存時間卷積層的列表
temp_convs = []
# 初始化用於儲存注意力模型的列表
attentions = []
# 初始化用於儲存時間注意力模型的列表
temp_attentions = []
# 設定是否使用交叉注意力
self.has_cross_attention = True
# 設定注意力頭的數量
self.num_attention_heads = num_attention_heads
# 遍歷每一層以構建網路結構
for i in range(num_layers):
# 確定殘差跳過通道數
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 確定 ResNet 的輸入通道數
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 新增一個 ResNet 塊到列表中
resnets.append(
ResnetBlock2D(
# 設定輸入通道數,包括殘差跳過通道
in_channels=resnet_in_channels + res_skip_channels,
# 設定輸出通道數
out_channels=out_channels,
# 設定時間嵌入通道數
temb_channels=temb_channels,
# 設定 epsilon 值
eps=resnet_eps,
# 設定組數
groups=resnet_groups,
# 設定 dropout 機率
dropout=dropout,
# 設定時間嵌入歸一化方式
time_embedding_norm=resnet_time_scale_shift,
# 設定非線性啟用函式
non_linearity=resnet_act_fn,
# 設定輸出縮放因子
output_scale_factor=output_scale_factor,
# 設定是否在預歸一化
pre_norm=resnet_pre_norm,
)
)
# 新增一個時間卷積層到列表中
temp_convs.append(
TemporalConvLayer(
# 設定輸出通道數
out_channels,
# 設定輸入通道數
out_channels,
# 設定 dropout 機率
dropout=0.1,
# 設定組數
norm_num_groups=resnet_groups,
)
)
# 新增一個 2D 轉換器模型到列表中
attentions.append(
Transformer2DModel(
# 設定每個注意力頭的通道數
out_channels // num_attention_heads,
# 設定注意力頭的數量
num_attention_heads,
# 設定輸入通道數
in_channels=out_channels,
# 設定層數
num_layers=1,
# 設定交叉注意力維度
cross_attention_dim=cross_attention_dim,
# 設定組數
norm_num_groups=resnet_groups,
# 是否使用線性投影
use_linear_projection=use_linear_projection,
# 是否僅使用交叉注意力
only_cross_attention=only_cross_attention,
# 是否提升注意力精度
upcast_attention=upcast_attention,
)
)
# 新增一個時間轉換器模型到列表中
temp_attentions.append(
TransformerTemporalModel(
# 設定每個注意力頭的通道數
out_channels // num_attention_heads,
# 設定注意力頭的數量
num_attention_heads,
# 設定輸入通道數
in_channels=out_channels,
# 設定層數
num_layers=1,
# 設定交叉注意力維度
cross_attention_dim=cross_attention_dim,
# 設定組數
norm_num_groups=resnet_groups,
)
)
# 將 ResNet 塊列表轉換為 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 將時間卷積層列表轉換為 nn.ModuleList
self.temp_convs = nn.ModuleList(temp_convs)
# 將注意力模型列表轉換為 nn.ModuleList
self.attentions = nn.ModuleList(attentions)
# 將時間注意力模型列表轉換為 nn.ModuleList
self.temp_attentions = nn.ModuleList(temp_attentions)
# 如果需要上取樣,則初始化上取樣層
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 否則將上取樣層設定為 None
self.upsamplers = None
# 設定梯度檢查點標誌
self.gradient_checkpointing = False
# 設定解析度索引
self.resolution_idx = resolution_idx
# 定義前向傳播函式,接受多個引數並返回張量
def forward(
self,
hidden_states: torch.Tensor, # 當前層的隱藏狀態
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 以前層的隱藏狀態元組
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可選的編碼器隱藏狀態
upsample_size: Optional[int] = None, # 可選的上取樣大小
attention_mask: Optional[torch.Tensor] = None, # 可選的注意力掩碼
num_frames: int = 1, # 幀數,預設值為1
cross_attention_kwargs: Dict[str, Any] = None, # 可選的交叉注意力引數
) -> torch.Tensor: # 返回一個張量
# 檢查 FreeU 是否啟用,基於多個屬性的存在性
is_freeu_enabled = (
getattr(self, "s1", None) # 獲取屬性 s1
and getattr(self, "s2", None) # 獲取屬性 s2
and getattr(self, "b1", None) # 獲取屬性 b1
and getattr(self, "b2", None) # 獲取屬性 b2
)
# TODO(Patrick, William) - 注意力掩碼尚未使用
for resnet, temp_conv, attn, temp_attn in zip( # 遍歷網路模組
self.resnets, self.temp_convs, self.attentions, self.temp_attentions # 從各模組提取
):
# 從元組中彈出最後一個殘差隱藏狀態
res_hidden_states = res_hidden_states_tuple[-1]
# 更新元組,去掉最後一個隱藏狀態
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU:僅對前兩個階段操作
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu( # 應用 FreeU 操作
self.resolution_idx, # 當前解析度索引
hidden_states, # 當前隱藏狀態
res_hidden_states, # 殘差隱藏狀態
s1=self.s1, # 屬性 s1
s2=self.s2, # 屬性 s2
b1=self.b1, # 屬性 b1
b2=self.b2, # 屬性 b2
)
# 將當前隱藏狀態與殘差隱藏狀態在維度1上拼接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 透過 ResNet 模組處理隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 透過臨時卷積模組處理隱藏狀態
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
# 透過注意力模組處理隱藏狀態,並提取第一個返回值
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states, # 傳遞編碼器隱藏狀態
cross_attention_kwargs=cross_attention_kwargs, # 傳遞交叉注意力引數
return_dict=False, # 不返回字典形式的結果
)[0] # 提取第一個返回值
# 透過臨時注意力模組處理隱藏狀態,並提取第一個返回值
hidden_states = temp_attn(
hidden_states,
num_frames=num_frames, # 傳遞幀數
cross_attention_kwargs=cross_attention_kwargs, # 傳遞交叉注意力引數
return_dict=False, # 不返回字典形式的結果
)[0] # 提取第一個返回值
# 如果存在上取樣模組
if self.upsamplers is not None:
for upsampler in self.upsamplers: # 遍歷上取樣模組
hidden_states = upsampler(hidden_states, upsample_size) # 應用上取樣模組
# 返回最終的隱藏狀態
return hidden_states
# 定義一個名為 UpBlock3D 的類,繼承自 nn.Module
class UpBlock3D(nn.Module):
# 初始化函式,接受多個引數以配置網路層
def __init__(
self,
in_channels: int, # 輸入通道數
prev_output_channel: int, # 前一層的輸出通道數
out_channels: int, # 當前層的輸出通道數
temb_channels: int, # 時間嵌入通道數
dropout: float = 0.0, # dropout 機率
num_layers: int = 1, # 層數
resnet_eps: float = 1e-6, # ResNet 的 epsilon 引數,用於數值穩定
resnet_time_scale_shift: str = "default", # ResNet 時間縮放偏移設定
resnet_act_fn: str = "swish", # ResNet 啟用函式
resnet_groups: int = 32, # ResNet 的組數
resnet_pre_norm: bool = True, # 是否使用前置歸一化
output_scale_factor: float = 1.0, # 輸出縮放因子
add_upsample: bool = True, # 是否新增上取樣層
resolution_idx: Optional[int] = None, # 解析度索引,預設為 None
):
# 呼叫父類建構函式
super().__init__()
# 建立一個空列表,用於儲存 ResNet 層
resnets = []
# 建立一個空列表,用於儲存時間卷積層
temp_convs = []
# 根據層數建立 ResNet 和時間卷積層
for i in range(num_layers):
# 確定跳躍連線的通道數
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 確定當前 ResNet 層的輸入通道數
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 建立 ResNet 層,並新增到 resnets 列表中
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 引數
groups=resnet_groups, # 組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 前置歸一化
)
)
# 建立時間卷積層,並新增到 temp_convs 列表中
temp_convs.append(
TemporalConvLayer(
out_channels, # 輸入通道數
out_channels, # 輸出通道數
dropout=0.1, # dropout 機率
norm_num_groups=resnet_groups, # 歸一化組數
)
)
# 將 ResNet 層的列表轉為 nn.ModuleList,以便於管理
self.resnets = nn.ModuleList(resnets)
# 將時間卷積層的列表轉為 nn.ModuleList,以便於管理
self.temp_convs = nn.ModuleList(temp_convs)
# 如果需要新增上取樣層,則建立並新增
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 如果不需要上取樣層,則設定為 None
self.upsamplers = None
# 初始化梯度檢查點標誌為 False
self.gradient_checkpointing = False
# 設定解析度索引
self.resolution_idx = resolution_idx
# 定義前向傳播函式
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 跳躍連線的隱藏狀態元組
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
upsample_size: Optional[int] = None, # 可選的上取樣尺寸
num_frames: int = 1, # 幀數
) -> torch.Tensor:
# 判斷是否啟用 FreeU,檢查相關屬性是否存在且不為 None
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
# 遍歷自定義的 resnets 和 temp_convs,進行逐對處理
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
# 從 res_hidden_states_tuple 中彈出最後一個隱藏狀態
res_hidden_states = res_hidden_states_tuple[-1]
# 更新 res_hidden_states_tuple,去掉最後一個隱藏狀態
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 如果啟用了 FreeU,則僅對前兩個階段進行操作
if is_freeu_enabled:
# 呼叫 apply_freeu 函式,處理隱藏狀態
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
# 將當前的隱藏狀態與殘差隱藏狀態在維度 1 上連線
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 透過當前的 resnet 處理隱藏狀態和 temb
hidden_states = resnet(hidden_states, temb)
# 透過當前的 temp_conv 處理隱藏狀態,傳入 num_frames 引數
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
# 如果存在上取樣器,則對每個上取樣器進行處理
if self.upsamplers is not None:
for upsampler in self.upsamplers:
# 透過當前的 upsampler 處理隱藏狀態,傳入 upsample_size 引數
hidden_states = upsampler(hidden_states, upsample_size)
# 返回最終的隱藏狀態
return hidden_states
# 定義一箇中間塊時間解碼器類,繼承自 nn.Module
class MidBlockTemporalDecoder(nn.Module):
# 初始化函式,定義輸入輸出通道、注意力頭維度、層數等引數
def __init__(
self,
in_channels: int,
out_channels: int,
attention_head_dim: int = 512,
num_layers: int = 1,
upcast_attention: bool = False,
):
# 呼叫父類的初始化函式
super().__init__()
# 初始化 ResNet 和 Attention 列表
resnets = []
attentions = []
# 根據層數建立相應數量的 ResNet
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
# 將 SpatioTemporalResBlock 例項新增到 ResNet 列表中
resnets.append(
SpatioTemporalResBlock(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=None,
eps=1e-6,
temporal_eps=1e-5,
merge_factor=0.0,
merge_strategy="learned",
switch_spatial_to_temporal_mix=True,
)
)
# 新增 Attention 例項到 Attention 列表中
attentions.append(
Attention(
query_dim=in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
eps=1e-6,
upcast_attention=upcast_attention,
norm_num_groups=32,
bias=True,
residual_connection=True,
)
)
# 將 Attention 和 ResNet 列表轉換為 ModuleList
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
# 前向傳播函式,定義輸入的隱藏狀態和影像指示器的處理
def forward(
self,
hidden_states: torch.Tensor,
image_only_indicator: torch.Tensor,
):
# 處理第一個 ResNet 的輸出
hidden_states = self.resnets[0](
hidden_states,
image_only_indicator=image_only_indicator,
)
# 遍歷剩餘的 ResNet 和 Attention,交替處理
for resnet, attn in zip(self.resnets[1:], self.attentions):
hidden_states = attn(hidden_states) # 應用注意力機制
# 處理 ResNet 的輸出
hidden_states = resnet(
hidden_states,
image_only_indicator=image_only_indicator,
)
# 返回最終的隱藏狀態
return hidden_states
# 定義一個上取樣塊時間解碼器類,繼承自 nn.Module
class UpBlockTemporalDecoder(nn.Module):
# 初始化函式,定義輸入輸出通道、層數等引數
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
add_upsample: bool = True,
):
# 呼叫父類的初始化函式
super().__init__()
# 初始化 ResNet 列表
resnets = []
# 根據層數建立相應數量的 ResNet
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
# 將 SpatioTemporalResBlock 例項新增到 ResNet 列表中
resnets.append(
SpatioTemporalResBlock(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=None,
eps=1e-6,
temporal_eps=1e-5,
merge_factor=0.0,
merge_strategy="learned",
switch_spatial_to_temporal_mix=True,
)
)
# 將 ResNet 列表轉換為 ModuleList
self.resnets = nn.ModuleList(resnets)
# 如果需要,初始化上取樣層
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
# 定義前向傳播函式,接收隱藏狀態和影像指示器作為輸入,返回處理後的張量
def forward(
self,
hidden_states: torch.Tensor,
image_only_indicator: torch.Tensor,
) -> torch.Tensor:
# 遍歷每個 ResNet 模組,更新隱藏狀態
for resnet in self.resnets:
hidden_states = resnet(
hidden_states,
image_only_indicator=image_only_indicator,
)
# 如果存在上取樣模組,則對隱藏狀態進行上取樣處理
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
# 返回最終的隱藏狀態
return hidden_states
# 定義一個名為 UNetMidBlockSpatioTemporal 的類,繼承自 nn.Module
class UNetMidBlockSpatioTemporal(nn.Module):
# 初始化方法,接收多個引數以配置該模組
def __init__(
self,
in_channels: int, # 輸入通道數
temb_channels: int, # 時間嵌入通道數
num_layers: int = 1, # 層數,預設為 1
transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 每個塊的變換層數,預設為 1
num_attention_heads: int = 1, # 注意力頭數,預設為 1
cross_attention_dim: int = 1280, # 交叉注意力維度,預設為 1280
):
# 呼叫父類的初始化方法
super().__init__()
# 設定是否使用交叉注意力標誌
self.has_cross_attention = True
# 儲存注意力頭的數量
self.num_attention_heads = num_attention_heads
# 支援每個塊的變換層數為可變的
if isinstance(transformer_layers_per_block, int):
# 如果是整數,則將其轉換為包含 num_layers 個相同元素的列表
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# 至少有一個 ResNet 塊
resnets = [
# 建立第一個時空殘差塊
SpatioTemporalResBlock(
in_channels=in_channels, # 輸入通道數
out_channels=in_channels, # 輸出通道數與輸入相同
temb_channels=temb_channels, # 時間嵌入通道數
eps=1e-5, # 小常數用於數值穩定性
)
]
# 初始化注意力模組列表
attentions = []
# 遍歷層數以新增註意力和殘差塊
for i in range(num_layers):
# 新增時空變換模型到注意力列表
attentions.append(
TransformerSpatioTemporalModel(
num_attention_heads, # 注意力頭數
in_channels // num_attention_heads, # 每個頭的通道數
in_channels=in_channels, # 輸入通道數
num_layers=transformer_layers_per_block[i], # 當前層的變換層數
cross_attention_dim=cross_attention_dim, # 交叉注意力維度
)
)
# 新增另一個時空殘差塊到殘差列表
resnets.append(
SpatioTemporalResBlock(
in_channels=in_channels, # 輸入通道數
out_channels=in_channels, # 輸出通道數與輸入相同
temb_channels=temb_channels, # 時間嵌入通道數
eps=1e-5, # 小常數用於數值穩定性
)
)
# 將注意力模組列表轉換為 nn.ModuleList
self.attentions = nn.ModuleList(attentions)
# 將殘差模組列表轉換為 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 初始化梯度檢查點標誌為 False
self.gradient_checkpointing = False
# 前向傳播方法,接收多個輸入引數
def forward(
self,
hidden_states: torch.Tensor, # 隱藏狀態張量
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可選的編碼器隱藏狀態
image_only_indicator: Optional[torch.Tensor] = None, # 可選的影像指示張量
# 返回型別為 torch.Tensor 的函式
) -> torch.Tensor:
# 使用第一個殘差網路處理隱藏狀態,傳入時間嵌入和影像指示器
hidden_states = self.resnets[0](
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
# 遍歷注意力層和後續殘差網路的組合
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# 檢查是否在訓練中且開啟了梯度檢查點
if self.training and self.gradient_checkpointing: # TODO
# 建立自定義前向傳播函式
def create_custom_forward(module, return_dict=None):
# 定義自定義前向傳播內部函式
def custom_forward(*inputs):
# 根據返回字典引數決定是否使用返回字典
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
# 設定檢查點引數,取決於 PyTorch 版本
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 使用注意力層處理隱藏狀態,傳入編碼器的隱藏狀態和影像指示器
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
# 使用檢查點進行殘差網路的前向傳播
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
else:
# 使用注意力層處理隱藏狀態,傳入編碼器的隱藏狀態和影像指示器
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
# 使用殘差網路處理隱藏狀態
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
# 返回處理後的隱藏狀態
return hidden_states
# 定義一個下采樣的時空塊,繼承自 nn.Module
class DownBlockSpatioTemporal(nn.Module):
# 初始化方法,設定輸入輸出通道和層數等引數
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
num_layers: int = 1, # 層數,預設為1
add_downsample: bool = True, # 是否新增下采樣
):
super().__init__() # 呼叫父類的初始化方法
resnets = [] # 初始化一個空列表以儲存殘差塊
# 根據層數建立相應數量的 SpatioTemporalResBlock
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels # 確定當前層的輸入通道數
resnets.append(
SpatioTemporalResBlock( # 新增一個新的時空殘差塊
in_channels=in_channels, # 設定輸入通道
out_channels=out_channels, # 設定輸出通道
temb_channels=temb_channels, # 設定時間嵌入通道
eps=1e-5, # 設定 epsilon 值
)
)
self.resnets = nn.ModuleList(resnets) # 將殘差塊列表轉化為 ModuleList
# 如果需要下采樣,建立下采樣模組
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D( # 新增一個下采樣層
out_channels, # 設定輸入通道
use_conv=True, # 是否使用卷積進行下采樣
out_channels=out_channels, # 設定輸出通道
name="op", # 下采樣層名稱
)
]
)
else:
self.downsamplers = None # 不新增下采樣模組
self.gradient_checkpointing = False # 初始化梯度檢查點為 False
# 前向傳播方法
def forward(
self,
hidden_states: torch.Tensor, # 隱藏狀態輸入
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入
image_only_indicator: Optional[torch.Tensor] = None, # 可選的影像指示器
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
output_states = () # 初始化輸出狀態元組
for resnet in self.resnets: # 遍歷每個殘差塊
if self.training and self.gradient_checkpointing: # 如果在訓練且啟用了梯度檢查點
# 定義一個建立自定義前向傳播的方法
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs) # 返回模組的前向輸出
return custom_forward
# 根據 PyTorch 版本進行檢查點操作
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 使用自定義前向傳播
hidden_states, # 輸入隱藏狀態
temb, # 輸入時間嵌入
image_only_indicator, # 輸入影像指示器
use_reentrant=False, # 不使用重入
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 使用自定義前向傳播
hidden_states, # 輸入隱藏狀態
temb, # 輸入時間嵌入
image_only_indicator, # 輸入影像指示器
)
else:
hidden_states = resnet( # 直接透過殘差塊進行前向傳播
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
output_states = output_states + (hidden_states,) # 將當前隱藏狀態新增到輸出狀態中
if self.downsamplers is not None: # 如果存在下采樣模組
for downsampler in self.downsamplers: # 遍歷下采樣模組
hidden_states = downsampler(hidden_states) # 對隱藏狀態進行下采樣
output_states = output_states + (hidden_states,) # 將下采樣後的狀態新增到輸出狀態中
return hidden_states, output_states # 返回最終的隱藏狀態和輸出狀態元組
# 定義一個交叉注意力下采樣時空塊,繼承自 nn.Module
class CrossAttnDownBlockSpatioTemporal(nn.Module):
# 初始化方法,設定模型的基本引數
def __init__(
# 輸入通道數
self,
in_channels: int,
# 輸出通道數
out_channels: int,
# 時間嵌入通道數
temb_channels: int,
# 層數,預設為1
num_layers: int = 1,
# 每個塊的變換層數,可以是整數或元組,預設為1
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# 注意力頭數,預設為1
num_attention_heads: int = 1,
# 交叉注意力的維度,預設為1280
cross_attention_dim: int = 1280,
# 是否新增下采樣,預設為True
add_downsample: bool = True,
):
# 呼叫父類初始化方法
super().__init__()
# 初始化殘差網路列表
resnets = []
# 初始化注意力層列表
attentions = []
# 設定是否使用交叉注意力為True
self.has_cross_attention = True
# 設定注意力頭數
self.num_attention_heads = num_attention_heads
# 如果變換層數是整數,則擴充套件為列表
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# 遍歷層數以建立殘差塊和注意力層
for i in range(num_layers):
# 如果是第一層,使用輸入通道數,否則使用輸出通道數
in_channels = in_channels if i == 0 else out_channels
# 新增殘差塊到列表
resnets.append(
SpatioTemporalResBlock(
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# 防止除零的微小值
eps=1e-6,
)
)
# 新增註意力模型到列表
attentions.append(
TransformerSpatioTemporalModel(
# 注意力頭數
num_attention_heads,
# 每個頭的輸出通道數
out_channels // num_attention_heads,
# 輸入通道數
in_channels=out_channels,
# 該層的變換層數
num_layers=transformer_layers_per_block[i],
# 交叉注意力的維度
cross_attention_dim=cross_attention_dim,
)
)
# 將注意力層轉換為nn.ModuleList以支援PyTorch模型
self.attentions = nn.ModuleList(attentions)
# 將殘差層轉換為nn.ModuleList以支援PyTorch模型
self.resnets = nn.ModuleList(resnets)
# 如果需要新增下采樣層
if add_downsample:
# 新增下采樣層到nn.ModuleList
self.downsamplers = nn.ModuleList(
[
Downsample2D(
# 輸出通道數
out_channels,
# 是否使用卷積
use_conv=True,
# 輸出通道數
out_channels=out_channels,
# 填充大小
padding=1,
# 操作名稱
name="op",
)
]
)
else:
# 如果不需要下采樣層,則設定為None
self.downsamplers = None
# 初始化梯度檢查點為False
self.gradient_checkpointing = False
# 前向傳播方法
def forward(
# 隱藏狀態的張量
hidden_states: torch.Tensor,
# 時間嵌入的可選張量
temb: Optional[torch.Tensor] = None,
# 編碼器隱藏狀態的可選張量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 僅影像指示的可選張量
image_only_indicator: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: # 定義返回型別為一個元組,包含一個張量和多個張量的元組
output_states = () # 初始化一個空元組,用於儲存輸出狀態
blocks = list(zip(self.resnets, self.attentions)) # 將自定義的殘差網路和注意力模組打包成一個列表
for resnet, attn in blocks: # 遍歷每個殘差網路和對應的注意力模組
if self.training and self.gradient_checkpointing: # 如果處於訓練模式並且啟用了梯度檢查點
def create_custom_forward(module, return_dict=None): # 定義一個建立自定義前向傳播函式的輔助函式
def custom_forward(*inputs): # 定義自定義前向傳播函式,接受任意數量的輸入
if return_dict is not None: # 如果提供了返回字典引數
return module(*inputs, return_dict=return_dict) # 使用返回字典引數呼叫模組
else: # 如果沒有提供返回字典
return module(*inputs) # 直接呼叫模組並返回結果
return custom_forward # 返回自定義前向傳播函式
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} # 根據 PyTorch 版本設定檢查點引數
hidden_states = torch.utils.checkpoint.checkpoint( # 使用檢查點功能計算隱藏狀態以節省記憶體
create_custom_forward(resnet), # 將殘差網路傳入自定義前向函式
hidden_states, # 將當前隱藏狀態作為輸入
temb, # 傳遞時間嵌入
image_only_indicator, # 傳遞影像指示器
**ckpt_kwargs, # 解包檢查點引數
)
hidden_states = attn( # 使用注意力模組處理隱藏狀態
hidden_states, # 輸入隱藏狀態
encoder_hidden_states=encoder_hidden_states, # 傳遞編碼器隱藏狀態
image_only_indicator=image_only_indicator, # 傳遞影像指示器
return_dict=False, # 不返回字典形式的結果
)[0] # 獲取輸出的第一個元素
else: # 如果不處於訓練模式或未啟用梯度檢查點
hidden_states = resnet( # 直接使用殘差網路處理隱藏狀態
hidden_states, # 輸入當前隱藏狀態
temb, # 傳遞時間嵌入
image_only_indicator=image_only_indicator, # 傳遞影像指示器
)
hidden_states = attn( # 使用注意力模組處理隱藏狀態
hidden_states, # 輸入隱藏狀態
encoder_hidden_states=encoder_hidden_states, # 傳遞編碼器隱藏狀態
image_only_indicator=image_only_indicator, # 傳遞影像指示器
return_dict=False, # 不返回字典形式的結果
)[0] # 獲取輸出的第一個元素
output_states = output_states + (hidden_states,) # 將當前的隱藏狀態新增到輸出狀態元組中
if self.downsamplers is not None: # 如果存在下采樣模組
for downsampler in self.downsamplers: # 遍歷每個下采樣模組
hidden_states = downsampler(hidden_states) # 使用下采樣模組處理隱藏狀態
output_states = output_states + (hidden_states,) # 將處理後的隱藏狀態新增到輸出狀態元組中
return hidden_states, output_states # 返回當前的隱藏狀態和所有輸出狀態
# 定義一個上取樣的時空塊類,繼承自 nn.Module
class UpBlockSpatioTemporal(nn.Module):
# 初始化方法,定義類的引數
def __init__(
self,
in_channels: int, # 輸入通道數
prev_output_channel: int, # 前一層輸出通道數
out_channels: int, # 當前層輸出通道數
temb_channels: int, # 時間嵌入通道數
resolution_idx: Optional[int] = None, # 可選的解析度索引
num_layers: int = 1, # 層數,預設為1
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值
add_upsample: bool = True, # 是否新增上取樣層
):
# 呼叫父類的初始化方法
super().__init__()
# 初始化一個空的列表,用於儲存 ResNet 模組
resnets = []
# 根據層數建立對應的 ResNet 塊
for i in range(num_layers):
# 如果是最後一層,使用輸入通道數,否則使用輸出通道數
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 第一層的輸入通道為前一層輸出通道,後續層為輸出通道
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 建立時空 ResNet 塊並新增到列表中
resnets.append(
SpatioTemporalResBlock(
in_channels=resnet_in_channels + res_skip_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # ResNet 的 epsilon 值
)
)
# 將 ResNet 模組列表轉換為 nn.ModuleList,以便在模型中管理
self.resnets = nn.ModuleList(resnets)
# 如果需要新增上取樣層,則建立對應的 nn.ModuleList
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 如果不新增上取樣層,設定為 None
self.upsamplers = None
# 初始化梯度檢查點標誌
self.gradient_checkpointing = False
# 設定解析度索引
self.resolution_idx = resolution_idx
# 定義前向傳播方法
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 之前層的隱藏狀態元組
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
image_only_indicator: Optional[torch.Tensor] = None, # 可選的影像指示器張量
) -> torch.Tensor: # 定義函式返回型別為 PyTorch 的張量
for resnet in self.resnets: # 遍歷當前物件中的所有 ResNet 模型
# pop res hidden states # 從隱藏狀態元組中提取最後一個隱藏狀態
res_hidden_states = res_hidden_states_tuple[-1] # 獲取最後的 ResNet 隱藏狀態
res_hidden_states_tuple = res_hidden_states_tuple[:-1] # 更新元組,移除最後一個隱藏狀態
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) # 將當前隱藏狀態與 ResNet 隱藏狀態在維度 1 上拼接
if self.training and self.gradient_checkpointing: # 如果處於訓練狀態並啟用了梯度檢查點
def create_custom_forward(module): # 定義用於建立自定義前向傳播函式的內部函式
def custom_forward(*inputs): # 自定義前向傳播,接收任意數量的輸入
return module(*inputs) # 呼叫原始模組的前向傳播
return custom_forward # 返回自定義前向傳播函式
if is_torch_version(">=", "1.11.0"): # 檢查當前 PyTorch 版本是否大於等於 1.11.0
hidden_states = torch.utils.checkpoint.checkpoint( # 使用檢查點機制儲存記憶體
create_custom_forward(resnet), # 傳入自定義前向傳播函式
hidden_states, # 傳入當前的隱藏狀態
temb, # 傳入時間嵌入
image_only_indicator, # 傳入影像指示器
use_reentrant=False, # 禁用重入
)
else: # 如果 PyTorch 版本小於 1.11.0
hidden_states = torch.utils.checkpoint.checkpoint( # 使用檢查點機制儲存記憶體
create_custom_forward(resnet), # 傳入自定義前向傳播函式
hidden_states, # 傳入當前的隱藏狀態
temb, # 傳入時間嵌入
image_only_indicator, # 傳入影像指示器
)
else: # 如果不是訓練狀態或沒有啟用梯度檢查點
hidden_states = resnet( # 直接呼叫 ResNet 模型處理隱藏狀態
hidden_states, # 傳入當前的隱藏狀態
temb, # 傳入時間嵌入
image_only_indicator=image_only_indicator, # 傳入影像指示器
)
if self.upsamplers is not None: # 如果存在上取樣模組
for upsampler in self.upsamplers: # 遍歷所有上取樣模組
hidden_states = upsampler(hidden_states) # 呼叫上取樣模組處理隱藏狀態
return hidden_states # 返回處理後的隱藏狀態
# 定義一個時空交叉注意力上取樣塊類,繼承自 nn.Module
class CrossAttnUpBlockSpatioTemporal(nn.Module):
# 初始化方法,設定網路的各個引數
def __init__(
# 輸入通道數
in_channels: int,
# 輸出通道數
out_channels: int,
# 前一層輸出通道數
prev_output_channel: int,
# 時間嵌入通道數
temb_channels: int,
# 解析度索引,可選
resolution_idx: Optional[int] = None,
# 層數
num_layers: int = 1,
# 每個塊的變換器層數,支援單個整數或元組
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 的 epsilon 值,防止除零錯誤
resnet_eps: float = 1e-6,
# 注意力頭的數量
num_attention_heads: int = 1,
# 交叉注意力維度
cross_attention_dim: int = 1280,
# 是否新增上取樣層
add_upsample: bool = True,
):
# 呼叫父類的初始化方法
super().__init__()
# 儲存 ResNet 層的列表
resnets = []
# 儲存注意力層的列表
attentions = []
# 指示是否使用交叉注意力
self.has_cross_attention = True
# 設定注意力頭的數量
self.num_attention_heads = num_attention_heads
# 如果是整數,將其轉換為列表,包含 num_layers 個相同的元素
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# 遍歷每一層,構建 ResNet 和注意力層
for i in range(num_layers):
# 根據當前層是否是最後一層來決定跳過連線的通道數
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 根據當前層決定 ResNet 的輸入通道數
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 新增時空 ResNet 塊到列表中
resnets.append(
SpatioTemporalResBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
)
)
# 新增時空變換器模型到列表中
attentions.append(
TransformerSpatioTemporalModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
)
)
# 將注意力層列表轉換為 nn.ModuleList,以便於管理
self.attentions = nn.ModuleList(attentions)
# 將 ResNet 層列表轉換為 nn.ModuleList,以便於管理
self.resnets = nn.ModuleList(resnets)
# 如果需要,新增上取樣層
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 否則上取樣層設定為 None
self.upsamplers = None
# 設定梯度檢查點標誌為 False
self.gradient_checkpointing = False
# 儲存解析度索引
self.resolution_idx = resolution_idx
# 前向傳播方法,定義輸入和輸出
def forward(
# 隱藏狀態張量
hidden_states: torch.Tensor,
# 上一層隱藏狀態的元組
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可選的時間嵌入張量
temb: Optional[torch.Tensor] = None,
# 可選的編碼器隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可選的影像指示器張量
image_only_indicator: Optional[torch.Tensor] = None,
# 返回一個 torch.Tensor 型別的結果
) -> torch.Tensor:
# 遍歷每個 resnet 和 attention 模組的組合
for resnet, attn in zip(self.resnets, self.attentions):
# 從隱藏狀態元組中彈出最後一個 res 隱藏狀態
res_hidden_states = res_hidden_states_tuple[-1]
# 更新隱藏狀態元組,去掉最後一個元素
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 在指定維度連線當前的 hidden_states 和 res_hidden_states
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 如果處於訓練模式且開啟了梯度檢查點
if self.training and self.gradient_checkpointing: # TODO
# 定義一個用於建立自定義前向傳播函式的函式
def create_custom_forward(module, return_dict=None):
# 定義自定義前向傳播邏輯
def custom_forward(*inputs):
# 根據是否返回字典,呼叫模組的前向傳播
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
# 根據 PyTorch 版本選擇 checkpoint 的引數
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 使用檢查點儲存記憶體,呼叫自定義前向傳播
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
# 透過 attention 模組處理 hidden_states
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
else:
# 如果不使用檢查點,直接透過 resnet 模組處理 hidden_states
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
# 透過 attention 模組處理 hidden_states
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
# 如果存在上取樣模組,逐個應用於 hidden_states
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
# 返回處理後的 hidden_states
return hidden_states