diffusers-原始碼解析-十四-

绝不原创的飞龙發表於2024-10-22

diffusers 原始碼解析(十四)

.\diffusers\models\unets\unet_2d_blocks_flax.py

# 版權宣告,說明該檔案的版權資訊及相關許可協議
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 許可資訊,使用 Apache License 2.0 許可
# Licensed under the Apache License, Version 2.0 (the "License");
# 該檔案只能在符合許可證的情況下使用
# you may not use this file except in compliance with the License.
# 提供許可證獲取連結
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 免責宣告,表明不提供任何形式的保證或條件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 許可證的相關許可權及限制
# See the License for the specific language governing permissions and
# limitations under the License.

# 匯入 flax.linen 模組,用於構建神經網路
import flax.linen as nn
# 匯入 jax.numpy,用於數值計算
import jax.numpy as jnp

# 從其他模組匯入特定的類,用於構建模型的各個元件
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D


# 定義 FlaxCrossAttnDownBlock2D 類,表示一個 2D 跨注意力下采樣模組
class FlaxCrossAttnDownBlock2D(nn.Module):
    r"""
    跨注意力 2D 下采樣塊 - 原始架構來自 Unet transformers:
    https://arxiv.org/abs/2103.06104

    引數說明:
        in_channels (:obj:`int`):
            輸入通道數
        out_channels (:obj:`int`):
            輸出通道數
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout 率
        num_layers (:obj:`int`, *optional*, defaults to 1):
            注意力塊層數
        num_attention_heads (:obj:`int`, *optional*, defaults to 1):
            每個空間變換塊的注意力頭數
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            是否在每個最終輸出之前新增下采樣層
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
            啟用記憶體高效的注意力 https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):
            是否將頭維度拆分為一個新的軸進行自注意力計算。在大多數情況下,
            啟用此標誌應加快 Stable Diffusion 2.x 和 Stable Diffusion XL 的計算速度。
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            引數的資料型別
    """

    # 定義輸入通道數
    in_channels: int
    # 定義輸出通道數
    out_channels: int
    # 定義 Dropout 率,預設為 0.0
    dropout: float = 0.0
    # 定義注意力塊的層數,預設為 1
    num_layers: int = 1
    # 定義注意力頭數,預設為 1
    num_attention_heads: int = 1
    # 定義是否新增下采樣層,預設為 True
    add_downsample: bool = True
    # 定義是否使用線性投影,預設為 False
    use_linear_projection: bool = False
    # 定義是否僅使用跨注意力,預設為 False
    only_cross_attention: bool = False
    # 定義是否啟用記憶體高效注意力,預設為 False
    use_memory_efficient_attention: bool = False
    # 定義是否拆分頭維度,預設為 False
    split_head_dim: bool = False
    # 定義引數的資料型別,預設為 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 定義每個塊的變換器層數,預設為 1
    transformer_layers_per_block: int = 1
    # 設定模型的各個組成部分,包括殘差塊和注意力塊
        def setup(self):
            # 初始化殘差塊列表
            resnets = []
            # 初始化注意力塊列表
            attentions = []
    
            # 遍歷每一層,構建殘差塊和注意力塊
            for i in range(self.num_layers):
                # 第一層的輸入通道為 in_channels,其他層為 out_channels
                in_channels = self.in_channels if i == 0 else self.out_channels
    
                # 建立一個 FlaxResnetBlock2D 例項
                res_block = FlaxResnetBlock2D(
                    in_channels=in_channels,  # 輸入通道
                    out_channels=self.out_channels,  # 輸出通道
                    dropout_prob=self.dropout,  # 丟棄率
                    dtype=self.dtype,  # 資料型別
                )
                # 將殘差塊新增到列表中
                resnets.append(res_block)
    
                # 建立一個 FlaxTransformer2DModel 例項
                attn_block = FlaxTransformer2DModel(
                    in_channels=self.out_channels,  # 輸入通道
                    n_heads=self.num_attention_heads,  # 注意力頭數
                    d_head=self.out_channels // self.num_attention_heads,  # 每個頭的維度
                    depth=self.transformer_layers_per_block,  # 每個塊的層數
                    use_linear_projection=self.use_linear_projection,  # 是否使用線性投影
                    only_cross_attention=self.only_cross_attention,  # 是否只使用交叉注意力
                    use_memory_efficient_attention=self.use_memory_efficient_attention,  # 是否使用記憶體高效的注意力
                    split_head_dim=self.split_head_dim,  # 是否拆分頭的維度
                    dtype=self.dtype,  # 資料型別
                )
                # 將注意力塊新增到列表中
                attentions.append(attn_block)
    
            # 將殘差塊列表賦值給例項變數
            self.resnets = resnets
            # 將注意力塊列表賦值給例項變數
            self.attentions = attentions
    
            # 如果需要下采樣,則建立下采樣層
            if self.add_downsample:
                self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
    
        # 定義前向呼叫方法,處理隱藏狀態和編碼器隱藏狀態
        def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
            # 初始化輸出狀態元組
            output_states = ()
    
            # 遍歷殘差塊和注意力塊並進行處理
            for resnet, attn in zip(self.resnets, self.attentions):
                # 透過殘差塊處理隱藏狀態
                hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
                # 透過注意力塊處理隱藏狀態
                hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
                # 將當前隱藏狀態新增到輸出狀態元組中
                output_states += (hidden_states,)
    
            # 如果需要下采樣,則進行下采樣
            if self.add_downsample:
                hidden_states = self.downsamplers_0(hidden_states)
                # 將下采樣後的隱藏狀態新增到輸出狀態元組中
                output_states += (hidden_states,)
    
            # 返回最終的隱藏狀態和輸出狀態元組
            return hidden_states, output_states
# 定義 Flax 2D 降維塊類,繼承自 nn.Module
class FlaxDownBlock2D(nn.Module):
    r"""
    Flax 2D downsizing block

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add downsampling layer before each final output
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
    
    # 宣告輸入輸出通道和其他引數
    in_channels: int
    out_channels: int
    dropout: float = 0.0
    num_layers: int = 1
    add_downsample: bool = True
    dtype: jnp.dtype = jnp.float32

    # 設定方法,用於初始化模型的層
    def setup(self):
        # 建立空列表以儲存殘差塊
        resnets = []

        # 根據層數建立殘差塊
        for i in range(self.num_layers):
            # 第一個塊的輸入通道為 in_channels,其餘為 out_channels
            in_channels = self.in_channels if i == 0 else self.out_channels

            # 建立殘差塊例項
            res_block = FlaxResnetBlock2D(
                in_channels=in_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
            )
            # 將殘差塊新增到列表中
            resnets.append(res_block)
        # 將列表賦值給例項屬性
        self.resnets = resnets

        # 如果需要,新增降取樣層
        if self.add_downsample:
            self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)

    # 呼叫方法,執行前向傳播
    def __call__(self, hidden_states, temb, deterministic=True):
        # 建立空元組以儲存輸出狀態
        output_states = ()

        # 遍歷所有殘差塊進行前向傳播
        for resnet in self.resnets:
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
            # 將當前隱藏狀態新增到輸出狀態中
            output_states += (hidden_states,)

        # 如果需要,應用降取樣層
        if self.add_downsample:
            hidden_states = self.downsamplers_0(hidden_states)
            # 將降取樣後的隱藏狀態新增到輸出狀態中
            output_states += (hidden_states,)

        # 返回最終的隱藏狀態和輸出狀態
        return hidden_states, output_states


# 定義 Flax 交叉注意力 2D 上取樣塊類,繼承自 nn.Module
class FlaxCrossAttnUpBlock2D(nn.Module):
    r"""
    Cross Attention 2D Upsampling block - original architecture from Unet transformers:
    https://arxiv.org/abs/2103.06104
    # 定義引數的文件字串,描述各個引數的用途和型別
        Parameters:
            in_channels (:obj:`int`):  # 輸入通道數
                Input channels
            out_channels (:obj:`int`):  # 輸出通道數
                Output channels
            dropout (:obj:`float`, *optional*, defaults to 0.0):  # Dropout 率,預設值為 0.0
                Dropout rate
            num_layers (:obj:`int`, *optional*, defaults to 1):  # 注意力塊的層數,預設值為 1
                Number of attention blocks layers
            num_attention_heads (:obj:`int`, *optional*, defaults to 1):  # 每個空間變換塊的注意力頭數量,預設值為 1
                Number of attention heads of each spatial transformer block
            add_upsample (:obj:`bool`, *optional*, defaults to `True`):  # 是否在每個最終輸出前新增上取樣層,預設值為 True
                Whether to add upsampling layer before each final output
            use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):  # 啟用記憶體高效注意力,預設值為 False
                enable memory efficient attention https://arxiv.org/abs/2112.05682
            split_head_dim (`bool`, *optional*, defaults to `False`):  # 是否將頭維度拆分為新軸以進行自注意力計算,預設值為 False
                Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
                enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
            dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):  # 資料型別引數,預設值為 jnp.float32
                Parameters `dtype`
        """
    
        in_channels: int  # 輸入通道數的宣告
        out_channels: int  # 輸出通道數的宣告
        prev_output_channel: int  # 前一個輸出通道數的宣告
        dropout: float = 0.0  # Dropout 率的宣告,預設值為 0.0
        num_layers: int = 1  # 注意力層數的宣告,預設值為 1
        num_attention_heads: int = 1  # 注意力頭數量的宣告,預設值為 1
        add_upsample: bool = True  # 是否新增上取樣的宣告,預設值為 True
        use_linear_projection: bool = False  # 是否使用線性投影的宣告,預設值為 False
        only_cross_attention: bool = False  # 是否僅使用交叉注意力的宣告,預設值為 False
        use_memory_efficient_attention: bool = False  # 是否啟用記憶體高效注意力的宣告,預設值為 False
        split_head_dim: bool = False  # 是否拆分頭維度的宣告,預設值為 False
        dtype: jnp.dtype = jnp.float32  # 資料型別的宣告,預設值為 jnp.float32
        transformer_layers_per_block: int = 1  # 每個塊的變換層數的宣告,預設值為 1
    # 設定方法,初始化網路結構
    def setup(self):
        # 初始化空列表以儲存 ResNet 塊
        resnets = []
        # 初始化空列表以儲存注意力塊
        attentions = []
    
        # 遍歷每一層以建立相應的 ResNet 和注意力塊
        for i in range(self.num_layers):
            # 設定跳躍連線的通道數,最後一層使用輸入通道,否則使用輸出通道
            res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
            # 設定當前 ResNet 塊的輸入通道,第一層使用前一層的輸出通道
            resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
    
            # 建立 FlaxResnetBlock2D 例項
            res_block = FlaxResnetBlock2D(
                # 設定輸入通道為當前 ResNet 塊輸入通道加跳躍連線通道
                in_channels=resnet_in_channels + res_skip_channels,
                # 設定輸出通道為指定的輸出通道
                out_channels=self.out_channels,
                # 設定 dropout 機率
                dropout_prob=self.dropout,
                # 設定資料型別
                dtype=self.dtype,
            )
            # 將建立的 ResNet 塊新增到列表中
            resnets.append(res_block)
    
            # 建立 FlaxTransformer2DModel 例項
            attn_block = FlaxTransformer2DModel(
                # 設定輸入通道為輸出通道
                in_channels=self.out_channels,
                # 設定注意力頭數
                n_heads=self.num_attention_heads,
                # 設定每個注意力頭的維度
                d_head=self.out_channels // self.num_attention_heads,
                # 設定 transformer 塊的深度
                depth=self.transformer_layers_per_block,
                # 設定是否使用線性投影
                use_linear_projection=self.use_linear_projection,
                # 設定是否僅使用交叉注意力
                only_cross_attention=self.only_cross_attention,
                # 設定是否使用記憶體高效的注意力機制
                use_memory_efficient_attention=self.use_memory_efficient_attention,
                # 設定是否分割頭部維度
                split_head_dim=self.split_head_dim,
                # 設定資料型別
                dtype=self.dtype,
            )
            # 將建立的注意力塊新增到列表中
            attentions.append(attn_block)
    
        # 將 ResNet 列表儲存到例項屬性
        self.resnets = resnets
        # 將注意力列表儲存到例項屬性
        self.attentions = attentions
    
        # 如果需要新增上取樣層,則建立相應的 FlaxUpsample2D 例項
        if self.add_upsample:
            self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
    
    # 定義呼叫方法,接受隱藏狀態和其他引數
    def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
        # 遍歷 ResNet 和注意力塊
        for resnet, attn in zip(self.resnets, self.attentions):
            # 從跳躍連線的隱藏狀態元組中取出最後一個狀態
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新跳躍連線的隱藏狀態元組,去掉最後一個狀態
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
            # 將隱藏狀態與跳躍連線的隱藏狀態在最後一個軸上拼接
            hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
    
            # 使用當前的 ResNet 塊處理隱藏狀態
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
            # 使用當前的注意力塊處理隱藏狀態
            hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
    
        # 如果需要新增上取樣,則使用上取樣層處理隱藏狀態
        if self.add_upsample:
            hidden_states = self.upsamplers_0(hidden_states)
    
        # 返回處理後的隱藏狀態
        return hidden_states
# 定義一個 2D 上取樣塊類,繼承自 nn.Module
class FlaxUpBlock2D(nn.Module):
    r"""
    Flax 2D upsampling block

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        prev_output_channel (:obj:`int`):
            Output channels from the previous block
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add downsampling layer before each final output
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """

    # 定義輸入輸出通道和其他引數
    in_channels: int
    out_channels: int
    prev_output_channel: int
    dropout: float = 0.0
    num_layers: int = 1
    add_upsample: bool = True
    dtype: jnp.dtype = jnp.float32

    # 設定方法用於初始化塊的結構
    def setup(self):
        resnets = []  # 建立一個空列表用於儲存 ResNet 塊

        # 遍歷每一層,建立 ResNet 塊
        for i in range(self.num_layers):
            # 計算跳躍連線通道數
            res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
            # 設定輸入通道數
            resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels

            # 建立一個新的 FlaxResnetBlock2D 例項
            res_block = FlaxResnetBlock2D(
                in_channels=resnet_in_channels + res_skip_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
            )
            resnets.append(res_block)  # 將塊新增到列表中

        self.resnets = resnets  # 將列表賦值給例項變數

        # 如果需要上取樣,初始化上取樣層
        if self.add_upsample:
            self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)

    # 定義前向傳播方法
    def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
        # 遍歷每個 ResNet 塊進行前向傳播
        for resnet in self.resnets:
            # 從元組中彈出最後的殘差隱藏狀態
            res_hidden_states = res_hidden_states_tuple[-1]
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]  # 更新元組,去掉最後一項
            # 連線當前隱藏狀態與殘差隱藏狀態
            hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)

            # 透過 ResNet 塊處理隱藏狀態
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)

        # 如果需要上取樣,呼叫上取樣層
        if self.add_upsample:
            hidden_states = self.upsamplers_0(hidden_states)

        return hidden_states  # 返回處理後的隱藏狀態


# 定義一個 2D 中級交叉注意力塊類,繼承自 nn.Module
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
    r"""
    Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
    # 定義引數的文件字串
    Parameters:
        in_channels (:obj:`int`):  # 輸入通道數
            Input channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):  # Dropout比率,預設為0.0
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):  # 注意力層的數量,預設為1
            Number of attention blocks layers
        num_attention_heads (:obj:`int`, *optional*, defaults to 1):  # 每個空間變換塊的注意力頭數量,預設為1
            Number of attention heads of each spatial transformer block
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):  # 是否啟用記憶體高效的注意力機制,預設為False
            enable memory efficient attention https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):  # 是否將頭維度分割為新的軸以加速計算,預設為False
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):  # 資料型別引數,預設為jnp.float32
            Parameters `dtype`
    """

    in_channels: int  # 輸入通道數的型別
    dropout: float = 0.0  # Dropout比率的預設值
    num_layers: int = 1  # 注意力層數量的預設值
    num_attention_heads: int = 1  # 注意力頭數量的預設值
    use_linear_projection: bool = False  # 是否使用線性投影的預設值
    use_memory_efficient_attention: bool = False  # 是否使用記憶體高效注意力的預設值
    split_head_dim: bool = False  # 是否分割頭維度的預設值
    dtype: jnp.dtype = jnp.float32  # 資料型別的預設值
    transformer_layers_per_block: int = 1  # 每個塊中的變換層數量的預設值

    def setup(self):  # 設定方法,用於初始化
        # 至少會有一個ResNet塊
        resnets = [  # 建立ResNet塊列表
            FlaxResnetBlock2D(  # 建立一個ResNet塊
                in_channels=self.in_channels,  # 輸入通道數
                out_channels=self.in_channels,  # 輸出通道數
                dropout_prob=self.dropout,  # Dropout機率
                dtype=self.dtype,  # 資料型別
            )
        ]

        attentions = []  # 初始化注意力塊列表

        for _ in range(self.num_layers):  # 遍歷指定的注意力層數
            attn_block = FlaxTransformer2DModel(  # 建立一個Transformer塊
                in_channels=self.in_channels,  # 輸入通道數
                n_heads=self.num_attention_heads,  # 注意力頭數量
                d_head=self.in_channels // self.num_attention_heads,  # 每個頭的維度
                depth=self.transformer_layers_per_block,  # 變換層深度
                use_linear_projection=self.use_linear_projection,  # 是否使用線性投影
                use_memory_efficient_attention=self.use_memory_efficient_attention,  # 是否使用記憶體高效注意力
                split_head_dim=self.split_head_dim,  # 是否分割頭維度
                dtype=self.dtype,  # 資料型別
            )
            attentions.append(attn_block)  # 將注意力塊新增到列表中

            res_block = FlaxResnetBlock2D(  # 建立一個ResNet塊
                in_channels=self.in_channels,  # 輸入通道數
                out_channels=self.in_channels,  # 輸出通道數
                dropout_prob=self.dropout,  # Dropout機率
                dtype=self.dtype,  # 資料型別
            )
            resnets.append(res_block)  # 將ResNet塊新增到列表中

        self.resnets = resnets  # 將ResNet塊列表賦值給例項屬性
        self.attentions = attentions  # 將注意力塊列表賦值給例項屬性

    def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):  # 呼叫方法
        hidden_states = self.resnets[0](hidden_states, temb)  # 透過第一個ResNet塊處理隱藏狀態
        for attn, resnet in zip(self.attentions, self.resnets[1:]):  # 遍歷每個注意力塊和後續ResNet塊
            hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)  # 處理隱藏狀態
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)  # 再次處理隱藏狀態

        return hidden_states  # 返回處理後的隱藏狀態

.\diffusers\models\unets\unet_2d_condition.py

# 版權宣告,標明版權資訊和使用許可
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 按照 Apache License 2.0 版本進行許可
# Licensed under the Apache License, Version 2.0 (the "License");
# 你不得在未遵守許可的情況下使用此檔案
# you may not use this file except in compliance with the License.
# 你可以在以下網址獲得許可證的副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或書面同意,軟體以“按原樣”方式分發,不提供任何形式的擔保或條件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 檢視許可證以瞭解特定語言所適用的許可權和限制
# See the License for the specific language governing permissions and
# limitations under the License.

# 從 dataclasses 模組匯入 dataclass 裝飾器,用於簡化類的定義
from dataclasses import dataclass
# 匯入所需的型別註釋
from typing import Any, Dict, List, Optional, Tuple, Union

# 匯入 PyTorch 庫和相關模組
import torch
import torch.nn as nn
import torch.utils.checkpoint

# 從配置和載入器模組中匯入所需的類和函式
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation
from ..attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,  # 匯入與注意力機制相關的處理器
    CROSS_ATTENTION_PROCESSORS,
    Attention,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
    FusedAttnProcessor2_0,
)
from ..embeddings import (
    GaussianFourierProjection,  # 匯入多種嵌入方法
    GLIGENTextBoundingboxProjection,
    ImageHintTimeEmbedding,
    ImageProjection,
    ImageTimeEmbedding,
    TextImageProjection,
    TextImageTimeEmbedding,
    TextTimeEmbedding,
    TimestepEmbedding,
    Timesteps,
)
from ..modeling_utils import ModelMixin  # 匯入模型混合類
from .unet_2d_blocks import (
    get_down_block,  # 匯入下采樣塊的建構函式
    get_mid_block,   # 匯入中間塊的建構函式
    get_up_block,    # 匯入上取樣塊的建構函式
)

# 建立一個日誌記錄器,用於記錄模型相關資訊
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定義 UNet2DConditionOutput 資料類,用於儲存 UNet2DConditionModel 的輸出
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
    UNet2DConditionModel 的輸出。

    引數:
        sample (`torch.Tensor`,形狀為 `(batch_size, num_channels, height, width)`):
            基於 `encoder_hidden_states` 輸入的隱藏狀態輸出,模型最後一層的輸出。
    """

    sample: torch.Tensor = None  # 定義一個樣本屬性,預設為 None

