diffusers-原始碼解析-十五-

绝不原创的飞龙發表於2024-10-22

diffusers 原始碼解析(十五)

.\diffusers\models\unets\unet_3d_condition.py

# 版權宣告,宣告此程式碼的版權資訊和所有權
# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
# 版權宣告,宣告此程式碼的版權資訊和所有權
# Copyright 2024 The ModelScope Team.
#
# 許可宣告,宣告本程式碼使用的 Apache 許可證 2.0 版本
# Licensed under the Apache License, Version 2.0 (the "License");
# 使用此檔案前需遵守許可證規定
# you may not use this file except in compliance with the License.
# 可在以下網址獲取許可證副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 免責宣告,說明軟體在許可下按 "原樣" 提供,不附加任何明示或暗示的保證
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 許可證中規定的許可權和限制說明
# See the License for the specific language governing permissions and
# limitations under the License.

# 從 dataclasses 模組匯入 dataclass 裝飾器
from dataclasses import dataclass
# 從 typing 模組匯入所需的型別提示
from typing import Any, Dict, List, Optional, Tuple, Union

# 匯入 PyTorch 庫
import torch
# 匯入 PyTorch 神經網路模組
import torch.nn as nn
# 匯入 PyTorch 的檢查點工具
import torch.utils.checkpoint

# 匯入配置相關的工具類和函式
from ...configuration_utils import ConfigMixin, register_to_config
# 匯入 UNet2D 條件載入器混合類
from ...loaders import UNet2DConditionLoadersMixin
# 匯入基本輸出類和日誌工具
from ...utils import BaseOutput, logging
# 匯入啟用函式獲取工具
from ..activations import get_activation
# 匯入各種注意力處理器相關元件
from ..attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,  # 匯入新增鍵值對注意力處理器
    CROSS_ATTENTION_PROCESSORS,      # 匯入交叉注意力處理器
    Attention,                       # 匯入注意力類
    AttentionProcessor,              # 匯入注意力處理器基類
    AttnAddedKVProcessor,            # 匯入新增鍵值對的注意力處理器
    AttnProcessor,                   # 匯入普通注意力處理器
    FusedAttnProcessor2_0,           # 匯入融合注意力處理器
)
# 匯入時間步嵌入和時間步類
from ..embeddings import TimestepEmbedding, Timesteps
# 匯入模型混合類
from ..modeling_utils import ModelMixin
# 匯入時間變換器模型
from ..transformers.transformer_temporal import TransformerTemporalModel
# 匯入 3D UNet 相關的塊
from .unet_3d_blocks import (
    CrossAttnDownBlock3D,          # 匯入交叉注意力下采樣塊
    CrossAttnUpBlock3D,            # 匯入交叉注意力上取樣塊
    DownBlock3D,                   # 匯入下采樣塊
    UNetMidBlock3DCrossAttn,      # 匯入 UNet 中間交叉注意力塊
    UpBlock3D,                     # 匯入上取樣塊
    get_down_block,                # 匯入獲取下采樣塊的函式
    get_up_block,                  # 匯入獲取上取樣塊的函式
)

# 建立日誌記錄器,使用當前模組的名稱
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定義 UNet3DConditionOutput 資料類,繼承自 BaseOutput
@dataclass
class UNet3DConditionOutput(BaseOutput):
    """
    [`UNet3DConditionModel`] 的輸出類。

    引數:
        sample (`torch.Tensor` 的形狀為 `(batch_size, num_channels, num_frames, height, width)`):
            基於 `encoder_hidden_states` 輸入的隱藏狀態輸出。模型最後一層的輸出。
    """

    sample: torch.Tensor  # 定義樣本輸出,型別為 PyTorch 張量

