diffusers 原始碼解析(十三)
.\diffusers\models\unets\unet_2d.py
# 版權宣告,表示該程式碼由 HuggingFace 團隊所有
#
# 根據 Apache 2.0 許可證進行許可;
# 除非遵循許可證,否則不得使用此檔案。
# 可以在以下地址獲取許可證的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律要求或書面同意,軟體按 "原樣" 分發,
# 不提供任何形式的保證或條件,無論是明示或暗示。
# 請參閱許可證以瞭解有關許可權和
# 限制的具體資訊。
from dataclasses import dataclass # 從 dataclasses 模組匯入 dataclass 裝飾器
from typing import Optional, Tuple, Union # 從 typing 模組匯入型別提示相關的類
import torch # 匯入 PyTorch 庫
import torch.nn as nn # 匯入 PyTorch 的神經網路模組
from ...configuration_utils import ConfigMixin, register_to_config # 從配置工具匯入配置混合類和註冊函式
from ...utils import BaseOutput # 從工具模組匯入基礎輸出類
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps # 從嵌入模組匯入相關類
from ..modeling_utils import ModelMixin # 從建模工具匯入模型混合類
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block # 從 UNet 2D 塊匯入相關元件
@dataclass
class UNet2DOutput(BaseOutput): # 定義 UNet2DOutput 類,繼承自 BaseOutput
"""
[`UNet2DModel`] 的輸出類。
引數:
sample (`torch.Tensor` 形狀為 `(batch_size, num_channels, height, width)`):
從模型最後一層輸出的隱藏狀態。
"""
sample: torch.Tensor # 定義輸出樣本的屬性,型別為 torch.Tensor
class UNet2DModel(ModelMixin, ConfigMixin): # 定義 UNet2DModel 類,繼承自 ModelMixin 和 ConfigMixin
r"""
一個 2D UNet 模型,接收一個有噪聲的樣本和一個時間步,返回一個樣本形狀的輸出。
此模型繼承自 [`ModelMixin`]。請檢視超類文件以瞭解其為所有模型實現的通用方法
(例如下載或儲存)。
"""
@register_to_config # 使用裝飾器將該方法註冊到配置中
def __init__( # 定義初始化方法
self,
sample_size: Optional[Union[int, Tuple[int, int]]] = None, # 可選的樣本大小,可以是整數或整數元組
in_channels: int = 3, # 輸入通道數,預設為 3
out_channels: int = 3, # 輸出通道數,預設為 3
center_input_sample: bool = False, # 是否將輸入樣本居中,預設為 False
time_embedding_type: str = "positional", # 時間嵌入型別,預設為 "positional"
freq_shift: int = 0, # 頻率偏移量,預設為 0
flip_sin_to_cos: bool = True, # 是否將正弦函式翻轉為餘弦函式,預設為 True
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), # 下采樣塊型別
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), # 上取樣塊型別
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), # 各塊輸出通道數
layers_per_block: int = 2, # 每個塊的層數,預設為 2
mid_block_scale_factor: float = 1, # 中間塊縮放因子,預設為 1
downsample_padding: int = 1, # 下采樣的填充大小,預設為 1
downsample_type: str = "conv", # 下采樣型別,預設為卷積
upsample_type: str = "conv", # 上取樣型別,預設為卷積
dropout: float = 0.0, # dropout 機率,預設為 0.0
act_fn: str = "silu", # 啟用函式型別,預設為 "silu"
attention_head_dim: Optional[int] = 8, # 注意力頭維度,預設為 8
norm_num_groups: int = 32, # 規範化組數量,預設為 32
attn_norm_num_groups: Optional[int] = None, # 注意力規範化組數量,可選
norm_eps: float = 1e-5, # 規範化的 epsilon 值,預設為 1e-5
resnet_time_scale_shift: str = "default", # ResNet 時間縮放偏移型別,預設為 "default"
add_attention: bool = True, # 是否新增註意力機制,預設為 True
class_embed_type: Optional[str] = None, # 類別嵌入型別,可選
num_class_embeds: Optional[int] = None, # 類別嵌入數量,可選
num_train_timesteps: Optional[int] = None, # 訓練時間步數量,可選
# 定義一個名為 forward 的方法
def forward(
# 輸入引數 sample,型別為 torch.Tensor,表示樣本資料
self,
sample: torch.Tensor,
# 輸入引數 timestep,可以是 torch.Tensor、float 或 int,表示時間步
timestep: Union[torch.Tensor, float, int],
# 可選引數 class_labels,型別為 torch.Tensor,表示分類標籤,預設為 None
class_labels: Optional[torch.Tensor] = None,
# 可選引數 return_dict,型別為 bool,表示是否以字典形式返回結果,預設為 True
return_dict: bool = True,
.\diffusers\models\unets\unet_2d_blocks.py
# 版權所有 2024 HuggingFace 團隊,保留所有權利。
#
# 根據 Apache 許可證第 2.0 版(“許可證”)進行許可;
# 除非遵循許可證,否則不得使用此檔案。
# 可以在以下網址獲取許可證副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面協議另有規定,否則根據許可證分發的軟體按“現狀”基礎分發,
# 不提供任何形式的保證或條件,無論是明示還是暗示。
# 請參閱許可證以瞭解特定語言的許可權和
# 限制。
from typing import Any, Dict, Optional, Tuple, Union # 匯入型別註解相關的模組
import numpy as np # 匯入 NumPy 庫用於數值計算
import torch # 匯入 PyTorch 庫用於深度學習
import torch.nn.functional as F # 匯入 PyTorch 的功能性 API
from torch import nn # 匯入 PyTorch 的神經網路模組
from ...utils import deprecate, is_torch_version, logging # 從工具模組匯入日誌和版本檢測功能
from ...utils.torch_utils import apply_freeu # 從 PyTorch 工具模組匯入 apply_freeu 函式
from ..activations import get_activation # 從啟用函式模組匯入 get_activation 函式
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 # 匯入注意力處理器相關的類
from ..normalization import AdaGroupNorm # 從歸一化模組匯入 AdaGroupNorm 類
from ..resnet import ( # 從 ResNet 模組匯入多個下采樣和上取樣類
Downsample2D,
FirDownsample2D,
FirUpsample2D,
KDownsample2D,
KUpsample2D,
ResnetBlock2D,
ResnetBlockCondNorm2D,
Upsample2D,
)
from ..transformers.dual_transformer_2d import DualTransformer2DModel # 匯入雙重變換器模型
from ..transformers.transformer_2d import Transformer2DModel # 匯入二維變換器模型
logger = logging.get_logger(__name__) # 獲取當前模組的日誌記錄器例項
def get_down_block( # 定義獲取下采樣塊的函式
down_block_type: str, # 下采樣塊的型別
num_layers: int, # 下采樣層的數量
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
add_downsample: bool, # 是否新增下采樣標誌
resnet_eps: float, # ResNet 中的 epsilon 引數
resnet_act_fn: str, # ResNet 使用的啟用函式
transformer_layers_per_block: int = 1, # 每個塊中的變換器層數,預設為 1
num_attention_heads: Optional[int] = None, # 注意力頭的數量,預設為 None
resnet_groups: Optional[int] = None, # ResNet 中的組數,預設為 None
cross_attention_dim: Optional[int] = None, # 交叉注意力維度,預設為 None
downsample_padding: Optional[int] = None, # 下采樣填充引數,預設為 None
dual_cross_attention: bool = False, # 是否使用雙重交叉注意力標誌
use_linear_projection: bool = False, # 是否使用線性投影標誌
only_cross_attention: bool = False, # 是否僅使用交叉注意力標誌
upcast_attention: bool = False, # 是否上升注意力標誌
resnet_time_scale_shift: str = "default", # ResNet 時間縮放移位,預設為“default”
attention_type: str = "default", # 注意力型別,預設為“default”
resnet_skip_time_act: bool = False, # ResNet 跳過時間啟用標誌
resnet_out_scale_factor: float = 1.0, # ResNet 輸出縮放因子,預設為 1.0
cross_attention_norm: Optional[str] = None, # 交叉注意力歸一化,預設為 None
attention_head_dim: Optional[int] = None, # 注意力頭維度,預設為 None
downsample_type: Optional[str] = None, # 下采樣型別,預設為 None
dropout: float = 0.0, # dropout 比例,預設為 0.0
):
# 如果沒有定義注意力頭維度,預設設定為頭的數量
if attention_head_dim is None:
logger.warning( # 記錄警告資訊
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." # 提醒使用者使用預設的注意力頭維度
)
attention_head_dim = num_attention_heads # 將注意力頭維度設定為頭的數量
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type # 處理下采樣塊型別的字串
# 檢查下行塊的型別是否為 "DownBlock2D"
if down_block_type == "DownBlock2D":
# 返回 DownBlock2D 例項,傳入相關引數
return DownBlock2D(
# 傳入層數
num_layers=num_layers,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# dropout 比率
dropout=dropout,
# 是否新增下采樣
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的啟用函式
resnet_act_fn=resnet_act_fn,
# ResNet 的分組數
resnet_groups=resnet_groups,
# 下采樣的填充
downsample_padding=downsample_padding,
# ResNet 的時間尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 檢查下行塊的型別是否為 "ResnetDownsampleBlock2D"
elif down_block_type == "ResnetDownsampleBlock2D":
# 返回 ResnetDownsampleBlock2D 例項,傳入相關引數
return ResnetDownsampleBlock2D(
# 傳入層數
num_layers=num_layers,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# dropout 比率
dropout=dropout,
# 是否新增下采樣
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的啟用函式
resnet_act_fn=resnet_act_fn,
# ResNet 的分組數
resnet_groups=resnet_groups,
# ResNet 的時間尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# ResNet 的時間啟用跳過標誌
skip_time_act=resnet_skip_time_act,
# ResNet 的輸出縮放因子
output_scale_factor=resnet_out_scale_factor,
)
# 檢查下行塊的型別是否為 "AttnDownBlock2D"
elif down_block_type == "AttnDownBlock2D":
# 如果不新增下采樣,則將下采樣型別設為 None
if add_downsample is False:
downsample_type = None
else:
# 如果新增下采樣,則預設下采樣型別為 'conv'
downsample_type = downsample_type or "conv" # default to 'conv'
# 返回 AttnDownBlock2D 例項,傳入相關引數
return AttnDownBlock2D(
# 傳入層數
num_layers=num_layers,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# dropout 比率
dropout=dropout,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的啟用函式
resnet_act_fn=resnet_act_fn,
# ResNet 的分組數
resnet_groups=resnet_groups,
# 下采樣的填充
downsample_padding=downsample_padding,
# 注意力頭的維度
attention_head_dim=attention_head_dim,
# ResNet 的時間尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# 下采樣型別
downsample_type=downsample_type,
)
# 檢查下行塊型別是否為 CrossAttnDownBlock2D
elif down_block_type == "CrossAttnDownBlock2D":
# 如果 cross_attention_dim 未指定,則丟擲錯誤
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
# 返回 CrossAttnDownBlock2D 例項,使用提供的引數進行初始化
return CrossAttnDownBlock2D(
# 設定層的數量
num_layers=num_layers,
# 每個塊的變換層數量
transformer_layers_per_block=transformer_layers_per_block,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# dropout 比例
dropout=dropout,
# 是否新增下采樣
add_downsample=add_downsample,
# ResNet 中的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 啟用函式
resnet_act_fn=resnet_act_fn,
# ResNet 分組數
resnet_groups=resnet_groups,
# 下采樣填充
downsample_padding=downsample_padding,
# 跨注意力維度
cross_attention_dim=cross_attention_dim,
# 注意力頭數
num_attention_heads=num_attention_heads,
# 是否使用雙向跨注意力
dual_cross_attention=dual_cross_attention,
# 是否使用線性投影
use_linear_projection=use_linear_projection,
# 是否僅使用跨注意力
only_cross_attention=only_cross_attention,
# 是否上溯注意力
upcast_attention=upcast_attention,
# ResNet 時間尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# 注意力型別
attention_type=attention_type,
)
# 檢查下行塊型別是否為 SimpleCrossAttnDownBlock2D
elif down_block_type == "SimpleCrossAttnDownBlock2D":
# 如果 cross_attention_dim 未指定,則丟擲錯誤
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
# 返回 SimpleCrossAttnDownBlock2D 例項,使用提供的引數進行初始化
return SimpleCrossAttnDownBlock2D(
# 設定層的數量
num_layers=num_layers,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# dropout 比例
dropout=dropout,
# 是否新增下采樣
add_downsample=add_downsample,
# ResNet 中的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 啟用函式
resnet_act_fn=resnet_act_fn,
# ResNet 分組數
resnet_groups=resnet_groups,
# 跨注意力維度
cross_attention_dim=cross_attention_dim,
# 注意力頭的維度
attention_head_dim=attention_head_dim,
# ResNet 時間尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# 是否跳過時間啟用
skip_time_act=resnet_skip_time_act,
# 輸出縮放因子
output_scale_factor=resnet_out_scale_factor,
# 是否僅使用跨注意力
only_cross_attention=only_cross_attention,
# 跨注意力規範化
cross_attention_norm=cross_attention_norm,
)
# 檢查下行塊型別是否為 SkipDownBlock2D
elif down_block_type == "SkipDownBlock2D":
# 返回 SkipDownBlock2D 例項,使用提供的引數進行初始化
return SkipDownBlock2D(
# 設定層的數量
num_layers=num_layers,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# dropout 比例
dropout=dropout,
# 是否新增下采樣
add_downsample=add_downsample,
# ResNet 中的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 啟用函式
resnet_act_fn=resnet_act_fn,
# 下采樣填充
downsample_padding=downsample_padding,
# ResNet 時間尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 檢查下采樣塊的型別是否為 "AttnSkipDownBlock2D"
elif down_block_type == "AttnSkipDownBlock2D":
# 返回一個 AttnSkipDownBlock2D 物件,初始化引數傳入
return AttnSkipDownBlock2D(
# 設定層數
num_layers=num_layers,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# dropout 率
dropout=dropout,
# 是否新增下采樣
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的啟用函式
resnet_act_fn=resnet_act_fn,
# 注意力頭的維度
attention_head_dim=attention_head_dim,
# ResNet 時間縮放偏移
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 檢查下采樣塊的型別是否為 "DownEncoderBlock2D"
elif down_block_type == "DownEncoderBlock2D":
# 返回一個 DownEncoderBlock2D 物件,初始化引數傳入
return DownEncoderBlock2D(
# 設定層數
num_layers=num_layers,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# dropout 率
dropout=dropout,
# 是否新增下采樣
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的啟用函式
resnet_act_fn=resnet_act_fn,
# ResNet 的組數
resnet_groups=resnet_groups,
# 下采樣填充
downsample_padding=downsample_padding,
# ResNet 時間縮放偏移
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 檢查下采樣塊的型別是否為 "AttnDownEncoderBlock2D"
elif down_block_type == "AttnDownEncoderBlock2D":
# 返回一個 AttnDownEncoderBlock2D 物件,初始化引數傳入
return AttnDownEncoderBlock2D(
# 設定層數
num_layers=num_layers,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# dropout 率
dropout=dropout,
# 是否新增下采樣
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的啟用函式
resnet_act_fn=resnet_act_fn,
# ResNet 的組數
resnet_groups=resnet_groups,
# 下采樣填充
downsample_padding=downsample_padding,
# 注意力頭的維度
attention_head_dim=attention_head_dim,
# ResNet 時間縮放偏移
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 檢查下采樣塊的型別是否為 "KDownBlock2D"
elif down_block_type == "KDownBlock2D":
# 返回一個 KDownBlock2D 物件,初始化引數傳入
return KDownBlock2D(
# 設定層數
num_layers=num_layers,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# dropout 率
dropout=dropout,
# 是否新增下采樣
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的啟用函式
resnet_act_fn=resnet_act_fn,
)
# 檢查下采樣塊的型別是否為 "KCrossAttnDownBlock2D"
elif down_block_type == "KCrossAttnDownBlock2D":
# 返回一個 KCrossAttnDownBlock2D 物件,初始化引數傳入
return KCrossAttnDownBlock2D(
# 設定層數
num_layers=num_layers,
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# dropout 率
dropout=dropout,
# 是否新增下采樣
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的啟用函式
resnet_act_fn=resnet_act_fn,
# 跨注意力的維度
cross_attention_dim=cross_attention_dim,
# 注意力頭的維度
attention_head_dim=attention_head_dim,
# 是否新增自注意力
add_self_attention=True if not add_downsample else False,
)
# 如果下采樣塊型別不匹配,則丟擲異常
raise ValueError(f"{down_block_type} does not exist.")
# 根據給定引數生成中間塊(mid block)
def get_mid_block(
# 中間塊的型別
mid_block_type: str,
# 嵌入通道數
temb_channels: int,
# 輸入通道數
in_channels: int,
# ResNet 的 epsilon 值
resnet_eps: float,
# ResNet 的啟用函式型別
resnet_act_fn: str,
# ResNet 的組數
resnet_groups: int,
# 輸出縮放因子,預設為 1.0
output_scale_factor: float = 1.0,
# 每個塊的變換層數,預設為 1
transformer_layers_per_block: int = 1,
# 注意力頭的數量,預設為 None
num_attention_heads: Optional[int] = None,
# 跨注意力的維度,預設為 None
cross_attention_dim: Optional[int] = None,
# 是否使用雙重跨注意力,預設為 False
dual_cross_attention: bool = False,
# 是否使用線性投影,預設為 False
use_linear_projection: bool = False,
# 是否僅使用跨注意力作為中間塊,預設為 False
mid_block_only_cross_attention: bool = False,
# 是否提升注意力精度,預設為 False
upcast_attention: bool = False,
# ResNet 的時間縮放偏移,預設為 "default"
resnet_time_scale_shift: str = "default",
# 注意力型別,預設為 "default"
attention_type: str = "default",
# ResNet 是否跳過時間啟用,預設為 False
resnet_skip_time_act: bool = False,
# 跨注意力的歸一化型別,預設為 None
cross_attention_norm: Optional[str] = None,
# 注意力頭的維度,預設為 1
attention_head_dim: Optional[int] = 1,
# dropout 機率,預設為 0.0
dropout: float = 0.0,
):
# 根據中間塊的型別生成對應的物件
if mid_block_type == "UNetMidBlock2DCrossAttn":
# 建立 UNet 的 2D 跨注意力中間塊
return UNetMidBlock2DCrossAttn(
# 設定變換層數
transformer_layers_per_block=transformer_layers_per_block,
# 設定輸入通道數
in_channels=in_channels,
# 設定嵌入通道數
temb_channels=temb_channels,
# 設定 dropout 機率
dropout=dropout,
# 設定 ResNet epsilon 值
resnet_eps=resnet_eps,
# 設定 ResNet 啟用函式
resnet_act_fn=resnet_act_fn,
# 設定輸出縮放因子
output_scale_factor=output_scale_factor,
# 設定時間縮放偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# 設定跨注意力維度
cross_attention_dim=cross_attention_dim,
# 設定注意力頭數量
num_attention_heads=num_attention_heads,
# 設定 ResNet 組數
resnet_groups=resnet_groups,
# 設定是否使用雙重跨注意力
dual_cross_attention=dual_cross_attention,
# 設定是否使用線性投影
use_linear_projection=use_linear_projection,
# 設定是否提升注意力精度
upcast_attention=upcast_attention,
# 設定注意力型別
attention_type=attention_type,
)
# 檢查是否為簡單跨注意力中間塊
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
# 建立 UNet 的 2D 簡單跨注意力中間塊
return UNetMidBlock2DSimpleCrossAttn(
# 設定輸入通道數
in_channels=in_channels,
# 設定嵌入通道數
temb_channels=temb_channels,
# 設定 dropout 機率
dropout=dropout,
# 設定 ResNet epsilon 值
resnet_eps=resnet_eps,
# 設定 ResNet 啟用函式
resnet_act_fn=resnet_act_fn,
# 設定輸出縮放因子
output_scale_factor=output_scale_factor,
# 設定跨注意力維度
cross_attention_dim=cross_attention_dim,
# 設定注意力頭的維度
attention_head_dim=attention_head_dim,
# 設定 ResNet 組數
resnet_groups=resnet_groups,
# 設定時間縮放偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# 設定是否跳過時間啟用
skip_time_act=resnet_skip_time_act,
# 設定是否僅使用跨注意力
only_cross_attention=mid_block_only_cross_attention,
# 設定跨注意力的歸一化型別
cross_attention_norm=cross_attention_norm,
)
# 檢查是否為標準的 2D 中間塊
elif mid_block_type == "UNetMidBlock2D":
# 建立 UNet 的 2D 中間塊
return UNetMidBlock2D(
# 設定輸入通道數
in_channels=in_channels,
# 設定嵌入通道數
temb_channels=temb_channels,
# 設定 dropout 機率
dropout=dropout,
# 設定層數為 0
num_layers=0,
# 設定 ResNet epsilon 值
resnet_eps=resnet_eps,
# 設定 ResNet 啟用函式
resnet_act_fn=resnet_act_fn,
# 設定輸出縮放因子
output_scale_factor=output_scale_factor,
# 設定 ResNet 組數
resnet_groups=resnet_groups,
# 設定時間縮放偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# 不新增註意力
add_attention=False,
)
# 檢查中間塊型別是否為 None
elif mid_block_type is None:
# 返回 None
return None
# 丟擲未知型別的異常
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
# 輸出通道的數量
out_channels: int,
# 前一層輸出通道的數量
prev_output_channel: int,
# 嵌入層通道的數量
temb_channels: int,
# 是否新增上取樣層
add_upsample: bool,
# ResNet 中的 epsilon 值,用於數值穩定性
resnet_eps: float,
# ResNet 中的啟用函式型別
resnet_act_fn: str,
# 解析度索引,預設為 None
resolution_idx: Optional[int] = None,
# 每個塊中的變換層數量
transformer_layers_per_block: int = 1,
# 注意力頭的數量,預設為 None
num_attention_heads: Optional[int] = None,
# ResNet 中的組數量,預設為 None
resnet_groups: Optional[int] = None,
# 交叉注意力的維度,預設為 None
cross_attention_dim: Optional[int] = None,
# 是否使用雙重交叉注意力
dual_cross_attention: bool = False,
# 是否使用線性投影
use_linear_projection: bool = False,
# 是否僅使用交叉注意力
only_cross_attention: bool = False,
# 是否上取樣時提高注意力精度
upcast_attention: bool = False,
# ResNet 時間縮放偏移的型別,預設為 "default"
resnet_time_scale_shift: str = "default",
# 注意力型別,預設為 "default"
attention_type: str = "default",
# ResNet 中跳過時間啟用的標誌
resnet_skip_time_act: bool = False,
# ResNet 輸出縮放因子,預設為 1.0
resnet_out_scale_factor: float = 1.0,
# 交叉注意力的歸一化方式,預設為 None
cross_attention_norm: Optional[str] = None,
# 注意力頭的維度,預設為 None
attention_head_dim: Optional[int] = None,
# 上取樣型別,預設為 None
upsample_type: Optional[str] = None,
# 丟棄率,預設為 0.0
dropout: float = 0.0,
) -> nn.Module: # 指定該函式的返回型別為 nn.Module
# 如果未定義注意力頭的維度,預設設定為注意力頭的數量
if attention_head_dim is None:
logger.warning( # 記錄警告資訊
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." # 提示使用者提供 attention_head_dim
)
attention_head_dim = num_attention_heads # 將 attention_head_dim 設定為 num_attention_heads
# 如果 up_block_type 以 "UNetRes" 開頭,則去掉字首
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
# 檢查 up_block_type 是否為 "UpBlock2D"
if up_block_type == "UpBlock2D":
return UpBlock2D( # 返回 UpBlock2D 物件
num_layers=num_layers, # 設定網路層數
in_channels=in_channels, # 設定輸入通道數
out_channels=out_channels, # 設定輸出通道數
prev_output_channel=prev_output_channel, # 設定前一個輸出通道
temb_channels=temb_channels, # 設定時間嵌入通道數
resolution_idx=resolution_idx, # 設定解析度索引
dropout=dropout, # 設定 dropout 引數
add_upsample=add_upsample, # 設定是否新增上取樣
resnet_eps=resnet_eps, # 設定 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 設定 ResNet 的啟用函式
resnet_groups=resnet_groups, # 設定 ResNet 的組數
resnet_time_scale_shift=resnet_time_scale_shift, # 設定 ResNet 的時間縮放偏移
)
# 檢查 up_block_type 是否為 "ResnetUpsampleBlock2D"
elif up_block_type == "ResnetUpsampleBlock2D":
return ResnetUpsampleBlock2D( # 返回 ResnetUpsampleBlock2D 物件
num_layers=num_layers, # 設定網路層數
in_channels=in_channels, # 設定輸入通道數
out_channels=out_channels, # 設定輸出通道數
prev_output_channel=prev_output_channel, # 設定前一個輸出通道
temb_channels=temb_channels, # 設定時間嵌入通道數
resolution_idx=resolution_idx, # 設定解析度索引
dropout=dropout, # 設定 dropout 引數
add_upsample=add_upsample, # 設定是否新增上取樣
resnet_eps=resnet_eps, # 設定 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 設定 ResNet 的啟用函式
resnet_groups=resnet_groups, # 設定 ResNet 的組數
resnet_time_scale_shift=resnet_time_scale_shift, # 設定 ResNet 的時間縮放偏移
skip_time_act=resnet_skip_time_act, # 設定是否跳過時間啟用
output_scale_factor=resnet_out_scale_factor, # 設定輸出縮放因子
)
# 檢查 up_block_type 是否為 "CrossAttnUpBlock2D"
elif up_block_type == "CrossAttnUpBlock2D":
# 如果未定義交叉注意力維度,則丟擲異常
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") # 丟擲值錯誤
return CrossAttnUpBlock2D( # 返回 CrossAttnUpBlock2D 物件
num_layers=num_layers, # 設定網路層數
transformer_layers_per_block=transformer_layers_per_block, # 設定每個塊的變換層數
in_channels=in_channels, # 設定輸入通道數
out_channels=out_channels, # 設定輸出通道數
prev_output_channel=prev_output_channel, # 設定前一個輸出通道
temb_channels=temb_channels, # 設定時間嵌入通道數
resolution_idx=resolution_idx, # 設定解析度索引
dropout=dropout, # 設定 dropout 引數
add_upsample=add_upsample, # 設定是否新增上取樣
resnet_eps=resnet_eps, # 設定 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 設定 ResNet 的啟用函式
resnet_groups=resnet_groups, # 設定 ResNet 的組數
cross_attention_dim=cross_attention_dim, # 設定交叉注意力維度
num_attention_heads=num_attention_heads, # 設定注意力頭的數量
dual_cross_attention=dual_cross_attention, # 設定雙重交叉注意力
use_linear_projection=use_linear_projection, # 設定是否使用線性投影
only_cross_attention=only_cross_attention, # 設定是否僅使用交叉注意力
upcast_attention=upcast_attention, # 設定是否提升注意力
resnet_time_scale_shift=resnet_time_scale_shift, # 設定 ResNet 的時間縮放偏移
attention_type=attention_type, # 設定注意力型別
)
# 檢查上取樣塊的型別是否為 SimpleCrossAttnUpBlock2D
elif up_block_type == "SimpleCrossAttnUpBlock2D":
# 如果未指定 cross_attention_dim,則丟擲值錯誤
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
# 返回 SimpleCrossAttnUpBlock2D 例項,使用相關引數初始化
return SimpleCrossAttnUpBlock2D(
num_layers=num_layers, # 設定層數
in_channels=in_channels, # 設定輸入通道數
out_channels=out_channels, # 設定輸出通道數
prev_output_channel=prev_output_channel, # 設定前一層的輸出通道數
temb_channels=temb_channels, # 設定時間嵌入通道數
resolution_idx=resolution_idx, # 設定解析度索引
dropout=dropout, # 設定 dropout 機率
add_upsample=add_upsample, # 設定是否新增上取樣
resnet_eps=resnet_eps, # 設定 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 設定 ResNet 的啟用函式
resnet_groups=resnet_groups, # 設定 ResNet 的組數
cross_attention_dim=cross_attention_dim, # 設定交叉注意力維度
attention_head_dim=attention_head_dim, # 設定注意力頭維度
resnet_time_scale_shift=resnet_time_scale_shift, # 設定 ResNet 的時間縮放偏移
skip_time_act=resnet_skip_time_act, # 設定是否跳過時間啟用
output_scale_factor=resnet_out_scale_factor, # 設定輸出縮放因子
only_cross_attention=only_cross_attention, # 設定是否僅使用交叉注意力
cross_attention_norm=cross_attention_norm, # 設定交叉注意力的歸一化方式
)
# 檢查上取樣塊的型別是否為 AttnUpBlock2D
elif up_block_type == "AttnUpBlock2D":
# 如果未新增上取樣,則將上取樣型別設為 None
if add_upsample is False:
upsample_type = None
else:
# 預設將上取樣型別設為 'conv'
upsample_type = upsample_type or "conv"
# 返回 AttnUpBlock2D 例項,使用相關引數初始化
return AttnUpBlock2D(
num_layers=num_layers, # 設定層數
in_channels=in_channels, # 設定輸入通道數
out_channels=out_channels, # 設定輸出通道數
prev_output_channel=prev_output_channel, # 設定前一層的輸出通道數
temb_channels=temb_channels, # 設定時間嵌入通道數
resolution_idx=resolution_idx, # 設定解析度索引
dropout=dropout, # 設定 dropout 機率
resnet_eps=resnet_eps, # 設定 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 設定 ResNet 的啟用函式
resnet_groups=resnet_groups, # 設定 ResNet 的組數
attention_head_dim=attention_head_dim, # 設定注意力頭維度
resnet_time_scale_shift=resnet_time_scale_shift, # 設定 ResNet 的時間縮放偏移
upsample_type=upsample_type, # 設定上取樣型別
)
# 檢查上取樣塊的型別是否為 SkipUpBlock2D
elif up_block_type == "SkipUpBlock2D":
# 返回 SkipUpBlock2D 例項,使用相關引數初始化
return SkipUpBlock2D(
num_layers=num_layers, # 設定層數
in_channels=in_channels, # 設定輸入通道數
out_channels=out_channels, # 設定輸出通道數
prev_output_channel=prev_output_channel, # 設定前一層的輸出通道數
temb_channels=temb_channels, # 設定時間嵌入通道數
resolution_idx=resolution_idx, # 設定解析度索引
dropout=dropout, # 設定 dropout 機率
add_upsample=add_upsample, # 設定是否新增上取樣
resnet_eps=resnet_eps, # 設定 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 設定 ResNet 的啟用函式
resnet_time_scale_shift=resnet_time_scale_shift, # 設定 ResNet 的時間縮放偏移
)
# 檢查上取樣塊的型別是否為 AttnSkipUpBlock2D
elif up_block_type == "AttnSkipUpBlock2D":
# 返回 AttnSkipUpBlock2D 例項,使用相關引數初始化
return AttnSkipUpBlock2D(
num_layers=num_layers, # 設定層數
in_channels=in_channels, # 設定輸入通道數
out_channels=out_channels, # 設定輸出通道數
prev_output_channel=prev_output_channel, # 設定前一層的輸出通道數
temb_channels=temb_channels, # 設定時間嵌入通道數
resolution_idx=resolution_idx, # 設定解析度索引
dropout=dropout, # 設定 dropout 機率
add_upsample=add_upsample, # 設定是否新增上取樣
resnet_eps=resnet_eps, # 設定 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 設定 ResNet 的啟用函式
attention_head_dim=attention_head_dim, # 設定注意力頭維度
resnet_time_scale_shift=resnet_time_scale_shift, # 設定 ResNet 的時間縮放偏移
)
# 檢查上取樣塊型別是否為 UpDecoderBlock2D
elif up_block_type == "UpDecoderBlock2D":
# 返回 UpDecoderBlock2D 的例項,傳入相應引數
return UpDecoderBlock2D(
num_layers=num_layers, # 設定層數
in_channels=in_channels, # 設定輸入通道數
out_channels=out_channels, # 設定輸出通道數
resolution_idx=resolution_idx, # 設定解析度索引
dropout=dropout, # 設定 dropout 比例
add_upsample=add_upsample, # 設定是否新增上取樣
resnet_eps=resnet_eps, # 設定 ResNet 的 epsilon 值
resnet_act_fn=resnet_act_fn, # 設定 ResNet 的啟用函式
resnet_groups=resnet_groups, # 設定 ResNet 的組數
resnet_time_scale_shift=resnet_time_scale_shift, # 設定時間尺度偏移
temb_channels=temb_channels, # 設定時間嵌入通道數
)
# 檢查上取樣塊型別是否為 AttnUpDecoderBlock2D
elif up_block_type == "AttnUpDecoderBlock2D":
# 返回 AttnUpDecoderBlock2D 的例項,傳入相應引數
return AttnUpDecoderBlock2D(
num_layers=num_layers, # 設定層數
in_channels=in_channels, # 設定輸入通道數
out_channels=out_channels, # 設定輸出通道數
resolution_idx=resolution_idx, # 設定解析度索引
dropout=dropout, # 設定 dropout 比例
add_upsample=add_upsample, # 設定是否新增上取樣
resnet_eps=resnet_eps, # 設定 ResNet 的 epsilon 值
resnet_act_fn=resnet_act_fn, # 設定 ResNet 的啟用函式
resnet_groups=resnet_groups, # 設定 ResNet 的組數
attention_head_dim=attention_head_dim, # 設定注意力頭維度
resnet_time_scale_shift=resnet_time_scale_shift, # 設定時間尺度偏移
temb_channels=temb_channels, # 設定時間嵌入通道數
)
# 檢查上取樣塊型別是否為 KUpBlock2D
elif up_block_type == "KUpBlock2D":
# 返回 KUpBlock2D 的例項,傳入相應引數
return KUpBlock2D(
num_layers=num_layers, # 設定層數
in_channels=in_channels, # 設定輸入通道數
out_channels=out_channels, # 設定輸出通道數
temb_channels=temb_channels, # 設定時間嵌入通道數
resolution_idx=resolution_idx, # 設定解析度索引
dropout=dropout, # 設定 dropout 比例
add_upsample=add_upsample, # 設定是否新增上取樣
resnet_eps=resnet_eps, # 設定 ResNet 的 epsilon 值
resnet_act_fn=resnet_act_fn, # 設定 ResNet 的啟用函式
)
# 檢查上取樣塊型別是否為 KCrossAttnUpBlock2D
elif up_block_type == "KCrossAttnUpBlock2D":
# 返回 KCrossAttnUpBlock2D 的例項,傳入相應引數
return KCrossAttnUpBlock2D(
num_layers=num_layers, # 設定層數
in_channels=in_channels, # 設定輸入通道數
out_channels=out_channels, # 設定輸出通道數
temb_channels=temb_channels, # 設定時間嵌入通道數
resolution_idx=resolution_idx, # 設定解析度索引
dropout=dropout, # 設定 dropout 比例
add_upsample=add_upsample, # 設定是否新增上取樣
resnet_eps=resnet_eps, # 設定 ResNet 的 epsilon 值
resnet_act_fn=resnet_act_fn, # 設定 ResNet 的啟用函式
cross_attention_dim=cross_attention_dim, # 設定交叉注意力維度
attention_head_dim=attention_head_dim, # 設定注意力頭維度
)
# 如果未匹配到任何上取樣塊型別,丟擲值錯誤
raise ValueError(f"{up_block_type} does not exist.")
# 定義一個小型自編碼器塊,繼承自 nn.Module
class AutoencoderTinyBlock(nn.Module):
"""
Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
blocks.
Args:
in_channels (`int`): The number of input channels.
out_channels (`int`): The number of output channels.
act_fn (`str`):
` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
Returns:
`torch.Tensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
`out_channels`.
"""
# 初始化函式,接受輸入通道數、輸出通道數和啟用函式型別
def __init__(self, in_channels: int, out_channels: int, act_fn: str):
# 呼叫父類初始化
super().__init__()
# 獲取指定的啟用函式
act_fn = get_activation(act_fn)
# 定義一個序列,包括多個卷積層和啟用函式
self.conv = nn.Sequential(
# 第一層卷積,輸入通道數、輸出通道數、卷積核大小和填充方式
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
# 新增啟用函式
act_fn,
# 第二層卷積,保持輸出通道數
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
# 新增啟用函式
act_fn,
# 第三層卷積,保持輸出通道數
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
)
# 判斷輸入和輸出通道是否相同,決定使用卷積或身份對映
self.skip = (
# 如果通道數不一致,使用 1x1 卷積進行跳躍連線
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
if in_channels != out_channels
else nn.Identity()
)
# 使用 ReLU 進行特徵融合
self.fuse = nn.ReLU()
# 定義前向傳播函式,接受輸入張量 x
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 返回卷積輸出和跳躍連線的和,經過融合啟用函式處理
return self.fuse(self.conv(x) + self.skip(x))
# 定義一個 2D UNet 中間塊,繼承自 nn.Module
class UNetMidBlock2D(nn.Module):
"""
A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
# 引數說明
Args:
in_channels (`int`): 輸入通道的數量。
temb_channels (`int`): 時間嵌入通道的數量。
dropout (`float`, *optional*, defaults to 0.0): dropout 比率,用於防止過擬合。
num_layers (`int`, *optional*, defaults to 1): 殘差塊的數量。
resnet_eps (`float`, *optional*, 1e-6 ): 殘差塊的 epsilon 值,用於數值穩定性。
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
應用於時間嵌入的歸一化型別。可以改善模型在長時間依賴任務上的表現。
resnet_act_fn (`str`, *optional*, defaults to `swish`): 殘差塊的啟用函式型別。
resnet_groups (`int`, *optional*, defaults to 32):
殘差塊中組歸一化層使用的組數量。
attn_groups (`Optional[int]`, *optional*, defaults to None): 注意力塊的組數量。
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
是否在殘差塊中使用預歸一化。
add_attention (`bool`, *optional*, defaults to `True`): 是否新增註意力塊。
attention_head_dim (`int`, *optional*, defaults to 1):
單個注意力頭的維度。注意力頭的數量由該值和輸入通道的數量決定。
output_scale_factor (`float`, *optional*, defaults to 1.0): 輸出的縮放因子。
# 返回值說明
Returns:
`torch.Tensor`: 最後一個殘差塊的輸出,形狀為 `(batch_size, in_channels,
height, width)`。
"""
# 初始化函式,設定模型的各種引數
def __init__(
self,
in_channels: int, # 輸入通道數
temb_channels: int, # 時間嵌入通道數
dropout: float = 0.0, # dropout比率,預設為0.0
num_layers: int = 1, # 殘差塊數量,預設為1
resnet_eps: float = 1e-6, # 殘差塊的epsilon值
resnet_time_scale_shift: str = "default", # 時間尺度歸一化的型別
resnet_act_fn: str = "swish", # 殘差塊的啟用函式
resnet_groups: int = 32, # 殘差塊的組數量
attn_groups: Optional[int] = None, # 注意力塊的組數量
resnet_pre_norm: bool = True, # 是否使用預歸一化
add_attention: bool = True, # 是否新增註意力塊
attention_head_dim: int = 1, # 注意力頭的維度
output_scale_factor: float = 1.0, # 輸出縮放因子
# 前向傳播函式,接受隱藏狀態和時間嵌入作為輸入
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 將輸入的隱藏狀態透過第一個殘差塊進行處理
hidden_states = self.resnets[0](hidden_states, temb)
# 遍歷剩餘的注意力塊和殘差塊進行處理
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# 如果當前存在注意力塊,則進行處理
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
# 將處理後的隱藏狀態透過當前的殘差塊進行處理
hidden_states = resnet(hidden_states, temb)
# 返回最終的隱藏狀態
return hidden_states
# 定義一個繼承自 nn.Module 的類 UNetMidBlock2DCrossAttn
class UNetMidBlock2DCrossAttn(nn.Module):
# 初始化方法,定義各個引數
def __init__(
# 輸入通道數
in_channels: int,
# 時間嵌入通道數
temb_channels: int,
# 輸出通道數(可選,預設為 None)
out_channels: Optional[int] = None,
# dropout 機率(預設為 0.0)
dropout: float = 0.0,
# 層數(預設為 1)
num_layers: int = 1,
# 每個塊的變換器層數(預設為 1)
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 的 epsilon 值(預設為 1e-6)
resnet_eps: float = 1e-6,
# ResNet 的時間縮放偏移方式(預設為 "default")
resnet_time_scale_shift: str = "default",
# ResNet 的啟用函式型別(預設為 "swish")
resnet_act_fn: str = "swish",
# ResNet 的組數(預設為 32)
resnet_groups: int = 32,
# 輸出的 ResNet 組數(可選,預設為 None)
resnet_groups_out: Optional[int] = None,
# 是否進行 ResNet 預歸一化(預設為 True)
resnet_pre_norm: bool = True,
# 注意力頭的數量(預設為 1)
num_attention_heads: int = 1,
# 輸出縮放因子(預設為 1.0)
output_scale_factor: float = 1.0,
# 交叉注意力維度(預設為 1280)
cross_attention_dim: int = 1280,
# 是否使用雙重交叉注意力(預設為 False)
dual_cross_attention: bool = False,
# 是否使用線性投影(預設為 False)
use_linear_projection: bool = False,
# 是否提升注意力計算精度(預設為 False)
upcast_attention: bool = False,
# 注意力型別(預設為 "default")
attention_type: str = "default",
# 前向傳播方法
def forward(
# 隱藏狀態的張量
hidden_states: torch.Tensor,
# 時間嵌入的張量(可選,預設為 None)
temb: Optional[torch.Tensor] = None,
# 編碼器隱藏狀態的張量(可選,預設為 None)
encoder_hidden_states: Optional[torch.Tensor] = None,
# 注意力掩碼(可選,預設為 None)
attention_mask: Optional[torch.Tensor] = None,
# 交叉注意力引數(可選,預設為 None)
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 編碼器注意力掩碼(可選,預設為 None)
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: # 定義返回型別為 torch.Tensor
if cross_attention_kwargs is not None: # 檢查交叉注意力引數是否存在
if cross_attention_kwargs.get("scale", None) is not None: # 檢查是否有 scale 引數
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") # 記錄警告,提示 scale 引數已棄用
hidden_states = self.resnets[0](hidden_states, temb) # 使用第一個殘差網路處理隱藏狀態和時間嵌入
for attn, resnet in zip(self.attentions, self.resnets[1:]): # 遍歷注意力層和殘差網路(跳過第一個)
if self.training and self.gradient_checkpointing: # 如果處於訓練模式並且啟用了梯度檢查點
def create_custom_forward(module, return_dict=None): # 定義自定義前向傳播函式
def custom_forward(*inputs): # 定義實際的前向傳播實現
if return_dict is not None: # 如果返回字典不為 None
return module(*inputs, return_dict=return_dict) # 呼叫模組並返回字典
else: # 如果返回字典為 None
return module(*inputs) # 直接呼叫模組並返回結果
return custom_forward # 返回自定義前向傳播函式
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} # 設定檢查點的關鍵字引數
hidden_states = attn( # 呼叫注意力層處理隱藏狀態
hidden_states, # 輸入隱藏狀態
encoder_hidden_states=encoder_hidden_states, # 編碼器隱藏狀態
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力引數
attention_mask=attention_mask, # 注意力掩碼
encoder_attention_mask=encoder_attention_mask, # 編碼器注意力掩碼
return_dict=False, # 不返回字典格式
)[0] # 取處理結果的第一個元素
hidden_states = torch.utils.checkpoint.checkpoint( # 使用梯度檢查點
create_custom_forward(resnet), # 建立殘差網路的自定義前向函式
hidden_states, # 輸入隱藏狀態
temb, # 輸入時間嵌入
**ckpt_kwargs, # 解包關鍵字引數
)
else: # 如果不處於訓練模式或不啟用梯度檢查點
hidden_states = attn( # 呼叫注意力層處理隱藏狀態
hidden_states, # 輸入隱藏狀態
encoder_hidden_states=encoder_hidden_states, # 編碼器隱藏狀態
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力引數
attention_mask=attention_mask, # 注意力掩碼
encoder_attention_mask=encoder_attention_mask, # 編碼器注意力掩碼
return_dict=False, # 不返回字典格式
)[0] # 取處理結果的第一個元素
hidden_states = resnet(hidden_states, temb) # 使用殘差網路處理隱藏狀態和時間嵌入
return hidden_states # 返回最終的隱藏狀態
# 定義一個 UNet 中間塊類,繼承自 nn.Module
class UNetMidBlock2DSimpleCrossAttn(nn.Module):
# 初始化方法,設定類的引數
def __init__(
# 輸入通道數
in_channels: int,
# 時間嵌入通道數
temb_channels: int,
# Dropout 比例,預設為 0.0
dropout: float = 0.0,
# 層數,預設為 1
num_layers: int = 1,
# ResNet 的小 epsilon 值,預設為 1e-6
resnet_eps: float = 1e-6,
# ResNet 時間尺度偏移,預設為 "default"
resnet_time_scale_shift: str = "default",
# ResNet 啟用函式型別,預設為 "swish"
resnet_act_fn: str = "swish",
# ResNet 組數,預設為 32
resnet_groups: int = 32,
# 是否使用預歸一化,預設為 True
resnet_pre_norm: bool = True,
# 注意力頭維度,預設為 1
attention_head_dim: int = 1,
# 輸出縮放因子,預設為 1.0
output_scale_factor: float = 1.0,
# 交叉注意力維度,預設為 1280
cross_attention_dim: int = 1280,
# 是否跳過時間啟用,預設為 False
skip_time_act: bool = False,
# 是否僅使用交叉注意力,預設為 False
only_cross_attention: bool = False,
# 交叉注意力的歸一化方法,可選引數
cross_attention_norm: Optional[str] = None,
):
# 呼叫父類的初始化方法
super().__init__()
# 設定是否使用交叉注意力
self.has_cross_attention = True
# 設定注意力頭的維度
self.attention_head_dim = attention_head_dim
# 計算 ResNet 的組數,如果未提供,則取輸入通道數的四分之一和 32 的最小值
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# 計算注意力頭的數量
self.num_heads = in_channels // self.attention_head_dim
# 至少存在一個 ResNet 塊
resnets = [
# 建立一個 ResNet 塊
ResnetBlock2D(
# 輸入通道數
in_channels=in_channels,
# 輸出通道數
out_channels=in_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# 正則化引數
eps=resnet_eps,
# ResNet 組數
groups=resnet_groups,
# dropout 機率
dropout=dropout,
# 時間嵌入歸一化方法
time_embedding_norm=resnet_time_scale_shift,
# 非線性啟用函式
non_linearity=resnet_act_fn,
# 輸出縮放因子
output_scale_factor=output_scale_factor,
# 是否進行預歸一化
pre_norm=resnet_pre_norm,
# 是否跳過時間啟用
skip_time_act=skip_time_act,
)
]
# 初始化注意力層列表
attentions = []
# 迴圈建立指定數量的層
for _ in range(num_layers):
# 根據是否具有縮放點積注意力,選擇處理器
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
# 新增註意力層到列表中
attentions.append(
Attention(
# 查詢的維度
query_dim=in_channels,
# 交叉注意力的維度
cross_attention_dim=in_channels,
# 注意力頭的數量
heads=self.num_heads,
# 每個頭的維度
dim_head=self.attention_head_dim,
# 額外的 KV 投影維度
added_kv_proj_dim=cross_attention_dim,
# 歸一化的組數
norm_num_groups=resnet_groups,
# 是否使用偏置
bias=True,
# 是否上cast softmax
upcast_softmax=True,
# 是否僅使用交叉注意力
only_cross_attention=only_cross_attention,
# 交叉注意力的歸一化方法
cross_attention_norm=cross_attention_norm,
# 設定處理器
processor=processor,
)
)
# 新增另一個 ResNet 塊到列表中
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
# 將注意力層轉為可訓練模組列表
self.attentions = nn.ModuleList(attentions)
# 將 ResNet 塊轉為可訓練模組列表
self.resnets = nn.ModuleList(resnets)
def forward(
# 前向傳播函式的定義
self,
# 輸入的隱狀態
hidden_states: torch.Tensor,
# 可選的時間嵌入
temb: Optional[torch.Tensor] = None,
# 可選的編碼器隱狀態
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可選的注意力掩碼
attention_mask: Optional[torch.Tensor] = None,
# 可選的交叉注意力引數
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可選的編碼器注意力掩碼
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# 如果 cross_attention_kwargs 為 None,則初始化為空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 檢查 cross_attention_kwargs 中是否有 "scale" 引數,如果有則記錄警告資訊
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 如果 attention_mask 為 None
if attention_mask is None:
# 如果 encoder_hidden_states 已定義,表示正在進行交叉注意力,因此使用交叉注意力掩碼
mask = None if encoder_hidden_states is None else encoder_attention_mask
else:
# 當 attention_mask 已定義時,不檢查 encoder_attention_mask
# 這是為了與 UnCLIP 相容,UnCLIP 使用 'attention_mask' 引數作為交叉注意力掩碼
# TODO: UnCLIP 應該透過 encoder_attention_mask 引數而不是 attention_mask 引數來表達交叉注意力掩碼
# 然後可以簡化整個 if/else 塊為:
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask
# 透過第一個殘差網路處理隱藏狀態
hidden_states = self.resnets[0](hidden_states, temb)
# 遍歷注意力層和後續的殘差網路
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# 使用當前的注意力層處理隱藏狀態
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states, # 傳遞編碼器的隱藏狀態
attention_mask=mask, # 傳遞注意力掩碼
**cross_attention_kwargs, # 解包交叉注意力引數
)
# 透過當前的殘差網路處理隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 返回處理後的隱藏狀態
return hidden_states
# 定義一個名為 AttnDownBlock2D 的類,繼承自 nn.Module
class AttnDownBlock2D(nn.Module):
# 初始化方法,接受多個引數以設定層的屬性
def __init__(
# 輸入通道數
in_channels: int,
# 輸出通道數
out_channels: int,
# 時間嵌入通道數
temb_channels: int,
# dropout 機率,預設為 0.0
dropout: float = 0.0,
# 層數,預設為 1
num_layers: int = 1,
# ResNet 的 epsilon 值,防止除零錯誤,預設為 1e-6
resnet_eps: float = 1e-6,
# ResNet 時間尺度偏移,預設為 "default"
resnet_time_scale_shift: str = "default",
# ResNet 啟用函式,預設為 "swish"
resnet_act_fn: str = "swish",
# ResNet 組數,預設為 32
resnet_groups: int = 32,
# 是否在 ResNet 中使用預歸一化,預設為 True
resnet_pre_norm: bool = True,
# 注意力頭的維度,預設為 1
attention_head_dim: int = 1,
# 輸出縮放因子,預設為 1.0
output_scale_factor: float = 1.0,
# 下采樣的填充大小,預設為 1
downsample_padding: int = 1,
# 下采樣型別,預設為 "conv"
downsample_type: str = "conv",
):
# 呼叫父類建構函式初始化
super().__init__()
# 初始化空列表,用於儲存 ResNet 塊
resnets = []
# 初始化空列表,用於儲存注意力機制模組
attentions = []
# 儲存下采樣型別
self.downsample_type = downsample_type
# 如果未指定注意力頭的維度,則發出警告,並預設使用輸出通道數
if attention_head_dim is None:
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
# 將注意力頭維度設定為輸出通道數
attention_head_dim = out_channels
# 遍歷層數以構建 ResNet 塊和注意力模組
for i in range(num_layers):
# 確定輸入通道數,第一層使用初始通道數,其餘層使用輸出通道數
in_channels = in_channels if i == 0 else out_channels
# 建立並新增 ResNet 塊到 resnets 列表
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # 防止除零的epsilon值
groups=resnet_groups, # 分組數
dropout=dropout, # dropout 比率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入的歸一化方式
non_linearity=resnet_act_fn, # 非線性啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否在前面進行歸一化
)
)
# 建立並新增註意力模組到 attentions 列表
attentions.append(
Attention(
out_channels, # 輸出通道數
heads=out_channels // attention_head_dim, # 注意力頭的數量
dim_head=attention_head_dim, # 每個注意力頭的維度
rescale_output_factor=output_scale_factor, # 輸出縮放因子
eps=resnet_eps, # 防止除零的epsilon值
norm_num_groups=resnet_groups, # 歸一化的組數
residual_connection=True, # 是否使用殘差連線
bias=True, # 是否使用偏置
upcast_softmax=True, # 是否上溯 softmax
_from_deprecated_attn_block=True, # 是否來自於已棄用的注意力塊
)
)
# 將注意力模組列表轉換為 nn.ModuleList,便於管理
self.attentions = nn.ModuleList(attentions)
# 將 ResNet 塊列表轉換為 nn.ModuleList,便於管理
self.resnets = nn.ModuleList(resnets)
# 根據下采樣型別選擇相應的下采樣方法
if downsample_type == "conv":
# 建立卷積下采樣模組並儲存
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, # 輸出通道數
use_conv=True, # 是否使用卷積
out_channels=out_channels, # 輸出通道數
padding=downsample_padding, # 填充
name="op" # 模組名稱
)
]
)
elif downsample_type == "resnet":
# 建立 ResNet 下采樣塊並儲存
self.downsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # 防止除零的epsilon值
groups=resnet_groups, # 分組數
dropout=dropout, # dropout 比率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入的歸一化方式
non_linearity=resnet_act_fn, # 非線性啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否在前面進行歸一化
down=True, # 指示為下采樣
)
]
)
else:
# 如果沒有匹配的下采樣型別,則將下采樣模組設定為 None
self.downsamplers = None
# 前向傳播方法,處理隱狀態和可選的其他引數,返回處理後的隱狀態和輸出狀態
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱狀態張量
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
upsample_size: Optional[int] = None, # 可選的上取樣尺寸
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 可選的交叉注意力引數字典
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: # 返回隱狀態和輸出狀態的元組
# 如果沒有提供交叉注意力引數,則初始化為空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 檢查是否傳入了 scale 引數,如果有,則發出警告,因為這個引數已被棄用
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
output_states = () # 初始化輸出狀態為一個空元組
# 遍歷自定義的殘差網路和注意力層
for resnet, attn in zip(self.resnets, self.attentions):
# 將隱狀態傳遞給殘差網路並更新隱狀態
hidden_states = resnet(hidden_states, temb)
# 將隱狀態傳遞給注意力層,並更新隱狀態
hidden_states = attn(hidden_states, **cross_attention_kwargs)
# 將當前隱狀態新增到輸出狀態元組中
output_states = output_states + (hidden_states,)
# 檢查是否存在下采樣層
if self.downsamplers is not None:
# 遍歷每個下采樣層
for downsampler in self.downsamplers:
# 根據下采樣型別選擇不同的處理方式
if self.downsample_type == "resnet":
hidden_states = downsampler(hidden_states, temb=temb) # 使用時間嵌入處理
else:
hidden_states = downsampler(hidden_states) # 不使用時間嵌入處理
# 將最後的隱狀態新增到輸出狀態中
output_states += (hidden_states,)
# 返回處理後的隱狀態和輸出狀態
return hidden_states, output_states
# 定義一個名為 CrossAttnDownBlock2D 的類,繼承自 nn.Module
class CrossAttnDownBlock2D(nn.Module):
# 初始化方法,接收多個引數以配置模組
def __init__(
# 輸入通道數
self,
in_channels: int,
# 輸出通道數
out_channels: int,
# 時間嵌入通道數
temb_channels: int,
# dropout 比例,預設為 0.0
dropout: float = 0.0,
# 層數,預設為 1
num_layers: int = 1,
# 每個塊的變換器層數,可以是單個整數或整數元組
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 的 epsilon 值,預設為 1e-6
resnet_eps: float = 1e-6,
# ResNet 時間縮放偏移的型別,預設為 "default"
resnet_time_scale_shift: str = "default",
# ResNet 的啟用函式,預設為 "swish"
resnet_act_fn: str = "swish",
# ResNet 的分組數,預設為 32
resnet_groups: int = 32,
# 是否使用預歸一化,預設為 True
resnet_pre_norm: bool = True,
# 注意力頭的數量,預設為 1
num_attention_heads: int = 1,
# 交叉注意力維度,預設為 1280
cross_attention_dim: int = 1280,
# 輸出縮放因子,預設為 1.0
output_scale_factor: float = 1.0,
# 下采樣填充的大小,預設為 1
downsample_padding: int = 1,
# 是否新增下采樣,預設為 True
add_downsample: bool = True,
# 是否使用雙重交叉注意力,預設為 False
dual_cross_attention: bool = False,
# 是否使用線性投影,預設為 False
use_linear_projection: bool = False,
# 是否僅使用交叉注意力,預設為 False
only_cross_attention: bool = False,
# 是否提升注意力精度,預設為 False
upcast_attention: bool = False,
# 注意力型別,預設為 "default"
attention_type: str = "default",
# 初始化父類
):
super().__init__()
# 初始化殘差塊列表
resnets = []
# 初始化注意力機制列表
attentions = []
# 設定是否使用交叉注意力
self.has_cross_attention = True
# 設定注意力頭的數量
self.num_attention_heads = num_attention_heads
# 如果每個塊的變換層是整數,則擴充套件為列表
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# 遍歷每一層
for i in range(num_layers):
# 確定輸入通道數
in_channels = in_channels if i == 0 else out_channels
# 新增殘差塊到列表
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon值
groups=resnet_groups, # 組數
dropout=dropout, # dropout比率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化
non_linearity=resnet_act_fn, # 非線性啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否使用預歸一化
)
)
# 檢查是否使用雙重交叉注意力
if not dual_cross_attention:
# 新增普通的變換模型到列表
attentions.append(
Transformer2DModel(
num_attention_heads, # 注意力頭數量
out_channels // num_attention_heads, # 每個頭的輸出通道數
in_channels=out_channels, # 輸入通道數
num_layers=transformer_layers_per_block[i], # 層數
cross_attention_dim=cross_attention_dim, # 交叉注意力維度
norm_num_groups=resnet_groups, # 歸一化組數
use_linear_projection=use_linear_projection, # 是否使用線性投影
only_cross_attention=only_cross_attention, # 是否僅使用交叉注意力
upcast_attention=upcast_attention, # 是否向上投射注意力
attention_type=attention_type, # 注意力型別
)
)
else:
# 新增雙重變換模型到列表
attentions.append(
DualTransformer2DModel(
num_attention_heads, # 注意力頭數量
out_channels // num_attention_heads, # 每個頭的輸出通道數
in_channels=out_channels, # 輸入通道數
num_layers=1, # 層數固定為1
cross_attention_dim=cross_attention_dim, # 交叉注意力維度
norm_num_groups=resnet_groups, # 歸一化組數
)
)
# 將注意力模型列表轉換為nn.ModuleList
self.attentions = nn.ModuleList(attentions)
# 將殘差塊列表轉換為nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 檢查是否新增下采樣層
if add_downsample:
# 建立下采樣層列表
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, # 輸出通道數
use_conv=True, # 是否使用卷積
out_channels=out_channels, # 輸出通道數
padding=downsample_padding, # 填充
name="op" # 操作名稱
)
]
)
else:
# 如果不新增下采樣層,則設定為None
self.downsamplers = None
# 設定梯度檢查點開關為關閉
self.gradient_checkpointing = False
# 定義前向傳播函式,接收多個引數
def forward(
self,
# 隱藏狀態張量,表示模型的內部狀態
hidden_states: torch.Tensor,
# 可選的時間嵌入張量,用於控制生成的時間步
temb: Optional[torch.Tensor] = None,
# 可選的編碼器隱藏狀態張量,表示編碼器的輸出
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可選的注意力掩碼,用於遮蔽輸入中不需要關注的部分
attention_mask: Optional[torch.Tensor] = None,
# 可選的交叉注意力引數字典,用於傳遞其他配置
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可選的編碼器注意力掩碼,控制編碼器的注意力機制
encoder_attention_mask: Optional[torch.Tensor] = None,
# 可選的附加殘差張量,作為額外的資訊傳遞
additional_residuals: Optional[torch.Tensor] = None,
# 返回型別為元組,包含張量和一個張量元組
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 檢查交叉注意力引數是否不為空
if cross_attention_kwargs is not None:
# 如果"scale"引數存在,則發出警告,說明其已被棄用
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 初始化輸出狀態為空元組
output_states = ()
# 將殘差網路和注意力層配對
blocks = list(zip(self.resnets, self.attentions))
# 遍歷每一對殘差網路和注意力層
for i, (resnet, attn) in enumerate(blocks):
# 如果在訓練中並且啟用了梯度檢查點
if self.training and self.gradient_checkpointing:
# 定義建立自定義前向傳播的函式
def create_custom_forward(module, return_dict=None):
# 定義自定義前向傳播邏輯
def custom_forward(*inputs):
# 根據是否返回字典呼叫模組
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
# 根據 PyTorch 版本設定檢查點引數
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 使用檢查點機制計算隱藏狀態
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
# 透過注意力層處理隱藏狀態
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
# 直接透過殘差網路處理隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 透過注意力層處理隱藏狀態
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# 如果是最後一對塊並且有額外的殘差
if i == len(blocks) - 1 and additional_residuals is not None:
# 將額外的殘差新增到隱藏狀態
hidden_states = hidden_states + additional_residuals
# 更新輸出狀態,新增當前隱藏狀態
output_states = output_states + (hidden_states,)
# 如果下采樣器不為空
if self.downsamplers is not None:
# 遍歷每個下采樣器
for downsampler in self.downsamplers:
# 使用下采樣器處理隱藏狀態
hidden_states = downsampler(hidden_states)
# 更新輸出狀態,新增當前隱藏狀態
output_states = output_states + (hidden_states,)
# 返回最終的隱藏狀態和輸出狀態
return hidden_states, output_states
# 定義一個二維向下塊,繼承自 nn.Module
class DownBlock2D(nn.Module):
# 初始化方法,定義各個引數及其預設值
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
dropout: float = 0.0, # dropout 機率
num_layers: int = 1, # ResNet 層數
resnet_eps: float = 1e-6, # ResNet 中的 epsilon
resnet_time_scale_shift: str = "default", # 時間縮放偏移設定
resnet_act_fn: str = "swish", # ResNet 的啟用函式
resnet_groups: int = 32, # ResNet 的分組數
resnet_pre_norm: bool = True, # 是否使用預歸一化
output_scale_factor: float = 1.0, # 輸出縮放因子
add_downsample: bool = True, # 是否新增下采樣
downsample_padding: int = 1, # 下采樣填充
):
# 呼叫父類建構函式
super().__init__()
# 初始化空的 ResNet 塊列表
resnets = []
# 根據層數迴圈建立 ResNet 塊
for i in range(num_layers):
# 確定當前層的輸入通道數
in_channels = in_channels if i == 0 else out_channels
# 新增 ResNet 塊到列表
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 當前層的輸入通道數
out_channels=out_channels, # 當前層的輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 引數
groups=resnet_groups, # 分組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 預歸一化標誌
)
)
# 將 ResNet 塊列表轉換為 ModuleList
self.resnets = nn.ModuleList(resnets)
# 根據標誌決定是否新增下采樣層
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D( # 建立下采樣層
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
# 如果不新增下采樣,則將其設定為 None
self.downsamplers = None
# 初始化梯度檢查點標誌為 False
self.gradient_checkpointing = False
# 定義前向傳播方法
def forward(
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 檢查位置引數是否大於0,或關鍵字引數中的 "scale" 是否不為 None
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定義棄用訊息,提示使用者 "scale" 引數已棄用,未來將引發錯誤
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫棄用函式,記錄 "scale" 引數的棄用
deprecate("scale", "1.0.0", deprecation_message)
# 初始化輸出狀態元組
output_states = ()
# 遍歷自定義的 ResNet 模組列表
for resnet in self.resnets:
# 如果在訓練模式下並且啟用了梯度檢查點
if self.training and self.gradient_checkpointing:
# 定義用於建立自定義前向傳播的函式
def create_custom_forward(module):
# 定義自定義前向傳播函式
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 檢查 PyTorch 版本是否大於等於 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用檢查點技術來計算隱藏狀態,防止記憶體洩漏
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
# 在較早的版本中呼叫檢查點計算隱藏狀態
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 在非訓練模式下直接呼叫 ResNet 模組以計算隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 將計算出的隱藏狀態新增到輸出狀態元組中
output_states = output_states + (hidden_states,)
# 如果存在下采樣器
if self.downsamplers is not None:
# 遍歷下采樣器並計算隱藏狀態
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
# 將下采樣後的隱藏狀態新增到輸出狀態元組中
output_states = output_states + (hidden_states,)
# 返回當前的隱藏狀態和輸出狀態元組
return hidden_states, output_states
# 定義一個 2D 下采樣編碼塊的類,繼承自 nn.Module
class DownEncoderBlock2D(nn.Module):
# 初始化方法,設定各類引數
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
dropout: float = 0.0, # dropout 機率,預設 0
num_layers: int = 1, # 層數,預設 1
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值,預設 1e-6
resnet_time_scale_shift: str = "default", # 時間尺度偏移方式,預設值為 "default"
resnet_act_fn: str = "swish", # ResNet 使用的啟用函式,預設是 "swish"
resnet_groups: int = 32, # ResNet 的組數,預設 32
resnet_pre_norm: bool = True, # 是否在前面進行歸一化,預設 True
output_scale_factor: float = 1.0, # 輸出縮放因子,預設 1.0
add_downsample: bool = True, # 是否新增下采樣層,預設 True
downsample_padding: int = 1, # 下采樣的填充大小,預設 1
):
# 呼叫父類的初始化方法
super().__init__()
# 初始化一個空的列表,用於存放 ResNet 塊
resnets = []
# 根據層數建立 ResNet 塊
for i in range(num_layers):
# 如果是第一層,使用輸入通道數;否則使用輸出通道數
in_channels = in_channels if i == 0 else out_channels
# 根據時間尺度偏移方式選擇相應的 ResNet 塊
if resnet_time_scale_shift == "spatial":
# 建立一個帶條件歸一化的 ResNet 塊,並新增到 resnets 列表
resnets.append(
ResnetBlockCondNorm2D(
in_channels=in_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=None, # 時間嵌入通道數,預設 None
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 組數
dropout=dropout, # dropout 機率
time_embedding_norm="spatial", # 時間嵌入的歸一化方式
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
)
)
else:
# 建立一個標準的 ResNet 塊,並新增到 resnets 列表
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=None, # 時間嵌入通道數,預設 None
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入的歸一化方式
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否前歸一化
)
)
# 將 ResNet 塊列表轉為 nn.ModuleList,便於管理
self.resnets = nn.ModuleList(resnets)
# 根據 add_downsample 引數決定是否新增下采樣層
if add_downsample:
# 建立下采樣層,並新增到 nn.ModuleList
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, # 輸入通道數
use_conv=True, # 是否使用卷積進行下采樣
out_channels=out_channels, # 輸出通道數
padding=downsample_padding, # 填充大小
name="op" # 下采樣層的名稱
)
]
)
else:
# 如果不新增下采樣層,設定為 None
self.downsamplers = None
# 定義前向傳播函式,接受隱藏狀態和可選引數
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
# 檢查是否傳入多餘的引數或已棄用的 `scale` 引數
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 設定棄用資訊,提醒使用者移除 `scale` 引數
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫棄用函式,記錄警告資訊
deprecate("scale", "1.0.0", deprecation_message)
# 遍歷每個 ResNet 層,更新隱藏狀態
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
# 如果存在下采樣層,則逐個應用下采樣
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
# 返回最終的隱藏狀態
return hidden_states
# 定義一個二維注意力下采樣編碼器塊的類,繼承自 nn.Module
class AttnDownEncoderBlock2D(nn.Module):
# 初始化方法,接收多個引數用於配置編碼器塊
def __init__(
# 輸入通道數
in_channels: int,
# 輸出通道數
out_channels: int,
# 丟棄率,控制神經元隨機失活的比例
dropout: float = 0.0,
# 編碼器塊的層數
num_layers: int = 1,
# ResNet 的小常量,用於防止除零錯誤
resnet_eps: float = 1e-6,
# ResNet 的時間尺度偏移,預設配置
resnet_time_scale_shift: str = "default",
# ResNet 使用的啟用函式型別,預設為 swish
resnet_act_fn: str = "swish",
# ResNet 中的分組數
resnet_groups: int = 32,
# 是否在前面進行歸一化,預設為 True
resnet_pre_norm: bool = True,
# 注意力頭的維度
attention_head_dim: int = 1,
# 輸出縮放因子,預設為 1.0
output_scale_factor: float = 1.0,
# 是否新增下采樣層,預設為 True
add_downsample: bool = True,
# 下采樣的填充大小,預設為 1
downsample_padding: int = 1,
):
# 呼叫父類建構函式
super().__init__()
# 初始化空列表以儲存殘差塊
resnets = []
# 初始化空列表以儲存注意力模組
attentions = []
# 檢查是否傳入注意力頭維度
if attention_head_dim is None:
# 記錄警告資訊,預設設定注意力頭維度為輸出通道數
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
# 將注意力頭維度設定為輸出通道數
attention_head_dim = out_channels
# 遍歷層數以構建殘差塊和注意力模組
for i in range(num_layers):
# 第一層輸入通道為 in_channels,其餘層為 out_channels
in_channels = in_channels if i == 0 else out_channels
# 根據時間縮放偏移型別構建不同型別的殘差塊
if resnet_time_scale_shift == "spatial":
# 新增條件歸一化的殘差塊到列表
resnets.append(
ResnetBlockCondNorm2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm="spatial",
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
)
)
else:
# 新增普通殘差塊到列表
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
# 新增註意力模組到列表
attentions.append(
Attention(
out_channels,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
# 將注意力模組列表轉換為 nn.ModuleList
self.attentions = nn.ModuleList(attentions)
# 將殘差塊列表轉換為 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 根據標誌決定是否新增下采樣層
if add_downsample:
# 建立下采樣模組並新增到列表
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
# 如果不新增下采樣層,將其設定為 None
self.downsamplers = None
# 定義前向傳播方法,接收隱藏狀態和其他引數
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
# 檢查是否有額外的引數或已棄用的 scale 引數
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 構建棄用警告資訊
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫 deprecate 函式顯示棄用警告
deprecate("scale", "1.0.0", deprecation_message)
# 遍歷自定義的 ResNet 和注意力層進行處理
for resnet, attn in zip(self.resnets, self.attentions):
# 透過 ResNet 層處理隱藏狀態
hidden_states = resnet(hidden_states, temb=None)
# 透過注意力層處理更新後的隱藏狀態
hidden_states = attn(hidden_states)
# 如果有下采樣層,則依次處理隱藏狀態
if self.downsamplers is not None:
for downsampler in self.downsamplers:
# 透過下采樣層處理隱藏狀態
hidden_states = downsampler(hidden_states)
# 返回處理後的隱藏狀態
return hidden_states
# 定義一個名為 AttnSkipDownBlock2D 的類,繼承自 nn.Module
class AttnSkipDownBlock2D(nn.Module):
# 初始化方法,定義類的建構函式
def __init__(
# 輸入通道數,整型
in_channels: int,
# 輸出通道數,整型
out_channels: int,
# 嵌入通道數,整型
temb_channels: int,
# dropout 率,浮點型,預設為 0.0
dropout: float = 0.0,
# 網路層數,整型,預設為 1
num_layers: int = 1,
# ResNet 的 epsilon 值,浮點型,預設為 1e-6
resnet_eps: float = 1e-6,
# ResNet 的時間尺度偏移方式,字串,預設為 "default"
resnet_time_scale_shift: str = "default",
# ResNet 的啟用函式型別,字串,預設為 "swish"
resnet_act_fn: str = "swish",
# 是否在前面進行規範化,布林值,預設為 True
resnet_pre_norm: bool = True,
# 注意力頭的維度,整型,預設為 1
attention_head_dim: int = 1,
# 輸出縮放因子,浮點型,預設為平方根2
output_scale_factor: float = np.sqrt(2.0),
# 是否新增下采樣層,布林值,預設為 True
add_downsample: bool = True,
):
# 呼叫父類的初始化方法
super().__init__()
# 初始化一個空的模組列表用於儲存注意力層
self.attentions = nn.ModuleList([])
# 初始化一個空的模組列表用於儲存殘差塊
self.resnets = nn.ModuleList([])
# 檢查 attention_head_dim 是否為 None
if attention_head_dim is None:
# 如果為 None,記錄警告資訊,並將其設定為輸出通道數
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
# 根據層數建立殘差塊和注意力層
for i in range(num_layers):
# 設定當前層的輸入通道數,如果是第一層則使用 in_channels,否則使用 out_channels
in_channels = in_channels if i == 0 else out_channels
# 新增一個 ResnetBlock2D 到 resnets 列表中
self.resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # 小常數以防止除零
groups=min(in_channels // 4, 32), # 分組數
groups_out=min(out_channels // 4, 32), # 輸出分組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化方式
non_linearity=resnet_act_fn, # 非線性啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否在殘差塊前進行歸一化
)
)
# 新增一個 Attention 層到 attentions 列表中
self.attentions.append(
Attention(
out_channels, # 輸出通道數
heads=out_channels // attention_head_dim, # 注意力頭的數量
dim_head=attention_head_dim, # 每個注意力頭的維度
rescale_output_factor=output_scale_factor, # 輸出縮放因子
eps=resnet_eps, # 小常數以防止除零
norm_num_groups=32, # 歸一化分組數
residual_connection=True, # 是否使用殘差連線
bias=True, # 是否使用偏置
upcast_softmax=True, # 是否使用上溢位 softmax
_from_deprecated_attn_block=True, # 是否來自過時的注意力塊
)
)
# 檢查是否需要新增下采樣層
if add_downsample:
# 建立一個 ResnetBlock2D 作為下采樣層
self.resnet_down = ResnetBlock2D(
in_channels=out_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # 小常數以防止除零
groups=min(out_channels // 4, 32), # 分組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化方式
non_linearity=resnet_act_fn, # 非線性啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否在殘差塊前進行歸一化
use_in_shortcut=True, # 是否在快捷連線中使用
down=True, # 是否進行下采樣
kernel="fir", # 卷積核型別
)
# 建立下采樣模組列表
self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
# 建立跳躍連線卷積層
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else:
# 如果不新增下采樣層,則將相關屬性設定為 None
self.resnet_down = None
self.downsamplers = None
self.skip_conv = None
# 前向傳播方法
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入
skip_sample: Optional[torch.Tensor] = None, # 可選的跳躍樣本
*args, # 額外的位置引數
**kwargs, # 額外的關鍵字引數
# 定義返回型別為元組,包含張量和多個張量的元組
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
# 檢查傳入的引數是否存在,或 kwargs 中的 scale 是否不為 None
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定義棄用資訊,說明 scale 引數將被忽略並將來會引發錯誤
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫棄用函式,記錄 scale 引數的棄用資訊
deprecate("scale", "1.0.0", deprecation_message)
# 初始化輸出狀態為一個空元組
output_states = ()
# 遍歷 resnet 和 attention 的組合
for resnet, attn in zip(self.resnets, self.attentions):
# 使用 resnet 處理隱藏狀態和時間嵌入
hidden_states = resnet(hidden_states, temb)
# 使用 attention 處理更新後的隱藏狀態
hidden_states = attn(hidden_states)
# 將當前隱藏狀態新增到輸出狀態元組中
output_states += (hidden_states,)
# 檢查是否存在下采樣器
if self.downsamplers is not None:
# 使用下采樣網路處理隱藏狀態
hidden_states = self.resnet_down(hidden_states, temb)
# 遍歷每個下采樣器並處理跳躍樣本
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
# 結合跳躍樣本和隱藏狀態,更新隱藏狀態
hidden_states = self.skip_conv(skip_sample) + hidden_states
# 將當前隱藏狀態新增到輸出狀態元組中
output_states += (hidden_states,)
# 返回更新後的隱藏狀態、輸出狀態元組和跳躍樣本
return hidden_states, output_states, skip_sample
# 定義一個二維跳過塊的類,繼承自 nn.Module
class SkipDownBlock2D(nn.Module):
# 初始化方法,設定輸入和輸出通道等引數
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
dropout: float = 0.0, # dropout 機率
num_layers: int = 1, # 層數
resnet_eps: float = 1e-6, # ResNet 的 epsilon 引數
resnet_time_scale_shift: str = "default", # 時間縮放偏移方式
resnet_act_fn: str = "swish", # ResNet 啟用函式
resnet_pre_norm: bool = True, # 是否在前面進行歸一化
output_scale_factor: float = np.sqrt(2.0), # 輸出縮放因子
add_downsample: bool = True, # 是否新增下采樣層
downsample_padding: int = 1, # 下采樣時的填充
):
super().__init__() # 呼叫父類建構函式
self.resnets = nn.ModuleList([]) # 初始化 ResNet 塊列表
# 迴圈建立每一層的 ResNet 塊
for i in range(num_layers):
# 第一層使用輸入通道,後續層使用輸出通道
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
ResnetBlock2D( # 新增 ResNet 塊
in_channels=in_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 引數
groups=min(in_channels // 4, 32), # 輸入通道組數
groups_out=min(out_channels // 4, 32), # 輸出通道組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 前歸一化
)
)
# 如果需要新增下采樣層
if add_downsample:
self.resnet_down = ResnetBlock2D( # 建立下采樣 ResNet 塊
in_channels=out_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 引數
groups=min(out_channels // 4, 32), # 輸出通道組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 前歸一化
use_in_shortcut=True, # 使用短接
down=True, # 啟用下采樣
kernel="fir", # 指定卷積核型別
)
# 建立下采樣模組列表
self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
# 建立跳過連線卷積層
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else: # 如果不新增下采樣層
self.resnet_down = None # 不使用下采樣 ResNet 塊
self.downsamplers = None # 不使用下采樣模組列表
self.skip_conv = None # 不使用跳過連線卷積層
# 前向傳播方法
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入
skip_sample: Optional[torch.Tensor] = None, # 可選的跳過樣本
*args, # 可變位置引數
**kwargs, # 可變關鍵字引數
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]: # 定義返回型別為元組,包含一個張量、一個張量元組和另一個張量
if len(args) > 0 or kwargs.get("scale", None) is not None: # 檢查是否傳入了位置引數或名為“scale”的關鍵字引數
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." # 定義廢棄警告資訊
deprecate("scale", "1.0.0", deprecation_message) # 呼叫 deprecate 函式,記錄“scale”引數的廢棄資訊
output_states = () # 初始化輸出狀態為一個空元組
for resnet in self.resnets: # 遍歷自定義的 ResNet 模型列表
hidden_states = resnet(hidden_states, temb) # 將當前的隱藏狀態和時間嵌入傳遞給 ResNet,獲取更新後的隱藏狀態
output_states += (hidden_states,) # 將當前的隱藏狀態新增到輸出狀態元組中
if self.downsamplers is not None: # 檢查是否存在下采樣模組
hidden_states = self.resnet_down(hidden_states, temb) # 使用 ResNet 下采樣隱藏狀態和時間嵌入
for downsampler in self.downsamplers: # 遍歷下采樣模組
skip_sample = downsampler(skip_sample) # 對跳過連線樣本進行下采樣
hidden_states = self.skip_conv(skip_sample) + hidden_states # 透過跳過卷積處理下采樣樣本,並與當前的隱藏狀態相加
output_states += (hidden_states,) # 將更新後的隱藏狀態新增到輸出狀態元組中
return hidden_states, output_states, skip_sample # 返回更新後的隱藏狀態、輸出狀態元組和跳過樣本
# 定義一個 2D ResNet 下采樣塊,繼承自 nn.Module
class ResnetDownsampleBlock2D(nn.Module):
# 初始化函式,定義各層引數
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
dropout: float = 0.0, # dropout 機率
num_layers: int = 1, # ResNet 層數
resnet_eps: float = 1e-6, # ResNet 中的 epsilon 值
resnet_time_scale_shift: str = "default", # 時間縮放偏移方式
resnet_act_fn: str = "swish", # 啟用函式
resnet_groups: int = 32, # 組數
resnet_pre_norm: bool = True, # 是否使用預歸一化
output_scale_factor: float = 1.0, # 輸出縮放因子
add_downsample: bool = True, # 是否新增下采樣層
skip_time_act: bool = False, # 是否跳過時間啟用
):
# 呼叫父類建構函式
super().__init__()
resnets = [] # 初始化 ResNet 層列表
# 根據指定的層數構建 ResNet 塊
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels # 確定輸入通道數
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化方式
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否使用預歸一化
skip_time_act=skip_time_act, # 是否跳過時間啟用
)
)
self.resnets = nn.ModuleList(resnets) # 將 ResNet 層轉換為模組列表
# 如果需要,新增下采樣層
if add_downsample:
self.downsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化方式
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否使用預歸一化
skip_time_act=skip_time_act, # 是否跳過時間啟用
down=True, # 指定為下采樣層
)
]
)
else:
self.downsamplers = None # 如果不需要下采樣層,則為 None
self.gradient_checkpointing = False # 初始化梯度檢查點為 False
# 定義前向傳播函式
def forward(
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs # 前向傳播輸入
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 檢查引數是否存在,或者是否提供了已棄用的 `scale` 引數
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定義棄用訊息,告知使用者 `scale` 引數已棄用並將在未來的版本中引發錯誤
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫 deprecate 函式記錄棄用資訊
deprecate("scale", "1.0.0", deprecation_message)
# 初始化輸出狀態元組
output_states = ()
# 遍歷所有的 ResNet 模型
for resnet in self.resnets:
# 如果處於訓練模式並啟用梯度檢查點
if self.training and self.gradient_checkpointing:
# 定義一個函式用於建立自定義前向傳播
def create_custom_forward(module):
# 定義自定義前向傳播函式,呼叫傳入的模組
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 檢查 PyTorch 版本是否大於等於 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用梯度檢查點機制計算隱藏狀態,禁用重入
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
# 在較舊版本的 PyTorch 中使用梯度檢查點機制計算隱藏狀態
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 在非訓練模式下直接呼叫 ResNet 計算隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 將當前的隱藏狀態新增到輸出狀態元組中
output_states = output_states + (hidden_states,)
# 檢查是否存在下采樣器
if self.downsamplers is not None:
# 遍歷所有的下采樣器
for downsampler in self.downsamplers:
# 使用下采樣器計算隱藏狀態
hidden_states = downsampler(hidden_states, temb)
# 將當前的隱藏狀態新增到輸出狀態元組中
output_states = output_states + (hidden_states,)
# 返回最終的隱藏狀態和輸出狀態元組
return hidden_states, output_states
# 定義一個簡單的二維交叉注意力下采樣塊類,繼承自 nn.Module
class SimpleCrossAttnDownBlock2D(nn.Module):
# 初始化方法,設定輸入、輸出通道等引數
def __init__(
# 輸入通道數
in_channels: int,
# 輸出通道數
out_channels: int,
# 時間嵌入通道數
temb_channels: int,
# dropout 機率,預設值為 0.0
dropout: float = 0.0,
# 層數,預設值為 1
num_layers: int = 1,
# ResNet 中的 epsilon 值,預設值為 1e-6
resnet_eps: float = 1e-6,
# ResNet 的時間尺度偏移設定,預設為 "default"
resnet_time_scale_shift: str = "default",
# ResNet 的啟用函式型別,預設為 "swish"
resnet_act_fn: str = "swish",
# ResNet 的分組數,預設為 32
resnet_groups: int = 32,
# 是否在 ResNet 中使用預歸一化,預設為 True
resnet_pre_norm: bool = True,
# 注意力頭的維度,預設為 1
attention_head_dim: int = 1,
# 交叉注意力的維度,預設為 1280
cross_attention_dim: int = 1280,
# 輸出縮放因子,預設為 1.0
output_scale_factor: float = 1.0,
# 是否新增下采樣層,預設為 True
add_downsample: bool = True,
# 是否跳過時間啟用,預設為 False
skip_time_act: bool = False,
# 是否僅使用交叉注意力,預設為 False
only_cross_attention: bool = False,
# 交叉注意力的歸一化方式,預設為 None
cross_attention_norm: Optional[str] = None,
):
# 呼叫父類的初始化方法
super().__init__()
# 初始化是否有交叉注意力標誌
self.has_cross_attention = True
# 初始化殘差網路和注意力模組的列表
resnets = []
attentions = []
# 設定注意力頭的維度
self.attention_head_dim = attention_head_dim
# 計算注意力頭的數量
self.num_heads = out_channels // self.attention_head_dim
# 根據層數建立殘差塊
for i in range(num_layers):
# 設定輸入通道,第一層使用給定的輸入通道,其餘層使用輸出通道
in_channels = in_channels if i == 0 else out_channels
# 將殘差塊新增到列表中
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # 殘差網路中的 epsilon 值
groups=resnet_groups, # 分組數量
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入規範化
non_linearity=resnet_act_fn, # 非線性啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否預歸一化
skip_time_act=skip_time_act, # 是否跳過時間啟用
)
)
# 根據是否有縮放點積注意力建立處理器
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
# 將注意力模組新增到列表中
attentions.append(
Attention(
query_dim=out_channels, # 查詢維度
cross_attention_dim=out_channels, # 交叉注意力維度
heads=self.num_heads, # 注意力頭數量
dim_head=attention_head_dim, # 每個頭的維度
added_kv_proj_dim=cross_attention_dim, # 額外的鍵值投影維度
norm_num_groups=resnet_groups, # 規範化的組數量
bias=True, # 是否使用偏置
upcast_softmax=True, # 是否上調 softmax
only_cross_attention=only_cross_attention, # 是否僅使用交叉注意力
cross_attention_norm=cross_attention_norm, # 交叉注意力的規範化
processor=processor, # 使用的處理器
)
)
# 將注意力模組列表轉換為可訓練模組
self.attentions = nn.ModuleList(attentions)
# 將殘差塊列表轉換為可訓練模組
self.resnets = nn.ModuleList(resnets)
# 如果需要新增下采樣
if add_downsample:
# 建立下采樣的殘差塊
self.downsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # 殘差網路中的 epsilon 值
groups=resnet_groups, # 分組數量
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入規範化
non_linearity=resnet_act_fn, # 非線性啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否預歸一化
skip_time_act=skip_time_act, # 是否跳過時間啟用
down=True, # 表示這是下采樣
)
]
)
else:
# 如果不需要下采樣,將下采樣設定為 None
self.downsamplers = None
# 初始化梯度檢查點標誌為 False
self.gradient_checkpointing = False
# 定義一個前向傳播方法,接受多個輸入引數
def forward(
self,
# 輸入的隱藏狀態張量
hidden_states: torch.Tensor,
# 可選的時間嵌入張量
temb: Optional[torch.Tensor] = None,
# 可選的編碼器隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可選的注意力掩碼張量
attention_mask: Optional[torch.Tensor] = None,
# 可選的交叉注意力引數字典
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可選的編碼器注意力掩碼張量
encoder_attention_mask: Optional[torch.Tensor] = None,
# 返回值型別為元組,包含一個張量和一個張量元組
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 如果未提供 cross_attention_kwargs,則使用空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 檢查是否傳入了 scale 引數,若有則發出棄用警告
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 初始化輸出狀態為一個空元組
output_states = ()
# 檢查 attention_mask 是否為 None
if attention_mask is None:
# 如果 encoder_hidden_states 已定義,則進行交叉注意力,使用交叉注意力掩碼
mask = None if encoder_hidden_states is None else encoder_attention_mask
else:
# 如果已定義 attention_mask,則直接使用,不檢查 encoder_attention_mask
# 這為 UnCLIP 相容性提供支援
# TODO: UnCLIP 應該透過 encoder_attention_mask 參數列達交叉注意力掩碼
mask = attention_mask
# 遍歷 ResNet 和注意力層
for resnet, attn in zip(self.resnets, self.attentions):
# 在訓練中且開啟了梯度檢查點
if self.training and self.gradient_checkpointing:
# 定義一個自定義前向傳播函式
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
# 根據 return_dict 決定返回方式
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
# 使用檢查點進行前向傳播,節省記憶體
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
# 執行注意力層
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=mask,
**cross_attention_kwargs,
)
else:
# 否則直接使用 ResNet 進行前向傳播
hidden_states = resnet(hidden_states, temb)
# 執行注意力層
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=mask,
**cross_attention_kwargs,
)
# 將當前隱藏狀態新增到輸出狀態元組中
output_states = output_states + (hidden_states,)
# 如果存在下采樣層
if self.downsamplers is not None:
# 遍歷所有下采樣層
for downsampler in self.downsamplers:
# 執行下采樣
hidden_states = downsampler(hidden_states, temb)
# 將下采樣後的隱藏狀態新增到輸出狀態元組中
output_states = output_states + (hidden_states,)
# 返回最終的隱藏狀態和輸出狀態元組
return hidden_states, output_states
# 定義一個二維下采樣的神經網路模組,繼承自 nn.Module
class KDownBlock2D(nn.Module):
# 初始化方法,定義輸入輸出通道數、時間嵌入通道、dropout 機率等引數
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
dropout: float = 0.0, # dropout 機率,預設為 0
num_layers: int = 4, # 殘差層的數量,預設為 4
resnet_eps: float = 1e-5, # 殘差層中的 epsilon 值,防止除零錯誤
resnet_act_fn: str = "gelu", # 殘差層使用的啟用函式,預設為 GELU
resnet_group_size: int = 32, # 殘差層中組的大小
add_downsample: bool = False, # 是否新增下采樣層的標誌
):
# 呼叫父類的初始化方法
super().__init__()
resnets = [] # 初始化一個空列表用於儲存殘差塊
# 根據層數構建殘差塊
for i in range(num_layers):
# 第一層使用輸入通道,其他層使用輸出通道
in_channels = in_channels if i == 0 else out_channels
# 計算組的數量
groups = in_channels // resnet_group_size
# 計算輸出組的數量
groups_out = out_channels // resnet_group_size
# 建立殘差塊並新增到列表中
resnets.append(
ResnetBlockCondNorm2D(
in_channels=in_channels, # 當前層的輸入通道數
out_channels=out_channels, # 當前層的輸出通道數
dropout=dropout, # 當前層的 dropout 機率
temb_channels=temb_channels, # 時間嵌入通道數
groups=groups, # 當前層的組數量
groups_out=groups_out, # 輸出層的組數量
eps=resnet_eps, # 殘差層中的 epsilon 值
non_linearity=resnet_act_fn, # 殘差層的啟用函式
time_embedding_norm="ada_group", # 時間嵌入的歸一化方式
conv_shortcut_bias=False, # 卷積快捷連線是否使用偏置
)
)
# 將殘差塊列表轉換為 nn.ModuleList,以便於引數管理
self.resnets = nn.ModuleList(resnets)
# 根據標誌決定是否新增下采樣層
if add_downsample:
# 如果需要,建立一個下采樣層並新增到列表中
self.downsamplers = nn.ModuleList([KDownsample2D()])
else:
# 如果不需要下采樣,設定為 None
self.downsamplers = None
# 初始化梯度檢查點標誌為 False
self.gradient_checkpointing = False
# 前向傳播方法,接收隱藏狀態和時間嵌入(可選)
def forward(
self, hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
*args, **kwargs # 其他可選引數
# 函式返回一個包含張量和元組的元組
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 檢查是否有額外引數或“scale”關鍵字引數不為 None
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定義關於“scale”引數的棄用訊息
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫棄用函式記錄“scale”的棄用
deprecate("scale", "1.0.0", deprecation_message)
# 初始化輸出狀態為一個空元組
output_states = ()
# 遍歷所有的 ResNet 模組
for resnet in self.resnets:
# 如果處於訓練模式且啟用了梯度檢查點
if self.training and self.gradient_checkpointing:
# 定義一個建立自定義前向傳播的函式
def create_custom_forward(module):
# 定義自定義前向傳播函式,接受任意輸入並返回模組的輸出
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 如果 PyTorch 版本大於等於 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用檢查點功能計算隱藏狀態,傳遞自定義前向函式
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
# 否則使用檢查點功能計算隱藏狀態
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 如果不滿足條件,則直接透過 ResNet 模組計算隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 將當前隱藏狀態新增到輸出狀態元組中
output_states += (hidden_states,)
# 如果存在下采樣器
if self.downsamplers is not None:
# 遍歷每個下采樣器
for downsampler in self.downsamplers:
# 透過下采樣器計算隱藏狀態
hidden_states = downsampler(hidden_states)
# 返回最終的隱藏狀態和輸出狀態
return hidden_states, output_states
# 定義一個名為 KCrossAttnDownBlock2D 的類,繼承自 nn.Module
class KCrossAttnDownBlock2D(nn.Module):
# 初始化方法,接受多個引數以設定模型的結構
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
cross_attention_dim: int, # 跨注意力維度
dropout: float = 0.0, # dropout 機率,預設為 0
num_layers: int = 4, # 層數,預設為 4
resnet_group_size: int = 32, # ResNet 組的大小,預設為 32
add_downsample: bool = True, # 是否新增下采樣,預設為 True
attention_head_dim: int = 64, # 注意力頭維度,預設為 64
add_self_attention: bool = False, # 是否新增自注意力,預設為 False
resnet_eps: float = 1e-5, # ResNet 的 epsilon 值,預設為 1e-5
resnet_act_fn: str = "gelu", # ResNet 的啟用函式型別,預設為 "gelu"
):
# 呼叫父類的初始化方法
super().__init__()
# 初始化空列表以存放 ResNet 塊
resnets = []
# 初始化空列表以存放注意力塊
attentions = []
# 設定是否包含跨注意力標誌
self.has_cross_attention = True
# 建立指定數量的層
for i in range(num_layers):
# 第一層的輸入通道數為 in_channels,之後的層使用 out_channels
in_channels = in_channels if i == 0 else out_channels
# 計算組數
groups = in_channels // resnet_group_size
groups_out = out_channels // resnet_group_size
# 將 ResnetBlockCondNorm2D 新增到 resnets 列表
resnets.append(
ResnetBlockCondNorm2D(
in_channels=in_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
dropout=dropout, # dropout 機率
temb_channels=temb_channels, # 時間嵌入通道數
groups=groups, # 組數
groups_out=groups_out, # 輸出組數
eps=resnet_eps, # epsilon 值
non_linearity=resnet_act_fn, # 啟用函式
time_embedding_norm="ada_group", # 時間嵌入歸一化型別
conv_shortcut_bias=False, # 是否使用卷積快捷連線偏置
)
)
# 將 KAttentionBlock 新增到 attentions 列表
attentions.append(
KAttentionBlock(
out_channels, # 輸出通道數
out_channels // attention_head_dim, # 注意力頭數量
attention_head_dim, # 注意力頭維度
cross_attention_dim=cross_attention_dim, # 跨注意力維度
temb_channels=temb_channels, # 時間嵌入通道數
attention_bias=True, # 是否使用注意力偏置
add_self_attention=add_self_attention, # 是否新增自注意力
cross_attention_norm="layer_norm", # 跨注意力歸一化型別
group_size=resnet_group_size, # 組大小
)
)
# 將 resnets 列表轉換為 nn.ModuleList,以便可以在模型中使用
self.resnets = nn.ModuleList(resnets)
# 將 attentions 列表轉換為 nn.ModuleList
self.attentions = nn.ModuleList(attentions)
# 根據引數決定是否新增下采樣層
if add_downsample:
# 新增下采樣模組
self.downsamplers = nn.ModuleList([KDownsample2D()])
else:
# 如果不新增下采樣,設定為 None
self.downsamplers = None
# 初始化梯度檢查點標誌為 False
self.gradient_checkpointing = False
# 前向傳播方法,定義輸入和輸出
def forward(
self,
hidden_states: torch.Tensor, # 隱藏狀態輸入
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入
encoder_hidden_states: Optional[torch.Tensor] = None, # 可選的編碼器隱藏狀態
attention_mask: Optional[torch.Tensor] = None, # 可選的注意力掩碼
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 可選的跨注意力引數
encoder_attention_mask: Optional[torch.Tensor] = None, # 可選的編碼器注意力掩碼
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 如果沒有傳入 cross_attention_kwargs,則初始化為空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 檢查 cross_attention_kwargs 中是否存在 "scale" 引數,併發出警告
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 初始化輸出狀態為一個空元組
output_states = ()
# 遍歷 resnets 和 attentions 的對應元素
for resnet, attn in zip(self.resnets, self.attentions):
# 如果處於訓練模式且開啟了梯度檢查點
if self.training and self.gradient_checkpointing:
# 建立自定義前向傳播函式
def create_custom_forward(module, return_dict=None):
# 定義自定義前向傳播邏輯
def custom_forward(*inputs):
# 如果指定了返回字典,則返回包含字典的結果
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
# 否則返回普通結果
return module(*inputs)
return custom_forward
# 設定檢查點引數,針對 PyTorch 版本進行不同處理
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 使用檢查點機制計算隱藏狀態
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 傳入自定義前向函式
hidden_states, # 輸入隱藏狀態
temb, # 傳入時間嵌入
**ckpt_kwargs, # 傳入檢查點引數
)
# 使用注意力機制更新隱藏狀態
hidden_states = attn(
hidden_states, # 輸入隱藏狀態
encoder_hidden_states=encoder_hidden_states, # 編碼器隱藏狀態
emb=temb, # 時間嵌入
attention_mask=attention_mask, # 注意力掩碼
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力引數
encoder_attention_mask=encoder_attention_mask, # 編碼器注意力掩碼
)
else:
# 如果不是訓練模式或沒有使用梯度檢查點,直接透過 ResNet 更新隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 使用注意力機制更新隱藏狀態
hidden_states = attn(
hidden_states, # 輸入隱藏狀態
encoder_hidden_states=encoder_hidden_states, # 編碼器隱藏狀態
emb=temb, # 時間嵌入
attention_mask=attention_mask, # 注意力掩碼
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力引數
encoder_attention_mask=encoder_attention_mask, # 編碼器注意力掩碼
)
# 如果沒有下采樣層,輸出狀態新增 None
if self.downsamplers is None:
output_states += (None,)
else:
# 否則將當前隱藏狀態新增到輸出狀態
output_states += (hidden_states,)
# 如果存在下采樣層,則依次對隱藏狀態進行下采樣
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
# 返回最終的隱藏狀態和輸出狀態
return hidden_states, output_states
# 定義一個名為 AttnUpBlock2D 的類,繼承自 nn.Module
class AttnUpBlock2D(nn.Module):
# 初始化方法,接受多個引數以配置該模組
def __init__(
# 輸入通道數
self,
in_channels: int,
# 前一層輸出通道數
prev_output_channel: int,
# 輸出通道數
out_channels: int,
# 嵌入通道數
temb_channels: int,
# 解析度索引,預設為 None
resolution_idx: int = None,
# dropout 率,預設為 0.0
dropout: float = 0.0,
# 層數,預設為 1
num_layers: int = 1,
# ResNet 中的小常數,避免除零
resnet_eps: float = 1e-6,
# ResNet 時間縮放偏移,預設為 "default"
resnet_time_scale_shift: str = "default",
# ResNet 啟用函式,預設為 "swish"
resnet_act_fn: str = "swish",
# ResNet 組數,預設為 32
resnet_groups: int = 32,
# 是否在 ResNet 中使用預歸一化,預設為 True
resnet_pre_norm: bool = True,
# 注意力頭維度,預設為 1
attention_head_dim: int = 1,
# 輸出縮放因子,預設為 1.0
output_scale_factor: float = 1.0,
# 上取樣型別,預設為 "conv"
upsample_type: str = "conv",
):
# 呼叫父類的初始化方法
super().__init__()
# 初始化空列表用於儲存 ResNet 塊
resnets = []
# 初始化空列表用於儲存注意力層
attentions = []
# 設定上取樣型別
self.upsample_type = upsample_type
# 如果沒有傳入注意力頭維度
if attention_head_dim is None:
# 記錄警告,建議使用預設的頭維度
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
# 將注意力頭維度設定為輸出通道數
attention_head_dim = out_channels
# 遍歷每一層
for i in range(num_layers):
# 設定殘差跳過通道數,最後一層使用輸入通道
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 設定當前 ResNet 塊的輸入通道數
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 建立 ResNet 塊並新增到列表
resnets.append(
ResnetBlock2D(
# 輸入通道數為當前 ResNet 的輸入通道加上跳過的通道
in_channels=resnet_in_channels + res_skip_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# 餘弦相似性小常數
eps=resnet_eps,
# 分組數
groups=resnet_groups,
# dropout 機率
dropout=dropout,
# 時間嵌入的歸一化方式
time_embedding_norm=resnet_time_scale_shift,
# 非線性啟用函式
non_linearity=resnet_act_fn,
# 輸出縮放因子
output_scale_factor=output_scale_factor,
# 是否進行預歸一化
pre_norm=resnet_pre_norm,
)
)
# 建立注意力層並新增到列表
attentions.append(
Attention(
# 輸出通道數
out_channels,
# 注意力頭的數量
heads=out_channels // attention_head_dim,
# 每個頭的維度
dim_head=attention_head_dim,
# 輸出重縮放因子
rescale_output_factor=output_scale_factor,
# 餘弦相似性小常數
eps=resnet_eps,
# 歸一化的組數
norm_num_groups=resnet_groups,
# 是否使用殘差連線
residual_connection=True,
# 是否使用偏置
bias=True,
# 是否上升 softmax 的精度
upcast_softmax=True,
# 是否從已棄用的注意力塊中獲取
_from_deprecated_attn_block=True,
)
)
# 將注意力層列表轉換為 nn.ModuleList,以便於管理
self.attentions = nn.ModuleList(attentions)
# 將 ResNet 塊列表轉換為 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 根據上取樣型別選擇上取樣方法
if upsample_type == "conv":
# 使用卷積上取樣,並建立 ModuleList
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
elif upsample_type == "resnet":
# 使用 ResNet 塊進行上取樣,並建立 ModuleList
self.upsamplers = nn.ModuleList(
[
ResnetBlock2D(
# 輸入通道數
in_channels=out_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# 餘弦相似性小常數
eps=resnet_eps,
# 分組數
groups=resnet_groups,
# dropout 機率
dropout=dropout,
# 時間嵌入的歸一化方式
time_embedding_norm=resnet_time_scale_shift,
# 非線性啟用函式
non_linearity=resnet_act_fn,
# 輸出縮放因子
output_scale_factor=output_scale_factor,
# 是否進行預歸一化
pre_norm=resnet_pre_norm,
# 表示這是一個上取樣的塊
up=True,
)
]
)
else:
# 如果上取樣型別無效,則設為 None
self.upsamplers = None
# 儲存當前解析度索引
self.resolution_idx = resolution_idx
# 定義前向傳播函式,接收隱藏狀態及其他引數
def forward(
self,
hidden_states: torch.Tensor, # 當前的隱藏狀態張量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 之前的隱藏狀態元組
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
upsample_size: Optional[int] = None, # 可選的上取樣大小
*args, # 額外的位置引數
**kwargs, # 額外的關鍵字引數
) -> torch.Tensor: # 返回型別為張量
# 檢查是否傳入了多餘的引數或已棄用的 scale 引數
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 設定棄用警告資訊
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫棄用函式發出警告
deprecate("scale", "1.0.0", deprecation_message)
# 遍歷每一對殘差網路和注意力層
for resnet, attn in zip(self.resnets, self.attentions):
# 從元組中彈出最後一個殘差隱藏狀態
res_hidden_states = res_hidden_states_tuple[-1]
# 更新殘差隱藏狀態元組,去掉最後一個元素
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 將當前隱藏狀態和殘差隱藏狀態在維度1上拼接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 將拼接後的隱藏狀態傳入殘差網路
hidden_states = resnet(hidden_states, temb)
# 將輸出的隱藏狀態傳入注意力層
hidden_states = attn(hidden_states)
# 檢查是否存在上取樣器
if self.upsamplers is not None:
# 遍歷每個上取樣器
for upsampler in self.upsamplers:
# 根據上取樣型別選擇處理方式
if self.upsample_type == "resnet":
# 將隱藏狀態傳入上取樣器並提供時間嵌入
hidden_states = upsampler(hidden_states, temb=temb)
else:
# 將隱藏狀態傳入上取樣器
hidden_states = upsampler(hidden_states)
# 返回處理後的隱藏狀態
return hidden_states
# 定義一個名為 CrossAttnUpBlock2D 的類,繼承自 nn.Module
class CrossAttnUpBlock2D(nn.Module):
# 初始化方法,定義該類的建構函式
def __init__(
# 輸入通道數
self,
in_channels: int,
# 輸出通道數
out_channels: int,
# 前一層輸出通道數
prev_output_channel: int,
# 時間嵌入通道數
temb_channels: int,
# 可選的解析度索引
resolution_idx: Optional[int] = None,
# Dropout 機率
dropout: float = 0.0,
# 層數
num_layers: int = 1,
# 每個塊的 Transformer 層數,可以是單個整數或元組
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 的 epsilon 值,避免除零錯誤
resnet_eps: float = 1e-6,
# ResNet 的時間尺度偏移引數
resnet_time_scale_shift: str = "default",
# ResNet 的啟用函式型別
resnet_act_fn: str = "swish",
# ResNet 的組數
resnet_groups: int = 32,
# 是否使用預歸一化
resnet_pre_norm: bool = True,
# 注意力頭的數量
num_attention_heads: int = 1,
# 跨注意力的維度
cross_attention_dim: int = 1280,
# 輸出縮放因子
output_scale_factor: float = 1.0,
# 是否新增上取樣層
add_upsample: bool = True,
# 是否使用雙重跨注意力
dual_cross_attention: bool = False,
# 是否使用線性投影
use_linear_projection: bool = False,
# 是否僅使用跨注意力
only_cross_attention: bool = False,
# 是否提升注意力計算精度
upcast_attention: bool = False,
# 注意力型別
attention_type: str = "default",
# 繼承父類的初始化方法
):
super().__init__()
# 初始化空列表用於儲存 ResNet 和注意力模組
resnets = []
attentions = []
# 設定是否有交叉注意力的標誌
self.has_cross_attention = True
# 設定注意力頭的數量
self.num_attention_heads = num_attention_heads
# 如果輸入的是整數,則將其轉換為包含多個相同值的列表
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# 遍歷每一層,構建 ResNet 和注意力模組
for i in range(num_layers):
# 設定殘差跳躍通道數量
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 設定 ResNet 輸入通道
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 將 ResNet 模組新增到列表中
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # 小常數用於數值穩定性
groups=resnet_groups, # 分組數
dropout=dropout, # Dropout 率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入的歸一化方式
non_linearity=resnet_act_fn, # 非線性啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否進行預歸一化
)
)
# 根據是否啟用雙重交叉注意力,選擇不同的注意力模組
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads, # 注意力頭數量
out_channels // num_attention_heads, # 每個頭的輸出通道數
in_channels=out_channels, # 輸入通道數
num_layers=transformer_layers_per_block[i], # 當前層的變換器層數
cross_attention_dim=cross_attention_dim, # 交叉注意力維度
norm_num_groups=resnet_groups, # 歸一化分組數
use_linear_projection=use_linear_projection, # 是否使用線性投影
only_cross_attention=only_cross_attention, # 是否僅使用交叉注意力
upcast_attention=upcast_attention, # 是否上調注意力精度
attention_type=attention_type, # 注意力型別
)
)
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads, # 注意力頭數量
out_channels // num_attention_heads, # 每個頭的輸出通道數
in_channels=out_channels, # 輸入通道數
num_layers=1, # 僅使用一層
cross_attention_dim=cross_attention_dim, # 交叉注意力維度
norm_num_groups=resnet_groups, # 歸一化分組數
)
)
# 將注意力模組和 ResNet 模組轉換為 nn.ModuleList 以便於管理
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
# 根據是否新增上取樣層初始化上取樣模組
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
# 初始化梯度檢查點標誌
self.gradient_checkpointing = False
# 設定解析度索引
self.resolution_idx = resolution_idx
# 定義前向傳播函式,接收多個引數
def forward(
self,
# 當前隱藏狀態的張量
hidden_states: torch.Tensor,
# 元組,包含殘差隱藏狀態的張量
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可選的時間嵌入張量
temb: Optional[torch.Tensor] = None,
# 可選的編碼器隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可選的跨注意力引數字典
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可選的上取樣大小
upsample_size: Optional[int] = None,
# 可選的注意力掩碼張量
attention_mask: Optional[torch.Tensor] = None,
# 可選的編碼器注意力掩碼張量
encoder_attention_mask: Optional[torch.Tensor] = None,
# 定義一個名為 UpBlock2D 的類,繼承自 nn.Module
class UpBlock2D(nn.Module):
# 初始化方法,接受多個引數來構造 UpBlock2D 物件
def __init__(
self,
in_channels: int, # 輸入通道數
prev_output_channel: int, # 前一層輸出的通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
resolution_idx: Optional[int] = None, # 解析度索引,預設為 None
dropout: float = 0.0, # dropout 機率,預設為 0.0
num_layers: int = 1, # 層數,預設為 1
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值,預設為 1e-6
resnet_time_scale_shift: str = "default", # ResNet 的時間縮放偏移,預設為 "default"
resnet_act_fn: str = "swish", # ResNet 的啟用函式,預設為 "swish"
resnet_groups: int = 32, # ResNet 的組數,預設為 32
resnet_pre_norm: bool = True, # 是否進行預歸一化,預設為 True
output_scale_factor: float = 1.0, # 輸出縮放因子,預設為 1.0
add_upsample: bool = True, # 是否新增上取樣層,預設為 True
):
# 呼叫父類的初始化方法
super().__init__()
resnets = [] # 初始化一個空列表用於儲存 ResNet 塊
# 根據 num_layers 建立 ResNet 塊
for i in range(num_layers):
# 設定跳過通道數,如果是最後一層,則使用 in_channels,否則使用 out_channels
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 設定 ResNet 輸入通道數,第一層使用 prev_output_channel,其餘層使用 out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 將 ResNet 塊新增到列表中
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, # 輸入通道數加上跳過通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化
non_linearity=resnet_act_fn, # 非線性啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 預歸一化
)
)
# 將 ResNet 塊列表轉換為 nn.ModuleList 以便於管理
self.resnets = nn.ModuleList(resnets)
# 如果需要新增上取樣層,則初始化上取樣模組
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None # 不新增上取樣層,設定為 None
self.gradient_checkpointing = False # 初始化梯度檢查點標誌為 False
self.resolution_idx = resolution_idx # 儲存解析度索引
# 定義前向傳播方法,接受輸入的隱藏狀態及其他引數
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 之前層的隱藏狀態元組
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
upsample_size: Optional[int] = None, # 可選的上取樣大小
*args, # 其他位置引數
**kwargs, # 其他關鍵字引數
) -> torch.Tensor:
# 檢查是否傳入引數或 'scale' 引數
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 設定棄用訊息,提示使用者 'scale' 引數已被棄用
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫 deprecate 函式記錄棄用
deprecate("scale", "1.0.0", deprecation_message)
# 檢查 FreeU 是否啟用
is_freeu_enabled = (
getattr(self, "s1", None) # 獲取屬性 s1
and getattr(self, "s2", None) # 獲取屬性 s2
and getattr(self, "b1", None) # 獲取屬性 b1
and getattr(self, "b2", None) # 獲取屬性 b2
)
# 遍歷所有 ResNet 模型
for resnet in self.resnets:
# 彈出最後的 ResNet 隱藏狀態
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1] # 更新元組,去掉最後一個元素
# 如果 FreeU 被啟用,則僅對前兩個階段進行操作
if is_freeu_enabled:
# 呼叫 apply_freeu 函式處理隱藏狀態
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx, # 當前解析度索引
hidden_states, # 當前隱藏狀態
res_hidden_states, # ResNet 隱藏狀態
s1=self.s1, # s1 屬性
s2=self.s2, # s2 屬性
b1=self.b1, # b1 屬性
b2=self.b2, # b2 屬性
)
# 連線當前隱藏狀態和 ResNet 隱藏狀態
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 如果處於訓練模式且啟用梯度檢查點
if self.training and self.gradient_checkpointing:
# 建立自定義前向傳播函式
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs) # 呼叫模組的前向傳播
return custom_forward
# 根據 PyTorch 版本選擇檢查點方式
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 使用自定義前向函式
hidden_states, # 當前隱藏狀態
temb, # 時間嵌入
use_reentrant=False # 不使用可重入檢查點
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 使用自定義前向函式
hidden_states, # 當前隱藏狀態
temb # 時間嵌入
)
else:
# 如果不啟用檢查點,直接呼叫 ResNet
hidden_states = resnet(hidden_states, temb)
# 如果存在上取樣器,則遍歷進行上取樣
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size) # 呼叫上取樣器
# 返回最終的隱藏狀態
return hidden_states
# 定義一個 2D 上取樣解碼塊類,繼承自 nn.Module
class UpDecoderBlock2D(nn.Module):
# 初始化方法,接受多個引數用於構造解碼塊
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
resolution_idx: Optional[int] = None, # 解析度索引,預設為 None
dropout: float = 0.0, # dropout 機率,預設為 0
num_layers: int = 1, # 層數,預設為 1
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值,預設為 1e-6
resnet_time_scale_shift: str = "default", # 時間尺度偏移型別,預設為 "default"
resnet_act_fn: str = "swish", # ResNet 的啟用函式,預設為 "swish"
resnet_groups: int = 32, # ResNet 的分組數,預設為 32
resnet_pre_norm: bool = True, # 是否在 ResNet 中預先歸一化,預設為 True
output_scale_factor: float = 1.0, # 輸出縮放因子,預設為 1.0
add_upsample: bool = True, # 是否新增上取樣層,預設為 True
temb_channels: Optional[int] = None, # 時間嵌入通道數,預設為 None
):
# 呼叫父類構造方法
super().__init__()
# 初始化一個空的 ResNet 列表
resnets = []
# 根據層數建立相應數量的 ResNet 層
for i in range(num_layers):
# 第一個層使用輸入通道數,其餘層使用輸出通道數
input_channels = in_channels if i == 0 else out_channels
# 根據時間尺度偏移型別建立不同的 ResNet 塊
if resnet_time_scale_shift == "spatial":
resnets.append(
ResnetBlockCondNorm2D( # 新增條件歸一化的 ResNet 塊
in_channels=input_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 分組數
dropout=dropout, # dropout 機率
time_embedding_norm="spatial", # 時間嵌入歸一化型別
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
)
)
else:
resnets.append(
ResnetBlock2D( # 新增普通的 ResNet 塊
in_channels=input_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 分組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化型別
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否預先歸一化
)
)
# 將建立的 ResNet 塊儲存在 ModuleList 中,以便於管理
self.resnets = nn.ModuleList(resnets)
# 根據是否新增上取樣層初始化上取樣層列表
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None # 如果不新增,則設為 None
# 儲存解析度索引
self.resolution_idx = resolution_idx
# 定義前向傳播方法
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 遍歷所有 ResNet 層進行前向傳播
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb) # 更新隱藏狀態
# 如果存在上取樣層,則遍歷進行上取樣
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states) # 更新隱藏狀態
# 返回最終的隱藏狀態
return hidden_states
# 定義一個注意力上取樣解碼塊類,繼承自 nn.Module
class AttnUpDecoderBlock2D(nn.Module):
# 初始化類的建構函式,定義各引數
def __init__(
# 輸入通道數,決定輸入資料的特徵維度
self,
in_channels: int,
# 輸出通道數,決定輸出資料的特徵維度
out_channels: int,
# 解析度索引,選擇特定解析度(可選)
resolution_idx: Optional[int] = None,
# dropout比率,用於防止過擬合,預設為0
dropout: float = 0.0,
# 網路層數,決定模型的深度,預設為1
num_layers: int = 1,
# ResNet的epsilon值,防止分母為0的情況,預設為1e-6
resnet_eps: float = 1e-6,
# ResNet的時間尺度偏移設定,預設為"default"
resnet_time_scale_shift: str = "default",
# ResNet的啟用函式型別,預設為"swish"
resnet_act_fn: str = "swish",
# ResNet中的組數,影響計算和模型複雜度,預設為32
resnet_groups: int = 32,
# 是否在ResNet中使用預歸一化,預設為True
resnet_pre_norm: bool = True,
# 注意力頭的維度,影響注意力機制的計算,預設為1
attention_head_dim: int = 1,
# 輸出縮放因子,調整輸出的大小,預設為1.0
output_scale_factor: float = 1.0,
# 是否新增上取樣層,影響模型的結構,預設為True
add_upsample: bool = True,
# 時間嵌入通道數(可選),用於特定的時間資訊表示
temb_channels: Optional[int] = None,
):
# 呼叫父類的初始化方法
super().__init__()
# 初始化儲存殘差塊的列表
resnets = []
# 初始化儲存注意力層的列表
attentions = []
# 如果未指定注意力頭維度,則發出警告並使用輸出通道數作為預設值
if attention_head_dim is None:
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
)
# 將注意力頭維度設定為輸出通道數
attention_head_dim = out_channels
# 遍歷每一層以構建殘差塊和注意力層
for i in range(num_layers):
# 如果是第一層,則輸入通道為輸入通道數,否則為輸出通道數
input_channels = in_channels if i == 0 else out_channels
# 如果時間尺度偏移為"spatial",則使用條件歸一化的殘差塊
if resnet_time_scale_shift == "spatial":
resnets.append(
ResnetBlockCondNorm2D(
# 輸入通道數
in_channels=input_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# 殘差塊的epsilon引數
eps=resnet_eps,
# 組歸一化的組數
groups=resnet_groups,
# dropout比例
dropout=dropout,
# 時間嵌入的歸一化方式
time_embedding_norm="spatial",
# 非線性啟用函式
non_linearity=resnet_act_fn,
# 輸出縮放因子
output_scale_factor=output_scale_factor,
)
)
else:
# 否則使用普通的2D殘差塊
resnets.append(
ResnetBlock2D(
# 輸入通道數
in_channels=input_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# 殘差塊的epsilon引數
eps=resnet_eps,
# 組歸一化的組數
groups=resnet_groups,
# dropout比例
dropout=dropout,
# 時間嵌入的歸一化方式
time_embedding_norm=resnet_time_scale_shift,
# 非線性啟用函式
non_linearity=resnet_act_fn,
# 輸出縮放因子
output_scale_factor=output_scale_factor,
# 是否使用預歸一化
pre_norm=resnet_pre_norm,
)
)
# 新增註意力層
attentions.append(
Attention(
# 輸出通道數
out_channels,
# 計算注意力頭數
heads=out_channels // attention_head_dim,
# 注意力頭維度
dim_head=attention_head_dim,
# 輸出縮放因子
rescale_output_factor=output_scale_factor,
# epsilon引數
eps=resnet_eps,
# 組歸一化的組數(如果不是空間歸一化)
norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None,
# 空間歸一化維度(如果是空間歸一化)
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
# 是否使用殘差連線
residual_connection=True,
# 是否使用偏置
bias=True,
# 是否使用上取樣的softmax
upcast_softmax=True,
# 從過時的注意力塊建立
_from_deprecated_attn_block=True,
)
)
# 將注意力層轉換為模組列表
self.attentions = nn.ModuleList(attentions)
# 將殘差塊轉換為模組列表
self.resnets = nn.ModuleList(resnets)
# 如果需要新增上取樣層
if add_upsample:
# 建立上取樣層的模組列表
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 如果不需要上取樣,則將其設定為None
self.upsamplers = None
# 設定解析度索引
self.resolution_idx = resolution_idx
# 定義前向傳播函式,接收隱藏狀態和可選的時間嵌入,返回處理後的隱藏狀態
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 遍歷殘差網路和注意力模組的組合
for resnet, attn in zip(self.resnets, self.attentions):
# 將當前隱藏狀態輸入到殘差網路中,可能包含時間嵌入
hidden_states = resnet(hidden_states, temb=temb)
# 將殘差網路的輸出輸入到注意力模組中,可能包含時間嵌入
hidden_states = attn(hidden_states, temb=temb)
# 如果上取樣模組不為空,則執行上取樣操作
if self.upsamplers is not None:
# 遍歷所有上取樣模組
for upsampler in self.upsamplers:
# 將當前隱藏狀態輸入到上取樣模組中
hidden_states = upsampler(hidden_states)
# 返回最終處理後的隱藏狀態
return hidden_states
# 定義一個名為 AttnSkipUpBlock2D 的類,繼承自 nn.Module
class AttnSkipUpBlock2D(nn.Module):
# 初始化方法,定義該類的屬性
def __init__(
# 輸入通道數
in_channels: int,
# 前一個輸出通道數
prev_output_channel: int,
# 輸出通道數
out_channels: int,
# 時間嵌入通道數
temb_channels: int,
# 可選的解析度索引
resolution_idx: Optional[int] = None,
# dropout 機率
dropout: float = 0.0,
# 層數
num_layers: int = 1,
# ResNet 的 epsilon 值
resnet_eps: float = 1e-6,
# ResNet 時間縮放偏移設定
resnet_time_scale_shift: str = "default",
# ResNet 啟用函式型別
resnet_act_fn: str = "swish",
# 是否使用 ResNet 預歸一化
resnet_pre_norm: bool = True,
# 注意力頭的維度
attention_head_dim: int = 1,
# 輸出縮放因子
output_scale_factor: float = np.sqrt(2.0),
# 是否新增上取樣
add_upsample: bool = True,
):
# 初始化父類
super().__init__()
# 此處應有具體的初始化程式碼(如層的定義),略去
# 定義前向傳播方法
def forward(
# 隱藏狀態輸入
hidden_states: torch.Tensor,
# 之前的隱藏狀態元組
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可選的時間嵌入
temb: Optional[torch.Tensor] = None,
# 可選的跳躍樣本
skip_sample=None,
# 額外引數
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 檢查 args 和 kwargs 是否包含 scale 引數
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定義棄用資訊
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫棄用函式
deprecate("scale", "1.0.0", deprecation_message)
# 遍歷 ResNet 層
for resnet in self.resnets:
# 從元組中提取最近的隱藏狀態
res_hidden_states = res_hidden_states_tuple[-1]
# 更新隱藏狀態元組,去掉最近的隱藏狀態
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 將當前隱藏狀態與提取的隱藏狀態拼接在一起
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 透過 ResNet 層處理隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 透過注意力層處理隱藏狀態
hidden_states = self.attentions[0](hidden_states)
# 檢查跳躍樣本是否為 None
if skip_sample is not None:
# 如果不是,則透過上取樣層處理跳躍樣本
skip_sample = self.upsampler(skip_sample)
else:
# 如果是,則將跳躍樣本設為 0
skip_sample = 0
# 檢查 ResNet 上取樣層是否存在
if self.resnet_up is not None:
# 透過跳躍歸一化層處理隱藏狀態
skip_sample_states = self.skip_norm(hidden_states)
# 應用啟用函式
skip_sample_states = self.act(skip_sample_states)
# 透過跳躍卷積層處理狀態
skip_sample_states = self.skip_conv(skip_sample_states)
# 更新跳躍樣本
skip_sample = skip_sample + skip_sample_states
# 透過 ResNet 上取樣層處理隱藏狀態
hidden_states = self.resnet_up(hidden_states, temb)
# 返回處理後的隱藏狀態和跳躍樣本
return hidden_states, skip_sample
# 定義一個名為 SkipUpBlock2D 的類,繼承自 nn.Module
class SkipUpBlock2D(nn.Module):
# 初始化方法,定義該類的屬性
def __init__(
# 輸入通道數
in_channels: int,
# 前一個輸出通道數
prev_output_channel: int,
# 輸出通道數
out_channels: int,
# 時間嵌入通道數
temb_channels: int,
# 可選的解析度索引
resolution_idx: Optional[int] = None,
# dropout 機率
dropout: float = 0.0,
# 層數
num_layers: int = 1,
# ResNet 的 epsilon 值
resnet_eps: float = 1e-6,
# ResNet 時間縮放偏移設定
resnet_time_scale_shift: str = "default",
# ResNet 啟用函式型別
resnet_act_fn: str = "swish",
# 是否使用 ResNet 預歸一化
resnet_pre_norm: bool = True,
# 輸出縮放因子
output_scale_factor: float = np.sqrt(2.0),
# 是否新增上取樣
add_upsample: bool = True,
# 上取樣填充大小
upsample_padding: int = 1,
):
# 初始化父類
super().__init__()
# 此處應有具體的初始化程式碼(如層的定義),略去
):
# 呼叫父類的初始化方法
super().__init__()
# 建立一個空的 ModuleList 用於儲存 ResnetBlock2D 層
self.resnets = nn.ModuleList([])
# 根據 num_layers 的數量來新增 ResnetBlock2D 層
for i in range(num_layers):
# 計算跳過通道數,如果是最後一層則使用 in_channels,否則使用 out_channels
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 確定當前 ResNet 塊的輸入通道數
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 向 resnets 列表中新增一個新的 ResnetBlock2D 例項
self.resnets.append(
ResnetBlock2D(
# 輸入通道數為當前層輸入通道加跳過通道數
in_channels=resnet_in_channels + res_skip_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# 歸一化的 epsilon 值
eps=resnet_eps,
# 分組數為輸入通道數的一部分,最多為 32
groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
# 輸出分組數,同樣最多為 32
groups_out=min(out_channels // 4, 32),
# dropout 機率
dropout=dropout,
# 時間嵌入的歸一化方式
time_embedding_norm=resnet_time_scale_shift,
# 啟用函式型別
non_linearity=resnet_act_fn,
# 輸出縮放因子
output_scale_factor=output_scale_factor,
# 是否使用預歸一化
pre_norm=resnet_pre_norm,
)
)
# 初始化上取樣層
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
# 如果需要新增上取樣層
if add_upsample:
# 新增一個上取樣的 ResnetBlock2D
self.resnet_up = ResnetBlock2D(
# 輸入通道數
in_channels=out_channels,
# 輸出通道數
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# 歸一化的 epsilon 值
eps=resnet_eps,
# 分組數,最多為 32
groups=min(out_channels // 4, 32),
# 輸出分組數,最多為 32
groups_out=min(out_channels // 4, 32),
# dropout 機率
dropout=dropout,
# 時間嵌入的歸一化方式
time_embedding_norm=resnet_time_scale_shift,
# 啟用函式型別
non_linearity=resnet_act_fn,
# 輸出縮放因子
output_scale_factor=output_scale_factor,
# 是否使用預歸一化
pre_norm=resnet_pre_norm,
# 是否在快捷路徑中使用
use_in_shortcut=True,
# 標記為上取樣
up=True,
# 使用 FIR 卷積核
kernel="fir",
)
# 定義跳過連線的卷積層,將輸出通道數對映到 3 通道
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# 定義跳過連線的歸一化層
self.skip_norm = torch.nn.GroupNorm(
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
)
# 定義啟用函式為 SiLU
self.act = nn.SiLU()
else:
# 如果不新增上取樣層,則將相關屬性設為 None
self.resnet_up = None
self.skip_conv = None
self.skip_norm = None
self.act = None
# 儲存解析度索引
self.resolution_idx = resolution_idx
# 定義前向傳播方法
def forward(
# 定義前向傳播輸入引數
hidden_states: torch.Tensor,
# 儲存殘差隱藏狀態的元組
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可選的時間嵌入張量
temb: Optional[torch.Tensor] = None,
# 跳過取樣的可選引數
skip_sample=None,
# 可變位置引數
*args,
# 可變關鍵字引數
**kwargs,
# 函式返回兩個張量,表示隱藏狀態和跳過的樣本
) -> Tuple[torch.Tensor, torch.Tensor]:
# 檢查引數長度或關鍵字引數 "scale" 是否存在
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 建立棄用訊息,提醒使用者 "scale" 引數將被忽略
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫棄用函式,傳遞引數資訊
deprecate("scale", "1.0.0", deprecation_message)
# 遍歷自定義的 ResNet 模組
for resnet in self.resnets:
# 從隱藏狀態元組中彈出最後一個 ResNet 隱藏狀態
res_hidden_states = res_hidden_states_tuple[-1]
# 更新隱藏狀態元組,去掉最後一個元素
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 將當前的隱藏狀態與 ResNet 隱藏狀態連線
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 透過 ResNet 模組更新隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 檢查跳過樣本是否存在
if skip_sample is not None:
# 如果存在,使用上取樣器處理跳過樣本
skip_sample = self.upsampler(skip_sample)
else:
# 否則,將跳過樣本初始化為 0
skip_sample = 0
# 檢查是否存在 ResNet 上取樣模組
if self.resnet_up is not None:
# 對隱藏狀態應用歸一化
skip_sample_states = self.skip_norm(hidden_states)
# 對歸一化結果應用啟用函式
skip_sample_states = self.act(skip_sample_states)
# 對啟用結果應用卷積操作
skip_sample_states = self.skip_conv(skip_sample_states)
# 將跳過樣本與處理後的狀態相加
skip_sample = skip_sample + skip_sample_states
# 透過 ResNet 上取樣模組更新隱藏狀態
hidden_states = self.resnet_up(hidden_states, temb)
# 返回最終的隱藏狀態和跳過樣本
return hidden_states, skip_sample
# 定義一個 2D 上取樣的 ResNet 塊,繼承自 nn.Module
class ResnetUpsampleBlock2D(nn.Module):
# 初始化方法,設定網路引數
def __init__(
self,
in_channels: int, # 輸入通道數
prev_output_channel: int, # 前一層輸出的通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入的通道數
resolution_idx: Optional[int] = None, # 解析度索引(可選)
dropout: float = 0.0, # dropout 比例
num_layers: int = 1, # ResNet 層數
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值
resnet_time_scale_shift: str = "default", # 時間縮放偏移方式
resnet_act_fn: str = "swish", # 啟用函式型別
resnet_groups: int = 32, # 組數
resnet_pre_norm: bool = True, # 是否使用預歸一化
output_scale_factor: float = 1.0, # 輸出縮放因子
add_upsample: bool = True, # 是否新增上取樣層
skip_time_act: bool = False, # 是否跳過時間啟用
):
# 呼叫父類建構函式
super().__init__()
# 初始化一個空的 ResNet 列表
resnets = []
# 遍歷層數,建立每一層的 ResNet 塊
for i in range(num_layers):
# 確定跳過通道數,如果是最後一層,則使用輸入通道數,否則使用輸出通道數
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 確定當前層的輸入通道數
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 將 ResNet 塊新增到列表中
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 組數
dropout=dropout, # dropout 比例
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化方式
non_linearity=resnet_act_fn, # 非線性啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否使用預歸一化
skip_time_act=skip_time_act, # 是否跳過時間啟用
)
)
# 將 ResNet 列表轉換為 nn.ModuleList,便於 PyTorch 管理
self.resnets = nn.ModuleList(resnets)
# 如果需要新增上取樣層,則建立上取樣模組
if add_upsample:
self.upsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 組數
dropout=dropout, # dropout 比例
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化方式
non_linearity=resnet_act_fn, # 非線性啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否使用預歸一化
skip_time_act=skip_time_act, # 是否跳過時間啟用
up=True, # 標記為上取樣塊
)
]
)
else:
# 如果不需要上取樣,則將其設為 None
self.upsamplers = None
# 初始化梯度檢查點為 False
self.gradient_checkpointing = False
# 設定解析度索引
self.resolution_idx = resolution_idx
# 定義前向傳播方法
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 額外的隱藏狀態元組
temb: Optional[torch.Tensor] = None, # 時間嵌入(可選)
upsample_size: Optional[int] = None, # 上取樣大小(可選)
*args, # 額外的位置引數
**kwargs, # 額外的關鍵字引數
) -> torch.Tensor: # 定義一個返回 torch.Tensor 的函式
# 檢查引數列表是否包含引數或 "scale" 是否不為 None
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定義棄用資訊,說明 "scale" 引數將被忽略
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫 deprecate 函式以記錄 "scale" 引數的棄用
deprecate("scale", "1.0.0", deprecation_message)
# 遍歷儲存的 ResNet 模型列表
for resnet in self.resnets:
# 從隱藏狀態元組中彈出最後一個 ResNet 隱藏狀態
res_hidden_states = res_hidden_states_tuple[-1]
# 更新隱藏狀態元組,去掉最後一個元素
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 將當前的隱藏狀態與 ResNet 隱藏狀態在指定維度上拼接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 如果處於訓練模式且開啟了梯度檢查點
if self.training and self.gradient_checkpointing:
# 定義一個建立自定義前向傳播函式的內部函式
def create_custom_forward(module):
# 定義自定義前向傳播函式,呼叫模組
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 檢查 PyTorch 版本是否大於等於 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用梯度檢查點功能進行前向傳播
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
# 使用梯度檢查點功能進行前向傳播,不使用可重入選項
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 直接呼叫 ResNet 模型進行前向傳播
hidden_states = resnet(hidden_states, temb)
# 如果存在上取樣器
if self.upsamplers is not None:
# 遍歷每個上取樣器
for upsampler in self.upsamplers:
# 呼叫上取樣器進行上取樣處理
hidden_states = upsampler(hidden_states, temb)
# 返回最終的隱藏狀態
return hidden_states
# 定義一個簡單的二維交叉注意力上取樣模組,繼承自 nn.Module
class SimpleCrossAttnUpBlock2D(nn.Module):
# 初始化函式,接受多個引數來配置模組
def __init__(
# 輸入通道數
in_channels: int,
# 輸出通道數
out_channels: int,
# 上一個輸出通道數
prev_output_channel: int,
# 時間嵌入通道數
temb_channels: int,
# 可選的解析度索引
resolution_idx: Optional[int] = None,
# Dropout 機率
dropout: float = 0.0,
# 層數
num_layers: int = 1,
# ResNet 的 epsilon 值
resnet_eps: float = 1e-6,
# ResNet 時間縮放偏移
resnet_time_scale_shift: str = "default",
# ResNet 啟用函式
resnet_act_fn: str = "swish",
# ResNet 中的組數
resnet_groups: int = 32,
# 是否在 ResNet 中使用預歸一化
resnet_pre_norm: bool = True,
# 注意力頭的維度
attention_head_dim: int = 1,
# 交叉注意力的維度
cross_attention_dim: int = 1280,
# 輸出縮放因子
output_scale_factor: float = 1.0,
# 是否新增上取樣
add_upsample: bool = True,
# 是否跳過時間啟用
skip_time_act: bool = False,
# 是否僅使用交叉注意力
only_cross_attention: bool = False,
# 可選的交叉注意力歸一化方式
cross_attention_norm: Optional[str] = None,
# 初始化函式
):
# 呼叫父類的初始化方法
super().__init__()
# 建立一個空列表用於儲存 ResNet 模組
resnets = []
# 建立一個空列表用於儲存 Attention 模組
attentions = []
# 設定是否使用交叉注意力
self.has_cross_attention = True
# 設定每個注意力頭的維度
self.attention_head_dim = attention_head_dim
# 計算注意力頭的數量
self.num_heads = out_channels // self.attention_head_dim
# 遍歷每一層以構建 ResNet 模組
for i in range(num_layers):
# 設定跳躍連線通道數
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 設定當前 ResNet 輸入通道數
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 新增 ResNet 塊到列表
resnets.append(
ResnetBlock2D(
# 設定輸入通道數為 ResNet 輸入通道加上跳躍連線通道
in_channels=resnet_in_channels + res_skip_channels,
# 設定輸出通道數
out_channels=out_channels,
# 設定時間嵌入通道數
temb_channels=temb_channels,
# 設定小常數用於數值穩定性
eps=resnet_eps,
# 設定分組數
groups=resnet_groups,
# 設定 dropout 比例
dropout=dropout,
# 設定時間嵌入的歸一化方式
time_embedding_norm=resnet_time_scale_shift,
# 設定啟用函式
non_linearity=resnet_act_fn,
# 設定輸出縮放因子
output_scale_factor=output_scale_factor,
# 設定是否預歸一化
pre_norm=resnet_pre_norm,
# 設定是否跳過時間啟用
skip_time_act=skip_time_act,
)
)
# 根據是否支援縮放點積注意力選擇處理器
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
# 新增 Attention 模組到列表
attentions.append(
Attention(
# 設定查詢維度
query_dim=out_channels,
# 設定交叉注意力維度
cross_attention_dim=out_channels,
# 設定頭的數量
heads=self.num_heads,
# 設定每個頭的維度
dim_head=self.attention_head_dim,
# 設定額外的 KV 投影維度
added_kv_proj_dim=cross_attention_dim,
# 設定歸一化的組數
norm_num_groups=resnet_groups,
# 設定是否使用偏置
bias=True,
# 設定是否上溯 softmax
upcast_softmax=True,
# 設定是否僅使用交叉注意力
only_cross_attention=only_cross_attention,
# 設定交叉注意力的歸一化方式
cross_attention_norm=cross_attention_norm,
# 設定處理器
processor=processor,
)
)
# 將 Attention 模組列表轉換為 ModuleList
self.attentions = nn.ModuleList(attentions)
# 將 ResNet 模組列表轉換為 ModuleList
self.resnets = nn.ModuleList(resnets)
# 如果需要新增上取樣模組
if add_upsample:
# 建立一個上取樣的 ResNet 模組列表
self.upsamplers = nn.ModuleList(
[
ResnetBlock2D(
# 設定上取樣的輸入和輸出通道數
in_channels=out_channels,
out_channels=out_channels,
# 設定時間嵌入通道數
temb_channels=temb_channels,
# 設定小常數用於數值穩定性
eps=resnet_eps,
# 設定分組數
groups=resnet_groups,
# 設定 dropout 比例
dropout=dropout,
# 設定時間嵌入的歸一化方式
time_embedding_norm=resnet_time_scale_shift,
# 設定啟用函式
non_linearity=resnet_act_fn,
# 設定輸出縮放因子
output_scale_factor=output_scale_factor,
# 設定是否預歸一化
pre_norm=resnet_pre_norm,
# 設定是否跳過時間啟用
skip_time_act=skip_time_act,
# 設定為上取樣模式
up=True,
)
]
)
else:
# 如果不需要上取樣,則將上取樣模組設定為 None
self.upsamplers = None
# 初始化梯度檢查點設定為 False
self.gradient_checkpointing = False
# 設定解析度索引
self.resolution_idx = resolution_idx
# 定義前向傳播方法,接收多個輸入引數
def forward(
self,
# 當前隱藏狀態的張量
hidden_states: torch.Tensor,
# 包含殘差隱藏狀態的元組
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可選的時間嵌入張量
temb: Optional[torch.Tensor] = None,
# 可選的編碼器隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可選的上取樣大小
upsample_size: Optional[int] = None,
# 可選的注意力掩碼張量
attention_mask: Optional[torch.Tensor] = None,
# 可選的跨注意力引數字典
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可選的編碼器注意力掩碼張量
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: # 定義函式的返回型別為 torch.Tensor
# 如果 cross_attention_kwargs 為 None,則初始化為空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 如果 cross_attention_kwargs 中的 "scale" 存在,發出警告,提示已棄用
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 如果 attention_mask 為 None
if attention_mask is None:
# 如果 encoder_hidden_states 已定義,則進行交叉注意力,使用 encoder_attention_mask
mask = None if encoder_hidden_states is None else encoder_attention_mask
else:
# 如果 attention_mask 已定義,不檢查 encoder_attention_mask
# 這樣做是為了相容 UnCLIP,後者使用 'attention_mask' 引數作為交叉注意力掩碼
# TODO: UnCLIP 應該透過 encoder_attention_mask 參數列達交叉注意力掩碼,而不是透過 attention_mask
# 那樣可以簡化整個 if/else 語句塊
mask = attention_mask # 使用提供的 attention_mask
# 遍歷 self.resnets 和 self.attentions 的元素
for resnet, attn in zip(self.resnets, self.attentions):
# 獲取最後一項的殘差隱藏狀態
res_hidden_states = res_hidden_states_tuple[-1]
# 更新 res_hidden_states_tuple,去掉最後一項
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 將當前的 hidden_states 和殘差隱藏狀態在維度 1 上連線
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 如果處於訓練模式並且開啟了梯度檢查點
if self.training and self.gradient_checkpointing:
# 定義建立自定義前向傳播函式的內部函式
def create_custom_forward(module, return_dict=None):
# 定義自定義前向傳播函式
def custom_forward(*inputs):
# 如果 return_dict 不為 None,則返回字典形式的結果
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs) # 否則直接返回結果
return custom_forward # 返回自定義前向傳播函式
# 使用檢查點機制進行前向傳播,節省記憶體
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
# 進行注意力計算
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=mask,
**cross_attention_kwargs,
)
else:
# 直接進行前向傳播計算
hidden_states = resnet(hidden_states, temb)
# 進行注意力計算
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=mask,
**cross_attention_kwargs,
)
# 如果存在上取樣器
if self.upsamplers is not None:
# 遍歷所有上取樣器
for upsampler in self.upsamplers:
# 進行上取樣操作
hidden_states = upsampler(hidden_states, temb)
# 返回最終的隱藏狀態
return hidden_states
# 定義一個名為 KUpBlock2D 的神經網路模組,繼承自 nn.Module
class KUpBlock2D(nn.Module):
# 初始化函式,設定網路層的引數
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
resolution_idx: int, # 解析度索引
dropout: float = 0.0, # dropout 比例,預設 0
num_layers: int = 5, # 網路層數,預設 5
resnet_eps: float = 1e-5, # ResNet 中的 epsilon 值
resnet_act_fn: str = "gelu", # ResNet 中使用的啟用函式,預設 "gelu"
resnet_group_size: Optional[int] = 32, # ResNet 中的組大小,預設 32
add_upsample: bool = True, # 是否新增上取樣層,預設 True
):
# 呼叫父類初始化函式
super().__init__()
# 建立一個空的列表,用於存放 ResNet 模組
resnets = []
# 定義輸入通道的數量,設定為輸出通道的兩倍
k_in_channels = 2 * out_channels
# 定義輸出通道的數量
k_out_channels = in_channels
# 減少層數,以適應後續的迴圈
num_layers = num_layers - 1
# 建立指定層數的 ResNet 模組
for i in range(num_layers):
# 第一層的輸入通道為 k_in_channels,其餘層為 out_channels
in_channels = k_in_channels if i == 0 else out_channels
# 計算組的數量
groups = in_channels // resnet_group_size
# 計算輸出組的數量
groups_out = out_channels // resnet_group_size
# 將 ResNet 模組新增到列表中
resnets.append(
ResnetBlockCondNorm2D(
in_channels=in_channels, # 輸入通道數
out_channels=k_out_channels if (i == num_layers - 1) else out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 值
groups=groups, # 輸入組數量
groups_out=groups_out, # 輸出組數量
dropout=dropout, # dropout 比例
non_linearity=resnet_act_fn, # 啟用函式
time_embedding_norm="ada_group", # 時間嵌入規範化方式
conv_shortcut_bias=False, # 是否使用卷積快捷連線的偏置
)
)
# 將 ResNet 模組列表轉換為 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 根據是否新增上取樣層來初始化上取樣模組
if add_upsample:
# 如果新增上取樣,建立上取樣層列表
self.upsamplers = nn.ModuleList([KUpsample2D()])
else:
# 如果不新增上取樣,設定為 None
self.upsamplers = None
# 初始化梯度檢查點為 False
self.gradient_checkpointing = False
# 儲存解析度索引
self.resolution_idx = resolution_idx
# 定義前向傳播函式
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 傳入的隱藏狀態元組
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
upsample_size: Optional[int] = None, # 可選的上取樣大小
*args, # 額外的位置引數
**kwargs, # 額外的關鍵字引數
# 定義返回值為 torch.Tensor 的函式結束部分
) -> torch.Tensor:
# 檢查傳入引數是否包含 args 或者 kwargs 中的 "scale" 引數
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定義棄用的提示資訊,告知使用者應刪除 "scale" 引數
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫 deprecate 函式記錄棄用資訊
deprecate("scale", "1.0.0", deprecation_message)
# 取 res_hidden_states_tuple 的最後一個元素
res_hidden_states_tuple = res_hidden_states_tuple[-1]
# 如果 res_hidden_states_tuple 不為 None,則將其與 hidden_states 拼接
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
# 遍歷 self.resnets 列表中的每個 resnet 模組
for resnet in self.resnets:
# 如果處於訓練模式且開啟梯度檢查點功能
if self.training and self.gradient_checkpointing:
# 定義一個建立自定義前向函式的內部函式
def create_custom_forward(module):
# 定義自定義前向函式,呼叫模組處理輸入
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 檢查 PyTorch 版本是否大於等於 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用檢查點功能進行前向傳播,避免計算圖儲存
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
# 使用檢查點功能進行前向傳播
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 在非訓練模式下直接透過 resnet 處理 hidden_states
hidden_states = resnet(hidden_states, temb)
# 如果存在 upsamplers,則遍歷每個 upsampler
if self.upsamplers is not None:
for upsampler in self.upsamplers:
# 透過 upsampler 處理 hidden_states
hidden_states = upsampler(hidden_states)
# 返回處理後的 hidden_states
return hidden_states
# 定義一個 KCrossAttnUpBlock2D 類,繼承自 nn.Module
class KCrossAttnUpBlock2D(nn.Module):
# 初始化方法,定義該類的屬性
def __init__(
# 輸入通道數
in_channels: int,
# 輸出通道數
out_channels: int,
# 額外的嵌入通道數
temb_channels: int,
# 當前解析度索引
resolution_idx: int,
# dropout 機率,預設為 0.0
dropout: float = 0.0,
# 殘差網路的層數,預設為 4
num_layers: int = 4,
# 殘差網路的 epsilon 值,預設為 1e-5
resnet_eps: float = 1e-5,
# 殘差網路的啟用函式型別,預設為 "gelu"
resnet_act_fn: str = "gelu",
# 殘差網路的分組大小,預設為 32
resnet_group_size: int = 32,
# 注意力的維度,預設為 1
attention_head_dim: int = 1, # attention dim_head
# 交叉注意力的維度,預設為 768
cross_attention_dim: int = 768,
# 是否新增上取樣,預設為 True
add_upsample: bool = True,
# 是否上溢注意力,預設為 False
upcast_attention: bool = False,
):
# 呼叫父類的建構函式
super().__init__()
# 初始化一個空列表,用於儲存 ResNet 塊
resnets = []
# 初始化一個空列表,用於儲存注意力塊
attentions = []
# 判斷是否為第一個塊:輸入、輸出和時間嵌入通道是否相等
is_first_block = in_channels == out_channels == temb_channels
# 判斷是否為中間塊:輸入和輸出通道是否不相等
is_middle_block = in_channels != out_channels
# 如果是第一個塊,設定為 True 以新增自注意力
add_self_attention = True if is_first_block else False
# 設定跨注意力的標誌為 True
self.has_cross_attention = True
# 儲存注意力頭的維度
self.attention_head_dim = attention_head_dim
# 定義當前塊的輸入通道,若是第一個塊則使用輸出通道,否則使用兩倍的輸出通道
k_in_channels = out_channels if is_first_block else 2 * out_channels
# 當前塊的輸出通道為輸入通道
k_out_channels = in_channels
# 減少層數以計算迴圈中的層數
num_layers = num_layers - 1
# 根據層數迴圈建立 ResNet 塊和注意力塊
for i in range(num_layers):
# 第一個層使用 k_in_channels,後續層使用 out_channels
in_channels = k_in_channels if i == 0 else out_channels
# 計算組數,以便在 ResNet 中分組
groups = in_channels // resnet_group_size
groups_out = out_channels // resnet_group_size
# 判斷是否為中間塊並且是最後一層,設定卷積的輸出通道
if is_middle_block and (i == num_layers - 1):
conv_2d_out_channels = k_out_channels
else:
# 如果不是,設定為 None
conv_2d_out_channels = None
# 建立並新增 ResNet 塊到 resnets 列表
resnets.append(
ResnetBlockCondNorm2D(
# 輸入通道
in_channels=in_channels,
# 輸出通道
out_channels=out_channels,
# 卷積輸出通道
conv_2d_out_channels=conv_2d_out_channels,
# 時間嵌入通道
temb_channels=temb_channels,
# 設定 epsilon 值
eps=resnet_eps,
# 輸入組數
groups=groups,
# 輸出組數
groups_out=groups_out,
# dropout 機率
dropout=dropout,
# 非線性啟用函式
non_linearity=resnet_act_fn,
# 時間嵌入的歸一化方式
time_embedding_norm="ada_group",
# 是否使用卷積快捷連線的偏置
conv_shortcut_bias=False,
)
)
# 建立並新增註意力塊到 attentions 列表
attentions.append(
KAttentionBlock(
# 最後一個層使用 k_out_channels,否則使用 out_channels
k_out_channels if (i == num_layers - 1) else out_channels,
# 最後一個層注意力維度
k_out_channels // attention_head_dim
if (i == num_layers - 1)
else out_channels // attention_head_dim,
# 注意力頭的維度
attention_head_dim,
# 跨注意力維度
cross_attention_dim=cross_attention_dim,
# 時間嵌入通道
temb_channels=temb_channels,
# 是否新增註意力偏置
attention_bias=True,
# 是否新增自注意力
add_self_attention=add_self_attention,
# 跨注意力歸一化方式
cross_attention_norm="layer_norm",
# 是否上溯注意力
upcast_attention=upcast_attention,
)
)
# 將 ResNet 塊列表轉為 PyTorch 的 ModuleList
self.resnets = nn.ModuleList(resnets)
# 將注意力塊列表轉為 PyTorch 的 ModuleList
self.attentions = nn.ModuleList(attentions)
# 如果需要上取樣,則建立一個包含上取樣塊的 ModuleList
if add_upsample:
self.upsamplers = nn.ModuleList([KUpsample2D()])
else:
# 否則將上取樣塊設定為 None
self.upsamplers = None
# 初始化梯度檢查點為 False
self.gradient_checkpointing = False
# 儲存當前解析度的索引
self.resolution_idx = resolution_idx
# 定義前向傳播函式,接受多個輸入引數,返回處理後的張量
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 先前隱藏狀態的元組
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可選的編碼器隱藏狀態
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 交叉注意力的可選引數
upsample_size: Optional[int] = None, # 可選的上取樣尺寸
attention_mask: Optional[torch.Tensor] = None, # 可選的注意力掩碼
encoder_attention_mask: Optional[torch.Tensor] = None, # 可選的編碼器注意力掩碼
) -> torch.Tensor: # 函式返回型別為張量
res_hidden_states_tuple = res_hidden_states_tuple[-1] # 獲取最後一個隱藏狀態
if res_hidden_states_tuple is not None: # 檢查是否存在先前的隱藏狀態
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) # 拼接當前和先前的隱藏狀態
for resnet, attn in zip(self.resnets, self.attentions): # 遍歷每個殘差網路和注意力層
if self.training and self.gradient_checkpointing: # 檢查是否在訓練模式且使用梯度檢查點
def create_custom_forward(module, return_dict=None): # 定義建立自定義前向函式的內部函式
def custom_forward(*inputs): # 自定義前向函式
if return_dict is not None: # 檢查是否需要返回字典
return module(*inputs, return_dict=return_dict) # 使用返回字典的方式呼叫模組
else:
return module(*inputs) # 普通呼叫模組
return custom_forward # 返回自定義前向函式
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} # 根據Torch版本設定檢查點引數
hidden_states = torch.utils.checkpoint.checkpoint( # 使用檢查點進行前向傳播以節省記憶體
create_custom_forward(resnet), # 建立自定義前向函式
hidden_states, # 輸入隱藏狀態
temb, # 輸入時間嵌入
**ckpt_kwargs, # 傳遞檢查點引數
)
hidden_states = attn( # 透過注意力層處理隱藏狀態
hidden_states, # 輸入隱藏狀態
encoder_hidden_states=encoder_hidden_states, # 輸入編碼器隱藏狀態
emb=temb, # 輸入時間嵌入
attention_mask=attention_mask, # 輸入注意力掩碼
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力引數
encoder_attention_mask=encoder_attention_mask, # 編碼器注意力掩碼
)
else: # 如果不使用梯度檢查點
hidden_states = resnet(hidden_states, temb) # 直接透過殘差網路處理隱藏狀態
hidden_states = attn( # 透過注意力層處理隱藏狀態
hidden_states, # 輸入隱藏狀態
encoder_hidden_states=encoder_hidden_states, # 輸入編碼器隱藏狀態
emb=temb, # 輸入時間嵌入
attention_mask=attention_mask, # 輸入注意力掩碼
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力引數
encoder_attention_mask=encoder_attention_mask, # 編碼器注意力掩碼
)
if self.upsamplers is not None: # 檢查是否有上取樣層
for upsampler in self.upsamplers: # 遍歷每個上取樣層
hidden_states = upsampler(hidden_states) # 透過上取樣層處理隱藏狀態
return hidden_states # 返回處理後的隱藏狀態
# 可以潛在地更名為 `No-feed-forward` 注意力
class KAttentionBlock(nn.Module):
r"""
基本的 Transformer 塊。
引數:
dim (`int`): 輸入和輸出的通道數。
num_attention_heads (`int`): 用於多頭注意力的頭數。
attention_head_dim (`int`): 每個頭的通道數。
dropout (`float`, *可選*, 預設為 0.0): 使用的丟棄機率。
cross_attention_dim (`int`, *可選*): 用於交叉注意力的 encoder_hidden_states 向量的大小。
attention_bias (`bool`, *可選*, 預設為 `False`):
配置注意力層是否應該包含偏置引數。
upcast_attention (`bool`, *可選*, 預設為 `False`):
設定為 `True` 以將注意力計算上調為 `float32`。
temb_channels (`int`, *可選*, 預設為 768):
令牌嵌入中的通道數。
add_self_attention (`bool`, *可選*, 預設為 `False`):
設定為 `True` 以將自注意力新增到塊中。
cross_attention_norm (`str`, *可選*, 預設為 `None`):
用於交叉注意力的規範化型別。可以是 `None`、`layer_norm` 或 `group_norm`。
group_size (`int`, *可選*, 預設為 32):
用於組規範化將通道分成的組數。
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
upcast_attention: bool = False,
temb_channels: int = 768, # 用於 ada_group_norm
add_self_attention: bool = False,
cross_attention_norm: Optional[str] = None,
group_size: int = 32,
):
# 呼叫父類建構函式,初始化 nn.Module
super().__init__()
# 設定是否新增自注意力的標誌
self.add_self_attention = add_self_attention
# 1. 自注意力
if add_self_attention:
# 初始化自注意力的歸一化層
self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
# 初始化自注意力機制
self.attn1 = Attention(
query_dim=dim, # 查詢向量的維度
heads=num_attention_heads, # 注意力頭數
dim_head=attention_head_dim, # 每個頭的維度
dropout=dropout, # 丟棄率
bias=attention_bias, # 是否使用偏置
cross_attention_dim=None, # 交叉注意力維度
cross_attention_norm=None, # 交叉注意力的歸一化
)
# 2. 交叉注意力
# 初始化交叉注意力的歸一化層
self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
# 初始化交叉注意力機制
self.attn2 = Attention(
query_dim=dim, # 查詢向量的維度
cross_attention_dim=cross_attention_dim, # 交叉注意力維度
heads=num_attention_heads, # 注意力頭數
dim_head=attention_head_dim, # 每個頭的維度
dropout=dropout, # 丟棄率
bias=attention_bias, # 是否使用偏置
upcast_attention=upcast_attention, # 是否上調注意力計算
cross_attention_norm=cross_attention_norm, # 交叉注意力的歸一化
)
# 將隱藏狀態轉換為 3D 張量,包含 batch size, height*weight 和通道數
def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
# 重新排列維度,並調整形狀為 (batch size, height*weight, -1)
return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
# 將隱藏狀態轉換為 4D 張量,包含 batch size, 通道數, height 和 weight
def _to_4d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
# 重新排列維度,並調整形狀為 (batch size, -1, height, weight)
return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
# TODO: 將 emb 標記為非可選 (self.norm2 需要它)。
# 需要評估對位置引數介面更改的影響。
emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# 如果 cross_attention_kwargs 為空,則初始化為一個空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 檢查 "scale" 引數是否存在,如果存在則發出警告
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 1. 自注意力
if self.add_self_attention:
# 使用 norm1 對隱藏狀態進行歸一化處理
norm_hidden_states = self.norm1(hidden_states, emb)
# 獲取歸一化後狀態的高度和寬度
height, weight = norm_hidden_states.shape[2:]
# 將歸一化後的隱藏狀態轉換為 3D 張量
norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
# 執行自注意力操作
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
# 將自注意力輸出轉換為 4D 張量
attn_output = self._to_4d(attn_output, height, weight)
# 將自注意力輸出與原始隱藏狀態相加
hidden_states = attn_output + hidden_states
# 2. 交叉注意力或無交叉注意力
# 使用 norm2 對隱藏狀態進行歸一化處理
norm_hidden_states = self.norm2(hidden_states, emb)
# 獲取歸一化後狀態的高度和寬度
height, weight = norm_hidden_states.shape[2:]
# 將歸一化後的隱藏狀態轉換為 3D 張量
norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
# 執行交叉注意力操作
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask,
**cross_attention_kwargs,
)
# 將交叉注意力輸出轉換為 4D 張量
attn_output = self._to_4d(attn_output, height, weight)
# 將交叉注意力輸出與隱藏狀態相加
hidden_states = attn_output + hidden_states
# 返回最終的隱藏狀態
return hidden_states