diffusers-原始碼解析-十三-

绝不原创的飞龙發表於2024-10-22

diffusers 原始碼解析(十三)

.\diffusers\models\unets\unet_2d.py

# 版權宣告,表示該程式碼由 HuggingFace 團隊所有
# 
# 根據 Apache 2.0 許可證進行許可;
# 除非遵循許可證,否則不得使用此檔案。
# 可以在以下地址獲取許可證的副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非適用法律要求或書面同意,軟體按 "原樣" 分發,
# 不提供任何形式的保證或條件,無論是明示或暗示。
# 請參閱許可證以瞭解有關許可權和
# 限制的具體資訊。
from dataclasses import dataclass  # 從 dataclasses 模組匯入 dataclass 裝飾器
from typing import Optional, Tuple, Union  # 從 typing 模組匯入型別提示相關的類

import torch  # 匯入 PyTorch 庫
import torch.nn as nn  # 匯入 PyTorch 的神經網路模組

from ...configuration_utils import ConfigMixin, register_to_config  # 從配置工具匯入配置混合類和註冊函式
from ...utils import BaseOutput  # 從工具模組匯入基礎輸出類
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps  # 從嵌入模組匯入相關類
from ..modeling_utils import ModelMixin  # 從建模工具匯入模型混合類
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block  # 從 UNet 2D 塊匯入相關元件


@dataclass
class UNet2DOutput(BaseOutput):  # 定義 UNet2DOutput 類,繼承自 BaseOutput
    """
    [`UNet2DModel`] 的輸出類。

    引數:
        sample (`torch.Tensor` 形狀為 `(batch_size, num_channels, height, width)`):
            從模型最後一層輸出的隱藏狀態。
    """

    sample: torch.Tensor  # 定義輸出樣本的屬性,型別為 torch.Tensor


class UNet2DModel(ModelMixin, ConfigMixin):  # 定義 UNet2DModel 類,繼承自 ModelMixin 和 ConfigMixin
    r"""
    一個 2D UNet 模型,接收一個有噪聲的樣本和一個時間步,返回一個樣本形狀的輸出。

    此模型繼承自 [`ModelMixin`]。請檢視超類文件以瞭解其為所有模型實現的通用方法
    (例如下載或儲存)。
    """

    @register_to_config  # 使用裝飾器將該方法註冊到配置中
    def __init__(  # 定義初始化方法
        self,
        sample_size: Optional[Union[int, Tuple[int, int]]] = None,  # 可選的樣本大小,可以是整數或整數元組
        in_channels: int = 3,  # 輸入通道數,預設為 3
        out_channels: int = 3,  # 輸出通道數,預設為 3
        center_input_sample: bool = False,  # 是否將輸入樣本居中,預設為 False
        time_embedding_type: str = "positional",  # 時間嵌入型別,預設為 "positional"
        freq_shift: int = 0,  # 頻率偏移量,預設為 0
        flip_sin_to_cos: bool = True,  # 是否將正弦函式翻轉為餘弦函式,預設為 True
        down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),  # 下采樣塊型別
        up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),  # 上取樣塊型別
        block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),  # 各塊輸出通道數
        layers_per_block: int = 2,  # 每個塊的層數,預設為 2
        mid_block_scale_factor: float = 1,  # 中間塊縮放因子,預設為 1
        downsample_padding: int = 1,  # 下采樣的填充大小,預設為 1
        downsample_type: str = "conv",  # 下采樣型別,預設為卷積
        upsample_type: str = "conv",  # 上取樣型別,預設為卷積
        dropout: float = 0.0,  # dropout 機率,預設為 0.0
        act_fn: str = "silu",  # 啟用函式型別,預設為 "silu"
        attention_head_dim: Optional[int] = 8,  # 注意力頭維度,預設為 8
        norm_num_groups: int = 32,  # 規範化組數量,預設為 32
        attn_norm_num_groups: Optional[int] = None,  # 注意力規範化組數量,可選
        norm_eps: float = 1e-5,  # 規範化的 epsilon 值,預設為 1e-5
        resnet_time_scale_shift: str = "default",  # ResNet 時間縮放偏移型別,預設為 "default"
        add_attention: bool = True,  # 是否新增註意力機制,預設為 True
        class_embed_type: Optional[str] = None,  # 類別嵌入型別,可選
        num_class_embeds: Optional[int] = None,  # 類別嵌入數量,可選
        num_train_timesteps: Optional[int] = None,  # 訓練時間步數量,可選
    # 定義一個名為 forward 的方法
        def forward(
            # 輸入引數 sample,型別為 torch.Tensor,表示樣本資料
            self,
            sample: torch.Tensor,
            # 輸入引數 timestep,可以是 torch.Tensor、float 或 int,表示時間步
            timestep: Union[torch.Tensor, float, int],
            # 可選引數 class_labels,型別為 torch.Tensor,表示分類標籤,預設為 None
            class_labels: Optional[torch.Tensor] = None,
            # 可選引數 return_dict,型別為 bool,表示是否以字典形式返回結果,預設為 True
            return_dict: bool = True,

.\diffusers\models\unets\unet_2d_blocks.py

# 版權所有 2024 HuggingFace 團隊,保留所有權利。
#
# 根據 Apache 許可證第 2.0 版(“許可證”)進行許可;
# 除非遵循許可證,否則不得使用此檔案。
# 可以在以下網址獲取許可證副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面協議另有規定,否則根據許可證分發的軟體按“現狀”基礎分發,
# 不提供任何形式的保證或條件,無論是明示還是暗示。
# 請參閱許可證以瞭解特定語言的許可權和
# 限制。
from typing import Any, Dict, Optional, Tuple, Union  # 匯入型別註解相關的模組

import numpy as np  # 匯入 NumPy 庫用於數值計算
import torch  # 匯入 PyTorch 庫用於深度學習
import torch.nn.functional as F  # 匯入 PyTorch 的功能性 API
from torch import nn  # 匯入 PyTorch 的神經網路模組

from ...utils import deprecate, is_torch_version, logging  # 從工具模組匯入日誌和版本檢測功能
from ...utils.torch_utils import apply_freeu  # 從 PyTorch 工具模組匯入 apply_freeu 函式
from ..activations import get_activation  # 從啟用函式模組匯入 get_activation 函式
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0  # 匯入注意力處理器相關的類
from ..normalization import AdaGroupNorm  # 從歸一化模組匯入 AdaGroupNorm 類
from ..resnet import (  # 從 ResNet 模組匯入多個下采樣和上取樣類
    Downsample2D,
    FirDownsample2D,
    FirUpsample2D,
    KDownsample2D,
    KUpsample2D,
    ResnetBlock2D,
    ResnetBlockCondNorm2D,
    Upsample2D,
)
from ..transformers.dual_transformer_2d import DualTransformer2DModel  # 匯入雙重變換器模型
from ..transformers.transformer_2d import Transformer2DModel  # 匯入二維變換器模型


logger = logging.get_logger(__name__)  # 獲取當前模組的日誌記錄器例項


def get_down_block(  # 定義獲取下采樣塊的函式
    down_block_type: str,  # 下采樣塊的型別
    num_layers: int,  # 下采樣層的數量
    in_channels: int,  # 輸入通道數
    out_channels: int,  # 輸出通道數
    temb_channels: int,  # 時間嵌入通道數
    add_downsample: bool,  # 是否新增下采樣標誌
    resnet_eps: float,  # ResNet 中的 epsilon 引數
    resnet_act_fn: str,  # ResNet 使用的啟用函式
    transformer_layers_per_block: int = 1,  # 每個塊中的變換器層數,預設為 1
    num_attention_heads: Optional[int] = None,  # 注意力頭的數量,預設為 None
    resnet_groups: Optional[int] = None,  # ResNet 中的組數,預設為 None
    cross_attention_dim: Optional[int] = None,  # 交叉注意力維度,預設為 None
    downsample_padding: Optional[int] = None,  # 下采樣填充引數,預設為 None
    dual_cross_attention: bool = False,  # 是否使用雙重交叉注意力標誌
    use_linear_projection: bool = False,  # 是否使用線性投影標誌
    only_cross_attention: bool = False,  # 是否僅使用交叉注意力標誌
    upcast_attention: bool = False,  # 是否上升注意力標誌
    resnet_time_scale_shift: str = "default",  # ResNet 時間縮放移位,預設為“default”
    attention_type: str = "default",  # 注意力型別,預設為“default”
    resnet_skip_time_act: bool = False,  # ResNet 跳過時間啟用標誌
    resnet_out_scale_factor: float = 1.0,  # ResNet 輸出縮放因子,預設為 1.0
    cross_attention_norm: Optional[str] = None,  # 交叉注意力歸一化,預設為 None
    attention_head_dim: Optional[int] = None,  # 注意力頭維度,預設為 None
    downsample_type: Optional[str] = None,  # 下采樣型別,預設為 None
    dropout: float = 0.0,  # dropout 比例,預設為 0.0
):
    # 如果沒有定義注意力頭維度,預設設定為頭的數量
    if attention_head_dim is None:
        logger.warning(  # 記錄警告資訊
            f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."  # 提醒使用者使用預設的注意力頭維度
        )
        attention_head_dim = num_attention_heads  # 將注意力頭維度設定為頭的數量

    down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type  # 處理下采樣塊型別的字串
    # 檢查下行塊的型別是否為 "DownBlock2D"
        if down_block_type == "DownBlock2D":
            # 返回 DownBlock2D 例項,傳入相關引數
            return DownBlock2D(
                # 傳入層數
                num_layers=num_layers,
                # 輸入通道數
                in_channels=in_channels,
                # 輸出通道數
                out_channels=out_channels,
                # 時間嵌入通道數
                temb_channels=temb_channels,
                # dropout 比率
                dropout=dropout,
                # 是否新增下采樣
                add_downsample=add_downsample,
                # ResNet 的 epsilon 值
                resnet_eps=resnet_eps,
                # ResNet 的啟用函式
                resnet_act_fn=resnet_act_fn,
                # ResNet 的分組數
                resnet_groups=resnet_groups,
                # 下采樣的填充
                downsample_padding=downsample_padding,
                # ResNet 的時間尺度偏移
                resnet_time_scale_shift=resnet_time_scale_shift,
            )
        # 檢查下行塊的型別是否為 "ResnetDownsampleBlock2D"
        elif down_block_type == "ResnetDownsampleBlock2D":
            # 返回 ResnetDownsampleBlock2D 例項,傳入相關引數
            return ResnetDownsampleBlock2D(
                # 傳入層數
                num_layers=num_layers,
                # 輸入通道數
                in_channels=in_channels,
                # 輸出通道數
                out_channels=out_channels,
                # 時間嵌入通道數
                temb_channels=temb_channels,
                # dropout 比率
                dropout=dropout,
                # 是否新增下采樣
                add_downsample=add_downsample,
                # ResNet 的 epsilon 值
                resnet_eps=resnet_eps,
                # ResNet 的啟用函式
                resnet_act_fn=resnet_act_fn,
                # ResNet 的分組數
                resnet_groups=resnet_groups,
                # ResNet 的時間尺度偏移
                resnet_time_scale_shift=resnet_time_scale_shift,
                # ResNet 的時間啟用跳過標誌
                skip_time_act=resnet_skip_time_act,
                # ResNet 的輸出縮放因子
                output_scale_factor=resnet_out_scale_factor,
            )
        # 檢查下行塊的型別是否為 "AttnDownBlock2D"
        elif down_block_type == "AttnDownBlock2D":
            # 如果不新增下采樣,則將下采樣型別設為 None
            if add_downsample is False:
                downsample_type = None
            else:
                # 如果新增下采樣,則預設下采樣型別為 'conv'
                downsample_type = downsample_type or "conv"  # default to 'conv'
            # 返回 AttnDownBlock2D 例項,傳入相關引數
            return AttnDownBlock2D(
                # 傳入層數
                num_layers=num_layers,
                # 輸入通道數
                in_channels=in_channels,
                # 輸出通道數
                out_channels=out_channels,
                # 時間嵌入通道數
                temb_channels=temb_channels,
                # dropout 比率
                dropout=dropout,
                # ResNet 的 epsilon 值
                resnet_eps=resnet_eps,
                # ResNet 的啟用函式
                resnet_act_fn=resnet_act_fn,
                # ResNet 的分組數
                resnet_groups=resnet_groups,
                # 下采樣的填充
                downsample_padding=downsample_padding,
                # 注意力頭的維度
                attention_head_dim=attention_head_dim,
                # ResNet 的時間尺度偏移
                resnet_time_scale_shift=resnet_time_scale_shift,
                # 下采樣型別
                downsample_type=downsample_type,
            )
    # 檢查下行塊型別是否為 CrossAttnDownBlock2D
    elif down_block_type == "CrossAttnDownBlock2D":
        # 如果 cross_attention_dim 未指定,則丟擲錯誤
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
        # 返回 CrossAttnDownBlock2D 例項,使用提供的引數進行初始化
        return CrossAttnDownBlock2D(
            # 設定層的數量
            num_layers=num_layers,
            # 每個塊的變換層數量
            transformer_layers_per_block=transformer_layers_per_block,
            # 輸入通道數
            in_channels=in_channels,
            # 輸出通道數
            out_channels=out_channels,
            # 時間嵌入通道數
            temb_channels=temb_channels,
            # dropout 比例
            dropout=dropout,
            # 是否新增下采樣
            add_downsample=add_downsample,
            # ResNet 中的 epsilon 值
            resnet_eps=resnet_eps,
            # ResNet 啟用函式
            resnet_act_fn=resnet_act_fn,
            # ResNet 分組數
            resnet_groups=resnet_groups,
            # 下采樣填充
            downsample_padding=downsample_padding,
            # 跨注意力維度
            cross_attention_dim=cross_attention_dim,
            # 注意力頭數
            num_attention_heads=num_attention_heads,
            # 是否使用雙向跨注意力
            dual_cross_attention=dual_cross_attention,
            # 是否使用線性投影
            use_linear_projection=use_linear_projection,
            # 是否僅使用跨注意力
            only_cross_attention=only_cross_attention,
            # 是否上溯注意力
            upcast_attention=upcast_attention,
            # ResNet 時間尺度偏移
            resnet_time_scale_shift=resnet_time_scale_shift,
            # 注意力型別
            attention_type=attention_type,
        )
    # 檢查下行塊型別是否為 SimpleCrossAttnDownBlock2D
    elif down_block_type == "SimpleCrossAttnDownBlock2D":
        # 如果 cross_attention_dim 未指定,則丟擲錯誤
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
        # 返回 SimpleCrossAttnDownBlock2D 例項,使用提供的引數進行初始化
        return SimpleCrossAttnDownBlock2D(
            # 設定層的數量
            num_layers=num_layers,
            # 輸入通道數
            in_channels=in_channels,
            # 輸出通道數
            out_channels=out_channels,
            # 時間嵌入通道數
            temb_channels=temb_channels,
            # dropout 比例
            dropout=dropout,
            # 是否新增下采樣
            add_downsample=add_downsample,
            # ResNet 中的 epsilon 值
            resnet_eps=resnet_eps,
            # ResNet 啟用函式
            resnet_act_fn=resnet_act_fn,
            # ResNet 分組數
            resnet_groups=resnet_groups,
            # 跨注意力維度
            cross_attention_dim=cross_attention_dim,
            # 注意力頭的維度
            attention_head_dim=attention_head_dim,
            # ResNet 時間尺度偏移
            resnet_time_scale_shift=resnet_time_scale_shift,
            # 是否跳過時間啟用
            skip_time_act=resnet_skip_time_act,
            # 輸出縮放因子
            output_scale_factor=resnet_out_scale_factor,
            # 是否僅使用跨注意力
            only_cross_attention=only_cross_attention,
            # 跨注意力規範化
            cross_attention_norm=cross_attention_norm,
        )
    # 檢查下行塊型別是否為 SkipDownBlock2D
    elif down_block_type == "SkipDownBlock2D":
        # 返回 SkipDownBlock2D 例項,使用提供的引數進行初始化
        return SkipDownBlock2D(
            # 設定層的數量
            num_layers=num_layers,
            # 輸入通道數
            in_channels=in_channels,
            # 輸出通道數
            out_channels=out_channels,
            # 時間嵌入通道數
            temb_channels=temb_channels,
            # dropout 比例
            dropout=dropout,
            # 是否新增下采樣
            add_downsample=add_downsample,
            # ResNet 中的 epsilon 值
            resnet_eps=resnet_eps,
            # ResNet 啟用函式
            resnet_act_fn=resnet_act_fn,
            # 下采樣填充
            downsample_padding=downsample_padding,
            # ResNet 時間尺度偏移
            resnet_time_scale_shift=resnet_time_scale_shift,
        )
    # 檢查下采樣塊的型別是否為 "AttnSkipDownBlock2D"
    elif down_block_type == "AttnSkipDownBlock2D":
        # 返回一個 AttnSkipDownBlock2D 物件,初始化引數傳入
        return AttnSkipDownBlock2D(
            # 設定層數
            num_layers=num_layers,
            # 輸入通道數
            in_channels=in_channels,
            # 輸出通道數
            out_channels=out_channels,
            # 時間嵌入通道數
            temb_channels=temb_channels,
            # dropout 率
            dropout=dropout,
            # 是否新增下采樣
            add_downsample=add_downsample,
            # ResNet 的 epsilon 值
            resnet_eps=resnet_eps,
            # ResNet 的啟用函式
            resnet_act_fn=resnet_act_fn,
            # 注意力頭的維度
            attention_head_dim=attention_head_dim,
            # ResNet 時間縮放偏移
            resnet_time_scale_shift=resnet_time_scale_shift,
        )
    # 檢查下采樣塊的型別是否為 "DownEncoderBlock2D"
    elif down_block_type == "DownEncoderBlock2D":
        # 返回一個 DownEncoderBlock2D 物件,初始化引數傳入
        return DownEncoderBlock2D(
            # 設定層數
            num_layers=num_layers,
            # 輸入通道數
            in_channels=in_channels,
            # 輸出通道數
            out_channels=out_channels,
            # dropout 率
            dropout=dropout,
            # 是否新增下采樣
            add_downsample=add_downsample,
            # ResNet 的 epsilon 值
            resnet_eps=resnet_eps,
            # ResNet 的啟用函式
            resnet_act_fn=resnet_act_fn,
            # ResNet 的組數
            resnet_groups=resnet_groups,
            # 下采樣填充
            downsample_padding=downsample_padding,
            # ResNet 時間縮放偏移
            resnet_time_scale_shift=resnet_time_scale_shift,
        )
    # 檢查下采樣塊的型別是否為 "AttnDownEncoderBlock2D"
    elif down_block_type == "AttnDownEncoderBlock2D":
        # 返回一個 AttnDownEncoderBlock2D 物件,初始化引數傳入
        return AttnDownEncoderBlock2D(
            # 設定層數
            num_layers=num_layers,
            # 輸入通道數
            in_channels=in_channels,
            # 輸出通道數
            out_channels=out_channels,
            # dropout 率
            dropout=dropout,
            # 是否新增下采樣
            add_downsample=add_downsample,
            # ResNet 的 epsilon 值
            resnet_eps=resnet_eps,
            # ResNet 的啟用函式
            resnet_act_fn=resnet_act_fn,
            # ResNet 的組數
            resnet_groups=resnet_groups,
            # 下采樣填充
            downsample_padding=downsample_padding,
            # 注意力頭的維度
            attention_head_dim=attention_head_dim,
            # ResNet 時間縮放偏移
            resnet_time_scale_shift=resnet_time_scale_shift,
        )
    # 檢查下采樣塊的型別是否為 "KDownBlock2D"
    elif down_block_type == "KDownBlock2D":
        # 返回一個 KDownBlock2D 物件,初始化引數傳入
        return KDownBlock2D(
            # 設定層數
            num_layers=num_layers,
            # 輸入通道數
            in_channels=in_channels,
            # 輸出通道數
            out_channels=out_channels,
            # 時間嵌入通道數
            temb_channels=temb_channels,
            # dropout 率
            dropout=dropout,
            # 是否新增下采樣
            add_downsample=add_downsample,
            # ResNet 的 epsilon 值
            resnet_eps=resnet_eps,
            # ResNet 的啟用函式
            resnet_act_fn=resnet_act_fn,
        )
    # 檢查下采樣塊的型別是否為 "KCrossAttnDownBlock2D"
    elif down_block_type == "KCrossAttnDownBlock2D":
        # 返回一個 KCrossAttnDownBlock2D 物件,初始化引數傳入
        return KCrossAttnDownBlock2D(
            # 設定層數
            num_layers=num_layers,
            # 輸入通道數
            in_channels=in_channels,
            # 輸出通道數
            out_channels=out_channels,
            # 時間嵌入通道數
            temb_channels=temb_channels,
            # dropout 率
            dropout=dropout,
            # 是否新增下采樣
            add_downsample=add_downsample,
            # ResNet 的 epsilon 值
            resnet_eps=resnet_eps,
            # ResNet 的啟用函式
            resnet_act_fn=resnet_act_fn,
            # 跨注意力的維度
            cross_attention_dim=cross_attention_dim,
            # 注意力頭的維度
            attention_head_dim=attention_head_dim,
            # 是否新增自注意力
            add_self_attention=True if not add_downsample else False,
        )
    # 如果下采樣塊型別不匹配,則丟擲異常
    raise ValueError(f"{down_block_type} does not exist.")
# 根據給定引數生成中間塊(mid block)
def get_mid_block(
    # 中間塊的型別
    mid_block_type: str,
    # 嵌入通道數
    temb_channels: int,
    # 輸入通道數
    in_channels: int,
    # ResNet 的 epsilon 值
    resnet_eps: float,
    # ResNet 的啟用函式型別
    resnet_act_fn: str,
    # ResNet 的組數
    resnet_groups: int,
    # 輸出縮放因子,預設為 1.0
    output_scale_factor: float = 1.0,
    # 每個塊的變換層數,預設為 1
    transformer_layers_per_block: int = 1,
    # 注意力頭的數量,預設為 None
    num_attention_heads: Optional[int] = None,
    # 跨注意力的維度,預設為 None
    cross_attention_dim: Optional[int] = None,
    # 是否使用雙重跨注意力,預設為 False
    dual_cross_attention: bool = False,
    # 是否使用線性投影,預設為 False
    use_linear_projection: bool = False,
    # 是否僅使用跨注意力作為中間塊,預設為 False
    mid_block_only_cross_attention: bool = False,
    # 是否提升注意力精度,預設為 False
    upcast_attention: bool = False,
    # ResNet 的時間縮放偏移,預設為 "default"
    resnet_time_scale_shift: str = "default",
    # 注意力型別,預設為 "default"
    attention_type: str = "default",
    # ResNet 是否跳過時間啟用,預設為 False
    resnet_skip_time_act: bool = False,
    # 跨注意力的歸一化型別,預設為 None
    cross_attention_norm: Optional[str] = None,
    # 注意力頭的維度,預設為 1
    attention_head_dim: Optional[int] = 1,
    # dropout 機率,預設為 0.0
    dropout: float = 0.0,
):
    # 根據中間塊的型別生成對應的物件
    if mid_block_type == "UNetMidBlock2DCrossAttn":
        # 建立 UNet 的 2D 跨注意力中間塊
        return UNetMidBlock2DCrossAttn(
            # 設定變換層數
            transformer_layers_per_block=transformer_layers_per_block,
            # 設定輸入通道數
            in_channels=in_channels,
            # 設定嵌入通道數
            temb_channels=temb_channels,
            # 設定 dropout 機率
            dropout=dropout,
            # 設定 ResNet epsilon 值
            resnet_eps=resnet_eps,
            # 設定 ResNet 啟用函式
            resnet_act_fn=resnet_act_fn,
            # 設定輸出縮放因子
            output_scale_factor=output_scale_factor,
            # 設定時間縮放偏移
            resnet_time_scale_shift=resnet_time_scale_shift,
            # 設定跨注意力維度
            cross_attention_dim=cross_attention_dim,
            # 設定注意力頭數量
            num_attention_heads=num_attention_heads,
            # 設定 ResNet 組數
            resnet_groups=resnet_groups,
            # 設定是否使用雙重跨注意力
            dual_cross_attention=dual_cross_attention,
            # 設定是否使用線性投影
            use_linear_projection=use_linear_projection,
            # 設定是否提升注意力精度
            upcast_attention=upcast_attention,
            # 設定注意力型別
            attention_type=attention_type,
        )
    # 檢查是否為簡單跨注意力中間塊
    elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
        # 建立 UNet 的 2D 簡單跨注意力中間塊
        return UNetMidBlock2DSimpleCrossAttn(
            # 設定輸入通道數
            in_channels=in_channels,
            # 設定嵌入通道數
            temb_channels=temb_channels,
            # 設定 dropout 機率
            dropout=dropout,
            # 設定 ResNet epsilon 值
            resnet_eps=resnet_eps,
            # 設定 ResNet 啟用函式
            resnet_act_fn=resnet_act_fn,
            # 設定輸出縮放因子
            output_scale_factor=output_scale_factor,
            # 設定跨注意力維度
            cross_attention_dim=cross_attention_dim,
            # 設定注意力頭的維度
            attention_head_dim=attention_head_dim,
            # 設定 ResNet 組數
            resnet_groups=resnet_groups,
            # 設定時間縮放偏移
            resnet_time_scale_shift=resnet_time_scale_shift,
            # 設定是否跳過時間啟用
            skip_time_act=resnet_skip_time_act,
            # 設定是否僅使用跨注意力
            only_cross_attention=mid_block_only_cross_attention,
            # 設定跨注意力的歸一化型別
            cross_attention_norm=cross_attention_norm,
        )
    # 檢查是否為標準的 2D 中間塊
    elif mid_block_type == "UNetMidBlock2D":
        # 建立 UNet 的 2D 中間塊
        return UNetMidBlock2D(
            # 設定輸入通道數
            in_channels=in_channels,
            # 設定嵌入通道數
            temb_channels=temb_channels,
            # 設定 dropout 機率
            dropout=dropout,
            # 設定層數為 0
            num_layers=0,
            # 設定 ResNet epsilon 值
            resnet_eps=resnet_eps,
            # 設定 ResNet 啟用函式
            resnet_act_fn=resnet_act_fn,
            # 設定輸出縮放因子
            output_scale_factor=output_scale_factor,
            # 設定 ResNet 組數
            resnet_groups=resnet_groups,
            # 設定時間縮放偏移
            resnet_time_scale_shift=resnet_time_scale_shift,
            # 不新增註意力
            add_attention=False,
        )
    # 檢查中間塊型別是否為 None
    elif mid_block_type is None:
        # 返回 None
        return None
    # 丟擲未知型別的異常
    else:
        raise ValueError(f"unknown mid_block_type : {mid_block_type}")
    # 輸出通道的數量
        out_channels: int,
        # 前一層輸出通道的數量
        prev_output_channel: int,
        # 嵌入層通道的數量
        temb_channels: int,
        # 是否新增上取樣層
        add_upsample: bool,
        # ResNet 中的 epsilon 值,用於數值穩定性
        resnet_eps: float,
        # ResNet 中的啟用函式型別
        resnet_act_fn: str,
        # 解析度索引,預設為 None
        resolution_idx: Optional[int] = None,
        # 每個塊中的變換層數量
        transformer_layers_per_block: int = 1,
        # 注意力頭的數量,預設為 None
        num_attention_heads: Optional[int] = None,
        # ResNet 中的組數量,預設為 None
        resnet_groups: Optional[int] = None,
        # 交叉注意力的維度,預設為 None
        cross_attention_dim: Optional[int] = None,
        # 是否使用雙重交叉注意力
        dual_cross_attention: bool = False,
        # 是否使用線性投影
        use_linear_projection: bool = False,
        # 是否僅使用交叉注意力
        only_cross_attention: bool = False,
        # 是否上取樣時提高注意力精度
        upcast_attention: bool = False,
        # ResNet 時間縮放偏移的型別,預設為 "default"
        resnet_time_scale_shift: str = "default",
        # 注意力型別,預設為 "default"
        attention_type: str = "default",
        # ResNet 中跳過時間啟用的標誌
        resnet_skip_time_act: bool = False,
        # ResNet 輸出縮放因子,預設為 1.0
        resnet_out_scale_factor: float = 1.0,
        # 交叉注意力的歸一化方式,預設為 None
        cross_attention_norm: Optional[str] = None,
        # 注意力頭的維度,預設為 None
        attention_head_dim: Optional[int] = None,
        # 上取樣型別,預設為 None
        upsample_type: Optional[str] = None,
        # 丟棄率,預設為 0.0
        dropout: float = 0.0,
) -> nn.Module:  # 指定該函式的返回型別為 nn.Module
    # 如果未定義注意力頭的維度,預設設定為注意力頭的數量
    if attention_head_dim is None:
        logger.warning(  # 記錄警告資訊
            f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."  # 提示使用者提供 attention_head_dim
        )
        attention_head_dim = num_attention_heads  # 將 attention_head_dim 設定為 num_attention_heads

    # 如果 up_block_type 以 "UNetRes" 開頭,則去掉字首
    up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
    # 檢查 up_block_type 是否為 "UpBlock2D"
    if up_block_type == "UpBlock2D":
        return UpBlock2D(  # 返回 UpBlock2D 物件
            num_layers=num_layers,  # 設定網路層數
            in_channels=in_channels,  # 設定輸入通道數
            out_channels=out_channels,  # 設定輸出通道數
            prev_output_channel=prev_output_channel,  # 設定前一個輸出通道
            temb_channels=temb_channels,  # 設定時間嵌入通道數
            resolution_idx=resolution_idx,  # 設定解析度索引
            dropout=dropout,  # 設定 dropout 引數
            add_upsample=add_upsample,  # 設定是否新增上取樣
            resnet_eps=resnet_eps,  # 設定 ResNet 的 epsilon
            resnet_act_fn=resnet_act_fn,  # 設定 ResNet 的啟用函式
            resnet_groups=resnet_groups,  # 設定 ResNet 的組數
            resnet_time_scale_shift=resnet_time_scale_shift,  # 設定 ResNet 的時間縮放偏移
        )
    # 檢查 up_block_type 是否為 "ResnetUpsampleBlock2D"
    elif up_block_type == "ResnetUpsampleBlock2D":
        return ResnetUpsampleBlock2D(  # 返回 ResnetUpsampleBlock2D 物件
            num_layers=num_layers,  # 設定網路層數
            in_channels=in_channels,  # 設定輸入通道數
            out_channels=out_channels,  # 設定輸出通道數
            prev_output_channel=prev_output_channel,  # 設定前一個輸出通道
            temb_channels=temb_channels,  # 設定時間嵌入通道數
            resolution_idx=resolution_idx,  # 設定解析度索引
            dropout=dropout,  # 設定 dropout 引數
            add_upsample=add_upsample,  # 設定是否新增上取樣
            resnet_eps=resnet_eps,  # 設定 ResNet 的 epsilon
            resnet_act_fn=resnet_act_fn,  # 設定 ResNet 的啟用函式
            resnet_groups=resnet_groups,  # 設定 ResNet 的組數
            resnet_time_scale_shift=resnet_time_scale_shift,  # 設定 ResNet 的時間縮放偏移
            skip_time_act=resnet_skip_time_act,  # 設定是否跳過時間啟用
            output_scale_factor=resnet_out_scale_factor,  # 設定輸出縮放因子
        )
    # 檢查 up_block_type 是否為 "CrossAttnUpBlock2D"
    elif up_block_type == "CrossAttnUpBlock2D":
        # 如果未定義交叉注意力維度,則丟擲異常
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")  # 丟擲值錯誤
        return CrossAttnUpBlock2D(  # 返回 CrossAttnUpBlock2D 物件
            num_layers=num_layers,  # 設定網路層數
            transformer_layers_per_block=transformer_layers_per_block,  # 設定每個塊的變換層數
            in_channels=in_channels,  # 設定輸入通道數
            out_channels=out_channels,  # 設定輸出通道數
            prev_output_channel=prev_output_channel,  # 設定前一個輸出通道
            temb_channels=temb_channels,  # 設定時間嵌入通道數
            resolution_idx=resolution_idx,  # 設定解析度索引
            dropout=dropout,  # 設定 dropout 引數
            add_upsample=add_upsample,  # 設定是否新增上取樣
            resnet_eps=resnet_eps,  # 設定 ResNet 的 epsilon
            resnet_act_fn=resnet_act_fn,  # 設定 ResNet 的啟用函式
            resnet_groups=resnet_groups,  # 設定 ResNet 的組數
            cross_attention_dim=cross_attention_dim,  # 設定交叉注意力維度
            num_attention_heads=num_attention_heads,  # 設定注意力頭的數量
            dual_cross_attention=dual_cross_attention,  # 設定雙重交叉注意力
            use_linear_projection=use_linear_projection,  # 設定是否使用線性投影
            only_cross_attention=only_cross_attention,  # 設定是否僅使用交叉注意力
            upcast_attention=upcast_attention,  # 設定是否提升注意力
            resnet_time_scale_shift=resnet_time_scale_shift,  # 設定 ResNet 的時間縮放偏移
            attention_type=attention_type,  # 設定注意力型別
        )
    # 檢查上取樣塊的型別是否為 SimpleCrossAttnUpBlock2D
    elif up_block_type == "SimpleCrossAttnUpBlock2D":
        # 如果未指定 cross_attention_dim,則丟擲值錯誤
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
        # 返回 SimpleCrossAttnUpBlock2D 例項,使用相關引數初始化
        return SimpleCrossAttnUpBlock2D(
            num_layers=num_layers,  # 設定層數
            in_channels=in_channels,  # 設定輸入通道數
            out_channels=out_channels,  # 設定輸出通道數
            prev_output_channel=prev_output_channel,  # 設定前一層的輸出通道數
            temb_channels=temb_channels,  # 設定時間嵌入通道數
            resolution_idx=resolution_idx,  # 設定解析度索引
            dropout=dropout,  # 設定 dropout 機率
            add_upsample=add_upsample,  # 設定是否新增上取樣
            resnet_eps=resnet_eps,  # 設定 ResNet 的 epsilon
            resnet_act_fn=resnet_act_fn,  # 設定 ResNet 的啟用函式
            resnet_groups=resnet_groups,  # 設定 ResNet 的組數
            cross_attention_dim=cross_attention_dim,  # 設定交叉注意力維度
            attention_head_dim=attention_head_dim,  # 設定注意力頭維度
            resnet_time_scale_shift=resnet_time_scale_shift,  # 設定 ResNet 的時間縮放偏移
            skip_time_act=resnet_skip_time_act,  # 設定是否跳過時間啟用
            output_scale_factor=resnet_out_scale_factor,  # 設定輸出縮放因子
            only_cross_attention=only_cross_attention,  # 設定是否僅使用交叉注意力
            cross_attention_norm=cross_attention_norm,  # 設定交叉注意力的歸一化方式
        )
    # 檢查上取樣塊的型別是否為 AttnUpBlock2D
    elif up_block_type == "AttnUpBlock2D":
        # 如果未新增上取樣,則將上取樣型別設為 None
        if add_upsample is False:
            upsample_type = None
        else:
            # 預設將上取樣型別設為 'conv'
            upsample_type = upsample_type or "conv"

        # 返回 AttnUpBlock2D 例項,使用相關引數初始化
        return AttnUpBlock2D(
            num_layers=num_layers,  # 設定層數
            in_channels=in_channels,  # 設定輸入通道數
            out_channels=out_channels,  # 設定輸出通道數
            prev_output_channel=prev_output_channel,  # 設定前一層的輸出通道數
            temb_channels=temb_channels,  # 設定時間嵌入通道數
            resolution_idx=resolution_idx,  # 設定解析度索引
            dropout=dropout,  # 設定 dropout 機率
            resnet_eps=resnet_eps,  # 設定 ResNet 的 epsilon
            resnet_act_fn=resnet_act_fn,  # 設定 ResNet 的啟用函式
            resnet_groups=resnet_groups,  # 設定 ResNet 的組數
            attention_head_dim=attention_head_dim,  # 設定注意力頭維度
            resnet_time_scale_shift=resnet_time_scale_shift,  # 設定 ResNet 的時間縮放偏移
            upsample_type=upsample_type,  # 設定上取樣型別
        )
    # 檢查上取樣塊的型別是否為 SkipUpBlock2D
    elif up_block_type == "SkipUpBlock2D":
        # 返回 SkipUpBlock2D 例項,使用相關引數初始化
        return SkipUpBlock2D(
            num_layers=num_layers,  # 設定層數
            in_channels=in_channels,  # 設定輸入通道數
            out_channels=out_channels,  # 設定輸出通道數
            prev_output_channel=prev_output_channel,  # 設定前一層的輸出通道數
            temb_channels=temb_channels,  # 設定時間嵌入通道數
            resolution_idx=resolution_idx,  # 設定解析度索引
            dropout=dropout,  # 設定 dropout 機率
            add_upsample=add_upsample,  # 設定是否新增上取樣
            resnet_eps=resnet_eps,  # 設定 ResNet 的 epsilon
            resnet_act_fn=resnet_act_fn,  # 設定 ResNet 的啟用函式
            resnet_time_scale_shift=resnet_time_scale_shift,  # 設定 ResNet 的時間縮放偏移
        )
    # 檢查上取樣塊的型別是否為 AttnSkipUpBlock2D
    elif up_block_type == "AttnSkipUpBlock2D":
        # 返回 AttnSkipUpBlock2D 例項,使用相關引數初始化
        return AttnSkipUpBlock2D(
            num_layers=num_layers,  # 設定層數
            in_channels=in_channels,  # 設定輸入通道數
            out_channels=out_channels,  # 設定輸出通道數
            prev_output_channel=prev_output_channel,  # 設定前一層的輸出通道數
            temb_channels=temb_channels,  # 設定時間嵌入通道數
            resolution_idx=resolution_idx,  # 設定解析度索引
            dropout=dropout,  # 設定 dropout 機率
            add_upsample=add_upsample,  # 設定是否新增上取樣
            resnet_eps=resnet_eps,  # 設定 ResNet 的 epsilon
            resnet_act_fn=resnet_act_fn,  # 設定 ResNet 的啟用函式
            attention_head_dim=attention_head_dim,  # 設定注意力頭維度
            resnet_time_scale_shift=resnet_time_scale_shift,  # 設定 ResNet 的時間縮放偏移
        )
    # 檢查上取樣塊型別是否為 UpDecoderBlock2D
    elif up_block_type == "UpDecoderBlock2D":
        # 返回 UpDecoderBlock2D 的例項,傳入相應引數
        return UpDecoderBlock2D(
            num_layers=num_layers,  # 設定層數
            in_channels=in_channels,  # 設定輸入通道數
            out_channels=out_channels,  # 設定輸出通道數
            resolution_idx=resolution_idx,  # 設定解析度索引
            dropout=dropout,  # 設定 dropout 比例
            add_upsample=add_upsample,  # 設定是否新增上取樣
            resnet_eps=resnet_eps,  # 設定 ResNet 的 epsilon 值
            resnet_act_fn=resnet_act_fn,  # 設定 ResNet 的啟用函式
            resnet_groups=resnet_groups,  # 設定 ResNet 的組數
            resnet_time_scale_shift=resnet_time_scale_shift,  # 設定時間尺度偏移
            temb_channels=temb_channels,  # 設定時間嵌入通道數
        )
    # 檢查上取樣塊型別是否為 AttnUpDecoderBlock2D
    elif up_block_type == "AttnUpDecoderBlock2D":
        # 返回 AttnUpDecoderBlock2D 的例項,傳入相應引數
        return AttnUpDecoderBlock2D(
            num_layers=num_layers,  # 設定層數
            in_channels=in_channels,  # 設定輸入通道數
            out_channels=out_channels,  # 設定輸出通道數
            resolution_idx=resolution_idx,  # 設定解析度索引
            dropout=dropout,  # 設定 dropout 比例
            add_upsample=add_upsample,  # 設定是否新增上取樣
            resnet_eps=resnet_eps,  # 設定 ResNet 的 epsilon 值
            resnet_act_fn=resnet_act_fn,  # 設定 ResNet 的啟用函式
            resnet_groups=resnet_groups,  # 設定 ResNet 的組數
            attention_head_dim=attention_head_dim,  # 設定注意力頭維度
            resnet_time_scale_shift=resnet_time_scale_shift,  # 設定時間尺度偏移
            temb_channels=temb_channels,  # 設定時間嵌入通道數
        )
    # 檢查上取樣塊型別是否為 KUpBlock2D
    elif up_block_type == "KUpBlock2D":
        # 返回 KUpBlock2D 的例項,傳入相應引數
        return KUpBlock2D(
            num_layers=num_layers,  # 設定層數
            in_channels=in_channels,  # 設定輸入通道數
            out_channels=out_channels,  # 設定輸出通道數
            temb_channels=temb_channels,  # 設定時間嵌入通道數
            resolution_idx=resolution_idx,  # 設定解析度索引
            dropout=dropout,  # 設定 dropout 比例
            add_upsample=add_upsample,  # 設定是否新增上取樣
            resnet_eps=resnet_eps,  # 設定 ResNet 的 epsilon 值
            resnet_act_fn=resnet_act_fn,  # 設定 ResNet 的啟用函式
        )
    # 檢查上取樣塊型別是否為 KCrossAttnUpBlock2D
    elif up_block_type == "KCrossAttnUpBlock2D":
        # 返回 KCrossAttnUpBlock2D 的例項,傳入相應引數
        return KCrossAttnUpBlock2D(
            num_layers=num_layers,  # 設定層數
            in_channels=in_channels,  # 設定輸入通道數
            out_channels=out_channels,  # 設定輸出通道數
            temb_channels=temb_channels,  # 設定時間嵌入通道數
            resolution_idx=resolution_idx,  # 設定解析度索引
            dropout=dropout,  # 設定 dropout 比例
            add_upsample=add_upsample,  # 設定是否新增上取樣
            resnet_eps=resnet_eps,  # 設定 ResNet 的 epsilon 值
            resnet_act_fn=resnet_act_fn,  # 設定 ResNet 的啟用函式
            cross_attention_dim=cross_attention_dim,  # 設定交叉注意力維度
            attention_head_dim=attention_head_dim,  # 設定注意力頭維度
        )

    # 如果未匹配到任何上取樣塊型別,丟擲值錯誤
    raise ValueError(f"{up_block_type} does not exist.")