# 定義 UNet3DConditionModel 類,繼承自多個混合類
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
    r"""
    條件 3D UNet 模型,接受噪聲樣本、條件狀態和時間步,並返回形狀為樣本的輸出。

    此模型繼承自 [`ModelMixin`]。有關其通用方法的文件,請參閱超類文件(如下載或儲存)。
    # 引數說明部分
    Parameters:
        # 輸入/輸出樣本的高度和寬度,型別可以為整數或元組,預設為 None
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
        # 輸入樣本的通道數,預設為 4
        in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
        # 輸出的通道數,預設為 4
        out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
        # 使用的下采樣塊型別的元組,預設為指定的四種塊
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`):
            The tuple of downsample blocks to use.
        # 使用的上取樣塊型別的元組,預設為指定的四種塊
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`):
            The tuple of upsample blocks to use.
        # 每個塊的輸出通道數的元組,預設為 (320, 640, 1280, 1280)
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        # 每個塊的層數,預設為 2
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        # 下采樣卷積使用的填充,預設為 1
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        # 中間塊使用的縮放因子,預設為 1.0
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
        # 使用的啟用函式,預設為 "silu"
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        # 用於歸一化的組數,預設為 32;如果為 None,則跳過歸一化和啟用層
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
            If `None`, normalization and activation layers is skipped in post-processing.
        # 歸一化使用的 epsilon 值,預設為 1e-5
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
        # 交叉注意力特徵的維度,預設為 1024
        cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features.
        # 注意力頭的維度,預設為 64
        attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
        # 注意力頭的數量,型別為整數,預設為 None
        num_attention_heads (`int`, *optional*): The number of attention heads.
        # 時間條件投影層的維度,預設為 None
        time_cond_proj_dim (`int`, *optional*, defaults to `None`):
            The dimension of `cond_proj` layer in the timestep embedding.
    """

    # 是否支援梯度檢查點,預設為 False
    _supports_gradient_checkpointing = False

    # 將此類註冊到配置中
    @register_to_config
    # 初始化方法,用於建立類的例項
        def __init__(
            # 樣本大小,預設為 None
            self,
            sample_size: Optional[int] = None,
            # 輸入通道數量,預設為 4
            in_channels: int = 4,
            # 輸出通道數量,預設為 4
            out_channels: int = 4,
            # 下采樣塊型別的元組,定義模型的下采樣結構
            down_block_types: Tuple[str, ...] = (
                "CrossAttnDownBlock3D",
                "CrossAttnDownBlock3D",
                "CrossAttnDownBlock3D",
                "DownBlock3D",
            ),
            # 上取樣塊型別的元組,定義模型的上取樣結構
            up_block_types: Tuple[str, ...] = (
                "UpBlock3D",
                "CrossAttnUpBlock3D",
                "CrossAttnUpBlock3D",
                "CrossAttnUpBlock3D",
            ),
            # 每個塊的輸出通道數量,定義模型每個層的通道設定
            block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
            # 每個塊的層數,預設為 2
            layers_per_block: int = 2,
            # 下采樣時的填充大小,預設為 1
            downsample_padding: int = 1,
            # 中間塊的縮放因子,預設為 1
            mid_block_scale_factor: float = 1,
            # 啟用函式型別,預設為 "silu"
            act_fn: str = "silu",
            # 歸一化組的數量,預設為 32
            norm_num_groups: Optional[int] = 32,
            # 歸一化的 epsilon 值,預設為 1e-5
            norm_eps: float = 1e-5,
            # 跨注意力維度,預設為 1024
            cross_attention_dim: int = 1024,
            # 注意力頭的維度,可以是單一整數或整數元組,預設為 64
            attention_head_dim: Union[int, Tuple[int]] = 64,
            # 注意力頭的數量,可選引數
            num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
            # 時間條件投影維度,可選引數
            time_cond_proj_dim: Optional[int] = None,
        @property
        # 從 UNet2DConditionModel 複製的屬性,獲取注意力處理器
        # 返回所有注意力處理器的字典,以權重名稱為索引
        def attn_processors(self) -> Dict[str, AttentionProcessor]:
            r"""
            Returns:
                `dict` of attention processors: A dictionary containing all attention processors used in the model with
                indexed by its weight name.
            """
            # 初始化處理器字典
            processors = {}
    
            # 遞迴新增處理器的函式
            def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
                # 如果模組有獲取處理器的方法,新增到處理器字典中
                if hasattr(module, "get_processor"):
                    processors[f"{name}.processor"] = module.get_processor()
    
                # 遍歷子模組,遞迴呼叫該函式
                for sub_name, child in module.named_children():
                    fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
    
                # 返回處理器字典
                return processors
    
            # 遍歷當前類的子模組,呼叫遞迴新增處理器的函式
            for name, module in self.named_children():
                fn_recursive_add_processors(name, module, processors)
    
            # 返回所有處理器
            return processors
    
        # 從 UNet2DConditionModel 複製的設定注意力切片的方法
        # 從 UNet2DConditionModel 複製的設定注意力處理器的方法
    # 定義一個方法用於設定注意力處理器
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        設定用於計算注意力的處理器。
    
        引數:
            processor(`dict` of `AttentionProcessor` 或僅 `AttentionProcessor`):
                例項化的處理器類或一個處理器類的字典,將作為所有 `Attention` 層的處理器。
    
                如果 `processor` 是一個字典,鍵需要定義相應的交叉注意力處理器的路徑。
                在設定可訓練的注意力處理器時,強烈推薦這樣做。
    
        """
        # 獲取當前注意力處理器的數量
        count = len(self.attn_processors.keys())
    
        # 如果傳入的處理器是字典,且數量不等於注意力層數量,丟擲錯誤
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"傳入了一個處理器字典,但處理器的數量 {len(processor)} 與"
                f" 注意力層的數量 {count} 不匹配。請確保傳入 {count} 個處理器類。"
            )
    
        # 定義一個遞迴函式來設定每個模組的處理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 如果模組有設定處理器的方法
            if hasattr(module, "set_processor"):
                # 如果處理器不是字典,直接設定處理器
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 從字典中獲取相應的處理器並設定
                    module.set_processor(processor.pop(f"{name}.processor"))
    
            # 遍歷子模組並遞迴呼叫處理器設定
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
        # 遍歷當前物件的所有子模組,並呼叫遞迴設定函式
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
    # 定義一個方法來啟用前饋層的分塊處理
        def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
            """
            設定注意力處理器以使用 [前饋分塊](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。
    
            引數:
                chunk_size (`int`, *可選*):
                    前饋層的分塊大小。如果未指定,將對維度為`dim`的每個張量單獨執行前饋層。
                dim (`int`, *可選*, 預設為`0`):
                    應對哪個維度進行前饋計算的分塊。可以選擇 dim=0(批次)或 dim=1(序列長度)。
            """
            # 確保 dim 引數為 0 或 1
            if dim not in [0, 1]:
                raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
    
            # 預設的分塊大小為 1
            chunk_size = chunk_size or 1
    
            # 定義一個遞迴函式來設定每個模組的分塊前饋處理
            def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
                # 如果模組具有設定分塊前饋的屬性,則設定它
                if hasattr(module, "set_chunk_feed_forward"):
                    module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
    
                # 遍歷子模組,遞迴呼叫函式
                for child in module.children():
                    fn_recursive_feed_forward(child, chunk_size, dim)
    
            # 遍歷當前例項的子模組,應用遞迴函式
            for module in self.children():
                fn_recursive_feed_forward(module, chunk_size, dim)
    
        # 定義一個方法來禁用前饋層的分塊處理
        def disable_forward_chunking(self):
            # 定義一個遞迴函式來禁用分塊前饋處理
            def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
                # 如果模組具有設定分塊前饋的屬性,則設定為 None
                if hasattr(module, "set_chunk_feed_forward"):
                    module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
    
                # 遍歷子模組,遞迴呼叫函式
                for child in module.children():
                    fn_recursive_feed_forward(child, chunk_size, dim)
    
            # 遍歷當前例項的子模組,應用遞迴函式,禁用分塊
            for module in self.children():
                fn_recursive_feed_forward(module, None, 0)
    
        # 從 diffusers.models.unets.unet_2d_condition 中複製的方法,設定預設注意力處理器
        def set_default_attn_processor(self):
            """
            禁用自定義注意力處理器並設定預設注意力實現。
            """
            # 檢查所有注意力處理器是否為新增的 KV 處理器
            if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
                processor = AttnAddedKVProcessor()  # 設定為新增的 KV 處理器
            # 檢查所有注意力處理器是否為交叉注意力處理器
            elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
                processor = AttnProcessor()  # 設定為普通注意力處理器
            else:
                # 丟擲異常,若注意力處理器型別不符合預期
                raise ValueError(
                    f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
                )
    
            # 設定選定的注意力處理器
            self.set_attn_processor(processor)
    
        # 定義一個私有方法來設定模組的梯度檢查點
        def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
            # 檢查模組是否屬於特定型別
            if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
                module.gradient_checkpointing = value  # 設定梯度檢查點值
    
        # 從 diffusers.models.unets.unet_2d_condition 中複製的方法,啟用自由度
    # 啟用 FreeU 機制,引數為兩個縮放因子和兩個增強因子的值
    def enable_freeu(self, s1, s2, b1, b2):
        r"""從 https://arxiv.org/abs/2309.11497 啟用 FreeU 機制。

        縮放因子的字尾表示它們應用的階段塊。

        請參考 [官方倉庫](https://github.com/ChenyangSi/FreeU) 以獲取在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中已知效果良好的值組合。

        Args:
            s1 (`float`):
                第1階段的縮放因子,用於減弱跳躍特徵的貢獻,以減輕增強去噪過程中的“過平滑效應”。
            s2 (`float`):
                第2階段的縮放因子,用於減弱跳躍特徵的貢獻,以減輕增強去噪過程中的“過平滑效應”。
            b1 (`float`): 第1階段的縮放因子,用於增強骨幹特徵的貢獻。
            b2 (`float`): 第2階段的縮放因子,用於增強骨幹特徵的貢獻。
        """
        # 遍歷上取樣塊,給每個塊設定縮放因子和增強因子
        for i, upsample_block in enumerate(self.up_blocks):
            # 設定第1階段的縮放因子
            setattr(upsample_block, "s1", s1)
            # 設定第2階段的縮放因子
            setattr(upsample_block, "s2", s2)
            # 設定第1階段的增強因子
            setattr(upsample_block, "b1", b1)
            # 設定第2階段的增強因子
            setattr(upsample_block, "b2", b2)

    # 從 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu 複製
    # 禁用 FreeU 機制
    def disable_freeu(self):
        """禁用 FreeU 機制。"""
        # 定義 FreeU 機制的關鍵屬性
        freeu_keys = {"s1", "s2", "b1", "b2"}
        # 遍歷上取樣塊
        for i, upsample_block in enumerate(self.up_blocks):
            # 遍歷 FreeU 關鍵屬性
            for k in freeu_keys:
                # 如果上取樣塊有該屬性,或者該屬性值不為 None
                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
                    # 將屬性值設定為 None,禁用 FreeU
                    setattr(upsample_block, k, None)

    # 從 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections 複製
    # 啟用融合的 QKV 投影
    def fuse_qkv_projections(self):
        """
        啟用融合的 QKV 投影。對於自注意力模組,所有投影矩陣(即查詢、鍵、值)都被融合。對於交叉注意力模組,鍵和值投影矩陣被融合。

        <Tip warning={true}>

        此 API 是 🧪 實驗性的。

        </Tip>
        """
        # 儲存原始的注意力處理器
        self.original_attn_processors = None

        # 遍歷注意力處理器
        for _, attn_processor in self.attn_processors.items():
            # 如果注意力處理器的類名中包含“Added”
            if "Added" in str(attn_processor.__class__.__name__):
                # 丟擲錯誤,表示不支援具有附加 KV 投影的模型
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

        # 儲存當前的注意力處理器
        self.original_attn_processors = self.attn_processors

        # 遍歷所有模組
        for module in self.modules():
            # 如果模組是 Attention 型別
            if isinstance(module, Attention):
                # 融合投影
                module.fuse_projections(fuse=True)

        # 設定注意力處理器為融合的注意力處理器
        self.set_attn_processor(FusedAttnProcessor2_0())

    # 從 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 複製
    # 定義一個方法,用於禁用已啟用的融合 QKV 投影
    def unfuse_qkv_projections(self):
        """禁用已啟用的融合 QKV 投影。
    
        <Tip warning={true}>
    
        該 API 是 🧪 實驗性的。
    
        </Tip>
    
        """
        # 如果存在原始的注意力處理器,則設定當前的注意力處理器為原始處理器
        if self.original_attn_processors is not None:
            self.set_attn_processor(self.original_attn_processors)
    
    # 定義前向傳播方法,接受多個引數進行計算
    def forward(
        self,
        sample: torch.Tensor,  # 輸入樣本,張量格式
        timestep: Union[torch.Tensor, float, int],  # 當前時間步,可以是張量、浮點數或整數
        encoder_hidden_states: torch.Tensor,  # 編碼器的隱藏狀態,張量格式
        class_labels: Optional[torch.Tensor] = None,  # 類別標籤,預設為 None
        timestep_cond: Optional[torch.Tensor] = None,  # 時間步條件,預設為 None
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩碼,預設為 None
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 跨注意力的關鍵字引數,預設為 None
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,  # 降級塊的附加殘差,預設為 None
        mid_block_additional_residual: Optional[torch.Tensor] = None,  # 中間塊的附加殘差,預設為 None
        return_dict: bool = True,  # 是否返回字典格式的結果,預設為 True

.\diffusers\models\unets\unet_i2vgen_xl.py

# 版權宣告,表明版權歸2024年阿里巴巴DAMO-VILAB和HuggingFace團隊所有
# 提供Apache許可證2.0版本的使用條款
# 說明只能在遵循許可證的情況下使用此檔案
# 可在指定網址獲取許可證副本
#
# 除非適用法律或書面協議另有約定,否則軟體按“原樣”分發
# 不提供任何形式的擔保或條件
# 請參見許可證以獲取與許可權和限制相關的具體資訊

from typing import Any, Dict, Optional, Tuple, Union  # 匯入型別提示工具,用於型別註解

import torch  # 匯入PyTorch庫
import torch.nn as nn  # 匯入PyTorch的神經網路模組
import torch.utils.checkpoint  # 匯入PyTorch的檢查點工具

from ...configuration_utils import ConfigMixin, register_to_config  # 從配置工具匯入類和函式
from ...loaders import UNet2DConditionLoadersMixin  # 匯入2D條件載入器混合類
from ...utils import logging  # 匯入日誌工具
from ..activations import get_activation  # 匯入啟用函式獲取工具
from ..attention import Attention, FeedForward  # 匯入注意力機制和前饋網路
from ..attention_processor import (  # 從注意力處理器模組匯入多個處理器
    ADDED_KV_ATTENTION_PROCESSORS,
    CROSS_ATTENTION_PROCESSORS,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
    FusedAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps  # 匯入時間步嵌入和時間步類
from ..modeling_utils import ModelMixin  # 匯入模型混合類
from ..transformers.transformer_temporal import TransformerTemporalModel  # 匯入時間變換器模型
from .unet_3d_blocks import (  # 從3D U-Net塊模組匯入多個類
    CrossAttnDownBlock3D,
    CrossAttnUpBlock3D,
    DownBlock3D,
    UNetMidBlock3DCrossAttn,
    UpBlock3D,
    get_down_block,
    get_up_block,
)
from .unet_3d_condition import UNet3DConditionOutput  # 匯入3D條件輸出類

logger = logging.get_logger(__name__)  # 建立日誌記錄器,用於記錄當前模組的資訊

class I2VGenXLTransformerTemporalEncoder(nn.Module):  # 定義一個名為I2VGenXLTransformerTemporalEncoder的類,繼承自nn.Module
    def __init__(  # 建構函式,用於初始化類的例項
        self,
        dim: int,  # 輸入的特徵維度
        num_attention_heads: int,  # 注意力頭的數量
        attention_head_dim: int,  # 每個注意力頭的維度
        activation_fn: str = "geglu",  # 啟用函式型別,預設使用geglu
        upcast_attention: bool = False,  # 是否提升注意力計算的精度
        ff_inner_dim: Optional[int] = None,  # 前饋網路的內部維度,預設為None
        dropout: int = 0.0,  # dropout機率,預設為0.0
    ):
        super().__init__()  # 呼叫父類建構函式
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)  # 初始化層歸一化層
        self.attn1 = Attention(  # 初始化注意力層
            query_dim=dim,  # 查詢維度
            heads=num_attention_heads,  # 注意力頭數量
            dim_head=attention_head_dim,  # 每個頭的維度
            dropout=dropout,  # dropout機率
            bias=False,  # 不使用偏置
            upcast_attention=upcast_attention,  # 是否提升注意力計算精度
            out_bias=True,  # 輸出使用偏置
        )
        self.ff = FeedForward(  # 初始化前饋網路
            dim,  # 輸入維度
            dropout=dropout,  # dropout機率
            activation_fn=activation_fn,  # 啟用函式型別
            final_dropout=False,  # 最後層不使用dropout
            inner_dim=ff_inner_dim,  # 內部維度
            bias=True,  # 使用偏置
        )

    def forward(  # 定義前向傳播方法
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態
    # 該方法返回處理後的隱藏狀態張量
    ) -> torch.Tensor:
        # 對隱藏狀態進行歸一化處理
        norm_hidden_states = self.norm1(hidden_states)
        # 計算注意力輸出,使用歸一化後的隱藏狀態
        attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
        # 將注意力輸出與原始隱藏狀態相加,更新隱藏狀態
        hidden_states = attn_output + hidden_states
        # 如果隱藏狀態是四維,則去掉第一維
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)
    
        # 透過前饋網路處理隱藏狀態
        ff_output = self.ff(hidden_states)
        # 將前饋輸出與當前隱藏狀態相加,更新隱藏狀態
        hidden_states = ff_output + hidden_states
        # 如果隱藏狀態是四維,則去掉第一維
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)
    
        # 返回最終的隱藏狀態
        return hidden_states
# 定義 I2VGenXL UNet 類,繼承多個混入類以增加功能
class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
    r"""
    I2VGenXL UNet。一個條件3D UNet模型,接收噪聲樣本、條件狀態和時間步,
    返回與樣本形狀相同的輸出。

    該模型繼承自 [`ModelMixin`]。有關所有模型實現的通用方法(如下載或儲存),
    請檢視超類文件。

    引數:
        sample_size (`int` 或 `Tuple[int, int]`, *可選*, 預設值為 `None`):
            輸入/輸出樣本的高度和寬度。
        in_channels (`int`, *可選*, 預設值為 4): 輸入樣本的通道數。
        out_channels (`int`, *可選*, 預設值為 4): 輸出樣本的通道數。
        down_block_types (`Tuple[str]`, *可選*, 預設值為 `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            使用的下采樣塊的元組。
        up_block_types (`Tuple[str]`, *可選*, 預設值為 `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
            使用的上取樣塊的元組。
        block_out_channels (`Tuple[int]`, *可選*, 預設值為 `(320, 640, 1280, 1280)`):
            每個塊的輸出通道元組。
        layers_per_block (`int`, *可選*, 預設值為 2): 每個塊的層數。
        norm_num_groups (`int`, *可選*, 預設值為 32): 用於歸一化的組數。
            如果為 `None`,則跳過後處理中的歸一化和啟用層。
        cross_attention_dim (`int`, *可選*, 預設值為 1280): 跨注意力特徵的維度。
        attention_head_dim (`int`, *可選*, 預設值為 64): 注意力頭的維度。
        num_attention_heads (`int`, *可選*): 注意力頭的數量。
    """

    # 設定不支援梯度檢查點的屬性為 False
    _supports_gradient_checkpointing = False

    @register_to_config
    # 初始化方法,接受多種可選引數以設定模型配置
    def __init__(
        self,
        sample_size: Optional[int] = None,  # 輸入/輸出樣本大小,預設為 None
        in_channels: int = 4,  # 輸入樣本的通道數,預設為 4
        out_channels: int = 4,  # 輸出樣本的通道數,預設為 4
        down_block_types: Tuple[str, ...] = (  # 下采樣塊的型別,預設為指定的元組
            "CrossAttnDownBlock3D",
            "CrossAttnDownBlock3D",
            "CrossAttnDownBlock3D",
            "DownBlock3D",
        ),
        up_block_types: Tuple[str, ...] = (  # 上取樣塊的型別,預設為指定的元組
            "UpBlock3D",
            "CrossAttnUpBlock3D",
            "CrossAttnUpBlock3D",
            "CrossAttnUpBlock3D",
        ),
        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),  # 每個塊的輸出通道,預設為指定的元組
        layers_per_block: int = 2,  # 每個塊的層數,預設為 2
        norm_num_groups: Optional[int] = 32,  # 歸一化組數,預設為 32
        cross_attention_dim: int = 1024,  # 跨注意力特徵的維度,預設為 1024
        attention_head_dim: Union[int, Tuple[int]] = 64,  # 注意力頭的維度,預設為 64
        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,  # 注意力頭的數量,預設為 None
    @property
    # 該屬性從 UNet2DConditionModel 的 attn_processors 複製
    # 定義返回注意力處理器的函式,返回型別為字典,鍵為字串,值為 AttentionProcessor 物件
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # 建立一個空字典,用於儲存處理器
        processors = {}

        # 定義一個遞迴函式,用於新增處理器到字典
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 檢查模組是否有 get_processor 方法
            if hasattr(module, "get_processor"):
                # 將處理器新增到字典中,鍵為名稱加上 ".processor"
                processors[f"{name}.processor"] = module.get_processor()

            # 遍歷模組的子模組
            for sub_name, child in module.named_children():
                # 遞迴呼叫,處理子模組
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            # 返回更新後的處理器字典
            return processors

        # 遍歷當前物件的所有子模組
        for name, module in self.named_children():
            # 呼叫遞迴函式,將處理器新增到字典中
            fn_recursive_add_processors(name, module, processors)

        # 返回包含所有處理器的字典
        return processors

    # 從 diffusers.models.unets.unet_2d_condition 中複製的設定注意力處理器的函式
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        # 獲取當前注意力處理器的數量
        count = len(self.attn_processors.keys())

        # 如果傳入的是字典且數量不匹配,則引發錯誤
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        # 定義一個遞迴函式,用於設定處理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 檢查模組是否有 set_processor 方法
            if hasattr(module, "set_processor"):
                # 如果傳入的處理器不是字典,直接設定
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 從字典中移除並設定對應的處理器
                    module.set_processor(processor.pop(f"{name}.processor"))

            # 遍歷模組的子模組
            for sub_name, child in module.named_children():
                # 遞迴呼叫,設定子模組的處理器
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        # 遍歷當前物件的所有子模組
        for name, module in self.named_children():
            # 呼叫遞迴函式,為每個模組設定處理器
            fn_recursive_attn_processor(name, module, processor)

    # 從 diffusers.models.unets.unet_3d_condition 中複製的啟用前向分塊的函式
    # 啟用前饋層的分塊處理
    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
        """
        設定注意力處理器使用[前饋分塊](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。

        引數:
            chunk_size (`int`, *可選*):
                前饋層的塊大小。如果未指定,將對維度為`dim`的每個張量單獨執行前饋層。
            dim (`int`, *可選*, 預設為`0`):
                前饋計算應分塊的維度。可以選擇dim=0(批次)或dim=1(序列長度)。
        """
        # 檢查維度是否在有效範圍內
        if dim not in [0, 1]:
            # 丟擲錯誤,確保dim只為0或1
            raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")

        # 預設塊大小為1
        chunk_size = chunk_size or 1

        # 定義遞迴函式,用於設定每個模組的前饋分塊
        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            # 如果模組有設定分塊前饋的方法,呼叫該方法
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            # 遞迴遍歷子模組
            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)

        # 對當前物件的所有子模組應用前饋分塊設定
        for module in self.children():
            fn_recursive_feed_forward(module, chunk_size, dim)

    # 從diffusers.models.unets.unet_3d_condition.UNet3DConditionModel複製的禁用前饋分塊的方法
    def disable_forward_chunking(self):
        # 定義遞迴函式,用於禁用模組的前饋分塊
        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            # 如果模組有設定分塊前饋的方法,呼叫該方法
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            # 遞迴遍歷子模組
            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)

        # 對當前物件的所有子模組應用禁用前饋分塊設定
        for module in self.children():
            fn_recursive_feed_forward(module, None, 0)

    # 從diffusers.models.unets.unet_2d_condition.UNet2DConditionModel複製的設定預設注意力處理器的方法
    def set_default_attn_processor(self):
        """
        禁用自定義注意力處理器並設定預設的注意力實現。
        """
        # 檢查所有注意力處理器是否屬於已新增的KV注意力處理器類
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 如果是,則設定為已新增KV處理器
            processor = AttnAddedKVProcessor()
        # 檢查所有注意力處理器是否屬於交叉注意力處理器類
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 如果是,則設定為標準注意力處理器
            processor = AttnProcessor()
        else:
            # 丟擲錯誤,說明當前的注意力處理器型別不被支援
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )

        # 設定當前物件的注意力處理器為選擇的處理器
        self.set_attn_processor(processor)

    # 從diffusers.models.unets.unet_3d_condition.UNet3DConditionModel複製的設定梯度檢查點的方法
    # 設定梯度檢查點,指定模組和布林值
    def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
        # 檢查模組是否為指定的型別之一
        if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
            # 設定模組的梯度檢查點屬性為指定值
            module.gradient_checkpointing = value

    # 從 UNet2DConditionModel 中複製的啟用 FreeU 方法
    def enable_freeu(self, s1, s2, b1, b2):
        r"""啟用 FreeU 機制,詳情見 https://arxiv.org/abs/2309.11497.

        字尾表示縮放因子應用的階段塊。

        請參考 [官方庫](https://github.com/ChenyangSi/FreeU) 以獲取適用於不同管道(如 Stable Diffusion v1, v2 和 Stable Diffusion XL)的有效值組合。

        引數:
            s1 (`float`):
                階段 1 的縮放因子,用於減弱跳過特徵的貢獻,以緩解增強去噪過程中的“過平滑效應”。
            s2 (`float`):
                階段 2 的縮放因子,用於減弱跳過特徵的貢獻,以緩解增強去噪過程中的“過平滑效應”。
            b1 (`float`): 階段 1 的縮放因子,用於放大主幹特徵的貢獻。
            b2 (`float`): 階段 2 的縮放因子,用於放大主幹特徵的貢獻。
        """
        # 遍歷上取樣塊,索引 i 和塊物件 upsample_block
        for i, upsample_block in enumerate(self.up_blocks):
            # 設定上取樣塊的屬性 s1 為給定值 s1
            setattr(upsample_block, "s1", s1)
            # 設定上取樣塊的屬性 s2 為給定值 s2
            setattr(upsample_block, "s2", s2)
            # 設定上取樣塊的屬性 b1 為給定值 b1
            setattr(upsample_block, "b1", b1)
            # 設定上取樣塊的屬性 b2 為給定值 b2
            setattr(upsample_block, "b2", b2)

    # 從 UNet2DConditionModel 中複製的禁用 FreeU 方法
    def disable_freeu(self):
        """禁用 FreeU 機制。"""
        # 定義 FreeU 相關的屬性鍵
        freeu_keys = {"s1", "s2", "b1", "b2"}
        # 遍歷上取樣塊,索引 i 和塊物件 upsample_block
        for i, upsample_block in enumerate(self.up_blocks):
            # 遍歷 FreeU 屬性鍵
            for k in freeu_keys:
                # 如果上取樣塊具有該屬性或屬性值不為 None
                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
                    # 將上取樣塊的該屬性設定為 None
                    setattr(upsample_block, k, None)

    # 從 UNet2DConditionModel 中複製的融合 QKV 投影方法
    # 定義一個方法,用於啟用融合的 QKV 投影
    def fuse_qkv_projections(self):
        # 提供方法的文件字串,描述其功能和警告資訊
        """
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
        are fused. For cross-attention modules, key and value projection matrices are fused.
    
        <Tip warning={true}>
    
        This API is 🧪 experimental.
    
        </Tip>
        """
        # 初始化原始注意力處理器為 None
        self.original_attn_processors = None
    
        # 遍歷當前物件的注意力處理器
        for _, attn_processor in self.attn_processors.items():
            # 檢查處理器類名中是否包含 "Added"
            if "Added" in str(attn_processor.__class__.__name__):
                # 如果包含,丟擲異常提示不支援
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
    
        # 儲存當前的注意力處理器以備後用
        self.original_attn_processors = self.attn_processors
    
        # 遍歷當前物件的所有模組
        for module in self.modules():
            # 檢查模組是否為 Attention 型別
            if isinstance(module, Attention):
                # 呼叫模組的方法,啟用融合投影
                module.fuse_projections(fuse=True)
    
        # 設定注意力處理器為 FusedAttnProcessor2_0 的例項
        self.set_attn_processor(FusedAttnProcessor2_0())
    
    # 從 UNet2DConditionModel 複製的方法,用於禁用融合的 QKV 投影
    def unfuse_qkv_projections(self):
        # 提供方法的文件字串,描述其功能和警告資訊
        """Disables the fused QKV projection if enabled.
    
        <Tip warning={true}>
    
        This API is 🧪 experimental.
    
        </Tip>
    
        """
        # 檢查原始注意力處理器是否不為 None
        if self.original_attn_processors is not None:
            # 如果不為 None,恢復原始的注意力處理器
            self.set_attn_processor(self.original_attn_processors)
    
    # 定義前向傳播方法,接受多個輸入引數
    def forward(
        self,
        sample: torch.Tensor,
        timestep: Union[torch.Tensor, float, int],
        fps: torch.Tensor,
        image_latents: torch.Tensor,
        image_embeddings: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        return_dict: bool = True,

.\diffusers\models\unets\unet_kandinsky3.py

# 版權宣告,指明該檔案屬於 HuggingFace 團隊,所有權利保留
# 
# 根據 Apache License 2.0 版(“許可證”)授權;
# 除非遵循許可證,否則不得使用此檔案。
# 可以在以下網址獲取許可證的副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非適用法律要求或書面同意,否則根據許可證分發的軟體
# 是在“原樣”基礎上分發的,不附帶任何形式的保證或條件。
# 有關特定語言的許可條款和條件,請參見許可證。

from dataclasses import dataclass  # 從 dataclasses 模組匯入 dataclass 裝飾器
from typing import Dict, Tuple, Union  # 匯入用於型別提示的字典、元組和聯合型別

import torch  # 匯入 PyTorch 庫
import torch.utils.checkpoint  # 匯入 PyTorch 的檢查點工具
from torch import nn  # 從 PyTorch 匯入神經網路模組

from ...configuration_utils import ConfigMixin, register_to_config  # 從配置工具匯入混合類和註冊函式
from ...utils import BaseOutput, logging  # 從工具模組匯入基礎輸出類和日誌功能
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor  # 匯入注意力處理器相關類
from ..embeddings import TimestepEmbedding, Timesteps  # 匯入時間步嵌入相關類
from ..modeling_utils import ModelMixin  # 匯入模型混合類

logger = logging.get_logger(__name__)  # 建立一個記錄器,用於當前模組的日誌記錄

@dataclass  # 將該類標記為資料類,以簡化初始化和表示
class Kandinsky3UNetOutput(BaseOutput):  # 定義 Kandinsky3UNetOutput 類,繼承自 BaseOutput
    sample: torch.Tensor = None  # 定義輸出樣本,預設為 None

class Kandinsky3EncoderProj(nn.Module):  # 定義 Kandinsky3EncoderProj 類,繼承自 nn.Module
    def __init__(self, encoder_hid_dim, cross_attention_dim):  # 初始化方法,接收隱藏維度和交叉注意力維度
        super().__init__()  # 呼叫父類的初始化方法
        self.projection_linear = nn.Linear(encoder_hid_dim, cross_attention_dim, bias=False)  # 定義線性投影層,不使用偏置
        self.projection_norm = nn.LayerNorm(cross_attention_dim)  # 定義層歸一化層

    def forward(self, x):  # 定義前向傳播方法
        x = self.projection_linear(x)  # 透過線性層處理輸入
        x = self.projection_norm(x)  # 透過層歸一化處理輸出
        return x  # 返回處理後的結果

class Kandinsky3UNet(ModelMixin, ConfigMixin):  # 定義 Kandinsky3UNet 類,繼承自 ModelMixin 和 ConfigMixin
    @register_to_config  # 將該方法註冊到配置中
    def __init__(  # 初始化方法
        self,
        in_channels: int = 4,  # 輸入通道數,預設值為 4
        time_embedding_dim: int = 1536,  # 時間嵌入維度,預設值為 1536
        groups: int = 32,  # 組數,預設值為 32
        attention_head_dim: int = 64,  # 注意力頭維度,預設值為 64
        layers_per_block: Union[int, Tuple[int]] = 3,  # 每個塊的層數,預設值為 3,可以是整數或元組
        block_out_channels: Tuple[int] = (384, 768, 1536, 3072),  # 塊輸出通道,預設為指定元組
        cross_attention_dim: Union[int, Tuple[int]] = 4096,  # 交叉注意力維度,預設值為 4096
        encoder_hid_dim: int = 4096,  # 編碼器隱藏維度,預設值為 4096
    @property  # 定義一個屬性
    def attn_processors(self) -> Dict[str, AttentionProcessor]:  # 返回注意力處理器字典
        r"""  # 文件字串,描述該方法的功能
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # 設定一個空字典以遞迴儲存處理器
        processors = {}

        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):  # 定義遞迴函式新增處理器
            if hasattr(module, "set_processor"):  # 檢查模組是否具有 set_processor 屬性
                processors[f"{name}.processor"] = module.processor  # 將處理器新增到字典中

            for sub_name, child in module.named_children():  # 遍歷模組的所有子模組
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)  # 遞迴呼叫自身

            return processors  # 返回更新後的處理器字典

        for name, module in self.named_children():  # 遍歷當前類的所有子模組
            fn_recursive_add_processors(name, module, processors)  # 呼叫遞迴函式

        return processors  # 返回包含所有處理器的字典
    # 定義設定注意力處理器的方法,引數為處理器,可以是 AttentionProcessor 類或其字典形式
        def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
            r"""
            設定用於計算注意力的處理器。
    
            引數:
                processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                    例項化的處理器類或處理器類的字典,將作為 **所有** `Attention` 層的處理器。
    
                    如果 `processor` 是一個字典,鍵需要定義相應交叉注意力處理器的路徑。這在設定可訓練注意力處理器時強烈推薦。
    
            """
            # 獲取當前注意力處理器的數量
            count = len(self.attn_processors.keys())
    
            # 如果傳入的是字典且其長度與注意力層的數量不匹配,則丟擲錯誤
            if isinstance(processor, dict) and len(processor) != count:
                raise ValueError(
                    f"傳入了處理器字典,但處理器的數量 {len(processor)} 與"
                    f" 注意力層的數量 {count} 不匹配。請確保傳入 {count} 個處理器類。"
                )
    
            # 定義遞迴設定注意力處理器的方法
            def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
                # 如果模組有設定處理器的方法
                if hasattr(module, "set_processor"):
                    # 如果處理器不是字典,則直接設定
                    if not isinstance(processor, dict):
                        module.set_processor(processor)
                    else:
                        # 從字典中獲取對應的處理器並設定
                        module.set_processor(processor.pop(f"{name}.processor"))
    
                # 遍歷模組的所有子模組
                for sub_name, child in module.named_children():
                    # 遞迴呼叫處理子模組
                    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
            # 遍歷當前物件的所有子模組
            for name, module in self.named_children():
                # 遞迴設定每個子模組的處理器
                fn_recursive_attn_processor(name, module, processor)
    
        # 定義設定預設注意力處理器的方法
        def set_default_attn_processor(self):
            """
            禁用自定義注意力處理器,並設定預設的注意力實現。
            """
            # 呼叫設定注意力處理器的方法,使用預設的 AttnProcessor 例項
            self.set_attn_processor(AttnProcessor())
    
        # 定義設定梯度檢查點的方法
        def _set_gradient_checkpointing(self, module, value=False):
            # 如果模組有梯度檢查點的屬性
            if hasattr(module, "gradient_checkpointing"):
                # 設定該屬性為指定的值
                module.gradient_checkpointing = value
    # 定義前向傳播函式,接收樣本、時間步以及可選的編碼器隱藏狀態和注意力掩碼
    def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
        # 如果存在編碼器注意力掩碼,則進行調整以適應後續計算
        if encoder_attention_mask is not None:
            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
            # 增加一個維度,以便後續處理
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
    
        # 檢查時間步是否為張量型別
        if not torch.is_tensor(timestep):
            # 根據時間步型別確定資料型別
            dtype = torch.float32 if isinstance(timestep, float) else torch.int32
            # 將時間步轉換為張量並指定裝置
            timestep = torch.tensor([timestep], dtype=dtype, device=sample.device)
        # 如果時間步為標量,則擴充套件為一維張量
        elif len(timestep.shape) == 0:
            timestep = timestep[None].to(sample.device)
    
        # 擴充套件時間步到與批次維度相容的形狀
        timestep = timestep.expand(sample.shape[0])
        # 透過時間投影獲取時間嵌入輸入並轉換為樣本的資料型別
        time_embed_input = self.time_proj(timestep).to(sample.dtype)
        # 獲取時間嵌入
        time_embed = self.time_embedding(time_embed_input)
    
        # 對編碼器隱藏狀態進行線性變換
        encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
    
        # 如果存在編碼器隱藏狀態,則將時間嵌入與隱藏狀態結合
        if encoder_hidden_states is not None:
            time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask)
    
        # 初始化隱藏狀態列表
        hidden_states = []
        # 對輸入樣本進行初步卷積處理
        sample = self.conv_in(sample)
        # 遍歷下采樣塊
        for level, down_sample in enumerate(self.down_blocks):
            # 透過下采樣塊處理樣本
            sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
            # 如果不是最後一個層級,記錄當前樣本狀態
            if level != self.num_levels - 1:
                hidden_states.append(sample)
    
        # 遍歷上取樣塊
        for level, up_sample in enumerate(self.up_blocks):
            # 如果不是第一個層級,則拼接當前樣本與之前的隱藏狀態
            if level != 0:
                sample = torch.cat([sample, hidden_states.pop()], dim=1)
            # 透過上取樣塊處理樣本
            sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
    
        # 進行輸出卷積規範化
        sample = self.conv_norm_out(sample)
        # 進行輸出啟用
        sample = self.conv_act_out(sample)
        # 進行最終輸出卷積
        sample = self.conv_out(sample)
    
        # 根據返回標誌返回相應的結果
        if not return_dict:
            return (sample,)
        # 返回結果物件
        return Kandinsky3UNetOutput(sample=sample)
# 定義 Kandinsky3UpSampleBlock 類,繼承自 nn.Module
class Kandinsky3UpSampleBlock(nn.Module):
    # 初始化方法,設定各引數
    def __init__(
        self,
        in_channels,  # 輸入通道數
        cat_dim,  # 拼接維度
        out_channels,  # 輸出通道數
        time_embed_dim,  # 時間嵌入維度
        context_dim=None,  # 上下文維度,可選
        num_blocks=3,  # 塊的數量
        groups=32,  # 分組數
        head_dim=64,  # 頭維度
        expansion_ratio=4,  # 擴充套件比例
        compression_ratio=2,  # 壓縮比例
        up_sample=True,  # 是否上取樣
        self_attention=True,  # 是否使用自注意力
    ):
        # 呼叫父類初始化方法
        super().__init__()
        # 設定上取樣解析度
        up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1)
        # 設定隱藏通道數
        hidden_channels = (
            [(in_channels + cat_dim, in_channels)]  # 第一層的通道
            + [(in_channels, in_channels)] * (num_blocks - 2)  # 中間層的通道
            + [(in_channels, out_channels)]  # 最後一層的通道
        )
        attentions = []  # 用於儲存注意力塊
        resnets_in = []  # 用於儲存輸入 ResNet 塊
        resnets_out = []  # 用於儲存輸出 ResNet 塊

        # 設定自注意力和上下文維度
        self.self_attention = self_attention
        self.context_dim = context_dim

        # 如果使用自注意力,新增註意力塊
        if self_attention:
            attentions.append(
                Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
            )
        else:
            attentions.append(nn.Identity())  # 否則新增身份對映

        # 遍歷隱藏通道和上取樣解析度
        for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
            # 新增輸入 ResNet 塊
            resnets_in.append(
                Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution)
            )

            # 如果上下文維度不為 None,新增註意力塊
            if context_dim is not None:
                attentions.append(
                    Kandinsky3AttentionBlock(
                        in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
                    )
                )
            else:
                attentions.append(nn.Identity())  # 否則新增身份對映

            # 新增輸出 ResNet 塊
            resnets_out.append(
                Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
            )

        # 將注意力塊和 ResNet 塊轉換為模組列表
        self.attentions = nn.ModuleList(attentions)
        self.resnets_in = nn.ModuleList(resnets_in)
        self.resnets_out = nn.ModuleList(resnets_out)

    # 前向傳播方法
    def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
        # 遍歷注意力塊和 ResNet 塊進行前向計算
        for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
            x = resnet_in(x, time_embed)  # 輸入經過 ResNet 塊
            if self.context_dim is not None:  # 如果上下文維度存在
                x = attention(x, time_embed, context, context_mask, image_mask)  # 應用注意力塊
            x = resnet_out(x, time_embed)  # 輸出經過 ResNet 塊

        # 如果使用自注意力,應用首個注意力塊
        if self.self_attention:
            x = self.attentions[0](x, time_embed, image_mask=image_mask)
        return x  # 返回處理後的結果


# 定義 Kandinsky3DownSampleBlock 類,繼承自 nn.Module
class Kandinsky3DownSampleBlock(nn.Module):
    # 初始化方法,設定各引數
    def __init__(
        self,
        in_channels,  # 輸入通道數
        out_channels,  # 輸出通道數
        time_embed_dim,  # 時間嵌入維度
        context_dim=None,  # 上下文維度,可選
        num_blocks=3,  # 塊的數量
        groups=32,  # 分組數
        head_dim=64,  # 頭維度
        expansion_ratio=4,  # 擴充套件比例
        compression_ratio=2,  # 壓縮比例
        down_sample=True,  # 是否下采樣
        self_attention=True,  # 是否使用自注意力
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 初始化注意力模組列表
        attentions = []
        # 初始化輸入殘差塊列表
        resnets_in = []
        # 初始化輸出殘差塊列表
        resnets_out = []

        # 儲存自注意力標誌
        self.self_attention = self_attention
        # 儲存上下文維度
        self.context_dim = context_dim

        # 如果啟用自注意力
        if self_attention:
            # 新增 Kandinsky3AttentionBlock 到注意力列表
            attentions.append(
                Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
            )
        else:
            # 否則新增身份層(不改變輸入)
            attentions.append(nn.Identity())

        # 生成上取樣解析度列表
        up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]]
        # 生成隱藏通道的元組列表
        hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
        # 遍歷隱藏通道和上取樣解析度
        for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
            # 新增輸入殘差塊到列表
            resnets_in.append(
                Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
            )

            # 如果上下文維度不為 None
            if context_dim is not None:
                # 新增 Kandinsky3AttentionBlock 到注意力列表
                attentions.append(
                    Kandinsky3AttentionBlock(
                        out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
                    )
                )
            else:
                # 否則新增身份層(不改變輸入)
                attentions.append(nn.Identity())

            # 新增輸出殘差塊到列表
            resnets_out.append(
                Kandinsky3ResNetBlock(
                    out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution
                )
            )

        # 將注意力模組列表轉換為 nn.ModuleList 以便管理
        self.attentions = nn.ModuleList(attentions)
        # 將輸入殘差塊列表轉換為 nn.ModuleList 以便管理
        self.resnets_in = nn.ModuleList(resnets_in)
        # 將輸出殘差塊列表轉換為 nn.ModuleList 以便管理
        self.resnets_out = nn.ModuleList(resnets_out)

    # 定義前向傳播方法
    def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
        # 如果啟用自注意力
        if self.self_attention:
            # 使用第一個注意力模組處理輸入
            x = self.attentions[0](x, time_embed, image_mask=image_mask)

        # 遍歷剩餘的注意力模組、輸入和輸出殘差塊
        for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
            # 透過輸入殘差塊處理輸入
            x = resnet_in(x, time_embed)
            # 如果上下文維度不為 None
            if self.context_dim is not None:
                # 使用當前注意力模組處理輸入
                x = attention(x, time_embed, context, context_mask, image_mask)
            # 透過輸出殘差塊處理輸入
            x = resnet_out(x, time_embed)
        # 返回處理後的輸出
        return x