# 定義 UNet2DConditionModel 類,表示一個條件 2D UNet 模型
class UNet2DConditionModel(
    ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
):
    r"""
    一個條件 2D UNet 模型,接受一個噪聲樣本、條件狀態和時間步,並返回樣本形狀的輸出。

    該模型繼承自 [`ModelMixin`]。檢視超類文件以獲取其為所有模型實現的通用方法
    (例如下載或儲存)。
    """

    _supports_gradient_checkpointing = True  # 表示該模型支援梯度檢查點
    _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]  # 不進行拆分的模組列表

    @register_to_config  # 將該方法註冊到配置中
    # 初始化方法,設定類的基本屬性
        def __init__(
            # 樣本大小,預設為 None
            self,
            sample_size: Optional[int] = None,
            # 輸入通道數,預設為 4
            in_channels: int = 4,
            # 輸出通道數,預設為 4
            out_channels: int = 4,
            # 是否將輸入樣本中心化,預設為 False
            center_input_sample: bool = False,
            # 是否將正弦函式翻轉為餘弦函式,預設為 True
            flip_sin_to_cos: bool = True,
            # 頻率偏移量,預設為 0
            freq_shift: int = 0,
            # 向下取樣的塊型別,包含多種塊型別
            down_block_types: Tuple[str] = (
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "DownBlock2D",
            ),
            # 中間塊的型別,預設為 UNet 的中間塊型別
            mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
            # 向上取樣的塊型別,包含多種塊型別
            up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
            # 是否僅使用交叉注意力,預設為 False
            only_cross_attention: Union[bool, Tuple[bool]] = False,
            # 每個塊的輸出通道數
            block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
            # 每個塊的層數,預設為 2
            layers_per_block: Union[int, Tuple[int]] = 2,
            # 下采樣時的填充大小,預設為 1
            downsample_padding: int = 1,
            # 中間塊的縮放因子,預設為 1
            mid_block_scale_factor: float = 1,
            # dropout 機率,預設為 0.0
            dropout: float = 0.0,
            # 啟用函式型別,預設為 "silu"
            act_fn: str = "silu",
            # 歸一化的組數,預設為 32
            norm_num_groups: Optional[int] = 32,
            # 歸一化的 epsilon 值,預設為 1e-5
            norm_eps: float = 1e-5,
            # 交叉注意力的維度,預設為 1280
            cross_attention_dim: Union[int, Tuple[int]] = 1280,
            # 每個塊的變換層數,預設為 1
            transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
            # 反向變換層的塊數,預設為 None
            reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
            # 編碼器隱藏層的維度,預設為 None
            encoder_hid_dim: Optional[int] = None,
            # 編碼器隱藏層型別,預設為 None
            encoder_hid_dim_type: Optional[str] = None,
            # 注意力頭的維度,預設為 8
            attention_head_dim: Union[int, Tuple[int]] = 8,
            # 注意力頭的數量,預設為 None
            num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
            # 是否使用雙交叉注意力,預設為 False
            dual_cross_attention: bool = False,
            # 是否使用線性投影,預設為 False
            use_linear_projection: bool = False,
            # 類嵌入型別,預設為 None
            class_embed_type: Optional[str] = None,
            # 附加嵌入型別,預設為 None
            addition_embed_type: Optional[str] = None,
            # 附加時間嵌入維度,預設為 None
            addition_time_embed_dim: Optional[int] = None,
            # 類嵌入數量,預設為 None
            num_class_embeds: Optional[int] = None,
            # 是否上溯注意力,預設為 False
            upcast_attention: bool = False,
            # ResNet 時間縮放偏移型別,預設為 "default"
            resnet_time_scale_shift: str = "default",
            # ResNet 是否跳過時間啟用,預設為 False
            resnet_skip_time_act: bool = False,
            # ResNet 輸出縮放因子,預設為 1.0
            resnet_out_scale_factor: float = 1.0,
            # 時間嵌入型別,預設為 "positional"
            time_embedding_type: str = "positional",
            # 時間嵌入維度,預設為 None
            time_embedding_dim: Optional[int] = None,
            # 時間嵌入啟用函式,預設為 None
            time_embedding_act_fn: Optional[str] = None,
            # 時間步後啟用函式,預設為 None
            timestep_post_act: Optional[str] = None,
            # 時間條件投影維度,預設為 None
            time_cond_proj_dim: Optional[int] = None,
            # 輸入卷積核大小,預設為 3
            conv_in_kernel: int = 3,
            # 輸出卷積核大小,預設為 3
            conv_out_kernel: int = 3,
            # 投影類嵌入輸入維度,預設為 None
            projection_class_embeddings_input_dim: Optional[int] = None,
            # 注意力型別,預設為 "default"
            attention_type: str = "default",
            # 類嵌入是否拼接,預設為 False
            class_embeddings_concat: bool = False,
            # 中間塊是否僅使用交叉注意力,預設為 None
            mid_block_only_cross_attention: Optional[bool] = None,
            # 交叉注意力歸一化型別,預設為 None
            cross_attention_norm: Optional[str] = None,
            # 附加嵌入型別的頭數量,預設為 64
            addition_embed_type_num_heads: int = 64,
    # 定義一個私有方法,用於檢查配置引數
        def _check_config(
            self,
            # 定義下行塊型別的元組,表示模型的結構
            down_block_types: Tuple[str],
            # 定義上行塊型別的元組,表示模型的結構
            up_block_types: Tuple[str],
            # 定義僅使用交叉注意力的標誌,可以是布林值或布林值的元組
            only_cross_attention: Union[bool, Tuple[bool]],
            # 定義每個塊的輸出通道數的元組,表示層的寬度
            block_out_channels: Tuple[int],
            # 定義每個塊的層數,可以是整數或整數的元組
            layers_per_block: Union[int, Tuple[int]],
            # 定義交叉注意力維度,可以是整數或整數的元組
            cross_attention_dim: Union[int, Tuple[int]],
            # 定義每個塊的變換器層數,可以是整數、整數的元組或元組的元組
            transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
            # 定義是否反轉變換器層的布林值
            reverse_transformer_layers_per_block: bool,
            # 定義注意力頭的維度,表示注意力的解析度
            attention_head_dim: int,
            # 定義注意力頭的數量,可以是可選的整數或整數的元組
            num_attention_heads: Optional[Union[int, Tuple[int]],
    ):
        # 檢查 down_block_types 和 up_block_types 的長度是否相同
        if len(down_block_types) != len(up_block_types):
            # 如果不同,丟擲值錯誤並提供詳細資訊
            raise ValueError(
                f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
            )

        # 檢查 block_out_channels 和 down_block_types 的長度是否相同
        if len(block_out_channels) != len(down_block_types):
            # 如果不同,丟擲值錯誤並提供詳細資訊
            raise ValueError(
                f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
            )

        # 檢查 only_cross_attention 是否為布林值且長度與 down_block_types 相同
        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
            # 如果不滿足條件,丟擲值錯誤並提供詳細資訊
            raise ValueError(
                f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
            )

        # 檢查 num_attention_heads 是否為整數且長度與 down_block_types 相同
        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
            # 如果不滿足條件,丟擲值錯誤並提供詳細資訊
            raise ValueError(
                f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
            )

        # 檢查 attention_head_dim 是否為整數且長度與 down_block_types 相同
        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
            # 如果不滿足條件,丟擲值錯誤並提供詳細資訊
            raise ValueError(
                f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
            )

        # 檢查 cross_attention_dim 是否為列表且長度與 down_block_types 相同
        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
            # 如果不滿足條件,丟擲值錯誤並提供詳細資訊
            raise ValueError(
                f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
            )

        # 檢查 layers_per_block 是否為整數且長度與 down_block_types 相同
        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
            # 如果不滿足條件,丟擲值錯誤並提供詳細資訊
            raise ValueError(
                f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
            )
        # 檢查 transformer_layers_per_block 是否為列表且 reverse_transformer_layers_per_block 為 None
        if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
            # 遍歷 transformer_layers_per_block 中的每個層
            for layer_number_per_block in transformer_layers_per_block:
                # 檢查每個層是否為列表
                if isinstance(layer_number_per_block, list):
                    # 如果是,則丟擲值錯誤,提示需要提供 reverse_transformer_layers_per_block
                    raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")

    # 定義設定時間投影的私有方法
    def _set_time_proj(
        self,
        # 時間嵌入型別
        time_embedding_type: str,
        # 塊輸出通道數
        block_out_channels: int,
        # 是否翻轉正弦和餘弦
        flip_sin_to_cos: bool,
        # 頻率偏移
        freq_shift: float,
        # 時間嵌入維度
        time_embedding_dim: int,
    # 返回時間嵌入維度和時間步輸入維度的元組
    ) -> Tuple[int, int]:
        # 判斷時間嵌入型別是否為傅立葉
        if time_embedding_type == "fourier":
            # 計算時間嵌入維度,預設為 block_out_channels[0] * 2
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
            # 確保時間嵌入維度為偶數
            if time_embed_dim % 2 != 0:
                raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
            # 初始化高斯傅立葉投影,設定相關引數
            self.time_proj = GaussianFourierProjection(
                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
            )
            # 設定時間步輸入維度為時間嵌入維度
            timestep_input_dim = time_embed_dim
        # 判斷時間嵌入型別是否為位置編碼
        elif time_embedding_type == "positional":
            # 計算時間嵌入維度,預設為 block_out_channels[0] * 4
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
            # 初始化時間步物件,設定相關引數
            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
            # 設定時間步輸入維度為 block_out_channels[0]
            timestep_input_dim = block_out_channels[0]
        # 如果時間嵌入型別不合法,丟擲錯誤
        else:
            raise ValueError(
                f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
            )
    
        # 返回時間嵌入維度和時間步輸入維度
        return time_embed_dim, timestep_input_dim
    
    # 定義設定編碼器隱藏投影的方法
    def _set_encoder_hid_proj(
        self,
        encoder_hid_dim_type: Optional[str],
        cross_attention_dim: Union[int, Tuple[int]],
        encoder_hid_dim: Optional[int],
    ):
        # 如果編碼器隱藏維度型別為空且隱藏維度已定義
        if encoder_hid_dim_type is None and encoder_hid_dim is not None:
            # 預設將編碼器隱藏維度型別設為'text_proj'
            encoder_hid_dim_type = "text_proj"
            # 註冊編碼器隱藏維度型別到配置中
            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
            # 記錄資訊日誌
            logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
    
        # 如果編碼器隱藏維度為空且隱藏維度型別已定義,丟擲錯誤
        if encoder_hid_dim is None and encoder_hid_dim_type is not None:
            raise ValueError(
                f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
            )
    
        # 判斷編碼器隱藏維度型別是否為'text_proj'
        if encoder_hid_dim_type == "text_proj":
            # 初始化線性投影層,輸入維度為encoder_hid_dim,輸出維度為cross_attention_dim
            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
        # 判斷編碼器隱藏維度型別是否為'text_image_proj'
        elif encoder_hid_dim_type == "text_image_proj":
            # 初始化文字-影像投影物件,設定相關引數
            self.encoder_hid_proj = TextImageProjection(
                text_embed_dim=encoder_hid_dim,
                image_embed_dim=cross_attention_dim,
                cross_attention_dim=cross_attention_dim,
            )
        # 判斷編碼器隱藏維度型別是否為'image_proj'
        elif encoder_hid_dim_type == "image_proj":
            # 初始化影像投影物件,設定相關引數
            self.encoder_hid_proj = ImageProjection(
                image_embed_dim=encoder_hid_dim,
                cross_attention_dim=cross_attention_dim,
            )
        # 如果編碼器隱藏維度型別不合法,丟擲錯誤
        elif encoder_hid_dim_type is not None:
            raise ValueError(
                f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
            )
        # 如果都不符合,將編碼器隱藏投影設為None
        else:
            self.encoder_hid_proj = None
    # 設定類嵌入的私有方法
        def _set_class_embedding(
            self,
            class_embed_type: Optional[str],  # 嵌入型別,可能為 None 或特定字串
            act_fn: str,  # 啟用函式的名稱
            num_class_embeds: Optional[int],  # 類嵌入數量,可能為 None
            projection_class_embeddings_input_dim: Optional[int],  # 投影類嵌入輸入維度,可能為 None
            time_embed_dim: int,  # 時間嵌入的維度
            timestep_input_dim: int,  # 時間步輸入的維度
        ):
            # 如果嵌入型別為 None 且類嵌入數量不為 None
            if class_embed_type is None and num_class_embeds is not None:
                # 建立嵌入層,大小為類嵌入數量和時間嵌入維度
                self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
            # 如果嵌入型別為 "timestep"
            elif class_embed_type == "timestep":
                # 建立時間步嵌入物件
                self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
            # 如果嵌入型別為 "identity"
            elif class_embed_type == "identity":
                # 建立恆等層,輸入和輸出維度相同
                self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
            # 如果嵌入型別為 "projection"
            elif class_embed_type == "projection":
                # 如果投影類嵌入輸入維度為 None,丟擲錯誤
                if projection_class_embeddings_input_dim is None:
                    raise ValueError(
                        "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
                    )
                # 建立投影時間步嵌入物件
                self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
            # 如果嵌入型別為 "simple_projection"
            elif class_embed_type == "simple_projection":
                # 如果投影類嵌入輸入維度為 None,丟擲錯誤
                if projection_class_embeddings_input_dim is None:
                    raise ValueError(
                        "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
                    )
                # 建立線性層作為簡單投影
                self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
            # 如果沒有匹配的嵌入型別
            else:
                # 將類嵌入設定為 None
                self.class_embedding = None
    
        # 設定附加嵌入的私有方法
        def _set_add_embedding(
            self,
            addition_embed_type: str,  # 附加嵌入型別
            addition_embed_type_num_heads: int,  # 附加嵌入型別的頭數
            addition_time_embed_dim: Optional[int],  # 附加時間嵌入維度,可能為 None
            flip_sin_to_cos: bool,  # 是否翻轉正弦到餘弦
            freq_shift: float,  # 頻率偏移量
            cross_attention_dim: Optional[int],  # 交叉注意力維度,可能為 None
            encoder_hid_dim: Optional[int],  # 編碼器隱藏維度,可能為 None
            projection_class_embeddings_input_dim: Optional[int],  # 投影類嵌入輸入維度,可能為 None
            time_embed_dim: int,  # 時間嵌入維度
    ):
        # 檢查附加嵌入型別是否為 "text"
        if addition_embed_type == "text":
            # 如果編碼器隱藏維度不為 None,則使用該維度
            if encoder_hid_dim is not None:
                text_time_embedding_from_dim = encoder_hid_dim
            # 否則使用交叉注意力維度
            else:
                text_time_embedding_from_dim = cross_attention_dim

            # 建立文字時間嵌入物件
            self.add_embedding = TextTimeEmbedding(
                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
            )
        # 檢查附加嵌入型別是否為 "text_image"
        elif addition_embed_type == "text_image":
            # text_embed_dim 和 image_embed_dim 不必是 `cross_attention_dim`,為了避免 __init__ 過於繁雜
            # 在這裡設定為 `cross_attention_dim`,因為這是當前唯一使用情況的所需維度 (Kandinsky 2.1)
            self.add_embedding = TextImageTimeEmbedding(
                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
            )
        # 檢查附加嵌入型別是否為 "text_time"
        elif addition_embed_type == "text_time":
            # 建立時間投影物件
            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
            # 建立時間嵌入物件
            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
        # 檢查附加嵌入型別是否為 "image"
        elif addition_embed_type == "image":
            # Kandinsky 2.2
            # 建立影像時間嵌入物件
            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
        # 檢查附加嵌入型別是否為 "image_hint"
        elif addition_embed_type == "image_hint":
            # Kandinsky 2.2 ControlNet
            # 建立影像提示時間嵌入物件
            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
        # 檢查附加嵌入型別是否為 None 以外的值
        elif addition_embed_type is not None:
            # 丟擲值錯誤,提示無效的附加嵌入型別
            raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")

    # 定義一個屬性方法,用於設定位置網路
    def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
        # 檢查注意力型別是否為 "gated" 或 "gated-text-image"
        if attention_type in ["gated", "gated-text-image"]:
            positive_len = 768  # 預設的正向長度
            # 如果交叉注意力維度是整數,則使用該值
            if isinstance(cross_attention_dim, int):
                positive_len = cross_attention_dim
            # 如果交叉注意力維度是列表或元組,則使用第一個值
            elif isinstance(cross_attention_dim, (list, tuple)):
                positive_len = cross_attention_dim[0]

            # 根據注意力型別確定特徵型別
            feature_type = "text-only" if attention_type == "gated" else "text-image"
            # 建立 GLIGEN 文字邊界框投影物件
            self.position_net = GLIGENTextBoundingboxProjection(
                positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
            )

    # 定義一個屬性
    @property
    # 定義一個方法,返回一個字典,包含模型中所有的注意力處理器
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # 初始化一個空字典,用於儲存注意力處理器
        processors = {}

        # 定義一個遞迴函式,用於新增處理器到字典
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 檢查模組是否有獲取處理器的方法
            if hasattr(module, "get_processor"):
                # 將處理器新增到字典中,鍵為名稱,值為處理器
                processors[f"{name}.processor"] = module.get_processor()

            # 遍歷模組的所有子模組
            for sub_name, child in module.named_children():
                # 遞迴呼叫,處理子模組
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            # 返回更新後的處理器字典
            return processors

        # 遍歷當前模組的所有子模組
        for name, module in self.named_children():
            # 呼叫遞迴函式,新增處理器
            fn_recursive_add_processors(name, module, processors)

        # 返回包含所有處理器的字典
        return processors

    # 定義一個方法,設定用於計算注意力的處理器
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        # 獲取當前處理器的數量
        count = len(self.attn_processors.keys())

        # 如果傳入的是字典,且字典長度與注意力層數量不匹配,丟擲錯誤
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        # 定義一個遞迴函式,用於設定處理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 檢查模組是否有設定處理器的方法
            if hasattr(module, "set_processor"):
                # 如果處理器不是字典,直接設定
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 從字典中彈出對應的處理器並設定
                    module.set_processor(processor.pop(f"{name}.processor"))

            # 遍歷模組的所有子模組
            for sub_name, child in module.named_children():
                # 遞迴呼叫,設定子模組的處理器
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        # 遍歷當前模組的所有子模組
        for name, module in self.named_children():
            # 呼叫遞迴函式,設定處理器
            fn_recursive_attn_processor(name, module, processor)
    # 定義設定預設注意力處理器的方法
    def set_default_attn_processor(self):
        """
        禁用自定義注意力處理器並設定預設的注意力實現。
        """
        # 檢查所有注意力處理器是否屬於新增的鍵值注意力處理器類
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 建立新增鍵值注意力處理器的例項
            processor = AttnAddedKVProcessor()
        # 檢查所有注意力處理器是否屬於交叉注意力處理器類
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 建立標準注意力處理器的例項
            processor = AttnProcessor()
        else:
            # 如果注意力處理器型別不匹配,則丟擲錯誤
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )

        # 設定選定的注意力處理器
        self.set_attn_processor(processor)

    # 定義設定梯度檢查點的方法
    def _set_gradient_checkpointing(self, module, value=False):
        # 如果模組具有梯度檢查點屬性,則設定其值
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    # 定義啟用 FreeU 機制的方法
    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
        r"""啟用 FreeU 機制,詳細資訊請見 https://arxiv.org/abs/2309.11497。

        在縮放因子後面的字尾表示它們被應用的階段塊。

        請參考 [官方倉庫](https://github.com/ChenyangSi/FreeU) 以獲取已知在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中效果良好的值組合。

        引數:
            s1 (`float`):
                階段 1 的縮放因子,用於減弱跳躍特徵的貢獻,以減輕增強去噪過程中的“過平滑效應”。
            s2 (`float`):
                階段 2 的縮放因子,用於減弱跳躍特徵的貢獻,以減輕增強去噪過程中的“過平滑效應”。
            b1 (`float`): 階段 1 的縮放因子,用於增強骨幹特徵的貢獻。
            b2 (`float`): 階段 2 的縮放因子,用於增強骨幹特徵的貢獻。
        """
        # 遍歷上取樣塊並設定相應的縮放因子
        for i, upsample_block in enumerate(self.up_blocks):
            setattr(upsample_block, "s1", s1)  # 設定階段 1 的縮放因子
            setattr(upsample_block, "s2", s2)  # 設定階段 2 的縮放因子
            setattr(upsample_block, "b1", b1)  # 設定階段 1 的骨幹縮放因子
            setattr(upsample_block, "b2", b2)  # 設定階段 2 的骨幹縮放因子

    # 定義禁用 FreeU 機制的方法
    def disable_freeu(self):
        """禁用 FreeU 機制。"""
        freeu_keys = {"s1", "s2", "b1", "b2"}  # 定義 FreeU 相關的鍵
        # 遍歷上取樣塊
        for i, upsample_block in enumerate(self.up_blocks):
            # 遍歷每個 FreeU 鍵
            for k in freeu_keys:
                # 如果上取樣塊具有該鍵的屬性或其值不為 None,則將其值設定為 None
                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
                    setattr(upsample_block, k, None)
    # 定義一個方法,用於啟用融合的 QKV 投影
    def fuse_qkv_projections(self):
        """
        啟用融合的 QKV 投影。對於自注意力模組,所有投影矩陣(即查詢、鍵、值)都被融合。
        對於交叉注意力模組,鍵和值的投影矩陣被融合。

        <Tip warning={true}>

        此 API 是 🧪 實驗性的。

        </Tip>
        """
        # 初始化原始注意力處理器為 None
        self.original_attn_processors = None

        # 遍歷注意力處理器,檢查是否包含“Added”字樣
        for _, attn_processor in self.attn_processors.items():
            # 如果發現新增的 KV 投影,丟擲錯誤
            if "Added" in str(attn_processor.__class__.__name__):
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

        # 儲存當前的注意力處理器
        self.original_attn_processors = self.attn_processors

        # 遍歷模組,查詢型別為 Attention 的模組
        for module in self.modules():
            if isinstance(module, Attention):
                # 啟用投影融合
                module.fuse_projections(fuse=True)

        # 設定注意力處理器為融合的處理器
        self.set_attn_processor(FusedAttnProcessor2_0())

    # 定義一個方法,用於禁用已啟用的融合 QKV 投影
    def unfuse_qkv_projections(self):
        """禁用已啟用的融合 QKV 投影。

        <Tip warning={true}>

        此 API 是 🧪 實驗性的。

        </Tip>

        """
        # 如果原始注意力處理器不為 None,則恢復到原始處理器
        if self.original_attn_processors is not None:
            self.set_attn_processor(self.original_attn_processors)

    # 定義一個方法,用於獲取時間嵌入
    def get_time_embed(
        self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
    ) -> Optional[torch.Tensor]:
        # 將時間步長賦值給 timesteps
        timesteps = timestep
        # 如果 timesteps 不是張量
        if not torch.is_tensor(timesteps):
            # TODO: 這需要在 CPU 和 GPU 之間同步。因此,如果可以的話,儘量將 timesteps 作為張量傳遞
            # 這將是使用 `match` 語句的好例子(Python 3.10+)
            is_mps = sample.device.type == "mps"  # 檢查裝置型別是否為 MPS
            # 根據時間步長型別設定資料型別
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64  # 浮點數型別
            else:
                dtype = torch.int32 if is_mps else torch.int64  # 整數型別
            # 將 timesteps 轉換為張量
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        # 如果 timesteps 是標量(零維張量),則擴充套件維度
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)  # 增加一個維度並轉移到樣本裝置

        # 將 timesteps 廣播到與樣本批次維度相容的方式
        timesteps = timesteps.expand(sample.shape[0])  # 擴充套件到批次大小

        # 透過時間投影獲得時間嵌入
        t_emb = self.time_proj(timesteps)
        # `Timesteps` 不包含任何權重,總是返回 f32 張量
        # 但時間嵌入可能實際在 fp16 中執行,因此需要進行型別轉換。
        # 可能有更好的方法來封裝這一點。
        t_emb = t_emb.to(dtype=sample.dtype)  # 轉換 t_emb 的資料型別
        # 返回時間嵌入
        return t_emb
    # 獲取類嵌入的方法,接受樣本張量和可選的類標籤
        def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
            # 初始化類嵌入為 None
            class_emb = None
            # 檢查類嵌入是否存在
            if self.class_embedding is not None:
                # 如果類標籤為 None,丟擲錯誤
                if class_labels is None:
                    raise ValueError("class_labels should be provided when num_class_embeds > 0")
    
                # 檢查類嵌入型別是否為時間步
                if self.config.class_embed_type == "timestep":
                    # 將類標籤透過時間投影處理
                    class_labels = self.time_proj(class_labels)
    
                    # `Timesteps` 不包含權重,總是返回 f32 張量
                    # 可能有更好的方式來封裝這一點
                    class_labels = class_labels.to(dtype=sample.dtype)
    
                # 獲取類嵌入並轉換為與樣本相同的資料型別
                class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
            # 返回類嵌入
            return class_emb
    
        # 獲取增強嵌入的方法,接受嵌入張量、編碼器隱藏狀態和額外條件引數
        def get_aug_embed(
            self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
        # 處理編碼器隱藏狀態的方法,接受編碼器隱藏狀態和額外條件引數
        def process_encoder_hidden_states(
            self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
    # 定義返回型別為 torch.Tensor
        ) -> torch.Tensor:
            # 檢查是否存在隱藏層投影,並且配置為 "text_proj"
            if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
                # 使用文字投影對編碼隱藏狀態進行轉換
                encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
            # 檢查是否存在隱藏層投影,並且配置為 "text_image_proj"
            elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
                # 檢查條件中是否包含 "image_embeds"
                if "image_embeds" not in added_cond_kwargs:
                    # 丟擲錯誤提示缺少必要引數
                    raise ValueError(
                        f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                    )
    
                # 獲取傳入的影像嵌入
                image_embeds = added_cond_kwargs.get("image_embeds")
                # 對編碼隱藏狀態和影像嵌入進行投影轉換
                encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
            # 檢查是否存在隱藏層投影,並且配置為 "image_proj"
            elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
                # 檢查條件中是否包含 "image_embeds"
                if "image_embeds" not in added_cond_kwargs:
                    # 丟擲錯誤提示缺少必要引數
                    raise ValueError(
                        f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                    )
                # 獲取傳入的影像嵌入
                image_embeds = added_cond_kwargs.get("image_embeds")
                # 使用影像嵌入對編碼隱藏狀態進行投影轉換
                encoder_hidden_states = self.encoder_hid_proj(image_embeds)
            # 檢查是否存在隱藏層投影,並且配置為 "ip_image_proj"
            elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
                # 檢查條件中是否包含 "image_embeds"
                if "image_embeds" not in added_cond_kwargs:
                    # 丟擲錯誤提示缺少必要引數
                    raise ValueError(
                        f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                    )
    
                # 如果存在文字編碼器的隱藏層投影,則對編碼隱藏狀態進行投影轉換
                if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
                    encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
    
                # 獲取傳入的影像嵌入
                image_embeds = added_cond_kwargs.get("image_embeds")
                # 對影像嵌入進行投影轉換
                image_embeds = self.encoder_hid_proj(image_embeds)
                # 將編碼隱藏狀態和影像嵌入打包成元組
                encoder_hidden_states = (encoder_hidden_states, image_embeds)
            # 返回最終的編碼隱藏狀態
            return encoder_hidden_states
    # 定義前向傳播函式
    def forward(
            # 輸入的樣本資料,型別為 PyTorch 張量
            sample: torch.Tensor,
            # 當前時間步,型別可以是張量、浮點數或整數
            timestep: Union[torch.Tensor, float, int],
            # 編碼器的隱藏狀態,型別為 PyTorch 張量
            encoder_hidden_states: torch.Tensor,
            # 可選的類別標籤,型別為 PyTorch 張量
            class_labels: Optional[torch.Tensor] = None,
            # 可選的時間步條件,型別為 PyTorch 張量
            timestep_cond: Optional[torch.Tensor] = None,
            # 可選的注意力掩碼,型別為 PyTorch 張量
            attention_mask: Optional[torch.Tensor] = None,
            # 可選的交叉注意力引數,型別為字典,包含額外的關鍵字引數
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 可選的附加條件引數,型別為字典,鍵為字串,值為 PyTorch 張量
            added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
            # 可選的下層塊附加殘差,型別為元組,包含 PyTorch 張量
            down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
            # 可選的中間塊附加殘差,型別為 PyTorch 張量
            mid_block_additional_residual: Optional[torch.Tensor] = None,
            # 可選的下層內部塊附加殘差,型別為元組,包含 PyTorch 張量
            down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
            # 可選的編碼器注意力掩碼,型別為 PyTorch 張量
            encoder_attention_mask: Optional[torch.Tensor] = None,
            # 返回結果的標誌,布林值,預設值為 True
            return_dict: bool = True,

.\diffusers\models\unets\unet_2d_condition_flax.py

# 版權宣告,表明該檔案的版權所有者及相關資訊
# 
# 根據 Apache License 2.0 版本的許可協議
# 除非遵守該許可協議,否則不得使用本檔案
# 可以在以下地址獲取許可證副本
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非適用法律或書面協議另有約定,軟體在“按現狀”基礎上分發,
# 不提供任何明示或暗示的保證或條件
# 請參閱許可證瞭解管理許可權和限制的具體條款
from typing import Dict, Optional, Tuple, Union  # 從 typing 模組匯入型別註釋工具

import flax  # 匯入 flax 庫用於構建神經網路
import flax.linen as nn  # 從 flax 中匯入 linen 模組,方便定義神經網路層
import jax  # 匯入 jax 庫用於高效數值計算
import jax.numpy as jnp  # 匯入 jax 的 numpy 模組,提供張量操作功能
from flax.core.frozen_dict import FrozenDict  # 從 flax 匯入 FrozenDict,用於不可變字典

from ...configuration_utils import ConfigMixin, flax_register_to_config  # 匯入配置相關工具
from ...utils import BaseOutput  # 匯入基礎輸出類
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps  # 匯入時間步嵌入相關類
from ..modeling_flax_utils import FlaxModelMixin  # 匯入模型混合類
from .unet_2d_blocks_flax import (  # 匯入 UNet 的不同構建塊
    FlaxCrossAttnDownBlock2D,  # 匯入交叉注意力下采樣塊
    FlaxCrossAttnUpBlock2D,  # 匯入交叉注意力上取樣塊
    FlaxDownBlock2D,  # 匯入下采樣塊
    FlaxUNetMidBlock2DCrossAttn,  # 匯入中間塊,帶有交叉注意力
    FlaxUpBlock2D,  # 匯入上取樣塊
)


@flax.struct.dataclass  # 使用 flax 的資料類裝飾器
class FlaxUNet2DConditionOutput(BaseOutput):  # 定義 UNet 條件輸出類,繼承自基礎輸出類
    """
    [`FlaxUNet2DConditionModel`] 的輸出。

    引數:
        sample (`jnp.ndarray` 的形狀為 `(batch_size, num_channels, height, width)`):
            基於 `encoder_hidden_states` 輸入的隱藏狀態輸出。模型最後一層的輸出。
    """

    sample: jnp.ndarray  # 定義輸出樣本,資料型別為 jnp.ndarray


@flax_register_to_config  # 使用裝飾器將模型註冊到配置中
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):  # 定義條件 UNet 模型類,繼承多個混合類
    r"""
    一個條件 2D UNet 模型,接收噪聲樣本、條件狀態和時間步,並返回樣本形狀的輸出。

    此模型繼承自 [`FlaxModelMixin`]。請檢視超類文件以瞭解其通用方法
    (例如下載或儲存)。

    此模型也是 Flax Linen 的 [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
    子類。將其作為常規 Flax Linen 模組使用,具體使用和行為請參閱 Flax 文件。

    支援以下 JAX 特性:
    - [即時編譯 (JIT)](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [自動微分](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [向量化](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [並行化](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
    # 引數說明部分
    Parameters:
        # 輸入樣本的大小,型別為整型,選填引數
        sample_size (`int`, *optional*):
            The size of the input sample.
        # 輸入樣本的通道數,型別為整型,預設為4
        in_channels (`int`, *optional*, defaults to 4):
            The number of channels in the input sample.
        # 輸出的通道數,型別為整型,預設為4
        out_channels (`int`, *optional*, defaults to 4):
            The number of channels in the output.
        # 使用的下采樣塊的元組,型別為字串元組,預設為特定的下采樣塊
        down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
            The tuple of downsample blocks to use.
        # 使用的上取樣塊的元組,型別為字串元組,預設為特定的上取樣塊
        up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
            The tuple of upsample blocks to use.
        # UNet中間塊的型別,型別為字串,預設為"UNetMidBlock2DCrossAttn"
        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
            Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer
            is skipped.
        # 每個塊的輸出通道的元組,型別為整型元組,預設為特定的輸出通道
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        # 每個塊的層數,型別為整型,預設為2
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        # 注意力頭的維度,可以是整型或整型元組,預設為8
        attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
            The dimension of the attention heads.
        # 注意力頭的數量,可以是整型或整型元組,選填引數
        num_attention_heads (`int` or `Tuple[int]`, *optional*):
            The number of attention heads.
        # 交叉注意力特徵的維度,型別為整型,預設為768
        cross_attention_dim (`int`, *optional*, defaults to 768):
            The dimension of the cross attention features.
        # dropout的機率,型別為浮點數,預設為0
        dropout (`float`, *optional*, defaults to 0):
            Dropout probability for down, up and bottleneck blocks.
        # 是否在時間嵌入中將正弦轉換為餘弦,型別為布林值,預設為True
        flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
            Whether to flip the sin to cos in the time embedding.
        # 應用於時間嵌入的頻率偏移,型別為整型,預設為0
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
        # 是否啟用記憶體高效的注意力機制,型別為布林值,預設為False
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
            Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
        # 是否將頭維度拆分為新的軸進行自注意力計算,型別為布林值,預設為False
        split_head_dim (`bool`, *optional*, defaults to `False`):
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
    """
    
    # 定義樣本大小,預設為32
    sample_size: int = 32
    # 定義輸入通道數,預設為4
    in_channels: int = 4
    # 定義輸出通道數,預設為4
    out_channels: int = 4
    # 定義下采樣塊的型別元組
    down_block_types: Tuple[str, ...] = (
        "CrossAttnDownBlock2D",  # 第一個下采樣塊
        "CrossAttnDownBlock2D",  # 第二個下采樣塊
        "CrossAttnDownBlock2D",  # 第三個下采樣塊
        "DownBlock2D",           # 第四個下采樣塊
    )
    # 定義上取樣塊的型別元組
    up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
    # 定義中間塊型別,預設為"UNetMidBlock2DCrossAttn"
    mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn"
    # 定義是否只使用交叉注意力,預設為False
    only_cross_attention: Union[bool, Tuple[bool]] = False
    # 定義每個塊的輸出通道元組
    block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
    # 每個塊的層數設為 2
    layers_per_block: int = 2
    # 注意力頭的維度設為 8
    attention_head_dim: Union[int, Tuple[int, ...]] = 8
    # 可選的注意力頭數量,預設為 None
    num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
    # 跨注意力的維度設為 1280
    cross_attention_dim: int = 1280
    # dropout 比率設為 0.0
    dropout: float = 0.0
    # 是否使用線性投影,預設為 False
    use_linear_projection: bool = False
    # 資料型別設為 float32
    dtype: jnp.dtype = jnp.float32
    # flip_sin_to_cos 設為 True
    flip_sin_to_cos: bool = True
    # 頻移設為 0
    freq_shift: int = 0
    # 是否使用記憶體高效的注意力,預設為 False
    use_memory_efficient_attention: bool = False
    # 是否拆分頭維度,預設為 False
    split_head_dim: bool = False
    # 每個塊的變換層數設為 1
    transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1
    # 可選的附加嵌入型別,預設為 None
    addition_embed_type: Optional[str] = None
    # 可選的附加時間嵌入維度,預設為 None
    addition_time_embed_dim: Optional[int] = None
    # 附加嵌入型別的頭數量設為 64
    addition_embed_type_num_heads: int = 64
    # 可選的投影類嵌入輸入維度,預設為 None
    projection_class_embeddings_input_dim: Optional[int] = None

    # 初始化權重函式,接受隨機數生成器作為引數
    def init_weights(self, rng: jax.Array) -> FrozenDict:
        # 初始化輸入張量的形狀
        sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
        # 建立全零的輸入樣本
        sample = jnp.zeros(sample_shape, dtype=jnp.float32)
        # 建立全一的時間步張量
        timesteps = jnp.ones((1,), dtype=jnp.int32)
        # 初始化編碼器的隱藏狀態
        encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)

        # 分割隨機數生成器,用於引數和 dropout
        params_rng, dropout_rng = jax.random.split(rng)
        # 建立隨機數字典
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 初始化附加條件關鍵字引數
        added_cond_kwargs = None
        # 判斷嵌入型別是否為 "text_time"
        if self.addition_embed_type == "text_time":
            # 透過反向計算獲取期望的文字嵌入維度
            is_refiner = (
                5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
                == self.config.projection_class_embeddings_input_dim
            )
            # 確定微條件的數量
            num_micro_conditions = 5 if is_refiner else 6

            # 計算文字嵌入維度
            text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
                num_micro_conditions * self.config.addition_time_embed_dim
            )

            # 計算時間 ID 的通道數和維度
            time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
            time_ids_dims = time_ids_channels // self.addition_time_embed_dim
            # 建立附加條件關鍵字引數字典
            added_cond_kwargs = {
                "text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
                "time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
            }
        # 返回初始化後的引數字典
        return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]

    # 定義呼叫函式,接收多個輸入引數
    def __call__(
        self,
        sample: jnp.ndarray,
        timesteps: Union[jnp.ndarray, float, int],
        encoder_hidden_states: jnp.ndarray,
        # 可選的附加條件關鍵字引數
        added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
        # 可選的下塊附加殘差
        down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None,
        # 可選的中塊附加殘差
        mid_block_additional_residual: Optional[jnp.ndarray] = None,
        # 是否返回字典,預設為 True
        return_dict: bool = True,
        # 是否為訓練模式,預設為 False
        train: bool = False,

.\diffusers\models\unets\unet_3d_blocks.py

# 版權所有 2024 HuggingFace 團隊。所有權利保留。
#
# 根據 Apache 許可證 2.0 版本("許可證")授權;
# 除非遵守許可證,否則您不得使用此檔案。
# 您可以在以下位置獲得許可證副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面同意,否則根據許可證分發的軟體是按“原樣”基礎分發,
# 不提供任何形式的保證或條件,無論是明示或暗示。
# 有關許可證的具體許可權和限制,請參閱許可證。

# 匯入型別提示中的任何型別
from typing import Any, Dict, Optional, Tuple, Union

# 匯入 PyTorch 庫
import torch
# 從 PyTorch 匯入神經網路模組
from torch import nn

# 匯入實用工具函式,包括棄用和日誌記錄
from ...utils import deprecate, is_torch_version, logging
# 匯入 PyTorch 相關的工具函式
from ...utils.torch_utils import apply_freeu
# 匯入注意力機制相關的類
from ..attention import Attention
# 匯入 ResNet 相關的類
from ..resnet import (
    Downsample2D,  # 匯入 2D 下采樣模組
    ResnetBlock2D,  # 匯入 2D ResNet 塊
    SpatioTemporalResBlock,  # 匯入時空 ResNet 塊
    TemporalConvLayer,  # 匯入時間卷積層
    Upsample2D,  # 匯入 2D 上取樣模組
)
# 匯入 2D 變換器模型
from ..transformers.transformer_2d import Transformer2DModel
# 匯入時間相關的變換器模型
from ..transformers.transformer_temporal import (
    TransformerSpatioTemporalModel,  # 匯入時空變換器模型
    TransformerTemporalModel,  # 匯入時間變換器模型
)
# 匯入運動模型的 UNet 相關類
from .unet_motion_model import (
    CrossAttnDownBlockMotion,  # 匯入交叉注意力下塊運動類
    CrossAttnUpBlockMotion,  # 匯入交叉注意力上塊運動類
    DownBlockMotion,  # 匯入下塊運動類
    UNetMidBlockCrossAttnMotion,  # 匯入中間塊交叉注意力運動類
    UpBlockMotion,  # 匯入上塊運動類
)

# 建立一個日誌記錄器,使用當前模組的名稱
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定義 DownBlockMotion 類,繼承自 DownBlockMotion
class DownBlockMotion(DownBlockMotion):
    # 初始化方法,接受任意引數和關鍵字引數
    def __init__(self, *args, **kwargs):
        # 設定棄用訊息,提醒使用者變更
        deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead."
        # 呼叫棄用函式,記錄棄用資訊
        deprecate("DownBlockMotion", "1.0.0", deprecation_message)
        # 呼叫父類的初始化方法
        super().__init__(*args, **kwargs)

# 定義 CrossAttnDownBlockMotion 類,繼承自 CrossAttnDownBlockMotion
class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion):
    # 初始化方法,接受任意引數和關鍵字引數
    def __init__(self, *args, **kwargs):
        # 設定棄用訊息,提醒使用者變更
        deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead."
        # 呼叫棄用函式,記錄棄用資訊
        deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message)
        # 呼叫父類的初始化方法
        super().__init__(*args, **kwargs)

