diffusers-原始碼解析-五-

绝不原创的飞龙發表於2024-10-22

diffusers 原始碼解析(五)

.\diffusers\models\autoencoders\autoencoder_asym_kl.py

# 版權宣告,標識該檔案的所有權和使用條款
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根據 Apache 許可證第 2.0 版(“許可證”)進行授權;
# 除非遵循許可證,否則您不得使用此檔案。
# 您可以在以下地址獲取許可證副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面同意,否則根據許可證分發的軟體是以“現狀”基礎提供的,
# 不提供任何形式的明示或暗示的擔保或條件。
# 有關許可證所管理的許可權和限制的具體資訊,請參見許可證。
from typing import Optional, Tuple, Union  # 匯入型別提示模組,用於指定可選型別、元組和聯合型別

import torch  # 匯入 PyTorch 庫,用於深度學習
import torch.nn as nn  # 匯入 PyTorch 神經網路模組

from ...configuration_utils import ConfigMixin, register_to_config  # 從配置工具中匯入配置混合類和註冊函式
from ...utils.accelerate_utils import apply_forward_hook  # 從加速工具中匯入前向鉤子應用函式
from ..modeling_outputs import AutoencoderKLOutput  # 從建模輸出模組匯入自編碼器 KL 輸出類
from ..modeling_utils import ModelMixin  # 從建模工具中匯入模型混合類
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder  # 從 VAE 模組匯入解碼器輸出、對角高斯分佈、編碼器和掩碼條件解碼器類