# 定義 Kandinsky3ConditionalGroupNorm 類,繼承自 nn.Module
class Kandinsky3ConditionalGroupNorm(nn.Module):
    # 初始化方法,設定分組數、標準化形狀和上下文維度
    def __init__(self, groups, normalized_shape, context_dim):
        # 呼叫父類建構函式
        super().__init__()
        # 建立分組歸一化層,不使用仿射變換
        self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
        # 定義上下文多層感知機,包含 SiLU 啟用和線性層
        self.context_mlp = nn.Sequential(nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape))
        # 將線性層的權重初始化為零
        self.context_mlp[1].weight.data.zero_()
        # 將線性層的偏置初始化為零
        self.context_mlp[1].bias.data.zero_()

    # 前向傳播方法,接收輸入和上下文
    def forward(self, x, context):
        # 透過上下文多層感知機處理上下文
        context = self.context_mlp(context)

        # 為了匹配輸入的維度,逐層擴充套件上下文
        for _ in range(len(x.shape[2:])):
            context = context.unsqueeze(-1)

        # 將上下文分割為縮放和偏移量
        scale, shift = context.chunk(2, dim=1)
        # 應用歸一化並進行縮放和偏移
        x = self.norm(x) * (scale + 1.0) + shift
        # 返回處理後的輸入
        return x