# 定義 UpBlockMotion 類,繼承自 UpBlockMotion
class UpBlockMotion(UpBlockMotion):
    # 初始化方法,接受任意引數和關鍵字引數
    def __init__(self, *args, **kwargs):
        # 設定棄用訊息,提醒使用者變更
        deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead."
        # 呼叫棄用函式,記錄棄用資訊
        deprecate("UpBlockMotion", "1.0.0", deprecation_message)
        # 呼叫父類的初始化方法
        super().__init__(*args, **kwargs)

# 定義 CrossAttnUpBlockMotion 類,繼承自 CrossAttnUpBlockMotion
class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion):
    # 初始化方法,用於建立類的例項
        def __init__(self, *args, **kwargs):
            # 定義一個關於匯入的棄用警告資訊
            deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead."
            # 呼叫棄用警告函式,記錄該功能的棄用資訊及版本
            deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message)
            # 呼叫父類的初始化方法,傳遞引數以初始化父類部分
            super().__init__(*args, **kwargs)
# 定義一個名為 UNetMidBlockCrossAttnMotion 的類,繼承自同名父類
class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion):
    # 初始化方法,接收可變引數和關鍵字引數
    def __init__(self, *args, **kwargs):
        # 定義棄用警告資訊,提示使用者更新匯入路徑
        deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead."
        # 觸發棄用警告
        deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message)
        # 呼叫父類的初始化方法
        super().__init__(*args, **kwargs)