# 定義一個小型自編碼器塊,繼承自 nn.Module
class AutoencoderTinyBlock(nn.Module):
    """
    Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
    blocks.

    Args:
        in_channels (`int`): The number of input channels.
        out_channels (`int`): The number of output channels.
        act_fn (`str`):
            ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.

    Returns:
        `torch.Tensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
        `out_channels`.
    """

    # 初始化函式,接受輸入通道數、輸出通道數和啟用函式型別
    def __init__(self, in_channels: int, out_channels: int, act_fn: str):
        # 呼叫父類初始化
        super().__init__()
        # 獲取指定的啟用函式
        act_fn = get_activation(act_fn)
        # 定義一個序列,包括多個卷積層和啟用函式
        self.conv = nn.Sequential(
            # 第一層卷積,輸入通道數、輸出通道數、卷積核大小和填充方式
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            # 新增啟用函式
            act_fn,
            # 第二層卷積,保持輸出通道數
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            # 新增啟用函式
            act_fn,
            # 第三層卷積,保持輸出通道數
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        )
        # 判斷輸入和輸出通道是否相同,決定使用卷積或身份對映
        self.skip = (
            # 如果通道數不一致,使用 1x1 卷積進行跳躍連線
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
            if in_channels != out_channels
            else nn.Identity()
        )
        # 使用 ReLU 進行特徵融合
        self.fuse = nn.ReLU()

    # 定義前向傳播函式,接受輸入張量 x
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 返回卷積輸出和跳躍連線的和,經過融合啟用函式處理
        return self.fuse(self.conv(x) + self.skip(x))


# 定義一個 2D UNet 中間塊,繼承自 nn.Module
class UNetMidBlock2D(nn.Module):
    """
    A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
    # 引數說明
    Args:
        in_channels (`int`): 輸入通道的數量。
        temb_channels (`int`): 時間嵌入通道的數量。
        dropout (`float`, *optional*, defaults to 0.0): dropout 比率,用於防止過擬合。
        num_layers (`int`, *optional*, defaults to 1): 殘差塊的數量。
        resnet_eps (`float`, *optional*, 1e-6 ): 殘差塊的 epsilon 值,用於數值穩定性。
        resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
            應用於時間嵌入的歸一化型別。可以改善模型在長時間依賴任務上的表現。
        resnet_act_fn (`str`, *optional*, defaults to `swish`): 殘差塊的啟用函式型別。
        resnet_groups (`int`, *optional*, defaults to 32):
            殘差塊中組歸一化層使用的組數量。
        attn_groups (`Optional[int]`, *optional*, defaults to None): 注意力塊的組數量。
        resnet_pre_norm (`bool`, *optional*, defaults to `True`):
            是否在殘差塊中使用預歸一化。
        add_attention (`bool`, *optional*, defaults to `True`): 是否新增註意力塊。
        attention_head_dim (`int`, *optional*, defaults to 1):
            單個注意力頭的維度。注意力頭的數量由該值和輸入通道的數量決定。
        output_scale_factor (`float`, *optional*, defaults to 1.0): 輸出的縮放因子。

    # 返回值說明
    Returns:
        `torch.Tensor`: 最後一個殘差塊的輸出,形狀為 `(batch_size, in_channels,
        height, width)`。

    """

    # 初始化函式,設定模型的各種引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        temb_channels: int,  # 時間嵌入通道數
        dropout: float = 0.0,  # dropout比率,預設為0.0
        num_layers: int = 1,  # 殘差塊數量,預設為1
        resnet_eps: float = 1e-6,  # 殘差塊的epsilon值
        resnet_time_scale_shift: str = "default",  # 時間尺度歸一化的型別
        resnet_act_fn: str = "swish",  # 殘差塊的啟用函式
        resnet_groups: int = 32,  # 殘差塊的組數量
        attn_groups: Optional[int] = None,  # 注意力塊的組數量
        resnet_pre_norm: bool = True,  # 是否使用預歸一化
        add_attention: bool = True,  # 是否新增註意力塊
        attention_head_dim: int = 1,  # 注意力頭的維度
        output_scale_factor: float = 1.0,  # 輸出縮放因子
    # 前向傳播函式,接受隱藏狀態和時間嵌入作為輸入
    def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
        # 將輸入的隱藏狀態透過第一個殘差塊進行處理
        hidden_states = self.resnets[0](hidden_states, temb)
        # 遍歷剩餘的注意力塊和殘差塊進行處理
        for attn, resnet in zip(self.attentions, self.resnets[1:]):
            # 如果當前存在注意力塊,則進行處理
            if attn is not None:
                hidden_states = attn(hidden_states, temb=temb)
            # 將處理後的隱藏狀態透過當前的殘差塊進行處理
            hidden_states = resnet(hidden_states, temb)

        # 返回最終的隱藏狀態
        return hidden_states