class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):  # 定義不對稱自編碼器 KL 類,繼承模型混合類和配置混合類
    r"""  # 開始文件字串,描述模型的用途和背景
    設計一個更好的不對稱 VQGAN 以用於 StableDiffusion https://arxiv.org/abs/2306.04632。一個具有 KL 損失的 VAE 模型
    用於將影像編碼為潛在表示,並將潛在表示解碼為影像。

    此模型繼承自 [`ModelMixin`]。請檢視超類文件以瞭解其為所有模型實現的通用方法
    (例如下載或儲存)。
    # 引數說明部分,描述類或函式的引數及其預設值
    Parameters:
        # 輸入影像的通道數,預設值為 3
        in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
        # 輸出的通道數,預設值為 3
        out_channels (int,  *optional*, defaults to 3): Number of channels in the output.
        # 下采樣塊型別的元組,預設值為包含一個元素的元組
        down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
            Tuple of downsample block types.
        # 下采樣塊輸出通道的元組,預設值為包含一個元素的元組
        down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
            Tuple of down block output channels.
        # 每個下采樣塊的層數,預設值為 1
        layers_per_down_block (`int`, *optional*, defaults to `1`):
            Number layers for down block.
        # 上取樣塊型別的元組,預設值為包含一個元素的元組
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
            Tuple of upsample block types.
        # 上取樣塊輸出通道的元組,預設值為包含一個元素的元組
        up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
            Tuple of up block output channels.
        # 每個上取樣塊的層數,預設值為 1
        layers_per_up_block (`int`, *optional*, defaults to `1`):
            Number layers for up block.
        # 使用的啟用函式,預設值為 "silu"
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        # 潛在空間的通道數,預設值為 4
        latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
        # 輸入樣本的大小,預設值為 32
        sample_size (`int`, *optional*, defaults to 32): Sample input size.
        # ResNet 塊中第一個歸一化層使用的組數,預設值為 32
        norm_num_groups (`int`, *optional*, defaults to 32):
            Number of groups to use for the first normalization layer in ResNet blocks.
        # 訓練潛在空間的分量標準差,預設值為 0.18215
        scaling_factor (`float`, *optional*, defaults to 0.18215):
            The component-wise standard deviation of the trained latent space computed using the first batch of the
            training set. This is used to scale the latent space to have unit variance when training the diffusion
            model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
            diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
            / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
            Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
    """
    # 定義初始化方法,無返回值
    ) -> None:
            # 呼叫父類初始化方法
            super().__init__()
    
            # 將初始化引數傳遞給編碼器
            self.encoder = Encoder(
                # 輸入通道數
                in_channels=in_channels,
                # 潛在通道數
                out_channels=latent_channels,
                # 下采樣塊型別
                down_block_types=down_block_types,
                # 下采樣塊輸出通道數
                block_out_channels=down_block_out_channels,
                # 每個塊的層數
                layers_per_block=layers_per_down_block,
                # 啟用函式
                act_fn=act_fn,
                # 歸一化的組數
                norm_num_groups=norm_num_groups,
                # 設定雙重潛變數
                double_z=True,
            )
    
            # 將初始化引數傳遞給解碼器
            self.decoder = MaskConditionDecoder(
                # 輸入潛在通道數
                in_channels=latent_channels,
                # 輸出通道數
                out_channels=out_channels,
                # 上取樣塊型別
                up_block_types=up_block_types,
                # 上取樣塊輸出通道數
                block_out_channels=up_block_out_channels,
                # 每個塊的層數
                layers_per_block=layers_per_up_block,
                # 啟用函式
                act_fn=act_fn,
                # 歸一化的組數
                norm_num_groups=norm_num_groups,
            )
    
            # 定義量化卷積層,輸入輸出通道數相同
            self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
            # 定義後量化卷積層,輸入輸出通道數相同
            self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
    
            # 禁用切片功能
            self.use_slicing = False
            # 禁用平鋪功能
            self.use_tiling = False
    
            # 註冊上取樣塊輸出通道數到配置
            self.register_to_config(block_out_channels=up_block_out_channels)
            # 註冊強制上溯引數到配置
            self.register_to_config(force_upcast=False)
    
        # 應用前向鉤子修飾符
        @apply_forward_hook
        def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor]]:
            # 使用編碼器處理輸入資料
            h = self.encoder(x)
            # 透過量化卷積獲取時刻
            moments = self.quant_conv(h)
            # 建立對角高斯分佈
            posterior = DiagonalGaussianDistribution(moments)
    
            # 檢查是否返回字典
            if not return_dict:
                return (posterior,)
    
            # 返回潛在分佈輸出
            return AutoencoderKLOutput(latent_dist=posterior)
    
        # 定義解碼私有方法
        def _decode(
            self,
            z: torch.Tensor,
            image: Optional[torch.Tensor] = None,
            mask: Optional[torch.Tensor] = None,
            return_dict: bool = True,
        ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
            # 透過後量化卷積處理潛在變數
            z = self.post_quant_conv(z)
            # 使用解碼器生成輸出
            dec = self.decoder(z, image, mask)
    
            # 檢查是否返回字典
            if not return_dict:
                return (dec,)
    
            # 返回解碼器輸出
            return DecoderOutput(sample=dec)
    
        # 應用前向鉤子修飾符
        @apply_forward_hook
        def decode(
            self,
            z: torch.Tensor,
            generator: Optional[torch.Generator] = None,
            image: Optional[torch.Tensor] = None,
            mask: Optional[torch.Tensor] = None,
            return_dict: bool = True,
        ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
            # 呼叫解碼私有方法並獲取樣本
            decoded = self._decode(z, image, mask).sample
    
            # 檢查是否返回字典
            if not return_dict:
                return (decoded,)
    
            # 返回解碼器輸出
            return DecoderOutput(sample=decoded)
    
        # 定義前向傳播方法
        def forward(
            self,
            sample: torch.Tensor,
            mask: Optional[torch.Tensor] = None,
            sample_posterior: bool = False,
            return_dict: bool = True,
            generator: Optional[torch.Generator] = None,
    # 定義一個函式,返回型別為解碼輸出或包含張量的元組
    ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
        # 函式文件字串,描述輸入引數的含義
        r"""
        Args:
            sample (`torch.Tensor`): 輸入樣本。
            mask (`torch.Tensor`, *optional*, defaults to `None`): 可選的修補掩碼。
            sample_posterior (`bool`, *optional*, defaults to `False`):
                是否從後驗分佈中取樣。
            return_dict (`bool`, *optional*, defaults to `True`):
                是否返回解碼輸出而不是普通元組。
        """
        # 將輸入樣本賦值給變數 x
        x = sample
        # 對輸入樣本進行編碼,獲取潛在分佈
        posterior = self.encode(x).latent_dist
        # 根據標誌決定是從後驗分佈中取樣還是使用眾數
        if sample_posterior:
            z = posterior.sample(generator=generator)
        else:
            z = posterior.mode()
        # 解碼潛在變數 z,並獲取樣本
        dec = self.decode(z, generator, sample, mask).sample
    
        # 檢查是否返回字典格式的輸出
        if not return_dict:
            # 如果不返回字典,則返回解碼樣本的元組
            return (dec,)
    
        # 返回解碼輸出的例項
        return DecoderOutput(sample=dec)

.\diffusers\models\autoencoders\autoencoder_kl.py

# 版權宣告,表明此檔案的版權所有者及其所有權利
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根據 Apache 許可證 2.0 版本進行許可,宣告該檔案使用條件
# Licensed under the Apache License, Version 2.0 (the "License");
# 只能在符合許可證的情況下使用該檔案
# you may not use this file except in compliance with the License.
# 可以在此網址獲取許可證副本
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面協議另有約定,軟體按 "原樣" 提供,且不附帶任何形式的保證
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不承擔任何明示或暗示的擔保或條件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 檢視許可證以瞭解特定許可權和限制
# See the License for the specific language governing permissions and
# limitations under the License.
# 匯入所需的型別定義
from typing import Dict, Optional, Tuple, Union

# 匯入 PyTorch 庫
import torch
import torch.nn as nn

# 匯入其他模組中的混合類和工具函式
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils.accelerate_utils import apply_forward_hook
# 匯入注意力處理器相關的類和常量
from ..attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,
    CROSS_ATTENTION_PROCESSORS,
    Attention,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
    FusedAttnProcessor2_0,
)
# 匯入模型輸出相關的類
from ..modeling_outputs import AutoencoderKLOutput
# 匯入模型工具類
from ..modeling_utils import ModelMixin
# 匯入變分自編碼器相關的類
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder

# 定義一個變分自編碼器模型,使用 KL 損失編碼影像到潛在空間並解碼
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    r"""
    一個帶有 KL 損失的 VAE 模型,用於將影像編碼為潛在表示,並將潛在表示解碼為影像。

    該模型繼承自 [`ModelMixin`]。檢視超類文件以瞭解其實現的通用方法
    適用於所有模型(例如下載或儲存)。
    # 引數說明
        Parameters:
            # 輸入影像的通道數,預設為 3
            in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
            # 輸出的通道數,預設為 3
            out_channels (int,  *optional*, defaults to 3): Number of channels in the output.
            # 下采樣塊型別的元組,預設為 ("DownEncoderBlock2D",)
            down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
                Tuple of downsample block types.
            # 上取樣塊型別的元組,預設為 ("UpDecoderBlock2D",)
            up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
                Tuple of upsample block types.
            # 塊輸出通道的元組,預設為 (64,)
            block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
                Tuple of block output channels.
            # 使用的啟用函式,預設為 "silu"
            act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
            # 潛在空間的通道數,預設為 4
            latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
            # 樣本輸入大小,預設為 32
            sample_size (`int`, *optional*, defaults to 32): Sample input size.
            # 訓練潛在空間的分量標準差,預設為 0.18215
            scaling_factor (`float`, *optional*, defaults to 0.18215):
                The component-wise standard deviation of the trained latent space computed using the first batch of the
                training set. This is used to scale the latent space to have unit variance when training the diffusion
                model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
                diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
                / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
                Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
            # 是否強制使用 float32,以適應高解析度管道,預設為 True
            force_upcast (`bool`, *optional*, default to `True`):
                If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
                can be fine-tuned / trained to a lower range without losing too much precision in which case
                `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
            # 是否在 Encoder 和 Decoder 的 mid_block 中新增註意力塊,預設為 True
            mid_block_add_attention (`bool`, *optional*, default to `True`):
                If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
                mid_block will only have resnet blocks
        """
    
        # 支援梯度檢查點
        _supports_gradient_checkpointing = True
        # 不分割的模組列表
        _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
    
        # 註冊到配置中
        @register_to_config
    # 建構函式,初始化模型引數
    def __init__(
        # 輸入通道數,預設值為3
        self,
        in_channels: int = 3,
        # 輸出通道數,預設值為3
        out_channels: int = 3,
        # 下采樣塊的型別,預設為包含一個下采樣編碼塊的元組
        down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
        # 上取樣塊的型別,預設為包含一個上取樣解碼塊的元組
        up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
        # 每個塊的輸出通道數,預設為包含64的元組
        block_out_channels: Tuple[int] = (64,),
        # 每個塊的層數,預設為1
        layers_per_block: int = 1,
        # 啟用函式型別,預設為"silu"
        act_fn: str = "silu",
        # 潛在通道數,預設為4
        latent_channels: int = 4,
        # 歸一化的組數,預設為32
        norm_num_groups: int = 32,
        # 樣本大小,預設為32
        sample_size: int = 32,
        # 縮放因子,預設為0.18215
        scaling_factor: float = 0.18215,
        # 移位因子,預設為None(可選)
        shift_factor: Optional[float] = None,
        # 潛在變數的均值,預設為None(可選)
        latents_mean: Optional[Tuple[float]] = None,
        # 潛在變數的標準差,預設為None(可選)
        latents_std: Optional[Tuple[float]] = None,
        # 強制上溢位,預設為True
        force_upcast: float = True,
        # 使用量化卷積,預設為True
        use_quant_conv: bool = True,
        # 使用後量化卷積,預設為True
        use_post_quant_conv: bool = True,
        # 中間塊是否新增註意力機制,預設為True
        mid_block_add_attention: bool = True,
    ):
        # 呼叫父類建構函式
        super().__init__()

        # 將初始化引數傳遞給編碼器
        self.encoder = Encoder(
            # 輸入通道數
            in_channels=in_channels,
            # 輸出潛在通道數
            out_channels=latent_channels,
            # 下采樣塊的型別
            down_block_types=down_block_types,
            # 每個塊的輸出通道數
            block_out_channels=block_out_channels,
            # 每個塊的層數
            layers_per_block=layers_per_block,
            # 啟用函式型別
            act_fn=act_fn,
            # 歸一化的組數
            norm_num_groups=norm_num_groups,
            # 是否雙重潛在變數
            double_z=True,
            # 中間塊是否新增註意力機制
            mid_block_add_attention=mid_block_add_attention,
        )

        # 將初始化引數傳遞給解碼器
        self.decoder = Decoder(
            # 潛在通道數作為輸入
            in_channels=latent_channels,
            # 輸出通道數
            out_channels=out_channels,
            # 上取樣塊的型別
            up_block_types=up_block_types,
            # 每個塊的輸出通道數
            block_out_channels=block_out_channels,
            # 每個塊的層數
            layers_per_block=layers_per_block,
            # 歸一化的組數
            norm_num_groups=norm_num_groups,
            # 啟用函式型別
            act_fn=act_fn,
            # 中間塊是否新增註意力機制
            mid_block_add_attention=mid_block_add_attention,
        )

        # 根據是否使用量化卷積初始化卷積層
        self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
        # 根據是否使用後量化卷積初始化卷積層
        self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None

        # 是否使用切片,初始值為False
        self.use_slicing = False
        # 是否使用平鋪,初始值為False
        self.use_tiling = False

        # 僅在啟用VAE平鋪時相關
        # 平鋪取樣的最小大小設定為配置中的樣本大小
        self.tile_sample_min_size = self.config.sample_size
        # 獲取樣本大小,如果是列表或元組則取第一個元素
        sample_size = (
            self.config.sample_size[0]
            if isinstance(self.config.sample_size, (list, tuple))
            else self.config.sample_size
        )
        # 計算平鋪潛在變數的最小大小
        self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
        # 設定平鋪重疊因子
        self.tile_overlap_factor = 0.25

    # 設定梯度檢查點的函式
    def _set_gradient_checkpointing(self, module, value=False):
        # 如果模組是編碼器或解碼器,設定梯度檢查點標誌
        if isinstance(module, (Encoder, Decoder)):
            module.gradient_checkpointing = value

    # 啟用平鋪的函式
    def enable_tiling(self, use_tiling: bool = True):
        r"""
        啟用平鋪VAE解碼。當此選項啟用時,VAE將輸入張量拆分為平鋪塊,以分步計算解碼和編碼。
        這對於節省大量記憶體並允許處理更大影像非常有用。
        """
        # 設定使用平鋪的標誌
        self.use_tiling = use_tiling
    # 定義一個方法用於禁用瓷磚 VAE 解碼
    def disable_tiling(self):
        r""" 
        禁用瓷磚 VAE 解碼。如果之前啟用了 `enable_tiling`,此方法將恢復到一次性解碼計算。
        """
        # 呼叫設定方法,將瓷磚解碼狀態設定為 False
        self.enable_tiling(False)
    
    # 定義一個方法用於啟用切片 VAE 解碼
    def enable_slicing(self):
        r""" 
        啟用切片 VAE 解碼。當此選項啟用時,VAE 將把輸入張量分割成切片,以
        多次計算解碼。這有助於節省一些記憶體並允許更大的批次大小。
        """
        # 設定使用切片的標誌為 True
        self.use_slicing = True
    
    # 定義一個方法用於禁用切片 VAE 解碼
    def disable_slicing(self):
        r""" 
        禁用切片 VAE 解碼。如果之前啟用了 `enable_slicing`,此方法將恢復到一次性解碼計算。
        """
        # 設定使用切片的標誌為 False
        self.use_slicing = False
    
    # 定義一個屬性,用於返回注意力處理器
    @property
    # 複製自 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r""" 
        返回:
            `dict` 的注意力處理器:一個字典,包含模型中所有注意力處理器,按權重名稱索引。
        """
        # 建立一個空字典用於儲存處理器
        processors = {}
    
        # 定義遞迴函式用於新增處理器
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 如果模組具有獲取處理器的方法,則將其新增到字典中
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()
    
            # 遍歷模組的所有子模組,並遞迴呼叫處理器新增函式
            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
    
            return processors
    
        # 遍歷當前物件的所有子模組,並呼叫遞迴函式
        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)
    
        # 返回所有收集到的處理器
        return processors
    
    # 複製自 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
    # 設定用於計算注意力的處理器
        def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
            r"""
            設定用於計算注意力的處理器。
    
            引數:
                processor(`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                    已例項化的處理器類或處理器類的字典,將作為**所有** `Attention` 層的處理器設定。
    
                    如果 `processor` 是字典,則鍵需要定義對應的交叉注意力處理器的路徑。當設定可訓練的注意力處理器時,強烈推薦使用這種方式。
    
            """
            # 獲取當前注意力處理器的數量
            count = len(self.attn_processors.keys())
    
            # 檢查傳入的處理器字典長度是否與注意力層數量匹配
            if isinstance(processor, dict) and len(processor) != count:
                raise ValueError(
                    f"傳入的處理器字典數量 {len(processor)} 與注意力層數量 {count} 不匹配。請確保傳入 {count} 個處理器類。"
                )
    
            # 遞迴設定注意力處理器的函式
            def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
                # 如果模組有設定處理器的方法
                if hasattr(module, "set_processor"):
                    # 如果處理器不是字典,直接設定
                    if not isinstance(processor, dict):
                        module.set_processor(processor)
                    else:
                        # 從字典中彈出對應的處理器並設定
                        module.set_processor(processor.pop(f"{name}.processor"))
    
                # 遍歷模組的子模組,遞迴呼叫
                for sub_name, child in module.named_children():
                    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
            # 對當前例項的每個子模組呼叫遞迴函式
            for name, module in self.named_children():
                fn_recursive_attn_processor(name, module, processor)
    
        # 從 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 複製
        def set_default_attn_processor(self):
            """
            禁用自定義注意力處理器,並設定預設的注意力實現。
            """
            # 檢查所有處理器是否屬於新增的 KV 注意力處理器
            if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
                processor = AttnAddedKVProcessor()
            # 檢查所有處理器是否屬於交叉注意力處理器
            elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
                processor = AttnProcessor()
            else:
                raise ValueError(
                    f"當注意力處理器的型別為 {next(iter(self.attn_processors.values()))} 時,無法呼叫 `set_default_attn_processor`"
                )
    
            # 呼叫設定處理器的方法
            self.set_attn_processor(processor)
    
        # 應用前向鉤子
        @apply_forward_hook
        def encode(
            self, x: torch.Tensor, return_dict: bool = True
    # 定義返回型別為 AutoencoderKLOutput 或者 DiagonalGaussianDistribution 的函式
    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
            """
            編碼一批影像為潛在表示。
    
            引數:
                x (`torch.Tensor`): 輸入影像的批次。
                return_dict (`bool`, *可選*, 預設為 `True`):
                    是否返回 [`~models.autoencoder_kl.AutoencoderKLOutput`] 而非簡單元組。
    
            返回:
                    編碼影像的潛在表示。如果 `return_dict` 為 True,則返回一個
                    [`~models.autoencoder_kl.AutoencoderKLOutput`],否則返回一個普通的 `tuple`。
            """
            # 檢查是否使用平鋪,並且輸入尺寸超過最小平鋪尺寸
            if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
                # 使用平鋪編碼方法處理輸入
                return self.tiled_encode(x, return_dict=return_dict)
    
            # 檢查是否使用切片,並且輸入批次大於1
            if self.use_slicing and x.shape[0] > 1:
                # 對輸入的每個切片進行編碼,並將結果連線起來
                encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
                h = torch.cat(encoded_slices)
            else:
                # 直接編碼整個輸入
                h = self.encoder(x)
    
            # 檢查量化卷積是否存在
            if self.quant_conv is not None:
                # 使用量化卷積處理編碼後的結果
                moments = self.quant_conv(h)
            else:
                # 如果不存在,直接使用編碼結果
                moments = h
    
            # 建立對角高斯分佈的後驗
            posterior = DiagonalGaussianDistribution(moments)
    
            # 如果不返回字典,返回後驗分佈的元組
            if not return_dict:
                return (posterior,)
    
            # 返回 AutoencoderKLOutput 物件,包含潛在分佈
            return AutoencoderKLOutput(latent_dist=posterior)
    
        # 定義解碼函式,返回型別為 DecoderOutput 或 torch.Tensor
        def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
            # 檢查是否使用平鋪,並且潛在向量尺寸超過最小平鋪尺寸
            if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
                # 使用平鋪解碼方法處理潛在向量
                return self.tiled_decode(z, return_dict=return_dict)
    
            # 檢查後量化卷積是否存在
            if self.post_quant_conv is not None:
                # 使用後量化卷積處理潛在向量
                z = self.post_quant_conv(z)
    
            # 透過解碼器解碼潛在向量
            dec = self.decoder(z)
    
            # 如果不返回字典,返回解碼結果的元組
            if not return_dict:
                return (dec,)
    
            # 返回解碼結果的 DecoderOutput 物件
            return DecoderOutput(sample=dec)
    
        # 應用前向鉤子的解碼函式
        @apply_forward_hook
        def decode(
            self, z: torch.FloatTensor, return_dict: bool = True, generator=None
        ) -> Union[DecoderOutput, torch.FloatTensor]:
            """
            解碼一批影像。
    
            引數:
                z (`torch.Tensor`): 輸入潛在向量的批次。
                return_dict (`bool`, *可選*, 預設為 `True`):
                    是否返回 [`~models.vae.DecoderOutput`] 而非簡單元組。
    
            返回:
                [`~models.vae.DecoderOutput`] 或 `tuple`:
                    如果 return_dict 為 True,返回 [`~models.vae.DecoderOutput`],否則返回普通 `tuple`。
            """
            # 檢查是否使用切片,並且潛在向量批次大於1
            if self.use_slicing and z.shape[0] > 1:
                # 對每個切片進行解碼,並將結果連線起來
                decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
                decoded = torch.cat(decoded_slices)
            else:
                # 直接解碼整個潛在向量
                decoded = self._decode(z).sample
    
            # 如果不返回字典,返回解碼結果的元組
            if not return_dict:
                return (decoded,)
    
            # 返回解碼結果的 DecoderOutput 物件
            return DecoderOutput(sample=decoded)
    # 定義一個垂直混合函式,接受兩個張量和混合範圍,返回混合後的張量
        def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
            # 計算實際的混合範圍,確保不超過輸入張量的尺寸
            blend_extent = min(a.shape[2], b.shape[2], blend_extent)
            # 逐行進行混合操作,根據當前行的比例計算混合值
            for y in range(blend_extent):
                b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
            # 返回混合後的張量
            return b
    
    # 定義一個水平混合函式,接受兩個張量和混合範圍,返回混合後的張量
        def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
            # 計算實際的混合範圍,確保不超過輸入張量的尺寸
            blend_extent = min(a.shape[3], b.shape[3], blend_extent)
            # 逐列進行混合操作,根據當前列的比例計算混合值
            for x in range(blend_extent):
                b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
            # 返回混合後的張量
            return b
    # 定義一個函式,用於透過平鋪編碼器對影像批次進行編碼
    def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
        # 文件字串,描述該函式的用途及引數
        r"""Encode a batch of images using a tiled encoder.
    
        當這個選項啟用時,VAE 會將輸入張量分割成多個小塊以進行編碼
        步驟。這對於保持記憶體使用量恆定非常有用。平鋪編碼的最終結果與非平鋪編碼不同,
        因為每個小塊使用不同的編碼器。為了避免平鋪偽影,小塊之間會重疊並混合在一起
        形成平滑的輸出。你可能仍然會看到與小塊大小相關的變化,
        但這些變化應該不那麼明顯。
    
        引數:
            x (`torch.Tensor`): 輸入影像批次。
            return_dict (`bool`, *可選*, 預設為 `True`):
                是否返回一個 [`~models.autoencoder_kl.AutoencoderKLOutput`] 而不是一個普通元組。
    
        返回:
            [`~models.autoencoder_kl.AutoencoderKLOutput`] 或 `tuple`:
                如果 return_dict 為 True,則返回 [`~models.autoencoder_kl.AutoencoderKLOutput`],
                否則返回普通元組。
        """
        # 計算重疊區域的大小
        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
        # 計算混合的範圍
        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
        # 計算行限制,確保不會超出範圍
        row_limit = self.tile_latent_min_size - blend_extent
    
        # 初始化一個列表以儲存每一行的編碼結果
        rows = []
        # 遍歷輸入張量的高度,以重疊的方式進行切片
        for i in range(0, x.shape[2], overlap_size):
            # 初始化當前行的編碼結果列表
            row = []
            # 遍歷輸入張量的寬度,以重疊的方式進行切片
            for j in range(0, x.shape[3], overlap_size):
                # 切割當前小塊
                tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
                # 對當前小塊進行編碼
                tile = self.encoder(tile)
                # 如果配置使用量化卷積,則對小塊進行量化處理
                if self.config.use_quant_conv:
                    tile = self.quant_conv(tile)
                # 將編碼後的小塊新增到當前行中
                row.append(tile)
            # 將當前行的結果新增到 rows 列表中
            rows.append(row)
        # 初始化一個列表以儲存最終的結果行
        result_rows = []
        # 遍歷所有行以進行混合處理
        for i, row in enumerate(rows):
            result_row = []
            # 遍歷當前行的每個小塊
            for j, tile in enumerate(row):
                # 將上方小塊與當前小塊混合
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
                # 將左側小塊與當前小塊混合
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_extent)
                # 將混合後的小塊裁剪至指定大小並新增到結果行
                result_row.append(tile[:, :, :row_limit, :row_limit])
            # 將當前行的結果合併並新增到最終結果中
            result_rows.append(torch.cat(result_row, dim=3))
    
        # 將所有結果行合併為一個張量
        moments = torch.cat(result_rows, dim=2)
        # 建立一個對角高斯分佈以表示後驗分佈
        posterior = DiagonalGaussianDistribution(moments)
    
        # 如果不返回字典,則返回後驗分佈的元組
        if not return_dict:
            return (posterior,)
    
        # 返回包含後驗分佈的 AutoencoderKLOutput 物件
        return AutoencoderKLOutput(latent_dist=posterior)
    # 定義一個方法,用於解碼一批影像,使用平鋪解碼器
    def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
        r"""
        使用平鋪解碼器解碼一批影像。

        引數:
            z (`torch.Tensor`): 輸入的潛在向量批次。
            return_dict (`bool`, *可選*, 預設值為 `True`):
                是否返回一個 [`~models.vae.DecoderOutput`] 而不是普通的元組。

        返回:
            [`~models.vae.DecoderOutput`] 或 `tuple`:
                如果 return_dict 為 True,則返回一個 [`~models.vae.DecoderOutput`],
                否則返回普通的 `tuple`。
        """
        # 計算重疊區域的大小
        overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
        # 計算混合區域的範圍
        blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
        # 計算每行的限制大小
        row_limit = self.tile_sample_min_size - blend_extent

        # 將 z 分割成重疊的 64x64 瓦片,並分別解碼
        # 瓦片之間有重疊,以避免瓦片之間的接縫
        rows = []
        # 遍歷潛在向量 z 的高度,按重疊大小步進
        for i in range(0, z.shape[2], overlap_size):
            row = []  # 儲存當前行的解碼結果
            # 遍歷潛在向量 z 的寬度,按重疊大小步進
            for j in range(0, z.shape[3], overlap_size):
                # 從 z 中提取當前瓦片
                tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
                # 如果配置中啟用了後量化卷積,則對瓦片進行處理
                if self.config.use_post_quant_conv:
                    tile = self.post_quant_conv(tile)
                # 解碼當前瓦片
                decoded = self.decoder(tile)
                # 將解碼結果新增到當前行中
                row.append(decoded)
            # 將當前行新增到總行列表中
            rows.append(row)
        result_rows = []  # 儲存最終結果的行
        # 遍歷解碼的每一行
        for i, row in enumerate(rows):
            result_row = []  # 儲存當前結果行
            # 遍歷當前行的瓦片
            for j, tile in enumerate(row):
                # 將上方的瓦片與當前瓦片混合
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
                # 將左側的瓦片與當前瓦片混合
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_extent)
                # 將當前瓦片的結果裁剪到限制大小並新增到結果行
                result_row.append(tile[:, :, :row_limit, :row_limit])
            # 將結果行中的瓦片沿著寬度拼接
            result_rows.append(torch.cat(result_row, dim=3))

        # 將所有結果行沿著高度拼接
        dec = torch.cat(result_rows, dim=2)
        # 如果不返回字典,則返回解碼結果的元組
        if not return_dict:
            return (dec,)

        # 返回解碼結果的 DecoderOutput 物件
        return DecoderOutput(sample=dec)

    # 定義前向傳播方法
    def forward(
        # 輸入樣本的張量
        sample: torch.Tensor,
        # 是否對樣本進行後驗取樣
        sample_posterior: bool = False,
        # 是否返回字典形式的結果
        return_dict: bool = True,
        # 隨機數生成器(可選)
        generator: Optional[torch.Generator] = None,
    ) -> Union[DecoderOutput, torch.Tensor]:
        r"""  # 函式的返回型別是 DecoderOutput 或 torch.Tensor 的聯合型別
        Args:  # 引數說明
            sample (`torch.Tensor`): Input sample.  # 輸入樣本,型別為 torch.Tensor
            sample_posterior (`bool`, *optional*, defaults to `False`):  # 是否從後驗分佈進行取樣,預設為 False
                Whether to sample from the posterior.  # 描述引數的用途
            return_dict (`bool`, *optional*, defaults to `True`):  # 是否返回 DecoderOutput 而不是普通元組,預設為 True
                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.  # 描述引數的用途
        """
        x = sample  # 將輸入樣本賦值給 x
        posterior = self.encode(x).latent_dist  # 對輸入樣本進行編碼,並獲取其後驗分佈
        if sample_posterior:  # 檢查是否需要從後驗分佈中取樣
            z = posterior.sample(generator=generator)  # 從後驗分佈中進行取樣
        else:  # 否則
            z = posterior.mode()  # 取後驗分佈的眾數
        dec = self.decode(z).sample  # 解碼 z 並獲取樣本

        if not return_dict:  # 如果不需要返回字典
            return (dec,)  # 返回樣本作為元組

        return DecoderOutput(sample=dec)  # 返回 DecoderOutput 物件,包含解碼後的樣本

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
    def fuse_qkv_projections(self):  # 定義融合 QKV 投影的方法
        """  # 方法的文件字串
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)  # 啟用融合的 QKV 投影,適用於自注意力模組
        are fused. For cross-attention modules, key and value projection matrices are fused.  # 適用於交叉注意力模組

        <Tip warning={true}>  # 提示標籤,表示此 API 為實驗性
        This API is 🧪 experimental.  # 提示內容
        </Tip>
        """
        self.original_attn_processors = None  # 初始化原始注意力處理器為 None

        for _, attn_processor in self.attn_processors.items():  # 遍歷當前的注意力處理器
            if "Added" in str(attn_processor.__class__.__name__):  # 檢查處理器類名中是否包含 "Added"
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")  # 丟擲異常,提示不支援融合操作

        self.original_attn_processors = self.attn_processors  # 儲存當前的注意力處理器

        for module in self.modules():  # 遍歷模型中的所有模組
            if isinstance(module, Attention):  # 如果模組是 Attention 型別
                module.fuse_projections(fuse=True)  # 融合其投影

        self.set_attn_processor(FusedAttnProcessor2_0())  # 設定融合的注意力處理器

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
    def unfuse_qkv_projections(self):  # 定義取消融合 QKV 投影的方法
        """Disables the fused QKV projection if enabled.  # 如果已啟用,禁用融合的 QKV 投影

        <Tip warning={true}>  # 提示標籤,表示此 API 為實驗性
        This API is 🧪 experimental.  # 提示內容
        </Tip>

        """
        if self.original_attn_processors is not None:  # 如果原始注意力處理器不為 None
            self.set_attn_processor(self.original_attn_processors)  # 恢復原始注意力處理器

.\diffusers\models\autoencoders\autoencoder_kl_cogvideox.py

# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union  # 匯入用於型別註解的 Optional、Tuple 和 Union

import numpy as np  # 匯入 NumPy 庫,用於陣列和矩陣操作
import torch  # 匯入 PyTorch 庫,用於張量操作和深度學習
import torch.nn as nn  # 匯入 PyTorch 的神經網路模組
import torch.nn.functional as F  # 匯入 PyTorch 的函式式 API

from ...configuration_utils import ConfigMixin, register_to_config  # 從配置工具中匯入配置混合類和註冊配置的方法
from ...loaders.single_file_model import FromOriginalModelMixin  # 匯入處理單檔案模型的混合類
from ...utils import logging  # 匯入日誌工具
from ...utils.accelerate_utils import apply_forward_hook  # 匯入用於應用前向鉤子的工具
from ..activations import get_activation  # 匯入獲取啟用函式的方法
from ..downsampling import CogVideoXDownsample3D  # 匯入3D下采樣模組
from ..modeling_outputs import AutoencoderKLOutput  # 匯入自編碼器KL輸出的模組
from ..modeling_utils import ModelMixin  # 匯入模型混合類
from ..upsampling import CogVideoXUpsample3D  # 匯入3D上取樣模組
from .vae import DecoderOutput, DiagonalGaussianDistribution  # 匯入變分自編碼器相關輸出類和分佈類


logger = logging.get_logger(__name__)  # 獲取當前模組的日誌記錄器,禁用 pylint 警告

class CogVideoXSafeConv3d(nn.Conv3d):  # 定義一個繼承自 nn.Conv3d 的類,代表安全的3D卷積層
    r"""A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
    """  # 類文件字串,描述該卷積層的功能

    def forward(self, input: torch.Tensor) -> torch.Tensor:  # 定義前向傳播方法,接收一個張量並返回一個張量
        memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3  # 計算輸入張量的記憶體佔用(GB)

        # Set to 2GB, suitable for CuDNN
        if memory_count > 2:  # 如果記憶體佔用超過2GB
            kernel_size = self.kernel_size[0]  # 獲取卷積核的大小
            part_num = int(memory_count / 2) + 1  # 計算需要拆分的部分數量
            input_chunks = torch.chunk(input, part_num, dim=2)  # 將輸入張量沿著深度維度拆分成多個塊

            if kernel_size > 1:  # 如果卷積核大小大於1
                input_chunks = [input_chunks[0]] + [  # 將第一個塊保留並處理後續塊
                    torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)  # 將前一個塊和當前塊拼接
                    for i in range(1, len(input_chunks))  # 遍歷後續塊
                ]

            output_chunks = []  # 初始化輸出塊的列表
            for input_chunk in input_chunks:  # 遍歷所有輸入塊
                output_chunks.append(super().forward(input_chunk))  # 使用父類的前向方法處理每個輸入塊並儲存結果
            output = torch.cat(output_chunks, dim=2)  # 將所有輸出塊沿著深度維度拼接
            return output  # 返回拼接後的輸出
        else:  # 如果記憶體佔用不超過2GB
            return super().forward(input)  # 直接使用父類的前向方法處理輸入張量


class CogVideoXCausalConv3d(nn.Module):  # 定義一個3D因果卷積層的類,繼承自 nn.Module
    r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
    """  # 類文件字串,描述該因果卷積層的功能
    # 引數說明文件
    Args:
        in_channels (`int`): 輸入張量的通道數。
        out_channels (`int`): 卷積生成的輸出通道數。
        kernel_size (`int` or `Tuple[int, int, int]`): 卷積核的大小。
        stride (`int`, defaults to `1`): 卷積的步幅。
        dilation (`int`, defaults to `1`): 卷積的擴張率。
        pad_mode (`str`, defaults to `"constant"`): 填充模式。
    """

    # 初始化方法
    def __init__(
        # 輸入通道數
        self,
        in_channels: int,
        # 輸出通道數
        out_channels: int,
        # 卷積核大小
        kernel_size: Union[int, Tuple[int, int, int]],
        # 步幅,預設為1
        stride: int = 1,
        # 擴張率,預設為1
        dilation: int = 1,
        # 填充模式,預設為"constant"
        pad_mode: str = "constant",
    ):
        # 呼叫父類建構函式
        super().__init__()

        # 如果卷積核大小是整數,則擴充套件為三維元組
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size,) * 3

        # 解包卷積核的時間、高度和寬度尺寸
        time_kernel_size, height_kernel_size, width_kernel_size = kernel_size

        # 設定填充模式
        self.pad_mode = pad_mode
        # 計算時間維度的填充量
        time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
        # 計算高度和寬度的填充量
        height_pad = height_kernel_size // 2
        width_pad = width_kernel_size // 2

        # 儲存填充量
        self.height_pad = height_pad
        self.width_pad = width_pad
        self.time_pad = time_pad
        # 設定因果填充引數
        self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)

        # 設定時間維度索引
        self.temporal_dim = 2
        # 儲存時間卷積核大小
        self.time_kernel_size = time_kernel_size

        # 將步幅和擴張轉換為三維元組
        stride = (stride, 1, 1)
        dilation = (dilation, 1, 1)
        # 建立三維卷積層物件
        self.conv = CogVideoXSafeConv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
        )

        # 初始化卷積快取為None
        self.conv_cache = None

    # 假上下文並行前向傳播方法
    def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
        # 獲取時間卷積核大小
        kernel_size = self.time_kernel_size
        # 如果卷積核大小大於1,進行快取處理
        if kernel_size > 1:
            # 使用快取的輸入,或者用當前輸入的首個切片填充
            cached_inputs = (
                [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
            )
            # 將快取輸入和當前輸入連線在一起
            inputs = torch.cat(cached_inputs + [inputs], dim=2)
        # 返回處理後的輸入
        return inputs

    # 清除假上下文並行快取的方法
    def _clear_fake_context_parallel_cache(self):
        # 刪除卷積快取
        del self.conv_cache
        # 將卷積快取設定為None
        self.conv_cache = None

    # 前向傳播方法
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        # 進行假上下文並行前向傳播
        inputs = self.fake_context_parallel_forward(inputs)

        # 清除卷積快取
        self._clear_fake_context_parallel_cache()
        # 注意:可以將這些資料移動到CPU以降低記憶體使用,但目前僅幾百兆,不考慮
        # 快取輸入的最後幾幀資料
        self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()

        # 設定二維填充引數
        padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
        # 對輸入進行填充
        inputs = F.pad(inputs, padding_2d, mode="constant", value=0)

        # 透過卷積層處理輸入
        output = self.conv(inputs)
        # 返回卷積結果
        return output
# 定義一個用於空間條件歸一化的3D影片處理模型
class CogVideoXSpatialNorm3D(nn.Module):
    r"""
    根據 https://arxiv.org/abs/2209.09002 中定義的空間條件歸一化,專門針對3D影片資料的實現。

    使用 CogVideoXSafeConv3d 替代 nn.Conv3d,以避免在 CogVideoX 模型中出現記憶體不足問題。

    引數:
        f_channels (`int`):
            輸入到組歸一化層的通道數,以及空間歸一化層的輸出通道數。
        zq_channels (`int`):
            論文中描述的量化向量的通道數。
        groups (`int`):
            用於將通道分組的組數。
    """

    # 初始化模型
    def __init__(
        self,
        f_channels: int,
        zq_channels: int,
        groups: int = 32,
    ):
        # 呼叫父類建構函式
        super().__init__()
        # 建立組歸一化層
        self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
        # 建立因果卷積層用於Y通道
        self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
        # 建立因果卷積層用於B通道
        self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)

    # 前向傳播定義
    def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
        # 檢查輸入形狀,確保處理的邏輯正確
        if f.shape[2] > 1 and f.shape[2] % 2 == 1:
            # 分離第一個幀和其餘幀
            f_first, f_rest = f[:, :, :1], f[:, :, 1:]
            # 獲取各部分的大小
            f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
            # 分離量化向量
            z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
            # 進行插值調整大小
            z_first = F.interpolate(z_first, size=f_first_size)
            z_rest = F.interpolate(z_rest, size=f_rest_size)
            # 合併調整後的量化向量
            zq = torch.cat([z_first, z_rest], dim=2)
        else:
            # 對量化向量進行插值以匹配輸入形狀
            zq = F.interpolate(zq, size=f.shape[-3:])

        # 對輸入進行歸一化
        norm_f = self.norm_layer(f)
        # 計算新的輸出
        new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
        # 返回處理後的結果
        return new_f


# 定義用於CogVideoX模型的3D ResNet塊
class CogVideoXResnetBlock3D(nn.Module):
    r"""
    CogVideoX模型中使用的3D ResNet塊。

    引數:
        in_channels (`int`):
            輸入通道數。
        out_channels (`int`, *可選*):
            輸出通道數。如果為 None,預設與 `in_channels` 相同。
        dropout (`float`, 預設值為 `0.0`):
            Dropout比率。
        temb_channels (`int`, 預設值為 `512`):
            時間嵌入通道數。
        groups (`int`, 預設值為 `32`):
            用於將通道分組的組數。
        eps (`float`, 預設值為 `1e-6`):
            歸一化層的 epsilon 值。
        non_linearity (`str`, 預設值為 `"swish"`):
            使用的啟用函式。
        conv_shortcut (bool, 預設值為 `False`):
            是否使用卷積快捷連線。
        spatial_norm_dim (`int`, *可選*):
            如果使用空間歸一化而非組歸一化時的維度。
        pad_mode (str, 預設值為 `"first"`):
            填充模式。
    """
    # 初始化方法,設定神經網路層的引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: Optional[int] = None,  # 輸出通道數(可選,預設為 None)
        dropout: float = 0.0,  # 丟棄率(預設為 0.0)
        temb_channels: int = 512,  # 時間嵌入通道數(預設為 512)
        groups: int = 32,  # 分組數(預設為 32)
        eps: float = 1e-6,  # 為數值穩定性而新增的小常數(預設為 1e-6)
        non_linearity: str = "swish",  # 非線性啟用函式的型別(預設為 "swish")
        conv_shortcut: bool = False,  # 是否使用卷積快捷連線(預設為 False)
        spatial_norm_dim: Optional[int] = None,  # 空間歸一化維度(可選)
        pad_mode: str = "first",  # 填充模式(預設為 "first")
    ):
        # 呼叫父類初始化方法
        super().__init__()

        # 如果未提供輸出通道數,則將其設定為輸入通道數
        out_channels = out_channels or in_channels

        # 儲存輸入和輸出通道數
        self.in_channels = in_channels
        self.out_channels = out_channels
        # 獲取指定的非線性啟用函式
        self.nonlinearity = get_activation(non_linearity)
        # 儲存是否使用卷積快捷連線的標誌
        self.use_conv_shortcut = conv_shortcut

        # 根據空間歸一化維度選擇歸一化方法
        if spatial_norm_dim is None:
            # 建立第一個歸一化層,使用分組歸一化
            self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
            # 建立第二個歸一化層,使用分組歸一化
            self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
        else:
            # 建立第一個歸一化層,使用空間歸一化
            self.norm1 = CogVideoXSpatialNorm3D(
                f_channels=in_channels,
                zq_channels=spatial_norm_dim,
                groups=groups,
            )
            # 建立第二個歸一化層,使用空間歸一化
            self.norm2 = CogVideoXSpatialNorm3D(
                f_channels=out_channels,
                zq_channels=spatial_norm_dim,
                groups=groups,
            )

        # 建立第一個卷積層
        self.conv1 = CogVideoXCausalConv3d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
        )

        # 如果時間嵌入通道數大於 0,則建立時間嵌入投影層
        if temb_channels > 0:
            self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)

        # 建立丟棄層
        self.dropout = nn.Dropout(dropout)
        # 建立第二個卷積層
        self.conv2 = CogVideoXCausalConv3d(
            in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
        )

        # 如果輸入通道數與輸出通道數不相同,則建立快捷連線
        if self.in_channels != self.out_channels:
            # 如果使用卷積快捷連線
            if self.use_conv_shortcut:
                # 建立卷積快捷連線層
                self.conv_shortcut = CogVideoXCausalConv3d(
                    in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
                )
            else:
                # 建立安全卷積快捷連線層
                self.conv_shortcut = CogVideoXSafeConv3d(
                    in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
                )

    # 前向傳播方法,定義模型如何處理輸入資料
    def forward(
        self,
        inputs: torch.Tensor,  # 輸入張量
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        zq: Optional[torch.Tensor] = None,  # 可選的 zq 張量
    # 定義函式的返回型別為 torch.Tensor
        ) -> torch.Tensor:
            # 初始化隱藏狀態為輸入
            hidden_states = inputs
    
            # 如果 zq 不為 None,則對隱藏狀態應用 norm1 歸一化
            if zq is not None:
                hidden_states = self.norm1(hidden_states, zq)
            # 否則僅對隱藏狀態應用 norm1 歸一化
            else:
                hidden_states = self.norm1(hidden_states)
    
            # 應用非線性啟用函式
            hidden_states = self.nonlinearity(hidden_states)
            # 透過卷積層 conv1 處理隱藏狀態
            hidden_states = self.conv1(hidden_states)
    
            # 如果 temb 不為 None,則將其透過投影與隱藏狀態相加
            if temb is not None:
                hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
    
            # 如果 zq 不為 None,則對隱藏狀態應用 norm2 歸一化
            if zq is not None:
                hidden_states = self.norm2(hidden_states, zq)
            # 否則僅對隱藏狀態應用 norm2 歸一化
            else:
                hidden_states = self.norm2(hidden_states)
    
            # 應用非線性啟用函式
            hidden_states = self.nonlinearity(hidden_states)
            # 進行 dropout 操作以防止過擬合
            hidden_states = self.dropout(hidden_states)
            # 透過卷積層 conv2 處理隱藏狀態
            hidden_states = self.conv2(hidden_states)
    
            # 如果輸入通道數與輸出通道數不相等,應用卷積快捷連線
            if self.in_channels != self.out_channels:
                inputs = self.conv_shortcut(inputs)
    
            # 將輸入與隱藏狀態相加以形成最終的隱藏狀態
            hidden_states = hidden_states + inputs
            # 返回最終的隱藏狀態
            return hidden_states
# 定義一個用於 CogVideoX 模型的下采樣模組
class CogVideoXDownBlock3D(nn.Module):
    r"""
    CogVideoX 模型中使用的下采樣塊。

    Args:
        in_channels (`int`):
            輸入通道數。
        out_channels (`int`, *可選*):
            輸出通道數。如果為 None,預設為 `in_channels`。
        temb_channels (`int`, defaults to `512`):
            時間嵌入通道數。
        num_layers (`int`, defaults to `1`):
            ResNet 層數。
        dropout (`float`, defaults to `0.0`):
            Dropout 率。
        resnet_eps (`float`, defaults to `1e-6`):
            歸一化層的 epsilon 值。
        resnet_act_fn (`str`, defaults to `"swish"`):
            使用的啟用函式。
        resnet_groups (`int`, defaults to `32`):
            用於組歸一化的通道組數。
        add_downsample (`bool`, defaults to `True`):
            是否使用下采樣層。如果不使用,輸出維度將與輸入維度相同。
        compress_time (`bool`, defaults to `False`):
            是否在時間維度上進行下采樣。
        pad_mode (str, defaults to `"first"`):
            填充模式。
    """

    # 支援梯度檢查點
    _supports_gradient_checkpointing = True

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        temb_channels: int,
        dropout: float = 0.0,
        num_layers: int = 1,
        resnet_eps: float = 1e-6,
        resnet_act_fn: str = "swish",
        resnet_groups: int = 32,
        add_downsample: bool = True,
        downsample_padding: int = 0,
        compress_time: bool = False,
        pad_mode: str = "first",
    ):
        # 初始化父類 nn.Module
        super().__init__()

        resnets = []  # 建立一個空列表以儲存 ResNet 層
        for i in range(num_layers):
            # 確定當前層的輸入通道數
            in_channel = in_channels if i == 0 else out_channels
            # 將 ResNet 層新增到列表中
            resnets.append(
                CogVideoXResnetBlock3D(
                    in_channels=in_channel,
                    out_channels=out_channels,
                    dropout=dropout,
                    temb_channels=temb_channels,
                    groups=resnet_groups,
                    eps=resnet_eps,
                    non_linearity=resnet_act_fn,
                    pad_mode=pad_mode,
                )
            )

        # 將 ResNet 層列表轉換為 nn.ModuleList 以便於管理
        self.resnets = nn.ModuleList(resnets)
        self.downsamplers = None  # 初始化下采樣層為 None

        # 如果需要下采樣,則新增下采樣層
        if add_downsample:
            self.downsamplers = nn.ModuleList(
                [
                    CogVideoXDownsample3D(
                        out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
                    )
                ]
            )

        self.gradient_checkpointing = False  # 初始化梯度檢查點為 False

    def forward(
        self,
        hidden_states: torch.Tensor,
        temb: Optional[torch.Tensor] = None,
        zq: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:  # 指定返回型別為 torch.Tensor
        for resnet in self.resnets:  # 遍歷每個 ResNet 模組
            if self.training and self.gradient_checkpointing:  # 檢查是否在訓練模式且啟用梯度檢查點

                def create_custom_forward(module):  # 定義建立自定義前向傳播的函式
                    def create_forward(*inputs):  # 定義前向傳播的具體實現
                        return module(*inputs)  # 呼叫傳入的模組進行前向傳播

                    return create_forward  # 返回前向傳播函式

                hidden_states = torch.utils.checkpoint.checkpoint(  # 使用檢查點機制計算前向傳播
                    create_custom_forward(resnet), hidden_states, temb, zq  # 呼叫自定義前向函式並傳入引數
                )
            else:  # 如果不滿足前面條件
                hidden_states = resnet(hidden_states, temb, zq)  # 直接透過 ResNet 模組進行前向傳播

        if self.downsamplers is not None:  # 檢查是否存在下采樣模組
            for downsampler in self.downsamplers:  # 遍歷每個下采樣模組
                hidden_states = downsampler(hidden_states)  # 透過下采樣模組處理隱藏狀態

        return hidden_states  # 返回處理後的隱藏狀態
# 定義 CogVideoX 模型中的一箇中間模組,繼承自 nn.Module
class CogVideoXMidBlock3D(nn.Module):
    r"""
    CogVideoX 模型中使用的中間塊。

    引數:
        in_channels (`int`):
            輸入通道的數量。
        temb_channels (`int`, defaults to `512`):
            時間嵌入通道的數量。
        dropout (`float`, defaults to `0.0`):
            dropout 比率。
        num_layers (`int`, defaults to `1`):
            ResNet 層的數量。
        resnet_eps (`float`, defaults to `1e-6`):
            歸一化層的 epsilon 值。
        resnet_act_fn (`str`, defaults to `"swish"`):
            要使用的啟用函式。
        resnet_groups (`int`, defaults to `32`):
            用於組歸一化的通道分組數量。
        spatial_norm_dim (`int`, *optional*):
            如果使用空間歸一化而不是組歸一化,則使用的維度。
        pad_mode (str, defaults to `"first"`):
            填充模式。
    """

    # 指示是否支援梯度檢查點
    _supports_gradient_checkpointing = True

    # 初始化方法
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        temb_channels: int,  # 時間嵌入通道數
        dropout: float = 0.0,  # dropout 比率
        num_layers: int = 1,  # ResNet 層數
        resnet_eps: float = 1e-6,  # 歸一化層的 epsilon 值
        resnet_act_fn: str = "swish",  # 啟用函式
        resnet_groups: int = 32,  # 組歸一化的組數
        spatial_norm_dim: Optional[int] = None,  # 空間歸一化的維度
        pad_mode: str = "first",  # 填充模式
    ):
        super().__init__()  # 呼叫父類的初始化方法

        resnets = []  # 初始化一個空列表以儲存 ResNet 層
        for _ in range(num_layers):  # 根據層數迴圈
            resnets.append(  # 將新的 ResNet 層新增到列表中
                CogVideoXResnetBlock3D(  # 例項化 ResNet 層
                    in_channels=in_channels,  # 輸入通道數
                    out_channels=in_channels,  # 輸出通道數與輸入相同
                    dropout=dropout,  # dropout 比率
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    groups=resnet_groups,  # 組歸一化的組數
                    eps=resnet_eps,  # epsilon 值
                    spatial_norm_dim=spatial_norm_dim,  # 空間歸一化的維度
                    non_linearity=resnet_act_fn,  # 啟用函式
                    pad_mode=pad_mode,  # 填充模式
                )
            )
        self.resnets = nn.ModuleList(resnets)  # 將 ResNet 層列表轉換為 ModuleList

        self.gradient_checkpointing = False  # 初始化梯度檢查點標誌為 False

    # 前向傳播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 隱藏狀態的輸入張量
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        zq: Optional[torch.Tensor] = None,  # 可選的 zq 張量
    ) -> torch.Tensor:  # 返回張量
        for resnet in self.resnets:  # 遍歷每個 ResNet 層
            if self.training and self.gradient_checkpointing:  # 如果在訓練中且支援梯度檢查點

                # 建立一個自定義前向傳播的函式
                def create_custom_forward(module):
                    def create_forward(*inputs):  # 定義前向傳播函式
                        return module(*inputs)  # 呼叫模組的前向傳播

                    return create_forward  # 返回前向傳播函式

                # 使用檢查點機制執行前向傳播以節省記憶體
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(resnet),  # 傳入自定義前向函式
                    hidden_states,  # 隱藏狀態
                    temb,  # 時間嵌入
                    zq  # zq 張量
                )
            else:
                hidden_states = resnet(hidden_states, temb, zq)  # 直接呼叫 ResNet 層的前向傳播

        return hidden_states  # 返回隱藏狀態的輸出


# 定義 CogVideoX 模型中的一個上取樣模組,繼承自 nn.Module
class CogVideoXUpBlock3D(nn.Module):
    r"""
    CogVideoX 模型中使用的上取樣塊。
    # 引數說明
    Args:
        in_channels (`int`):  # 輸入通道的數量
            Number of input channels.
        out_channels (`int`, *optional*):  # 輸出通道的數量,如果為 None,則預設為 `in_channels`
            Number of output channels. If None, defaults to `in_channels`.
        temb_channels (`int`, defaults to `512`):  # 時間嵌入通道的數量
            Number of time embedding channels.
        dropout (`float`, defaults to `0.0`):  # dropout 率
            Dropout rate.
        num_layers (`int`, defaults to `1`):  # ResNet 層的數量
            Number of resnet layers.
        resnet_eps (`float`, defaults to `1e-6`):  # 歸一化層的 epsilon 值
            Epsilon value for normalization layers.
        resnet_act_fn (`str`, defaults to `"swish"`):  # 使用的啟用函式
            Activation function to use.
        resnet_groups (`int`, defaults to `32`):  # 用於組歸一化的通道組數
            Number of groups to separate the channels into for group normalization.
        spatial_norm_dim (`int`, defaults to `16`):  # 用於空間歸一化的維度
            The dimension to use for spatial norm if it is to be used instead of group norm.
        add_upsample (`bool`, defaults to `True`):  # 是否使用上取樣層
            Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
        compress_time (`bool`, defaults to `False`):  # 是否在時間維度上進行下采樣
            Whether or not to downsample across temporal dimension.
        pad_mode (str, defaults to `"first"`):  # 填充模式
            Padding mode.
    """

    def __init__(  # 初始化方法
        self,
        in_channels: int,  # 輸入通道數量
        out_channels: int,  # 輸出通道數量
        temb_channels: int,  # 時間嵌入通道數量
        dropout: float = 0.0,  # dropout 率
        num_layers: int = 1,  # ResNet 層數量
        resnet_eps: float = 1e-6,  # 歸一化 epsilon 值
        resnet_act_fn: str = "swish",  # 啟用函式
        resnet_groups: int = 32,  # 組歸一化的組數
        spatial_norm_dim: int = 16,  # 空間歸一化維度
        add_upsample: bool = True,  # 是否新增上取樣層
        upsample_padding: int = 1,  # 上取樣時的填充
        compress_time: bool = False,  # 是否壓縮時間維度
        pad_mode: str = "first",  # 填充模式
    ):
        super().__init__()  # 呼叫父類初始化方法

        resnets = []  # 初始化空列表以儲存 ResNet 層
        for i in range(num_layers):  # 遍歷每一層
            in_channel = in_channels if i == 0 else out_channels  # 確定當前層的輸入通道數量
            resnets.append(  # 將新的 ResNet 塊新增到列表中
                CogVideoXResnetBlock3D(
                    in_channels=in_channel,  # 設定輸入通道數量
                    out_channels=out_channels,  # 設定輸出通道數量
                    dropout=dropout,  # 設定 dropout 率
                    temb_channels=temb_channels,  # 設定時間嵌入通道數量
                    groups=resnet_groups,  # 設定組數量
                    eps=resnet_eps,  # 設定 epsilon 值
                    non_linearity=resnet_act_fn,  # 設定非線性啟用函式
                    spatial_norm_dim=spatial_norm_dim,  # 設定空間歸一化維度
                    pad_mode=pad_mode,  # 設定填充模式
                )
            )

        self.resnets = nn.ModuleList(resnets)  # 將 ResNet 層列表轉換為 ModuleList
        self.upsamplers = None  # 初始化上取樣器為 None

        if add_upsample:  # 如果需要新增上取樣層
            self.upsamplers = nn.ModuleList(  # 建立上取樣器的 ModuleList
                [
                    CogVideoXUpsample3D(  # 新增上取樣層
                        out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
                    )
                ]
            )

        self.gradient_checkpointing = False  # 初始化梯度檢查點標誌為 False

    def forward(  # 前向傳播方法
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        zq: Optional[torch.Tensor] = None,  # 可選的額外張量
    # CogVideoXUpBlock3D 類的前向傳播方法
    ) -> torch.Tensor:
            r"""Forward method of the `CogVideoXUpBlock3D` class."""
            # 遍歷類中的每個 ResNet 模組
            for resnet in self.resnets:
                # 如果處於訓練模式並且啟用了梯度檢查點
                if self.training and self.gradient_checkpointing:
                    # 定義一個自定義前向傳播函式
                    def create_custom_forward(module):
                        # 建立接受輸入的前向傳播函式
                        def create_forward(*inputs):
                            return module(*inputs)
                        # 返回自定義的前向傳播函式
                        return create_forward
    
                    # 使用梯度檢查點機制計算隱藏狀態
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb, zq
                    )
                else:
                    # 直接透過 ResNet 模組處理隱藏狀態
                    hidden_states = resnet(hidden_states, temb, zq)
    
            # 如果存在上取樣器
            if self.upsamplers is not None:
                # 遍歷每個上取樣器
                for upsampler in self.upsamplers:
                    # 透過上取樣器處理隱藏狀態
                    hidden_states = upsampler(hidden_states)
    
            # 返回最終的隱藏狀態
            return hidden_states
# 定義一個名為 `CogVideoXEncoder3D` 的類,繼承自 `nn.Module`,用於變分自編碼器
class CogVideoXEncoder3D(nn.Module):
    r"""
    `CogVideoXEncoder3D` 層用於將輸入編碼為潛在表示的變分自編碼器。

    引數:
        in_channels (`int`, *可選*, 預設值為 3):
            輸入通道的數量。
        out_channels (`int`, *可選*, 預設值為 3):
            輸出通道的數量。
        down_block_types (`Tuple[str, ...]`, *可選*, 預設值為 `("DownEncoderBlock2D",)`):
            使用的下采樣塊型別。有關可用選項,請參見 `~diffusers.models.unet_2d_blocks.get_down_block`。
        block_out_channels (`Tuple[int, ...]`, *可選*, 預設值為 `(64,)`):
            每個塊的輸出通道數量。
        act_fn (`str`, *可選*, 預設值為 `"silu"`):
            要使用的啟用函式。有關可用選項,請參見 `~diffusers.models.activations.get_activation`。
        layers_per_block (`int`, *可選*, 預設值為 2):
            每個塊的層數。
        norm_num_groups (`int`, *可選*, 預設值為 32):
            歸一化的組數。
    """

    # 設定類屬性以支援梯度檢查點
    _supports_gradient_checkpointing = True

    # 初始化方法,設定類的引數
    def __init__(
        self,
        in_channels: int = 3,  # 輸入通道數,預設為 3
        out_channels: int = 16,  # 輸出通道數,預設為 16
        down_block_types: Tuple[str, ...] = (  # 下采樣塊型別的元組
            "CogVideoXDownBlock3D",  # 第一個下采樣塊型別
            "CogVideoXDownBlock3D",  # 第二個下采樣塊型別
            "CogVideoXDownBlock3D",  # 第三個下采樣塊型別
            "CogVideoXDownBlock3D",  # 第四個下采樣塊型別
        ),
        block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),  # 每個塊的輸出通道數的元組
        layers_per_block: int = 3,  # 每個塊的層數,預設為 3
        act_fn: str = "silu",  # 啟用函式,預設為 "silu"
        norm_eps: float = 1e-6,  # 歸一化的 epsilon 值,預設為 1e-6
        norm_num_groups: int = 32,  # 歸一化組數,預設為 32
        dropout: float = 0.0,  # dropout 機率,預設為 0.0
        pad_mode: str = "first",  # 填充模式,預設為 "first"
        temporal_compression_ratio: float = 4,  # 時間壓縮比,預設為 4
    ):
        # 呼叫父類的初始化方法
        super().__init__()

        # 計算時間壓縮等級的對數(以2為底)
        temporal_compress_level = int(np.log2(temporal_compression_ratio))

        # 建立一個三維卷積層,輸入通道數為in_channels,輸出通道數為block_out_channels[0],卷積核大小為3
        self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
        # 初始化一個空的ModuleList,用於儲存下采樣模組
        self.down_blocks = nn.ModuleList([])

        # 設定初始輸出通道數為第一個塊的輸出通道數
        output_channel = block_out_channels[0]
        # 遍歷下采樣模組的型別,i為索引,down_block_type為型別
        for i, down_block_type in enumerate(down_block_types):
            # 輸入通道數為當前輸出通道數
            input_channel = output_channel
            # 更新輸出通道數為當前塊的輸出通道數
            output_channel = block_out_channels[i]
            # 判斷是否為最後一個下采樣塊
            is_final_block = i == len(block_out_channels) - 1
            # 判斷當前塊是否需要壓縮時間
            compress_time = i < temporal_compress_level

            # 如果下采樣模組的型別為CogVideoXDownBlock3D
            if down_block_type == "CogVideoXDownBlock3D":
                # 建立下采樣塊,設定輸入輸出通道、丟棄率等引數
                down_block = CogVideoXDownBlock3D(
                    in_channels=input_channel,
                    out_channels=output_channel,
                    temb_channels=0,
                    dropout=dropout,
                    num_layers=layers_per_block,
                    resnet_eps=norm_eps,
                    resnet_act_fn=act_fn,
                    resnet_groups=norm_num_groups,
                    # 如果不是最後一個塊則新增下采樣
                    add_downsample=not is_final_block,
                    compress_time=compress_time,
                )
            else:
                # 如果下采樣模組型別無效,則丟擲異常
                raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")

            # 將建立的下采樣塊新增到down_blocks列表中
            self.down_blocks.append(down_block)

        # 建立中間塊
        self.mid_block = CogVideoXMidBlock3D(
            in_channels=block_out_channels[-1],
            temb_channels=0,
            dropout=dropout,
            num_layers=2,
            resnet_eps=norm_eps,
            resnet_act_fn=act_fn,
            resnet_groups=norm_num_groups,
            pad_mode=pad_mode,
        )

        # 建立歸一化層,使用GroupNorm
        self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
        # 建立啟用函式層,使用SiLU啟用函式
        self.conv_act = nn.SiLU()
        # 建立輸出卷積層,將最後一個塊的輸出通道數轉換為2倍的out_channels
        self.conv_out = CogVideoXCausalConv3d(
            block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
        )

        # 初始化梯度檢查點為False
        self.gradient_checkpointing = False
    # 定義 `CogVideoXEncoder3D` 類的前向傳播方法,接收輸入樣本和可選的時間嵌入
    def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
        r"""`CogVideoXEncoder3D` 類的前向方法。"""
        # 透過輸入樣本進行初始卷積,得到隱藏狀態
        hidden_states = self.conv_in(sample)

        # 檢查是否在訓練模式並且啟用梯度檢查點
        if self.training and self.gradient_checkpointing:

            # 定義一個建立自定義前向傳播函式的內部函式
            def create_custom_forward(module):
                # 自定義前向傳播,傳入可變引數
                def custom_forward(*inputs):
                    return module(*inputs)

                return custom_forward

            # 1. 向下取樣
            # 遍歷下采樣塊,並應用檢查點以減少記憶體使用
            for down_block in self.down_blocks:
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(down_block), hidden_states, temb, None
                )

            # 2. 中間塊
            # 對中間塊進行檢查點處理
            hidden_states = torch.utils.checkpoint.checkpoint(
                create_custom_forward(self.mid_block), hidden_states, temb, None
            )
        else:
            # 如果不是訓練模式,直接執行前向傳播
            
            # 1. 向下取樣
            # 遍歷下采樣塊,直接應用每個下采樣塊的前向傳播
            for down_block in self.down_blocks:
                hidden_states = down_block(hidden_states, temb, None)

            # 2. 中間塊
            # 直接應用中間塊的前向傳播
            hidden_states = self.mid_block(hidden_states, temb, None)

        # 3. 後處理
        # 對隱藏狀態進行歸一化處理
        hidden_states = self.norm_out(hidden_states)
        # 應用啟用函式
        hidden_states = self.conv_act(hidden_states)
        # 透過最後的卷積層輸出結果
        hidden_states = self.conv_out(hidden_states)
        # 返回最終的隱藏狀態
        return hidden_states
# 定義一個名為 `CogVideoXDecoder3D` 的類,繼承自 `nn.Module`
class CogVideoXDecoder3D(nn.Module):
    r"""
    `CogVideoXDecoder3D` 是一個變分自編碼器的層,用於將潛在表示解碼為輸出樣本。

    引數:
        in_channels (`int`, *可選*, 預設為 3):
            輸入通道的數量。
        out_channels (`int`, *可選*, 預設為 3):
            輸出通道的數量。
        up_block_types (`Tuple[str, ...]`, *可選*, 預設為 `("UpDecoderBlock2D",)`):
            要使用的上取樣塊型別。請參見 `~diffusers.models.unet_2d_blocks.get_up_block` 獲取可用選項。
        block_out_channels (`Tuple[int, ...]`, *可選*, 預設為 `(64,)`):
            每個塊的輸出通道數量。
        act_fn (`str`, *可選*, 預設為 `"silu"`):
            要使用的啟用函式。請參見 `~diffusers.models.activations.get_activation` 獲取可用選項。
        layers_per_block (`int`, *可選*, 預設為 2):
            每個塊的層數。
        norm_num_groups (`int`, *可選*, 預設為 32):
            歸一化的組數。
    """

    # 定義一個類屬性,表示支援梯度檢查點
    _supports_gradient_checkpointing = True

    # 初始化方法,定義類的建構函式
    def __init__(
        # 輸入通道數量,預設為 16
        in_channels: int = 16,
        # 輸出通道數量,預設為 3
        out_channels: int = 3,
        # 上取樣塊型別的元組,包含四個 'CogVideoXUpBlock3D'
        up_block_types: Tuple[str, ...] = (
            "CogVideoXUpBlock3D",
            "CogVideoXUpBlock3D",
            "CogVideoXUpBlock3D",
            "CogVideoXUpBlock3D",
        ),
        # 每個塊的輸出通道數量,指定為 128, 256, 256, 512
        block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
        # 每個塊的層數,預設為 3
        layers_per_block: int = 3,
        # 啟用函式名稱,預設為 "silu"
        act_fn: str = "silu",
        # 歸一化的 epsilon 值,預設為 1e-6
        norm_eps: float = 1e-6,
        # 歸一化的組數,預設為 32
        norm_num_groups: int = 32,
        # dropout 比例,預設為 0.0
        dropout: float = 0.0,
        # 填充模式,預設為 "first"
        pad_mode: str = "first",
        # 時間壓縮比,預設為 4
        temporal_compression_ratio: float = 4,
    ):
        # 呼叫父類的初始化方法
        super().__init__()

        # 反轉輸出通道列表,以便後續處理
        reversed_block_out_channels = list(reversed(block_out_channels))

        # 建立輸入卷積層,使用反轉後的輸出通道的第一個元素
        self.conv_in = CogVideoXCausalConv3d(
            in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
        )

        # 建立中間塊
        self.mid_block = CogVideoXMidBlock3D(
            # 使用反轉後的輸出通道的第一個元素作為輸入通道
            in_channels=reversed_block_out_channels[0],
            temb_channels=0,
            num_layers=2,
            resnet_eps=norm_eps,
            resnet_act_fn=act_fn,
            resnet_groups=norm_num_groups,
            spatial_norm_dim=in_channels,
            pad_mode=pad_mode,
        )

        # 初始化上取樣塊的模組列表
        self.up_blocks = nn.ModuleList([])

        # 設定當前的輸出通道為反轉後的輸出通道的第一個元素
        output_channel = reversed_block_out_channels[0]
        # 計算時間壓縮級別
        temporal_compress_level = int(np.log2(temporal_compression_ratio))

        # 遍歷每種上取樣塊型別
        for i, up_block_type in enumerate(up_block_types):
            # 儲存前一個輸出通道
            prev_output_channel = output_channel
            # 更新當前輸出通道為反轉後的輸出通道
            output_channel = reversed_block_out_channels[i]
            # 判斷當前塊是否為最後一個塊
            is_final_block = i == len(block_out_channels) - 1
            # 判斷是否需要時間壓縮
            compress_time = i < temporal_compress_level

            # 如果塊型別為指定的上取樣塊型別
            if up_block_type == "CogVideoXUpBlock3D":
                # 建立上取樣塊
                up_block = CogVideoXUpBlock3D(
                    in_channels=prev_output_channel,
                    out_channels=output_channel,
                    temb_channels=0,
                    dropout=dropout,
                    num_layers=layers_per_block + 1,
                    resnet_eps=norm_eps,
                    resnet_act_fn=act_fn,
                    resnet_groups=norm_num_groups,
                    spatial_norm_dim=in_channels,
                    add_upsample=not is_final_block,
                    compress_time=compress_time,
                    pad_mode=pad_mode,
                )
                # 更新前一個輸出通道
                prev_output_channel = output_channel
            else:
                # 如果上取樣塊型別不合法,丟擲錯誤
                raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")

            # 將建立的上取樣塊新增到模組列表中
            self.up_blocks.append(up_block)

        # 建立輸出的空間歸一化層
        self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
        # 建立啟用函式層
        self.conv_act = nn.SiLU()
        # 建立輸出卷積層
        self.conv_out = CogVideoXCausalConv3d(
            reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
        )

        # 初始化梯度檢查點標誌為 False
        self.gradient_checkpointing = False
    # 定義 `CogVideoXDecoder3D` 類的前向傳播方法
    def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
        # 方法文件字串,描述該方法的功能
        r"""The forward method of the `CogVideoXDecoder3D` class."""
        # 對輸入樣本應用初始卷積,生成隱藏狀態
        hidden_states = self.conv_in(sample)
    
        # 如果處於訓練模式且啟用梯度檢查點
        if self.training and self.gradient_checkpointing:
    
            # 建立自定義前向傳播函式
            def create_custom_forward(module):
                # 定義接受輸入並呼叫模組的函式
                def custom_forward(*inputs):
                    return module(*inputs)
    
                return custom_forward
    
            # 1. 中間塊處理
            hidden_states = torch.utils.checkpoint.checkpoint(
                create_custom_forward(self.mid_block), hidden_states, temb, sample
            )
    
            # 2. 上取樣塊處理
            for up_block in self.up_blocks:
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(up_block), hidden_states, temb, sample
                )
        else:
            # 1. 中間塊處理
            hidden_states = self.mid_block(hidden_states, temb, sample)
    
            # 2. 上取樣塊處理
            for up_block in self.up_blocks:
                hidden_states = up_block(hidden_states, temb, sample)
    
        # 3. 後處理
        hidden_states = self.norm_out(hidden_states, sample)  # 歸一化輸出
        hidden_states = self.conv_act(hidden_states)          # 應用啟用函式
        hidden_states = self.conv_out(hidden_states)          # 應用最終卷積
        return hidden_states                                   # 返回處理後的隱藏狀態
# 定義一個名為 AutoencoderKLCogVideoX 的類,繼承自 ModelMixin、ConfigMixin 和 FromOriginalModelMixin
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    r"""
    一個具有 KL 損失的變分自編碼器(VAE)模型,用於將影像編碼為潛在表示並解碼潛在表示為影像。用於
    [CogVideoX](https://github.com/THUDM/CogVideo)。

    該模型繼承自 [`ModelMixin`]。有關所有模型實現的通用方法(例如下載或儲存)的詳細資訊,請檢視父類文件。

    引數:
        in_channels (int, *可選*,預設值為 3):輸入影像的通道數。
        out_channels (int,  *可選*,預設值為 3):輸出的通道數。
        down_block_types (`Tuple[str]`, *可選*,預設值為 `("DownEncoderBlock2D",)`):
            下采樣塊型別的元組。
        up_block_types (`Tuple[str]`, *可選*,預設值為 `("UpDecoderBlock2D",)`):
            上取樣塊型別的元組。
        block_out_channels (`Tuple[int]`, *可選*,預設值為 `(64,)`):
            塊輸出通道數的元組。
        act_fn (`str`, *可選*,預設值為 `"silu"`):使用的啟用函式。
        sample_size (`int`, *可選*,預設值為 `32`):樣本輸入大小。
        scaling_factor (`float`, *可選*,預設值為 `1.15258426`):
            使用訓練集的第一批計算的訓練潛在空間的逐分量標準差。用於在訓練擴散模型時將潛在空間縮放到單位方差。潛在表示在傳遞給擴散模型之前使用公式 `z = z * scaling_factor` 進行縮放。在解碼時,潛在表示使用公式 `z = 1 / scaling_factor * z` 縮放回原始比例。有關詳細資訊,請參閱 [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) 論文的第 4.3.2 節和 D.1 節。
        force_upcast (`bool`, *可選*,預設值為 `True`):
            如果啟用,它將強制 VAE 在 float32 中執行,以支援高影像解析度管道,例如 SD-XL。VAE 可以在不失去太多精度的情況下微調/訓練到較低範圍,在這種情況下可以將 `force_upcast` 設定為 `False` - 參見:https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
    """

    # 設定支援梯度檢查點
    _supports_gradient_checkpointing = True
    # 定義不進行拆分的模組列表
    _no_split_modules = ["CogVideoXResnetBlock3D"]

    # 用於將類註冊到配置中的裝飾器
    @register_to_config
    # 初始化方法,用於設定類的屬性
        def __init__(
            # 輸入通道數,預設為3
            in_channels: int = 3,
            # 輸出通道數,預設為3
            out_channels: int = 3,
            # 下采樣塊型別的元組
            down_block_types: Tuple[str] = (
                "CogVideoXDownBlock3D",  # 第一個下采樣塊型別
                "CogVideoXDownBlock3D",  # 第二個下采樣塊型別
                "CogVideoXDownBlock3D",  # 第三個下采樣塊型別
                "CogVideoXDownBlock3D",  # 第四個下采樣塊型別
            ),
            # 上取樣塊型別的元組
            up_block_types: Tuple[str] = (
                "CogVideoXUpBlock3D",    # 第一個上取樣塊型別
                "CogVideoXUpBlock3D",    # 第二個上取樣塊型別
                "CogVideoXUpBlock3D",    # 第三個上取樣塊型別
                "CogVideoXUpBlock3D",    # 第四個上取樣塊型別
            ),
            # 每個塊的輸出通道數的元組
            block_out_channels: Tuple[int] = (128, 256, 256, 512),
            # 潛在通道數,預設為16
            latent_channels: int = 16,
            # 每個塊的層數,預設為3
            layers_per_block: int = 3,
            # 啟用函式型別,預設為"silu"
            act_fn: str = "silu",
            # 歸一化的epsilon,預設為1e-6
            norm_eps: float = 1e-6,
            # 歸一化的組數,預設為32
            norm_num_groups: int = 32,
            # 時間壓縮比,預設為4
            temporal_compression_ratio: float = 4,
            # 樣本高度,預設為480
            sample_height: int = 480,
            # 樣本寬度,預設為720
            sample_width: int = 720,
            # 縮放因子,預設為1.15258426
            scaling_factor: float = 1.15258426,
            # 位移因子,預設為None
            shift_factor: Optional[float] = None,
            # 潛在均值,預設為None
            latents_mean: Optional[Tuple[float]] = None,
            # 潛在標準差,預設為None
            latents_std: Optional[Tuple[float]] = None,
            # 強制上位數,預設為True
            force_upcast: float = True,
            # 是否使用量化卷積,預設為False
            use_quant_conv: bool = False,
            # 是否使用後量化卷積,預設為False
            use_post_quant_conv: bool = False,
        # 設定梯度檢查點的方法
        def _set_gradient_checkpointing(self, module, value=False):
            # 檢查模組是否為特定型別以設定梯度檢查點
            if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
                # 設定模組的梯度檢查點標誌
                module.gradient_checkpointing = value
    
        # 清理偽上下文並行快取的方法
        def _clear_fake_context_parallel_cache(self):
            # 遍歷所有命名模組
            for name, module in self.named_modules():
                # 檢查模組是否為特定型別
                if isinstance(module, CogVideoXCausalConv3d):
                    # 記錄清理操作
                    logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
                    # 清理模組的偽上下文並行快取
                    module._clear_fake_context_parallel_cache()
    
        # 啟用平鋪的方法
        def enable_tiling(
            # 平鋪樣本最小高度,預設為None
            tile_sample_min_height: Optional[int] = None,
            # 平鋪樣本最小寬度,預設為None
            tile_sample_min_width: Optional[int] = None,
            # 平鋪重疊因子高度,預設為None
            tile_overlap_factor_height: Optional[float] = None,
            # 平鋪重疊因子寬度,預設為None
            tile_overlap_factor_width: Optional[float] = None,
    # 該方法用於啟用分塊的 VAE 解碼
    ) -> None:
            r"""
            啟用分塊 VAE 解碼。啟用後,VAE 將輸入張量分割成多個塊來進行解碼和編碼。
            這有助於節省大量記憶體並允許處理更大影像。
    
            引數:
                tile_sample_min_height (`int`, *可選*):
                    樣本在高度維度上分塊所需的最小高度。
                tile_sample_min_width (`int`, *可選*):
                    樣本在寬度維度上分塊所需的最小寬度。
                tile_overlap_factor_height (`int`, *可選*):
                    兩個連續垂直塊之間的最小重疊量。以確保在高度維度上沒有塊狀偽影。必須在 0 和 1 之間。設定較高的值可能導致處理更多塊,從而減慢解碼過程。
                tile_overlap_factor_width (`int`, *可選*):
                    兩個連續水平塊之間的最小重疊量。以確保在寬度維度上沒有塊狀偽影。必須在 0 和 1 之間。設定較高的值可能導致處理更多塊,從而減慢解碼過程。
            """
            # 啟用分塊處理
            self.use_tiling = True
            # 設定最小高度,使用提供的值或預設值
            self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
            # 設定最小寬度,使用提供的值或預設值
            self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
            # 計算最小潛在高度,根據配置的塊通道數調整
            self.tile_latent_min_height = int(
                self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
            )
            # 計算最小潛在寬度,根據配置的塊通道數調整
            self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
            # 設定高度重疊因子,使用提供的值或預設值
            self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
            # 設定寬度重疊因子,使用提供的值或預設值
            self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
    
        # 該方法用於禁用分塊的 VAE 解碼
        def disable_tiling(self) -> None:
            r"""
            禁用分塊 VAE 解碼。如果之前啟用了 `enable_tiling`,該方法將返回到一步解碼。
            """
            # 將分塊處理狀態設定為禁用
            self.use_tiling = False
    
        # 該方法用於啟用切片的 VAE 解碼
        def enable_slicing(self) -> None:
            r"""
            啟用切片 VAE 解碼。啟用後,VAE 將輸入張量切割為切片以進行多步解碼。
            這有助於節省記憶體並允許更大的批處理大小。
            """
            # 啟用切片處理
            self.use_slicing = True
    
        # 該方法用於禁用切片的 VAE 解碼
        def disable_slicing(self) -> None:
            r"""
            禁用切片 VAE 解碼。如果之前啟用了 `enable_slicing`,該方法將返回到一步解碼。
            """
            # 將切片處理狀態設定為禁用
            self.use_slicing = False
    # 定義編碼函式,輸入為一個 Torch 張量,輸出為編碼後的 Torch 張量
    def _encode(self, x: torch.Tensor) -> torch.Tensor:
        # 獲取輸入張量的維度資訊,包括批大小、通道數、幀數、高度和寬度
        batch_size, num_channels, num_frames, height, width = x.shape
    
        # 檢查是否使用切片和輸入影像尺寸是否超過最小切片尺寸
        if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
            # 如果條件滿足,則呼叫切片編碼函式
            return self.tiled_encode(x)
    
        # 設定每個批次處理的幀數
        frame_batch_size = self.num_sample_frames_batch_size
        # 計算批次數,期望幀數為 1 或批大小的整數倍
        num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
        # 初始化編碼結果列表
        enc = []
        # 遍歷每個批次
        for i in range(num_batches):
            # 計算剩餘幀數
            remaining_frames = num_frames % frame_batch_size
            # 計算當前批次的起始和結束幀索引
            start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
            end_frame = frame_batch_size * (i + 1) + remaining_frames
            # 從輸入張量中提取當前批次的幀
            x_intermediate = x[:, :, start_frame:end_frame]
            # 對當前批次進行編碼
            x_intermediate = self.encoder(x_intermediate)
            # 如果存在量化卷積,則對結果進行量化
            if self.quant_conv is not None:
                x_intermediate = self.quant_conv(x_intermediate)
            # 將當前批次的編碼結果新增到結果列表中
            enc.append(x_intermediate)
    
        # 清除假上下文的並行快取
        self._clear_fake_context_parallel_cache()
        # 將所有批次的編碼結果沿時間維度連線
        enc = torch.cat(enc, dim=2)
    
        # 返回最終的編碼結果
        return enc
    
    # 應用前向鉤子裝飾器
    @apply_forward_hook
    def encode(
        self, x: torch.Tensor, return_dict: bool = True
    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
        """
        將一批影像編碼為潛在表示。
    
        引數:
            x (`torch.Tensor`): 輸入影像批次。
            return_dict (`bool`, *可選*, 預設為 `True`):
                是否返回 [`~models.autoencoder_kl.AutoencoderKLOutput`] 而不是普通元組。
    
        返回:
            編碼影片的潛在表示。如果 `return_dict` 為 True,則返回一個
            [`~models.autoencoder_kl.AutoencoderKLOutput`],否則返回普通元組。
        """
        # 如果使用切片且輸入批次大於 1,進行切片編碼
        if self.use_slicing and x.shape[0] > 1:
            # 針對每個切片呼叫編碼函式,並收集結果
            encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
            # 將所有切片的結果連線
            h = torch.cat(encoded_slices)
        else:
            # 否則直接編碼整個輸入
            h = self._encode(x)
    
        # 使用編碼結果建立對角高斯分佈
        posterior = DiagonalGaussianDistribution(h)
    
        # 根據返回字典標誌決定返回結果的型別
        if not return_dict:
            return (posterior,)
        # 返回編碼輸出物件
        return AutoencoderKLOutput(latent_dist=posterior)
    # 解碼給定的潛在張量 z,並選擇返回字典或張量格式
    def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
        # 獲取輸入張量的批次大小、通道數、幀數、高度和寬度
        batch_size, num_channels, num_frames, height, width = z.shape

        # 如果啟用平鋪解碼且寬度或高度超過最小平鋪尺寸,則呼叫平鋪解碼函式
        if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
            return self.tiled_decode(z, return_dict=return_dict)

        # 設定每批潛在幀的大小
        frame_batch_size = self.num_latent_frames_batch_size
        # 計算總的批次數
        num_batches = num_frames // frame_batch_size
        # 建立用於儲存解碼結果的列表
        dec = []
        # 遍歷每個批次
        for i in range(num_batches):
            # 計算剩餘幀數
            remaining_frames = num_frames % frame_batch_size
            # 計算當前批次的起始幀和結束幀
            start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
            end_frame = frame_batch_size * (i + 1) + remaining_frames
            # 獲取當前批次的潛在張量
            z_intermediate = z[:, :, start_frame:end_frame]
            # 如果存在後量化卷積,則對當前潛在張量進行處理
            if self.post_quant_conv is not None:
                z_intermediate = self.post_quant_conv(z_intermediate)
            # 將潛在張量解碼為輸出
            z_intermediate = self.decoder(z_intermediate)
            # 將解碼結果新增到列表中
            dec.append(z_intermediate)

        # 清除假上下文並行快取
        self._clear_fake_context_parallel_cache()
        # 將所有解碼結果沿著幀維度拼接
        dec = torch.cat(dec, dim=2)

        # 如果不需要返回字典,直接返回解碼結果
        if not return_dict:
            return (dec,)

        # 返回解碼結果的字典形式
        return DecoderOutput(sample=dec)

    # 應用前向鉤子修飾器
    @apply_forward_hook
    def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
        """
        解碼一批影像。

        引數:
            z (`torch.Tensor`): 輸入的潛在向量批次。
            return_dict (`bool`, *可選*, 預設為 `True`):
                是否返回 [`~models.vae.DecoderOutput`] 而不是普通元組。

        返回:
            [`~models.vae.DecoderOutput`] 或 `tuple`:
                如果 return_dict 為 True,返回 [`~models.vae.DecoderOutput`],否則返回普通元組。
        """
        # 如果啟用切片解碼且輸入的批次大小大於 1
        if self.use_slicing and z.shape[0] > 1:
            # 遍歷每個切片並解碼,收集解碼結果
            decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
            # 將所有解碼結果拼接
            decoded = torch.cat(decoded_slices)
        else:
            # 對整個輸入進行解碼並獲取解碼樣本
            decoded = self._decode(z).sample

        # 如果不需要返回字典,直接返回解碼結果
        if not return_dict:
            return (decoded,)
        # 返回解碼結果的字典形式
        return DecoderOutput(sample=decoded)

    # 垂直混合兩個張量 a 和 b,並指定混合範圍
    def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
        # 確定混合範圍的最小值
        blend_extent = min(a.shape[3], b.shape[3], blend_extent)
        # 在混合範圍內遍歷每一行
        for y in range(blend_extent):
            # 混合張量 a 的底部與張量 b 的頂部
            b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
                y / blend_extent
            )
        # 返回混合後的張量 b
        return b

    # 水平混合兩個張量 a 和 b,並指定混合範圍
    def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
        # 確定混合範圍的最小值
        blend_extent = min(a.shape[4], b.shape[4], blend_extent)
        # 在混合範圍內遍歷每一列
        for x in range(blend_extent):
            # 混合張量 a 的右側與張量 b 的左側
            b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
                x / blend_extent
            )
        # 返回混合後的張量 b
        return b
    # 定義前向傳播函式,接收輸入樣本及其他引數
    def forward(
            self,
            sample: torch.Tensor,  # 輸入樣本,型別為張量
            sample_posterior: bool = False,  # 是否從後驗分佈中取樣,預設值為 False
            return_dict: bool = True,  # 是否以字典形式返回結果,預設值為 True
            generator: Optional[torch.Generator] = None,  # 可選的隨機數生成器
        ) -> Union[torch.Tensor, torch.Tensor]:  # 返回型別為張量
            x = sample  # 將輸入樣本賦值給變數 x
            posterior = self.encode(x).latent_dist  # 編碼輸入樣本並獲取後驗分佈
            if sample_posterior:  # 檢查是否需要從後驗分佈中取樣
                z = posterior.sample(generator=generator)  # 從後驗分佈中取樣
            else:
                z = posterior.mode()  # 使用後驗分佈的眾數作為 z 的值
            dec = self.decode(z)  # 解碼 z 得到解碼結果 dec
            if not return_dict:  # 檢查是否需要返回字典形式的結果
                return (dec,)  # 以元組形式返回解碼結果
            return dec  # 返回解碼結果

.\diffusers\models\autoencoders\autoencoder_kl_temporal_decoder.py

# 版權宣告,指明該檔案由 HuggingFace 團隊建立,版權歸其所有
# 
# 根據 Apache 許可證第 2.0 版(“許可證”)進行許可;
# 除非遵循許可證,否則您不得使用此檔案。
# 您可以在以下網址獲取許可證副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非適用法律要求或書面協議另有約定,軟體
# 在許可證下以“原樣”方式分發,不附帶任何明示或暗示的擔保或條件。
# 有關許可證下的具體許可權和
# 限制,請參見許可證。
# 匯入所需的型別定義
from typing import Dict, Optional, Tuple, Union

# 匯入 PyTorch 和神經網路模組
import torch
import torch.nn as nn

# 從配置和工具模組中匯入必要的功能
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version
from ...utils.accelerate_utils import apply_forward_hook
# 從注意力處理模組中匯入相關處理器
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
# 從模型輸出模組中匯入 Autoencoder 的輸出型別
from ..modeling_outputs import AutoencoderKLOutput
# 從模型工具模組中匯入模型的混合功能
from ..modeling_utils import ModelMixin
# 從 3D U-Net 模組中匯入解碼器塊
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
# 從變分自編碼器模組中匯入解碼器輸出和相關分佈
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder


# 定義時間解碼器類,繼承自 PyTorch 的 nn.Module
class TemporalDecoder(nn.Module):
    # 初始化時間解碼器,設定輸入輸出通道及塊引數
    def __init__(
        self,
        in_channels: int = 4,  # 輸入通道數,預設值為 4
        out_channels: int = 3,  # 輸出通道數,預設值為 3
        block_out_channels: Tuple[int] = (128, 256, 512, 512),  # 每個塊的輸出通道數,預設值為指定的元組
        layers_per_block: int = 2,  # 每個塊的層數,預設值為 2
    ):
        # 初始化父類
        super().__init__()
        # 設定每個塊的層數
        self.layers_per_block = layers_per_block

        # 建立輸入卷積層,接受 in_channels 通道,輸出 block_out_channels[-1] 通道
        self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
        # 建立中間塊的時間解碼器,傳入層數、輸入通道、輸出通道和注意力頭維度
        self.mid_block = MidBlockTemporalDecoder(
            num_layers=self.layers_per_block,
            in_channels=block_out_channels[-1],
            out_channels=block_out_channels[-1],
            attention_head_dim=block_out_channels[-1],
        )

        # 建立上取樣塊的列表
        self.up_blocks = nn.ModuleList([])
        # 反轉輸出通道列表
        reversed_block_out_channels = list(reversed(block_out_channels))
        # 獲取第一個輸出通道
        output_channel = reversed_block_out_channels[0]
        # 遍歷每個輸出通道
        for i in range(len(block_out_channels)):
            # 儲存前一個輸出通道
            prev_output_channel = output_channel
            # 更新當前輸出通道
            output_channel = reversed_block_out_channels[i]

            # 判斷是否為最後一個塊
            is_final_block = i == len(block_out_channels) - 1
            # 建立上取樣塊的時間解碼器
            up_block = UpBlockTemporalDecoder(
                num_layers=self.layers_per_block + 1,
                in_channels=prev_output_channel,
                out_channels=output_channel,
                add_upsample=not is_final_block,
            )
            # 將上取樣塊新增到列表中
            self.up_blocks.append(up_block)
            # 更新前一個輸出通道
            prev_output_channel = output_channel

        # 建立輸出的歸一化卷積層
        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6)

        # 建立啟用函式層,使用 SiLU 啟用函式
        self.conv_act = nn.SiLU()
        # 建立輸出卷積層,將輸入通道轉換為輸出通道
        self.conv_out = torch.nn.Conv2d(
            in_channels=block_out_channels[0],
            out_channels=out_channels,
            kernel_size=3,
            padding=1,
        )

        # 定義卷積輸出的核大小
        conv_out_kernel_size = (3, 1, 1)
        # 計算填充
        padding = [int(k // 2) for k in conv_out_kernel_size]
        # 建立 3D 卷積層,進行時間卷積
        self.time_conv_out = torch.nn.Conv3d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=conv_out_kernel_size,
            padding=padding,
        )

        # 初始化梯度檢查點標誌為 False
        self.gradient_checkpointing = False

    def forward(
        # 定義前向傳播方法的輸入引數
        self,
        sample: torch.Tensor,
        image_only_indicator: torch.Tensor,
        num_frames: int = 1,
    ) -> torch.Tensor:
        r"""`Decoder` 類的前向傳播方法。"""

        # 對輸入樣本進行初始卷積處理
        sample = self.conv_in(sample)

        # 獲取上取樣塊引數的 dtype,用於後續轉換
        upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
        
        # 如果處於訓練模式並啟用了梯度檢查點
        if self.training and self.gradient_checkpointing:

            # 建立自定義前向傳播函式
            def create_custom_forward(module):
                # 定義自定義前向傳播,接受輸入並返回模組的輸出
                def custom_forward(*inputs):
                    return module(*inputs)

                return custom_forward

            # 如果 PyTorch 版本大於等於 1.11.0
            if is_torch_version(">=", "1.11.0"):
                # 中間處理
                sample = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.mid_block),  # 使用自定義前向傳播處理中間塊
                    sample,  # 輸入樣本
                    image_only_indicator,  # 指示符
                    use_reentrant=False,  # 不使用可重入的檢查點
                )
                # 轉換樣本的 dtype
                sample = sample.to(upscale_dtype)

                # 上取樣處理
                for up_block in self.up_blocks:
                    sample = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(up_block),  # 使用自定義前向傳播處理每個上取樣塊
                        sample,  # 當前樣本
                        image_only_indicator,  # 指示符
                        use_reentrant=False,  # 不使用可重入的檢查點
                    )
            else:
                # 中間處理
                sample = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.mid_block),  # 使用自定義前向傳播處理中間塊
                    sample,  # 輸入樣本
                    image_only_indicator,  # 指示符
                )
                # 轉換樣本的 dtype
                sample = sample.to(upscale_dtype)

                # 上取樣處理
                for up_block in self.up_blocks:
                    sample = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(up_block),  # 使用自定義前向傳播處理每個上取樣塊
                        sample,  # 當前樣本
                        image_only_indicator,  # 指示符
                    )
        else:
            # 如果不在訓練模式
            # 中間處理
            sample = self.mid_block(sample, image_only_indicator=image_only_indicator)  # 處理樣本
            # 轉換樣本的 dtype
            sample = sample.to(upscale_dtype)

            # 上取樣處理
            for up_block in self.up_blocks:
                sample = up_block(sample, image_only_indicator=image_only_indicator)  # 處理樣本

        # 後處理步驟
        sample = self.conv_norm_out(sample)  # 正則化輸出樣本
        sample = self.conv_act(sample)  # 應用啟用函式
        sample = self.conv_out(sample)  # 生成最終輸出樣本

        # 獲取樣本的形狀資訊
        batch_frames, channels, height, width = sample.shape
        # 計算批大小
        batch_size = batch_frames // num_frames
        # 重新排列樣本的形狀以適應時間維度
        sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
        # 應用時間卷積層
        sample = self.time_conv_out(sample)

        # 還原樣本的維度
        sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)

        # 返回最終處理後的樣本
        return sample
# 定義一個類 AutoencoderKLTemporalDecoder,繼承自 ModelMixin 和 ConfigMixin
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
    r"""
    一個具有 KL 損失的 VAE 模型,用於將影像編碼為潛在表示,並將潛在表示解碼為影像。

    該模型繼承自 [`ModelMixin`]。有關所有模型的通用方法(如下載或儲存)的詳細資訊,請查閱超類文件。

    引數:
        in_channels (int, *可選*, 預設為 3): 輸入影像的通道數。
        out_channels (int, *可選*, 預設為 3): 輸出的通道數。
        down_block_types (`Tuple[str]`, *可選*, 預設為 `("DownEncoderBlock2D",)`):
            下采樣塊型別的元組。
        block_out_channels (`Tuple[int]`, *可選*, 預設為 `(64,)`):
            塊輸出通道的元組。
        layers_per_block: (`int`, *可選*, 預設為 1): 每個塊的層數。
        latent_channels (`int`, *可選*, 預設為 4): 潛在空間中的通道數。
        sample_size (`int`, *可選*, 預設為 32): 輸入樣本大小。
        scaling_factor (`float`, *可選*, 預設為 0.18215):
            使用訓練集的第一批計算的訓練潛在空間的逐分量標準差。用於將潛在空間縮放到單位方差,當訓練擴散模型時,潛在變數按公式 `z = z * scaling_factor` 縮放。解碼時,潛在變數透過公式 `z = 1 / scaling_factor * z` 縮放回原始比例。有關詳細資訊,請參見 [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) 論文的第 4.3.2 節和 D.1 節。
        force_upcast (`bool`, *可選*, 預設為 `True`):
            如果啟用,將強制 VAE 在 float32 中執行,以適應高影像解析度管道,例如 SD-XL。VAE 可以進行微調/訓練到較低範圍,而不會失去太多精度,在這種情況下 `force_upcast` 可以設定為 `False` - 參見: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
    """

    # 設定支援梯度檢查點的標誌為真
    _supports_gradient_checkpointing = True

    # 註冊到配置的方法,初始化類的例項
    @register_to_config
    def __init__(
        # 輸入通道的數量,預設為 3
        self,
        in_channels: int = 3,
        # 輸出通道的數量,預設為 3
        out_channels: int = 3,
        # 下采樣塊的型別,預設為一個包含 "DownEncoderBlock2D" 的元組
        down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
        # 塊輸出通道的數量,預設為一個包含 64 的元組
        block_out_channels: Tuple[int] = (64,),
        # 每個塊的層數,預設為 1
        layers_per_block: int = 1,
        # 潛在通道的數量,預設為 4
        latent_channels: int = 4,
        # 樣本輸入大小,預設為 32
        sample_size: int = 32,
        # 縮放因子,預設為 0.18215
        scaling_factor: float = 0.18215,
        # 強制使用 float32 的標誌,預設為 True
        force_upcast: float = True,
    ):
        # 呼叫父類的建構函式進行初始化
        super().__init__()

        # 將初始化引數傳遞給編碼器(Encoder)
        self.encoder = Encoder(
            # 輸入通道數
            in_channels=in_channels,
            # 潛在通道數
            out_channels=latent_channels,
            # 下采樣塊的型別
            down_block_types=down_block_types,
            # 每個塊的輸出通道數
            block_out_channels=block_out_channels,
            # 每個塊的層數
            layers_per_block=layers_per_block,
            # 是否雙重潛在變數
            double_z=True,
        )

        # 將初始化引數傳遞給解碼器(Decoder)
        self.decoder = TemporalDecoder(
            # 潛在通道數作為輸入
            in_channels=latent_channels,
            # 輸出通道數
            out_channels=out_channels,
            # 每個塊的輸出通道數
            block_out_channels=block_out_channels,
            # 每個塊的層數
            layers_per_block=layers_per_block,
        )

        # 建立一個卷積層,用於量化
        self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)

        # 獲取樣本大小,支援列表或元組形式
        sample_size = (
            self.config.sample_size[0]  # 如果是列表或元組,取第一個元素
            if isinstance(self.config.sample_size, (list, tuple))
            else self.config.sample_size  # 否則直接使用樣本大小
        )
        # 計算最小的平鋪潛在大小
        self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
        # 設定平鋪重疊因子
        self.tile_overlap_factor = 0.25

    # 定義一個私有方法,用於設定梯度檢查點
    def _set_gradient_checkpointing(self, module, value=False):
        # 如果模組是編碼器或解碼器,設定其梯度檢查點
        if isinstance(module, (Encoder, TemporalDecoder)):
            module.gradient_checkpointing = value

    # 使用@property裝飾器定義一個屬性
    @property
    # 從 UNet2DConditionModel 的 attn_processors 複製而來
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        返回:
            `dict` 型別的注意力處理器:一個包含模型中所有注意力處理器的字典,按權重名稱索引。
        """
        # 初始化一個空字典以遞迴儲存處理器
        processors = {}

        # 定義一個遞迴函式,用於新增處理器
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 如果模組有獲取處理器的方法,則新增到字典中
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()

            # 遍歷模組的所有子模組
            for sub_name, child in module.named_children():
                # 遞迴呼叫自身以處理子模組
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors  # 返回處理器字典

        # 遍歷當前模組的所有子模組
        for name, module in self.named_children():
            # 呼叫遞迴函式新增處理器
            fn_recursive_add_processors(name, module, processors)

        return processors  # 返回所有處理器的字典

    # 從 UNet2DConditionModel 的 set_attn_processor 複製而來
    # 定義一個方法用於設定注意力處理器,接受處理器引數
        def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
            r"""
            設定用於計算注意力的處理器。
    
            引數:
                processor (`dict` 或 `AttentionProcessor`):
                    例項化的處理器類或將被設定為**所有**`Attention`層的處理器類字典。
    
                    如果 `processor` 是字典,則鍵需要定義相應交叉注意力處理器的路徑。強烈建議在設定可訓練的注意力處理器時使用。
    
            """
            # 計算當前注意力處理器的數量
            count = len(self.attn_processors.keys())
    
            # 如果傳入的是字典且其長度不匹配注意力層的數量,則丟擲異常
            if isinstance(processor, dict) and len(processor) != count:
                raise ValueError(
                    f"傳入了處理器字典,但處理器數量 {len(processor)} 與注意力層數量 {count} 不匹配。請確保傳入 {count} 個處理器類。"
                )
    
            # 定義一個遞迴函式以設定每個子模組的處理器
            def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
                # 如果模組具有設定處理器的方法,則進行設定
                if hasattr(module, "set_processor"):
                    if not isinstance(processor, dict):
                        module.set_processor(processor)  # 設定單個處理器
                    else:
                        # 從字典中移除並設定對應處理器
                        module.set_processor(processor.pop(f"{name}.processor"))
    
                # 遍歷模組的子模組並遞迴呼叫
                for sub_name, child in module.named_children():
                    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
            # 遍歷當前物件的所有子模組,設定處理器
            for name, module in self.named_children():
                fn_recursive_attn_processor(name, module, processor)
    
        # 定義一個方法用於設定預設的注意力處理器
        def set_default_attn_processor(self):
            """
            禁用自定義注意力處理器並設定預設的注意力實現。
            """
            # 如果所有處理器都是交叉注意力處理器,則建立預設處理器
            if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
                processor = AttnProcessor()  # 建立預設處理器例項
            else:
                # 否則丟擲異常,說明當前處理器型別不相容
                raise ValueError(
                    f"當注意力處理器的型別為 {next(iter(self.attn_processors.values()))} 時,無法呼叫 `set_default_attn_processor`。"
                )
    
            # 呼叫設定處理器的方法
            self.set_attn_processor(processor)
    
        # 應用前向鉤子,定義編碼方法
        @apply_forward_hook
        def encode(
            self, x: torch.Tensor, return_dict: bool = True
    # 定義編碼器輸出的返回型別,包含兩種可能的輸出格式
    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
        """
        將一批影像編碼為潛在表示。
    
        引數:
            x (`torch.Tensor`): 輸入的影像批次。
            return_dict (`bool`, *可選*, 預設值為 `True`):
                是否返回 [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] 而不是普通元組。
    
        返回:
                編碼影像的潛在表示。如果 `return_dict` 為 True,則返回
                [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`],否則返回普通 `tuple`。
        """
        # 使用編碼器對輸入影像進行編碼,得到中間表示
        h = self.encoder(x)
        # 對編碼結果進行量化,得到矩(均值和方差)
        moments = self.quant_conv(h)
        # 根據矩生成對角高斯分佈的後驗
        posterior = DiagonalGaussianDistribution(moments)
    
        # 檢查是否需要返回普通元組
        if not return_dict:
            # 返回後驗分佈
            return (posterior,)
    
        # 返回封裝後的潛在分佈
        return AutoencoderKLOutput(latent_dist=posterior)
    
    # 應用前向鉤子裝飾器
    @apply_forward_hook
    def decode(
        # 輸入的潛在向量
        z: torch.Tensor,
        # 輸入幀的數量
        num_frames: int,
        # 是否返回字典格式的結果
        return_dict: bool = True,
    ) -> Union[DecoderOutput, torch.Tensor]:
        """
        解碼一批影像。
    
        引數:
            z (`torch.Tensor`): 輸入的潛在向量批次。
            return_dict (`bool`, *可選*, 預設值為 `True`):
                是否返回 [`~models.vae.DecoderOutput`] 而不是普通元組。
    
        返回:
            [`~models.vae.DecoderOutput`] 或 `tuple`:
                如果 return_dict 為 True,返回 [`~models.vae.DecoderOutput`],否則返回普通 `tuple`。
        """
        # 計算批次大小
        batch_size = z.shape[0] // num_frames
        # 建立影像指示器,初始為零
        image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
        # 解碼潛在向量,生成影像
        decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
    
        # 檢查是否需要返回普通元組
        if not return_dict:
            # 返回解碼結果
            return (decoded,)
    
        # 返回解碼後的結果封裝
        return DecoderOutput(sample=decoded)
    
    def forward(
        # 輸入樣本
        sample: torch.Tensor,
        # 是否從後驗分佈中取樣
        sample_posterior: bool = False,
        # 是否返回字典格式的結果
        return_dict: bool = True,
        # 隨機生成器
        generator: Optional[torch.Generator] = None,
        # 輸入幀的數量
        num_frames: int = 1,
    ) -> Union[DecoderOutput, torch.Tensor]:
        r"""
        引數:
            sample (`torch.Tensor`): 輸入樣本。
            sample_posterior (`bool`, *可選*, 預設值為 `False`):
                是否從後驗分佈中取樣。
            return_dict (`bool`, *可選*, 預設值為 `True`):
                是否返回 [`DecoderOutput`] 而不是普通元組。
        """
        # 直接將樣本賦值給 x
        x = sample
        # 編碼樣本以獲取潛在分佈
        posterior = self.encode(x).latent_dist
        # 判斷是否需要從後驗中取樣
        if sample_posterior:
            # 從後驗中取樣潛在向量
            z = posterior.sample(generator=generator)
        else:
            # 使用後驗的模值作為潛在向量
            z = posterior.mode()
    
        # 解碼潛在向量以生成影像
        dec = self.decode(z, num_frames=num_frames).sample
    
        # 檢查是否需要返回普通元組
        if not return_dict:
            # 返回解碼結果
            return (dec,)
    
        # 返回解碼結果的封裝
        return DecoderOutput(sample=dec)

相關文章