# 定義一個函式,返回不同型別的下采樣塊
def get_down_block(
    # 定義引數,型別和含義
    down_block_type: str,  # 下采樣塊的型別
    num_layers: int,  # 層數
    in_channels: int,  # 輸入通道數
    out_channels: int,  # 輸出通道數
    temb_channels: int,  # 時間嵌入通道數
    add_downsample: bool,  # 是否新增下采樣
    resnet_eps: float,  # ResNet 的 epsilon 引數
    resnet_act_fn: str,  # ResNet 的啟用函式
    num_attention_heads: int,  # 注意力頭數
    resnet_groups: Optional[int] = None,  # ResNet 的分組數,可選
    cross_attention_dim: Optional[int] = None,  # 交叉注意力維度,可選
    downsample_padding: Optional[int] = None,  # 下采樣填充,可選
    dual_cross_attention: bool = False,  # 是否使用雙重交叉注意力
    use_linear_projection: bool = True,  # 是否使用線性投影
    only_cross_attention: bool = False,  # 是否僅使用交叉注意力
    upcast_attention: bool = False,  # 是否提升注意力精度
    resnet_time_scale_shift: str = "default",  # ResNet 時間尺度偏移
    temporal_num_attention_heads: int = 8,  # 時間注意力頭數
    temporal_max_seq_length: int = 32,  # 時間序列最大長度
    transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 每個塊的變換器層數
    temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 時間變換器每塊層數
    dropout: float = 0.0,  # dropout 機率
) -> Union[
    "DownBlock3D",  # 返回的可能型別之一:3D 下采樣塊
    "CrossAttnDownBlock3D",  # 返回的可能型別之二:交叉注意力下采樣塊
    "DownBlockSpatioTemporal",  # 返回的可能型別之三:時空下采樣塊
    "CrossAttnDownBlockSpatioTemporal",  # 返回的可能型別之四:時空交叉注意力下采樣塊
]:
    # 檢查下采樣塊型別是否為 DownBlock3D
    if down_block_type == "DownBlock3D":
        # 建立並返回 DownBlock3D 例項
        return DownBlock3D(
            num_layers=num_layers,  # 傳入層數
            in_channels=in_channels,  # 傳入輸入通道數
            out_channels=out_channels,  # 傳入輸出通道數
            temb_channels=temb_channels,  # 傳入時間嵌入通道數
            add_downsample=add_downsample,  # 傳入是否新增下采樣
            resnet_eps=resnet_eps,  # 傳入 ResNet 的 epsilon 引數
            resnet_act_fn=resnet_act_fn,  # 傳入啟用函式
            resnet_groups=resnet_groups,  # 傳入分組數
            downsample_padding=downsample_padding,  # 傳入下采樣填充
            resnet_time_scale_shift=resnet_time_scale_shift,  # 傳入時間尺度偏移
            dropout=dropout,  # 傳入 dropout 機率
        )
    # 檢查下采樣塊型別是否為 CrossAttnDownBlock3D
    elif down_block_type == "CrossAttnDownBlock3D":
        # 如果交叉注意力維度未指定,丟擲錯誤
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
        # 建立並返回 CrossAttnDownBlock3D 例項
        return CrossAttnDownBlock3D(
            num_layers=num_layers,  # 傳入層數
            in_channels=in_channels,  # 傳入輸入通道數
            out_channels=out_channels,  # 傳入輸出通道數
            temb_channels=temb_channels,  # 傳入時間嵌入通道數
            add_downsample=add_downsample,  # 傳入是否新增下采樣
            resnet_eps=resnet_eps,  # 傳入 ResNet 的 epsilon 引數
            resnet_act_fn=resnet_act_fn,  # 傳入啟用函式
            resnet_groups=resnet_groups,  # 傳入分組數
            downsample_padding=downsample_padding,  # 傳入下采樣填充
            cross_attention_dim=cross_attention_dim,  # 傳入交叉注意力維度
            num_attention_heads=num_attention_heads,  # 傳入注意力頭數
            dual_cross_attention=dual_cross_attention,  # 傳入是否使用雙重交叉注意力
            use_linear_projection=use_linear_projection,  # 傳入是否使用線性投影
            only_cross_attention=only_cross_attention,  # 傳入是否僅使用交叉注意力
            upcast_attention=upcast_attention,  # 傳入是否提升注意力精度
            resnet_time_scale_shift=resnet_time_scale_shift,  # 傳入時間尺度偏移
            dropout=dropout,  # 傳入 dropout 機率
        )
    # 檢查下一個塊的型別是否為時空下采樣塊
    elif down_block_type == "DownBlockSpatioTemporal":
        # 為 SDV 進行了新增
        # 返回一個時空下采樣塊的例項
        return DownBlockSpatioTemporal(
            # 設定層數
            num_layers=num_layers,
            # 輸入通道數
            in_channels=in_channels,
            # 輸出通道數
            out_channels=out_channels,
            # 時間嵌入通道數
            temb_channels=temb_channels,
            # 是否新增下采樣
            add_downsample=add_downsample,
        )
    # 檢查下一個塊的型別是否為交叉注意力時空下采樣塊
    elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
        # 為 SDV 進行了新增
        # 如果沒有指定交叉注意力維度,丟擲錯誤
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
        # 返回一個交叉注意力時空下采樣塊的例項
        return CrossAttnDownBlockSpatioTemporal(
            # 輸入通道數
            in_channels=in_channels,
            # 輸出通道數
            out_channels=out_channels,
            # 時間嵌入通道數
            temb_channels=temb_channels,
            # 設定層數
            num_layers=num_layers,
            # 每個塊的變換層數
            transformer_layers_per_block=transformer_layers_per_block,
            # 是否新增下采樣
            add_downsample=add_downsample,
            # 設定交叉注意力維度
            cross_attention_dim=cross_attention_dim,
            # 注意力頭數
            num_attention_heads=num_attention_heads,
        )

    # 如果塊型別不匹配,則丟擲錯誤
    raise ValueError(f"{down_block_type} does not exist.")