# 定義一個繼承自 nn.Module 的類 UNetMidBlock2DCrossAttn
class UNetMidBlock2DCrossAttn(nn.Module):
    # 初始化方法,定義各個引數
    def __init__(
        # 輸入通道數
        in_channels: int,
        # 時間嵌入通道數
        temb_channels: int,
        # 輸出通道數(可選,預設為 None)
        out_channels: Optional[int] = None,
        # dropout 機率(預設為 0.0)
        dropout: float = 0.0,
        # 層數(預設為 1)
        num_layers: int = 1,
        # 每個塊的變換器層數(預設為 1)
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
        # ResNet 的 epsilon 值(預設為 1e-6)
        resnet_eps: float = 1e-6,
        # ResNet 的時間縮放偏移方式(預設為 "default")
        resnet_time_scale_shift: str = "default",
        # ResNet 的啟用函式型別(預設為 "swish")
        resnet_act_fn: str = "swish",
        # ResNet 的組數(預設為 32)
        resnet_groups: int = 32,
        # 輸出的 ResNet 組數(可選,預設為 None)
        resnet_groups_out: Optional[int] = None,
        # 是否進行 ResNet 預歸一化(預設為 True)
        resnet_pre_norm: bool = True,
        # 注意力頭的數量(預設為 1)
        num_attention_heads: int = 1,
        # 輸出縮放因子(預設為 1.0)
        output_scale_factor: float = 1.0,
        # 交叉注意力維度(預設為 1280)
        cross_attention_dim: int = 1280,
        # 是否使用雙重交叉注意力(預設為 False)
        dual_cross_attention: bool = False,
        # 是否使用線性投影(預設為 False)
        use_linear_projection: bool = False,
        # 是否提升注意力計算精度(預設為 False)
        upcast_attention: bool = False,
        # 注意力型別(預設為 "default")
        attention_type: str = "default",
    # 前向傳播方法
    def forward(
        # 隱藏狀態的張量
        hidden_states: torch.Tensor,
        # 時間嵌入的張量(可選,預設為 None)
        temb: Optional[torch.Tensor] = None,
        # 編碼器隱藏狀態的張量(可選,預設為 None)
        encoder_hidden_states: Optional[torch.Tensor] = None,
        # 注意力掩碼(可選,預設為 None)
        attention_mask: Optional[torch.Tensor] = None,
        # 交叉注意力引數(可選,預設為 None)
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 編碼器注意力掩碼(可選,預設為 None)
        encoder_attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:  # 定義返回型別為 torch.Tensor
        if cross_attention_kwargs is not None:  # 檢查交叉注意力引數是否存在
            if cross_attention_kwargs.get("scale", None) is not None:  # 檢查是否有 scale 引數
                logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")  # 記錄警告,提示 scale 引數已棄用

        hidden_states = self.resnets[0](hidden_states, temb)  # 使用第一個殘差網路處理隱藏狀態和時間嵌入
        for attn, resnet in zip(self.attentions, self.resnets[1:]):  # 遍歷注意力層和殘差網路(跳過第一個)
            if self.training and self.gradient_checkpointing:  # 如果處於訓練模式並且啟用了梯度檢查點

                def create_custom_forward(module, return_dict=None):  # 定義自定義前向傳播函式
                    def custom_forward(*inputs):  # 定義實際的前向傳播實現
                        if return_dict is not None:  # 如果返回字典不為 None
                            return module(*inputs, return_dict=return_dict)  # 呼叫模組並返回字典
                        else:  # 如果返回字典為 None
                            return module(*inputs)  # 直接呼叫模組並返回結果

                    return custom_forward  # 返回自定義前向傳播函式

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}  # 設定檢查點的關鍵字引數
                hidden_states = attn(  # 呼叫注意力層處理隱藏狀態
                    hidden_states,  # 輸入隱藏狀態
                    encoder_hidden_states=encoder_hidden_states,  # 編碼器隱藏狀態
                    cross_attention_kwargs=cross_attention_kwargs,  # 交叉注意力引數
                    attention_mask=attention_mask,  # 注意力掩碼
                    encoder_attention_mask=encoder_attention_mask,  # 編碼器注意力掩碼
                    return_dict=False,  # 不返回字典格式
                )[0]  # 取處理結果的第一個元素
                hidden_states = torch.utils.checkpoint.checkpoint(  # 使用梯度檢查點
                    create_custom_forward(resnet),  # 建立殘差網路的自定義前向函式
                    hidden_states,  # 輸入隱藏狀態
                    temb,  # 輸入時間嵌入
                    **ckpt_kwargs,  # 解包關鍵字引數
                )
            else:  # 如果不處於訓練模式或不啟用梯度檢查點
                hidden_states = attn(  # 呼叫注意力層處理隱藏狀態
                    hidden_states,  # 輸入隱藏狀態
                    encoder_hidden_states=encoder_hidden_states,  # 編碼器隱藏狀態
                    cross_attention_kwargs=cross_attention_kwargs,  # 交叉注意力引數
                    attention_mask=attention_mask,  # 注意力掩碼
                    encoder_attention_mask=encoder_attention_mask,  # 編碼器注意力掩碼
                    return_dict=False,  # 不返回字典格式
                )[0]  # 取處理結果的第一個元素
                hidden_states = resnet(hidden_states, temb)  # 使用殘差網路處理隱藏狀態和時間嵌入

        return hidden_states  # 返回最終的隱藏狀態
# 定義一個 UNet 中間塊類,繼承自 nn.Module
class UNetMidBlock2DSimpleCrossAttn(nn.Module):
    # 初始化方法,設定類的引數
    def __init__(
        # 輸入通道數
        in_channels: int,
        # 時間嵌入通道數
        temb_channels: int,
        # Dropout 比例,預設為 0.0
        dropout: float = 0.0,
        # 層數,預設為 1
        num_layers: int = 1,
        # ResNet 的小 epsilon 值,預設為 1e-6
        resnet_eps: float = 1e-6,
        # ResNet 時間尺度偏移,預設為 "default"
        resnet_time_scale_shift: str = "default",
        # ResNet 啟用函式型別,預設為 "swish"
        resnet_act_fn: str = "swish",
        # ResNet 組數,預設為 32
        resnet_groups: int = 32,
        # 是否使用預歸一化,預設為 True
        resnet_pre_norm: bool = True,
        # 注意力頭維度,預設為 1
        attention_head_dim: int = 1,
        # 輸出縮放因子,預設為 1.0
        output_scale_factor: float = 1.0,
        # 交叉注意力維度,預設為 1280
        cross_attention_dim: int = 1280,
        # 是否跳過時間啟用,預設為 False
        skip_time_act: bool = False,
        # 是否僅使用交叉注意力,預設為 False
        only_cross_attention: bool = False,
        # 交叉注意力的歸一化方法,可選引數
        cross_attention_norm: Optional[str] = None,
    ):
        # 呼叫父類的初始化方法
        super().__init__()

        # 設定是否使用交叉注意力
        self.has_cross_attention = True

        # 設定注意力頭的維度
        self.attention_head_dim = attention_head_dim
        
        # 計算 ResNet 的組數,如果未提供,則取輸入通道數的四分之一和 32 的最小值
        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)

        # 計算注意力頭的數量
        self.num_heads = in_channels // self.attention_head_dim

        # 至少存在一個 ResNet 塊
        resnets = [
            # 建立一個 ResNet 塊
            ResnetBlock2D(
                # 輸入通道數
                in_channels=in_channels,
                # 輸出通道數
                out_channels=in_channels,
                # 時間嵌入通道數
                temb_channels=temb_channels,
                # 正則化引數
                eps=resnet_eps,
                # ResNet 組數
                groups=resnet_groups,
                # dropout 機率
                dropout=dropout,
                # 時間嵌入歸一化方法
                time_embedding_norm=resnet_time_scale_shift,
                # 非線性啟用函式
                non_linearity=resnet_act_fn,
                # 輸出縮放因子
                output_scale_factor=output_scale_factor,
                # 是否進行預歸一化
                pre_norm=resnet_pre_norm,
                # 是否跳過時間啟用
                skip_time_act=skip_time_act,
            )
        ]
        # 初始化注意力層列表
        attentions = []

        # 迴圈建立指定數量的層
        for _ in range(num_layers):
            # 根據是否具有縮放點積注意力,選擇處理器
            processor = (
                AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
            )

            # 新增註意力層到列表中
            attentions.append(
                Attention(
                    # 查詢的維度
                    query_dim=in_channels,
                    # 交叉注意力的維度
                    cross_attention_dim=in_channels,
                    # 注意力頭的數量
                    heads=self.num_heads,
                    # 每個頭的維度
                    dim_head=self.attention_head_dim,
                    # 額外的 KV 投影維度
                    added_kv_proj_dim=cross_attention_dim,
                    # 歸一化的組數
                    norm_num_groups=resnet_groups,
                    # 是否使用偏置
                    bias=True,
                    # 是否上cast softmax
                    upcast_softmax=True,
                    # 是否僅使用交叉注意力
                    only_cross_attention=only_cross_attention,
                    # 交叉注意力的歸一化方法
                    cross_attention_norm=cross_attention_norm,
                    # 設定處理器
                    processor=processor,
                )
            )
            # 新增另一個 ResNet 塊到列表中
            resnets.append(
                ResnetBlock2D(
                    in_channels=in_channels,
                    out_channels=in_channels,
                    temb_channels=temb_channels,
                    eps=resnet_eps,
                    groups=resnet_groups,
                    dropout=dropout,
                    time_embedding_norm=resnet_time_scale_shift,
                    non_linearity=resnet_act_fn,
                    output_scale_factor=output_scale_factor,
                    pre_norm=resnet_pre_norm,
                    skip_time_act=skip_time_act,
                )
            )

        # 將注意力層轉為可訓練模組列表
        self.attentions = nn.ModuleList(attentions)
        # 將 ResNet 塊轉為可訓練模組列表
        self.resnets = nn.ModuleList(resnets)

    def forward(
        # 前向傳播函式的定義
        self,
        # 輸入的隱狀態
        hidden_states: torch.Tensor,
        # 可選的時間嵌入
        temb: Optional[torch.Tensor] = None,
        # 可選的編碼器隱狀態
        encoder_hidden_states: Optional[torch.Tensor] = None,
        # 可選的注意力掩碼
        attention_mask: Optional[torch.Tensor] = None,
        # 可選的交叉注意力引數
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 可選的編碼器注意力掩碼
        encoder_attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # 如果 cross_attention_kwargs 為 None,則初始化為空字典
        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
        # 檢查 cross_attention_kwargs 中是否有 "scale" 引數,如果有則記錄警告資訊
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

        # 如果 attention_mask 為 None
        if attention_mask is None:
            # 如果 encoder_hidden_states 已定義,表示正在進行交叉注意力,因此使用交叉注意力掩碼
            mask = None if encoder_hidden_states is None else encoder_attention_mask
        else:
            # 當 attention_mask 已定義時,不檢查 encoder_attention_mask
            # 這是為了與 UnCLIP 相容,UnCLIP 使用 'attention_mask' 引數作為交叉注意力掩碼
            # TODO: UnCLIP 應該透過 encoder_attention_mask 引數而不是 attention_mask 引數來表達交叉注意力掩碼
            #       然後可以簡化整個 if/else 塊為:
            #         mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
            mask = attention_mask

        # 透過第一個殘差網路處理隱藏狀態
        hidden_states = self.resnets[0](hidden_states, temb)
        # 遍歷注意力層和後續的殘差網路
        for attn, resnet in zip(self.attentions, self.resnets[1:]):
            # 使用當前的注意力層處理隱藏狀態
            hidden_states = attn(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,  # 傳遞編碼器的隱藏狀態
                attention_mask=mask,  # 傳遞注意力掩碼
                **cross_attention_kwargs,  # 解包交叉注意力引數
            )

            # 透過當前的殘差網路處理隱藏狀態
            hidden_states = resnet(hidden_states, temb)

        # 返回處理後的隱藏狀態
        return hidden_states
# 定義一個名為 AttnDownBlock2D 的類,繼承自 nn.Module
class AttnDownBlock2D(nn.Module):
    # 初始化方法,接受多個引數以設定層的屬性
    def __init__(
        # 輸入通道數
        in_channels: int,
        # 輸出通道數
        out_channels: int,
        # 時間嵌入通道數
        temb_channels: int,
        # dropout 機率,預設為 0.0
        dropout: float = 0.0,
        # 層數,預設為 1
        num_layers: int = 1,
        # ResNet 的 epsilon 值,防止除零錯誤,預設為 1e-6
        resnet_eps: float = 1e-6,
        # ResNet 時間尺度偏移,預設為 "default"
        resnet_time_scale_shift: str = "default",
        # ResNet 啟用函式,預設為 "swish"
        resnet_act_fn: str = "swish",
        # ResNet 組數,預設為 32
        resnet_groups: int = 32,
        # 是否在 ResNet 中使用預歸一化,預設為 True
        resnet_pre_norm: bool = True,
        # 注意力頭的維度,預設為 1
        attention_head_dim: int = 1,
        # 輸出縮放因子,預設為 1.0
        output_scale_factor: float = 1.0,
        # 下采樣的填充大小,預設為 1
        downsample_padding: int = 1,
        # 下采樣型別,預設為 "conv"
        downsample_type: str = "conv",
    ):
        # 呼叫父類建構函式初始化
        super().__init__()
        # 初始化空列表,用於儲存 ResNet 塊
        resnets = []
        # 初始化空列表,用於儲存注意力機制模組
        attentions = []
        # 儲存下采樣型別
        self.downsample_type = downsample_type

        # 如果未指定注意力頭的維度,則發出警告,並預設使用輸出通道數
        if attention_head_dim is None:
            logger.warning(
                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
            )
            # 將注意力頭維度設定為輸出通道數
            attention_head_dim = out_channels

        # 遍歷層數以構建 ResNet 塊和注意力模組
        for i in range(num_layers):
            # 確定輸入通道數,第一層使用初始通道數,其餘層使用輸出通道數
            in_channels = in_channels if i == 0 else out_channels
            # 建立並新增 ResNet 塊到 resnets 列表
            resnets.append(
                ResnetBlock2D(
                    in_channels=in_channels,  # 輸入通道數
                    out_channels=out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # 防止除零的epsilon值
                    groups=resnet_groups,  # 分組數
                    dropout=dropout,  # dropout 比率
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入的歸一化方式
                    non_linearity=resnet_act_fn,  # 非線性啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 是否在前面進行歸一化
                )
            )
            # 建立並新增註意力模組到 attentions 列表
            attentions.append(
                Attention(
                    out_channels,  # 輸出通道數
                    heads=out_channels // attention_head_dim,  # 注意力頭的數量
                    dim_head=attention_head_dim,  # 每個注意力頭的維度
                    rescale_output_factor=output_scale_factor,  # 輸出縮放因子
                    eps=resnet_eps,  # 防止除零的epsilon值
                    norm_num_groups=resnet_groups,  # 歸一化的組數
                    residual_connection=True,  # 是否使用殘差連線
                    bias=True,  # 是否使用偏置
                    upcast_softmax=True,  # 是否上溯 softmax
                    _from_deprecated_attn_block=True,  # 是否來自於已棄用的注意力塊
                )
            )

        # 將注意力模組列表轉換為 nn.ModuleList,便於管理
        self.attentions = nn.ModuleList(attentions)
        # 將 ResNet 塊列表轉換為 nn.ModuleList,便於管理
        self.resnets = nn.ModuleList(resnets)

        # 根據下采樣型別選擇相應的下采樣方法
        if downsample_type == "conv":
            # 建立卷積下采樣模組並儲存
            self.downsamplers = nn.ModuleList(
                [
                    Downsample2D(
                        out_channels,  # 輸出通道數
                        use_conv=True,  # 是否使用卷積
                        out_channels=out_channels,  # 輸出通道數
                        padding=downsample_padding,  # 填充
                        name="op"  # 模組名稱
                    )
                ]
            )
        elif downsample_type == "resnet":
            # 建立 ResNet 下采樣塊並儲存
            self.downsamplers = nn.ModuleList(
                [
                    ResnetBlock2D(
                        in_channels=out_channels,  # 輸入通道數
                        out_channels=out_channels,  # 輸出通道數
                        temb_channels=temb_channels,  # 時間嵌入通道數
                        eps=resnet_eps,  # 防止除零的epsilon值
                        groups=resnet_groups,  # 分組數
                        dropout=dropout,  # dropout 比率
                        time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入的歸一化方式
                        non_linearity=resnet_act_fn,  # 非線性啟用函式
                        output_scale_factor=output_scale_factor,  # 輸出縮放因子
                        pre_norm=resnet_pre_norm,  # 是否在前面進行歸一化
                        down=True,  # 指示為下采樣
                    )
                ]
            )
        else:
            # 如果沒有匹配的下采樣型別,則將下采樣模組設定為 None
            self.downsamplers = None
    # 前向傳播方法,處理隱狀態和可選的其他引數,返回處理後的隱狀態和輸出狀態
    def forward(
            self,
            hidden_states: torch.Tensor,  # 輸入的隱狀態張量
            temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
            upsample_size: Optional[int] = None,  # 可選的上取樣尺寸
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 可選的交叉注意力引數字典
        ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:  # 返回隱狀態和輸出狀態的元組
            # 如果沒有提供交叉注意力引數,則初始化為空字典
            cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
            # 檢查是否傳入了 scale 引數,如果有,則發出警告,因為這個引數已被棄用
            if cross_attention_kwargs.get("scale", None) is not None:
                logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
    
            output_states = ()  # 初始化輸出狀態為一個空元組
    
            # 遍歷自定義的殘差網路和注意力層
            for resnet, attn in zip(self.resnets, self.attentions):
                # 將隱狀態傳遞給殘差網路並更新隱狀態
                hidden_states = resnet(hidden_states, temb)
                # 將隱狀態傳遞給注意力層,並更新隱狀態
                hidden_states = attn(hidden_states, **cross_attention_kwargs)
                # 將當前隱狀態新增到輸出狀態元組中
                output_states = output_states + (hidden_states,)
    
            # 檢查是否存在下采樣層
            if self.downsamplers is not None:
                # 遍歷每個下采樣層
                for downsampler in self.downsamplers:
                    # 根據下采樣型別選擇不同的處理方式
                    if self.downsample_type == "resnet":
                        hidden_states = downsampler(hidden_states, temb=temb)  # 使用時間嵌入處理
                    else:
                        hidden_states = downsampler(hidden_states)  # 不使用時間嵌入處理
    
                # 將最後的隱狀態新增到輸出狀態中
                output_states += (hidden_states,)
    
            # 返回處理後的隱狀態和輸出狀態
            return hidden_states, output_states
# 定義一個名為 CrossAttnDownBlock2D 的類,繼承自 nn.Module
class CrossAttnDownBlock2D(nn.Module):
    # 初始化方法,接收多個引數以配置模組
    def __init__(
        # 輸入通道數
        self,
        in_channels: int,
        # 輸出通道數
        out_channels: int,
        # 時間嵌入通道數
        temb_channels: int,
        # dropout 比例,預設為 0.0
        dropout: float = 0.0,
        # 層數,預設為 1
        num_layers: int = 1,
        # 每個塊的變換器層數,可以是單個整數或整數元組
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
        # ResNet 的 epsilon 值,預設為 1e-6
        resnet_eps: float = 1e-6,
        # ResNet 時間縮放偏移的型別,預設為 "default"
        resnet_time_scale_shift: str = "default",
        # ResNet 的啟用函式,預設為 "swish"
        resnet_act_fn: str = "swish",
        # ResNet 的分組數,預設為 32
        resnet_groups: int = 32,
        # 是否使用預歸一化,預設為 True
        resnet_pre_norm: bool = True,
        # 注意力頭的數量,預設為 1
        num_attention_heads: int = 1,
        # 交叉注意力維度,預設為 1280
        cross_attention_dim: int = 1280,
        # 輸出縮放因子,預設為 1.0
        output_scale_factor: float = 1.0,
        # 下采樣填充的大小,預設為 1
        downsample_padding: int = 1,
        # 是否新增下采樣,預設為 True
        add_downsample: bool = True,
        # 是否使用雙重交叉注意力,預設為 False
        dual_cross_attention: bool = False,
        # 是否使用線性投影,預設為 False
        use_linear_projection: bool = False,
        # 是否僅使用交叉注意力,預設為 False
        only_cross_attention: bool = False,
        # 是否提升注意力精度,預設為 False
        upcast_attention: bool = False,
        # 注意力型別,預設為 "default"
        attention_type: str = "default",
    # 初始化父類
        ):
            super().__init__()
            # 初始化殘差塊列表
            resnets = []
            # 初始化注意力機制列表
            attentions = []
    
            # 設定是否使用交叉注意力
            self.has_cross_attention = True
            # 設定注意力頭的數量
            self.num_attention_heads = num_attention_heads
            # 如果每個塊的變換層是整數,則擴充套件為列表
            if isinstance(transformer_layers_per_block, int):
                transformer_layers_per_block = [transformer_layers_per_block] * num_layers
    
            # 遍歷每一層
            for i in range(num_layers):
                # 確定輸入通道數
                in_channels = in_channels if i == 0 else out_channels
                # 新增殘差塊到列表
                resnets.append(
                    ResnetBlock2D(
                        in_channels=in_channels,  # 輸入通道數
                        out_channels=out_channels,  # 輸出通道數
                        temb_channels=temb_channels,  # 時間嵌入通道數
                        eps=resnet_eps,  #  epsilon值
                        groups=resnet_groups,  # 組數
                        dropout=dropout,  # dropout比率
                        time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化
                        non_linearity=resnet_act_fn,  # 非線性啟用函式
                        output_scale_factor=output_scale_factor,  # 輸出縮放因子
                        pre_norm=resnet_pre_norm,  # 是否使用預歸一化
                    )
                )
                # 檢查是否使用雙重交叉注意力
                if not dual_cross_attention:
                    # 新增普通的變換模型到列表
                    attentions.append(
                        Transformer2DModel(
                            num_attention_heads,  # 注意力頭數量
                            out_channels // num_attention_heads,  # 每個頭的輸出通道數
                            in_channels=out_channels,  # 輸入通道數
                            num_layers=transformer_layers_per_block[i],  # 層數
                            cross_attention_dim=cross_attention_dim,  # 交叉注意力維度
                            norm_num_groups=resnet_groups,  # 歸一化組數
                            use_linear_projection=use_linear_projection,  # 是否使用線性投影
                            only_cross_attention=only_cross_attention,  # 是否僅使用交叉注意力
                            upcast_attention=upcast_attention,  # 是否向上投射注意力
                            attention_type=attention_type,  # 注意力型別
                        )
                    )
                else:
                    # 新增雙重變換模型到列表
                    attentions.append(
                        DualTransformer2DModel(
                            num_attention_heads,  # 注意力頭數量
                            out_channels // num_attention_heads,  # 每個頭的輸出通道數
                            in_channels=out_channels,  # 輸入通道數
                            num_layers=1,  # 層數固定為1
                            cross_attention_dim=cross_attention_dim,  # 交叉注意力維度
                            norm_num_groups=resnet_groups,  # 歸一化組數
                        )
                    )
            # 將注意力模型列表轉換為nn.ModuleList
            self.attentions = nn.ModuleList(attentions)
            # 將殘差塊列表轉換為nn.ModuleList
            self.resnets = nn.ModuleList(resnets)
    
            # 檢查是否新增下采樣層
            if add_downsample:
                # 建立下采樣層列表
                self.downsamplers = nn.ModuleList(
                    [
                        Downsample2D(
                            out_channels,  # 輸出通道數
                            use_conv=True,  # 是否使用卷積
                            out_channels=out_channels,  # 輸出通道數
                            padding=downsample_padding,  # 填充
                            name="op"  # 操作名稱
                        )
                    ]
                )
            else:
                # 如果不新增下采樣層,則設定為None
                self.downsamplers = None
    
            # 設定梯度檢查點開關為關閉
            self.gradient_checkpointing = False
    # 定義前向傳播函式,接收多個引數
        def forward(
            self,
            # 隱藏狀態張量,表示模型的內部狀態
            hidden_states: torch.Tensor,
            # 可選的時間嵌入張量,用於控制生成的時間步
            temb: Optional[torch.Tensor] = None,
            # 可選的編碼器隱藏狀態張量,表示編碼器的輸出
            encoder_hidden_states: Optional[torch.Tensor] = None,
            # 可選的注意力掩碼,用於遮蔽輸入中不需要關注的部分
            attention_mask: Optional[torch.Tensor] = None,
            # 可選的交叉注意力引數字典,用於傳遞其他配置
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 可選的編碼器注意力掩碼,控制編碼器的注意力機制
            encoder_attention_mask: Optional[torch.Tensor] = None,
            # 可選的附加殘差張量,作為額外的資訊傳遞
            additional_residuals: Optional[torch.Tensor] = None,
    # 返回型別為元組,包含張量和一個張量元組
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
        # 檢查交叉注意力引數是否不為空
        if cross_attention_kwargs is not None:
            # 如果"scale"引數存在,則發出警告,說明其已被棄用
            if cross_attention_kwargs.get("scale", None) is not None:
                logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
    
        # 初始化輸出狀態為空元組
        output_states = ()
    
        # 將殘差網路和注意力層配對
        blocks = list(zip(self.resnets, self.attentions))
    
        # 遍歷每一對殘差網路和注意力層
        for i, (resnet, attn) in enumerate(blocks):
            # 如果在訓練中並且啟用了梯度檢查點
            if self.training and self.gradient_checkpointing:
    
                # 定義建立自定義前向傳播的函式
                def create_custom_forward(module, return_dict=None):
                    # 定義自定義前向傳播邏輯
                    def custom_forward(*inputs):
                        # 根據是否返回字典呼叫模組
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)
    
                    return custom_forward
    
                # 根據 PyTorch 版本設定檢查點引數
                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                # 使用檢查點機制計算隱藏狀態
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(resnet),
                    hidden_states,
                    temb,
                    **ckpt_kwargs,
                )
                # 透過注意力層處理隱藏狀態
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
            else:
                # 直接透過殘差網路處理隱藏狀態
                hidden_states = resnet(hidden_states, temb)
                # 透過注意力層處理隱藏狀態
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
    
            # 如果是最後一對塊並且有額外的殘差
            if i == len(blocks) - 1 and additional_residuals is not None:
                # 將額外的殘差新增到隱藏狀態
                hidden_states = hidden_states + additional_residuals
    
            # 更新輸出狀態,新增當前隱藏狀態
            output_states = output_states + (hidden_states,)
    
        # 如果下采樣器不為空
        if self.downsamplers is not None:
            # 遍歷每個下采樣器
            for downsampler in self.downsamplers:
                # 使用下采樣器處理隱藏狀態
                hidden_states = downsampler(hidden_states)
    
            # 更新輸出狀態,新增當前隱藏狀態
            output_states = output_states + (hidden_states,)
    
        # 返回最終的隱藏狀態和輸出狀態
        return hidden_states, output_states
# 定義一個二維向下塊,繼承自 nn.Module
class DownBlock2D(nn.Module):
    # 初始化方法,定義各個引數及其預設值
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: int,  # 輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        dropout: float = 0.0,  # dropout 機率
        num_layers: int = 1,  # ResNet 層數
        resnet_eps: float = 1e-6,  # ResNet 中的 epsilon
        resnet_time_scale_shift: str = "default",  # 時間縮放偏移設定
        resnet_act_fn: str = "swish",  # ResNet 的啟用函式
        resnet_groups: int = 32,  # ResNet 的分組數
        resnet_pre_norm: bool = True,  # 是否使用預歸一化
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        add_downsample: bool = True,  # 是否新增下采樣
        downsample_padding: int = 1,  # 下采樣填充
    ):
        # 呼叫父類建構函式
        super().__init__()
        # 初始化空的 ResNet 塊列表
        resnets = []

        # 根據層數迴圈建立 ResNet 塊
        for i in range(num_layers):
            # 確定當前層的輸入通道數
            in_channels = in_channels if i == 0 else out_channels
            # 新增 ResNet 塊到列表
            resnets.append(
                ResnetBlock2D(
                    in_channels=in_channels,  # 當前層的輸入通道數
                    out_channels=out_channels,  # 當前層的輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # epsilon 引數
                    groups=resnet_groups,  # 分組數
                    dropout=dropout,  # dropout 機率
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化
                    non_linearity=resnet_act_fn,  # 啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 預歸一化標誌
                )
            )

        # 將 ResNet 塊列表轉換為 ModuleList
        self.resnets = nn.ModuleList(resnets)

        # 根據標誌決定是否新增下采樣層
        if add_downsample:
            self.downsamplers = nn.ModuleList(
                [
                    Downsample2D(  # 建立下采樣層
                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
                    )
                ]
            )
        else:
            # 如果不新增下采樣,則將其設定為 None
            self.downsamplers = None

        # 初始化梯度檢查點標誌為 False
        self.gradient_checkpointing = False

    # 定義前向傳播方法
    def forward(
        self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
        # 檢查位置引數是否大於0,或關鍵字引數中的 "scale" 是否不為 None
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 定義棄用訊息,提示使用者 "scale" 引數已棄用,未來將引發錯誤
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫棄用函式,記錄 "scale" 引數的棄用
            deprecate("scale", "1.0.0", deprecation_message)

        # 初始化輸出狀態元組
        output_states = ()

        # 遍歷自定義的 ResNet 模組列表
        for resnet in self.resnets:
            # 如果在訓練模式下並且啟用了梯度檢查點
            if self.training and self.gradient_checkpointing:

                # 定義用於建立自定義前向傳播的函式
                def create_custom_forward(module):
                    # 定義自定義前向傳播函式
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                # 檢查 PyTorch 版本是否大於等於 1.11.0
                if is_torch_version(">=", "1.11.0"):
                    # 使用檢查點技術來計算隱藏狀態,防止記憶體洩漏
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
                    )
                else:
                    # 在較早的版本中呼叫檢查點計算隱藏狀態
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb
                    )
            else:
                # 在非訓練模式下直接呼叫 ResNet 模組以計算隱藏狀態
                hidden_states = resnet(hidden_states, temb)

            # 將計算出的隱藏狀態新增到輸出狀態元組中
            output_states = output_states + (hidden_states,)

        # 如果存在下采樣器
        if self.downsamplers is not None:
            # 遍歷下采樣器並計算隱藏狀態
            for downsampler in self.downsamplers:
                hidden_states = downsampler(hidden_states)

            # 將下采樣後的隱藏狀態新增到輸出狀態元組中
            output_states = output_states + (hidden_states,)

        # 返回當前的隱藏狀態和輸出狀態元組
        return hidden_states, output_states