# 定義 Kandinsky3Block 類,繼承自 nn.Module
class Kandinsky3Block(nn.Module):
    # 初始化方法,設定輸入通道、輸出通道、時間嵌入維度等引數
    def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
        # 呼叫父類建構函式
        super().__init__()
        # 建立條件分組歸一化層
        self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
        # 定義 SiLU 啟用函式
        self.activation = nn.SiLU()
        # 如果需要上取樣,使用轉置卷積進行上取樣
        if up_resolution is not None and up_resolution:
            self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
        else:
            # 否則使用恆等對映
            self.up_sample = nn.Identity()

        # 根據卷積核大小確定填充
        padding = int(kernel_size > 1)
        # 定義卷積投影層
        self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

        # 如果不需要上取樣,定義下采樣卷積層
        if up_resolution is not None and not up_resolution:
            self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
        else:
            # 否則使用恆等對映
            self.down_sample = nn.Identity()

    # 前向傳播方法,接收輸入和時間嵌入
    def forward(self, x, time_embed):
        # 透過條件分組歸一化處理輸入
        x = self.group_norm(x, time_embed)
        # 應用啟用函式
        x = self.activation(x)
        # 進行上取樣
        x = self.up_sample(x)
        # 透過卷積投影層處理輸入
        x = self.projection(x)
        # 進行下采樣
        x = self.down_sample(x)
        # 返回處理後的輸出
        return x


# 定義 Kandinsky3ResNetBlock 類,繼承自 nn.Module
class Kandinsky3ResNetBlock(nn.Module):
    # 初始化方法,設定輸入通道、輸出通道、時間嵌入維度等引數
    def __init__(
        self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None]
    # 初始化父類
        ):
            super().__init__()
            # 定義卷積核的大小
            kernel_sizes = [1, 3, 3, 1]
            # 計算隱藏通道數
            hidden_channel = max(in_channels, out_channels) // compression_ratio
            # 構建隱藏通道的元組列表
            hidden_channels = (
                [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
            )
            # 建立包含多個 Kandinsky3Block 的模組列表
            self.resnet_blocks = nn.ModuleList(
                [
                    Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
                    # 將隱藏通道、卷積核大小和上取樣解析度組合在一起
                    for (in_channel, out_channel), kernel_size, up_resolution in zip(
                        hidden_channels, kernel_sizes, up_resolutions
                    )
                ]
            )
            # 定義上取樣的快捷連線
            self.shortcut_up_sample = (
                nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
                # 如果存在上取樣解析度,則使用反摺積;否則使用恆等對映
                if True in up_resolutions
                else nn.Identity()
            )
            # 定義通道數不同時的投影連線
            self.shortcut_projection = (
                nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
            )
            # 定義下采樣的快捷連線
            self.shortcut_down_sample = (
                nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
                # 如果存在下采樣解析度,則使用卷積;否則使用恆等對映
                if False in up_resolutions
                else nn.Identity()
            )
    
        # 前向傳播方法
        def forward(self, x, time_embed):
            # 初始化輸出為輸入
            out = x
            # 依次透過每個 ResNet 塊
            for resnet_block in self.resnet_blocks:
                out = resnet_block(out, time_embed)
    
            # 上取樣輸入
            x = self.shortcut_up_sample(x)
            # 投影輸入到輸出通道
            x = self.shortcut_projection(x)
            # 下采樣輸入
            x = self.shortcut_down_sample(x)
            # 將輸出與處理後的輸入相加
            x = x + out
            # 返回最終輸出
            return x
# 定義 Kandinsky3AttentionPooling 類,繼承自 nn.Module
class Kandinsky3AttentionPooling(nn.Module):
    # 初始化方法,接受通道數、上下文維度和頭維度
    def __init__(self, num_channels, context_dim, head_dim=64):
        # 呼叫父類建構函式
        super().__init__()
        # 建立注意力機制物件,指定輸入和輸出維度及其他引數
        self.attention = Attention(
            context_dim,
            context_dim,
            dim_head=head_dim,
            out_dim=num_channels,
            out_bias=False,
        )

    # 前向傳播方法
    def forward(self, x, context, context_mask=None):
        # 將上下文掩碼轉換為與上下文相同的資料型別
        context_mask = context_mask.to(dtype=context.dtype)
        # 使用注意力機制計算上下文與其平均值的加權和
        context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
        # 返回輸入與上下文的和
        return x + context.squeeze(1)


# 定義 Kandinsky3AttentionBlock 類,繼承自 nn.Module
class Kandinsky3AttentionBlock(nn.Module):
    # 初始化方法,接受多種引數
    def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
        # 呼叫父類建構函式
        super().__init__()
        # 建立條件組歸一化物件,用於輸入規範化
        self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
        # 建立注意力機制物件,指定輸入和輸出維度及其他引數
        self.attention = Attention(
            num_channels,
            context_dim or num_channels,
            dim_head=head_dim,
            out_dim=num_channels,
            out_bias=False,
        )

        # 計算隱藏通道數,作為擴充套件比和通道數的乘積
        hidden_channels = expansion_ratio * num_channels
        # 建立條件組歸一化物件,用於輸出規範化
        self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
        # 定義前饋網路,包含兩個卷積層和啟用函式
        self.feed_forward = nn.Sequential(
            nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
            nn.SiLU(),
            nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
        )

    # 前向傳播方法
    def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
        # 獲取輸入的高度和寬度
        height, width = x.shape[-2:]
        # 對輸入進行歸一化處理
        out = self.in_norm(x, time_embed)
        # 將輸出重塑為適合注意力機制的形狀
        out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1)
        # 如果沒有上下文,則使用當前的輸出作為上下文
        context = context if context is not None else out
        # 如果存在上下文掩碼,轉換為與上下文相同的資料型別
        if context_mask is not None:
            context_mask = context_mask.to(dtype=context.dtype)

        # 使用注意力機制處理輸出和上下文
        out = self.attention(out, context, context_mask)
        # 重塑輸出為原始輸入形狀
        out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
        # 將處理後的輸出與原輸入相加
        x = x + out

        # 對相加後的結果進行輸出歸一化
        out = self.out_norm(x, time_embed)
        # 透過前饋網路處理歸一化輸出
        out = self.feed_forward(out)
        # 將處理後的輸出與相加後的輸入相加
        x = x + out
        # 返回最終輸出
        return x