# 定義函式 get_up_block,返回不同型別的上取樣模組
def get_up_block(
    # 上取樣塊型別
    up_block_type: str,
    # 層數
    num_layers: int,
    # 輸入通道數
    in_channels: int,
    # 輸出通道數
    out_channels: int,
    # 上一層的輸出通道數
    prev_output_channel: int,
    # 時間嵌入通道數
    temb_channels: int,
    # 是否新增上取樣
    add_upsample: bool,
    # ResNet 的 epsilon 值
    resnet_eps: float,
    # ResNet 的啟用函式型別
    resnet_act_fn: str,
    # 注意力頭的數量
    num_attention_heads: int,
    # 解析度索引(可選)
    resolution_idx: Optional[int] = None,
    # ResNet 組數(可選)
    resnet_groups: Optional[int] = None,
    # 跨注意力維度(可選)
    cross_attention_dim: Optional[int] = None,
    # 是否使用雙重跨注意力
    dual_cross_attention: bool = False,
    # 是否使用線性投影
    use_linear_projection: bool = True,
    # 是否僅使用跨注意力
    only_cross_attention: bool = False,
    # 是否提升注意力計算
    upcast_attention: bool = False,
    # ResNet 時間尺度移位的設定
    resnet_time_scale_shift: str = "default",
    # 時間上的注意力頭數量
    temporal_num_attention_heads: int = 8,
    # 時間上的跨注意力維度(可選)
    temporal_cross_attention_dim: Optional[int] = None,
    # 時間序列最大長度
    temporal_max_seq_length: int = 32,
    # 每個塊的變換器層數
    transformer_layers_per_block: Union[int, Tuple[int]] = 1,
    # 每個塊的時間變換器層數
    temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
    # dropout 機率
    dropout: float = 0.0,
) -> Union[
    # 返回型別為不同的上取樣塊
    "UpBlock3D",
    "CrossAttnUpBlock3D",
    "UpBlockSpatioTemporal",
    "CrossAttnUpBlockSpatioTemporal",
]:
    # 判斷上取樣塊型別是否為 UpBlock3D
    if up_block_type == "UpBlock3D":
        # 建立並返回 UpBlock3D 例項
        return UpBlock3D(
            # 傳入引數設定
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            prev_output_channel=prev_output_channel,
            temb_channels=temb_channels,
            add_upsample=add_upsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            resnet_time_scale_shift=resnet_time_scale_shift,
            resolution_idx=resolution_idx,
            dropout=dropout,
        )
    # 判斷上取樣塊型別是否為 CrossAttnUpBlock3D
    elif up_block_type == "CrossAttnUpBlock3D":
        # 檢查是否提供跨注意力維度
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
        # 建立並返回 CrossAttnUpBlock3D 例項
        return CrossAttnUpBlock3D(
            # 傳入引數設定
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            prev_output_channel=prev_output_channel,
            temb_channels=temb_channels,
            add_upsample=add_upsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            cross_attention_dim=cross_attention_dim,
            num_attention_heads=num_attention_heads,
            dual_cross_attention=dual_cross_attention,
            use_linear_projection=use_linear_projection,
            only_cross_attention=only_cross_attention,
            upcast_attention=upcast_attention,
            resnet_time_scale_shift=resnet_time_scale_shift,
            resolution_idx=resolution_idx,
            dropout=dropout,
        )
    # 檢查上升塊型別是否為 "UpBlockSpatioTemporal"
    elif up_block_type == "UpBlockSpatioTemporal":
        # 為 SDV 新增的內容
        # 返回 UpBlockSpatioTemporal 例項,使用指定的引數
        return UpBlockSpatioTemporal(
            # 層數引數
            num_layers=num_layers,
            # 輸入通道數
            in_channels=in_channels,
            # 輸出通道數
            out_channels=out_channels,
            # 前一個輸出通道數
            prev_output_channel=prev_output_channel,
            # 時間嵌入通道數
            temb_channels=temb_channels,
            # 解析度索引
            resolution_idx=resolution_idx,
            # 是否新增上取樣
            add_upsample=add_upsample,
        )
    # 檢查上升塊型別是否為 "CrossAttnUpBlockSpatioTemporal"
    elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
        # 為 SDV 新增的內容
        # 如果沒有指定交叉注意力維度,丟擲錯誤
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
        # 返回 CrossAttnUpBlockSpatioTemporal 例項,使用指定的引數
        return CrossAttnUpBlockSpatioTemporal(
            # 輸入通道數
            in_channels=in_channels,
            # 輸出通道數
            out_channels=out_channels,
            # 前一個輸出通道數
            prev_output_channel=prev_output_channel,
            # 時間嵌入通道數
            temb_channels=temb_channels,
            # 層數引數
            num_layers=num_layers,
            # 每個塊的變換層數
            transformer_layers_per_block=transformer_layers_per_block,
            # 是否新增上取樣
            add_upsample=add_upsample,
            # 交叉注意力維度
            cross_attention_dim=cross_attention_dim,
            # 注意力頭數
            num_attention_heads=num_attention_heads,
            # 解析度索引
            resolution_idx=resolution_idx,
        )

    # 如果上升塊型別不符合任何已知型別,丟擲錯誤
    raise ValueError(f"{up_block_type} does not exist.")
# 定義一個名為 UNetMidBlock3DCrossAttn 的類,繼承自 nn.Module
class UNetMidBlock3DCrossAttn(nn.Module):
    # 初始化方法,設定類的引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        temb_channels: int,  # 時間嵌入通道數
        dropout: float = 0.0,  # dropout 機率
        num_layers: int = 1,  # 層數
        resnet_eps: float = 1e-6,  # ResNet 中的小常數,避免除零
        resnet_time_scale_shift: str = "default",  # ResNet 時間縮放偏移
        resnet_act_fn: str = "swish",  # ResNet 啟用函式型別
        resnet_groups: int = 32,  # ResNet 分組數
        resnet_pre_norm: bool = True,  # 是否在 ResNet 前進行歸一化
        num_attention_heads: int = 1,  # 注意力頭數
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        cross_attention_dim: int = 1280,  # 交叉注意力維度
        dual_cross_attention: bool = False,  # 是否使用雙交叉注意力
        use_linear_projection: bool = True,  # 是否使用線性投影
        upcast_attention: bool = False,  # 是否使用上取樣注意力
    ):
        # 省略具體初始化程式碼,通常會在這裡初始化各個層和引數

    # 定義前向傳播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 可選的編碼器隱藏狀態
        attention_mask: Optional[torch.Tensor] = None,  # 可選的注意力掩碼
        num_frames: int = 1,  # 幀數
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 交叉注意力的額外引數
    ) -> torch.Tensor:  # 返回一個張量
        # 透過第一個 ResNet 層處理隱藏狀態
        hidden_states = self.resnets[0](hidden_states, temb)
        # 透過第一個時間卷積層處理隱藏狀態
        hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
        # 遍歷所有的注意力層、時間注意力層、ResNet 層和時間卷積層
        for attn, temp_attn, resnet, temp_conv in zip(
            self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
        ):
            # 透過當前注意力層處理隱藏狀態
            hidden_states = attn(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                cross_attention_kwargs=cross_attention_kwargs,
                return_dict=False,
            )[0]  # 只取返回的第一個元素
            # 透過當前時間注意力層處理隱藏狀態
            hidden_states = temp_attn(
                hidden_states,
                num_frames=num_frames,
                cross_attention_kwargs=cross_attention_kwargs,
                return_dict=False,
            )[0]  # 只取返回的第一個元素
            # 透過當前 ResNet 層處理隱藏狀態
            hidden_states = resnet(hidden_states, temb)
            # 透過當前時間卷積層處理隱藏狀態
            hidden_states = temp_conv(hidden_states, num_frames=num_frames)

        # 返回最終的隱藏狀態
        return hidden_states


# 定義一個名為 CrossAttnDownBlock3D 的類,繼承自 nn.Module
class CrossAttnDownBlock3D(nn.Module):
    # 初始化方法,設定類的引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: int,  # 輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        dropout: float = 0.0,  # dropout 機率
        num_layers: int = 1,  # 層數
        resnet_eps: float = 1e-6,  # ResNet 中的小常數,避免除零
        resnet_time_scale_shift: str = "default",  # ResNet 時間縮放偏移
        resnet_act_fn: str = "swish",  # ResNet 啟用函式型別
        resnet_groups: int = 32,  # ResNet 分組數
        resnet_pre_norm: bool = True,  # 是否在 ResNet 前進行歸一化
        num_attention_heads: int = 1,  # 注意力頭數
        cross_attention_dim: int = 1280,  # 交叉注意力維度
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        downsample_padding: int = 1,  # 下采樣的填充大小
        add_downsample: bool = True,  # 是否新增下采樣層
        dual_cross_attention: bool = False,  # 是否使用雙交叉注意力
        use_linear_projection: bool = False,  # 是否使用線性投影
        only_cross_attention: bool = False,  # 是否只使用交叉注意力
        upcast_attention: bool = False,  # 是否使用上取樣注意力
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 初始化殘差塊列表
        resnets = []
        # 初始化注意力層列表
        attentions = []
        # 初始化臨時注意力層列表
        temp_attentions = []
        # 初始化臨時卷積層列表
        temp_convs = []

        # 設定是否使用交叉注意力
        self.has_cross_attention = True
        # 設定注意力頭的數量
        self.num_attention_heads = num_attention_heads

        # 根據層數建立各層模組
        for i in range(num_layers):
            # 設定輸入通道數,第一層使用傳入的輸入通道數,後續層使用輸出通道數
            in_channels = in_channels if i == 0 else out_channels
            # 建立殘差塊並新增到列表中
            resnets.append(
                ResnetBlock2D(
                    in_channels=in_channels,  # 輸入通道數
                    out_channels=out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # 殘差塊的 epsilon 值
                    groups=resnet_groups,  # 組卷積的組數
                    dropout=dropout,  # Dropout 比例
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入的歸一化方式
                    non_linearity=resnet_act_fn,  # 非線性啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 是否進行預歸一化
                )
            )
            # 建立時間卷積層並新增到列表中
            temp_convs.append(
                TemporalConvLayer(
                    out_channels,  # 輸入通道數
                    out_channels,  # 輸出通道數
                    dropout=0.1,  # Dropout 比例
                    norm_num_groups=resnet_groups,  # 歸一化的組數
                )
            )
            # 建立二維變換器模型並新增到列表中
            attentions.append(
                Transformer2DModel(
                    out_channels // num_attention_heads,  # 每個注意力頭的通道數
                    num_attention_heads,  # 注意力頭的數量
                    in_channels=out_channels,  # 輸入通道數
                    num_layers=1,  # 變換器層數
                    cross_attention_dim=cross_attention_dim,  # 交叉注意力維度
                    norm_num_groups=resnet_groups,  # 歸一化的組數
                    use_linear_projection=use_linear_projection,  # 是否使用線性對映
                    only_cross_attention=only_cross_attention,  # 是否只使用交叉注意力
                    upcast_attention=upcast_attention,  # 是否上溢注意力
                )
            )
            # 建立時間變換器模型並新增到列表中
            temp_attentions.append(
                TransformerTemporalModel(
                    out_channels // num_attention_heads,  # 每個注意力頭的通道數
                    num_attention_heads,  # 注意力頭的數量
                    in_channels=out_channels,  # 輸入通道數
                    num_layers=1,  # 變換器層數
                    cross_attention_dim=cross_attention_dim,  # 交叉注意力維度
                    norm_num_groups=resnet_groups,  # 歸一化的組數
                )
            )
        # 將殘差塊列表轉換為模組列表
        self.resnets = nn.ModuleList(resnets)
        # 將臨時卷積層列表轉換為模組列表
        self.temp_convs = nn.ModuleList(temp_convs)
        # 將注意力層列表轉換為模組列表
        self.attentions = nn.ModuleList(attentions)
        # 將臨時注意力層列表轉換為模組列表
        self.temp_attentions = nn.ModuleList(temp_attentions)

        # 如果需要新增下采樣層
        if add_downsample:
            # 建立下采樣模組列表
            self.downsamplers = nn.ModuleList(
                [
                    Downsample2D(
                        out_channels,  # 輸出通道數
                        use_conv=True,  # 是否使用卷積進行下采樣
                        out_channels=out_channels,  # 輸出通道數
                        padding=downsample_padding,  # 下采樣時的填充
                        name="op",  # 模組名稱
                    )
                ]
            )
        else:
            # 如果不需要下采樣,設定為 None
            self.downsamplers = None

        # 初始化梯度檢查點設定為 False
        self.gradient_checkpointing = False
    # 定義前向傳播方法,接受多個輸入引數並返回張量或元組
    def forward(
            self,
            hidden_states: torch.Tensor,
            temb: Optional[torch.Tensor] = None,
            encoder_hidden_states: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            num_frames: int = 1,
            cross_attention_kwargs: Dict[str, Any] = None,
        ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
            # TODO(Patrick, William) - 注意力掩碼未使用
            output_states = ()  # 初始化輸出狀態為元組
    
            # 遍歷所有的殘差網路、臨時卷積、注意力和臨時注意力層
            for resnet, temp_conv, attn, temp_attn in zip(
                self.resnets, self.temp_convs, self.attentions, self.temp_attentions
            ):
                # 使用殘差網路處理隱狀態和時間嵌入
                hidden_states = resnet(hidden_states, temb)
                # 使用臨時卷積處理隱狀態,考慮幀數
                hidden_states = temp_conv(hidden_states, num_frames=num_frames)
                # 使用注意力層處理隱狀態,返回字典設為 False
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]  # 取返回的第一個元素
                # 使用臨時注意力層處理隱狀態,返回字典設為 False
                hidden_states = temp_attn(
                    hidden_states,
                    num_frames=num_frames,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]  # 取返回的第一個元素
    
                # 將當前隱狀態新增到輸出狀態中
                output_states += (hidden_states,)
    
            # 如果存在下采樣器,則逐個應用
            if self.downsamplers is not None:
                for downsampler in self.downsamplers:
                    hidden_states = downsampler(hidden_states)  # 應用下采樣器
    
                # 將下采樣後的隱狀態新增到輸出狀態中
                output_states += (hidden_states,)
    
            # 返回最終的隱狀態和所有輸出狀態
            return hidden_states, output_states