# 定義一個 2D 下采樣編碼塊的類,繼承自 nn.Module
class DownEncoderBlock2D(nn.Module):
    # 初始化方法,設定各類引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: int,  # 輸出通道數
        dropout: float = 0.0,  # dropout 機率,預設 0
        num_layers: int = 1,  # 層數,預設 1
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值,預設 1e-6
        resnet_time_scale_shift: str = "default",  # 時間尺度偏移方式,預設值為 "default"
        resnet_act_fn: str = "swish",  # ResNet 使用的啟用函式,預設是 "swish"
        resnet_groups: int = 32,  # ResNet 的組數,預設 32
        resnet_pre_norm: bool = True,  # 是否在前面進行歸一化,預設 True
        output_scale_factor: float = 1.0,  # 輸出縮放因子,預設 1.0
        add_downsample: bool = True,  # 是否新增下采樣層,預設 True
        downsample_padding: int = 1,  # 下采樣的填充大小,預設 1
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 初始化一個空的列表,用於存放 ResNet 塊
        resnets = []

        # 根據層數建立 ResNet 塊
        for i in range(num_layers):
            # 如果是第一層,使用輸入通道數;否則使用輸出通道數
            in_channels = in_channels if i == 0 else out_channels
            # 根據時間尺度偏移方式選擇相應的 ResNet 塊
            if resnet_time_scale_shift == "spatial":
                # 建立一個帶條件歸一化的 ResNet 塊,並新增到 resnets 列表
                resnets.append(
                    ResnetBlockCondNorm2D(
                        in_channels=in_channels,  # 輸入通道數
                        out_channels=out_channels,  # 輸出通道數
                        temb_channels=None,  # 時間嵌入通道數,預設 None
                        eps=resnet_eps,  # epsilon 值
                        groups=resnet_groups,  # 組數
                        dropout=dropout,  # dropout 機率
                        time_embedding_norm="spatial",  # 時間嵌入的歸一化方式
                        non_linearity=resnet_act_fn,  # 啟用函式
                        output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    )
                )
            else:
                # 建立一個標準的 ResNet 塊,並新增到 resnets 列表
                resnets.append(
                    ResnetBlock2D(
                        in_channels=in_channels,  # 輸入通道數
                        out_channels=out_channels,  # 輸出通道數
                        temb_channels=None,  # 時間嵌入通道數,預設 None
                        eps=resnet_eps,  # epsilon 值
                        groups=resnet_groups,  # 組數
                        dropout=dropout,  # dropout 機率
                        time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入的歸一化方式
                        non_linearity=resnet_act_fn,  # 啟用函式
                        output_scale_factor=output_scale_factor,  # 輸出縮放因子
                        pre_norm=resnet_pre_norm,  # 是否前歸一化
                    )
                )

        # 將 ResNet 塊列表轉為 nn.ModuleList,便於管理
        self.resnets = nn.ModuleList(resnets)

        # 根據 add_downsample 引數決定是否新增下采樣層
        if add_downsample:
            # 建立下采樣層,並新增到 nn.ModuleList
            self.downsamplers = nn.ModuleList(
                [
                    Downsample2D(
                        out_channels,  # 輸入通道數
                        use_conv=True,  # 是否使用卷積進行下采樣
                        out_channels=out_channels,  # 輸出通道數
                        padding=downsample_padding,  # 填充大小
                        name="op"  # 下采樣層的名稱
                    )
                ]
            )
        else:
            # 如果不新增下采樣層,設定為 None
            self.downsamplers = None
    # 定義前向傳播函式,接受隱藏狀態和可選引數
        def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
            # 檢查是否傳入多餘的引數或已棄用的 `scale` 引數
            if len(args) > 0 or kwargs.get("scale", None) is not None:
                # 設定棄用資訊,提醒使用者移除 `scale` 引數
                deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
                # 呼叫棄用函式,記錄警告資訊
                deprecate("scale", "1.0.0", deprecation_message)
    
            # 遍歷每個 ResNet 層,更新隱藏狀態
            for resnet in self.resnets:
                hidden_states = resnet(hidden_states, temb=None)
    
            # 如果存在下采樣層,則逐個應用下采樣
            if self.downsamplers is not None:
                for downsampler in self.downsamplers:
                    hidden_states = downsampler(hidden_states)
    
            # 返回最終的隱藏狀態
            return hidden_states
# 定義一個二維注意力下采樣編碼器塊的類,繼承自 nn.Module
class AttnDownEncoderBlock2D(nn.Module):
    # 初始化方法,接收多個引數用於配置編碼器塊
    def __init__(
        # 輸入通道數
        in_channels: int,
        # 輸出通道數
        out_channels: int,
        # 丟棄率,控制神經元隨機失活的比例
        dropout: float = 0.0,
        # 編碼器塊的層數
        num_layers: int = 1,
        # ResNet 的小常量,用於防止除零錯誤
        resnet_eps: float = 1e-6,
        # ResNet 的時間尺度偏移,預設配置
        resnet_time_scale_shift: str = "default",
        # ResNet 使用的啟用函式型別,預設為 swish
        resnet_act_fn: str = "swish",
        # ResNet 中的分組數
        resnet_groups: int = 32,
        # 是否在前面進行歸一化,預設為 True
        resnet_pre_norm: bool = True,
        # 注意力頭的維度
        attention_head_dim: int = 1,
        # 輸出縮放因子,預設為 1.0
        output_scale_factor: float = 1.0,
        # 是否新增下采樣層,預設為 True
        add_downsample: bool = True,
        # 下采樣的填充大小,預設為 1
        downsample_padding: int = 1,
    ):
        # 呼叫父類建構函式
        super().__init__()
        # 初始化空列表以儲存殘差塊
        resnets = []
        # 初始化空列表以儲存注意力模組
        attentions = []

        # 檢查是否傳入注意力頭維度
        if attention_head_dim is None:
            # 記錄警告資訊,預設設定注意力頭維度為輸出通道數
            logger.warning(
                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
            )
            # 將注意力頭維度設定為輸出通道數
            attention_head_dim = out_channels

        # 遍歷層數以構建殘差塊和注意力模組
        for i in range(num_layers):
            # 第一層輸入通道為 in_channels,其餘層為 out_channels
            in_channels = in_channels if i == 0 else out_channels
            # 根據時間縮放偏移型別構建不同型別的殘差塊
            if resnet_time_scale_shift == "spatial":
                # 新增條件歸一化的殘差塊到列表
                resnets.append(
                    ResnetBlockCondNorm2D(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        temb_channels=None,
                        eps=resnet_eps,
                        groups=resnet_groups,
                        dropout=dropout,
                        time_embedding_norm="spatial",
                        non_linearity=resnet_act_fn,
                        output_scale_factor=output_scale_factor,
                    )
                )
            else:
                # 新增普通殘差塊到列表
                resnets.append(
                    ResnetBlock2D(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        temb_channels=None,
                        eps=resnet_eps,
                        groups=resnet_groups,
                        dropout=dropout,
                        time_embedding_norm=resnet_time_scale_shift,
                        non_linearity=resnet_act_fn,
                        output_scale_factor=output_scale_factor,
                        pre_norm=resnet_pre_norm,
                    )
                )
            # 新增註意力模組到列表
            attentions.append(
                Attention(
                    out_channels,
                    heads=out_channels // attention_head_dim,
                    dim_head=attention_head_dim,
                    rescale_output_factor=output_scale_factor,
                    eps=resnet_eps,
                    norm_num_groups=resnet_groups,
                    residual_connection=True,
                    bias=True,
                    upcast_softmax=True,
                    _from_deprecated_attn_block=True,
                )
            )

        # 將注意力模組列表轉換為 nn.ModuleList
        self.attentions = nn.ModuleList(attentions)
        # 將殘差塊列表轉換為 nn.ModuleList
        self.resnets = nn.ModuleList(resnets)

        # 根據標誌決定是否新增下采樣層
        if add_downsample:
            # 建立下采樣模組並新增到列表
            self.downsamplers = nn.ModuleList(
                [
                    Downsample2D(
                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
                    )
                ]
            )
        else:
            # 如果不新增下采樣層,將其設定為 None
            self.downsamplers = None
    # 定義前向傳播方法,接收隱藏狀態和其他引數
        def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
            # 檢查是否有額外的引數或已棄用的 scale 引數
            if len(args) > 0 or kwargs.get("scale", None) is not None:
                # 構建棄用警告資訊
                deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
                # 呼叫 deprecate 函式顯示棄用警告
                deprecate("scale", "1.0.0", deprecation_message)
    
            # 遍歷自定義的 ResNet 和注意力層進行處理
            for resnet, attn in zip(self.resnets, self.attentions):
                # 透過 ResNet 層處理隱藏狀態
                hidden_states = resnet(hidden_states, temb=None)
                # 透過注意力層處理更新後的隱藏狀態
                hidden_states = attn(hidden_states)
    
            # 如果有下采樣層,則依次處理隱藏狀態
            if self.downsamplers is not None:
                for downsampler in self.downsamplers:
                    # 透過下采樣層處理隱藏狀態
                    hidden_states = downsampler(hidden_states)
    
            # 返回處理後的隱藏狀態
            return hidden_states