.\diffusers\models\unets\unet_motion_model.py

# 版權宣告,表明該檔案的所有權及相關使用條款
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根據 Apache License, Version 2.0 (“許可證”) 授權;
# 除非遵守許可證,否則您不得使用此檔案。
# 您可以在以下網址獲取許可證副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面協議另有規定,否則根據許可證分發的軟體
# 是在“按原樣”基礎上分發的,不提供任何形式的保證或條件,
# 無論是明示還是暗示。
# 有關許可證所管轄的許可權和限制,請參見許可證。
#
# 匯入所需的庫和模組
from dataclasses import dataclass  # 匯入資料類裝飾器
from typing import Any, Dict, Optional, Tuple, Union  # 匯入型別提示相關的型別

import torch  # 匯入 PyTorch 庫
import torch.nn as nn  # 匯入 PyTorch 的神經網路模組
import torch.nn.functional as F  # 匯入 PyTorch 的功能性神經網路模組
import torch.utils.checkpoint  # 匯入 PyTorch 的檢查點功能

# 匯入自定義的配置和載入工具
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, is_torch_version, logging  # 匯入常用的工具函式
from ...utils.torch_utils import apply_freeu  # 匯入應用 FreeU 的工具函式
from ..attention import BasicTransformerBlock  # 匯入基礎變換器模組
from ..attention_processor import (  # 匯入注意力處理器相關的類
    ADDED_KV_ATTENTION_PROCESSORS,
    CROSS_ATTENTION_PROCESSORS,
    Attention,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
    AttnProcessor2_0,
    FusedAttnProcessor2_0,
    IPAdapterAttnProcessor,
    IPAdapterAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps  # 匯入時間步嵌入相關的類
from ..modeling_utils import ModelMixin  # 匯入模型混合工具
from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D  # 匯入 ResNet 相關的模組
from ..transformers.dual_transformer_2d import DualTransformer2DModel  # 匯入雙重變換器模型
from ..transformers.transformer_2d import Transformer2DModel  # 匯入 2D 變換器模型
from .unet_2d_blocks import UNetMidBlock2DCrossAttn  # 匯入 U-Net 中間塊
from .unet_2d_condition import UNet2DConditionModel  # 匯入條件 U-Net 模型

logger = logging.get_logger(__name__)  # 獲取當前模組的日誌記錄器,便於除錯和日誌輸出

@dataclass
class UNetMotionOutput(BaseOutput):  # 定義 UNetMotionOutput 資料類,繼承自 BaseOutput
    """
    [`UNetMotionOutput`] 的輸出。

    引數:
        sample (`torch.Tensor` 的形狀為 `(batch_size, num_channels, num_frames, height, width)`):
            基於 `encoder_hidden_states` 輸入的隱藏狀態輸出。模型最後一層的輸出。
    """

    sample: torch.Tensor  # 定義 sample 屬性,型別為 torch.Tensor


class AnimateDiffTransformer3D(nn.Module):  # 定義 AnimateDiffTransformer3D 類,繼承自 nn.Module
    """
    一個用於影片類資料的變換器模型。
    # 引數說明部分,描述初始化函式中每個引數的用途
    Parameters:
        # 多頭注意力機制中頭的數量,預設為16
        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
        # 每個頭中的通道數,預設為88
        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
        # 輸入和輸出的通道數,如果輸入是**連續**,則需要指定
        in_channels (`int`, *optional*):
            The number of channels in the input and output (specify if the input is **continuous**).
        # Transformer塊的層數,預設為1
        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
        # dropout機率,預設為0.0
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        # 使用的`encoder_hidden_states`維度數
        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
        # 配置`TransformerBlock`的注意力是否包含偏置引數
        attention_bias (`bool`, *optional*):
            Configure if the `TransformerBlock` attention should contain a bias parameter.
        # 潛在影像的寬度,如果輸入是**離散**,則需要指定
        sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
            # 該值在訓練期間固定,用於學習位置嵌入的數量
            This is fixed during training since it is used to learn a number of position embeddings.
        # 前饋中的啟用函式,預設為"geglu"
        activation_fn (`str`, *optional*, defaults to `"geglu"`):
            Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
            activation functions.
        # 配置`TransformerBlock`是否使用可學習的逐元素仿射引數進行歸一化
        norm_elementwise_affine (`bool`, *optional*):
            Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
        # 配置每個`TransformerBlock`是否包含兩個自注意力層
        double_self_attention (`bool`, *optional*):
            Configure if each `TransformerBlock` should contain two self-attention layers.
        # 應用到序列輸入的位置資訊嵌入的型別
        positional_embeddings: (`str`, *optional*):
            The type of positional embeddings to apply to the sequence input before passing use.
        # 應用位置嵌入的最大序列長度
        num_positional_embeddings: (`int`, *optional*):
            The maximum length of the sequence over which to apply positional embeddings.
    """

    # 初始化方法定義
    def __init__(
        # 多頭注意力機制中頭的數量,預設為16
        self,
        num_attention_heads: int = 16,
        # 每個頭中的通道數,預設為88
        attention_head_dim: int = 88,
        # 輸入通道數,可選
        in_channels: Optional[int] = None,
        # 輸出通道數,可選
        out_channels: Optional[int] = None,
        # Transformer塊的層數,預設為1
        num_layers: int = 1,
        # dropout機率,預設為0.0
        dropout: float = 0.0,
        # 歸一化分組數,預設為32
        norm_num_groups: int = 32,
        # 使用的`encoder_hidden_states`維度數,可選
        cross_attention_dim: Optional[int] = None,
        # 注意力是否包含偏置引數,預設為False
        attention_bias: bool = False,
        # 潛在影像的寬度,可選
        sample_size: Optional[int] = None,
        # 前饋中的啟用函式,預設為"geglu"
        activation_fn: str = "geglu",
        # 歸一化是否使用可學習的逐元素仿射引數,預設為True
        norm_elementwise_affine: bool = True,
        # 每個`TransformerBlock`是否包含兩個自注意力層,預設為True
        double_self_attention: bool = True,
        # 位置資訊嵌入的型別,可選
        positional_embeddings: Optional[str] = None,
        # 應用位置嵌入的最大序列長度,可選
        num_positional_embeddings: Optional[int] = None,
    ):
        # 呼叫父類的建構函式以初始化父類的屬性
        super().__init__()
        # 設定注意力頭的數量
        self.num_attention_heads = num_attention_heads
        # 設定每個注意力頭的維度
        self.attention_head_dim = attention_head_dim
        # 計算內部維度,等於注意力頭數量與每個注意力頭維度的乘積
        inner_dim = num_attention_heads * attention_head_dim

        # 設定輸入通道數
        self.in_channels = in_channels

        # 定義歸一化層,使用組歸一化,允許可學習的偏移
        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
        # 定義輸入線性變換層,將輸入通道對映到內部維度
        self.proj_in = nn.Linear(in_channels, inner_dim)

        # 3. 定義變換器塊
        self.transformer_blocks = nn.ModuleList(
            [
                # 建立指定數量的基本變換器塊
                BasicTransformerBlock(
                    inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    dropout=dropout,
                    cross_attention_dim=cross_attention_dim,
                    activation_fn=activation_fn,
                    attention_bias=attention_bias,
                    double_self_attention=double_self_attention,
                    norm_elementwise_affine=norm_elementwise_affine,
                    positional_embeddings=positional_embeddings,
                    num_positional_embeddings=num_positional_embeddings,
                )
                # 遍歷建立 num_layers 個基本變換器塊
                for _ in range(num_layers)
            ]
        )

        # 定義輸出線性變換層,將內部維度對映回輸入通道數
        self.proj_out = nn.Linear(inner_dim, in_channels)

    def forward(
        # 定義前向傳播方法的輸入引數
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        encoder_hidden_states: Optional[torch.LongTensor] = None,  # 編碼器的隱藏狀態,預設為 None
        timestep: Optional[torch.LongTensor] = None,  # 時間步,預設為 None
        class_labels: Optional[torch.LongTensor] = None,  # 類標籤,預設為 None
        num_frames: int = 1,  # 幀數,預設值為 1
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 跨注意力引數,預設為 None
    # 該方法用於 [`AnimateDiffTransformer3D`] 的前向傳播
    
        ) -> torch.Tensor:
            """
            方法引數說明:
                hidden_states (`torch.LongTensor`): 輸入的隱狀態,形狀為 `(batch size, num latent pixels)` 或 `(batch size, channel, height, width)` 
                encoder_hidden_states ( `torch.LongTensor`, *可選*): 
                    交叉注意力層的條件嵌入。如果未提供,交叉注意力將預設使用自注意力。
                timestep ( `torch.LongTensor`, *可選*): 
                    用於指示去噪步驟的時間戳。
                class_labels ( `torch.LongTensor`, *可選*): 
                    用於指示類別標籤的條件嵌入。
                num_frames (`int`, *可選*, 預設為 1): 
                    每個批次處理的幀數,用於重新形狀隱狀態。
                cross_attention_kwargs (`dict`, *可選*): 
                    可選的關鍵字字典,傳遞給 `AttentionProcessor`。
            返回值:
                torch.Tensor: 
                    輸出張量。
            """
            # 1. 輸入
            # 獲取輸入隱狀態的形狀資訊
            batch_frames, channel, height, width = hidden_states.shape
            # 計算批次大小
            batch_size = batch_frames // num_frames
    
            # 將隱狀態保留用於殘差連線
            residual = hidden_states
    
            # 調整隱狀態的形狀以適應批次和幀數
            hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
            # 調整維度順序以便後續處理
            hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
    
            # 對隱狀態進行規範化
            hidden_states = self.norm(hidden_states)
            # 再次調整維度順序並重塑為適當的形狀
            hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
    
            # 輸入層投影
            hidden_states = self.proj_in(hidden_states)
    
            # 2. 處理塊
            # 遍歷每個變換塊以處理隱狀態
            for block in self.transformer_blocks:
                hidden_states = block(
                    hidden_states,  # 當前的隱狀態
                    encoder_hidden_states=encoder_hidden_states,  # 可選的編碼器隱狀態
                    timestep=timestep,  # 可選的時間戳
                    cross_attention_kwargs=cross_attention_kwargs,  # 可選的交叉注意力引數
                    class_labels=class_labels,  # 可選的類標籤
                )
    
            # 3. 輸出
            # 輸出層投影
            hidden_states = self.proj_out(hidden_states)
            # 調整輸出張量的形狀
            hidden_states = (
                hidden_states[None, None, :]  # 新增維度
                .reshape(batch_size, height, width, num_frames, channel)  # 重塑為適當形狀
                .permute(0, 3, 4, 1, 2)  # 調整維度順序
                .contiguous()  # 確保記憶體連續性
            )
            # 最終調整輸出的形狀
            hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
    
            # 將殘差新增到輸出中以形成最終輸出
            output = hidden_states + residual
            # 返回最終的輸出張量
            return output
# 定義一個名為 DownBlockMotion 的類,繼承自 nn.Module
class DownBlockMotion(nn.Module):
    # 初始化方法,定義多個引數,包括輸入輸出通道、dropout 率等
    def __init__(
        self,
        in_channels: int,  # 輸入通道數量
        out_channels: int,  # 輸出通道數量
        temb_channels: int,  # 時間嵌入通道數量
        dropout: float = 0.0,  # dropout 率,預設為 0
        num_layers: int = 1,  # 網路層數,預設為 1
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 引數
        resnet_time_scale_shift: str = "default",  # ResNet 時間尺度偏移
        resnet_act_fn: str = "swish",  # ResNet 啟用函式,預設為 swish
        resnet_groups: int = 32,  # ResNet 組數,預設為 32
        resnet_pre_norm: bool = True,  # ResNet 是否使用預歸一化
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        add_downsample: bool = True,  # 是否新增下采樣層
        downsample_padding: int = 1,  # 下采樣時的填充
        temporal_num_attention_heads: Union[int, Tuple[int]] = 1,  # 時間注意力頭數
        temporal_cross_attention_dim: Optional[int] = None,  # 時間交叉注意力維度
        temporal_max_seq_length: int = 32,  # 最大序列長度
        temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 每個塊的變換器層數
        temporal_double_self_attention: bool = True,  # 是否雙重自注意力
    ):
    # 前向傳播方法,接收隱藏狀態和時間嵌入等引數
    def forward(
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        num_frames: int = 1,  # 幀數,預設為 1
        *args,  # 接受任意位置引數
        **kwargs,  # 接受任意關鍵字引數
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:  # 返回隱藏狀態和輸出狀態的張量或元組
        # 檢查位置引數或關鍵字引數中的 scale 是否被傳遞
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 定義棄用資訊,提示使用者 scale 引數將被忽略
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫棄用函式,發出警告
            deprecate("scale", "1.0.0", deprecation_message)

        # 初始化輸出狀態為一個空元組
        output_states = ()

        # 將 ResNet 和運動模組進行配對
        blocks = zip(self.resnets, self.motion_modules)
        # 遍歷每對 ResNet 和運動模組
        for resnet, motion_module in blocks:
            # 如果處於訓練模式且啟用了梯度檢查點
            if self.training and self.gradient_checkpointing:
                # 定義一個自定義前向傳播函式
                def create_custom_forward(module):
                    def custom_forward(*inputs):  # 自定義前向函式,接受任意輸入
                        return module(*inputs)  # 返回模組的輸出

                    return custom_forward  # 返回自定義前向函式

                # 如果 PyTorch 版本大於等於 1.11.0
                if is_torch_version(">=", "1.11.0"):
                    # 使用檢查點機制來節省記憶體
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),  # 建立的自定義前向函式
                        hidden_states,  # 輸入的隱藏狀態
                        temb,  # 輸入的時間嵌入
                        use_reentrant=False,  # 不使用重入
                    )
                else:
                    # 在較早版本中也使用檢查點機制
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb
                    )

            else:
                # 如果不是訓練模式,直接透過 ResNet 處理隱藏狀態
                hidden_states = resnet(hidden_states, temb)

            # 使用運動模組處理當前的隱藏狀態
            hidden_states = motion_module(hidden_states, num_frames=num_frames)

            # 將當前隱藏狀態新增到輸出狀態中
            output_states = output_states + (hidden_states,)

        # 如果下采樣器不為空
        if self.downsamplers is not None:
            # 遍歷每個下采樣器
            for downsampler in self.downsamplers:
                # 透過下采樣器處理隱藏狀態
                hidden_states = downsampler(hidden_states)

            # 將下采樣後的隱藏狀態新增到輸出狀態中
            output_states = output_states + (hidden_states,)

        # 返回最終的隱藏狀態和輸出狀態
        return hidden_states, output_states
    # 初始化方法,用於設定網路的引數
        def __init__(
            # 輸入通道數量
            self,
            in_channels: int,
            # 輸出通道數量
            out_channels: int,
            # 時間嵌入通道數量
            temb_channels: int,
            # dropout 機率,預設為 0.0
            dropout: float = 0.0,
            # 網路層數,預設為 1
            num_layers: int = 1,
            # 每個塊中的變換器層數,預設為 1
            transformer_layers_per_block: Union[int, Tuple[int]] = 1,
            # ResNet 中的 epsilon 值,預設為 1e-6
            resnet_eps: float = 1e-6,
            # ResNet 時間尺度偏移,預設為 "default"
            resnet_time_scale_shift: str = "default",
            # ResNet 啟用函式,預設為 "swish"
            resnet_act_fn: str = "swish",
            # ResNet 中的組數,預設為 32
            resnet_groups: int = 32,
            # 是否在 ResNet 中使用預歸一化,預設為 True
            resnet_pre_norm: bool = True,
            # 注意力頭的數量,預設為 1
            num_attention_heads: int = 1,
            # 交叉注意力維度,預設為 1280
            cross_attention_dim: int = 1280,
            # 輸出縮放因子,預設為 1.0
            output_scale_factor: float = 1.0,
            # 下采樣填充,預設為 1
            downsample_padding: int = 1,
            # 是否新增下采樣層,預設為 True
            add_downsample: bool = True,
            # 是否使用雙交叉注意力,預設為 False
            dual_cross_attention: bool = False,
            # 是否使用線性投影,預設為 False
            use_linear_projection: bool = False,
            # 是否僅使用交叉注意力,預設為 False
            only_cross_attention: bool = False,
            # 是否提升注意力計算精度,預設為 False
            upcast_attention: bool = False,
            # 注意力型別,預設為 "default"
            attention_type: str = "default",
            # 時間交叉注意力維度,可選引數
            temporal_cross_attention_dim: Optional[int] = None,
            # 時間注意力頭數量,預設為 8
            temporal_num_attention_heads: int = 8,
            # 時間序列的最大長度,預設為 32
            temporal_max_seq_length: int = 32,
            # 時間變換器塊中的層數,預設為 1
            temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
            # 是否使用雙重自注意力,預設為 True
            temporal_double_self_attention: bool = True,
        # 前向傳播方法,定義如何透過模型傳遞資料
        def forward(
            # 隱藏狀態張量,輸入到模型中的主要資料
            self,
            hidden_states: torch.Tensor,
            # 可選的時間嵌入張量
            temb: Optional[torch.Tensor] = None,
            # 可選的編碼器隱藏狀態
            encoder_hidden_states: Optional[torch.Tensor] = None,
            # 可選的注意力掩碼
            attention_mask: Optional[torch.Tensor] = None,
            # 每次處理的幀數,預設為 1
            num_frames: int = 1,
            # 可選的編碼器注意力掩碼
            encoder_attention_mask: Optional[torch.Tensor] = None,
            # 可選的交叉注意力引數
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 可選的額外殘差連線
            additional_residuals: Optional[torch.Tensor] = None,
    ):
        # 檢查 cross_attention_kwargs 是否不為空
        if cross_attention_kwargs is not None:
            # 檢查 scale 引數是否存在,若存在則發出警告
            if cross_attention_kwargs.get("scale", None) is not None:
                logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

        # 初始化輸出狀態為空元組
        output_states = ()

        # 將自殘差網路、注意力模組和運動模組組合成一個列表
        blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
        # 遍歷組合後的模組及其索引
        for i, (resnet, attn, motion_module) in enumerate(blocks):
            # 如果處於訓練狀態且啟用了梯度檢查點
            if self.training and self.gradient_checkpointing:

                # 定義自定義前向傳播函式
                def create_custom_forward(module, return_dict=None):
                    # 定義實際的前向傳播邏輯
                    def custom_forward(*inputs):
                        # 根據 return_dict 的值選擇返回方式
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)

                    return custom_forward

                # 定義檢查點引數字典,根據 PyTorch 版本設定
                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                # 使用檢查點機制計算隱藏狀態
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(resnet),
                    hidden_states,
                    temb,
                    **ckpt_kwargs,
                )
                # 透過注意力模組處理隱藏狀態
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
            else:
                # 在非訓練模式下直接透過殘差網路處理隱藏狀態
                hidden_states = resnet(hidden_states, temb)

                # 透過注意力模組處理隱藏狀態
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
            # 透過運動模組處理隱藏狀態
            hidden_states = motion_module(
                hidden_states,
                num_frames=num_frames,
            )

            # 如果是最後一對模組且有額外殘差,則將其應用到隱藏狀態
            if i == len(blocks) - 1 and additional_residuals is not None:
                hidden_states = hidden_states + additional_residuals

            # 將當前隱藏狀態新增到輸出狀態中
            output_states = output_states + (hidden_states,)

        # 如果存在下采樣模組,則依次應用它們
        if self.downsamplers is not None:
            for downsampler in self.downsamplers:
                hidden_states = downsampler(hidden_states)

            # 將下采樣後的隱藏狀態新增到輸出狀態中
            output_states = output_states + (hidden_states,)

        # 返回最終的隱藏狀態和輸出狀態
        return hidden_states, output_states