# 定義一個 3D 下采樣模組,繼承自 nn.Module
class DownBlock3D(nn.Module):
    # 初始化方法,設定各個引數及模組
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: int,  # 輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        dropout: float = 0.0,  # dropout 機率
        num_layers: int = 1,  # 層數
        resnet_eps: float = 1e-6,  # ResNet 中的 epsilon 值
        resnet_time_scale_shift: str = "default",  # 時間尺度偏移方式
        resnet_act_fn: str = "swish",  # ResNet 啟用函式型別
        resnet_groups: int = 32,  # ResNet 中的組數
        resnet_pre_norm: bool = True,  # 是否使用預歸一化
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        add_downsample: bool = True,  # 是否新增下采樣層
        downsample_padding: int = 1,  # 下采樣的填充大小
    ):
        # 呼叫父類建構函式
        super().__init__()
        # 初始化 ResNet 模組列表
        resnets = []
        # 初始化時間卷積層列表
        temp_convs = []

        # 遍歷層數,構建 ResNet 模組和時間卷積層
        for i in range(num_layers):
            # 第一層使用輸入通道數,後續層使用輸出通道數
            in_channels = in_channels if i == 0 else out_channels
            # 新增 ResNet 模組到列表中
            resnets.append(
                ResnetBlock2D(
                    in_channels=in_channels,  # 輸入通道數
                    out_channels=out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # ResNet 中的 epsilon 值
                    groups=resnet_groups,  # 組數
                    dropout=dropout,  # dropout 機率
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化方式
                    non_linearity=resnet_act_fn,  # 啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 是否預歸一化
                )
            )
            # 新增時間卷積層到列表中
            temp_convs.append(
                TemporalConvLayer(
                    out_channels,  # 輸入通道數
                    out_channels,  # 輸出通道數
                    dropout=0.1,  # dropout 機率
                    norm_num_groups=resnet_groups,  # 組數
                )
            )

        # 將 ResNet 模組列表轉換為 nn.ModuleList
        self.resnets = nn.ModuleList(resnets)
        # 將時間卷積層列表轉換為 nn.ModuleList
        self.temp_convs = nn.ModuleList(temp_convs)

        # 如果需要新增下采樣層
        if add_downsample:
            # 初始化下采樣模組列表
            self.downsamplers = nn.ModuleList(
                [
                    Downsample2D(
                        out_channels,  # 輸出通道數
                        use_conv=True,  # 使用卷積
                        out_channels=out_channels,  # 輸出通道數
                        padding=downsample_padding,  # 填充大小
                        name="op",  # 模組名稱
                    )
                ]
            )
        else:
            # 不新增下采樣層
            self.downsamplers = None

        # 初始化梯度檢查點開關為 False
        self.gradient_checkpointing = False

    # 前向傳播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        num_frames: int = 1,  # 幀數
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        # 初始化輸出狀態元組
        output_states = ()

        # 遍歷 ResNet 模組和時間卷積層
        for resnet, temp_conv in zip(self.resnets, self.temp_convs):
            # 透過 ResNet 模組處理隱藏狀態
            hidden_states = resnet(hidden_states, temb)
            # 透過時間卷積層處理隱藏狀態
            hidden_states = temp_conv(hidden_states, num_frames=num_frames)

            # 將當前的隱藏狀態新增到輸出狀態元組中
            output_states += (hidden_states,)

        # 如果存在下采樣層
        if self.downsamplers is not None:
            # 遍歷下采樣層
            for downsampler in self.downsamplers:
                # 透過下采樣層處理隱藏狀態
                hidden_states = downsampler(hidden_states)

            # 將當前的隱藏狀態新增到輸出狀態元組中
            output_states += (hidden_states,)

        # 返回最終的隱藏狀態和輸出狀態元組
        return hidden_states, output_states


# 定義一個 3D 交叉注意力上取樣模組,繼承自 nn.Module
class CrossAttnUpBlock3D(nn.Module):
    # 初始化方法,用於設定類的屬性
        def __init__(
            # 輸入通道數
            self,
            in_channels: int,
            # 輸出通道數
            out_channels: int,
            # 前一個輸出通道數
            prev_output_channel: int,
            # 時間嵌入通道數
            temb_channels: int,
            # dropout比率,預設為0.0
            dropout: float = 0.0,
            # 網路層數,預設為1
            num_layers: int = 1,
            # ResNet的epsilon值,預設為1e-6
            resnet_eps: float = 1e-6,
            # ResNet的時間尺度偏移,預設為"default"
            resnet_time_scale_shift: str = "default",
            # ResNet啟用函式,預設為"swish"
            resnet_act_fn: str = "swish",
            # ResNet的組數,預設為32
            resnet_groups: int = 32,
            # ResNet是否使用預歸一化,預設為True
            resnet_pre_norm: bool = True,
            # 注意力頭的數量,預設為1
            num_attention_heads: int = 1,
            # 交叉注意力的維度,預設為1280
            cross_attention_dim: int = 1280,
            # 輸出縮放因子,預設為1.0
            output_scale_factor: float = 1.0,
            # 是否新增上取樣,預設為True
            add_upsample: bool = True,
            # 雙重交叉注意力的開關,預設為False
            dual_cross_attention: bool = False,
            # 是否使用線性投影,預設為False
            use_linear_projection: bool = False,
            # 僅使用交叉注意力的開關,預設為False
            only_cross_attention: bool = False,
            # 是否上取樣注意力,預設為False
            upcast_attention: bool = False,
            # 解析度索引,預設為None
            resolution_idx: Optional[int] = None,
    ):
        # 呼叫父類的建構函式進行初始化
        super().__init__()
        # 初始化用於儲存 ResNet 塊的列表
        resnets = []
        # 初始化用於儲存時間卷積層的列表
        temp_convs = []
        # 初始化用於儲存注意力模型的列表
        attentions = []
        # 初始化用於儲存時間注意力模型的列表
        temp_attentions = []

        # 設定是否使用交叉注意力
        self.has_cross_attention = True
        # 設定注意力頭的數量
        self.num_attention_heads = num_attention_heads

        # 遍歷每一層以構建網路結構
        for i in range(num_layers):
            # 確定殘差跳過通道數
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 確定 ResNet 的輸入通道數
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 新增一個 ResNet 塊到列表中
            resnets.append(
                ResnetBlock2D(
                    # 設定輸入通道數,包括殘差跳過通道
                    in_channels=resnet_in_channels + res_skip_channels,
                    # 設定輸出通道數
                    out_channels=out_channels,
                    # 設定時間嵌入通道數
                    temb_channels=temb_channels,
                    # 設定 epsilon 值
                    eps=resnet_eps,
                    # 設定組數
                    groups=resnet_groups,
                    # 設定 dropout 機率
                    dropout=dropout,
                    # 設定時間嵌入歸一化方式
                    time_embedding_norm=resnet_time_scale_shift,
                    # 設定非線性啟用函式
                    non_linearity=resnet_act_fn,
                    # 設定輸出縮放因子
                    output_scale_factor=output_scale_factor,
                    # 設定是否在預歸一化
                    pre_norm=resnet_pre_norm,
                )
            )
            # 新增一個時間卷積層到列表中
            temp_convs.append(
                TemporalConvLayer(
                    # 設定輸出通道數
                    out_channels,
                    # 設定輸入通道數
                    out_channels,
                    # 設定 dropout 機率
                    dropout=0.1,
                    # 設定組數
                    norm_num_groups=resnet_groups,
                )
            )
            # 新增一個 2D 轉換器模型到列表中
            attentions.append(
                Transformer2DModel(
                    # 設定每個注意力頭的通道數
                    out_channels // num_attention_heads,
                    # 設定注意力頭的數量
                    num_attention_heads,
                    # 設定輸入通道數
                    in_channels=out_channels,
                    # 設定層數
                    num_layers=1,
                    # 設定交叉注意力維度
                    cross_attention_dim=cross_attention_dim,
                    # 設定組數
                    norm_num_groups=resnet_groups,
                    # 是否使用線性投影
                    use_linear_projection=use_linear_projection,
                    # 是否僅使用交叉注意力
                    only_cross_attention=only_cross_attention,
                    # 是否提升注意力精度
                    upcast_attention=upcast_attention,
                )
            )
            # 新增一個時間轉換器模型到列表中
            temp_attentions.append(
                TransformerTemporalModel(
                    # 設定每個注意力頭的通道數
                    out_channels // num_attention_heads,
                    # 設定注意力頭的數量
                    num_attention_heads,
                    # 設定輸入通道數
                    in_channels=out_channels,
                    # 設定層數
                    num_layers=1,
                    # 設定交叉注意力維度
                    cross_attention_dim=cross_attention_dim,
                    # 設定組數
                    norm_num_groups=resnet_groups,
                )
            )
        # 將 ResNet 塊列表轉換為 nn.ModuleList
        self.resnets = nn.ModuleList(resnets)
        # 將時間卷積層列表轉換為 nn.ModuleList
        self.temp_convs = nn.ModuleList(temp_convs)
        # 將注意力模型列表轉換為 nn.ModuleList
        self.attentions = nn.ModuleList(attentions)
        # 將時間注意力模型列表轉換為 nn.ModuleList
        self.temp_attentions = nn.ModuleList(temp_attentions)

        # 如果需要上取樣,則初始化上取樣層
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            # 否則將上取樣層設定為 None
            self.upsamplers = None

        # 設定梯度檢查點標誌
        self.gradient_checkpointing = False
        # 設定解析度索引
        self.resolution_idx = resolution_idx
    # 定義前向傳播函式,接受多個引數並返回張量
        def forward(
            self,
            hidden_states: torch.Tensor,  # 當前層的隱藏狀態
            res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 以前層的隱藏狀態元組
            temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
            encoder_hidden_states: Optional[torch.Tensor] = None,  # 可選的編碼器隱藏狀態
            upsample_size: Optional[int] = None,  # 可選的上取樣大小
            attention_mask: Optional[torch.Tensor] = None,  # 可選的注意力掩碼
            num_frames: int = 1,  # 幀數,預設值為1
            cross_attention_kwargs: Dict[str, Any] = None,  # 可選的交叉注意力引數
        ) -> torch.Tensor:  # 返回一個張量
            # 檢查 FreeU 是否啟用,基於多個屬性的存在性
            is_freeu_enabled = (
                getattr(self, "s1", None)  # 獲取屬性 s1
                and getattr(self, "s2", None)  # 獲取屬性 s2
                and getattr(self, "b1", None)  # 獲取屬性 b1
                and getattr(self, "b2", None)  # 獲取屬性 b2
            )
    
            # TODO(Patrick, William) - 注意力掩碼尚未使用
            for resnet, temp_conv, attn, temp_attn in zip(  # 遍歷網路模組
                self.resnets, self.temp_convs, self.attentions, self.temp_attentions  # 從各模組提取
            ):
                # 從元組中彈出最後一個殘差隱藏狀態
                res_hidden_states = res_hidden_states_tuple[-1]
                # 更新元組,去掉最後一個隱藏狀態
                res_hidden_states_tuple = res_hidden_states_tuple[:-1]
    
                # FreeU:僅對前兩個階段操作
                if is_freeu_enabled:
                    hidden_states, res_hidden_states = apply_freeu(  # 應用 FreeU 操作
                        self.resolution_idx,  # 當前解析度索引
                        hidden_states,  # 當前隱藏狀態
                        res_hidden_states,  # 殘差隱藏狀態
                        s1=self.s1,  # 屬性 s1
                        s2=self.s2,  # 屬性 s2
                        b1=self.b1,  # 屬性 b1
                        b2=self.b2,  # 屬性 b2
                    )
    
                # 將當前隱藏狀態與殘差隱藏狀態在維度1上拼接
                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
    
                # 透過 ResNet 模組處理隱藏狀態
                hidden_states = resnet(hidden_states, temb)
                # 透過臨時卷積模組處理隱藏狀態
                hidden_states = temp_conv(hidden_states, num_frames=num_frames)
                # 透過注意力模組處理隱藏狀態,並提取第一個返回值
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,  # 傳遞編碼器隱藏狀態
                    cross_attention_kwargs=cross_attention_kwargs,  # 傳遞交叉注意力引數
                    return_dict=False,  # 不返回字典形式的結果
                )[0]  # 提取第一個返回值
                # 透過臨時注意力模組處理隱藏狀態,並提取第一個返回值
                hidden_states = temp_attn(
                    hidden_states,
                    num_frames=num_frames,  # 傳遞幀數
                    cross_attention_kwargs=cross_attention_kwargs,  # 傳遞交叉注意力引數
                    return_dict=False,  # 不返回字典形式的結果
                )[0]  # 提取第一個返回值
    
            # 如果存在上取樣模組
            if self.upsamplers is not None:
                for upsampler in self.upsamplers:  # 遍歷上取樣模組
                    hidden_states = upsampler(hidden_states, upsample_size)  # 應用上取樣模組
    
            # 返回最終的隱藏狀態
            return hidden_states