# 定義一個名為 AttnSkipDownBlock2D 的類,繼承自 nn.Module
class AttnSkipDownBlock2D(nn.Module):
    # 初始化方法,定義類的建構函式
    def __init__(
        # 輸入通道數,整型
        in_channels: int,
        # 輸出通道數,整型
        out_channels: int,
        # 嵌入通道數,整型
        temb_channels: int,
        # dropout 率,浮點型,預設為 0.0
        dropout: float = 0.0,
        # 網路層數,整型,預設為 1
        num_layers: int = 1,
        # ResNet 的 epsilon 值,浮點型,預設為 1e-6
        resnet_eps: float = 1e-6,
        # ResNet 的時間尺度偏移方式,字串,預設為 "default"
        resnet_time_scale_shift: str = "default",
        # ResNet 的啟用函式型別,字串,預設為 "swish"
        resnet_act_fn: str = "swish",
        # 是否在前面進行規範化,布林值,預設為 True
        resnet_pre_norm: bool = True,
        # 注意力頭的維度,整型,預設為 1
        attention_head_dim: int = 1,
        # 輸出縮放因子,浮點型,預設為平方根2
        output_scale_factor: float = np.sqrt(2.0),
        # 是否新增下采樣層,布林值,預設為 True
        add_downsample: bool = True,
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 初始化一個空的模組列表用於儲存注意力層
        self.attentions = nn.ModuleList([])
        # 初始化一個空的模組列表用於儲存殘差塊
        self.resnets = nn.ModuleList([])

        # 檢查 attention_head_dim 是否為 None
        if attention_head_dim is None:
            # 如果為 None,記錄警告資訊,並將其設定為輸出通道數
            logger.warning(
                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
            )
            attention_head_dim = out_channels

        # 根據層數建立殘差塊和注意力層
        for i in range(num_layers):
            # 設定當前層的輸入通道數,如果是第一層則使用 in_channels,否則使用 out_channels
            in_channels = in_channels if i == 0 else out_channels
            # 新增一個 ResnetBlock2D 到 resnets 列表中
            self.resnets.append(
                ResnetBlock2D(
                    in_channels=in_channels,  # 輸入通道數
                    out_channels=out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # 小常數以防止除零
                    groups=min(in_channels // 4, 32),  # 分組數
                    groups_out=min(out_channels // 4, 32),  # 輸出分組數
                    dropout=dropout,  # dropout 機率
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化方式
                    non_linearity=resnet_act_fn,  # 非線性啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 是否在殘差塊前進行歸一化
                )
            )
            # 新增一個 Attention 層到 attentions 列表中
            self.attentions.append(
                Attention(
                    out_channels,  # 輸出通道數
                    heads=out_channels // attention_head_dim,  # 注意力頭的數量
                    dim_head=attention_head_dim,  # 每個注意力頭的維度
                    rescale_output_factor=output_scale_factor,  # 輸出縮放因子
                    eps=resnet_eps,  # 小常數以防止除零
                    norm_num_groups=32,  # 歸一化分組數
                    residual_connection=True,  # 是否使用殘差連線
                    bias=True,  # 是否使用偏置
                    upcast_softmax=True,  # 是否使用上溢位 softmax
                    _from_deprecated_attn_block=True,  # 是否來自過時的注意力塊
                )
            )

        # 檢查是否需要新增下采樣層
        if add_downsample:
            # 建立一個 ResnetBlock2D 作為下采樣層
            self.resnet_down = ResnetBlock2D(
                in_channels=out_channels,  # 輸入通道數
                out_channels=out_channels,  # 輸出通道數
                temb_channels=temb_channels,  # 時間嵌入通道數
                eps=resnet_eps,  # 小常數以防止除零
                groups=min(out_channels // 4, 32),  # 分組數
                dropout=dropout,  # dropout 機率
                time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化方式
                non_linearity=resnet_act_fn,  # 非線性啟用函式
                output_scale_factor=output_scale_factor,  # 輸出縮放因子
                pre_norm=resnet_pre_norm,  # 是否在殘差塊前進行歸一化
                use_in_shortcut=True,  # 是否在快捷連線中使用
                down=True,  # 是否進行下采樣
                kernel="fir",  # 卷積核型別
            )
            # 建立下采樣模組列表
            self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
            # 建立跳躍連線卷積層
            self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
        else:
            # 如果不新增下采樣層,則將相關屬性設定為 None
            self.resnet_down = None
            self.downsamplers = None
            self.skip_conv = None

    # 前向傳播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入
        skip_sample: Optional[torch.Tensor] = None,  # 可選的跳躍樣本
        *args,  # 額外的位置引數
        **kwargs,  # 額外的關鍵字引數
    # 定義返回型別為元組,包含張量和多個張量的元組
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
        # 檢查傳入的引數是否存在,或 kwargs 中的 scale 是否不為 None
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 定義棄用資訊,說明 scale 引數將被忽略並將來會引發錯誤
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫棄用函式,記錄 scale 引數的棄用資訊
            deprecate("scale", "1.0.0", deprecation_message)
    
        # 初始化輸出狀態為一個空元組
        output_states = ()
    
        # 遍歷 resnet 和 attention 的組合
        for resnet, attn in zip(self.resnets, self.attentions):
            # 使用 resnet 處理隱藏狀態和時間嵌入
            hidden_states = resnet(hidden_states, temb)
            # 使用 attention 處理更新後的隱藏狀態
            hidden_states = attn(hidden_states)
            # 將當前隱藏狀態新增到輸出狀態元組中
            output_states += (hidden_states,)
    
        # 檢查是否存在下采樣器
        if self.downsamplers is not None:
            # 使用下采樣網路處理隱藏狀態
            hidden_states = self.resnet_down(hidden_states, temb)
            # 遍歷每個下采樣器並處理跳躍樣本
            for downsampler in self.downsamplers:
                skip_sample = downsampler(skip_sample)
    
            # 結合跳躍樣本和隱藏狀態,更新隱藏狀態
            hidden_states = self.skip_conv(skip_sample) + hidden_states
    
            # 將當前隱藏狀態新增到輸出狀態元組中
            output_states += (hidden_states,)
    
        # 返回更新後的隱藏狀態、輸出狀態元組和跳躍樣本
        return hidden_states, output_states, skip_sample
# 定義一個二維跳過塊的類,繼承自 nn.Module
class SkipDownBlock2D(nn.Module):
    # 初始化方法,設定輸入和輸出通道等引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: int,  # 輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        dropout: float = 0.0,  # dropout 機率
        num_layers: int = 1,  # 層數
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 引數
        resnet_time_scale_shift: str = "default",  # 時間縮放偏移方式
        resnet_act_fn: str = "swish",  # ResNet 啟用函式
        resnet_pre_norm: bool = True,  # 是否在前面進行歸一化
        output_scale_factor: float = np.sqrt(2.0),  # 輸出縮放因子
        add_downsample: bool = True,  # 是否新增下采樣層
        downsample_padding: int = 1,  # 下采樣時的填充
    ):
        super().__init__()  # 呼叫父類建構函式
        self.resnets = nn.ModuleList([])  # 初始化 ResNet 塊列表

        # 迴圈建立每一層的 ResNet 塊
        for i in range(num_layers):
            # 第一層使用輸入通道,後續層使用輸出通道
            in_channels = in_channels if i == 0 else out_channels
            self.resnets.append(
                ResnetBlock2D(  # 新增 ResNet 塊
                    in_channels=in_channels,  # 輸入通道數
                    out_channels=out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # epsilon 引數
                    groups=min(in_channels // 4, 32),  # 輸入通道組數
                    groups_out=min(out_channels // 4, 32),  # 輸出通道組數
                    dropout=dropout,  # dropout 機率
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化
                    non_linearity=resnet_act_fn,  # 啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 前歸一化
                )
            )

        # 如果需要新增下采樣層
        if add_downsample:
            self.resnet_down = ResnetBlock2D(  # 建立下采樣 ResNet 塊
                in_channels=out_channels,  # 輸入通道數
                out_channels=out_channels,  # 輸出通道數
                temb_channels=temb_channels,  # 時間嵌入通道數
                eps=resnet_eps,  # epsilon 引數
                groups=min(out_channels // 4, 32),  # 輸出通道組數
                dropout=dropout,  # dropout 機率
                time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化
                non_linearity=resnet_act_fn,  # 啟用函式
                output_scale_factor=output_scale_factor,  # 輸出縮放因子
                pre_norm=resnet_pre_norm,  # 前歸一化
                use_in_shortcut=True,  # 使用短接
                down=True,  # 啟用下采樣
                kernel="fir",  # 指定卷積核型別
            )
            # 建立下采樣模組列表
            self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
            # 建立跳過連線卷積層
            self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
        else:  # 如果不新增下采樣層
            self.resnet_down = None  # 不使用下采樣 ResNet 塊
            self.downsamplers = None  # 不使用下采樣模組列表
            self.skip_conv = None  # 不使用跳過連線卷積層

    # 前向傳播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入
        skip_sample: Optional[torch.Tensor] = None,  # 可選的跳過樣本
        *args,  # 可變位置引數
        **kwargs,  # 可變關鍵字引數
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:  # 定義返回型別為元組,包含一個張量、一個張量元組和另一個張量
        if len(args) > 0 or kwargs.get("scale", None) is not None:  # 檢查是否傳入了位置引數或名為“scale”的關鍵字引數
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."  # 定義廢棄警告資訊
            deprecate("scale", "1.0.0", deprecation_message)  # 呼叫 deprecate 函式,記錄“scale”引數的廢棄資訊

        output_states = ()  # 初始化輸出狀態為一個空元組

        for resnet in self.resnets:  # 遍歷自定義的 ResNet 模型列表
            hidden_states = resnet(hidden_states, temb)  # 將當前的隱藏狀態和時間嵌入傳遞給 ResNet,獲取更新後的隱藏狀態
            output_states += (hidden_states,)  # 將當前的隱藏狀態新增到輸出狀態元組中

        if self.downsamplers is not None:  # 檢查是否存在下采樣模組
            hidden_states = self.resnet_down(hidden_states, temb)  # 使用 ResNet 下采樣隱藏狀態和時間嵌入
            for downsampler in self.downsamplers:  # 遍歷下采樣模組
                skip_sample = downsampler(skip_sample)  # 對跳過連線樣本進行下采樣

            hidden_states = self.skip_conv(skip_sample) + hidden_states  # 透過跳過卷積處理下采樣樣本,並與當前的隱藏狀態相加

            output_states += (hidden_states,)  # 將更新後的隱藏狀態新增到輸出狀態元組中

        return hidden_states, output_states, skip_sample  # 返回更新後的隱藏狀態、輸出狀態元組和跳過樣本
# 定義一個 2D ResNet 下采樣塊,繼承自 nn.Module
class ResnetDownsampleBlock2D(nn.Module):
    # 初始化函式,定義各層引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: int,  # 輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        dropout: float = 0.0,  # dropout 機率
        num_layers: int = 1,  # ResNet 層數
        resnet_eps: float = 1e-6,  # ResNet 中的 epsilon 值
        resnet_time_scale_shift: str = "default",  # 時間縮放偏移方式
        resnet_act_fn: str = "swish",  # 啟用函式
        resnet_groups: int = 32,  # 組數
        resnet_pre_norm: bool = True,  # 是否使用預歸一化
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        add_downsample: bool = True,  # 是否新增下采樣層
        skip_time_act: bool = False,  # 是否跳過時間啟用
    ):
        # 呼叫父類建構函式
        super().__init__()
        resnets = []  # 初始化 ResNet 層列表

        # 根據指定的層數構建 ResNet 塊
        for i in range(num_layers):
            in_channels = in_channels if i == 0 else out_channels  # 確定輸入通道數
            resnets.append(
                ResnetBlock2D(
                    in_channels=in_channels,  # 輸入通道數
                    out_channels=out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # epsilon 值
                    groups=resnet_groups,  # 組數
                    dropout=dropout,  # dropout 機率
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化方式
                    non_linearity=resnet_act_fn,  # 啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 是否使用預歸一化
                    skip_time_act=skip_time_act,  # 是否跳過時間啟用
                )
            )

        self.resnets = nn.ModuleList(resnets)  # 將 ResNet 層轉換為模組列表

        # 如果需要,新增下采樣層
        if add_downsample:
            self.downsamplers = nn.ModuleList(
                [
                    ResnetBlock2D(
                        in_channels=out_channels,  # 輸入通道數
                        out_channels=out_channels,  # 輸出通道數
                        temb_channels=temb_channels,  # 時間嵌入通道數
                        eps=resnet_eps,  # epsilon 值
                        groups=resnet_groups,  # 組數
                        dropout=dropout,  # dropout 機率
                        time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化方式
                        non_linearity=resnet_act_fn,  # 啟用函式
                        output_scale_factor=output_scale_factor,  # 輸出縮放因子
                        pre_norm=resnet_pre_norm,  # 是否使用預歸一化
                        skip_time_act=skip_time_act,  # 是否跳過時間啟用
                        down=True,  # 指定為下采樣層
                    )
                ]
            )
        else:
            self.downsamplers = None  # 如果不需要下采樣層,則為 None

        self.gradient_checkpointing = False  # 初始化梯度檢查點為 False

    # 定義前向傳播函式
    def forward(
        self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs  # 前向傳播輸入
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
        # 檢查引數是否存在,或者是否提供了已棄用的 `scale` 引數
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 定義棄用訊息,告知使用者 `scale` 引數已棄用並將在未來的版本中引發錯誤
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫 deprecate 函式記錄棄用資訊
            deprecate("scale", "1.0.0", deprecation_message)

        # 初始化輸出狀態元組
        output_states = ()

        # 遍歷所有的 ResNet 模型
        for resnet in self.resnets:
            # 如果處於訓練模式並啟用梯度檢查點
            if self.training and self.gradient_checkpointing:
                
                # 定義一個函式用於建立自定義前向傳播
                def create_custom_forward(module):
                    # 定義自定義前向傳播函式,呼叫傳入的模組
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                # 檢查 PyTorch 版本是否大於等於 1.11.0
                if is_torch_version(">=", "1.11.0"):
                    # 使用梯度檢查點機制計算隱藏狀態,禁用重入
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
                    )
                else:
                    # 在較舊版本的 PyTorch 中使用梯度檢查點機制計算隱藏狀態
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb
                    )
            else:
                # 在非訓練模式下直接呼叫 ResNet 計算隱藏狀態
                hidden_states = resnet(hidden_states, temb)

            # 將當前的隱藏狀態新增到輸出狀態元組中
            output_states = output_states + (hidden_states,)

        # 檢查是否存在下采樣器
        if self.downsamplers is not None:
            # 遍歷所有的下采樣器
            for downsampler in self.downsamplers:
                # 使用下采樣器計算隱藏狀態
                hidden_states = downsampler(hidden_states, temb)

            # 將當前的隱藏狀態新增到輸出狀態元組中
            output_states = output_states + (hidden_states,)

        # 返回最終的隱藏狀態和輸出狀態元組
        return hidden_states, output_states
# 定義一個簡單的二維交叉注意力下采樣塊類,繼承自 nn.Module
class SimpleCrossAttnDownBlock2D(nn.Module):
    # 初始化方法,設定輸入、輸出通道等引數
    def __init__(
        # 輸入通道數
        in_channels: int,
        # 輸出通道數
        out_channels: int,
        # 時間嵌入通道數
        temb_channels: int,
        # dropout 機率,預設值為 0.0
        dropout: float = 0.0,
        # 層數,預設值為 1
        num_layers: int = 1,
        # ResNet 中的 epsilon 值,預設值為 1e-6
        resnet_eps: float = 1e-6,
        # ResNet 的時間尺度偏移設定,預設為 "default"
        resnet_time_scale_shift: str = "default",
        # ResNet 的啟用函式型別,預設為 "swish"
        resnet_act_fn: str = "swish",
        # ResNet 的分組數,預設為 32
        resnet_groups: int = 32,
        # 是否在 ResNet 中使用預歸一化,預設為 True
        resnet_pre_norm: bool = True,
        # 注意力頭的維度,預設為 1
        attention_head_dim: int = 1,
        # 交叉注意力的維度,預設為 1280
        cross_attention_dim: int = 1280,
        # 輸出縮放因子,預設為 1.0
        output_scale_factor: float = 1.0,
        # 是否新增下采樣層,預設為 True
        add_downsample: bool = True,
        # 是否跳過時間啟用,預設為 False
        skip_time_act: bool = False,
        # 是否僅使用交叉注意力,預設為 False
        only_cross_attention: bool = False,
        # 交叉注意力的歸一化方式,預設為 None
        cross_attention_norm: Optional[str] = None,
    ):
        # 呼叫父類的初始化方法
        super().__init__()

        # 初始化是否有交叉注意力標誌
        self.has_cross_attention = True

        # 初始化殘差網路和注意力模組的列表
        resnets = []
        attentions = []

        # 設定注意力頭的維度
        self.attention_head_dim = attention_head_dim
        # 計算注意力頭的數量
        self.num_heads = out_channels // self.attention_head_dim

        # 根據層數建立殘差塊
        for i in range(num_layers):
            # 設定輸入通道,第一層使用給定的輸入通道,其餘層使用輸出通道
            in_channels = in_channels if i == 0 else out_channels
            # 將殘差塊新增到列表中
            resnets.append(
                ResnetBlock2D(
                    in_channels=in_channels,  # 輸入通道數
                    out_channels=out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # 殘差網路中的 epsilon 值
                    groups=resnet_groups,  # 分組數量
                    dropout=dropout,  # dropout 機率
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入規範化
                    non_linearity=resnet_act_fn,  # 非線性啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 是否預歸一化
                    skip_time_act=skip_time_act,  # 是否跳過時間啟用
                )
            )

            # 根據是否有縮放點積注意力建立處理器
            processor = (
                AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
            )

            # 將注意力模組新增到列表中
            attentions.append(
                Attention(
                    query_dim=out_channels,  # 查詢維度
                    cross_attention_dim=out_channels,  # 交叉注意力維度
                    heads=self.num_heads,  # 注意力頭數量
                    dim_head=attention_head_dim,  # 每個頭的維度
                    added_kv_proj_dim=cross_attention_dim,  # 額外的鍵值投影維度
                    norm_num_groups=resnet_groups,  # 規範化的組數量
                    bias=True,  # 是否使用偏置
                    upcast_softmax=True,  # 是否上調 softmax
                    only_cross_attention=only_cross_attention,  # 是否僅使用交叉注意力
                    cross_attention_norm=cross_attention_norm,  # 交叉注意力的規範化
                    processor=processor,  # 使用的處理器
                )
            )
        # 將注意力模組列表轉換為可訓練模組
        self.attentions = nn.ModuleList(attentions)
        # 將殘差塊列表轉換為可訓練模組
        self.resnets = nn.ModuleList(resnets)

        # 如果需要新增下采樣
        if add_downsample:
            # 建立下采樣的殘差塊
            self.downsamplers = nn.ModuleList(
                [
                    ResnetBlock2D(
                        in_channels=out_channels,  # 輸入通道數
                        out_channels=out_channels,  # 輸出通道數
                        temb_channels=temb_channels,  # 時間嵌入通道數
                        eps=resnet_eps,  # 殘差網路中的 epsilon 值
                        groups=resnet_groups,  # 分組數量
                        dropout=dropout,  # dropout 機率
                        time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入規範化
                        non_linearity=resnet_act_fn,  # 非線性啟用函式
                        output_scale_factor=output_scale_factor,  # 輸出縮放因子
                        pre_norm=resnet_pre_norm,  # 是否預歸一化
                        skip_time_act=skip_time_act,  # 是否跳過時間啟用
                        down=True,  # 表示這是下采樣
                    )
                ]
            )
        else:
            # 如果不需要下采樣,將下采樣設定為 None
            self.downsamplers = None

        # 初始化梯度檢查點標誌為 False
        self.gradient_checkpointing = False
    # 定義一個前向傳播方法,接受多個輸入引數
        def forward(
            self,
            # 輸入的隱藏狀態張量
            hidden_states: torch.Tensor,
            # 可選的時間嵌入張量
            temb: Optional[torch.Tensor] = None,
            # 可選的編碼器隱藏狀態張量
            encoder_hidden_states: Optional[torch.Tensor] = None,
            # 可選的注意力掩碼張量
            attention_mask: Optional[torch.Tensor] = None,
            # 可選的交叉注意力引數字典
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 可選的編碼器注意力掩碼張量
            encoder_attention_mask: Optional[torch.Tensor] = None,
    # 返回值型別為元組,包含一個張量和一個張量元組
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
        # 如果未提供 cross_attention_kwargs,則使用空字典
        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
        # 檢查是否傳入了 scale 引數,若有則發出棄用警告
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
    
        # 初始化輸出狀態為一個空元組
        output_states = ()
    
        # 檢查 attention_mask 是否為 None
        if attention_mask is None:
            # 如果 encoder_hidden_states 已定義,則進行交叉注意力,使用交叉注意力掩碼
            mask = None if encoder_hidden_states is None else encoder_attention_mask
        else:
            # 如果已定義 attention_mask,則直接使用,不檢查 encoder_attention_mask
            # 這為 UnCLIP 相容性提供支援
            # TODO: UnCLIP 應該透過 encoder_attention_mask 參數列達交叉注意力掩碼
            mask = attention_mask
    
        # 遍歷 ResNet 和注意力層
        for resnet, attn in zip(self.resnets, self.attentions):
            # 在訓練中且開啟了梯度檢查點
            if self.training and self.gradient_checkpointing:
                # 定義一個自定義前向傳播函式
                def create_custom_forward(module, return_dict=None):
                    def custom_forward(*inputs):
                        # 根據 return_dict 決定返回方式
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)
    
                    return custom_forward
    
                # 使用檢查點進行前向傳播,節省記憶體
                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
                # 執行注意力層
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=mask,
                    **cross_attention_kwargs,
                )
            else:
                # 否則直接使用 ResNet 進行前向傳播
                hidden_states = resnet(hidden_states, temb)
    
                # 執行注意力層
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=mask,
                    **cross_attention_kwargs,
                )
    
            # 將當前隱藏狀態新增到輸出狀態元組中
            output_states = output_states + (hidden_states,)
    
        # 如果存在下采樣層
        if self.downsamplers is not None:
            # 遍歷所有下采樣層
            for downsampler in self.downsamplers:
                # 執行下采樣
                hidden_states = downsampler(hidden_states, temb)
    
            # 將下采樣後的隱藏狀態新增到輸出狀態元組中
            output_states = output_states + (hidden_states,)
    
        # 返回最終的隱藏狀態和輸出狀態元組
        return hidden_states, output_states
# 定義一個二維下采樣的神經網路模組,繼承自 nn.Module
class KDownBlock2D(nn.Module):
    # 初始化方法,定義輸入輸出通道數、時間嵌入通道、dropout 機率等引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: int,  # 輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        dropout: float = 0.0,  # dropout 機率,預設為 0
        num_layers: int = 4,  # 殘差層的數量,預設為 4
        resnet_eps: float = 1e-5,  # 殘差層中的 epsilon 值,防止除零錯誤
        resnet_act_fn: str = "gelu",  # 殘差層使用的啟用函式,預設為 GELU
        resnet_group_size: int = 32,  # 殘差層中組的大小
        add_downsample: bool = False,  # 是否新增下采樣層的標誌
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        resnets = []  # 初始化一個空列表用於儲存殘差塊

        # 根據層數構建殘差塊
        for i in range(num_layers):
            # 第一層使用輸入通道,其他層使用輸出通道
            in_channels = in_channels if i == 0 else out_channels
            # 計算組的數量
            groups = in_channels // resnet_group_size
            # 計算輸出組的數量
            groups_out = out_channels // resnet_group_size

            # 建立殘差塊並新增到列表中
            resnets.append(
                ResnetBlockCondNorm2D(
                    in_channels=in_channels,  # 當前層的輸入通道數
                    out_channels=out_channels,  # 當前層的輸出通道數
                    dropout=dropout,  # 當前層的 dropout 機率
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    groups=groups,  # 當前層的組數量
                    groups_out=groups_out,  # 輸出層的組數量
                    eps=resnet_eps,  # 殘差層中的 epsilon 值
                    non_linearity=resnet_act_fn,  # 殘差層的啟用函式
                    time_embedding_norm="ada_group",  # 時間嵌入的歸一化方式
                    conv_shortcut_bias=False,  # 卷積快捷連線是否使用偏置
                )
            )

        # 將殘差塊列表轉換為 nn.ModuleList,以便於引數管理
        self.resnets = nn.ModuleList(resnets)

        # 根據標誌決定是否新增下采樣層
        if add_downsample:
            # 如果需要,建立一個下采樣層並新增到列表中
            self.downsamplers = nn.ModuleList([KDownsample2D()])
        else:
            # 如果不需要下采樣,設定為 None
            self.downsamplers = None

        # 初始化梯度檢查點標誌為 False
        self.gradient_checkpointing = False

    # 前向傳播方法,接收隱藏狀態和時間嵌入(可選)
    def forward(
        self, hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        *args, **kwargs  # 其他可選引數
    # 函式返回一個包含張量和元組的元組
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
        # 檢查是否有額外引數或“scale”關鍵字引數不為 None
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 定義關於“scale”引數的棄用訊息
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫棄用函式記錄“scale”的棄用
            deprecate("scale", "1.0.0", deprecation_message)
    
        # 初始化輸出狀態為一個空元組
        output_states = ()
    
        # 遍歷所有的 ResNet 模組
        for resnet in self.resnets:
            # 如果處於訓練模式且啟用了梯度檢查點
            if self.training and self.gradient_checkpointing:
    
                # 定義一個建立自定義前向傳播的函式
                def create_custom_forward(module):
                    # 定義自定義前向傳播函式,接受任意輸入並返回模組的輸出
                    def custom_forward(*inputs):
                        return module(*inputs)
    
                    return custom_forward
    
                # 如果 PyTorch 版本大於等於 1.11.0
                if is_torch_version(">=", "1.11.0"):
                    # 使用檢查點功能計算隱藏狀態,傳遞自定義前向函式
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
                    )
                else:
                    # 否則使用檢查點功能計算隱藏狀態
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb
                    )
            else:
                # 如果不滿足條件,則直接透過 ResNet 模組計算隱藏狀態
                hidden_states = resnet(hidden_states, temb)
    
            # 將當前隱藏狀態新增到輸出狀態元組中
            output_states += (hidden_states,)
    
        # 如果存在下采樣器
        if self.downsamplers is not None:
            # 遍歷每個下采樣器
            for downsampler in self.downsamplers:
                # 透過下采樣器計算隱藏狀態
                hidden_states = downsampler(hidden_states)
    
        # 返回最終的隱藏狀態和輸出狀態
        return hidden_states, output_states
# 定義一個名為 KCrossAttnDownBlock2D 的類,繼承自 nn.Module
class KCrossAttnDownBlock2D(nn.Module):
    # 初始化方法,接受多個引數以設定模型的結構
    def __init__(
        self,
        in_channels: int,               # 輸入通道數
        out_channels: int,              # 輸出通道數
        temb_channels: int,             # 時間嵌入通道數
        cross_attention_dim: int,       # 跨注意力維度
        dropout: float = 0.0,           # dropout 機率,預設為 0
        num_layers: int = 4,            # 層數,預設為 4
        resnet_group_size: int = 32,    # ResNet 組的大小,預設為 32
        add_downsample: bool = True,    # 是否新增下采樣,預設為 True
        attention_head_dim: int = 64,    # 注意力頭維度,預設為 64
        add_self_attention: bool = False, # 是否新增自注意力,預設為 False
        resnet_eps: float = 1e-5,       # ResNet 的 epsilon 值,預設為 1e-5
        resnet_act_fn: str = "gelu",    # ResNet 的啟用函式型別,預設為 "gelu"
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 初始化空列表以存放 ResNet 塊
        resnets = []
        # 初始化空列表以存放注意力塊
        attentions = []

        # 設定是否包含跨注意力標誌
        self.has_cross_attention = True

        # 建立指定數量的層
        for i in range(num_layers):
            # 第一層的輸入通道數為 in_channels,之後的層使用 out_channels
            in_channels = in_channels if i == 0 else out_channels
            # 計算組數
            groups = in_channels // resnet_group_size
            groups_out = out_channels // resnet_group_size

            # 將 ResnetBlockCondNorm2D 新增到 resnets 列表
            resnets.append(
                ResnetBlockCondNorm2D(
                    in_channels=in_channels,        # 輸入通道數
                    out_channels=out_channels,      # 輸出通道數
                    dropout=dropout,                # dropout 機率
                    temb_channels=temb_channels,    # 時間嵌入通道數
                    groups=groups,                  # 組數
                    groups_out=groups_out,          # 輸出組數
                    eps=resnet_eps,                 # epsilon 值
                    non_linearity=resnet_act_fn,    # 啟用函式
                    time_embedding_norm="ada_group", # 時間嵌入歸一化型別
                    conv_shortcut_bias=False,       # 是否使用卷積快捷連線偏置
                )
            )
            # 將 KAttentionBlock 新增到 attentions 列表
            attentions.append(
                KAttentionBlock(
                    out_channels,                     # 輸出通道數
                    out_channels // attention_head_dim, # 注意力頭數量
                    attention_head_dim,              # 注意力頭維度
                    cross_attention_dim=cross_attention_dim, # 跨注意力維度
                    temb_channels=temb_channels,     # 時間嵌入通道數
                    attention_bias=True,             # 是否使用注意力偏置
                    add_self_attention=add_self_attention, # 是否新增自注意力
                    cross_attention_norm="layer_norm", # 跨注意力歸一化型別
                    group_size=resnet_group_size,    # 組大小
                )
            )

        # 將 resnets 列表轉換為 nn.ModuleList,以便可以在模型中使用
        self.resnets = nn.ModuleList(resnets)
        # 將 attentions 列表轉換為 nn.ModuleList
        self.attentions = nn.ModuleList(attentions)

        # 根據引數決定是否新增下采樣層
        if add_downsample:
            # 新增下采樣模組
            self.downsamplers = nn.ModuleList([KDownsample2D()])
        else:
            # 如果不新增下采樣,設定為 None
            self.downsamplers = None

        # 初始化梯度檢查點標誌為 False
        self.gradient_checkpointing = False

    # 前向傳播方法,定義輸入和輸出
    def forward(
        self,
        hidden_states: torch.Tensor,         # 隱藏狀態輸入
        temb: Optional[torch.Tensor] = None, # 可選的時間嵌入
        encoder_hidden_states: Optional[torch.Tensor] = None, # 可選的編碼器隱藏狀態
        attention_mask: Optional[torch.Tensor] = None, # 可選的注意力掩碼
        cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 可選的跨注意力引數
        encoder_attention_mask: Optional[torch.Tensor] = None, # 可選的編碼器注意力掩碼
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
        # 如果沒有傳入 cross_attention_kwargs,則初始化為空字典
        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
        # 檢查 cross_attention_kwargs 中是否存在 "scale" 引數,併發出警告
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

        # 初始化輸出狀態為一個空元組
        output_states = ()

        # 遍歷 resnets 和 attentions 的對應元素
        for resnet, attn in zip(self.resnets, self.attentions):
            # 如果處於訓練模式且開啟了梯度檢查點
            if self.training and self.gradient_checkpointing:

                # 建立自定義前向傳播函式
                def create_custom_forward(module, return_dict=None):
                    # 定義自定義前向傳播邏輯
                    def custom_forward(*inputs):
                        # 如果指定了返回字典,則返回包含字典的結果
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            # 否則返回普通結果
                            return module(*inputs)

                    return custom_forward

                # 設定檢查點引數,針對 PyTorch 版本進行不同處理
                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                # 使用檢查點機制計算隱藏狀態
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(resnet),  # 傳入自定義前向函式
                    hidden_states,  # 輸入隱藏狀態
                    temb,  # 傳入時間嵌入
                    **ckpt_kwargs,  # 傳入檢查點引數
                )
                # 使用注意力機制更新隱藏狀態
                hidden_states = attn(
                    hidden_states,  # 輸入隱藏狀態
                    encoder_hidden_states=encoder_hidden_states,  # 編碼器隱藏狀態
                    emb=temb,  # 時間嵌入
                    attention_mask=attention_mask,  # 注意力掩碼
                    cross_attention_kwargs=cross_attention_kwargs,  # 交叉注意力引數
                    encoder_attention_mask=encoder_attention_mask,  # 編碼器注意力掩碼
                )
            else:
                # 如果不是訓練模式或沒有使用梯度檢查點,直接透過 ResNet 更新隱藏狀態
                hidden_states = resnet(hidden_states, temb)
                # 使用注意力機制更新隱藏狀態
                hidden_states = attn(
                    hidden_states,  # 輸入隱藏狀態
                    encoder_hidden_states=encoder_hidden_states,  # 編碼器隱藏狀態
                    emb=temb,  # 時間嵌入
                    attention_mask=attention_mask,  # 注意力掩碼
                    cross_attention_kwargs=cross_attention_kwargs,  # 交叉注意力引數
                    encoder_attention_mask=encoder_attention_mask,  # 編碼器注意力掩碼
                )

            # 如果沒有下采樣層,輸出狀態新增 None
            if self.downsamplers is None:
                output_states += (None,)
            else:
                # 否則將當前隱藏狀態新增到輸出狀態
                output_states += (hidden_states,)

        # 如果存在下采樣層,則依次對隱藏狀態進行下采樣
        if self.downsamplers is not None:
            for downsampler in self.downsamplers:
                hidden_states = downsampler(hidden_states)

        # 返回最終的隱藏狀態和輸出狀態
        return hidden_states, output_states
# 定義一個名為 AttnUpBlock2D 的類,繼承自 nn.Module
class AttnUpBlock2D(nn.Module):
    # 初始化方法,接受多個引數以配置該模組
    def __init__(
        # 輸入通道數
        self,
        in_channels: int,
        # 前一層輸出通道數
        prev_output_channel: int,
        # 輸出通道數
        out_channels: int,
        # 嵌入通道數
        temb_channels: int,
        # 解析度索引,預設為 None
        resolution_idx: int = None,
        # dropout 率,預設為 0.0
        dropout: float = 0.0,
        # 層數,預設為 1
        num_layers: int = 1,
        # ResNet 中的小常數,避免除零
        resnet_eps: float = 1e-6,
        # ResNet 時間縮放偏移,預設為 "default"
        resnet_time_scale_shift: str = "default",
        # ResNet 啟用函式,預設為 "swish"
        resnet_act_fn: str = "swish",
        # ResNet 組數,預設為 32
        resnet_groups: int = 32,
        # 是否在 ResNet 中使用預歸一化,預設為 True
        resnet_pre_norm: bool = True,
        # 注意力頭維度,預設為 1
        attention_head_dim: int = 1,
        # 輸出縮放因子,預設為 1.0
        output_scale_factor: float = 1.0,
        # 上取樣型別,預設為 "conv"
        upsample_type: str = "conv",
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 初始化空列表用於儲存 ResNet 塊
        resnets = []
        # 初始化空列表用於儲存注意力層
        attentions = []

        # 設定上取樣型別
        self.upsample_type = upsample_type

        # 如果沒有傳入注意力頭維度
        if attention_head_dim is None:
            # 記錄警告,建議使用預設的頭維度
            logger.warning(
                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
            )
            # 將注意力頭維度設定為輸出通道數
            attention_head_dim = out_channels

        # 遍歷每一層
        for i in range(num_layers):
            # 設定殘差跳過通道數,最後一層使用輸入通道
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 設定當前 ResNet 塊的輸入通道數
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 建立 ResNet 塊並新增到列表
            resnets.append(
                ResnetBlock2D(
                    # 輸入通道數為當前 ResNet 的輸入通道加上跳過的通道
                    in_channels=resnet_in_channels + res_skip_channels,
                    # 輸出通道數
                    out_channels=out_channels,
                    # 時間嵌入通道數
                    temb_channels=temb_channels,
                    # 餘弦相似性小常數
                    eps=resnet_eps,
                    # 分組數
                    groups=resnet_groups,
                    # dropout 機率
                    dropout=dropout,
                    # 時間嵌入的歸一化方式
                    time_embedding_norm=resnet_time_scale_shift,
                    # 非線性啟用函式
                    non_linearity=resnet_act_fn,
                    # 輸出縮放因子
                    output_scale_factor=output_scale_factor,
                    # 是否進行預歸一化
                    pre_norm=resnet_pre_norm,
                )
            )
            # 建立注意力層並新增到列表
            attentions.append(
                Attention(
                    # 輸出通道數
                    out_channels,
                    # 注意力頭的數量
                    heads=out_channels // attention_head_dim,
                    # 每個頭的維度
                    dim_head=attention_head_dim,
                    # 輸出重縮放因子
                    rescale_output_factor=output_scale_factor,
                    # 餘弦相似性小常數
                    eps=resnet_eps,
                    # 歸一化的組數
                    norm_num_groups=resnet_groups,
                    # 是否使用殘差連線
                    residual_connection=True,
                    # 是否使用偏置
                    bias=True,
                    # 是否上升 softmax 的精度
                    upcast_softmax=True,
                    # 是否從已棄用的注意力塊中獲取
                    _from_deprecated_attn_block=True,
                )
            )

        # 將注意力層列表轉換為 nn.ModuleList,以便於管理
        self.attentions = nn.ModuleList(attentions)
        # 將 ResNet 塊列表轉換為 nn.ModuleList
        self.resnets = nn.ModuleList(resnets)

        # 根據上取樣型別選擇上取樣方法
        if upsample_type == "conv":
            # 使用卷積上取樣,並建立 ModuleList
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        elif upsample_type == "resnet":
            # 使用 ResNet 塊進行上取樣,並建立 ModuleList
            self.upsamplers = nn.ModuleList(
                [
                    ResnetBlock2D(
                        # 輸入通道數
                        in_channels=out_channels,
                        # 輸出通道數
                        out_channels=out_channels,
                        # 時間嵌入通道數
                        temb_channels=temb_channels,
                        # 餘弦相似性小常數
                        eps=resnet_eps,
                        # 分組數
                        groups=resnet_groups,
                        # dropout 機率
                        dropout=dropout,
                        # 時間嵌入的歸一化方式
                        time_embedding_norm=resnet_time_scale_shift,
                        # 非線性啟用函式
                        non_linearity=resnet_act_fn,
                        # 輸出縮放因子
                        output_scale_factor=output_scale_factor,
                        # 是否進行預歸一化
                        pre_norm=resnet_pre_norm,
                        # 表示這是一個上取樣的塊
                        up=True,
                    )
                ]
            )
        else:
            # 如果上取樣型別無效,則設為 None
            self.upsamplers = None

        # 儲存當前解析度索引
        self.resolution_idx = resolution_idx
    # 定義前向傳播函式,接收隱藏狀態及其他引數
    def forward(
            self,
            hidden_states: torch.Tensor,  # 當前的隱藏狀態張量
            res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 之前的隱藏狀態元組
            temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
            upsample_size: Optional[int] = None,  # 可選的上取樣大小
            *args,  # 額外的位置引數
            **kwargs,  # 額外的關鍵字引數
        ) -> torch.Tensor:  # 返回型別為張量
            # 檢查是否傳入了多餘的引數或已棄用的 scale 引數
            if len(args) > 0 or kwargs.get("scale", None) is not None:
                # 設定棄用警告資訊
                deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
                # 呼叫棄用函式發出警告
                deprecate("scale", "1.0.0", deprecation_message)
    
            # 遍歷每一對殘差網路和注意力層
            for resnet, attn in zip(self.resnets, self.attentions):
                # 從元組中彈出最後一個殘差隱藏狀態
                res_hidden_states = res_hidden_states_tuple[-1]
                # 更新殘差隱藏狀態元組,去掉最後一個元素
                res_hidden_states_tuple = res_hidden_states_tuple[:-1]
                # 將當前隱藏狀態和殘差隱藏狀態在維度1上拼接
                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
    
                # 將拼接後的隱藏狀態傳入殘差網路
                hidden_states = resnet(hidden_states, temb)
                # 將輸出的隱藏狀態傳入注意力層
                hidden_states = attn(hidden_states)
    
            # 檢查是否存在上取樣器
            if self.upsamplers is not None:
                # 遍歷每個上取樣器
                for upsampler in self.upsamplers:
                    # 根據上取樣型別選擇處理方式
                    if self.upsample_type == "resnet":
                        # 將隱藏狀態傳入上取樣器並提供時間嵌入
                        hidden_states = upsampler(hidden_states, temb=temb)
                    else:
                        # 將隱藏狀態傳入上取樣器
                        hidden_states = upsampler(hidden_states)
    
            # 返回處理後的隱藏狀態
            return hidden_states
# 定義一個名為 CrossAttnUpBlock2D 的類,繼承自 nn.Module
class CrossAttnUpBlock2D(nn.Module):
    # 初始化方法,定義該類的建構函式
    def __init__(
        # 輸入通道數
        self,
        in_channels: int,
        # 輸出通道數
        out_channels: int,
        # 前一層輸出通道數
        prev_output_channel: int,
        # 時間嵌入通道數
        temb_channels: int,
        # 可選的解析度索引
        resolution_idx: Optional[int] = None,
        # Dropout 機率
        dropout: float = 0.0,
        # 層數
        num_layers: int = 1,
        # 每個塊的 Transformer 層數,可以是單個整數或元組
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
        # ResNet 的 epsilon 值,避免除零錯誤
        resnet_eps: float = 1e-6,
        # ResNet 的時間尺度偏移引數
        resnet_time_scale_shift: str = "default",
        # ResNet 的啟用函式型別
        resnet_act_fn: str = "swish",
        # ResNet 的組數
        resnet_groups: int = 32,
        # 是否使用預歸一化
        resnet_pre_norm: bool = True,
        # 注意力頭的數量
        num_attention_heads: int = 1,
        # 跨注意力的維度
        cross_attention_dim: int = 1280,
        # 輸出縮放因子
        output_scale_factor: float = 1.0,
        # 是否新增上取樣層
        add_upsample: bool = True,
        # 是否使用雙重跨注意力
        dual_cross_attention: bool = False,
        # 是否使用線性投影
        use_linear_projection: bool = False,
        # 是否僅使用跨注意力
        only_cross_attention: bool = False,
        # 是否提升注意力計算精度
        upcast_attention: bool = False,
        # 注意力型別
        attention_type: str = "default",
    # 繼承父類的初始化方法
        ):
            super().__init__()
            # 初始化空列表用於儲存 ResNet 和注意力模組
            resnets = []
            attentions = []
    
            # 設定是否有交叉注意力的標誌
            self.has_cross_attention = True
            # 設定注意力頭的數量
            self.num_attention_heads = num_attention_heads
    
            # 如果輸入的是整數,則將其轉換為包含多個相同值的列表
            if isinstance(transformer_layers_per_block, int):
                transformer_layers_per_block = [transformer_layers_per_block] * num_layers
    
            # 遍歷每一層,構建 ResNet 和注意力模組
            for i in range(num_layers):
                # 設定殘差跳躍通道數量
                res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
                # 設定 ResNet 輸入通道
                resnet_in_channels = prev_output_channel if i == 0 else out_channels
    
                # 將 ResNet 模組新增到列表中
                resnets.append(
                    ResnetBlock2D(
                        in_channels=resnet_in_channels + res_skip_channels,  # 輸入通道數
                        out_channels=out_channels,  # 輸出通道數
                        temb_channels=temb_channels,  # 時間嵌入通道數
                        eps=resnet_eps,  # 小常數用於數值穩定性
                        groups=resnet_groups,  # 分組數
                        dropout=dropout,  # Dropout 率
                        time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入的歸一化方式
                        non_linearity=resnet_act_fn,  # 非線性啟用函式
                        output_scale_factor=output_scale_factor,  # 輸出縮放因子
                        pre_norm=resnet_pre_norm,  # 是否進行預歸一化
                    )
                )
                # 根據是否啟用雙重交叉注意力,選擇不同的注意力模組
                if not dual_cross_attention:
                    attentions.append(
                        Transformer2DModel(
                            num_attention_heads,  # 注意力頭數量
                            out_channels // num_attention_heads,  # 每個頭的輸出通道數
                            in_channels=out_channels,  # 輸入通道數
                            num_layers=transformer_layers_per_block[i],  # 當前層的變換器層數
                            cross_attention_dim=cross_attention_dim,  # 交叉注意力維度
                            norm_num_groups=resnet_groups,  # 歸一化分組數
                            use_linear_projection=use_linear_projection,  # 是否使用線性投影
                            only_cross_attention=only_cross_attention,  # 是否僅使用交叉注意力
                            upcast_attention=upcast_attention,  # 是否上調注意力精度
                            attention_type=attention_type,  # 注意力型別
                        )
                    )
                else:
                    attentions.append(
                        DualTransformer2DModel(
                            num_attention_heads,  # 注意力頭數量
                            out_channels // num_attention_heads,  # 每個頭的輸出通道數
                            in_channels=out_channels,  # 輸入通道數
                            num_layers=1,  # 僅使用一層
                            cross_attention_dim=cross_attention_dim,  # 交叉注意力維度
                            norm_num_groups=resnet_groups,  # 歸一化分組數
                        )
                    )
            # 將注意力模組和 ResNet 模組轉換為 nn.ModuleList 以便於管理
            self.attentions = nn.ModuleList(attentions)
            self.resnets = nn.ModuleList(resnets)
    
            # 根據是否新增上取樣層初始化上取樣模組
            if add_upsample:
                self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
            else:
                self.upsamplers = None
    
            # 初始化梯度檢查點標誌
            self.gradient_checkpointing = False
            # 設定解析度索引
            self.resolution_idx = resolution_idx
    # 定義前向傳播函式,接收多個引數
        def forward(
            self,
            # 當前隱藏狀態的張量
            hidden_states: torch.Tensor,
            # 元組,包含殘差隱藏狀態的張量
            res_hidden_states_tuple: Tuple[torch.Tensor, ...],
            # 可選的時間嵌入張量
            temb: Optional[torch.Tensor] = None,
            # 可選的編碼器隱藏狀態張量
            encoder_hidden_states: Optional[torch.Tensor] = None,
            # 可選的跨注意力引數字典
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 可選的上取樣大小
            upsample_size: Optional[int] = None,
            # 可選的注意力掩碼張量
            attention_mask: Optional[torch.Tensor] = None,
            # 可選的編碼器注意力掩碼張量
            encoder_attention_mask: Optional[torch.Tensor] = None,
# 定義一個名為 UpBlock2D 的類,繼承自 nn.Module
class UpBlock2D(nn.Module):
    # 初始化方法,接受多個引數來構造 UpBlock2D 物件
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        prev_output_channel: int,  # 前一層輸出的通道數
        out_channels: int,  # 輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        resolution_idx: Optional[int] = None,  # 解析度索引,預設為 None
        dropout: float = 0.0,  # dropout 機率,預設為 0.0
        num_layers: int = 1,  # 層數,預設為 1
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值,預設為 1e-6
        resnet_time_scale_shift: str = "default",  # ResNet 的時間縮放偏移,預設為 "default"
        resnet_act_fn: str = "swish",  # ResNet 的啟用函式,預設為 "swish"
        resnet_groups: int = 32,  # ResNet 的組數,預設為 32
        resnet_pre_norm: bool = True,  # 是否進行預歸一化,預設為 True
        output_scale_factor: float = 1.0,  # 輸出縮放因子,預設為 1.0
        add_upsample: bool = True,  # 是否新增上取樣層,預設為 True
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        resnets = []  # 初始化一個空列表用於儲存 ResNet 塊

        # 根據 num_layers 建立 ResNet 塊
        for i in range(num_layers):
            # 設定跳過通道數,如果是最後一層,則使用 in_channels,否則使用 out_channels
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 設定 ResNet 輸入通道數,第一層使用 prev_output_channel,其餘層使用 out_channels
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 將 ResNet 塊新增到列表中
            resnets.append(
                ResnetBlock2D(
                    in_channels=resnet_in_channels + res_skip_channels,  # 輸入通道數加上跳過通道數
                    out_channels=out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # epsilon 值
                    groups=resnet_groups,  # 組數
                    dropout=dropout,  # dropout 機率
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化
                    non_linearity=resnet_act_fn,  # 非線性啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 預歸一化
                )
            )

        # 將 ResNet 塊列表轉換為 nn.ModuleList 以便於管理
        self.resnets = nn.ModuleList(resnets)

        # 如果需要新增上取樣層,則初始化上取樣模組
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            self.upsamplers = None  # 不新增上取樣層,設定為 None

        self.gradient_checkpointing = False  # 初始化梯度檢查點標誌為 False
        self.resolution_idx = resolution_idx  # 儲存解析度索引

    # 定義前向傳播方法,接受輸入的隱藏狀態及其他引數
    def forward(
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 之前層的隱藏狀態元組
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        upsample_size: Optional[int] = None,  # 可選的上取樣大小
        *args,  # 其他位置引數
        **kwargs,  # 其他關鍵字引數
    ) -> torch.Tensor:
        # 檢查是否傳入引數或 'scale' 引數
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 設定棄用訊息,提示使用者 'scale' 引數已被棄用
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫 deprecate 函式記錄棄用
            deprecate("scale", "1.0.0", deprecation_message)

        # 檢查 FreeU 是否啟用
        is_freeu_enabled = (
            getattr(self, "s1", None)  # 獲取屬性 s1
            and getattr(self, "s2", None)  # 獲取屬性 s2
            and getattr(self, "b1", None)  # 獲取屬性 b1
            and getattr(self, "b2", None)  # 獲取屬性 b2
        )

        # 遍歷所有 ResNet 模型
        for resnet in self.resnets:
            # 彈出最後的 ResNet 隱藏狀態
            res_hidden_states = res_hidden_states_tuple[-1]
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]  # 更新元組,去掉最後一個元素

            # 如果 FreeU 被啟用,則僅對前兩個階段進行操作
            if is_freeu_enabled:
                # 呼叫 apply_freeu 函式處理隱藏狀態
                hidden_states, res_hidden_states = apply_freeu(
                    self.resolution_idx,  # 當前解析度索引
                    hidden_states,  # 當前隱藏狀態
                    res_hidden_states,  # ResNet 隱藏狀態
                    s1=self.s1,  # s1 屬性
                    s2=self.s2,  # s2 屬性
                    b1=self.b1,  # b1 屬性
                    b2=self.b2,  # b2 屬性
                )

            # 連線當前隱藏狀態和 ResNet 隱藏狀態
            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

            # 如果處於訓練模式且啟用梯度檢查點
            if self.training and self.gradient_checkpointing:
                # 建立自定義前向傳播函式
                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)  # 呼叫模組的前向傳播

                    return custom_forward

                # 根據 PyTorch 版本選擇檢查點方式
                if is_torch_version(">=", "1.11.0"):
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),  # 使用自定義前向函式
                        hidden_states,  # 當前隱藏狀態
                        temb,  # 時間嵌入
                        use_reentrant=False  # 不使用可重入檢查點
                    )
                else:
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),  # 使用自定義前向函式
                        hidden_states,  # 當前隱藏狀態
                        temb  # 時間嵌入
                    )
            else:
                # 如果不啟用檢查點,直接呼叫 ResNet
                hidden_states = resnet(hidden_states, temb)

        # 如果存在上取樣器,則遍歷進行上取樣
        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                hidden_states = upsampler(hidden_states, upsample_size)  # 呼叫上取樣器

        # 返回最終的隱藏狀態
        return hidden_states