# 定義一個繼承自 nn.Module 的類,用於交叉注意力上取樣塊
class CrossAttnUpBlockMotion(nn.Module):
    # 初始化方法,設定各層的引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: int,  # 輸出通道數
        prev_output_channel: int,  # 前一層輸出的通道數
        temb_channels: int,  # 時間嵌入通道數
        resolution_idx: Optional[int] = None,  # 解析度索引,預設為 None
        dropout: float = 0.0,  # dropout 機率
        num_layers: int = 1,  # 層數
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 每個塊的變換器層數
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值
        resnet_time_scale_shift: str = "default",  # ResNet 時間縮放偏移
        resnet_act_fn: str = "swish",  # ResNet 啟用函式
        resnet_groups: int = 32,  # ResNet 組數
        resnet_pre_norm: bool = True,  # 是否在前面進行歸一化
        num_attention_heads: int = 1,  # 注意力頭的數量
        cross_attention_dim: int = 1280,  # 交叉注意力的維度
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        add_upsample: bool = True,  # 是否新增上取樣
        dual_cross_attention: bool = False,  # 是否使用雙重交叉注意力
        use_linear_projection: bool = False,  # 是否使用線性投影
        only_cross_attention: bool = False,  # 是否僅使用交叉注意力
        upcast_attention: bool = False,  # 是否上浮注意力
        attention_type: str = "default",  # 注意力型別
        temporal_cross_attention_dim: Optional[int] = None,  # 時間交叉注意力維度,預設為 None
        temporal_num_attention_heads: int = 8,  # 時間注意力頭數量
        temporal_max_seq_length: int = 32,  # 時間序列的最大長度
        temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 時間塊的變換器層數
    # 定義前向傳播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 之前隱藏狀態的元組
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 可選的編碼器隱藏狀態
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 交叉注意力的可選引數
        upsample_size: Optional[int] = None,  # 可選的上取樣大小
        attention_mask: Optional[torch.Tensor] = None,  # 可選的注意力掩碼
        encoder_attention_mask: Optional[torch.Tensor] = None,  # 可選的編碼器注意力掩碼
        num_frames: int = 1,  # 幀數,預設為 1
# 定義一個繼承自 nn.Module 的類,用於上取樣塊
class UpBlockMotion(nn.Module):
    # 初始化方法,設定各層的引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        prev_output_channel: int,  # 前一層輸出的通道數
        out_channels: int,  # 輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        resolution_idx: Optional[int] = None,  # 解析度索引,預設為 None
        dropout: float = 0.0,  # dropout 機率
        num_layers: int = 1,  # 層數
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值
        resnet_time_scale_shift: str = "default",  # ResNet 時間縮放偏移
        resnet_act_fn: str = "swish",  # ResNet 啟用函式
        resnet_groups: int = 32,  # ResNet 組數
        resnet_pre_norm: bool = True,  # 是否在前面進行歸一化
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        add_upsample: bool = True,  # 是否新增上取樣
        temporal_cross_attention_dim: Optional[int] = None,  # 時間交叉注意力維度,預設為 None
        temporal_num_attention_heads: int = 8,  # 時間注意力頭數量
        temporal_max_seq_length: int = 32,  # 時間序列的最大長度
        temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 時間塊的變換器層數
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 初始化空列表,用於存放 ResNet 模組
        resnets = []
        # 初始化空列表,用於存放運動模組
        motion_modules = []

        # 支援每個時間塊的變換層數量為變數
        if isinstance(temporal_transformer_layers_per_block, int):
            # 將單個整數轉換為與層數相同的元組
            temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
        elif len(temporal_transformer_layers_per_block) != num_layers:
            # 檢查傳入的層數是否與預期一致
            raise ValueError(
                f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
            )

        # 遍歷每層,構建 ResNet 和運動模組
        for i in range(num_layers):
            # 設定跳過連線的通道數
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 設定當前層的輸入通道數
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 新增 ResNetBlock2D 模組到 resnets 列表
            resnets.append(
                ResnetBlock2D(
                    # 輸入通道數為當前層的輸入和跳過連線的通道數之和
                    in_channels=resnet_in_channels + res_skip_channels,
                    # 輸出通道數設定
                    out_channels=out_channels,
                    # 時間嵌入通道數
                    temb_channels=temb_channels,
                    # 小常數以避免除零
                    eps=resnet_eps,
                    # 組歸一化的組數
                    groups=resnet_groups,
                    # Dropout 率
                    dropout=dropout,
                    # 時間嵌入的歸一化方式
                    time_embedding_norm=resnet_time_scale_shift,
                    # 啟用函式設定
                    non_linearity=resnet_act_fn,
                    # 輸出尺度因子
                    output_scale_factor=output_scale_factor,
                    # 是否使用預歸一化
                    pre_norm=resnet_pre_norm,
                )
            )

            # 新增 AnimateDiffTransformer3D 模組到 motion_modules 列表
            motion_modules.append(
                AnimateDiffTransformer3D(
                    # 注意力頭的數量
                    num_attention_heads=temporal_num_attention_heads,
                    # 輸入通道數
                    in_channels=out_channels,
                    # 當前層的變換層數量
                    num_layers=temporal_transformer_layers_per_block[i],
                    # 組歸一化的組數
                    norm_num_groups=resnet_groups,
                    # 跨注意力維度
                    cross_attention_dim=temporal_cross_attention_dim,
                    # 是否使用注意力偏置
                    attention_bias=False,
                    # 啟用函式型別
                    activation_fn="geglu",
                    # 位置資訊嵌入型別
                    positional_embeddings="sinusoidal",
                    # 位置資訊嵌入數量
                    num_positional_embeddings=temporal_max_seq_length,
                    # 每個注意力頭的維度
                    attention_head_dim=out_channels // temporal_num_attention_heads,
                )
            )

        # 將 ResNet 模組列表轉換為 nn.ModuleList
        self.resnets = nn.ModuleList(resnets)
        # 將運動模組列表轉換為 nn.ModuleList
        self.motion_modules = nn.ModuleList(motion_modules)

        # 如果需要上取樣,則初始化上取樣模組
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            # 否則,設定為 None
            self.upsamplers = None

        # 設定梯度檢查點標誌為 False
        self.gradient_checkpointing = False
        # 儲存解析度索引
        self.resolution_idx = resolution_idx

    def forward(
        # 前向傳播方法的引數定義
        self,
        hidden_states: torch.Tensor,
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],
        # 可選的時間嵌入
        temb: Optional[torch.Tensor] = None,
        # 上取樣大小
        upsample_size=None,
        # 幀數,預設為 1
        num_frames: int = 1,
        # 額外的引數
        *args,
        **kwargs,
    # 函式返回型別為 torch.Tensor
    ) -> torch.Tensor:
        # 檢查傳入引數是否存在或 "scale" 引數是否為非 None
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 定義棄用提示資訊
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫 deprecate 函式記錄棄用警告
            deprecate("scale", "1.0.0", deprecation_message)
    
        # 檢查 FreeU 是否啟用,確保相關屬性均不為 None
        is_freeu_enabled = (
            getattr(self, "s1", None)
            and getattr(self, "s2", None)
            and getattr(self, "b1", None)
            and getattr(self, "b2", None)
        )
    
        # 將自定義模組打包成元組,方便遍歷
        blocks = zip(self.resnets, self.motion_modules)
    
        # 遍歷每一對 resnet 和 motion_module
        for resnet, motion_module in blocks:
            # 從隱藏狀態元組中彈出最後一個隱藏狀態
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新隱藏狀態元組,移除最後一個元素
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
    
            # 如果啟用 FreeU,則僅對前兩個階段進行操作
            if is_freeu_enabled:
                # 應用 FreeU 函式獲取新的隱藏狀態
                hidden_states, res_hidden_states = apply_freeu(
                    self.resolution_idx,
                    hidden_states,
                    res_hidden_states,
                    s1=self.s1,
                    s2=self.s2,
                    b1=self.b1,
                    b2=self.b2,
                )
    
            # 將當前隱藏狀態和殘差隱藏狀態在維度 1 上拼接
            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
    
            # 如果在訓練模式並且啟用了梯度檢查點
            if self.training and self.gradient_checkpointing:
                # 定義建立自定義前向傳播函式
                def create_custom_forward(module):
                    # 定義自定義前向傳播的實現
                    def custom_forward(*inputs):
                        return module(*inputs)
    
                    return custom_forward
    
                # 如果 torch 版本大於等於 1.11.0
                if is_torch_version(">=", "1.11.0"):
                    # 使用檢查點機制儲存記憶體
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),
                        hidden_states,
                        temb,
                        use_reentrant=False,
                    )
                else:
                    # 否則使用舊版檢查點機制
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb
                    )
            else:
                # 否則直接透過 resnet 計算隱藏狀態
                hidden_states = resnet(hidden_states, temb)
    
            # 透過 motion_module 處理隱藏狀態,傳入幀數
            hidden_states = motion_module(hidden_states, num_frames=num_frames)
    
        # 如果存在上取樣器,則對每個上取樣器進行處理
        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                # 透過上取樣器處理隱藏狀態,傳入上取樣大小
                hidden_states = upsampler(hidden_states, upsample_size)
    
        # 返回最終處理後的隱藏狀態
        return hidden_states