# 定義一個名為 UpBlock3D 的類,繼承自 nn.Module
class UpBlock3D(nn.Module):
    # 初始化函式,接受多個引數以配置網路層
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        prev_output_channel: int,  # 前一層的輸出通道數
        out_channels: int,  # 當前層的輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        dropout: float = 0.0,  # dropout 機率
        num_layers: int = 1,  # 層數
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 引數,用於數值穩定
        resnet_time_scale_shift: str = "default",  # ResNet 時間縮放偏移設定
        resnet_act_fn: str = "swish",  # ResNet 啟用函式
        resnet_groups: int = 32,  # ResNet 的組數
        resnet_pre_norm: bool = True,  # 是否使用前置歸一化
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        add_upsample: bool = True,  # 是否新增上取樣層
        resolution_idx: Optional[int] = None,  # 解析度索引,預設為 None
    ):
        # 呼叫父類建構函式
        super().__init__()
        # 建立一個空列表,用於儲存 ResNet 層
        resnets = []
        # 建立一個空列表,用於儲存時間卷積層
        temp_convs = []

        # 根據層數建立 ResNet 和時間卷積層
        for i in range(num_layers):
            # 確定跳躍連線的通道數
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 確定當前 ResNet 層的輸入通道數
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 建立 ResNet 層,並新增到 resnets 列表中
            resnets.append(
                ResnetBlock2D(
                    in_channels=resnet_in_channels + res_skip_channels,  # 輸入通道數
                    out_channels=out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # epsilon 引數
                    groups=resnet_groups,  # 組數
                    dropout=dropout,  # dropout 機率
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化
                    non_linearity=resnet_act_fn,  # 啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 前置歸一化
                )
            )
            # 建立時間卷積層,並新增到 temp_convs 列表中
            temp_convs.append(
                TemporalConvLayer(
                    out_channels,  # 輸入通道數
                    out_channels,  # 輸出通道數
                    dropout=0.1,  # dropout 機率
                    norm_num_groups=resnet_groups,  # 歸一化組數
                )
            )

        # 將 ResNet 層的列表轉為 nn.ModuleList,以便於管理
        self.resnets = nn.ModuleList(resnets)
        # 將時間卷積層的列表轉為 nn.ModuleList,以便於管理
        self.temp_convs = nn.ModuleList(temp_convs)

        # 如果需要新增上取樣層,則建立並新增
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            # 如果不需要上取樣層,則設定為 None
            self.upsamplers = None

        # 初始化梯度檢查點標誌為 False
        self.gradient_checkpointing = False
        # 設定解析度索引
        self.resolution_idx = resolution_idx

    # 定義前向傳播函式
    def forward(
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 跳躍連線的隱藏狀態元組
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        upsample_size: Optional[int] = None,  # 可選的上取樣尺寸
        num_frames: int = 1,  # 幀數
    ) -> torch.Tensor:
        # 判斷是否啟用 FreeU,檢查相關屬性是否存在且不為 None
        is_freeu_enabled = (
            getattr(self, "s1", None)
            and getattr(self, "s2", None)
            and getattr(self, "b1", None)
            and getattr(self, "b2", None)
        )
        # 遍歷自定義的 resnets 和 temp_convs,進行逐對處理
        for resnet, temp_conv in zip(self.resnets, self.temp_convs):
            # 從 res_hidden_states_tuple 中彈出最後一個隱藏狀態
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新 res_hidden_states_tuple,去掉最後一個隱藏狀態
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]

            # 如果啟用了 FreeU,則僅對前兩個階段進行操作
            if is_freeu_enabled:
                # 呼叫 apply_freeu 函式,處理隱藏狀態
                hidden_states, res_hidden_states = apply_freeu(
                    self.resolution_idx,
                    hidden_states,
                    res_hidden_states,
                    s1=self.s1,
                    s2=self.s2,
                    b1=self.b1,
                    b2=self.b2,
                )

            # 將當前的隱藏狀態與殘差隱藏狀態在維度 1 上連線
            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

            # 透過當前的 resnet 處理隱藏狀態和 temb
            hidden_states = resnet(hidden_states, temb)
            # 透過當前的 temp_conv 處理隱藏狀態,傳入 num_frames 引數
            hidden_states = temp_conv(hidden_states, num_frames=num_frames)

        # 如果存在上取樣器,則對每個上取樣器進行處理
        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                # 透過當前的 upsampler 處理隱藏狀態,傳入 upsample_size 引數
                hidden_states = upsampler(hidden_states, upsample_size)

        # 返回最終的隱藏狀態
        return hidden_states
# 定義一箇中間塊時間解碼器類,繼承自 nn.Module
class MidBlockTemporalDecoder(nn.Module):
    # 初始化函式,定義輸入輸出通道、注意力頭維度、層數等引數
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        attention_head_dim: int = 512,
        num_layers: int = 1,
        upcast_attention: bool = False,
    ):
        # 呼叫父類的初始化函式
        super().__init__()

        # 初始化 ResNet 和 Attention 列表
        resnets = []
        attentions = []
        # 根據層數建立相應數量的 ResNet
        for i in range(num_layers):
            input_channels = in_channels if i == 0 else out_channels
            # 將 SpatioTemporalResBlock 例項新增到 ResNet 列表中
            resnets.append(
                SpatioTemporalResBlock(
                    in_channels=input_channels,
                    out_channels=out_channels,
                    temb_channels=None,
                    eps=1e-6,
                    temporal_eps=1e-5,
                    merge_factor=0.0,
                    merge_strategy="learned",
                    switch_spatial_to_temporal_mix=True,
                )
            )

        # 新增 Attention 例項到 Attention 列表中
        attentions.append(
            Attention(
                query_dim=in_channels,
                heads=in_channels // attention_head_dim,
                dim_head=attention_head_dim,
                eps=1e-6,
                upcast_attention=upcast_attention,
                norm_num_groups=32,
                bias=True,
                residual_connection=True,
            )
        )

        # 將 Attention 和 ResNet 列表轉換為 ModuleList
        self.attentions = nn.ModuleList(attentions)
        self.resnets = nn.ModuleList(resnets)

    # 前向傳播函式,定義輸入的隱藏狀態和影像指示器的處理
    def forward(
        self,
        hidden_states: torch.Tensor,
        image_only_indicator: torch.Tensor,
    ):
        # 處理第一個 ResNet 的輸出
        hidden_states = self.resnets[0](
            hidden_states,
            image_only_indicator=image_only_indicator,
        )
        # 遍歷剩餘的 ResNet 和 Attention,交替處理
        for resnet, attn in zip(self.resnets[1:], self.attentions):
            hidden_states = attn(hidden_states)  # 應用注意力機制
            # 處理 ResNet 的輸出
            hidden_states = resnet(
                hidden_states,
                image_only_indicator=image_only_indicator,
            )

        # 返回最終的隱藏狀態
        return hidden_states


# 定義一個上取樣塊時間解碼器類,繼承自 nn.Module
class UpBlockTemporalDecoder(nn.Module):
    # 初始化函式,定義輸入輸出通道、層數等引數
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_layers: int = 1,
        add_upsample: bool = True,
    ):
        # 呼叫父類的初始化函式
        super().__init__()
        # 初始化 ResNet 列表
        resnets = []
        # 根據層數建立相應數量的 ResNet
        for i in range(num_layers):
            input_channels = in_channels if i == 0 else out_channels
            # 將 SpatioTemporalResBlock 例項新增到 ResNet 列表中
            resnets.append(
                SpatioTemporalResBlock(
                    in_channels=input_channels,
                    out_channels=out_channels,
                    temb_channels=None,
                    eps=1e-6,
                    temporal_eps=1e-5,
                    merge_factor=0.0,
                    merge_strategy="learned",
                    switch_spatial_to_temporal_mix=True,
                )
            )
        # 將 ResNet 列表轉換為 ModuleList
        self.resnets = nn.ModuleList(resnets)

        # 如果需要,初始化上取樣層
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            self.upsamplers = None
    # 定義前向傳播函式,接收隱藏狀態和影像指示器作為輸入,返回處理後的張量
        def forward(
            self,
            hidden_states: torch.Tensor,
            image_only_indicator: torch.Tensor,
        ) -> torch.Tensor:
            # 遍歷每個 ResNet 模組,更新隱藏狀態
            for resnet in self.resnets:
                hidden_states = resnet(
                    hidden_states,
                    image_only_indicator=image_only_indicator,
                )
    
            # 如果存在上取樣模組,則對隱藏狀態進行上取樣處理
            if self.upsamplers is not None:
                for upsampler in self.upsamplers:
                    hidden_states = upsampler(hidden_states)
    
            # 返回最終的隱藏狀態
            return hidden_states
# 定義一個名為 UNetMidBlockSpatioTemporal 的類,繼承自 nn.Module
class UNetMidBlockSpatioTemporal(nn.Module):
    # 初始化方法,接收多個引數以配置該模組
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        temb_channels: int,  # 時間嵌入通道數
        num_layers: int = 1,  # 層數,預設為 1
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 每個塊的變換層數,預設為 1
        num_attention_heads: int = 1,  # 注意力頭數,預設為 1
        cross_attention_dim: int = 1280,  # 交叉注意力維度,預設為 1280
    ):
        # 呼叫父類的初始化方法
        super().__init__()

        # 設定是否使用交叉注意力標誌
        self.has_cross_attention = True
        # 儲存注意力頭的數量
        self.num_attention_heads = num_attention_heads

        # 支援每個塊的變換層數為可變的
        if isinstance(transformer_layers_per_block, int):
            # 如果是整數,則將其轉換為包含 num_layers 個相同元素的列表
            transformer_layers_per_block = [transformer_layers_per_block] * num_layers

        # 至少有一個 ResNet 塊
        resnets = [
            # 建立第一個時空殘差塊
            SpatioTemporalResBlock(
                in_channels=in_channels,  # 輸入通道數
                out_channels=in_channels,  # 輸出通道數與輸入相同
                temb_channels=temb_channels,  # 時間嵌入通道數
                eps=1e-5,  # 小常數用於數值穩定性
            )
        ]
        # 初始化注意力模組列表
        attentions = []

        # 遍歷層數以新增註意力和殘差塊
        for i in range(num_layers):
            # 新增時空變換模型到注意力列表
            attentions.append(
                TransformerSpatioTemporalModel(
                    num_attention_heads,  # 注意力頭數
                    in_channels // num_attention_heads,  # 每個頭的通道數
                    in_channels=in_channels,  # 輸入通道數
                    num_layers=transformer_layers_per_block[i],  # 當前層的變換層數
                    cross_attention_dim=cross_attention_dim,  # 交叉注意力維度
                )
            )

            # 新增另一個時空殘差塊到殘差列表
            resnets.append(
                SpatioTemporalResBlock(
                    in_channels=in_channels,  # 輸入通道數
                    out_channels=in_channels,  # 輸出通道數與輸入相同
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=1e-5,  # 小常數用於數值穩定性
                )
            )

        # 將注意力模組列表轉換為 nn.ModuleList
        self.attentions = nn.ModuleList(attentions)
        # 將殘差模組列表轉換為 nn.ModuleList
        self.resnets = nn.ModuleList(resnets)

        # 初始化梯度檢查點標誌為 False
        self.gradient_checkpointing = False

    # 前向傳播方法,接收多個輸入引數
    def forward(
        self,
        hidden_states: torch.Tensor,  # 隱藏狀態張量
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 可選的編碼器隱藏狀態
        image_only_indicator: Optional[torch.Tensor] = None,  # 可選的影像指示張量
    # 返回型別為 torch.Tensor 的函式
    ) -> torch.Tensor:
        # 使用第一個殘差網路處理隱藏狀態,傳入時間嵌入和影像指示器
        hidden_states = self.resnets[0](
            hidden_states,
            temb,
            image_only_indicator=image_only_indicator,
        )
    
        # 遍歷注意力層和後續殘差網路的組合
        for attn, resnet in zip(self.attentions, self.resnets[1:]):
            # 檢查是否在訓練中且開啟了梯度檢查點
            if self.training and self.gradient_checkpointing:  # TODO
    
                # 建立自定義前向傳播函式
                def create_custom_forward(module, return_dict=None):
                    # 定義自定義前向傳播內部函式
                    def custom_forward(*inputs):
                        # 根據返回字典引數決定是否使用返回字典
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)
    
                    return custom_forward
    
                # 設定檢查點引數,取決於 PyTorch 版本
                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                # 使用注意力層處理隱藏狀態,傳入編碼器的隱藏狀態和影像指示器
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    image_only_indicator=image_only_indicator,
                    return_dict=False,
                )[0]
                # 使用檢查點進行殘差網路的前向傳播
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(resnet),
                    hidden_states,
                    temb,
                    image_only_indicator,
                    **ckpt_kwargs,
                )
            else:
                # 使用注意力層處理隱藏狀態,傳入編碼器的隱藏狀態和影像指示器
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    image_only_indicator=image_only_indicator,
                    return_dict=False,
                )[0]
                # 使用殘差網路處理隱藏狀態
                hidden_states = resnet(
                    hidden_states,
                    temb,
                    image_only_indicator=image_only_indicator,
                )
    
        # 返回處理後的隱藏狀態
        return hidden_states
# 定義一個下采樣的時空塊,繼承自 nn.Module
class DownBlockSpatioTemporal(nn.Module):
    # 初始化方法,設定輸入輸出通道和層數等引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: int,  # 輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        num_layers: int = 1,  # 層數,預設為1
        add_downsample: bool = True,  # 是否新增下采樣
    ):
        super().__init__()  # 呼叫父類的初始化方法
        resnets = []  # 初始化一個空列表以儲存殘差塊

        # 根據層數建立相應數量的 SpatioTemporalResBlock
        for i in range(num_layers):
            in_channels = in_channels if i == 0 else out_channels  # 確定當前層的輸入通道數
            resnets.append(
                SpatioTemporalResBlock(  # 新增一個新的時空殘差塊
                    in_channels=in_channels,  # 設定輸入通道
                    out_channels=out_channels,  # 設定輸出通道
                    temb_channels=temb_channels,  # 設定時間嵌入通道
                    eps=1e-5,  # 設定 epsilon 值
                )
            )

        self.resnets = nn.ModuleList(resnets)  # 將殘差塊列表轉化為 ModuleList

        # 如果需要下采樣,建立下采樣模組
        if add_downsample:
            self.downsamplers = nn.ModuleList(
                [
                    Downsample2D(  # 新增一個下采樣層
                        out_channels,  # 設定輸入通道
                        use_conv=True,  # 是否使用卷積進行下采樣
                        out_channels=out_channels,  # 設定輸出通道
                        name="op",  # 下采樣層名稱
                    )
                ]
            )
        else:
            self.downsamplers = None  # 不新增下采樣模組

        self.gradient_checkpointing = False  # 初始化梯度檢查點為 False

    # 前向傳播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 隱藏狀態輸入
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入
        image_only_indicator: Optional[torch.Tensor] = None,  # 可選的影像指示器
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
        output_states = ()  # 初始化輸出狀態元組
        for resnet in self.resnets:  # 遍歷每個殘差塊
            if self.training and self.gradient_checkpointing:  # 如果在訓練且啟用了梯度檢查點

                # 定義一個建立自定義前向傳播的方法
                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)  # 返回模組的前向輸出

                    return custom_forward

                # 根據 PyTorch 版本進行檢查點操作
                if is_torch_version(">=", "1.11.0"):
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),  # 使用自定義前向傳播
                        hidden_states,  # 輸入隱藏狀態
                        temb,  # 輸入時間嵌入
                        image_only_indicator,  # 輸入影像指示器
                        use_reentrant=False,  # 不使用重入
                    )
                else:
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),  # 使用自定義前向傳播
                        hidden_states,  # 輸入隱藏狀態
                        temb,  # 輸入時間嵌入
                        image_only_indicator,  # 輸入影像指示器
                    )
            else:
                hidden_states = resnet(  # 直接透過殘差塊進行前向傳播
                    hidden_states,
                    temb,
                    image_only_indicator=image_only_indicator,
                )

            output_states = output_states + (hidden_states,)  # 將當前隱藏狀態新增到輸出狀態中

        if self.downsamplers is not None:  # 如果存在下采樣模組
            for downsampler in self.downsamplers:  # 遍歷下采樣模組
                hidden_states = downsampler(hidden_states)  # 對隱藏狀態進行下采樣

            output_states = output_states + (hidden_states,)  # 將下采樣後的狀態新增到輸出狀態中

        return hidden_states, output_states  # 返回最終的隱藏狀態和輸出狀態元組