# 定義一個 2D 上取樣解碼塊類,繼承自 nn.Module
class UpDecoderBlock2D(nn.Module):
    # 初始化方法,接受多個引數用於構造解碼塊
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: int,  # 輸出通道數
        resolution_idx: Optional[int] = None,  # 解析度索引,預設為 None
        dropout: float = 0.0,  # dropout 機率,預設為 0
        num_layers: int = 1,  # 層數,預設為 1
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值,預設為 1e-6
        resnet_time_scale_shift: str = "default",  # 時間尺度偏移型別,預設為 "default"
        resnet_act_fn: str = "swish",  # ResNet 的啟用函式,預設為 "swish"
        resnet_groups: int = 32,  # ResNet 的分組數,預設為 32
        resnet_pre_norm: bool = True,  # 是否在 ResNet 中預先歸一化,預設為 True
        output_scale_factor: float = 1.0,  # 輸出縮放因子,預設為 1.0
        add_upsample: bool = True,  # 是否新增上取樣層,預設為 True
        temb_channels: Optional[int] = None,  # 時間嵌入通道數,預設為 None
    ):
        # 呼叫父類構造方法
        super().__init__()
        # 初始化一個空的 ResNet 列表
        resnets = []

        # 根據層數建立相應數量的 ResNet 層
        for i in range(num_layers):
            # 第一個層使用輸入通道數,其餘層使用輸出通道數
            input_channels = in_channels if i == 0 else out_channels

            # 根據時間尺度偏移型別建立不同的 ResNet 塊
            if resnet_time_scale_shift == "spatial":
                resnets.append(
                    ResnetBlockCondNorm2D(  # 新增條件歸一化的 ResNet 塊
                        in_channels=input_channels,  # 輸入通道數
                        out_channels=out_channels,  # 輸出通道數
                        temb_channels=temb_channels,  # 時間嵌入通道數
                        eps=resnet_eps,  # epsilon 值
                        groups=resnet_groups,  # 分組數
                        dropout=dropout,  # dropout 機率
                        time_embedding_norm="spatial",  # 時間嵌入歸一化型別
                        non_linearity=resnet_act_fn,  # 啟用函式
                        output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    )
                )
            else:
                resnets.append(
                    ResnetBlock2D(  # 新增普通的 ResNet 塊
                        in_channels=input_channels,  # 輸入通道數
                        out_channels=out_channels,  # 輸出通道數
                        temb_channels=temb_channels,  # 時間嵌入通道數
                        eps=resnet_eps,  # epsilon 值
                        groups=resnet_groups,  # 分組數
                        dropout=dropout,  # dropout 機率
                        time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化型別
                        non_linearity=resnet_act_fn,  # 啟用函式
                        output_scale_factor=output_scale_factor,  # 輸出縮放因子
                        pre_norm=resnet_pre_norm,  # 是否預先歸一化
                    )
                )

        # 將建立的 ResNet 塊儲存在 ModuleList 中,以便於管理
        self.resnets = nn.ModuleList(resnets)

        # 根據是否新增上取樣層初始化上取樣層列表
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            self.upsamplers = None  # 如果不新增,則設為 None

        # 儲存解析度索引
        self.resolution_idx = resolution_idx

    # 定義前向傳播方法
    def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
        # 遍歷所有 ResNet 層進行前向傳播
        for resnet in self.resnets:
            hidden_states = resnet(hidden_states, temb=temb)  # 更新隱藏狀態

        # 如果存在上取樣層,則遍歷進行上取樣
        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                hidden_states = upsampler(hidden_states)  # 更新隱藏狀態

        # 返回最終的隱藏狀態
        return hidden_states