# 定義 UNetMidBlockCrossAttnMotion 類,繼承自 nn.Module
class UNetMidBlockCrossAttnMotion(nn.Module):
    # 初始化方法,定義類的引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        temb_channels: int,  # 時間嵌入通道數
        dropout: float = 0.0,  # Dropout 率
        num_layers: int = 1,  # 層數
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 每個塊的變換層數
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值
        resnet_time_scale_shift: str = "default",  # ResNet 時間尺度偏移
        resnet_act_fn: str = "swish",  # ResNet 啟用函式型別
        resnet_groups: int = 32,  # ResNet 組數
        resnet_pre_norm: bool = True,  # 是否進行前置歸一化
        num_attention_heads: int = 1,  # 注意力頭數量
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        cross_attention_dim: int = 1280,  # 交叉注意力維度
        dual_cross_attention: bool = False,  # 是否使用雙重交叉注意力
        use_linear_projection: bool = False,  # 是否使用線性投影
        upcast_attention: bool = False,  # 是否上升注意力精度
        attention_type: str = "default",  # 注意力型別
        temporal_num_attention_heads: int = 1,  # 時間注意力頭數量
        temporal_cross_attention_dim: Optional[int] = None,  # 時間交叉注意力維度
        temporal_max_seq_length: int = 32,  # 時間序列最大長度
        temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 時間塊的變換層數
    # 前向傳播方法,定義輸入和輸出
    def forward(
        self,
        hidden_states: torch.Tensor,  # 隱藏狀態的輸入張量
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 可選的編碼器隱藏狀態
        attention_mask: Optional[torch.Tensor] = None,  # 可選的注意力掩碼
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 可選的交叉注意力引數
        encoder_attention_mask: Optional[torch.Tensor] = None,  # 可選的編碼器注意力掩碼
        num_frames: int = 1,  # 幀數
    # 該函式的返回型別為 torch.Tensor
        ) -> torch.Tensor:
            # 檢查交叉注意力引數是否不為 None
            if cross_attention_kwargs is not None:
                # 如果引數中包含 "scale",發出警告,說明該引數已棄用
                if cross_attention_kwargs.get("scale", None) is not None:
                    logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
    
            # 透過第一個殘差網路處理隱藏狀態
            hidden_states = self.resnets[0](hidden_states, temb)
    
            # 將注意力層、殘差網路和運動模組打包在一起
            blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
            # 遍歷每個注意力層、殘差網路和運動模組
            for attn, resnet, motion_module in blocks:
                # 如果在訓練模式下並且啟用了梯度檢查點
                if self.training and self.gradient_checkpointing:
    
                    # 建立自定義前向函式
                    def create_custom_forward(module, return_dict=None):
                        # 定義自定義前向函式,接受任意輸入
                        def custom_forward(*inputs):
                            # 如果返回字典不為 None,使用返回字典呼叫模組
                            if return_dict is not None:
                                return module(*inputs, return_dict=return_dict)
                            else:
                                # 否則直接呼叫模組
                                return module(*inputs)
    
                        return custom_forward
    
                    # 根據 PyTorch 版本設定檢查點引數
                    ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                    # 呼叫注意力模組並獲取輸出的第一個元素
                    hidden_states = attn(
                        hidden_states,
                        encoder_hidden_states=encoder_hidden_states,
                        cross_attention_kwargs=cross_attention_kwargs,
                        attention_mask=attention_mask,
                        encoder_attention_mask=encoder_attention_mask,
                        return_dict=False,
                    )[0]
                    # 使用梯度檢查點對運動模組進行前向傳播
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(motion_module),
                        hidden_states,
                        temb,
                        **ckpt_kwargs,
                    )
                    # 使用梯度檢查點對殘差網路進行前向傳播
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),
                        hidden_states,
                        temb,
                        **ckpt_kwargs,
                    )
                else:
                    # 在非訓練模式下直接呼叫注意力模組
                    hidden_states = attn(
                        hidden_states,
                        encoder_hidden_states=encoder_hidden_states,
                        cross_attention_kwargs=cross_attention_kwargs,
                        attention_mask=attention_mask,
                        encoder_attention_mask=encoder_attention_mask,
                        return_dict=False,
                    )[0]
                    # 呼叫運動模組,傳入隱藏狀態和幀數
                    hidden_states = motion_module(
                        hidden_states,
                        num_frames=num_frames,
                    )
                    # 呼叫殘差網路處理隱藏狀態
                    hidden_states = resnet(hidden_states, temb)
    
            # 返回處理後的隱藏狀態
            return hidden_states
# 定義一個繼承自 nn.Module 的運動模組類
class MotionModules(nn.Module):
    # 初始化方法,接收多個引數配置運動模組
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        layers_per_block: int = 2,  # 每個模組塊的層數,預設是 2
        transformer_layers_per_block: Union[int, Tuple[int]] = 8,  # 每個塊中的變換層數
        num_attention_heads: Union[int, Tuple[int]] = 8,  # 注意力頭的數量
        attention_bias: bool = False,  # 是否使用注意力偏差
        cross_attention_dim: Optional[int] = None,  # 交叉注意力維度
        activation_fn: str = "geglu",  # 啟用函式,預設使用 "geglu"
        norm_num_groups: int = 32,  # 歸一化組的數量
        max_seq_length: int = 32,  # 最大序列長度
    ):
        # 呼叫父類初始化方法
        super().__init__()
        # 初始化運動模組列表
        self.motion_modules = nn.ModuleList([])

        # 如果變換層數是整數,重複為每個模組塊填充
        if isinstance(transformer_layers_per_block, int):
            transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block
        # 檢查變換層數與塊層數是否匹配
        elif len(transformer_layers_per_block) != layers_per_block:
            raise ValueError(
                f"The number of transformer layers per block must match the number of layers per block, "
                f"got {layers_per_block} and {len(transformer_layers_per_block)}"
            )

        # 遍歷每個模組塊
        for i in range(layers_per_block):
            # 向運動模組列表新增 AnimateDiffTransformer3D 例項
            self.motion_modules.append(
                AnimateDiffTransformer3D(
                    in_channels=in_channels,  # 輸入通道數
                    num_layers=transformer_layers_per_block[i],  # 當前塊的變換層數
                    norm_num_groups=norm_num_groups,  # 歸一化組的數量
                    cross_attention_dim=cross_attention_dim,  # 交叉注意力維度
                    activation_fn=activation_fn,  # 啟用函式
                    attention_bias=attention_bias,  # 注意力偏差
                    num_attention_heads=num_attention_heads,  # 注意力頭數量
                    attention_head_dim=in_channels // num_attention_heads,  # 每個注意力頭的維度
                    positional_embeddings="sinusoidal",  # 使用正弦波的位置嵌入
                    num_positional_embeddings=max_seq_length,  # 位置嵌入的數量
                )
            )


# 定義一個運動介面卡類,結合多個混合類
class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    @register_to_config
    # 初始化方法,配置多個運動介面卡引數
    def __init__(
        self,
        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),  # 塊輸出通道
        motion_layers_per_block: Union[int, Tuple[int]] = 2,  # 每個運動塊的層數
        motion_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1,  # 每個運動塊中的變換層數
        motion_mid_block_layers_per_block: int = 1,  # 中間塊的層數
        motion_transformer_layers_per_mid_block: Union[int, Tuple[int]] = 1,  # 中間塊中的變換層數
        motion_num_attention_heads: Union[int, Tuple[int]] = 8,  # 中間塊的注意力頭數量
        motion_norm_num_groups: int = 32,  # 中間塊的歸一化組數量
        motion_max_seq_length: int = 32,  # 中間塊的最大序列長度
        use_motion_mid_block: bool = True,  # 是否使用中間塊
        conv_in_channels: Optional[int] = None,  # 輸入通道數
    ):
        pass  # 前向傳播方法,尚未實現