# 定義一個交叉注意力下采樣時空塊,繼承自 nn.Module
class CrossAttnDownBlockSpatioTemporal(nn.Module):
    # 初始化方法,設定模型的基本引數
        def __init__(
            # 輸入通道數
            self,
            in_channels: int,
            # 輸出通道數
            out_channels: int,
            # 時間嵌入通道數
            temb_channels: int,
            # 層數,預設為1
            num_layers: int = 1,
            # 每個塊的變換層數,可以是整數或元組,預設為1
            transformer_layers_per_block: Union[int, Tuple[int]] = 1,
            # 注意力頭數,預設為1
            num_attention_heads: int = 1,
            # 交叉注意力的維度,預設為1280
            cross_attention_dim: int = 1280,
            # 是否新增下采樣,預設為True
            add_downsample: bool = True,
        ):
            # 呼叫父類初始化方法
            super().__init__()
            # 初始化殘差網路列表
            resnets = []
            # 初始化注意力層列表
            attentions = []
    
            # 設定是否使用交叉注意力為True
            self.has_cross_attention = True
            # 設定注意力頭數
            self.num_attention_heads = num_attention_heads
            # 如果變換層數是整數,則擴充套件為列表
            if isinstance(transformer_layers_per_block, int):
                transformer_layers_per_block = [transformer_layers_per_block] * num_layers
    
            # 遍歷層數以建立殘差塊和注意力層
            for i in range(num_layers):
                # 如果是第一層,使用輸入通道數,否則使用輸出通道數
                in_channels = in_channels if i == 0 else out_channels
                # 新增殘差塊到列表
                resnets.append(
                    SpatioTemporalResBlock(
                        # 輸入通道數
                        in_channels=in_channels,
                        # 輸出通道數
                        out_channels=out_channels,
                        # 時間嵌入通道數
                        temb_channels=temb_channels,
                        # 防止除零的微小值
                        eps=1e-6,
                    )
                )
                # 新增註意力模型到列表
                attentions.append(
                    TransformerSpatioTemporalModel(
                        # 注意力頭數
                        num_attention_heads,
                        # 每個頭的輸出通道數
                        out_channels // num_attention_heads,
                        # 輸入通道數
                        in_channels=out_channels,
                        # 該層的變換層數
                        num_layers=transformer_layers_per_block[i],
                        # 交叉注意力的維度
                        cross_attention_dim=cross_attention_dim,
                    )
                )
    
            # 將注意力層轉換為nn.ModuleList以支援PyTorch模型
            self.attentions = nn.ModuleList(attentions)
            # 將殘差層轉換為nn.ModuleList以支援PyTorch模型
            self.resnets = nn.ModuleList(resnets)
    
            # 如果需要新增下采樣層
            if add_downsample:
                # 新增下采樣層到nn.ModuleList
                self.downsamplers = nn.ModuleList(
                    [
                        Downsample2D(
                            # 輸出通道數
                            out_channels,
                            # 是否使用卷積
                            use_conv=True,
                            # 輸出通道數
                            out_channels=out_channels,
                            # 填充大小
                            padding=1,
                            # 操作名稱
                            name="op",
                        )
                    ]
                )
            else:
                # 如果不需要下采樣層,則設定為None
                self.downsamplers = None
    
            # 初始化梯度檢查點為False
            self.gradient_checkpointing = False
    
        # 前向傳播方法
        def forward(
            # 隱藏狀態的張量
            hidden_states: torch.Tensor,
            # 時間嵌入的可選張量
            temb: Optional[torch.Tensor] = None,
            # 編碼器隱藏狀態的可選張量
            encoder_hidden_states: Optional[torch.Tensor] = None,
            # 僅影像指示的可選張量
            image_only_indicator: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:  # 定義返回型別為一個元組,包含一個張量和多個張量的元組
        output_states = ()  # 初始化一個空元組,用於儲存輸出狀態

        blocks = list(zip(self.resnets, self.attentions))  # 將自定義的殘差網路和注意力模組打包成一個列表
        for resnet, attn in blocks:  # 遍歷每個殘差網路和對應的注意力模組
            if self.training and self.gradient_checkpointing:  # 如果處於訓練模式並且啟用了梯度檢查點

                def create_custom_forward(module, return_dict=None):  # 定義一個建立自定義前向傳播函式的輔助函式
                    def custom_forward(*inputs):  # 定義自定義前向傳播函式,接受任意數量的輸入
                        if return_dict is not None:  # 如果提供了返回字典引數
                            return module(*inputs, return_dict=return_dict)  # 使用返回字典引數呼叫模組
                        else:  # 如果沒有提供返回字典
                            return module(*inputs)  # 直接呼叫模組並返回結果

                    return custom_forward  # 返回自定義前向傳播函式

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}  # 根據 PyTorch 版本設定檢查點引數
                hidden_states = torch.utils.checkpoint.checkpoint(  # 使用檢查點功能計算隱藏狀態以節省記憶體
                    create_custom_forward(resnet),  # 將殘差網路傳入自定義前向函式
                    hidden_states,  # 將當前隱藏狀態作為輸入
                    temb,  # 傳遞時間嵌入
                    image_only_indicator,  # 傳遞影像指示器
                    **ckpt_kwargs,  # 解包檢查點引數
                )

                hidden_states = attn(  # 使用注意力模組處理隱藏狀態
                    hidden_states,  # 輸入隱藏狀態
                    encoder_hidden_states=encoder_hidden_states,  # 傳遞編碼器隱藏狀態
                    image_only_indicator=image_only_indicator,  # 傳遞影像指示器
                    return_dict=False,  # 不返回字典形式的結果
                )[0]  # 獲取輸出的第一個元素
            else:  # 如果不處於訓練模式或未啟用梯度檢查點
                hidden_states = resnet(  # 直接使用殘差網路處理隱藏狀態
                    hidden_states,  # 輸入當前隱藏狀態
                    temb,  # 傳遞時間嵌入
                    image_only_indicator=image_only_indicator,  # 傳遞影像指示器
                )
                hidden_states = attn(  # 使用注意力模組處理隱藏狀態
                    hidden_states,  # 輸入隱藏狀態
                    encoder_hidden_states=encoder_hidden_states,  # 傳遞編碼器隱藏狀態
                    image_only_indicator=image_only_indicator,  # 傳遞影像指示器
                    return_dict=False,  # 不返回字典形式的結果
                )[0]  # 獲取輸出的第一個元素

            output_states = output_states + (hidden_states,)  # 將當前的隱藏狀態新增到輸出狀態元組中

        if self.downsamplers is not None:  # 如果存在下采樣模組
            for downsampler in self.downsamplers:  # 遍歷每個下采樣模組
                hidden_states = downsampler(hidden_states)  # 使用下采樣模組處理隱藏狀態

            output_states = output_states + (hidden_states,)  # 將處理後的隱藏狀態新增到輸出狀態元組中

        return hidden_states, output_states  # 返回當前的隱藏狀態和所有輸出狀態
# 定義一個上取樣的時空塊類,繼承自 nn.Module
class UpBlockSpatioTemporal(nn.Module):
    # 初始化方法,定義類的引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        prev_output_channel: int,  # 前一層輸出通道數
        out_channels: int,  # 當前層輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        resolution_idx: Optional[int] = None,  # 可選的解析度索引
        num_layers: int = 1,  # 層數,預設為1
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值
        add_upsample: bool = True,  # 是否新增上取樣層
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 初始化一個空的列表,用於儲存 ResNet 模組
        resnets = []

        # 根據層數建立對應的 ResNet 塊
        for i in range(num_layers):
            # 如果是最後一層,使用輸入通道數,否則使用輸出通道數
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 第一層的輸入通道為前一層輸出通道,後續層為輸出通道
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 建立時空 ResNet 塊並新增到列表中
            resnets.append(
                SpatioTemporalResBlock(
                    in_channels=resnet_in_channels + res_skip_channels,  # 輸入通道數
                    out_channels=out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # ResNet 的 epsilon 值
                )
            )

        # 將 ResNet 模組列表轉換為 nn.ModuleList,以便在模型中管理
        self.resnets = nn.ModuleList(resnets)

        # 如果需要新增上取樣層,則建立對應的 nn.ModuleList
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            # 如果不新增上取樣層,設定為 None
            self.upsamplers = None

        # 初始化梯度檢查點標誌
        self.gradient_checkpointing = False
        # 設定解析度索引
        self.resolution_idx = resolution_idx

    # 定義前向傳播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 之前層的隱藏狀態元組
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        image_only_indicator: Optional[torch.Tensor] = None,  # 可選的影像指示器張量
    ) -> torch.Tensor:  # 定義函式返回型別為 PyTorch 的張量
        for resnet in self.resnets:  # 遍歷當前物件中的所有 ResNet 模型
            # pop res hidden states  # 從隱藏狀態元組中提取最後一個隱藏狀態
            res_hidden_states = res_hidden_states_tuple[-1]  # 獲取最後的 ResNet 隱藏狀態
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]  # 更新元組,移除最後一個隱藏狀態

            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)  # 將當前隱藏狀態與 ResNet 隱藏狀態在維度 1 上拼接

            if self.training and self.gradient_checkpointing:  # 如果處於訓練狀態並啟用了梯度檢查點
                def create_custom_forward(module):  # 定義用於建立自定義前向傳播函式的內部函式
                    def custom_forward(*inputs):  # 自定義前向傳播,接收任意數量的輸入
                        return module(*inputs)  # 呼叫原始模組的前向傳播

                    return custom_forward  # 返回自定義前向傳播函式

                if is_torch_version(">=", "1.11.0"):  # 檢查當前 PyTorch 版本是否大於等於 1.11.0
                    hidden_states = torch.utils.checkpoint.checkpoint(  # 使用檢查點機制儲存記憶體
                        create_custom_forward(resnet),  # 傳入自定義前向傳播函式
                        hidden_states,  # 傳入當前的隱藏狀態
                        temb,  # 傳入時間嵌入
                        image_only_indicator,  # 傳入影像指示器
                        use_reentrant=False,  # 禁用重入
                    )
                else:  # 如果 PyTorch 版本小於 1.11.0
                    hidden_states = torch.utils.checkpoint.checkpoint(  # 使用檢查點機制儲存記憶體
                        create_custom_forward(resnet),  # 傳入自定義前向傳播函式
                        hidden_states,  # 傳入當前的隱藏狀態
                        temb,  # 傳入時間嵌入
                        image_only_indicator,  # 傳入影像指示器
                    )
            else:  # 如果不是訓練狀態或沒有啟用梯度檢查點
                hidden_states = resnet(  # 直接呼叫 ResNet 模型處理隱藏狀態
                    hidden_states,  # 傳入當前的隱藏狀態
                    temb,  # 傳入時間嵌入
                    image_only_indicator=image_only_indicator,  # 傳入影像指示器
                )

        if self.upsamplers is not None:  # 如果存在上取樣模組
            for upsampler in self.upsamplers:  # 遍歷所有上取樣模組
                hidden_states = upsampler(hidden_states)  # 呼叫上取樣模組處理隱藏狀態

        return hidden_states  # 返回處理後的隱藏狀態
# 定義一個時空交叉注意力上取樣塊類,繼承自 nn.Module
class CrossAttnUpBlockSpatioTemporal(nn.Module):
    # 初始化方法,設定網路的各個引數
    def __init__(
        # 輸入通道數
        in_channels: int,
        # 輸出通道數
        out_channels: int,
        # 前一層輸出通道數
        prev_output_channel: int,
        # 時間嵌入通道數
        temb_channels: int,
        # 解析度索引,可選
        resolution_idx: Optional[int] = None,
        # 層數
        num_layers: int = 1,
        # 每個塊的變換器層數,支援單個整數或元組
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
        # ResNet 的 epsilon 值,防止除零錯誤
        resnet_eps: float = 1e-6,
        # 注意力頭的數量
        num_attention_heads: int = 1,
        # 交叉注意力維度
        cross_attention_dim: int = 1280,
        # 是否新增上取樣層
        add_upsample: bool = True,
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 儲存 ResNet 層的列表
        resnets = []
        # 儲存注意力層的列表
        attentions = []

        # 指示是否使用交叉注意力
        self.has_cross_attention = True
        # 設定注意力頭的數量
        self.num_attention_heads = num_attention_heads

        # 如果是整數,將其轉換為列表,包含 num_layers 個相同的元素
        if isinstance(transformer_layers_per_block, int):
            transformer_layers_per_block = [transformer_layers_per_block] * num_layers

        # 遍歷每一層,構建 ResNet 和注意力層
        for i in range(num_layers):
            # 根據當前層是否是最後一層來決定跳過連線的通道數
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 根據當前層決定 ResNet 的輸入通道數
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 新增時空 ResNet 塊到列表中
            resnets.append(
                SpatioTemporalResBlock(
                    in_channels=resnet_in_channels + res_skip_channels,
                    out_channels=out_channels,
                    temb_channels=temb_channels,
                    eps=resnet_eps,
                )
            )
            # 新增時空變換器模型到列表中
            attentions.append(
                TransformerSpatioTemporalModel(
                    num_attention_heads,
                    out_channels // num_attention_heads,
                    in_channels=out_channels,
                    num_layers=transformer_layers_per_block[i],
                    cross_attention_dim=cross_attention_dim,
                )
            )

        # 將注意力層列表轉換為 nn.ModuleList,以便於管理
        self.attentions = nn.ModuleList(attentions)
        # 將 ResNet 層列表轉換為 nn.ModuleList,以便於管理
        self.resnets = nn.ModuleList(resnets)

        # 如果需要,新增上取樣層
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            # 否則上取樣層設定為 None
            self.upsamplers = None

        # 設定梯度檢查點標誌為 False
        self.gradient_checkpointing = False
        # 儲存解析度索引
        self.resolution_idx = resolution_idx

    # 前向傳播方法,定義輸入和輸出
    def forward(
        # 隱藏狀態張量
        hidden_states: torch.Tensor,
        # 上一層隱藏狀態的元組
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],
        # 可選的時間嵌入張量
        temb: Optional[torch.Tensor] = None,
        # 可選的編碼器隱藏狀態張量
        encoder_hidden_states: Optional[torch.Tensor] = None,
        # 可選的影像指示器張量
        image_only_indicator: Optional[torch.Tensor] = None,
    # 返回一個 torch.Tensor 型別的結果
    ) -> torch.Tensor:
        # 遍歷每個 resnet 和 attention 模組的組合
        for resnet, attn in zip(self.resnets, self.attentions):
            # 從隱藏狀態元組中彈出最後一個 res 隱藏狀態
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新隱藏狀態元組,去掉最後一個元素
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
    
            # 在指定維度連線當前的 hidden_states 和 res_hidden_states
            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
    
            # 如果處於訓練模式且開啟了梯度檢查點
            if self.training and self.gradient_checkpointing:  # TODO
                # 定義一個用於建立自定義前向傳播函式的函式
                def create_custom_forward(module, return_dict=None):
                    # 定義自定義前向傳播邏輯
                    def custom_forward(*inputs):
                        # 根據是否返回字典,呼叫模組的前向傳播
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)
    
                    return custom_forward
    
                # 根據 PyTorch 版本選擇 checkpoint 的引數
                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                # 使用檢查點儲存記憶體,呼叫自定義前向傳播
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(resnet),
                    hidden_states,
                    temb,
                    image_only_indicator,
                    **ckpt_kwargs,
                )
                # 透過 attention 模組處理 hidden_states
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    image_only_indicator=image_only_indicator,
                    return_dict=False,
                )[0]
            else:
                # 如果不使用檢查點,直接透過 resnet 模組處理 hidden_states
                hidden_states = resnet(
                    hidden_states,
                    temb,
                    image_only_indicator=image_only_indicator,
                )
                # 透過 attention 模組處理 hidden_states
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    image_only_indicator=image_only_indicator,
                    return_dict=False,
                )[0]
    
        # 如果存在上取樣模組,逐個應用於 hidden_states
        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                hidden_states = upsampler(hidden_states)
    
        # 返回處理後的 hidden_states
        return hidden_states

相關文章