# 定義一個注意力上取樣解碼塊類,繼承自 nn.Module
class AttnUpDecoderBlock2D(nn.Module):
    # 初始化類的建構函式,定義各引數
        def __init__(
            # 輸入通道數,決定輸入資料的特徵維度
            self,
            in_channels: int,
            # 輸出通道數,決定輸出資料的特徵維度
            out_channels: int,
            # 解析度索引,選擇特定解析度(可選)
            resolution_idx: Optional[int] = None,
            # dropout比率,用於防止過擬合,預設為0
            dropout: float = 0.0,
            # 網路層數,決定模型的深度,預設為1
            num_layers: int = 1,
            # ResNet的epsilon值,防止分母為0的情況,預設為1e-6
            resnet_eps: float = 1e-6,
            # ResNet的時間尺度偏移設定,預設為"default"
            resnet_time_scale_shift: str = "default",
            # ResNet的啟用函式型別,預設為"swish"
            resnet_act_fn: str = "swish",
            # ResNet中的組數,影響計算和模型複雜度,預設為32
            resnet_groups: int = 32,
            # 是否在ResNet中使用預歸一化,預設為True
            resnet_pre_norm: bool = True,
            # 注意力頭的維度,影響注意力機制的計算,預設為1
            attention_head_dim: int = 1,
            # 輸出縮放因子,調整輸出的大小,預設為1.0
            output_scale_factor: float = 1.0,
            # 是否新增上取樣層,影響模型的結構,預設為True
            add_upsample: bool = True,
            # 時間嵌入通道數(可選),用於特定的時間資訊表示
            temb_channels: Optional[int] = None,
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 初始化儲存殘差塊的列表
        resnets = []
        # 初始化儲存注意力層的列表
        attentions = []

        # 如果未指定注意力頭維度,則發出警告並使用輸出通道數作為預設值
        if attention_head_dim is None:
            logger.warning(
                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
            )
            # 將注意力頭維度設定為輸出通道數
            attention_head_dim = out_channels

        # 遍歷每一層以構建殘差塊和注意力層
        for i in range(num_layers):
            # 如果是第一層,則輸入通道為輸入通道數,否則為輸出通道數
            input_channels = in_channels if i == 0 else out_channels

            # 如果時間尺度偏移為"spatial",則使用條件歸一化的殘差塊
            if resnet_time_scale_shift == "spatial":
                resnets.append(
                    ResnetBlockCondNorm2D(
                        # 輸入通道數
                        in_channels=input_channels,
                        # 輸出通道數
                        out_channels=out_channels,
                        # 時間嵌入通道數
                        temb_channels=temb_channels,
                        # 殘差塊的epsilon引數
                        eps=resnet_eps,
                        # 組歸一化的組數
                        groups=resnet_groups,
                        # dropout比例
                        dropout=dropout,
                        # 時間嵌入的歸一化方式
                        time_embedding_norm="spatial",
                        # 非線性啟用函式
                        non_linearity=resnet_act_fn,
                        # 輸出縮放因子
                        output_scale_factor=output_scale_factor,
                    )
                )
            else:
                # 否則使用普通的2D殘差塊
                resnets.append(
                    ResnetBlock2D(
                        # 輸入通道數
                        in_channels=input_channels,
                        # 輸出通道數
                        out_channels=out_channels,
                        # 時間嵌入通道數
                        temb_channels=temb_channels,
                        # 殘差塊的epsilon引數
                        eps=resnet_eps,
                        # 組歸一化的組數
                        groups=resnet_groups,
                        # dropout比例
                        dropout=dropout,
                        # 時間嵌入的歸一化方式
                        time_embedding_norm=resnet_time_scale_shift,
                        # 非線性啟用函式
                        non_linearity=resnet_act_fn,
                        # 輸出縮放因子
                        output_scale_factor=output_scale_factor,
                        # 是否使用預歸一化
                        pre_norm=resnet_pre_norm,
                    )
                )

            # 新增註意力層
            attentions.append(
                Attention(
                    # 輸出通道數
                    out_channels,
                    # 計算注意力頭數
                    heads=out_channels // attention_head_dim,
                    # 注意力頭維度
                    dim_head=attention_head_dim,
                    # 輸出縮放因子
                    rescale_output_factor=output_scale_factor,
                    # epsilon引數
                    eps=resnet_eps,
                    # 組歸一化的組數(如果不是空間歸一化)
                    norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None,
                    # 空間歸一化維度(如果是空間歸一化)
                    spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
                    # 是否使用殘差連線
                    residual_connection=True,
                    # 是否使用偏置
                    bias=True,
                    # 是否使用上取樣的softmax
                    upcast_softmax=True,
                    # 從過時的注意力塊建立
                    _from_deprecated_attn_block=True,
                )
            )

        # 將注意力層轉換為模組列表
        self.attentions = nn.ModuleList(attentions)
        # 將殘差塊轉換為模組列表
        self.resnets = nn.ModuleList(resnets)

        # 如果需要新增上取樣層
        if add_upsample:
            # 建立上取樣層的模組列表
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            # 如果不需要上取樣,則將其設定為None
            self.upsamplers = None

        # 設定解析度索引
        self.resolution_idx = resolution_idx
    # 定義前向傳播函式,接收隱藏狀態和可選的時間嵌入,返回處理後的隱藏狀態
    def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
        # 遍歷殘差網路和注意力模組的組合
        for resnet, attn in zip(self.resnets, self.attentions):
            # 將當前隱藏狀態輸入到殘差網路中,可能包含時間嵌入
            hidden_states = resnet(hidden_states, temb=temb)
            # 將殘差網路的輸出輸入到注意力模組中,可能包含時間嵌入
            hidden_states = attn(hidden_states, temb=temb)
    
        # 如果上取樣模組不為空,則執行上取樣操作
        if self.upsamplers is not None:
            # 遍歷所有上取樣模組
            for upsampler in self.upsamplers:
                # 將當前隱藏狀態輸入到上取樣模組中
                hidden_states = upsampler(hidden_states)
    
        # 返回最終處理後的隱藏狀態
        return hidden_states
# 定義一個名為 AttnSkipUpBlock2D 的類,繼承自 nn.Module
class AttnSkipUpBlock2D(nn.Module):
    # 初始化方法,定義該類的屬性
    def __init__(
        # 輸入通道數
        in_channels: int,
        # 前一個輸出通道數
        prev_output_channel: int,
        # 輸出通道數
        out_channels: int,
        # 時間嵌入通道數
        temb_channels: int,
        # 可選的解析度索引
        resolution_idx: Optional[int] = None,
        # dropout 機率
        dropout: float = 0.0,
        # 層數
        num_layers: int = 1,
        # ResNet 的 epsilon 值
        resnet_eps: float = 1e-6,
        # ResNet 時間縮放偏移設定
        resnet_time_scale_shift: str = "default",
        # ResNet 啟用函式型別
        resnet_act_fn: str = "swish",
        # 是否使用 ResNet 預歸一化
        resnet_pre_norm: bool = True,
        # 注意力頭的維度
        attention_head_dim: int = 1,
        # 輸出縮放因子
        output_scale_factor: float = np.sqrt(2.0),
        # 是否新增上取樣
        add_upsample: bool = True,
    ):
        # 初始化父類
        super().__init__()
        # 此處應有具體的初始化程式碼(如層的定義),略去

    # 定義前向傳播方法
    def forward(
        # 隱藏狀態輸入
        hidden_states: torch.Tensor,
        # 之前的隱藏狀態元組
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],
        # 可選的時間嵌入
        temb: Optional[torch.Tensor] = None,
        # 可選的跳躍樣本
        skip_sample=None,
        # 額外引數
        *args,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # 檢查 args 和 kwargs 是否包含 scale 引數
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 定義棄用資訊
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫棄用函式
            deprecate("scale", "1.0.0", deprecation_message)

        # 遍歷 ResNet 層
        for resnet in self.resnets:
            # 從元組中提取最近的隱藏狀態
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新隱藏狀態元組,去掉最近的隱藏狀態
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
            # 將當前隱藏狀態與提取的隱藏狀態拼接在一起
            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

            # 透過 ResNet 層處理隱藏狀態
            hidden_states = resnet(hidden_states, temb)

        # 透過注意力層處理隱藏狀態
        hidden_states = self.attentions[0](hidden_states)

        # 檢查跳躍樣本是否為 None
        if skip_sample is not None:
            # 如果不是,則透過上取樣層處理跳躍樣本
            skip_sample = self.upsampler(skip_sample)
        else:
            # 如果是,則將跳躍樣本設為 0
            skip_sample = 0

        # 檢查 ResNet 上取樣層是否存在
        if self.resnet_up is not None:
            # 透過跳躍歸一化層處理隱藏狀態
            skip_sample_states = self.skip_norm(hidden_states)
            # 應用啟用函式
            skip_sample_states = self.act(skip_sample_states)
            # 透過跳躍卷積層處理狀態
            skip_sample_states = self.skip_conv(skip_sample_states)

            # 更新跳躍樣本
            skip_sample = skip_sample + skip_sample_states

            # 透過 ResNet 上取樣層處理隱藏狀態
            hidden_states = self.resnet_up(hidden_states, temb)

        # 返回處理後的隱藏狀態和跳躍樣本
        return hidden_states, skip_sample


# 定義一個名為 SkipUpBlock2D 的類,繼承自 nn.Module
class SkipUpBlock2D(nn.Module):
    # 初始化方法,定義該類的屬性
    def __init__(
        # 輸入通道數
        in_channels: int,
        # 前一個輸出通道數
        prev_output_channel: int,
        # 輸出通道數
        out_channels: int,
        # 時間嵌入通道數
        temb_channels: int,
        # 可選的解析度索引
        resolution_idx: Optional[int] = None,
        # dropout 機率
        dropout: float = 0.0,
        # 層數
        num_layers: int = 1,
        # ResNet 的 epsilon 值
        resnet_eps: float = 1e-6,
        # ResNet 時間縮放偏移設定
        resnet_time_scale_shift: str = "default",
        # ResNet 啟用函式型別
        resnet_act_fn: str = "swish",
        # 是否使用 ResNet 預歸一化
        resnet_pre_norm: bool = True,
        # 輸出縮放因子
        output_scale_factor: float = np.sqrt(2.0),
        # 是否新增上取樣
        add_upsample: bool = True,
        # 上取樣填充大小
        upsample_padding: int = 1,
    ):
        # 初始化父類
        super().__init__()
        # 此處應有具體的初始化程式碼(如層的定義),略去
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 建立一個空的 ModuleList 用於儲存 ResnetBlock2D 層
        self.resnets = nn.ModuleList([])

        # 根據 num_layers 的數量來新增 ResnetBlock2D 層
        for i in range(num_layers):
            # 計算跳過通道數,如果是最後一層則使用 in_channels,否則使用 out_channels
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 確定當前 ResNet 塊的輸入通道數
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 向 resnets 列表中新增一個新的 ResnetBlock2D 例項
            self.resnets.append(
                ResnetBlock2D(
                    # 輸入通道數為當前層輸入通道加跳過通道數
                    in_channels=resnet_in_channels + res_skip_channels,
                    # 輸出通道數
                    out_channels=out_channels,
                    # 時間嵌入通道數
                    temb_channels=temb_channels,
                    # 歸一化的 epsilon 值
                    eps=resnet_eps,
                    # 分組數為輸入通道數的一部分,最多為 32
                    groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
                    # 輸出分組數,同樣最多為 32
                    groups_out=min(out_channels // 4, 32),
                    # dropout 機率
                    dropout=dropout,
                    # 時間嵌入的歸一化方式
                    time_embedding_norm=resnet_time_scale_shift,
                    # 啟用函式型別
                    non_linearity=resnet_act_fn,
                    # 輸出縮放因子
                    output_scale_factor=output_scale_factor,
                    # 是否使用預歸一化
                    pre_norm=resnet_pre_norm,
                )
            )

        # 初始化上取樣層
        self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
        # 如果需要新增上取樣層
        if add_upsample:
            # 新增一個上取樣的 ResnetBlock2D
            self.resnet_up = ResnetBlock2D(
                # 輸入通道數
                in_channels=out_channels,
                # 輸出通道數
                out_channels=out_channels,
                # 時間嵌入通道數
                temb_channels=temb_channels,
                # 歸一化的 epsilon 值
                eps=resnet_eps,
                # 分組數,最多為 32
                groups=min(out_channels // 4, 32),
                # 輸出分組數,最多為 32
                groups_out=min(out_channels // 4, 32),
                # dropout 機率
                dropout=dropout,
                # 時間嵌入的歸一化方式
                time_embedding_norm=resnet_time_scale_shift,
                # 啟用函式型別
                non_linearity=resnet_act_fn,
                # 輸出縮放因子
                output_scale_factor=output_scale_factor,
                # 是否使用預歸一化
                pre_norm=resnet_pre_norm,
                # 是否在快捷路徑中使用
                use_in_shortcut=True,
                # 標記為上取樣
                up=True,
                # 使用 FIR 卷積核
                kernel="fir",
            )
            # 定義跳過連線的卷積層,將輸出通道數對映到 3 通道
            self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            # 定義跳過連線的歸一化層
            self.skip_norm = torch.nn.GroupNorm(
                num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
            )
            # 定義啟用函式為 SiLU
            self.act = nn.SiLU()
        else:
            # 如果不新增上取樣層,則將相關屬性設為 None
            self.resnet_up = None
            self.skip_conv = None
            self.skip_norm = None
            self.act = None

        # 儲存解析度索引
        self.resolution_idx = resolution_idx

    # 定義前向傳播方法
    def forward(
        # 定義前向傳播輸入引數
        hidden_states: torch.Tensor,
        # 儲存殘差隱藏狀態的元組
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],
        # 可選的時間嵌入張量
        temb: Optional[torch.Tensor] = None,
        # 跳過取樣的可選引數
        skip_sample=None,
        # 可變位置引數
        *args,
        # 可變關鍵字引數
        **kwargs,
    # 函式返回兩個張量,表示隱藏狀態和跳過的樣本
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # 檢查引數長度或關鍵字引數 "scale" 是否存在
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 建立棄用訊息,提醒使用者 "scale" 引數將被忽略
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫棄用函式,傳遞引數資訊
            deprecate("scale", "1.0.0", deprecation_message)
    
        # 遍歷自定義的 ResNet 模組
        for resnet in self.resnets:
            # 從隱藏狀態元組中彈出最後一個 ResNet 隱藏狀態
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新隱藏狀態元組,去掉最後一個元素
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
            # 將當前的隱藏狀態與 ResNet 隱藏狀態連線
            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
    
            # 透過 ResNet 模組更新隱藏狀態
            hidden_states = resnet(hidden_states, temb)
    
        # 檢查跳過樣本是否存在
        if skip_sample is not None:
            # 如果存在,使用上取樣器處理跳過樣本
            skip_sample = self.upsampler(skip_sample)
        else:
            # 否則,將跳過樣本初始化為 0
            skip_sample = 0
    
        # 檢查是否存在 ResNet 上取樣模組
        if self.resnet_up is not None:
            # 對隱藏狀態應用歸一化
            skip_sample_states = self.skip_norm(hidden_states)
            # 對歸一化結果應用啟用函式
            skip_sample_states = self.act(skip_sample_states)
            # 對啟用結果應用卷積操作
            skip_sample_states = self.skip_conv(skip_sample_states)
    
            # 將跳過樣本與處理後的狀態相加
            skip_sample = skip_sample + skip_sample_states
    
            # 透過 ResNet 上取樣模組更新隱藏狀態
            hidden_states = self.resnet_up(hidden_states, temb)
    
        # 返回最終的隱藏狀態和跳過樣本
        return hidden_states, skip_sample
# 定義一個 2D 上取樣的 ResNet 塊,繼承自 nn.Module
class ResnetUpsampleBlock2D(nn.Module):
    # 初始化方法,設定網路引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        prev_output_channel: int,  # 前一層輸出的通道數
        out_channels: int,  # 輸出通道數
        temb_channels: int,  # 時間嵌入的通道數
        resolution_idx: Optional[int] = None,  # 解析度索引(可選)
        dropout: float = 0.0,  # dropout 比例
        num_layers: int = 1,  # ResNet 層數
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值
        resnet_time_scale_shift: str = "default",  # 時間縮放偏移方式
        resnet_act_fn: str = "swish",  # 啟用函式型別
        resnet_groups: int = 32,  # 組數
        resnet_pre_norm: bool = True,  # 是否使用預歸一化
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        add_upsample: bool = True,  # 是否新增上取樣層
        skip_time_act: bool = False,  # 是否跳過時間啟用
    ):
        # 呼叫父類建構函式
        super().__init__()
        # 初始化一個空的 ResNet 列表
        resnets = []

        # 遍歷層數,建立每一層的 ResNet 塊
        for i in range(num_layers):
            # 確定跳過通道數,如果是最後一層,則使用輸入通道數,否則使用輸出通道數
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 確定當前層的輸入通道數
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 將 ResNet 塊新增到列表中
            resnets.append(
                ResnetBlock2D(
                    in_channels=resnet_in_channels + res_skip_channels,  # 輸入通道數
                    out_channels=out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # epsilon 值
                    groups=resnet_groups,  # 組數
                    dropout=dropout,  # dropout 比例
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化方式
                    non_linearity=resnet_act_fn,  # 非線性啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 是否使用預歸一化
                    skip_time_act=skip_time_act,  # 是否跳過時間啟用
                )
            )

        # 將 ResNet 列表轉換為 nn.ModuleList,便於 PyTorch 管理
        self.resnets = nn.ModuleList(resnets)

        # 如果需要新增上取樣層,則建立上取樣模組
        if add_upsample:
            self.upsamplers = nn.ModuleList(
                [
                    ResnetBlock2D(
                        in_channels=out_channels,  # 輸入通道數
                        out_channels=out_channels,  # 輸出通道數
                        temb_channels=temb_channels,  # 時間嵌入通道數
                        eps=resnet_eps,  # epsilon 值
                        groups=resnet_groups,  # 組數
                        dropout=dropout,  # dropout 比例
                        time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化方式
                        non_linearity=resnet_act_fn,  # 非線性啟用函式
                        output_scale_factor=output_scale_factor,  # 輸出縮放因子
                        pre_norm=resnet_pre_norm,  # 是否使用預歸一化
                        skip_time_act=skip_time_act,  # 是否跳過時間啟用
                        up=True,  # 標記為上取樣塊
                    )
                ]
            )
        else:
            # 如果不需要上取樣,則將其設為 None
            self.upsamplers = None

        # 初始化梯度檢查點為 False
        self.gradient_checkpointing = False
        # 設定解析度索引
        self.resolution_idx = resolution_idx

    # 定義前向傳播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 額外的隱藏狀態元組
        temb: Optional[torch.Tensor] = None,  # 時間嵌入(可選)
        upsample_size: Optional[int] = None,  # 上取樣大小(可選)
        *args,  # 額外的位置引數
        **kwargs,  # 額外的關鍵字引數
    ) -> torch.Tensor:  # 定義一個返回 torch.Tensor 的函式
        # 檢查引數列表是否包含引數或 "scale" 是否不為 None
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 定義棄用資訊,說明 "scale" 引數將被忽略
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫 deprecate 函式以記錄 "scale" 引數的棄用
            deprecate("scale", "1.0.0", deprecation_message)

        # 遍歷儲存的 ResNet 模型列表
        for resnet in self.resnets:
            # 從隱藏狀態元組中彈出最後一個 ResNet 隱藏狀態
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新隱藏狀態元組,去掉最後一個元素
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
            # 將當前的隱藏狀態與 ResNet 隱藏狀態在指定維度上拼接
            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

            # 如果處於訓練模式且開啟了梯度檢查點
            if self.training and self.gradient_checkpointing:

                # 定義一個建立自定義前向傳播函式的內部函式
                def create_custom_forward(module):
                    # 定義自定義前向傳播函式,呼叫模組
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                # 檢查 PyTorch 版本是否大於等於 1.11.0
                if is_torch_version(">=", "1.11.0"):
                    # 使用梯度檢查點功能進行前向傳播
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
                    )
                else:
                    # 使用梯度檢查點功能進行前向傳播,不使用可重入選項
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb
                    )
            else:
                # 直接呼叫 ResNet 模型進行前向傳播
                hidden_states = resnet(hidden_states, temb)

        # 如果存在上取樣器
        if self.upsamplers is not None:
            # 遍歷每個上取樣器
            for upsampler in self.upsamplers:
                # 呼叫上取樣器進行上取樣處理
                hidden_states = upsampler(hidden_states, temb)

        # 返回最終的隱藏狀態
        return hidden_states