# 定義一個修改後的條件 2D UNet 模型
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
    r"""
    一個修改後的條件 2D UNet 模型,接收嘈雜樣本、條件狀態和時間步,返回形狀輸出。

    該模型繼承自 [`ModelMixin`]。檢視超類文件以獲取所有模型的通用方法實現(如下載或儲存)。
    """

    # 支援梯度檢查點
    _supports_gradient_checkpointing = True

    @register_to_config
    # 初始化方法,用於建立類的例項
    def __init__(
        # 可選引數,樣本大小,預設為 None
        self,
        sample_size: Optional[int] = None,
        # 輸入通道數,預設為 4
        in_channels: int = 4,
        # 輸出通道數,預設為 4
        out_channels: int = 4,
        # 下采樣塊的型別元組
        down_block_types: Tuple[str, ...] = (
            "CrossAttnDownBlockMotion",  # 第一個下采樣塊型別
            "CrossAttnDownBlockMotion",  # 第二個下采樣塊型別
            "CrossAttnDownBlockMotion",  # 第三個下采樣塊型別
            "DownBlockMotion",            # 第四個下采樣塊型別
        ),
        # 上取樣塊的型別元組
        up_block_types: Tuple[str, ...] = (
            "UpBlockMotion",              # 第一個上取樣塊型別
            "CrossAttnUpBlockMotion",    # 第二個上取樣塊型別
            "CrossAttnUpBlockMotion",    # 第三個上取樣塊型別
            "CrossAttnUpBlockMotion",    # 第四個上取樣塊型別
        ),
        # 塊的輸出通道數元組
        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
        # 每個塊的層數,預設為 2
        layers_per_block: Union[int, Tuple[int]] = 2,
        # 下采樣填充,預設為 1
        downsample_padding: int = 1,
        # 中間塊的縮放因子,預設為 1
        mid_block_scale_factor: float = 1,
        # 啟用函式型別,預設為 "silu"
        act_fn: str = "silu",
        # 歸一化的組數,預設為 32
        norm_num_groups: int = 32,
        # 歸一化的 epsilon 值,預設為 1e-5
        norm_eps: float = 1e-5,
        # 交叉注意力的維度,預設為 1280
        cross_attention_dim: int = 1280,
        # 每個塊的變換器層數,預設為 1
        transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
        # 可選引數,反向變換器層數,預設為 None
        reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
        # 時間變換器的層數,預設為 1
        temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
        # 可選引數,反向時間變換器層數,預設為 None
        reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
        # 每個中間塊的變換器層數,預設為 None
        transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
        # 每個中間塊的時間變換器層數,預設為 1
        temporal_transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = 1,
        # 是否使用線性投影,預設為 False
        use_linear_projection: bool = False,
        # 注意力頭的數量,預設為 8
        num_attention_heads: Union[int, Tuple[int, ...]] = 8,
        # 動作最大序列長度,預設為 32
        motion_max_seq_length: int = 32,
        # 動作注意力頭的數量,預設為 8
        motion_num_attention_heads: Union[int, Tuple[int, ...]] = 8,
        # 可選引數,反向動作注意力頭的數量,預設為 None
        reverse_motion_num_attention_heads: Optional[Union[int, Tuple[int, ...], Tuple[Tuple[int, ...], ...]]] = None,
        # 是否使用動作中間塊,預設為 True
        use_motion_mid_block: bool = True,
        # 中間塊的層數,預設為 1
        mid_block_layers: int = 1,
        # 編碼器隱藏層維度,預設為 None
        encoder_hid_dim: Optional[int] = None,
        # 編碼器隱藏層型別,預設為 None
        encoder_hid_dim_type: Optional[str] = None,
        # 可選引數,附加嵌入型別,預設為 None
        addition_embed_type: Optional[str] = None,
        # 可選引數,附加時間嵌入維度,預設為 None
        addition_time_embed_dim: Optional[int] = None,
        # 可選引數,投影類別嵌入的輸入維度,預設為 None
        projection_class_embeddings_input_dim: Optional[int] = None,
        # 可選引數,時間條件投影維度,預設為 None
        time_cond_proj_dim: Optional[int] = None,
    # 類方法,用於從 UNet2DConditionModel 建立物件
    @classmethod
    def from_unet2d(
        cls,
        # UNet2DConditionModel 物件
        unet: UNet2DConditionModel,
        # 可選的運動介面卡,預設為 None
        motion_adapter: Optional[MotionAdapter] = None,
        # 是否載入權重,預設為 True
        load_weights: bool = True,
    # 凍結 UNet2DConditionModel 的權重,只保留運動模組可訓練,便於微調
    def freeze_unet2d_params(self) -> None:
        """Freeze the weights of just the UNet2DConditionModel, and leave the motion modules
        unfrozen for fine tuning.
        """
        # 凍結所有引數
        for param in self.parameters():
            # 將引數的 requires_grad 屬性設定為 False,禁止梯度更新
            param.requires_grad = False

        # 解凍運動模組
        for down_block in self.down_blocks:
            # 獲取當前下采樣塊的運動模組
            motion_modules = down_block.motion_modules
            for param in motion_modules.parameters():
                # 將運動模組引數的 requires_grad 屬性設定為 True,允許梯度更新
                param.requires_grad = True

        for up_block in self.up_blocks:
            # 獲取當前上取樣塊的運動模組
            motion_modules = up_block.motion_modules
            for param in motion_modules.parameters():
                # 將運動模組引數的 requires_grad 屬性設定為 True,允許梯度更新
                param.requires_grad = True

        # 檢查中間塊是否具有運動模組
        if hasattr(self.mid_block, "motion_modules"):
            # 獲取中間塊的運動模組
            motion_modules = self.mid_block.motion_modules
            for param in motion_modules.parameters():
                # 將運動模組引數的 requires_grad 屬性設定為 True,允許梯度更新
                param.requires_grad = True

    # 載入運動模組的狀態字典
    def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None:
        # 遍歷運動介面卡的下采樣塊
        for i, down_block in enumerate(motion_adapter.down_blocks):
            # 載入下采樣塊的運動模組狀態字典
            self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict())
        # 遍歷運動介面卡的上取樣塊
        for i, up_block in enumerate(motion_adapter.up_blocks):
            # 載入上取樣塊的運動模組狀態字典
            self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict())

        # 支援沒有中間塊的舊運動模組
        if hasattr(self.mid_block, "motion_modules"):
            # 載入中間塊的運動模組狀態字典
            self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict())

    # 儲存運動模組的狀態
    def save_motion_modules(
        self,
        save_directory: str,
        is_main_process: bool = True,
        safe_serialization: bool = True,
        variant: Optional[str] = None,
        push_to_hub: bool = False,
        **kwargs,
    ) -> None:
        # 獲取當前模型的狀態字典
        state_dict = self.state_dict()

        # 提取所有運動模組的狀態
        motion_state_dict = {}
        for k, v in state_dict.items():
            # 篩選出包含 "motion_modules" 的鍵值對
            if "motion_modules" in k:
                motion_state_dict[k] = v

        # 建立運動介面卡例項
        adapter = MotionAdapter(
            block_out_channels=self.config["block_out_channels"],
            motion_layers_per_block=self.config["layers_per_block"],
            motion_norm_num_groups=self.config["norm_num_groups"],
            motion_num_attention_heads=self.config["motion_num_attention_heads"],
            motion_max_seq_length=self.config["motion_max_seq_length"],
            use_motion_mid_block=self.config["use_motion_mid_block"],
        )
        # 載入運動狀態字典
        adapter.load_state_dict(motion_state_dict)
        # 儲存介面卡的預訓練狀態
        adapter.save_pretrained(
            save_directory=save_directory,
            is_main_process=is_main_process,
            safe_serialization=safe_serialization,
            variant=variant,
            push_to_hub=push_to_hub,
            **kwargs,
        )

    @property
    # 從 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 複製的屬性
    # 定義一個方法,返回模型中所有注意力處理器的字典
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        返回值:
            `dict` 型別的注意力處理器: 包含模型中所有注意力處理器的字典,
            按照其權重名稱索引。
        """
        # 初始化一個空字典,用於儲存注意力處理器
        processors = {}

        # 定義一個遞迴函式,用於新增註意力處理器
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 檢查模組是否具有 'get_processor' 方法
            if hasattr(module, "get_processor"):
                # 將處理器新增到字典中,鍵為處理器名稱
                processors[f"{name}.processor"] = module.get_processor()

            # 遍歷模組的所有子模組
            for sub_name, child in module.named_children():
                # 遞迴呼叫,繼續新增子模組的處理器
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            # 返回處理器字典
            return processors

        # 遍歷當前物件的所有子模組
        for name, module in self.named_children():
            # 呼叫遞迴函式新增所有處理器
            fn_recursive_add_processors(name, module, processors)

        # 返回最終的處理器字典
        return processors

    # 從 diffusers.models.unets.unet_2d_condition 中複製的方法,用於設定注意力處理器
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        設定用於計算注意力的注意力處理器。

        引數:
            processor (`dict` of `AttentionProcessor` 或僅 `AttentionProcessor`):
                例項化的處理器類或處理器類的字典,將被設定為**所有** `Attention` 層的處理器。

                如果 `processor` 是字典,鍵需要定義相應的交叉注意力處理器的路徑。
                當設定可訓練的注意力處理器時,強烈推薦這樣做。

        """
        # 獲取當前注意力處理器字典的鍵數量
        count = len(self.attn_processors.keys())

        # 檢查傳入的處理器字典長度是否與注意力層數量匹配
        if isinstance(processor, dict) and len(processor) != count:
            # 如果不匹配,丟擲錯誤
            raise ValueError(
                f"傳入了處理器字典,但處理器數量 {len(processor)} 與"
                f" 注意力層數量 {count} 不匹配。請確保傳入 {count} 個處理器類。"
            )

        # 定義一個遞迴函式,用於設定注意力處理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 檢查模組是否具有 'set_processor' 方法
            if hasattr(module, "set_processor"):
                # 如果處理器不是字典,直接設定處理器
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 從字典中彈出對應的處理器並設定
                    module.set_processor(processor.pop(f"{name}.processor"))

            # 遍歷模組的所有子模組
            for sub_name, child in module.named_children():
                # 遞迴呼叫,繼續設定子模組的處理器
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        # 遍歷當前物件的所有子模組
        for name, module in self.named_children():
            # 呼叫遞迴函式設定所有處理器
            fn_recursive_attn_processor(name, module, processor)
    # 定義一個方法以啟用前向分塊處理
    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
        """
        設定注意力處理器以使用[前饋分塊](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。

        引數:
            chunk_size (`int`, *可選*):
                前饋層的塊大小。如果未指定,將單獨對維度為`dim`的每個張量執行前饋層。
            dim (`int`, *可選*, 預設為`0`):
                前饋計算應分塊的維度。選擇dim=0(批次)或dim=1(序列長度)。
        """
        # 檢查dim引數是否在有效範圍內(0或1)
        if dim not in [0, 1]:
            raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")

        # 預設塊大小為1
        chunk_size = chunk_size or 1

        # 定義遞迴前饋函式以設定模組的分塊前饋處理
        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            # 如果模組有set_chunk_feed_forward屬性,設定塊大小和維度
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            # 遍歷模組的子模組
            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)

        # 遍歷當前物件的子模組,應用遞迴前饋函式
        for module in self.children():
            fn_recursive_feed_forward(module, chunk_size, dim)

    # 定義一個方法以禁用前向分塊處理
    def disable_forward_chunking(self) -> None:
        # 定義遞迴前饋函式以設定模組的分塊前饋處理為None
        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            # 如果模組有set_chunk_feed_forward屬性,設定塊大小和維度為None
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            # 遍歷模組的子模組
            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)

        # 遍歷當前物件的子模組,應用遞迴前饋函式
        for module in self.children():
            fn_recursive_feed_forward(module, None, 0)

    # 從diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor複製的方法
    def set_default_attn_processor(self) -> None:
        """
        禁用自定義注意力處理器並設定預設的注意力實現。
        """
        # 如果所有注意力處理器都是ADDED_KV_ATTENTION_PROCESSORS型別
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 設定處理器為AttnAddedKVProcessor
            processor = AttnAddedKVProcessor()
        # 如果所有注意力處理器都是CROSS_ATTENTION_PROCESSORS型別
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 設定處理器為AttnProcessor
            processor = AttnProcessor()
        else:
            # 丟擲錯誤,表示不能在不匹配的注意力處理器型別下呼叫該方法
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )

        # 設定當前物件的注意力處理器
        self.set_attn_processor(processor)

    # 定義一個方法以設定模組的梯度檢查點
    def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
        # 檢查模組是否為特定型別,如果是則設定其梯度檢查點屬性
        if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
            module.gradient_checkpointing = value

    # 從diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu複製的方法
    # 啟用 FreeU 機制,接受四個浮點型縮放因子作為引數
    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
        # 文件字串,描述該方法的作用及引數含義
        r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.

        The suffixes after the scaling factors represent the stage blocks where they are being applied.

        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
        are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.

        Args:
            s1 (`float`):
                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
                mitigate the "oversmoothing effect" in the enhanced denoising process.
            s2 (`float`):
                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
                mitigate the "oversmoothing effect" in the enhanced denoising process.
            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
        """
        # 遍歷上取樣塊,併為每個塊設定縮放因子
        for i, upsample_block in enumerate(self.up_blocks):
            # 為上取樣塊設定階段1的縮放因子
            setattr(upsample_block, "s1", s1)
            # 為上取樣塊設定階段2的縮放因子
            setattr(upsample_block, "s2", s2)
            # 為上取樣塊設定階段1的主幹特徵縮放因子
            setattr(upsample_block, "b1", b1)
            # 為上取樣塊設定階段2的主幹特徵縮放因子
            setattr(upsample_block, "b2", b2)

    # 禁用 FreeU 機制
    def disable_freeu(self) -> None:
        # 文件字串,描述該方法的作用
        """Disables the FreeU mechanism."""
        # 定義 FreeU 相關的鍵名集合
        freeu_keys = {"s1", "s2", "b1", "b2"}
        # 遍歷上取樣塊
        for i, upsample_block in enumerate(self.up_blocks):
            # 遍歷 FreeU 鍵名
            for k in freeu_keys:
                # 檢查上取樣塊是否具有該屬性或該屬性是否不為 None
                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
                    # 將上取樣塊的該屬性設定為 None
                    setattr(upsample_block, k, None)

    # 啟用融合的 QKV 投影
    def fuse_qkv_projections(self):
        # 文件字串,描述該方法的作用
        """
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
        are fused. For cross-attention modules, key and value projection matrices are fused.

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>
        """
        # 初始化原始注意力處理器為 None
        self.original_attn_processors = None

        # 遍歷注意力處理器
        for _, attn_processor in self.attn_processors.items():
            # 檢查注意力處理器類名中是否包含 "Added"
            if "Added" in str(attn_processor.__class__.__name__):
                # 丟擲異常,說明不支援該操作
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

        # 儲存原始的注意力處理器
        self.original_attn_processors = self.attn_processors

        # 遍歷所有模組
        for module in self.modules():
            # 檢查模組是否為 Attention 型別
            if isinstance(module, Attention):
                # 融合投影
                module.fuse_projections(fuse=True)

        # 設定融合後的注意力處理器
        self.set_attn_processor(FusedAttnProcessor2_0())

    # 解融合 QKV 投影的方法(省略具體實現)
    # 定義一個禁用融合 QKV 投影的方法
    def unfuse_qkv_projections(self):
        """如果啟用了,禁用融合 QKV 投影。
    
        <Tip warning={true}>
        
        此 API 是 🧪 實驗性。
        
        </Tip>
    
        """
        # 檢查原始注意力處理器是否不為 None
        if self.original_attn_processors is not None:
            # 設定當前注意力處理器為原始的注意力處理器
            self.set_attn_processor(self.original_attn_processors)
    
    # 定義前向傳播方法,接收多個引數
    def forward(
        self,
        # 輸入樣本張量
        sample: torch.Tensor,
        # 時間步,可以是張量、浮點數或整數
        timestep: Union[torch.Tensor, float, int],
        # 編碼器隱藏狀態張量
        encoder_hidden_states: torch.Tensor,
        # 可選的時間步條件張量
        timestep_cond: Optional[torch.Tensor] = None,
        # 可選的注意力掩碼張量
        attention_mask: Optional[torch.Tensor] = None,
        # 可選的交叉注意力引數字典
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 可選的附加條件引數字典
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        # 可選的下塊附加殘差元組
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        # 可選的中間塊附加殘差張量
        mid_block_additional_residual: Optional[torch.Tensor] = None,
        # 是否返回字典格式的結果,預設為 True
        return_dict: bool = True,

相關文章