# 定義一個簡單的二維交叉注意力上取樣模組,繼承自 nn.Module
class SimpleCrossAttnUpBlock2D(nn.Module):
    # 初始化函式,接受多個引數來配置模組
    def __init__(
        # 輸入通道數
        in_channels: int,
        # 輸出通道數
        out_channels: int,
        # 上一個輸出通道數
        prev_output_channel: int,
        # 時間嵌入通道數
        temb_channels: int,
        # 可選的解析度索引
        resolution_idx: Optional[int] = None,
        # Dropout 機率
        dropout: float = 0.0,
        # 層數
        num_layers: int = 1,
        # ResNet 的 epsilon 值
        resnet_eps: float = 1e-6,
        # ResNet 時間縮放偏移
        resnet_time_scale_shift: str = "default",
        # ResNet 啟用函式
        resnet_act_fn: str = "swish",
        # ResNet 中的組數
        resnet_groups: int = 32,
        # 是否在 ResNet 中使用預歸一化
        resnet_pre_norm: bool = True,
        # 注意力頭的維度
        attention_head_dim: int = 1,
        # 交叉注意力的維度
        cross_attention_dim: int = 1280,
        # 輸出縮放因子
        output_scale_factor: float = 1.0,
        # 是否新增上取樣
        add_upsample: bool = True,
        # 是否跳過時間啟用
        skip_time_act: bool = False,
        # 是否僅使用交叉注意力
        only_cross_attention: bool = False,
        # 可選的交叉注意力歸一化方式
        cross_attention_norm: Optional[str] = None,
    # 初始化函式
        ):
            # 呼叫父類的初始化方法
            super().__init__()
            # 建立一個空列表用於儲存 ResNet 模組
            resnets = []
            # 建立一個空列表用於儲存 Attention 模組
            attentions = []
    
            # 設定是否使用交叉注意力
            self.has_cross_attention = True
            # 設定每個注意力頭的維度
            self.attention_head_dim = attention_head_dim
    
            # 計算注意力頭的數量
            self.num_heads = out_channels // self.attention_head_dim
    
            # 遍歷每一層以構建 ResNet 模組
            for i in range(num_layers):
                # 設定跳躍連線通道數
                res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
                # 設定當前 ResNet 輸入通道數
                resnet_in_channels = prev_output_channel if i == 0 else out_channels
    
                # 新增 ResNet 塊到列表
                resnets.append(
                    ResnetBlock2D(
                        # 設定輸入通道數為 ResNet 輸入通道加上跳躍連線通道
                        in_channels=resnet_in_channels + res_skip_channels,
                        # 設定輸出通道數
                        out_channels=out_channels,
                        # 設定時間嵌入通道數
                        temb_channels=temb_channels,
                        # 設定小常數用於數值穩定性
                        eps=resnet_eps,
                        # 設定分組數
                        groups=resnet_groups,
                        # 設定 dropout 比例
                        dropout=dropout,
                        # 設定時間嵌入的歸一化方式
                        time_embedding_norm=resnet_time_scale_shift,
                        # 設定啟用函式
                        non_linearity=resnet_act_fn,
                        # 設定輸出縮放因子
                        output_scale_factor=output_scale_factor,
                        # 設定是否預歸一化
                        pre_norm=resnet_pre_norm,
                        # 設定是否跳過時間啟用
                        skip_time_act=skip_time_act,
                    )
                )
    
                # 根據是否支援縮放點積注意力選擇處理器
                processor = (
                    AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
                )
    
                # 新增 Attention 模組到列表
                attentions.append(
                    Attention(
                        # 設定查詢維度
                        query_dim=out_channels,
                        # 設定交叉注意力維度
                        cross_attention_dim=out_channels,
                        # 設定頭的數量
                        heads=self.num_heads,
                        # 設定每個頭的維度
                        dim_head=self.attention_head_dim,
                        # 設定額外的 KV 投影維度
                        added_kv_proj_dim=cross_attention_dim,
                        # 設定歸一化的組數
                        norm_num_groups=resnet_groups,
                        # 設定是否使用偏置
                        bias=True,
                        # 設定是否上溯 softmax
                        upcast_softmax=True,
                        # 設定是否僅使用交叉注意力
                        only_cross_attention=only_cross_attention,
                        # 設定交叉注意力的歸一化方式
                        cross_attention_norm=cross_attention_norm,
                        # 設定處理器
                        processor=processor,
                    )
                )
            # 將 Attention 模組列表轉換為 ModuleList
            self.attentions = nn.ModuleList(attentions)
            # 將 ResNet 模組列表轉換為 ModuleList
            self.resnets = nn.ModuleList(resnets)
    
            # 如果需要新增上取樣模組
            if add_upsample:
                # 建立一個上取樣的 ResNet 模組列表
                self.upsamplers = nn.ModuleList(
                    [
                        ResnetBlock2D(
                            # 設定上取樣的輸入和輸出通道數
                            in_channels=out_channels,
                            out_channels=out_channels,
                            # 設定時間嵌入通道數
                            temb_channels=temb_channels,
                            # 設定小常數用於數值穩定性
                            eps=resnet_eps,
                            # 設定分組數
                            groups=resnet_groups,
                            # 設定 dropout 比例
                            dropout=dropout,
                            # 設定時間嵌入的歸一化方式
                            time_embedding_norm=resnet_time_scale_shift,
                            # 設定啟用函式
                            non_linearity=resnet_act_fn,
                            # 設定輸出縮放因子
                            output_scale_factor=output_scale_factor,
                            # 設定是否預歸一化
                            pre_norm=resnet_pre_norm,
                            # 設定是否跳過時間啟用
                            skip_time_act=skip_time_act,
                            # 設定為上取樣模式
                            up=True,
                        )
                    ]
                )
            else:
                # 如果不需要上取樣,則將上取樣模組設定為 None
                self.upsamplers = None
    
            # 初始化梯度檢查點設定為 False
            self.gradient_checkpointing = False
            # 設定解析度索引
            self.resolution_idx = resolution_idx
    # 定義前向傳播方法,接收多個輸入引數
        def forward(
            self,
            # 當前隱藏狀態的張量
            hidden_states: torch.Tensor,
            # 包含殘差隱藏狀態的元組
            res_hidden_states_tuple: Tuple[torch.Tensor, ...],
            # 可選的時間嵌入張量
            temb: Optional[torch.Tensor] = None,
            # 可選的編碼器隱藏狀態張量
            encoder_hidden_states: Optional[torch.Tensor] = None,
            # 可選的上取樣大小
            upsample_size: Optional[int] = None,
            # 可選的注意力掩碼張量
            attention_mask: Optional[torch.Tensor] = None,
            # 可選的跨注意力引數字典
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 可選的編碼器注意力掩碼張量
            encoder_attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:  # 定義函式的返回型別為 torch.Tensor
        # 如果 cross_attention_kwargs 為 None,則初始化為空字典
        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
        # 如果 cross_attention_kwargs 中的 "scale" 存在,發出警告,提示已棄用
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

        # 如果 attention_mask 為 None
        if attention_mask is None:
            # 如果 encoder_hidden_states 已定義,則進行交叉注意力,使用 encoder_attention_mask
            mask = None if encoder_hidden_states is None else encoder_attention_mask
        else:
            # 如果 attention_mask 已定義,不檢查 encoder_attention_mask
            # 這樣做是為了相容 UnCLIP,後者使用 'attention_mask' 引數作為交叉注意力掩碼
            # TODO: UnCLIP 應該透過 encoder_attention_mask 參數列達交叉注意力掩碼,而不是透過 attention_mask
            #       那樣可以簡化整個 if/else 語句塊
            mask = attention_mask  # 使用提供的 attention_mask

        # 遍歷 self.resnets 和 self.attentions 的元素
        for resnet, attn in zip(self.resnets, self.attentions):
            # 獲取最後一項的殘差隱藏狀態
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新 res_hidden_states_tuple,去掉最後一項
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
            # 將當前的 hidden_states 和殘差隱藏狀態在維度 1 上連線
            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

            # 如果處於訓練模式並且開啟了梯度檢查點
            if self.training and self.gradient_checkpointing:

                # 定義建立自定義前向傳播函式的內部函式
                def create_custom_forward(module, return_dict=None):
                    # 定義自定義前向傳播函式
                    def custom_forward(*inputs):
                        # 如果 return_dict 不為 None,則返回字典形式的結果
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)  # 否則直接返回結果

                    return custom_forward  # 返回自定義前向傳播函式

                # 使用檢查點機制進行前向傳播,節省記憶體
                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
                # 進行注意力計算
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=mask,
                    **cross_attention_kwargs,
                )
            else:
                # 直接進行前向傳播計算
                hidden_states = resnet(hidden_states, temb)

                # 進行注意力計算
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=mask,
                    **cross_attention_kwargs,
                )

        # 如果存在上取樣器
        if self.upsamplers is not None:
            # 遍歷所有上取樣器
            for upsampler in self.upsamplers:
                # 進行上取樣操作
                hidden_states = upsampler(hidden_states, temb)

        # 返回最終的隱藏狀態
        return hidden_states
# 定義一個名為 KUpBlock2D 的神經網路模組,繼承自 nn.Module
class KUpBlock2D(nn.Module):
    # 初始化函式,設定網路層的引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: int,  # 輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        resolution_idx: int,  # 解析度索引
        dropout: float = 0.0,  # dropout 比例,預設 0
        num_layers: int = 5,  # 網路層數,預設 5
        resnet_eps: float = 1e-5,  # ResNet 中的 epsilon 值
        resnet_act_fn: str = "gelu",  # ResNet 中使用的啟用函式,預設 "gelu"
        resnet_group_size: Optional[int] = 32,  # ResNet 中的組大小,預設 32
        add_upsample: bool = True,  # 是否新增上取樣層,預設 True
    ):
        # 呼叫父類初始化函式
        super().__init__()
        # 建立一個空的列表,用於存放 ResNet 模組
        resnets = []
        # 定義輸入通道的數量,設定為輸出通道的兩倍
        k_in_channels = 2 * out_channels
        # 定義輸出通道的數量
        k_out_channels = in_channels
        # 減少層數,以適應後續的迴圈
        num_layers = num_layers - 1

        # 建立指定層數的 ResNet 模組
        for i in range(num_layers):
            # 第一層的輸入通道為 k_in_channels,其餘層為 out_channels
            in_channels = k_in_channels if i == 0 else out_channels
            # 計算組的數量
            groups = in_channels // resnet_group_size
            # 計算輸出組的數量
            groups_out = out_channels // resnet_group_size

            # 將 ResNet 模組新增到列表中
            resnets.append(
                ResnetBlockCondNorm2D(
                    in_channels=in_channels,  # 輸入通道數
                    out_channels=k_out_channels if (i == num_layers - 1) else out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # epsilon 值
                    groups=groups,  # 輸入組數量
                    groups_out=groups_out,  # 輸出組數量
                    dropout=dropout,  # dropout 比例
                    non_linearity=resnet_act_fn,  # 啟用函式
                    time_embedding_norm="ada_group",  # 時間嵌入規範化方式
                    conv_shortcut_bias=False,  # 是否使用卷積快捷連線的偏置
                )
            )

        # 將 ResNet 模組列表轉換為 nn.ModuleList
        self.resnets = nn.ModuleList(resnets)

        # 根據是否新增上取樣層來初始化上取樣模組
        if add_upsample:
            # 如果新增上取樣,建立上取樣層列表
            self.upsamplers = nn.ModuleList([KUpsample2D()])
        else:
            # 如果不新增上取樣,設定為 None
            self.upsamplers = None

        # 初始化梯度檢查點為 False
        self.gradient_checkpointing = False
        # 儲存解析度索引
        self.resolution_idx = resolution_idx

    # 定義前向傳播函式
    def forward(
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 傳入的隱藏狀態元組
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        upsample_size: Optional[int] = None,  # 可選的上取樣大小
        *args,  # 額外的位置引數
        **kwargs,  # 額外的關鍵字引數
    # 定義返回值為 torch.Tensor 的函式結束部分
        ) -> torch.Tensor:
            # 檢查傳入引數是否包含 args 或者 kwargs 中的 "scale" 引數
            if len(args) > 0 or kwargs.get("scale", None) is not None:
                # 定義棄用的提示資訊,告知使用者應刪除 "scale" 引數
                deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
                # 呼叫 deprecate 函式記錄棄用資訊
                deprecate("scale", "1.0.0", deprecation_message)
    
            # 取 res_hidden_states_tuple 的最後一個元素
            res_hidden_states_tuple = res_hidden_states_tuple[-1]
            # 如果 res_hidden_states_tuple 不為 None,則將其與 hidden_states 拼接
            if res_hidden_states_tuple is not None:
                hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
    
            # 遍歷 self.resnets 列表中的每個 resnet 模組
            for resnet in self.resnets:
                # 如果處於訓練模式且開啟梯度檢查點功能
                if self.training and self.gradient_checkpointing:
    
                    # 定義一個建立自定義前向函式的內部函式
                    def create_custom_forward(module):
                        # 定義自定義前向函式,呼叫模組處理輸入
                        def custom_forward(*inputs):
                            return module(*inputs)
    
                        return custom_forward
    
                    # 檢查 PyTorch 版本是否大於等於 1.11.0
                    if is_torch_version(">=", "1.11.0"):
                        # 使用檢查點功能進行前向傳播,避免計算圖儲存
                        hidden_states = torch.utils.checkpoint.checkpoint(
                            create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
                        )
                    else:
                        # 使用檢查點功能進行前向傳播
                        hidden_states = torch.utils.checkpoint.checkpoint(
                            create_custom_forward(resnet), hidden_states, temb
                        )
                else:
                    # 在非訓練模式下直接透過 resnet 處理 hidden_states
                    hidden_states = resnet(hidden_states, temb)
    
            # 如果存在 upsamplers,則遍歷每個 upsampler
            if self.upsamplers is not None:
                for upsampler in self.upsamplers:
                    # 透過 upsampler 處理 hidden_states
                    hidden_states = upsampler(hidden_states)
    
            # 返回處理後的 hidden_states
            return hidden_states
# 定義一個 KCrossAttnUpBlock2D 類,繼承自 nn.Module
class KCrossAttnUpBlock2D(nn.Module):
    # 初始化方法,定義該類的屬性
    def __init__(
        # 輸入通道數
        in_channels: int,
        # 輸出通道數
        out_channels: int,
        # 額外的嵌入通道數
        temb_channels: int,
        # 當前解析度索引
        resolution_idx: int,
        # dropout 機率,預設為 0.0
        dropout: float = 0.0,
        # 殘差網路的層數,預設為 4
        num_layers: int = 4,
        # 殘差網路的 epsilon 值,預設為 1e-5
        resnet_eps: float = 1e-5,
        # 殘差網路的啟用函式型別,預設為 "gelu"
        resnet_act_fn: str = "gelu",
        # 殘差網路的分組大小,預設為 32
        resnet_group_size: int = 32,
        # 注意力的維度,預設為 1
        attention_head_dim: int = 1,  # attention dim_head
        # 交叉注意力的維度,預設為 768
        cross_attention_dim: int = 768,
        # 是否新增上取樣,預設為 True
        add_upsample: bool = True,
        # 是否上溢注意力,預設為 False
        upcast_attention: bool = False,
    ):
        # 呼叫父類的建構函式
        super().__init__()
        # 初始化一個空列表,用於儲存 ResNet 塊
        resnets = []
        # 初始化一個空列表,用於儲存注意力塊
        attentions = []

        # 判斷是否為第一個塊:輸入、輸出和時間嵌入通道是否相等
        is_first_block = in_channels == out_channels == temb_channels
        # 判斷是否為中間塊:輸入和輸出通道是否不相等
        is_middle_block = in_channels != out_channels
        # 如果是第一個塊,設定為 True 以新增自注意力
        add_self_attention = True if is_first_block else False

        # 設定跨注意力的標誌為 True
        self.has_cross_attention = True
        # 儲存注意力頭的維度
        self.attention_head_dim = attention_head_dim

        # 定義當前塊的輸入通道,若是第一個塊則使用輸出通道,否則使用兩倍的輸出通道
        k_in_channels = out_channels if is_first_block else 2 * out_channels
        # 當前塊的輸出通道為輸入通道
        k_out_channels = in_channels

        # 減少層數以計算迴圈中的層數
        num_layers = num_layers - 1

        # 根據層數迴圈建立 ResNet 塊和注意力塊
        for i in range(num_layers):
            # 第一個層使用 k_in_channels,後續層使用 out_channels
            in_channels = k_in_channels if i == 0 else out_channels
            # 計算組數,以便在 ResNet 中分組
            groups = in_channels // resnet_group_size
            groups_out = out_channels // resnet_group_size

            # 判斷是否為中間塊並且是最後一層,設定卷積的輸出通道
            if is_middle_block and (i == num_layers - 1):
                conv_2d_out_channels = k_out_channels
            else:
                # 如果不是,設定為 None
                conv_2d_out_channels = None

            # 建立並新增 ResNet 塊到 resnets 列表
            resnets.append(
                ResnetBlockCondNorm2D(
                    # 輸入通道
                    in_channels=in_channels,
                    # 輸出通道
                    out_channels=out_channels,
                    # 卷積輸出通道
                    conv_2d_out_channels=conv_2d_out_channels,
                    # 時間嵌入通道
                    temb_channels=temb_channels,
                    # 設定 epsilon 值
                    eps=resnet_eps,
                    # 輸入組數
                    groups=groups,
                    # 輸出組數
                    groups_out=groups_out,
                    # dropout 機率
                    dropout=dropout,
                    # 非線性啟用函式
                    non_linearity=resnet_act_fn,
                    # 時間嵌入的歸一化方式
                    time_embedding_norm="ada_group",
                    # 是否使用卷積快捷連線的偏置
                    conv_shortcut_bias=False,
                )
            )
            # 建立並新增註意力塊到 attentions 列表
            attentions.append(
                KAttentionBlock(
                    # 最後一個層使用 k_out_channels,否則使用 out_channels
                    k_out_channels if (i == num_layers - 1) else out_channels,
                    # 最後一個層注意力維度
                    k_out_channels // attention_head_dim
                    if (i == num_layers - 1)
                    else out_channels // attention_head_dim,
                    # 注意力頭的維度
                    attention_head_dim,
                    # 跨注意力維度
                    cross_attention_dim=cross_attention_dim,
                    # 時間嵌入通道
                    temb_channels=temb_channels,
                    # 是否新增註意力偏置
                    attention_bias=True,
                    # 是否新增自注意力
                    add_self_attention=add_self_attention,
                    # 跨注意力歸一化方式
                    cross_attention_norm="layer_norm",
                    # 是否上溯注意力
                    upcast_attention=upcast_attention,
                )
            )

        # 將 ResNet 塊列表轉為 PyTorch 的 ModuleList
        self.resnets = nn.ModuleList(resnets)
        # 將注意力塊列表轉為 PyTorch 的 ModuleList
        self.attentions = nn.ModuleList(attentions)

        # 如果需要上取樣,則建立一個包含上取樣塊的 ModuleList
        if add_upsample:
            self.upsamplers = nn.ModuleList([KUpsample2D()])
        else:
            # 否則將上取樣塊設定為 None
            self.upsamplers = None

        # 初始化梯度檢查點為 False
        self.gradient_checkpointing = False
        # 儲存當前解析度的索引
        self.resolution_idx = resolution_idx
    # 定義前向傳播函式,接受多個輸入引數,返回處理後的張量
        def forward(
            self,
            hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
            res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 先前隱藏狀態的元組
            temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
            encoder_hidden_states: Optional[torch.Tensor] = None,  # 可選的編碼器隱藏狀態
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 交叉注意力的可選引數
            upsample_size: Optional[int] = None,  # 可選的上取樣尺寸
            attention_mask: Optional[torch.Tensor] = None,  # 可選的注意力掩碼
            encoder_attention_mask: Optional[torch.Tensor] = None,  # 可選的編碼器注意力掩碼
        ) -> torch.Tensor:  # 函式返回型別為張量
            res_hidden_states_tuple = res_hidden_states_tuple[-1]  # 獲取最後一個隱藏狀態
            if res_hidden_states_tuple is not None:  # 檢查是否存在先前的隱藏狀態
                hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)  # 拼接當前和先前的隱藏狀態
    
            for resnet, attn in zip(self.resnets, self.attentions):  # 遍歷每個殘差網路和注意力層
                if self.training and self.gradient_checkpointing:  # 檢查是否在訓練模式且使用梯度檢查點
    
                    def create_custom_forward(module, return_dict=None):  # 定義建立自定義前向函式的內部函式
                        def custom_forward(*inputs):  # 自定義前向函式
                            if return_dict is not None:  # 檢查是否需要返回字典
                                return module(*inputs, return_dict=return_dict)  # 使用返回字典的方式呼叫模組
                            else:
                                return module(*inputs)  # 普通呼叫模組
    
                        return custom_forward  # 返回自定義前向函式
    
                    ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}  # 根據Torch版本設定檢查點引數
                    hidden_states = torch.utils.checkpoint.checkpoint(  # 使用檢查點進行前向傳播以節省記憶體
                        create_custom_forward(resnet),  # 建立自定義前向函式
                        hidden_states,  # 輸入隱藏狀態
                        temb,  # 輸入時間嵌入
                        **ckpt_kwargs,  # 傳遞檢查點引數
                    )
                    hidden_states = attn(  # 透過注意力層處理隱藏狀態
                        hidden_states,  # 輸入隱藏狀態
                        encoder_hidden_states=encoder_hidden_states,  # 輸入編碼器隱藏狀態
                        emb=temb,  # 輸入時間嵌入
                        attention_mask=attention_mask,  # 輸入注意力掩碼
                        cross_attention_kwargs=cross_attention_kwargs,  # 交叉注意力引數
                        encoder_attention_mask=encoder_attention_mask,  # 編碼器注意力掩碼
                    )
                else:  # 如果不使用梯度檢查點
                    hidden_states = resnet(hidden_states, temb)  # 直接透過殘差網路處理隱藏狀態
                    hidden_states = attn(  # 透過注意力層處理隱藏狀態
                        hidden_states,  # 輸入隱藏狀態
                        encoder_hidden_states=encoder_hidden_states,  # 輸入編碼器隱藏狀態
                        emb=temb,  # 輸入時間嵌入
                        attention_mask=attention_mask,  # 輸入注意力掩碼
                        cross_attention_kwargs=cross_attention_kwargs,  # 交叉注意力引數
                        encoder_attention_mask=encoder_attention_mask,  # 編碼器注意力掩碼
                    )
    
            if self.upsamplers is not None:  # 檢查是否有上取樣層
                for upsampler in self.upsamplers:  # 遍歷每個上取樣層
                    hidden_states = upsampler(hidden_states)  # 透過上取樣層處理隱藏狀態
    
            return hidden_states  # 返回處理後的隱藏狀態
# 可以潛在地更名為 `No-feed-forward` 注意力
class KAttentionBlock(nn.Module):
    r"""
    基本的 Transformer 塊。

    引數:
        dim (`int`): 輸入和輸出的通道數。
        num_attention_heads (`int`): 用於多頭注意力的頭數。
        attention_head_dim (`int`): 每個頭的通道數。
        dropout (`float`, *可選*, 預設為 0.0): 使用的丟棄機率。
        cross_attention_dim (`int`, *可選*): 用於交叉注意力的 encoder_hidden_states 向量的大小。
        attention_bias (`bool`, *可選*, 預設為 `False`):
            配置注意力層是否應該包含偏置引數。
        upcast_attention (`bool`, *可選*, 預設為 `False`):
            設定為 `True` 以將注意力計算上調為 `float32`。
        temb_channels (`int`, *可選*, 預設為 768):
            令牌嵌入中的通道數。
        add_self_attention (`bool`, *可選*, 預設為 `False`):
            設定為 `True` 以將自注意力新增到塊中。
        cross_attention_norm (`str`, *可選*, 預設為 `None`):
            用於交叉注意力的規範化型別。可以是 `None`、`layer_norm` 或 `group_norm`。
        group_size (`int`, *可選*, 預設為 32):
            用於組規範化將通道分成的組數。
    """

    def __init__(
        self,
        dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        dropout: float = 0.0,
        cross_attention_dim: Optional[int] = None,
        attention_bias: bool = False,
        upcast_attention: bool = False,
        temb_channels: int = 768,  # 用於 ada_group_norm
        add_self_attention: bool = False,
        cross_attention_norm: Optional[str] = None,
        group_size: int = 32,
    ):
        # 呼叫父類建構函式,初始化 nn.Module
        super().__init__()
        # 設定是否新增自注意力的標誌
        self.add_self_attention = add_self_attention

        # 1. 自注意力
        if add_self_attention:
            # 初始化自注意力的歸一化層
            self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
            # 初始化自注意力機制
            self.attn1 = Attention(
                query_dim=dim,  # 查詢向量的維度
                heads=num_attention_heads,  # 注意力頭數
                dim_head=attention_head_dim,  # 每個頭的維度
                dropout=dropout,  # 丟棄率
                bias=attention_bias,  # 是否使用偏置
                cross_attention_dim=None,  # 交叉注意力維度
                cross_attention_norm=None,  # 交叉注意力的歸一化
            )

        # 2. 交叉注意力
        # 初始化交叉注意力的歸一化層
        self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
        # 初始化交叉注意力機制
        self.attn2 = Attention(
            query_dim=dim,  # 查詢向量的維度
            cross_attention_dim=cross_attention_dim,  # 交叉注意力維度
            heads=num_attention_heads,  # 注意力頭數
            dim_head=attention_head_dim,  # 每個頭的維度
            dropout=dropout,  # 丟棄率
            bias=attention_bias,  # 是否使用偏置
            upcast_attention=upcast_attention,  # 是否上調注意力計算
            cross_attention_norm=cross_attention_norm,  # 交叉注意力的歸一化
        )
    # 將隱藏狀態轉換為 3D 張量,包含 batch size, height*weight 和通道數
    def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
        # 重新排列維度,並調整形狀為 (batch size, height*weight, -1)
        return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)

    # 將隱藏狀態轉換為 4D 張量,包含 batch size, 通道數, height 和 weight
    def _to_4d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
        # 重新排列維度,並調整形狀為 (batch size, -1, height, weight)
        return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        # TODO: 將 emb 標記為非可選 (self.norm2 需要它)。
        #       需要評估對位置引數介面更改的影響。
        emb: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # 如果 cross_attention_kwargs 為空,則初始化為一個空字典
        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
        # 檢查 "scale" 引數是否存在,如果存在則發出警告
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

        # 1. 自注意力
        if self.add_self_attention:
            # 使用 norm1 對隱藏狀態進行歸一化處理
            norm_hidden_states = self.norm1(hidden_states, emb)

            # 獲取歸一化後狀態的高度和寬度
            height, weight = norm_hidden_states.shape[2:]
            # 將歸一化後的隱藏狀態轉換為 3D 張量
            norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)

            # 執行自注意力操作
            attn_output = self.attn1(
                norm_hidden_states,
                encoder_hidden_states=None,
                attention_mask=attention_mask,
                **cross_attention_kwargs,
            )
            # 將自注意力輸出轉換為 4D 張量
            attn_output = self._to_4d(attn_output, height, weight)

            # 將自注意力輸出與原始隱藏狀態相加
            hidden_states = attn_output + hidden_states

        # 2. 交叉注意力或無交叉注意力
        # 使用 norm2 對隱藏狀態進行歸一化處理
        norm_hidden_states = self.norm2(hidden_states, emb)

        # 獲取歸一化後狀態的高度和寬度
        height, weight = norm_hidden_states.shape[2:]
        # 將歸一化後的隱藏狀態轉換為 3D 張量
        norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
        # 執行交叉注意力操作
        attn_output = self.attn2(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask,
            **cross_attention_kwargs,
        )
        # 將交叉注意力輸出轉換為 4D 張量
        attn_output = self._to_4d(attn_output, height, weight)

        # 將交叉注意力輸出與隱藏狀態相加
        hidden_states = attn_output + hidden_states

        # 返回最終的隱藏狀態
        return hidden_states

相關文章