diffusers-原始碼解析-二十九-

绝不原创的飞龙發表於2024-10-22

diffusers 原始碼解析(二十九)

.\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_model_editing.py

# 版權資訊,宣告版權和許可協議
# Copyright 2024 TIME Authors and The HuggingFace Team. All rights reserved."
# 根據 Apache License 2.0 許可協議進行許可
# 此檔案只能在遵守許可證的情況下使用
# 可透過以下網址獲取許可證副本
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律要求或書面同意,否則軟體按“現狀”分發,不提供任何形式的擔保或條件
# 具體許可條款和限制請參見許可證

# 匯入複製模組,用於物件複製
import copy
# 匯入檢查模組,用於檢查物件的資訊
import inspect
# 匯入型別提示相關的模組
from typing import Any, Callable, Dict, List, Optional, Union

# 匯入 PyTorch 庫
import torch
# 從 transformers 庫中匯入影像處理器、文字模型和標記器
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

# 從相對路徑匯入自定義影像處理器
from ....image_processor import VaeImageProcessor
# 從相對路徑匯入自定義載入器混合類
from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
# 從相對路徑匯入自定義模型類
from ....models import AutoencoderKL, UNet2DConditionModel
# 從相對路徑匯入調整 LoRA 規模的函式
from ....models.lora import adjust_lora_scale_text_encoder
# 從相對路徑匯入排程器類
from ....schedulers import PNDMScheduler
# 從排程器工具匯入排程器混合類
from ....schedulers.scheduling_utils import SchedulerMixin
# 從工具庫中匯入多個功能模組
from ....utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
# 從自定義的 PyTorch 工具庫中匯入隨機張量生成函式
from ....utils.torch_utils import randn_tensor
# 從管道工具匯入擴散管道和穩定擴散混合類
from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin
# 從穩定擴散相關模組匯入管道輸出類
from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
# 從穩定擴散安全檢查器模組匯入安全檢查器類
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker

# 建立一個日誌記錄器,用於記錄當前模組的資訊
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定義一個常量列表,包含不同的影像描述字首
AUGS_CONST = ["A photo of ", "An image of ", "A picture of "]

# 定義一個穩定擴散模型編輯管道類,繼承自多個基類
class StableDiffusionModelEditingPipeline(
    DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
):
    r"""
    文字到影像模型編輯的管道。

    該模型繼承自 [`DiffusionPipeline`]。請查閱超類文件以獲取所有管道實現的通用方法
    (下載、儲存、在特定裝置上執行等)。

    該管道還繼承以下載入方法:
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] 用於載入文字反轉嵌入
        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] 用於載入 LoRA 權重
        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] 用於儲存 LoRA 權重
    # 文件字串,描述類或方法的引數
    Args:
        vae ([`AutoencoderKL`]):
            # Variational Auto-Encoder (VAE) 模型,用於將影像編碼和解碼為潛在表示
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
        text_encoder ([`~transformers.CLIPTextModel`]):
            # 凍結的文字編碼器,使用 CLIP 的大型視覺變換模型
            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
        tokenizer ([`~transformers.CLIPTokenizer`]):
            # 用於文字標記化的 CLIPTokenizer
            A `CLIPTokenizer` to tokenize text.
        unet ([`UNet2DConditionModel`]):
            # 用於去噪編碼影像潛在的 UNet2DConditionModel
            A `UNet2DConditionModel` to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            # 排程器,與 unet 一起用於去噪編碼的影像潛在,可以是 DDIMScheduler、LMSDiscreteScheduler 或 PNDMScheduler
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
            # 分類模組,用於估計生成的影像是否可能被視為冒犯或有害
            Classification module that estimates whether generated images could be considered offensive or harmful.
            # 參閱模型卡以獲取有關模型潛在危害的更多詳細資訊
            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
            about a model's potential harms.
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            # 用於從生成的影像中提取特徵的 CLIPImageProcessor;作為輸入傳遞給 safety_checker
            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
        with_to_k ([`bool`]):
            # 是否在編輯文字到影像模型時編輯鍵投影矩陣與值投影矩陣
            Whether to edit the key projection matrices along with the value projection matrices.
        with_augs ([`list`]):
            # 在編輯文字到影像模型時應用的文字增強,設定為 [] 表示不進行增強
            Textual augmentations to apply while editing the text-to-image model. Set to `[]` for no augmentations.
    """

    # 定義模型的 CPU 解除安裝順序
    model_cpu_offload_seq = "text_encoder->unet->vae"
    # 定義可選元件列表
    _optional_components = ["safety_checker", "feature_extractor"]
    # 定義不參與 CPU 解除安裝的元件
    _exclude_from_cpu_offload = ["safety_checker"]

    # 初始化方法,設定模型和引數
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: SchedulerMixin,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPImageProcessor,
        requires_safety_checker: bool = True,
        with_to_k: bool = True,
        with_augs: list = AUGS_CONST,
    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 複製的程式碼
    def _encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        lora_scale: Optional[float] = None,
        **kwargs,
    ):
        # 定義一個棄用訊息,提示使用者 `_encode_prompt()` 已棄用,建議使用 `encode_prompt()`,並說明輸出格式變化
        deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
        # 呼叫 deprecate 函式記錄棄用警告,指定版本和警告資訊,標準警告設定為 False
        deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)

        # 呼叫 encode_prompt 方法,將引數傳入以獲取提示嵌入的元組
        prompt_embeds_tuple = self.encode_prompt(
            prompt=prompt,  # 使用者輸入的提示文字
            device=device,  # 指定執行裝置
            num_images_per_prompt=num_images_per_prompt,  # 每個提示生成的影像數量
            do_classifier_free_guidance=do_classifier_free_guidance,  # 是否使用無分類器的引導
            negative_prompt=negative_prompt,  # 負面提示文字
            prompt_embeds=prompt_embeds,  # 現有的提示嵌入(如果有的話)
            negative_prompt_embeds=negative_prompt_embeds,  # 負面提示嵌入(如果有的話)
            lora_scale=lora_scale,  # Lora 縮放引數
            **kwargs,  # 額外引數
        )

        # 連線提示嵌入元組的兩個部分,適配舊版本的相容性
        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])

        # 返回連線後的提示嵌入
        return prompt_embeds

    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt 複製
    def encode_prompt(
        self,
        prompt,  # 使用者輸入的提示文字
        device,  # 指定執行裝置
        num_images_per_prompt,  # 每個提示生成的影像數量
        do_classifier_free_guidance,  # 是否使用無分類器的引導
        negative_prompt=None,  # 負面提示文字,預設為 None
        prompt_embeds: Optional[torch.Tensor] = None,  # 現有的提示嵌入,預設為 None
        negative_prompt_embeds: Optional[torch.Tensor] = None,  # 負面提示嵌入,預設為 None
        lora_scale: Optional[float] = None,  # Lora 縮放引數,預設為 None
        clip_skip: Optional[int] = None,  # 跳過的剪輯層,預設為 None
    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 複製
    def run_safety_checker(self, image, device, dtype):  # 定義安全檢查函式,接收影像、裝置和資料型別
        # 檢查安全檢查器是否存在
        if self.safety_checker is None:
            has_nsfw_concept = None  # 如果沒有安全檢查器,則 NSFW 概念為 None
        else:
            # 如果影像是張量型別,進行後處理,轉換為 PIL 格式
            if torch.is_tensor(image):
                feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
            else:
                # 如果影像是 numpy 陣列,直接轉換為 PIL 格式
                feature_extractor_input = self.image_processor.numpy_to_pil(image)
            # 獲取安全檢查器的輸入,轉換為張量並移至指定裝置
            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
            # 使用安全檢查器檢查影像,返回處理後的影像和 NSFW 概念
            image, has_nsfw_concept = self.safety_checker(
                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
            )
        # 返回處理後的影像和 NSFW 概念
        return image, has_nsfw_concept

    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 複製
    # 解碼潛在向量的方法
    def decode_latents(self, latents):
        # 警告資訊,提示此方法已棄用,將在 1.0.0 中移除
        deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
        # 呼叫 deprecate 函式記錄棄用資訊
        deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
    
        # 根據配置的縮放因子調整潛在向量
        latents = 1 / self.vae.config.scaling_factor * latents
        # 解碼潛在向量並返回影像資料
        image = self.vae.decode(latents, return_dict=False)[0]
        # 將影像資料歸一化到 [0, 1] 範圍內
        image = (image / 2 + 0.5).clamp(0, 1)
        # 始終將影像資料轉換為 float32 型別,以確保相容性並降低開銷
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        # 返回處理後的影像資料
        return image
    
    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 複製
    def prepare_extra_step_kwargs(self, generator, eta):
        # 為排程器步驟準備額外的關鍵字引數,不同排程器的簽名不同
        # eta (η) 僅在 DDIMScheduler 中使用,其他排程器會忽略
        # eta 對應於 DDIM 論文中的 η,範圍應在 [0, 1] 之間
    
        # 檢查排程器步驟的引數是否接受 eta
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 初始化額外步驟關鍵字引數字典
        extra_step_kwargs = {}
        # 如果排程器接受 eta,則將其新增到字典中
        if accepts_eta:
            extra_step_kwargs["eta"] = eta
    
        # 檢查排程器步驟的引數是否接受 generator
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 如果排程器接受 generator,則將其新增到字典中
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        # 返回額外步驟關鍵字引數字典
        return extra_step_kwargs
    
    # 檢查輸入引數的方法
    def check_inputs(
        self,
        prompt,
        height,
        width,
        callback_steps,
        negative_prompt=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
        callback_on_step_end_tensor_inputs=None,
    ):
        # 檢查高度和寬度是否能被8整除
        if height % 8 != 0 or width % 8 != 0:
            # 丟擲異常,給出高度和寬度的資訊
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

        # 檢查回撥步數是否為正整數
        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
            # 丟擲異常,給出回撥步數的資訊
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )
        # 檢查回撥輸入是否在預期的輸入列表中
        if callback_on_step_end_tensor_inputs is not None and not all(
            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
        ):
            # 丟擲異常,列出不在預期列表中的輸入
            raise ValueError(
                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
            )

        # 檢查是否同時提供了提示和提示嵌入
        if prompt is not None and prompt_embeds is not None:
            # 丟擲異常,提示只能提供一個
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        # 檢查是否同時未提供提示和提示嵌入
        elif prompt is None and prompt_embeds is None:
            # 丟擲異常,提示至少要提供一個
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        # 檢查提示型別是否合法
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            # 丟擲異常,提示型別不符合要求
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        # 檢查是否同時提供了負提示和負提示嵌入
        if negative_prompt is not None and negative_prompt_embeds is not None:
            # 丟擲異常,提示只能提供一個
            raise ValueError(
                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
            )

        # 檢查提示嵌入和負提示嵌入的形狀是否一致
        if prompt_embeds is not None and negative_prompt_embeds is not None:
            if prompt_embeds.shape != negative_prompt_embeds.shape:
                # 丟擲異常,給出形狀不一致的資訊
                raise ValueError(
                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                    f" {negative_prompt_embeds.shape}."
                )

    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 複製
    # 準備潛在變數的函式,接受多個引數以控制形狀和生成方式
        def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
            # 定義潛在變數的形狀,包括批次大小、通道數和調整後的高度寬度
            shape = (
                batch_size,
                num_channels_latents,
                int(height) // self.vae_scale_factor,
                int(width) // self.vae_scale_factor,
            )
            # 檢查生成器列表的長度是否與批次大小一致
            if isinstance(generator, list) and len(generator) != batch_size:
                # 如果不一致,則丟擲值錯誤並提供相關資訊
                raise ValueError(
                    f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                    f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                )
    
            # 如果潛在變數為空,則生成新的隨機潛在變數
            if latents is None:
                latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
            else:
                # 如果已提供潛在變數,則將其轉移到指定裝置
                latents = latents.to(device)
    
            # 根據排程器要求的標準差縮放初始噪聲
            latents = latents * self.scheduler.init_noise_sigma
            # 返回準備好的潛在變數
            return latents
    
        # 裝飾器,指示該函式不需要計算梯度
        @torch.no_grad()
        def edit_model(
            self,
            source_prompt: str,
            destination_prompt: str,
            lamb: float = 0.1,
            restart_params: bool = True,
        # 裝飾器,指示該函式不需要計算梯度
        @torch.no_grad()
        def __call__(
            self,
            # 允許輸入字串或字串列表作為提示
            prompt: Union[str, List[str]] = None,
            # 可選引數,指定生成影像的高度
            height: Optional[int] = None,
            # 可選引數,指定生成影像的寬度
            width: Optional[int] = None,
            # 設定推理步驟的預設數量為50
            num_inference_steps: int = 50,
            # 設定引導比例的預設值為7.5
            guidance_scale: float = 7.5,
            # 可選引數,允許輸入負面提示
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 可選引數,指定每個提示生成的影像數量,預設值為1
            num_images_per_prompt: Optional[int] = 1,
            # 設定eta的預設值為0.0
            eta: float = 0.0,
            # 可選引數,允許輸入生成器或生成器列表
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 可選引數,允許輸入潛在變數的張量
            latents: Optional[torch.Tensor] = None,
            # 可選引數,允許輸入提示嵌入的張量
            prompt_embeds: Optional[torch.Tensor] = None,
            # 可選引數,允許輸入負面提示嵌入的張量
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 可選引數,指定輸出型別,預設為"pil"
            output_type: Optional[str] = "pil",
            # 可選引數,控制是否返回字典形式的輸出,預設為True
            return_dict: bool = True,
            # 可選回撥函式,用於處理生成過程中每一步的資訊
            callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
            # 可選引數,指定回撥的步驟數,預設為1
            callback_steps: int = 1,
            # 可選引數,允許傳入交叉注意力的關鍵字引數
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 可選引數,允許指定跳過的clip層數
            clip_skip: Optional[int] = None,

.\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_paradigms.py

# 版權所有 2024 ParaDiGMS 作者和 HuggingFace 團隊。保留所有權利。
#
# 根據 Apache 許可證第 2.0 版(“許可證”)許可;
# 除非遵守許可證,否則您不得使用此檔案。
# 您可以在以下地址獲取許可證副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面協議另有約定,依據許可證分發的軟體
# 是按“原樣”提供的,沒有任何形式的明示或暗示的擔保或條件。
# 有關許可證的特定許可權和限制,請參見許可證。

# 匯入 inspect 模組以進行物件檢查
import inspect
# 從 typing 模組匯入型別提示相關的類
from typing import Any, Callable, Dict, List, Optional, Union

# 匯入 PyTorch 庫以進行深度學習操作
import torch
# 從 transformers 庫匯入 CLIP 影像處理器、文字模型和分詞器
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

# 匯入自定義的影像處理器
from ....image_processor import VaeImageProcessor
# 匯入與載入相關的混合類
from ....loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
# 匯入用於自動編碼器和條件 UNet 的模型
from ....models import AutoencoderKL, UNet2DConditionModel
# 從 lora 模組匯入調整 lora 規模的函式
from ....models.lora import adjust_lora_scale_text_encoder
# 匯入 Karras 擴散排程器
from ....schedulers import KarrasDiffusionSchedulers
# 匯入實用工具模組中的各種功能
from ....utils import (
    USE_PEFT_BACKEND,  # 用於 PEFT 後端的標誌
    deprecate,  # 用於標記棄用功能的裝飾器
    logging,  # 日誌記錄功能
    replace_example_docstring,  # 替換示例文件字串的功能
    scale_lora_layers,  # 調整 lora 層規模的功能
    unscale_lora_layers,  # 反調整 lora 層規模的功能
)
# 從 torch_utils 模組匯入生成隨機張量的功能
from ....utils.torch_utils import randn_tensor
# 匯入擴散管道和穩定擴散混合類
from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin
# 匯入穩定擴散管道輸出類
from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
# 匯入穩定擴散安全檢查器
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker

# 建立日誌記錄器,用於當前模組的日誌
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 示例文件字串,展示用法示例
EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import DDPMParallelScheduler
        >>> from diffusers import StableDiffusionParadigmsPipeline

        >>> scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")

        >>> pipe = StableDiffusionParadigmsPipeline.from_pretrained(
        ...     "runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16
        ... )
        >>> pipe = pipe.to("cuda")

        >>> ngpu, batch_per_device = torch.cuda.device_count(), 5
        >>> pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)])

        >>> prompt = "a photo of an astronaut riding a horse on mars"
        >>> image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0]
        ```py
"""

# 定義 StableDiffusionParadigmsPipeline 類,繼承多個混合類以實現功能
class StableDiffusionParadigmsPipeline(
    DiffusionPipeline,  # 從擴散管道繼承
    StableDiffusionMixin,  # 從穩定擴散混合類繼承
    TextualInversionLoaderMixin,  # 從文字反轉載入混合類繼承
    StableDiffusionLoraLoaderMixin,  # 從穩定擴散 lora 載入混合類繼承
    FromSingleFileMixin,  # 從單檔案載入混合類繼承
):
    r"""
    用於文字到影像生成的管道,使用穩定擴散的並行化版本。

    此模型繼承自 [`DiffusionPipeline`]。有關通用方法的文件,請檢視超類文件
    # 實現所有管道的功能(下載、儲存、在特定裝置上執行等)。
    
    # 管道還繼承以下載入方法:
    # - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] 用於載入文字反轉嵌入
    # - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] 用於載入 LoRA 權重
    # - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] 用於儲存 LoRA 權重
    # - [`~loaders.FromSingleFileMixin.from_single_file`] 用於載入 `.ckpt` 檔案
    
    # 引數說明:
    # vae ([`AutoencoderKL`]):
    #    變分自編碼器(VAE)模型,用於將影像編碼和解碼為潛在表示。
    # text_encoder ([`~transformers.CLIPTextModel`]):
    #    凍結的文字編碼器([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。
    # tokenizer ([`~transformers.CLIPTokenizer`]):
    #    一個 `CLIPTokenizer` 用於對文字進行標記化。
    # unet ([`UNet2DConditionModel`]):
    #    一個 `UNet2DConditionModel` 用於去噪編碼的影像潛在。
    # scheduler ([`SchedulerMixin`]):
    #    用於與 `unet` 結合使用的排程器,用於去噪編碼的影像潛在。可以是
    #    [`DDIMScheduler`], [`LMSDiscreteScheduler`] 或 [`PNDMScheduler`] 中的一個。
    # safety_checker ([`StableDiffusionSafetyChecker`]):
    #    分類模組,估計生成的影像是否可能被認為是冒犯或有害的。
    #    請參考 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) 獲取更多關於模型潛在危害的詳細資訊。
    # feature_extractor ([`~transformers.CLIPImageProcessor`]):
    #    一個 `CLIPImageProcessor` 用於從生成的影像中提取特徵;作為 `safety_checker` 的輸入。
    
    # 定義模型的 CPU 離線載入順序
    model_cpu_offload_seq = "text_encoder->unet->vae"
    # 定義可選元件列表
    _optional_components = ["safety_checker", "feature_extractor"]
    # 定義排除在 CPU 離線載入之外的元件
    _exclude_from_cpu_offload = ["safety_checker"]
    
    # 初始化方法,接受多個引數
    def __init__(
        self,
        vae: AutoencoderKL,  # 變分自編碼器模型
        text_encoder: CLIPTextModel,  # 文字編碼器
        tokenizer: CLIPTokenizer,  # 文字標記器
        unet: UNet2DConditionModel,  # UNet2D 條件模型
        scheduler: KarrasDiffusionSchedulers,  # 排程器
        safety_checker: StableDiffusionSafetyChecker,  # 安全檢查模組
        feature_extractor: CLIPImageProcessor,  # 特徵提取器
        requires_safety_checker: bool = True,  # 是否需要安全檢查器
    ):
        # 呼叫父類的初始化方法
        super().__init__()

        # 檢查是否禁用安全檢查器,並且需要安全檢查器
        if safety_checker is None and requires_safety_checker:
            # 記錄警告資訊,提醒使用者禁用安全檢查器的風險
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

        # 檢查是否提供了安全檢查器但未提供特徵提取器
        if safety_checker is not None and feature_extractor is None:
            # 丟擲錯誤,提示使用者需要定義特徵提取器
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )

        # 註冊各個模組到當前例項
        self.register_modules(
            vae=vae,  # 變分自編碼器
            text_encoder=text_encoder,  # 文字編碼器
            tokenizer=tokenizer,  # 分詞器
            unet=unet,  # U-Net 模型
            scheduler=scheduler,  # 排程器
            safety_checker=safety_checker,  # 安全檢查器
            feature_extractor=feature_extractor,  # 特徵提取器
        )
        # 計算 VAE 的縮放因子
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        # 初始化影像處理器,使用 VAE 縮放因子
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
        # 將是否需要安全檢查器的配置註冊到當前例項
        self.register_to_config(requires_safety_checker=requires_safety_checker)

        # 用於在多個 GPU 上執行多個去噪步驟時,將 unet 包裝為 torch.nn.DataParallel
        self.wrapped_unet = self.unet

    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 複製的函式
    def _encode_prompt(
        self,
        prompt,  # 輸入的提示文字
        device,  # 裝置型別(CPU或GPU)
        num_images_per_prompt,  # 每個提示生成的影像數量
        do_classifier_free_guidance,  # 是否使用無分類器引導
        negative_prompt=None,  # 可選的負面提示文字
        prompt_embeds: Optional[torch.Tensor] = None,  # 可選的提示嵌入
        negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可選的負面提示嵌入
        lora_scale: Optional[float] = None,  # 可選的 LoRA 縮放因子
        **kwargs,  # 其他可選引數
    ):
        # 設定棄用資訊,提醒使用者該方法即將被移除,建議使用新方法
        deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
        # 呼叫 deprecate 函式,傳遞棄用資訊和版本號
        deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)

        # 呼叫 encode_prompt 方法,獲取提示的嵌入元組
        prompt_embeds_tuple = self.encode_prompt(
            prompt=prompt,  # 輸入提示文字
            device=device,  # 計算裝置
            num_images_per_prompt=num_images_per_prompt,  # 每個提示生成的影像數量
            do_classifier_free_guidance=do_classifier_free_guidance,  # 是否使用無分類器引導
            negative_prompt=negative_prompt,  # 負提示文字
            prompt_embeds=prompt_embeds,  # 提示嵌入
            negative_prompt_embeds=negative_prompt_embeds,  # 負提示嵌入
            lora_scale=lora_scale,  # LORA 縮放因子
            **kwargs,  # 其他額外引數
        )

        # 將返回的嵌入元組進行拼接,以支援向後相容
        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])

        # 返回拼接後的提示嵌入
        return prompt_embeds

    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline 中複製的 encode_prompt 方法
    def encode_prompt(
        self,
        prompt,  # 輸入的提示文字
        device,  # 計算裝置
        num_images_per_prompt,  # 每個提示生成的影像數量
        do_classifier_free_guidance,  # 是否使用無分類器引導
        negative_prompt=None,  # 負提示文字,預設為 None
        prompt_embeds: Optional[torch.Tensor] = None,  # 可選的提示嵌入
        negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可選的負提示嵌入
        lora_scale: Optional[float] = None,  # 可選的 LORA 縮放因子
        clip_skip: Optional[int] = None,  # 可選的跳過剪輯引數
    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 複製的程式碼
    def run_safety_checker(self, image, device, dtype):
        # 檢查是否存在安全檢查器
        if self.safety_checker is None:
            has_nsfw_concept = None  # 如果沒有安全檢查器,則設定無 NSFW 概念為 None
        else:
            # 如果輸入是張量格式,則進行後處理為 PIL 格式
            if torch.is_tensor(image):
                feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
            else:
                # 如果輸入不是張量,轉換為 PIL 格式
                feature_extractor_input = self.image_processor.numpy_to_pil(image)
            # 使用特徵提取器處理影像,返回張量形式
            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
            # 呼叫安全檢查器,檢查影像是否包含 NSFW 概念
            image, has_nsfw_concept = self.safety_checker(
                images=image,  # 輸入影像
                clip_input=safety_checker_input.pixel_values.to(dtype)  # 安全檢查的特徵輸入
            )
        # 返回處理後的影像及是否存在 NSFW 概念
        return image, has_nsfw_concept

    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 複製的程式碼
    # 定義一個方法,用於準備額外的引數以供排程器步驟使用
    def prepare_extra_step_kwargs(self, generator, eta):
        # 為排程器步驟準備額外的關鍵字引數,因為並非所有排程器都有相同的簽名
        # eta (η) 僅在 DDIMScheduler 中使用,對於其他排程器將被忽略
        # eta 對應於 DDIM 論文中的 η: https://arxiv.org/abs/2010.02502
        # 其值應在 [0, 1] 之間

        # 檢查排程器的 step 方法是否接受 eta 引數
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 初始化一個字典以儲存額外的步驟引數
        extra_step_kwargs = {}
        # 如果接受 eta 引數,則將其新增到字典中
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # 檢查排程器的 step 方法是否接受 generator 引數
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 如果接受 generator 引數,則將其新增到字典中
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        # 返回準備好的額外步驟引數字典
        return extra_step_kwargs

    # 定義一個方法,用於檢查輸入引數的有效性
    def check_inputs(
        self,
        prompt,  # 輸入的提示文字
        height,  # 影像的高度
        width,   # 影像的寬度
        callback_steps,  # 回撥步驟的頻率
        negative_prompt=None,  # 可選的負面提示文字
        prompt_embeds=None,  # 可選的提示嵌入
        negative_prompt_embeds=None,  # 可選的負面提示嵌入
        callback_on_step_end_tensor_inputs=None,  # 可選的在步驟結束時的回撥張量輸入
    ):
        # 檢查高度和寬度是否為 8 的倍數
        if height % 8 != 0 or width % 8 != 0:
            # 丟擲錯誤,如果高度或寬度不符合要求
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

        # 檢查回撥步數是否為正整數
        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
            # 丟擲錯誤,如果回撥步數不符合要求
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )
        # 檢查回撥結束時的張量輸入是否在允許的輸入中
        if callback_on_step_end_tensor_inputs is not None and not all(
            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
        ):
            # 丟擲錯誤,如果不在允許的輸入中
            raise ValueError(
                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
            )

        # 檢查是否同時提供了提示和提示嵌入
        if prompt is not None and prompt_embeds is not None:
            # 丟擲錯誤,不能同時提供提示和提示嵌入
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        # 檢查提示和提示嵌入是否均未定義
        elif prompt is None and prompt_embeds is None:
            # 丟擲錯誤,必須提供一個提示或提示嵌入
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        # 檢查提示的型別
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            # 丟擲錯誤,如果提示不是字串或列表
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        # 檢查是否同時提供了負提示和負提示嵌入
        if negative_prompt is not None and negative_prompt_embeds is not None:
            # 丟擲錯誤,不能同時提供負提示和負提示嵌入
            raise ValueError(
                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
            )

        # 檢查提示嵌入和負提示嵌入的形狀是否相同
        if prompt_embeds is not None and negative_prompt_embeds is not None:
            if prompt_embeds.shape != negative_prompt_embeds.shape:
                # 丟擲錯誤,如果形狀不一致
                raise ValueError(
                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                    f" {negative_prompt_embeds.shape}."
                )

    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 複製的程式碼
    # 準備潛在變數,設定其形狀和屬性
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        # 定義潛在變數的形狀,考慮批次大小、通道數和縮放因子
        shape = (
            batch_size,
            num_channels_latents,
            int(height) // self.vae_scale_factor,
            int(width) // self.vae_scale_factor,
        )
        # 檢查生成器列表的長度是否與批次大小匹配
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )
        # 如果潛在變數未提供,則生成隨機潛在變數
        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            # 將提供的潛在變數移動到指定裝置
            latents = latents.to(device)
    
        # 根據排程器要求的標準差縮放初始噪聲
        latents = latents * self.scheduler.init_noise_sigma
        # 返回處理後的潛在變數
        return latents
    
    # 計算輸入張量在指定維度上的累積和
    def _cumsum(self, input, dim, debug=False):
        # 如果除錯模式開啟,則在CPU上執行累積和以確保可重複性
        if debug:
            # cumsum_cuda_kernel沒有確定性實現,故在CPU上執行
            return torch.cumsum(input.cpu().float(), dim=dim).to(input.device)
        else:
            # 在指定維度上直接計算累積和
            return torch.cumsum(input, dim=dim)
    
    # 呼叫方法,接受多種引數以生成輸出
    @torch.no_grad()
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        parallel: int = 10,
        tolerance: float = 0.1,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.Tensor] = None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
        callback_steps: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        debug: bool = False,
        clip_skip: int = None,

.\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_pix2pix_zero.py

# 版權宣告,說明版權資訊及持有者
# Copyright 2024 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
#
# 使用 Apache License 2.0 許可協議
# Licensed under the Apache License, Version 2.0 (the "License");
# 該檔案僅在遵循許可協議的情況下使用
# you may not use this file except in compliance with the License.
# 許可協議的獲取連結
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用的法律要求或書面協議,否則軟體按“原樣”分發
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何明示或暗示的保證或條件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 檢視許可協議中關於許可權和限制的詳細資訊
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect  # 匯入 inspect 模組,用於獲取物件資訊
from dataclasses import dataclass  # 從 dataclasses 模組匯入 dataclass 裝飾器
from typing import Any, Callable, Dict, List, Optional, Union  # 匯入型別提示

import numpy as np  # 匯入 numpy 模組,常用於數值計算
import PIL.Image  # 匯入 PIL.Image 模組,用於影像處理
import torch  # 匯入 PyTorch 庫,主要用於深度學習
import torch.nn.functional as F  # 匯入 PyTorch 的功能性神經網路模組
from transformers import (  # 從 transformers 模組匯入以下類
    BlipForConditionalGeneration,  # 匯入用於條件生成的 Blip 模型
    BlipProcessor,  # 匯入 Blip 處理器
    CLIPImageProcessor,  # 匯入 CLIP 影像處理器
    CLIPTextModel,  # 匯入 CLIP 文字模型
    CLIPTokenizer,  # 匯入 CLIP 分詞器
)

from ....image_processor import PipelineImageInput, VaeImageProcessor  # 從自定義模組匯入影像處理相關類
from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin  # 匯入穩定擴散和文字反轉載入器混合類
from ....models import AutoencoderKL, UNet2DConditionModel  # 匯入自動編碼器和 UNet 模型
from ....models.attention_processor import Attention  # 匯入注意力處理器
from ....models.lora import adjust_lora_scale_text_encoder  # 匯入調整 Lora 文字編碼器規模的函式
from ....schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler  # 匯入多種排程器
from ....schedulers.scheduling_ddim_inverse import DDIMInverseScheduler  # 匯入 DDIM 反向排程器
from ....utils import (  # 從自定義工具模組匯入實用函式和常量
    PIL_INTERPOLATION,  # 匯入 PIL 影像插值方法
    USE_PEFT_BACKEND,  # 匯入是否使用 PEFT 後端的常量
    BaseOutput,  # 匯入基礎輸出類
    deprecate,  # 匯入廢棄標記裝飾器
    logging,  # 匯入日誌記錄模組
    replace_example_docstring,  # 匯入替換示例文件字串的函式
    scale_lora_layers,  # 匯入縮放 Lora 層的函式
    unscale_lora_layers,  # 匯入反縮放 Lora 層的函式
)
from ....utils.torch_utils import randn_tensor  # 從 PyTorch 工具模組匯入生成隨機張量的函式
from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin  # 從管道工具模組匯入擴散管道和穩定擴散混合類
from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput  # 匯入穩定擴散管道輸出類
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker  # 匯入穩定擴散安全檢查器

logger = logging.get_logger(__name__)  # 獲取當前模組的日誌記錄器

@dataclass  # 將下面的類定義為資料類
class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin):  # 定義輸出類,繼承基礎輸出和文字反轉載入器混合類
    """
    輸出類用於穩定擴散管道。

    引數:
        latents (`torch.Tensor`)
            反轉的潛在張量
        images (`List[PIL.Image.Image]` or `np.ndarray`)
            長度為 `batch_size` 的去噪 PIL 影像列表或形狀為 `(batch_size, height, width,
            num_channels)` 的 numpy 陣列。PIL 影像或 numpy 陣列呈現擴散管道的去噪影像。
    """

    latents: torch.Tensor  # 定義潛在張量屬性
    images: Union[List[PIL.Image.Image], np.ndarray]  # 定義影像屬性,可以是影像列表或 numpy 陣列

EXAMPLE_DOC_STRING = """  # 定義示例文件字串的初始部分
```  # 示例文件字串的開始
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
```py  # 示例文件字串的結束
```  # 示例文件字串的結束
    # 示例程式碼展示如何使用 Diffusers 庫進行影像生成
    Examples:
        ```py
        # 匯入所需的庫
        >>> import requests  # 用於傳送 HTTP 請求
        >>> import torch  # 用於處理張量和深度學習模型

        # 從 Diffusers 庫匯入必要的類
        >>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline

        # 定義下載嵌入檔案的函式
        >>> def download(embedding_url, local_filepath):
        ...     # 傳送 GET 請求獲取嵌入檔案
        ...     r = requests.get(embedding_url)
        ...     # 以二進位制模式開啟本地檔案並寫入獲取的內容
        ...     with open(local_filepath, "wb") as f:
        ...         f.write(r.content)

        # 定義模型檢查點的名稱
        >>> model_ckpt = "CompVis/stable-diffusion-v1-4"
        # 從預訓練模型載入管道並設定資料型別為 float16
        >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16)
        # 根據管道配置建立 DDIM 排程器
        >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
        # 將模型移動到 GPU
        >>> pipeline.to("cuda")

        # 定義文字提示
        >>> prompt = "a high resolution painting of a cat in the style of van gough"
        # 定義源和目標嵌入檔案的 URL
        >>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt"
        >>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt"

        # 遍歷源和目標嵌入 URL 進行下載
        >>> for url in [source_emb_url, target_emb_url]:
        ...     # 呼叫下載函式,將檔案儲存到本地
        ...     download(url, url.split("/")[-1])

        # 從本地載入源嵌入
        >>> src_embeds = torch.load(source_emb_url.split("/")[-1])
        # 從本地載入目標嵌入
        >>> target_embeds = torch.load(target_emb_url.split("/")[-1])
        # 使用管道生成影像
        >>> images = pipeline(
        ...     prompt,  # 輸入的文字提示
        ...     source_embeds=src_embeds,  # 源嵌入
        ...     target_embeds=target_embeds,  # 目標嵌入
        ...     num_inference_steps=50,  # 推理步驟數
        ...     cross_attention_guidance_amount=0.15,  # 跨注意力引導的強度
        ... ).images  # 生成的影像

        # 儲存生成的第一張影像
        >>> images[0].save("edited_image_dog.png")  # 將影像儲存為 PNG 檔案
"""
# 示例文件字串,提供了使用示例和說明
EXAMPLE_INVERT_DOC_STRING = """
    Examples:
        ```py
        >>> import torch  # 匯入 PyTorch 庫
        >>> from transformers import BlipForConditionalGeneration, BlipProcessor  # 從 transformers 匯入模型和處理器
        >>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline  # 從 diffusers 匯入排程器和管道

        >>> import requests  # 匯入 requests 庫,用於傳送網路請求
        >>> from PIL import Image  # 從 PIL 匯入 Image 類,用於處理影像

        >>> captioner_id = "Salesforce/blip-image-captioning-base"  # 定義影像說明生成模型的 ID
        >>> processor = BlipProcessor.from_pretrained(captioner_id)  # 從預訓練模型載入處理器
        >>> model = BlipForConditionalGeneration.from_pretrained(  # 從預訓練模型載入影像說明生成模型
        ...     captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True  # 指定資料型別和低記憶體使用模式
        ... )

        >>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4"  # 定義穩定擴散模型的檢查點 ID
        >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(  # 從預訓練模型載入 Pix2Pix 零管道
        ...     sd_model_ckpt,  # 指定檢查點
        ...     caption_generator=model,  # 指定影像說明生成器
        ...     caption_processor=processor,  # 指定影像說明處理器
        ...     torch_dtype=torch.float16,  # 指定資料型別
        ...     safety_checker=None,  # 關閉安全檢查器
        ... )

        >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)  # 使用排程器配置初始化 DDIM 排程器
        >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)  # 使用排程器配置初始化 DDIM 反向排程器
        >>> pipeline.enable_model_cpu_offload()  # 啟用模型的 CPU 解除安裝

        >>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"  # 定義要處理的影像 URL

        >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))  # 從 URL 載入影像並調整大小
        >>> # 生成說明
        >>> caption = pipeline.generate_caption(raw_image)  # 生成影像的說明

        >>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii"  # 生成的說明示例
        >>> inv_latents = pipeline.invert(caption, image=raw_image).latents  # 根據說明和原始影像進行反向處理,獲取潛變數
        >>> # 我們需要生成源和目標嵌入

        >>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]  # 定義源提示列表

        >>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]  # 定義目標提示列表

        >>> source_embeds = pipeline.get_embeds(source_prompts)  # 獲取源提示的嵌入表示
        >>> target_embeds = pipeline.get_embeds(target_prompts)  # 獲取目標提示的嵌入表示
        >>> # 潛變數可以用於編輯真實影像
        >>> # 在使用穩定擴散 2 或其他使用 v-prediction 的模型時
        >>> # 將 `cross_attention_guidance_amount` 設定為 0.01 或更低,以避免輸入潛變數梯度爆炸

        >>> image = pipeline(  # 使用管道生成新的影像
        ...     caption,  # 使用生成的說明
        ...     source_embeds=source_embeds,  # 傳遞源嵌入
        ...     target_embeds=target_embeds,  # 傳遞目標嵌入
        ...     num_inference_steps=50,  # 指定推理步驟數量
        ...     cross_attention_guidance_amount=0.15,  # 指定交叉注意力指導量
        ...     generator=generator,  # 使用指定的生成器
        ...     latents=inv_latents,  # 傳遞潛變數
        ...     negative_prompt=caption,  # 使用生成的說明作為負提示
        ... ).images[0]  # 獲取生成的影像
        >>> image.save("edited_image.png")  # 儲存生成的影像
        ```py
"""


# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img 匯入的 preprocess 函式
def preprocess(image):  # 定義 preprocess 函式,接受影像作為引數
    # 設定一個警告資訊,提示使用者 preprocess 方法已被棄用,並將在 diffusers 1.0.0 中移除
    deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
    # 呼叫 deprecate 函式,記錄棄用資訊,設定標準警告為 False
    deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
    # 檢查輸入的 image 是否是一個 Torch 張量
    if isinstance(image, torch.Tensor):
        # 如果是,直接返回該張量
        return image
    # 檢查輸入的 image 是否是一個 PIL 影像
    elif isinstance(image, PIL.Image.Image):
        # 如果是,將其封裝為一個單元素列表
        image = [image]

    # 檢查列表中的第一個元素是否為 PIL 影像
    if isinstance(image[0], PIL.Image.Image):
        # 獲取第一個影像的寬度和高度
        w, h = image[0].size
        # 將寬度和高度調整為 8 的整數倍
        w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8

        # 對每個影像進行調整大小,轉換為 numpy 陣列,並在新維度上增加一維
        image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
        # 將所有影像在第 0 維上連線成一個大的陣列
        image = np.concatenate(image, axis=0)
        # 將資料轉換為 float32 型別並歸一化到 [0, 1] 範圍
        image = np.array(image).astype(np.float32) / 255.0
        # 調整陣列維度順序為 (batch_size, channels, height, width)
        image = image.transpose(0, 3, 1, 2)
        # 將畫素值範圍從 [0, 1] 轉換到 [-1, 1]
        image = 2.0 * image - 1.0
        # 將 numpy 陣列轉換為 Torch 張量
        image = torch.from_numpy(image)
    # 檢查列表中的第一個元素是否為 Torch 張量
    elif isinstance(image[0], torch.Tensor):
        # 將多個張量在第 0 維上連線成一個大的張量
        image = torch.cat(image, dim=0)
    # 返回處理後的影像
    return image
# 準備 UNet 模型以執行 Pix2Pix Zero 最佳化
def prepare_unet(unet: UNet2DConditionModel):
    # 初始化一個空字典,用於儲存 Pix2Pix Zero 注意力處理器
    pix2pix_zero_attn_procs = {}
    # 遍歷 UNet 的注意力處理器的鍵
    for name in unet.attn_processors.keys():
        # 將處理器名稱中的 ".processor" 替換為空
        module_name = name.replace(".processor", "")
        # 獲取 UNet 中對應的子模組
        module = unet.get_submodule(module_name)
        # 如果名稱包含 "attn2"
        if "attn2" in name:
            # 將處理器設定為 Pix2Pix Zero 模式
            pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True)
            # 允許該模組進行梯度更新
            module.requires_grad_(True)
        else:
            # 設定為非 Pix2Pix Zero 模式
            pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False)
            # 不允許該模組進行梯度更新
            module.requires_grad_(False)

    # 設定 UNet 的注意力處理器為修改後的處理器字典
    unet.set_attn_processor(pix2pix_zero_attn_procs)
    # 返回修改後的 UNet 模型
    return unet


class Pix2PixZeroL2Loss:
    # 初始化損失類
    def __init__(self):
        # 設定初始損失值為 0
        self.loss = 0.0

    # 計算損失的方法
    def compute_loss(self, predictions, targets):
        # 更新損失值為預測值與目標值之間的均方差
        self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0)


class Pix2PixZeroAttnProcessor:
    """注意力處理器類,用於儲存注意力權重。
    在 Pix2Pix Zero 中,該過程發生在交叉注意力塊的計算中。"""

    # 初始化注意力處理器
    def __init__(self, is_pix2pix_zero=False):
        # 記錄是否為 Pix2Pix Zero 模式
        self.is_pix2pix_zero = is_pix2pix_zero
        # 如果是 Pix2Pix Zero 模式,初始化參考交叉注意力對映
        if self.is_pix2pix_zero:
            self.reference_cross_attn_map = {}

    # 定義呼叫方法
    def __call__(
        self,
        attn: Attention,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        timestep=None,
        loss=None,
    ):
        # 獲取隱藏狀態的批次大小和序列長度
        batch_size, sequence_length, _ = hidden_states.shape
        # 準備注意力掩碼
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
        # 將隱藏狀態轉換為查詢向量
        query = attn.to_q(hidden_states)

        # 如果沒有編碼器隱藏狀態,則使用隱藏狀態本身
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        # 如果需要進行交叉規範化
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        # 將編碼器隱藏狀態轉換為鍵和值
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        # 將查詢、鍵和值轉換為批次維度
        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        # 計算注意力分數
        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        # 如果是 Pix2Pix Zero 模式且時間步不為 None
        if self.is_pix2pix_zero and timestep is not None:
            # 新的記錄以儲存注意力權重
            if loss is None:
                self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu()
            # 計算損失
            elif loss is not None:
                # 獲取之前的注意力機率
                prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item())
                # 計算損失
                loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device))

        # 將注意力機率與值相乘以獲得新的隱藏狀態
        hidden_states = torch.bmm(attention_probs, value)
        # 將隱藏狀態轉換回頭維度
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # 線性變換
        hidden_states = attn.to_out[0](hidden_states)
        # 應用 dropout
        hidden_states = attn.to_out[1](hidden_states)

        # 返回新的隱藏狀態
        return hidden_states
# 定義一個用於畫素級影像編輯的管道類,基於 Stable Diffusion
class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin):
    r"""
    使用 Pix2Pix Zero 進行畫素級影像編輯的管道。基於 Stable Diffusion。

    該模型繼承自 [`DiffusionPipeline`]。請查閱超類文件以獲取庫為所有管道實現的通用方法(例如下載或儲存、在特定裝置上執行等)。

    引數:
        vae ([`AutoencoderKL`]):
            用於將影像編碼和解碼到潛在表示的變分自編碼器(VAE)模型。
        text_encoder ([`CLIPTextModel`]):
            凍結的文字編碼器。Stable Diffusion 使用
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) 的文字部分,
            特別是 [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) 變體。
        tokenizer (`CLIPTokenizer`):
            類的分詞器
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)。
        unet ([`UNet2DConditionModel`]): 用於去噪編碼影像潛在的條件 U-Net 架構。
        scheduler ([`SchedulerMixin`]):
            與 `unet` 一起使用以去噪編碼影像潛在的排程器。可以是
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] 或 [`DDPMScheduler`] 的之一。
        safety_checker ([`StableDiffusionSafetyChecker`]):
            估計生成影像是否可能被視為攻擊性或有害的分類模組。
            請參閱 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) 以獲取詳細資訊。
        feature_extractor ([`CLIPImageProcessor`]):
            從生成影像中提取特徵的模型,以便作為 `safety_checker` 的輸入。
        requires_safety_checker (bool):
            管道是否需要安全檢查器。如果您公開使用該管道,我們建議將其設定為 True。
    """

    # 定義 CPU 解除安裝的模型元件順序
    model_cpu_offload_seq = "text_encoder->unet->vae"
    # 可選元件列表
    _optional_components = [
        "safety_checker",
        "feature_extractor",
        "caption_generator",
        "caption_processor",
        "inverse_scheduler",
    ]
    # 從 CPU 解除安裝中排除的元件列表
    _exclude_from_cpu_offload = ["safety_checker"]

    # 初始化方法,定義管道的引數
    def __init__(
        self,
        vae: AutoencoderKL,  # 變分自編碼器模型
        text_encoder: CLIPTextModel,  # 文字編碼器模型
        tokenizer: CLIPTokenizer,  # 分詞器模型
        unet: UNet2DConditionModel,  # 條件 U-Net 模型
        scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler],  # 排程器型別
        feature_extractor: CLIPImageProcessor,  # 特徵提取器模型
        safety_checker: StableDiffusionSafetyChecker,  # 安全檢查器模型
        inverse_scheduler: DDIMInverseScheduler,  # 反向排程器
        caption_generator: BlipForConditionalGeneration,  # 描述生成器
        caption_processor: BlipProcessor,  # 描述處理器
        requires_safety_checker: bool = True,  # 是否需要安全檢查器的標誌
    # 定義一個建構函式
        ):
            # 呼叫父類的建構函式
            super().__init__()
    
            # 如果沒有提供安全檢查器且需要安全檢查器,發出警告
            if safety_checker is None and requires_safety_checker:
                logger.warning(
                    # 輸出關於禁用安全檢查器的警告資訊
                    f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                    " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                    " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                    " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                    " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                    " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
                )
    
            # 如果提供了安全檢查器但沒有提供特徵提取器,丟擲錯誤
            if safety_checker is not None and feature_extractor is None:
                raise ValueError(
                    # 提示使用者必須定義特徵提取器以使用安全檢查器
                    "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                    " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
                )
    
            # 註冊模組,設定各個元件
            self.register_modules(
                vae=vae,
                text_encoder=text_encoder,
                tokenizer=tokenizer,
                unet=unet,
                scheduler=scheduler,
                safety_checker=safety_checker,
                feature_extractor=feature_extractor,
                caption_processor=caption_processor,
                caption_generator=caption_generator,
                inverse_scheduler=inverse_scheduler,
            )
            # 計算 VAE 的縮放因子
            self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
            # 建立影像處理器,使用 VAE 縮放因子
            self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
            # 將配置項註冊到當前例項
            self.register_to_config(requires_safety_checker=requires_safety_checker)
    
        # 從 StableDiffusionPipeline 類複製的編碼提示的方法
        def _encode_prompt(
            self,
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt=None,
            # 可選引數,表示提示的嵌入
            prompt_embeds: Optional[torch.Tensor] = None,
            # 可選引數,表示負面提示的嵌入
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 可選引數,表示 LORA 的縮放因子
            lora_scale: Optional[float] = None,
            # 接收任意額外引數
            **kwargs,
    # 開始定義一個方法,處理已棄用的編碼提示功能
        ):
            # 定義棄用資訊,說明該方法將被移除,並推薦使用新方法
            deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
            # 呼叫棄用函式,記錄棄用警告
            deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
    
            # 呼叫新的編碼提示方法,獲取結果元組
            prompt_embeds_tuple = self.encode_prompt(
                # 傳入提示文字
                prompt=prompt,
                # 裝置型別(CPU/GPU)
                device=device,
                # 每個提示生成的影像數量
                num_images_per_prompt=num_images_per_prompt,
                # 是否進行分類器自由引導
                do_classifier_free_guidance=do_classifier_free_guidance,
                # 負面提示文字
                negative_prompt=negative_prompt,
                # 提示嵌入
                prompt_embeds=prompt_embeds,
                # 負面提示嵌入
                negative_prompt_embeds=negative_prompt_embeds,
                # Lora縮放因子
                lora_scale=lora_scale,
                # 其他可選引數
                **kwargs,
            )
    
            # 將返回的元組中的兩個嵌入連線起來以相容舊版
            prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
    
            # 返回最終的提示嵌入
            return prompt_embeds
    
        # 從指定的管道中複製的 encode_prompt 方法定義
        def encode_prompt(
            # 提示文字
            self,
            prompt,
            # 裝置型別
            device,
            # 每個提示生成的影像數量
            num_images_per_prompt,
            # 是否進行分類器自由引導
            do_classifier_free_guidance,
            # 負面提示文字(可選)
            negative_prompt=None,
            # 提示嵌入(可選)
            prompt_embeds: Optional[torch.Tensor] = None,
            # 負面提示嵌入(可選)
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # Lora縮放因子(可選)
            lora_scale: Optional[float] = None,
            # 跳過的clip層數(可選)
            clip_skip: Optional[int] = None,
        # 從指定的管道中複製的 run_safety_checker 方法定義
        def run_safety_checker(self, image, device, dtype):
            # 檢查是否存在安全檢查器
            if self.safety_checker is None:
                # 如果沒有安全檢查器,標記為無概念
                has_nsfw_concept = None
            else:
                # 如果影像是張量,則進行後處理以轉換為PIL格式
                if torch.is_tensor(image):
                    feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
                else:
                    # 如果不是張量,則將其轉換為PIL格式
                    feature_extractor_input = self.image_processor.numpy_to_pil(image)
                # 將處理後的影像提取特徵,準備進行安全檢查
                safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
                # 使用安全檢查器檢查影像,返回影像及其概念狀態
                image, has_nsfw_concept = self.safety_checker(
                    images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
                )
            # 返回檢查後的影像及其概念狀態
            return image, has_nsfw_concept
    
        # 從指定的管道中複製的 decode_latents 方法
    # 解碼潛在向量並返回生成的影像
    def decode_latents(self, latents):
        # 警告使用者該方法已過時,將在1.0.0版本中刪除
        deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
        # 呼叫deprecate函式記錄該方法的棄用資訊
        deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
    
        # 使用配置的縮放因子對潛在向量進行縮放
        latents = 1 / self.vae.config.scaling_factor * latents
        # 解碼潛在向量,返回生成的影像
        image = self.vae.decode(latents, return_dict=False)[0]
        # 將影像值從[-1, 1]對映到[0, 1]並限制範圍
        image = (image / 2 + 0.5).clamp(0, 1)
        # 將影像轉換為float32格式以確保相容性,並將其轉換為numpy陣列
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        # 返回處理後的影像
        return image
    
    # 從diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs複製
    def prepare_extra_step_kwargs(self, generator, eta):
        # 準備額外的引數以供排程器步驟使用,因排程器的引數簽名可能不同
        # eta (η) 僅在DDIMScheduler中使用,其他排程器將忽略它。
        # eta在DDIM論文中對應於η:https://arxiv.org/abs/2010.02502
        # eta的值應在[0, 1]之間
    
        # 檢查排程器步驟是否接受eta引數
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 初始化額外引數字典
        extra_step_kwargs = {}
        # 如果接受eta,則將其新增到額外引數中
        if accepts_eta:
            extra_step_kwargs["eta"] = eta
    
        # 檢查排程器步驟是否接受generator引數
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 如果接受generator,則將其新增到額外引數中
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        # 返回準備好的額外引數
        return extra_step_kwargs
    
    def check_inputs(
        self,
        prompt,
        source_embeds,
        target_embeds,
        callback_steps,
        prompt_embeds=None,
    ):
        # 檢查callback_steps是否為正整數
        if (callback_steps is None) or (
            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
        ):
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )
        # 確保source_embeds和target_embeds不能同時未定義
        if source_embeds is None and target_embeds is None:
            raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.")
    
        # 檢查prompt和prompt_embeds不能同時被定義
        if prompt is not None and prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        # 檢查prompt和prompt_embeds不能同時未定義
        elif prompt is None and prompt_embeds is None:
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        # 確保prompt的型別為str或list
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
    #  從 StableDiffusionPipeline 的 prepare_latents 方法複製的內容
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        # 定義潛在張量的形狀,包括批次大小、通道數、高度和寬度
        shape = (
            batch_size,
            num_channels_latents,
            int(height) // self.vae_scale_factor,
            int(width) // self.vae_scale_factor,
        )
        # 檢查生成器是否為列表且其長度與批次大小匹配
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        # 如果未提供潛在張量,則生成隨機張量
        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            # 如果提供了潛在張量,則將其移動到指定裝置
            latents = latents.to(device)

        # 按排程器所需的標準差縮放初始噪聲
        latents = latents * self.scheduler.init_noise_sigma
        # 返回處理後的潛在張量
        return latents

    @torch.no_grad()
    def generate_caption(self, images):
        """為給定影像生成標題。"""
        # 初始化生成標題的文字
        text = "a photography of"

        # 儲存當前裝置
        prev_device = self.caption_generator.device

        # 獲取執行裝置
        device = self._execution_device
        # 處理輸入影像並轉換為張量
        inputs = self.caption_processor(images, text, return_tensors="pt").to(
            device=device, dtype=self.caption_generator.dtype
        )
        # 將標題生成器移動到指定裝置
        self.caption_generator.to(device)
        # 生成標題輸出
        outputs = self.caption_generator.generate(**inputs, max_new_tokens=128)

        # 將標題生成器移回先前裝置
        self.caption_generator.to(prev_device)

        # 解碼輸出以獲取標題
        caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0]
        # 返回生成的標題
        return caption

    def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor):
        """構造用於引導影像生成過程的編輯方向。"""
        # 返回目標和源嵌入的均值之差,並增加一個維度
        return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)

    @torch.no_grad()
    def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.Tensor:
        # 獲取提示的數量
        num_prompts = len(prompt)
        # 初始化嵌入列表
        embeds = []
        # 分批處理提示
        for i in range(0, num_prompts, batch_size):
            prompt_slice = prompt[i : i + batch_size]

            # 將提示轉換為輸入 ID,進行填充和截斷
            input_ids = self.tokenizer(
                prompt_slice,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            ).input_ids

            # 將輸入 ID 移動到文字編碼器裝置
            input_ids = input_ids.to(self.text_encoder.device)
            # 獲取嵌入並追加到列表
            embeds.append(self.text_encoder(input_ids)[0])

        # 將所有嵌入拼接並計算均值
        return torch.cat(embeds, dim=0).mean(0)[None]
    # 準備影像的潛在表示,接收影像和其他引數,返回潛在向量
        def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):
            # 檢查輸入影像的型別是否為有效型別
            if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
                # 丟擲型別錯誤,提示使用者輸入型別不正確
                raise ValueError(
                    f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
                )
    
            # 將影像轉換到指定的裝置和資料型別
            image = image.to(device=device, dtype=dtype)
    
            # 如果影像有四個通道,直接將其作為潛在表示
            if image.shape[1] == 4:
                latents = image
    
            else:
                # 檢查生成器列表的長度是否與批次大小匹配
                if isinstance(generator, list) and len(generator) != batch_size:
                    # 丟擲錯誤,提示生成器列表長度與批次大小不匹配
                    raise ValueError(
                        f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                        f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                    )
    
                # 如果生成器是列表,逐個影像編碼並生成潛在表示
                if isinstance(generator, list):
                    latents = [
                        self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
                    ]
                    # 將潛在表示合併到一個張量中
                    latents = torch.cat(latents, dim=0)
                else:
                    # 使用單個生成器編碼影像並生成潛在表示
                    latents = self.vae.encode(image).latent_dist.sample(generator)
    
                # 根據配置的縮放因子調整潛在表示
                latents = self.vae.config.scaling_factor * latents
    
            # 檢查潛在表示的批次大小是否與請求的匹配
            if batch_size != latents.shape[0]:
                # 如果可以整除,則擴充套件潛在表示以匹配批次大小
                if batch_size % latents.shape[0] == 0:
                    # 構建棄用訊息,提示使用者行為即將被移除
                    deprecation_message = (
                        f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"
                        " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
                        " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
                        " your script to pass as many initial images as text prompts to suppress this warning."
                    )
                    # 觸發棄用警告,提醒使用者修改程式碼
                    deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
                    # 計算每個影像需要複製的次數
                    additional_latents_per_image = batch_size // latents.shape[0]
                    # 將潛在表示按需重複以匹配批次大小
                    latents = torch.cat([latents] * additional_latents_per_image, dim=0)
                else:
                    # 丟擲錯誤,提示無法複製影像以匹配批次大小
                    raise ValueError(
                        f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
                    )
            else:
                # 將潛在表示封裝為一個張量
                latents = torch.cat([latents], dim=0)
    
            # 返回最終的潛在表示
            return latents
    # 定義一個獲取epsilon的函式,輸入為模型輸出、樣本和時間步
        def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int):
            # 獲取反向排程器的預測型別配置
            pred_type = self.inverse_scheduler.config.prediction_type
            # 計算在當前時間步的累積alpha值
            alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep]
    
            # 計算beta值為1減去alpha值
            beta_prod_t = 1 - alpha_prod_t
    
            # 根據預測型別返回相應的結果
            if pred_type == "epsilon":
                return model_output
            elif pred_type == "sample":
                # 根據樣本和模型輸出計算返回值
                return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5)
            elif pred_type == "v_prediction":
                # 根據alpha和beta值結合模型輸出和樣本計算返回值
                return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
            else:
                # 如果預測型別無效,丟擲異常
                raise ValueError(
                    f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
                )
    
        # 定義一個自動相關損失計算的函式,輸入為隱藏狀態和可選生成器
        def auto_corr_loss(self, hidden_states, generator=None):
            # 初始化正則化損失為0
            reg_loss = 0.0
            # 遍歷隱藏狀態的每一個維度
            for i in range(hidden_states.shape[0]):
                for j in range(hidden_states.shape[1]):
                    # 選取當前噪聲
                    noise = hidden_states[i : i + 1, j : j + 1, :, :]
                    # 進行迴圈,直到噪聲尺寸小於等於8
                    while True:
                        # 隨機生成滾動的位移量
                        roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
                        # 計算並累加正則化損失
                        reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
                        reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
    
                        # 如果噪聲尺寸小於等於8,跳出迴圈
                        if noise.shape[2] <= 8:
                            break
                        # 對噪聲進行2x2的平均池化
                        noise = F.avg_pool2d(noise, kernel_size=2)
            # 返回計算得到的正則化損失
            return reg_loss
    
        # 定義一個計算KL散度的函式,輸入為隱藏狀態
        def kl_divergence(self, hidden_states):
            # 計算隱藏狀態的均值
            mean = hidden_states.mean()
            # 計算隱藏狀態的方差
            var = hidden_states.var()
            # 返回KL散度的計算結果
            return var + mean**2 - 1 - torch.log(var + 1e-7)
    
        # 定義呼叫函式,使用@torch.no_grad()裝飾器禁止梯度計算
        @torch.no_grad()
        @replace_example_docstring(EXAMPLE_DOC_STRING)
        def __call__(
            # 輸入引數包括提示、源和目標嵌入、影像的高和寬、推理步驟等
            prompt: Optional[Union[str, List[str]]] = None,
            source_embeds: torch.Tensor = None,
            target_embeds: torch.Tensor = None,
            height: Optional[int] = None,
            width: Optional[int] = None,
            num_inference_steps: int = 50,
            guidance_scale: float = 7.5,
            negative_prompt: Optional[Union[str, List[str]]] = None,
            num_images_per_prompt: Optional[int] = 1,
            eta: float = 0.0,
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            latents: Optional[torch.Tensor] = None,
            prompt_embeds: Optional[torch.Tensor] = None,
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            cross_attention_guidance_amount: float = 0.1,
            output_type: Optional[str] = "pil",
            return_dict: bool = True,
            callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
            callback_steps: Optional[int] = 1,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            clip_skip: Optional[int] = None,
        # 使用@torch.no_grad()和裝飾器替換文件字串
        @torch.no_grad()
        @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING)
    # 定義一個名為 invert 的方法,包含多個可選引數
        def invert(
            # 輸入提示,預設為 None
            self,
            prompt: Optional[str] = None,
            # 輸入影像,預設為 None
            image: PipelineImageInput = None,
            # 推理步驟的數量,預設為 50
            num_inference_steps: int = 50,
            # 指導比例,預設為 1
            guidance_scale: float = 1,
            # 隨機數生成器,可以是單個或多個生成器,預設為 None
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 潛在變數,預設為 None
            latents: Optional[torch.Tensor] = None,
            # 提示嵌入,預設為 None
            prompt_embeds: Optional[torch.Tensor] = None,
            # 跨注意力引導量,預設為 0.1
            cross_attention_guidance_amount: float = 0.1,
            # 輸出型別,預設為 "pil"
            output_type: Optional[str] = "pil",
            # 是否返回字典,預設為 True
            return_dict: bool = True,
            # 回撥函式,預設為 None
            callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
            # 回撥步驟,預設為 1
            callback_steps: Optional[int] = 1,
            # 跨注意力引數,預設為 None
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 自動相關的權重,預設為 20.0
            lambda_auto_corr: float = 20.0,
            # KL 散度的權重,預設為 20.0
            lambda_kl: float = 20.0,
            # 正則化步驟的數量,預設為 5
            num_reg_steps: int = 5,
            # 自動相關滾動的數量,預設為 5
            num_auto_corr_rolls: int = 5,

.\diffusers\pipelines\deprecated\stable_diffusion_variants\__init__.py

# 從型別檢查模組匯入型別檢查相關功能
from typing import TYPE_CHECKING

# 從 utils 模組匯入所需的工具和常量
from ....utils import (
    DIFFUSERS_SLOW_IMPORT,  # 匯入用於延遲匯入的常量
    OptionalDependencyNotAvailable,  # 匯入可選依賴不可用的異常類
    _LazyModule,  # 匯入延遲模組載入的工具
    get_objects_from_module,  # 匯入從模組獲取物件的函式
    is_torch_available,  # 匯入檢查 Torch 庫是否可用的函式
    is_transformers_available,  # 匯入檢查 Transformers 庫是否可用的函式
)

# 初始化一個空字典用於存放虛擬物件
_dummy_objects = {}
# 初始化一個空字典用於存放模組匯入結構
_import_structure = {}

try:
    # 檢查 Transformers 和 Torch 庫是否都可用
    if not (is_transformers_available() and is_torch_available()):
        # 如果不可用,則丟擲異常
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 匯入虛擬物件以避免在依賴不可用時導致錯誤
    from ....utils import dummy_torch_and_transformers_objects

    # 更新虛擬物件字典,填充從虛擬物件模組獲取的物件
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
    # 如果依賴可用,新增相關的管道到匯入結構字典
    _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"]
    _import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"]
    _import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"]

    _import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
    _import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"]

# 根據型別檢查標誌或慢速匯入標誌進行條件判斷
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    try:
        # 再次檢查依賴是否可用
        if not (is_transformers_available() and is_torch_available()):
            # 如果不可用,則丟擲異常
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 匯入虛擬物件以避免在依賴不可用時導致錯誤
        from ....utils.dummy_torch_and_transformers_objects import *

    else:
        # 匯入具體的管道類,確保它們在依賴可用時被載入
        from .pipeline_cycle_diffusion import CycleDiffusionPipeline
        from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
        from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline
        from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
        from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline

else:
    # 如果不是型別檢查或慢速匯入,則進行懶載入處理
    import sys

    # 使用懶載入模組構造當前模組,指定匯入結構和模組規格
    sys.modules[__name__] = _LazyModule(
        __name__,
        globals()["__file__"],
        _import_structure,
        module_spec=__spec__,
    )
    # 遍歷虛擬物件字典,將物件屬性設定到當前模組
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)

.\diffusers\pipelines\deprecated\stochastic_karras_ve\pipeline_stochastic_karras_ve.py

# 版權宣告,表明此檔案的版權所有者及保留權利
# 
# 根據 Apache 許可證第 2.0 版(“許可證”)進行許可;
# 除非遵循許可證,否則您不得使用此檔案。
# 您可以在以下網址獲取許可證的副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非適用法律或書面同意,軟體在許可證下分發,按“原樣”基礎,
# 不提供任何形式的保證或條件,無論是明示或暗示的。
# 請參見許可證以獲取有關許可權和
# 限制的具體規定。

# 從 typing 模組匯入所需的型別提示
from typing import List, Optional, Tuple, Union

# 匯入 PyTorch 庫
import torch

# 從相對路徑匯入 UNet2DModel 模型
from ....models import UNet2DModel
# 從相對路徑匯入排程器 KarrasVeScheduler
from ....schedulers import KarrasVeScheduler
# 從相對路徑匯入隨機張量生成工具
from ....utils.torch_utils import randn_tensor
# 從相對路徑匯入擴散管道和影像輸出
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput

# 定義 KarrasVePipeline 類,繼承自 DiffusionPipeline
class KarrasVePipeline(DiffusionPipeline):
    r"""
    無條件影像生成的管道。

    引數:
        unet ([`UNet2DModel`]):
            用於去噪編碼影像的 `UNet2DModel`。
        scheduler ([`KarrasVeScheduler`]):
            用於與 `unet` 結合去噪編碼影像的排程器。
    """

    # 為 linting 新增型別提示
    unet: UNet2DModel  # 定義 unet 型別為 UNet2DModel
    scheduler: KarrasVeScheduler  # 定義 scheduler 型別為 KarrasVeScheduler

    # 初始化函式,接受 UNet2DModel 和 KarrasVeScheduler 作為引數
    def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
        # 呼叫父類的初始化函式
        super().__init__()
        # 註冊模組,將 unet 和 scheduler 註冊到當前例項中
        self.register_modules(unet=unet, scheduler=scheduler)

    # 裝飾器,表明此函式不需要梯度計算
    @torch.no_grad()
    def __call__(
        self,
        batch_size: int = 1,  # 定義批處理大小,預設為 1
        num_inference_steps: int = 50,  # 定義推理步驟數量,預設為 50
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,  # 可選生成器
        output_type: Optional[str] = "pil",  # 可選輸出型別,預設為 "pil"
        return_dict: bool = True,  # 是否返回字典,預設為 True
        **kwargs,  # 允許額外的關鍵字引數

.\diffusers\pipelines\deprecated\stochastic_karras_ve\__init__.py

# 從 typing 模組匯入 TYPE_CHECKING,用於型別檢查
from typing import TYPE_CHECKING

# 從相對路徑匯入工具模組中的 DIFFUSERS_SLOW_IMPORT 和 _LazyModule
from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule

# 定義一個字典,描述要匯入的模組及其對應的類
_import_structure = {"pipeline_stochastic_karras_ve": ["KarrasVePipeline"]}

# 如果正在進行型別檢查或 DIFFUSERS_SLOW_IMPORT 為真
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    # 從 pipeline_stochastic_karras_ve 模組匯入 KarrasVePipeline 類
    from .pipeline_stochastic_karras_ve import KarrasVePipeline

# 否則
else:
    # 匯入 sys 模組,用於動態修改模組
    import sys

    # 使用 _LazyModule 建立懶載入模組,並將其賦值給當前模組名稱
    sys.modules[__name__] = _LazyModule(
        __name__,  # 當前模組的名稱
        globals()["__file__"],  # 當前模組的檔案路徑
        _import_structure,  # 模組結構字典
        module_spec=__spec__,  # 當前模組的規格
    )

.\diffusers\pipelines\deprecated\versatile_diffusion\modeling_text_unet.py

# 從 typing 模組匯入各種型別註解
from typing import Any, Dict, List, Optional, Tuple, Union

# 匯入 numpy 庫,用於陣列和矩陣操作
import numpy as np
# 匯入 PyTorch 庫,進行深度學習模型的構建和訓練
import torch
# 匯入 PyTorch 的神經網路模組
import torch.nn as nn
# 匯入 PyTorch 的功能性模組,提供常用操作
import torch.nn.functional as F

# 從 diffusers.utils 模組匯入 deprecate 函式,用於處理棄用警告
from diffusers.utils import deprecate

# 匯入配置相關的類和函式
from ....configuration_utils import ConfigMixin, register_to_config
# 匯入模型相關的基類
from ....models import ModelMixin
# 匯入啟用函式獲取工具
from ....models.activations import get_activation
# 匯入注意力處理器相關元件
from ....models.attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,  # 額外來鍵值注意力處理器
    CROSS_ATTENTION_PROCESSORS,      # 交叉注意力處理器
    Attention,                       # 注意力機制類
    AttentionProcessor,              # 注意力處理器基類
    AttnAddedKVProcessor,            # 額外來鍵值注意力處理器類
    AttnAddedKVProcessor2_0,         # 版本 2.0 的額外來鍵值注意力處理器
    AttnProcessor,                   # 基礎注意力處理器
)
# 匯入嵌入層相關元件
from ....models.embeddings import (
    GaussianFourierProjection,        # 高斯傅立葉投影類
    ImageHintTimeEmbedding,           # 影像提示時間嵌入類
    ImageProjection,                  # 影像投影類
    ImageTimeEmbedding,               # 影像時間嵌入類
    TextImageProjection,              # 文字影像投影類
    TextImageTimeEmbedding,           # 文字影像時間嵌入類
    TextTimeEmbedding,                # 文字時間嵌入類
    TimestepEmbedding,                # 時間步嵌入類
    Timesteps,                        # 時間步類
)
# 匯入 ResNet 相關元件
from ....models.resnet import ResnetBlockCondNorm2D
# 匯入 2D 雙重變換器模型
from ....models.transformers.dual_transformer_2d import DualTransformer2DModel
# 匯入 2D 變換器模型
from ....models.transformers.transformer_2d import Transformer2DModel
# 匯入 2D 條件 UNet 輸出類
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
# 匯入工具函式和常量
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
# 匯入 PyTorch 相關工具函式
from ....utils.torch_utils import apply_freeu

# 建立日誌記錄器例項
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定義獲取下采樣塊的函式
def get_down_block(
    down_block_type,                    # 下采樣塊型別
    num_layers,                         # 層數
    in_channels,                        # 輸入通道數
    out_channels,                       # 輸出通道數
    temb_channels,                      # 時間嵌入通道數
    add_downsample,                    # 是否新增下采樣
    resnet_eps,                         # ResNet 中的 epsilon 值
    resnet_act_fn,                     # ResNet 啟用函式
    num_attention_heads,               # 注意力頭數量
    transformer_layers_per_block,      # 每個塊中的變換器層數
    attention_type,                    # 注意力型別
    attention_head_dim,                # 注意力頭維度
    resnet_groups=None,                 # ResNet 組數(可選)
    cross_attention_dim=None,           # 交叉注意力維度(可選)
    downsample_padding=None,            # 下采樣填充(可選)
    dual_cross_attention=False,         # 是否使用雙重交叉注意力
    use_linear_projection=False,        # 是否使用線性投影
    only_cross_attention=False,         # 是否僅使用交叉注意力
    upcast_attention=False,             # 是否上升注意力
    resnet_time_scale_shift="default",  # ResNet 時間縮放偏移
    resnet_skip_time_act=False,         # ResNet 是否跳過時間啟用
    resnet_out_scale_factor=1.0,       # ResNet 輸出縮放因子
    cross_attention_norm=None,          # 交叉注意力歸一化(可選)
    dropout=0.0,                       # dropout 機率
):
    # 如果下采樣塊型別以 "UNetRes" 開頭,則去掉字首
    down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
    # 如果下采樣塊型別為 "DownBlockFlat",則返回相應的塊例項
    if down_block_type == "DownBlockFlat":
        return DownBlockFlat(
            num_layers=num_layers,        # 層數
            in_channels=in_channels,      # 輸入通道數
            out_channels=out_channels,    # 輸出通道數
            temb_channels=temb_channels,  # 時間嵌入通道數
            dropout=dropout,              # dropout 機率
            add_downsample=add_downsample, # 是否新增下采樣
            resnet_eps=resnet_eps,        # ResNet 中的 epsilon 值
            resnet_act_fn=resnet_act_fn,  # ResNet 啟用函式
            resnet_groups=resnet_groups,   # ResNet 組數(可選)
            downsample_padding=downsample_padding, # 下采樣填充(可選)
            resnet_time_scale_shift=resnet_time_scale_shift, # ResNet 時間縮放偏移
        )
    # 檢查下采樣塊型別是否為 CrossAttnDownBlockFlat
    elif down_block_type == "CrossAttnDownBlockFlat":
        # 如果沒有指定 cross_attention_dim,則丟擲值錯誤
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat")
        # 建立並返回 CrossAttnDownBlockFlat 例項,傳入所需引數
        return CrossAttnDownBlockFlat(
            # 設定網路層數
            num_layers=num_layers,
            # 設定輸入通道數
            in_channels=in_channels,
            # 設定輸出通道數
            out_channels=out_channels,
            # 設定時間嵌入通道數
            temb_channels=temb_channels,
            # 設定 dropout 比率
            dropout=dropout,
            # 設定是否新增下采樣層
            add_downsample=add_downsample,
            # 設定 ResNet 中的 epsilon 引數
            resnet_eps=resnet_eps,
            # 設定 ResNet 啟用函式
            resnet_act_fn=resnet_act_fn,
            # 設定 ResNet 組的數量
            resnet_groups=resnet_groups,
            # 設定下采樣的填充引數
            downsample_padding=downsample_padding,
            # 設定交叉注意力維度
            cross_attention_dim=cross_attention_dim,
            # 設定注意力頭的數量
            num_attention_heads=num_attention_heads,
            # 設定是否使用雙交叉注意力
            dual_cross_attention=dual_cross_attention,
            # 設定是否使用線性投影
            use_linear_projection=use_linear_projection,
            # 設定是否僅使用交叉注意力
            only_cross_attention=only_cross_attention,
            # 設定 ResNet 的時間尺度偏移
            resnet_time_scale_shift=resnet_time_scale_shift,
        )
    # 如果下采樣塊型別不被支援,則丟擲值錯誤
    raise ValueError(f"{down_block_type} is not supported.")
# 根據給定引數建立上取樣塊的函式
def get_up_block(
    # 上取樣塊型別
    up_block_type,
    # 網路層數
    num_layers,
    # 輸入通道數
    in_channels,
    # 輸出通道數
    out_channels,
    # 上一層輸出通道數
    prev_output_channel,
    # 條件嵌入通道數
    temb_channels,
    # 是否新增上取樣
    add_upsample,
    # ResNet 的 epsilon 值
    resnet_eps,
    # ResNet 的啟用函式
    resnet_act_fn,
    # 注意力頭數
    num_attention_heads,
    # 每個塊的 Transformer 層數
    transformer_layers_per_block,
    # 解析度索引
    resolution_idx,
    # 注意力型別
    attention_type,
    # 注意力頭維度
    attention_head_dim,
    # ResNet 組數,可選引數
    resnet_groups=None,
    # 跨注意力維度,可選引數
    cross_attention_dim=None,
    # 是否使用雙重跨注意力
    dual_cross_attention=False,
    # 是否使用線性投影
    use_linear_projection=False,
    # 是否僅使用跨注意力
    only_cross_attention=False,
    # 是否上溯注意力
    upcast_attention=False,
    # ResNet 時間尺度偏移,預設為 "default"
    resnet_time_scale_shift="default",
    # ResNet 是否跳過時間啟用
    resnet_skip_time_act=False,
    # ResNet 輸出縮放因子
    resnet_out_scale_factor=1.0,
    # 跨注意力歸一化型別,可選引數
    cross_attention_norm=None,
    # dropout 機率
    dropout=0.0,
):
    # 如果上取樣塊型別以 "UNetRes" 開頭,去掉字首
    up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
    # 如果塊型別是 "UpBlockFlat",則返回相應的例項
    if up_block_type == "UpBlockFlat":
        return UpBlockFlat(
            # 傳入各個引數
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            prev_output_channel=prev_output_channel,
            temb_channels=temb_channels,
            dropout=dropout,
            add_upsample=add_upsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            resnet_time_scale_shift=resnet_time_scale_shift,
        )
    # 如果塊型別是 "CrossAttnUpBlockFlat"
    elif up_block_type == "CrossAttnUpBlockFlat":
        # 檢查跨注意力維度是否指定
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat")
        # 返回相應的跨注意力上取樣塊例項
        return CrossAttnUpBlockFlat(
            # 傳入各個引數
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            prev_output_channel=prev_output_channel,
            temb_channels=temb_channels,
            dropout=dropout,
            add_upsample=add_upsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            cross_attention_dim=cross_attention_dim,
            num_attention_heads=num_attention_heads,
            dual_cross_attention=dual_cross_attention,
            use_linear_projection=use_linear_projection,
            only_cross_attention=only_cross_attention,
            resnet_time_scale_shift=resnet_time_scale_shift,
        )
    # 如果塊型別不支援,丟擲異常
    raise ValueError(f"{up_block_type} is not supported.")


# 定義一個 Fourier 嵌入器類,繼承自 nn.Module
class FourierEmbedder(nn.Module):
    # 初始化方法,設定頻率和溫度
    def __init__(self, num_freqs=64, temperature=100):
        # 呼叫父類建構函式
        super().__init__()

        # 儲存頻率數
        self.num_freqs = num_freqs
        # 儲存溫度
        self.temperature = temperature

        # 計算頻率帶
        freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
        # 擴充套件維度以便後續操作
        freq_bands = freq_bands[None, None, None]
        # 註冊頻率帶為緩衝區,設為非永續性
        self.register_buffer("freq_bands", freq_bands, persistent=False)

    # 定義呼叫方法,用於處理輸入
    def __call__(self, x):
        # 將輸入與頻率帶相乘
        x = self.freq_bands * x.unsqueeze(-1)
        # 返回處理後的結果,包含正弦和餘弦
        return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)


# 定義 GLIGEN 文字邊界框投影類,繼承自 nn.Module
class GLIGENTextBoundingboxProjection(nn.Module):
    # 初始化方法,設定物件的基本引數
        def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8):
            # 呼叫父類的初始化方法
            super().__init__()
            # 儲存正樣本的長度
            self.positive_len = positive_len
            # 儲存輸出的維度
            self.out_dim = out_dim
    
            # 初始化傅立葉嵌入器,設定頻率數量
            self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
            # 計算位置特徵的維度,包含 sin 和 cos
            self.position_dim = fourier_freqs * 2 * 4  # 2: sin/cos, 4: xyxy
    
            # 如果輸出維度是元組,取第一個元素
            if isinstance(out_dim, tuple):
                out_dim = out_dim[0]
    
            # 根據特徵型別設定線性層
            if feature_type == "text-only":
                self.linears = nn.Sequential(
                    # 第一層線性變換,輸入為正樣本長度加位置維度
                    nn.Linear(self.positive_len + self.position_dim, 512),
                    # 啟用函式使用 SiLU
                    nn.SiLU(),
                    # 第二層線性變換
                    nn.Linear(512, 512),
                    # 啟用函式使用 SiLU
                    nn.SiLU(),
                    # 輸出層
                    nn.Linear(512, out_dim),
                )
                # 定義一個全為零的引數,用於文字特徵的空值處理
                self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
    
            # 處理文字和影像的特徵型別
            elif feature_type == "text-image":
                self.linears_text = nn.Sequential(
                    # 第一層線性變換
                    nn.Linear(self.positive_len + self.position_dim, 512),
                    # 啟用函式使用 SiLU
                    nn.SiLU(),
                    # 第二層線性變換
                    nn.Linear(512, 512),
                    # 啟用函式使用 SiLU
                    nn.SiLU(),
                    # 輸出層
                    nn.Linear(512, out_dim),
                )
                self.linears_image = nn.Sequential(
                    # 第一層線性變換
                    nn.Linear(self.positive_len + self.position_dim, 512),
                    # 啟用函式使用 SiLU
                    nn.SiLU(),
                    # 第二層線性變換
                    nn.Linear(512, 512),
                    # 啟用函式使用 SiLU
                    nn.SiLU(),
                    # 輸出層
                    nn.Linear(512, out_dim),
                )
                # 定義文字特徵的空值處理引數
                self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
                # 定義影像特徵的空值處理引數
                self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
    
            # 定義位置特徵的空值處理引數
            self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
    
        # 前向傳播方法定義
        def forward(
            self,
            boxes,
            masks,
            positive_embeddings=None,
            phrases_masks=None,
            image_masks=None,
            phrases_embeddings=None,
            image_embeddings=None,
    ):
        # 在最後一維增加一個維度,便於後續操作
        masks = masks.unsqueeze(-1)

        # 透過傅立葉嵌入函式生成 boxes 的嵌入表示
        xyxy_embedding = self.fourier_embedder(boxes)
        # 獲取空白位置的特徵,並調整形狀為 (1, 1, -1)
        xyxy_null = self.null_position_feature.view(1, 1, -1)
        # 計算加權嵌入,結合 masks 和空白位置特徵
        xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null

        # 如果存在正樣本嵌入
        if positive_embeddings:
            # 獲取正樣本的空白特徵,並調整形狀為 (1, 1, -1)
            positive_null = self.null_positive_feature.view(1, 1, -1)
            # 計算正樣本嵌入的加權,結合 masks 和空白特徵
            positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null

            # 將正樣本嵌入與 xyxy 嵌入連線並透過線性層處理
            objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
        else:
            # 在最後一維增加一個維度,便於後續操作
            phrases_masks = phrases_masks.unsqueeze(-1)
            image_masks = image_masks.unsqueeze(-1)

            # 獲取文字和影像的空白特徵,並調整形狀為 (1, 1, -1)
            text_null = self.null_text_feature.view(1, 1, -1)
            image_null = self.null_image_feature.view(1, 1, -1)

            # 計算文字嵌入的加權,結合 phrases_masks 和空白特徵
            phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
            # 計算影像嵌入的加權,結合 image_masks 和空白特徵
            image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null

            # 將文字嵌入與 xyxy 嵌入連線並透過文字線性層處理
            objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
            # 將影像嵌入與 xyxy 嵌入連線並透過影像線性層處理
            objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
            # 將文字和影像的處理結果在維度 1 上連線
            objs = torch.cat([objs_text, objs_image], dim=1)

        # 返回最終的物件結果
        return objs
# 定義一個名為 UNetFlatConditionModel 的類,繼承自 ModelMixin 和 ConfigMixin
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
    r"""
    一個條件 2D UNet 模型,它接收一個有噪聲的樣本、條件狀態和時間步,並返回一個樣本形狀的輸出。

    該模型繼承自 [`ModelMixin`]。請檢視父類文件以瞭解其為所有模型實現的通用方法(例如下載或儲存)。

    """

    # 設定該模型支援梯度檢查點
    _supports_gradient_checkpointing = True
    # 定義不進行拆分的模組名稱列表
    _no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"]

    # 註冊到配置的裝飾器
    @register_to_config
    # 初始化方法,設定類的基本引數
        def __init__(
            # 樣本大小,可選引數
            self,
            sample_size: Optional[int] = None,
            # 輸入通道數,預設為4
            in_channels: int = 4,
            # 輸出通道數,預設為4
            out_channels: int = 4,
            # 是否將輸入樣本居中,預設為False
            center_input_sample: bool = False,
            # 是否將正弦函式翻轉為餘弦函式,預設為True
            flip_sin_to_cos: bool = True,
            # 頻率偏移量,預設為0
            freq_shift: int = 0,
            # 向下取樣塊的型別,預設為三個CrossAttnDownBlockFlat和一個DownBlockFlat
            down_block_types: Tuple[str] = (
                "CrossAttnDownBlockFlat",
                "CrossAttnDownBlockFlat",
                "CrossAttnDownBlockFlat",
                "DownBlockFlat",
            ),
            # 中間塊的型別,預設為UNetMidBlockFlatCrossAttn
            mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
            # 向上取樣塊的型別,預設為一個UpBlockFlat和三個CrossAttnUpBlockFlat
            up_block_types: Tuple[str] = (
                "UpBlockFlat",
                "CrossAttnUpBlockFlat",
                "CrossAttnUpBlockFlat",
                "CrossAttnUpBlockFlat",
            ),
            # 是否僅使用交叉注意力,預設為False
            only_cross_attention: Union[bool, Tuple[bool]] = False,
            # 塊輸出通道數,預設為320, 640, 1280, 1280
            block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
            # 每個塊的層數,預設為2
            layers_per_block: Union[int, Tuple[int]] = 2,
            # 向下取樣時的填充大小,預設為1
            downsample_padding: int = 1,
            # 中間塊的縮放因子,預設為1
            mid_block_scale_factor: float = 1,
            # dropout比例,預設為0.0
            dropout: float = 0.0,
            # 啟用函式型別,預設為silu
            act_fn: str = "silu",
            # 歸一化的組數,可選引數,預設為32
            norm_num_groups: Optional[int] = 32,
            # 歸一化的epsilon值,預設為1e-5
            norm_eps: float = 1e-5,
            # 交叉注意力的維度,預設為1280
            cross_attention_dim: Union[int, Tuple[int]] = 1280,
            # 每個塊的變換器層數,預設為1
            transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
            # 反向變換器層數的可選配置
            reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
            # 編碼器隱藏維度的可選引數
            encoder_hid_dim: Optional[int] = None,
            # 編碼器隱藏維度型別的可選引數
            encoder_hid_dim_type: Optional[str] = None,
            # 注意力頭的維度,預設為8
            attention_head_dim: Union[int, Tuple[int]] = 8,
            # 注意力頭數量的可選引數
            num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
            # 是否使用雙交叉注意力,預設為False
            dual_cross_attention: bool = False,
            # 是否使用線性投影,預設為False
            use_linear_projection: bool = False,
            # 類嵌入型別的可選引數
            class_embed_type: Optional[str] = None,
            # 附加嵌入型別的可選引數
            addition_embed_type: Optional[str] = None,
            # 附加時間嵌入維度的可選引數
            addition_time_embed_dim: Optional[int] = None,
            # 類嵌入數量的可選引數
            num_class_embeds: Optional[int] = None,
            # 是否向上投射注意力,預設為False
            upcast_attention: bool = False,
            # ResNet時間縮放偏移的預設值
            resnet_time_scale_shift: str = "default",
            # ResNet跳過時間啟用的設定,預設為False
            resnet_skip_time_act: bool = False,
            # ResNet輸出縮放因子,預設為1.0
            resnet_out_scale_factor: int = 1.0,
            # 時間嵌入型別,預設為positional
            time_embedding_type: str = "positional",
            # 時間嵌入維度的可選引數
            time_embedding_dim: Optional[int] = None,
            # 時間嵌入啟用函式的可選引數
            time_embedding_act_fn: Optional[str] = None,
            # 時間步後啟用的可選引數
            timestep_post_act: Optional[str] = None,
            # 時間條件投影維度的可選引數
            time_cond_proj_dim: Optional[int] = None,
            # 輸入卷積核的大小,預設為3
            conv_in_kernel: int = 3,
            # 輸出卷積核的大小,預設為3
            conv_out_kernel: int = 3,
            # 投影類嵌入輸入維度的可選引數
            projection_class_embeddings_input_dim: Optional[int] = None,
            # 注意力型別,預設為default
            attention_type: str = "default",
            # 類嵌入是否連線,預設為False
            class_embeddings_concat: bool = False,
            # 中間塊是否僅使用交叉注意力的可選引數
            mid_block_only_cross_attention: Optional[bool] = None,
            # 交叉注意力的歸一化型別的可選引數
            cross_attention_norm: Optional[str] = None,
            # 附加嵌入型別的頭數量,預設為64
            addition_embed_type_num_heads=64,
        # 宣告該方法為屬性
        @property
    # 定義一個返回注意力處理器字典的方法
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        返回值:
            `dict` 的注意力處理器: 一個字典,包含模型中使用的所有注意力處理器,以其權重名稱為索引。
        """
        # 初始化一個空字典以遞迴儲存處理器
        processors = {}

        # 定義一個遞迴函式來新增處理器
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 如果模組有獲取處理器的方法,則新增到字典中
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()

            # 遍歷模組的子模組,遞迴呼叫函式
            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            # 返回處理器字典
            return processors

        # 遍歷當前模組的子模組,並呼叫遞迴函式
        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        # 返回所有注意力處理器的字典
        return processors

    # 定義一個設定注意力處理器的方法
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        設定用於計算注意力的處理器。

        引數:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                例項化的處理器類或處理器類的字典,將被設定為所有 `Attention` 層的處理器。

                如果 `processor` 是字典,則鍵需要定義對應的交叉注意力處理器的路徑。
                在設定可訓練的注意力處理器時,強烈推薦這種做法。
        """
        # 計算當前注意力處理器的數量
        count = len(self.attn_processors.keys())

        # 如果傳入的是字典且數量不匹配,則引發錯誤
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"傳入的是處理器字典,但處理器的數量 {len(processor)} 與注意力層的數量 {count} 不匹配。"
                f" 請確保傳入 {count} 個處理器類。"
            )

        # 定義一個遞迴函式來設定注意力處理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 如果模組有設定處理器的方法,則根據傳入的處理器設定
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))

            # 遍歷模組的子模組,遞迴呼叫函式
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        # 遍歷當前模組的子模組,並呼叫遞迴函式
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
    # 設定預設的注意力處理器
    def set_default_attn_processor(self):
        """
        禁用自定義注意力處理器,並設定預設的注意力實現。
        """
        # 檢查所有注意力處理器是否屬於已新增的 KV 注意力處理器
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 使用 AttnAddedKVProcessor 作為處理器
            processor = AttnAddedKVProcessor()
        # 檢查所有注意力處理器是否屬於交叉注意力處理器
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 使用 AttnProcessor 作為處理器
            processor = AttnProcessor()
        else:
            # 如果處理器型別不匹配,則引發值錯誤
            raise ValueError(
                f"當注意力處理器的型別為 {next(iter(self.attn_processors.values()))} 時,無法呼叫 `set_default_attn_processor`"
            )

        # 設定選定的注意力處理器
        self.set_attn_processor(processor)

    # 設定梯度檢查點
    def _set_gradient_checkpointing(self, module, value=False):
        # 如果模組具有 gradient_checkpointing 屬性,則設定其值
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    # 啟用 FreeU 機制
    def enable_freeu(self, s1, s2, b1, b2):
        r"""啟用來自 https://arxiv.org/abs/2309.11497 的 FreeU 機制。

        縮放因子的字尾表示應用的階段塊。

        請參考 [官方庫](https://github.com/ChenyangSi/FreeU) 以獲取已知在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中表現良好的值組合。

        引數:
            s1 (`float`):
                階段 1 的縮放因子,用於減弱跳過特徵的貢獻。這是為了減輕增強去噪過程中的“過平滑效應”。
            s2 (`float`):
                階段 2 的縮放因子,用於減弱跳過特徵的貢獻。這是為了減輕增強去噪過程中的“過平滑效應”。
            b1 (`float`): 階段 1 的縮放因子,用於增強主幹特徵的貢獻。
            b2 (`float`): 階段 2 的縮放因子,用於增強主幹特徵的貢獻。
        """
        # 遍歷上取樣塊並設定相應的縮放因子
        for i, upsample_block in enumerate(self.up_blocks):
            setattr(upsample_block, "s1", s1)  # 設定階段 1 的縮放因子
            setattr(upsample_block, "s2", s2)  # 設定階段 2 的縮放因子
            setattr(upsample_block, "b1", b1)  # 設定階段 1 的主幹特徵縮放因子
            setattr(upsample_block, "b2", b2)  # 設定階段 2 的主幹特徵縮放因子

    # 禁用 FreeU 機制
    def disable_freeu(self):
        """禁用 FreeU 機制。"""
        freeu_keys = {"s1", "s2", "b1", "b2"}  # FreeU 機制的關鍵字集合
        # 遍歷上取樣塊並將關鍵字的值設定為 None
        for i, upsample_block in enumerate(self.up_blocks):
            for k in freeu_keys:
                # 如果上取樣塊具有該屬性或屬性值不為 None,則將其設定為 None
                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
                    setattr(upsample_block, k, None)
    # 定義一個用於融合 QKV 投影的函式
    def fuse_qkv_projections(self):
        # 文件字串,描述該函式的作用及實驗性質
        """
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
        are fused. For cross-attention modules, key and value projection matrices are fused.
    
        <Tip warning={true}>
    
        This API is 🧪 experimental.
    
        </Tip>
        """
        # 初始化原始注意力處理器為 None
        self.original_attn_processors = None
    
        # 遍歷所有注意力處理器
        for _, attn_processor in self.attn_processors.items():
            # 檢查處理器的類名是否包含 "Added"
            if "Added" in str(attn_processor.__class__.__name__):
                # 如果是,丟擲錯誤提示不支援融合
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
    
        # 儲存原始的注意力處理器
        self.original_attn_processors = self.attn_processors
    
        # 遍歷所有模組
        for module in self.modules():
            # 檢查模組是否是 Attention 類的例項
            if isinstance(module, Attention):
                # 融合投影
                module.fuse_projections(fuse=True)
    
    # 定義一個用於取消 QKV 投影融合的函式
    def unfuse_qkv_projections(self):
        # 文件字串,描述該函式的作用及實驗性質
        """Disables the fused QKV projection if enabled.
    
        <Tip warning={true}>
    
        This API is 🧪 experimental.
    
        </Tip>
    
        """
        # 檢查原始注意力處理器是否不為 None
        if self.original_attn_processors is not None:
            # 恢復到原始的注意力處理器
            self.set_attn_processor(self.original_attn_processors)
    
    # 定義一個用於解除安裝 LoRA 權重的函式
    def unload_lora(self):
        # 文件字串,描述該函式的作用
        """Unloads LoRA weights."""
        # 發出解除安裝的棄用警告
        deprecate(
            "unload_lora",
            "0.28.0",
            "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
        )
        # 遍歷所有模組
        for module in self.modules():
            # 檢查模組是否具有 set_lora_layer 屬性
            if hasattr(module, "set_lora_layer"):
                # 將 LoRA 層設定為 None
                module.set_lora_layer(None)
    
    # 定義前向傳播函式
    def forward(
        self,
        sample: torch.Tensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
        down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        return_dict: bool = True,
# 定義一個繼承自 nn.Linear 的線性多維層
class LinearMultiDim(nn.Linear):
    # 初始化方法,接受輸入特徵、輸出特徵及其他引數
    def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs):
        # 如果 in_features 是整數,則將其轉換為包含三個維度的列表
        in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features)
        # 如果未提供 out_features,則將其設定為 in_features
        if out_features is None:
            out_features = in_features
        # 如果 out_features 是整數,則轉換為包含三個維度的列表
        out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features)
        # 儲存輸入特徵的多維資訊
        self.in_features_multidim = in_features
        # 儲存輸出特徵的多維資訊
        self.out_features_multidim = out_features
        # 呼叫父類的初始化方法,計算輸入和輸出特徵的總數量
        super().__init__(np.array(in_features).prod(), np.array(out_features).prod())

    # 定義前向傳播方法
    def forward(self, input_tensor, *args, **kwargs):
        # 獲取輸入張量的形狀
        shape = input_tensor.shape
        # 獲取輸入特徵的維度數量
        n_dim = len(self.in_features_multidim)
        # 將輸入張量重塑為適合線性層的形狀
        input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features)
        # 呼叫父類的前向傳播方法,得到輸出張量
        output_tensor = super().forward(input_tensor)
        # 將輸出張量重塑為目標形狀
        output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim)
        # 返回輸出張量
        return output_tensor


# 定義一個平坦的殘差塊類,繼承自 nn.Module
class ResnetBlockFlat(nn.Module):
    # 初始化方法,接受多個引數,包括通道數、丟棄率等
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        time_embedding_norm="default",
        use_in_shortcut=None,
        second_dim=4,
        **kwargs,
    # 初始化方法的結束,接收引數
        ):
            # 呼叫父類的初始化方法
            super().__init__()
            # 是否進行預歸一化,設定為傳入的值
            self.pre_norm = pre_norm
            # 將預歸一化設定為 True
            self.pre_norm = True
    
            # 如果輸入通道是整數,則構造一個包含三個維度的列表
            in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels)
            # 計算輸入通道數的乘積
            self.in_channels_prod = np.array(in_channels).prod()
            # 儲存輸入通道的多維資訊
            self.channels_multidim = in_channels
    
            # 如果輸出通道不為 None
            if out_channels is not None:
                # 如果輸出通道是整數,構造一個包含三個維度的列表
                out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels)
                # 計算輸出通道數的乘積
                out_channels_prod = np.array(out_channels).prod()
                # 儲存輸出通道的多維資訊
                self.out_channels_multidim = out_channels
            else:
                # 如果輸出通道為 None,則輸出通道乘積等於輸入通道乘積
                out_channels_prod = self.in_channels_prod
                # 輸出通道的多維資訊與輸入通道相同
                self.out_channels_multidim = self.channels_multidim
            # 儲存時間嵌入的歸一化狀態
            self.time_embedding_norm = time_embedding_norm
    
            # 如果輸出組數為 None,使用傳入的組數
            if groups_out is None:
                groups_out = groups
    
            # 建立第一個歸一化層,使用組歸一化
            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True)
            # 建立第一個卷積層,使用輸入通道和輸出通道乘積
            self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0)
    
            # 如果時間嵌入通道不為 None
            if temb_channels is not None:
                # 建立時間嵌入投影層
                self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod)
            else:
                # 如果時間嵌入通道為 None,則不進行投影
                self.time_emb_proj = None
    
            # 建立第二個歸一化層,使用輸出組數和輸出通道乘積
            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True)
            # 建立丟棄層,使用傳入的丟棄率
            self.dropout = torch.nn.Dropout(dropout)
            # 建立第二個卷積層,使用輸出通道乘積
            self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0)
    
            # 設定非線性啟用函式為 SiLU
            self.nonlinearity = nn.SiLU()
    
            # 檢查是否使用輸入短路,如果短路使用引數為 None,則根據通道數判斷
            self.use_in_shortcut = (
                self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut
            )
    
            # 初始化快捷連線卷積為 None
            self.conv_shortcut = None
            # 如果使用輸入短路
            if self.use_in_shortcut:
                # 建立快捷連線卷積層
                self.conv_shortcut = torch.nn.Conv2d(
                    self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0
                )
    # 定義前向傳播方法,接收輸入張量和時間嵌入
        def forward(self, input_tensor, temb):
            # 獲取輸入張量的形狀
            shape = input_tensor.shape
            # 獲取多維通道的維度數
            n_dim = len(self.channels_multidim)
            # 調整輸入張量形狀,合併通道維度並增加兩個維度
            input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1)
            # 將張量檢視轉換為指定形狀,保持通道數並增加兩個維度
            input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1)
    
            # 初始化隱藏狀態為輸入張量
            hidden_states = input_tensor
    
            # 對隱藏狀態進行歸一化處理
            hidden_states = self.norm1(hidden_states)
            # 應用非線性啟用函式
            hidden_states = self.nonlinearity(hidden_states)
            # 透過第一個卷積層處理隱藏狀態
            hidden_states = self.conv1(hidden_states)
    
            # 如果時間嵌入不為空
            if temb is not None:
                # 對時間嵌入進行非線性處理並調整形狀
                temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
                # 將時間嵌入與隱藏狀態相加
                hidden_states = hidden_states + temb
    
            # 對隱藏狀態進行第二次歸一化處理
            hidden_states = self.norm2(hidden_states)
            # 再次應用非線性啟用函式
            hidden_states = self.nonlinearity(hidden_states)
    
            # 對隱藏狀態應用 dropout 操作
            hidden_states = self.dropout(hidden_states)
            # 透過第二個卷積層處理隱藏狀態
            hidden_states = self.conv2(hidden_states)
    
            # 如果存在短路卷積層
            if self.conv_shortcut is not None:
                # 透過短路卷積層處理輸入張量
                input_tensor = self.conv_shortcut(input_tensor)
    
            # 將輸入張量與隱藏狀態相加,生成輸出張量
            output_tensor = input_tensor + hidden_states
    
            # 將輸出張量調整為指定形狀,去掉多餘的維度
            output_tensor = output_tensor.view(*shape[0:-n_dim], -1)
            # 再次調整輸出張量的形狀,匹配輸出通道的多維結構
            output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim)
    
            # 返回最終的輸出張量
            return output_tensor
# 定義一個名為 DownBlockFlat 的類,繼承自 nn.Module
class DownBlockFlat(nn.Module):
    # 初始化方法,接受多個引數用於配置模型
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        out_channels: int,  # 輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        dropout: float = 0.0,  # dropout 機率
        num_layers: int = 1,  # ResNet 層數
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值
        resnet_time_scale_shift: str = "default",  # ResNet 的時間縮放偏移
        resnet_act_fn: str = "swish",  # ResNet 的啟用函式
        resnet_groups: int = 32,  # ResNet 的分組數
        resnet_pre_norm: bool = True,  # 是否在 ResNet 前進行歸一化
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        add_downsample: bool = True,  # 是否新增下采樣層
        downsample_padding: int = 1,  # 下采樣時的填充
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 初始化一個空列表,用於存放 ResNet 層
        resnets = []

        # 迴圈建立指定數量的 ResNet 層
        for i in range(num_layers):
            # 第一層使用輸入通道,之後的層使用輸出通道
            in_channels = in_channels if i == 0 else out_channels
            # 將 ResNet 層新增到列表中
            resnets.append(
                ResnetBlockFlat(
                    in_channels=in_channels,  # 當前層的輸入通道數
                    out_channels=out_channels,  # 當前層的輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # epsilon 值
                    groups=resnet_groups,  # 分組數
                    dropout=dropout,  # dropout 機率
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化方式
                    non_linearity=resnet_act_fn,  # 啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 是否前歸一化
                )
            )

        # 將 ResNet 層列表轉為 nn.ModuleList 以便於管理
        self.resnets = nn.ModuleList(resnets)

        # 根據引數決定是否新增下采樣層
        if add_downsample:
            self.downsamplers = nn.ModuleList(
                [
                    LinearMultiDim(
                        out_channels,  # 輸入通道數
                        use_conv=True,  # 使用卷積
                        out_channels=out_channels,  # 輸出通道數
                        padding=downsample_padding,  # 填充
                        name="op"  # 下采樣層名稱
                    )
                ]
            )
        else:
            # 如果不新增下采樣層,設定為 None
            self.downsamplers = None

        # 初始化梯度檢查點為 False
        self.gradient_checkpointing = False

    # 定義前向傳播方法
    def forward(
        self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None  # 輸入的隱藏狀態和可選的時間嵌入
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
        # 初始化輸出狀態為一個空元組
        output_states = ()

        # 遍歷所有的 ResNet 層
        for resnet in self.resnets:
            # 如果在訓練模式且開啟了梯度檢查點
            if self.training and self.gradient_checkpointing:
                # 定義一個建立自定義前向傳播的方法
                def create_custom_forward(module):
                    # 定義自定義前向傳播函式
                    def custom_forward(*inputs):
                        return module(*inputs)  # 呼叫模組進行前向傳播

                    return custom_forward

                # 檢查 PyTorch 版本,使用不同的呼叫方式
                if is_torch_version(">=", "1.11.0"):
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False  # 進行梯度檢查點的前向傳播
                    )
                else:
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb  # 進行梯度檢查點的前向傳播
                    )
            else:
                # 正常呼叫 ResNet 層進行前向傳播
                hidden_states = resnet(hidden_states, temb)

            # 將當前隱藏狀態新增到輸出狀態中
            output_states = output_states + (hidden_states,)

        # 如果存在下采樣層
        if self.downsamplers is not None:
            # 遍歷所有下采樣層
            for downsampler in self.downsamplers:
                hidden_states = downsampler(hidden_states)  # 對隱藏狀態進行下采樣

            # 將下采樣後的隱藏狀態新增到輸出狀態中
            output_states = output_states + (hidden_states,)

        # 返回最終的隱藏狀態和所有輸出狀態
        return hidden_states, output_states
# 定義一個名為 CrossAttnDownBlockFlat 的類,繼承自 nn.Module
class CrossAttnDownBlockFlat(nn.Module):
    # 初始化方法,定義類的屬性
    def __init__(
        # 輸入通道數
        self,
        in_channels: int,
        # 輸出通道數
        out_channels: int,
        # 時間嵌入通道數
        temb_channels: int,
        # dropout 機率,預設為 0.0
        dropout: float = 0.0,
        # 層數,預設為 1
        num_layers: int = 1,
        # 每個塊的變換器層數,可以是整數或整數元組,預設為 1
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
        # ResNet 的 epsilon 值,預設為 1e-6
        resnet_eps: float = 1e-6,
        # ResNet 的時間尺度偏移設定,預設為 "default"
        resnet_time_scale_shift: str = "default",
        # ResNet 的啟用函式,預設為 "swish"
        resnet_act_fn: str = "swish",
        # ResNet 的組數,預設為 32
        resnet_groups: int = 32,
        # 是否使用預歸一化,預設為 True
        resnet_pre_norm: bool = True,
        # 注意力頭的數量,預設為 1
        num_attention_heads: int = 1,
        # 交叉注意力的維度,預設為 1280
        cross_attention_dim: int = 1280,
        # 輸出縮放因子,預設為 1.0
        output_scale_factor: float = 1.0,
        # 下采樣的填充大小,預設為 1
        downsample_padding: int = 1,
        # 是否新增下采樣層,預設為 True
        add_downsample: bool = True,
        # 是否使用雙重交叉注意力,預設為 False
        dual_cross_attention: bool = False,
        # 是否使用線性投影,預設為 False
        use_linear_projection: bool = False,
        # 是否只使用交叉注意力,預設為 False
        only_cross_attention: bool = False,
        # 是否上溯注意力,預設為 False
        upcast_attention: bool = False,
        # 注意力型別,預設為 "default"
        attention_type: str = "default",
    ):
        # 呼叫父類的建構函式以初始化基類
        super().__init__()
        # 初始化儲存 ResNet 塊的列表
        resnets = []
        # 初始化儲存注意力模型的列表
        attentions = []

        # 設定是否使用交叉注意力的標誌
        self.has_cross_attention = True
        # 設定注意力頭的數量
        self.num_attention_heads = num_attention_heads
        # 如果 transformer_layers_per_block 是一個整數,則將其轉換為列表形式
        if isinstance(transformer_layers_per_block, int):
            transformer_layers_per_block = [transformer_layers_per_block] * num_layers

        # 為每一層構建 ResNet 塊和注意力模型
        for i in range(num_layers):
            # 設定當前層的輸入通道數,第一層使用 in_channels,其他層使用 out_channels
            in_channels = in_channels if i == 0 else out_channels
            # 向 resnets 列表新增一個 ResNet 塊
            resnets.append(
                ResnetBlockFlat(
                    # 設定 ResNet 塊的輸入通道數
                    in_channels=in_channels,
                    # 設定 ResNet 塊的輸出通道數
                    out_channels=out_channels,
                    # 設定時間嵌入通道數
                    temb_channels=temb_channels,
                    # 設定 ResNet 塊的 epsilon 值
                    eps=resnet_eps,
                    # 設定 ResNet 塊的組數
                    groups=resnet_groups,
                    # 設定 dropout 機率
                    dropout=dropout,
                    # 設定時間嵌入的歸一化方法
                    time_embedding_norm=resnet_time_scale_shift,
                    # 設定啟用函式
                    non_linearity=resnet_act_fn,
                    # 設定輸出縮放因子
                    output_scale_factor=output_scale_factor,
                    # 設定是否在前面進行歸一化
                    pre_norm=resnet_pre_norm,
                )
            )
            # 如果不使用雙交叉注意力
            if not dual_cross_attention:
                # 向 attentions 列表新增一個 Transformer 2D 模型
                attentions.append(
                    Transformer2DModel(
                        # 設定注意力頭的數量
                        num_attention_heads,
                        # 設定每個注意力頭的輸出通道數
                        out_channels // num_attention_heads,
                        # 設定輸入通道數
                        in_channels=out_channels,
                        # 設定當前層的 Transformer 層數
                        num_layers=transformer_layers_per_block[i],
                        # 設定交叉注意力的維度
                        cross_attention_dim=cross_attention_dim,
                        # 設定歸一化的組數
                        norm_num_groups=resnet_groups,
                        # 設定是否使用線性投影
                        use_linear_projection=use_linear_projection,
                        # 設定是否僅使用交叉注意力
                        only_cross_attention=only_cross_attention,
                        # 設定是否提高注意力精度
                        upcast_attention=upcast_attention,
                        # 設定注意力型別
                        attention_type=attention_type,
                    )
                )
            else:
                # 向 attentions 列表新增一個雙 Transformer 2D 模型
                attentions.append(
                    DualTransformer2DModel(
                        # 設定注意力頭的數量
                        num_attention_heads,
                        # 設定每個注意力頭的輸出通道數
                        out_channels // num_attention_heads,
                        # 設定輸入通道數
                        in_channels=out_channels,
                        # 固定層數為 1
                        num_layers=1,
                        # 設定交叉注意力的維度
                        cross_attention_dim=cross_attention_dim,
                        # 設定歸一化的組數
                        norm_num_groups=resnet_groups,
                    )
                )
        # 將注意力模型列表轉換為 PyTorch 的 ModuleList
        self.attentions = nn.ModuleList(attentions)
        # 將 ResNet 塊列表轉換為 PyTorch 的 ModuleList
        self.resnets = nn.ModuleList(resnets)

        # 如果需要新增下采樣層
        if add_downsample:
            # 初始化下采樣層為 ModuleList
            self.downsamplers = nn.ModuleList(
                [
                    LinearMultiDim(
                        # 設定輸出通道數
                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
                    )
                ]
            )
        else:
            # 如果不新增下采樣層,將其設為 None
            self.downsamplers = None

        # 初始化梯度檢查點標誌為 False
        self.gradient_checkpointing = False
    # 定義前向傳播函式,接收隱藏狀態和其他可選引數
        def forward(
            self,
            hidden_states: torch.Tensor,  # 當前隱藏狀態的張量
            temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
            encoder_hidden_states: Optional[torch.Tensor] = None,  # 可選的編碼器隱藏狀態張量
            attention_mask: Optional[torch.Tensor] = None,  # 可選的注意力掩碼
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 可選的交叉注意力引數
            encoder_attention_mask: Optional[torch.Tensor] = None,  # 可選的編碼器注意力掩碼
            additional_residuals: Optional[torch.Tensor] = None,  # 可選的額外殘差張量
        ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:  # 返回隱藏狀態和輸出狀態元組
            output_states = ()  # 初始化輸出狀態元組
    
            blocks = list(zip(self.resnets, self.attentions))  # 將殘差網路和注意力模組配對成塊
    
            for i, (resnet, attn) in enumerate(blocks):  # 遍歷每個塊及其索引
                if self.training and self.gradient_checkpointing:  # 檢查是否在訓練且啟用梯度檢查點
    
                    def create_custom_forward(module, return_dict=None):  # 定義自定義前向傳播函式
                        def custom_forward(*inputs):  # 自定義前向傳播邏輯
                            if return_dict is not None:  # 如果提供了返回字典
                                return module(*inputs, return_dict=return_dict)  # 返回帶字典的結果
                            else:
                                return module(*inputs)  # 否則返回普通結果
    
                        return custom_forward  # 返回自定義前向傳播函式
    
                    # 設定檢查點引數,如果 PyTorch 版本大於等於 1.11.0,則使用非重入模式
                    ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                    # 透過檢查點機制計算當前塊的隱藏狀態
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),  # 建立自定義前向函式的檢查點
                        hidden_states,  # 輸入當前隱藏狀態
                        temb,  # 輸入時間嵌入
                        **ckpt_kwargs,  # 傳遞檢查點引數
                    )
                    # 透過注意力模組處理隱藏狀態並獲取輸出
                    hidden_states = attn(
                        hidden_states,
                        encoder_hidden_states=encoder_hidden_states,  # 編碼器隱藏狀態
                        cross_attention_kwargs=cross_attention_kwargs,  # 交叉注意力引數
                        attention_mask=attention_mask,  # 注意力掩碼
                        encoder_attention_mask=encoder_attention_mask,  # 編碼器注意力掩碼
                        return_dict=False,  # 不返回字典格式
                    )[0]  # 取出第一個輸出
                else:  # 如果不啟用梯度檢查
                    # 直接透過殘差網路處理隱藏狀態
                    hidden_states = resnet(hidden_states, temb)
                    # 透過注意力模組處理隱藏狀態並獲取輸出
                    hidden_states = attn(
                        hidden_states,
                        encoder_hidden_states=encoder_hidden_states,  # 編碼器隱藏狀態
                        cross_attention_kwargs=cross_attention_kwargs,  # 交叉注意力引數
                        attention_mask=attention_mask,  # 注意力掩碼
                        encoder_attention_mask=encoder_attention_mask,  # 編碼器注意力掩碼
                        return_dict=False,  # 不返回字典格式
                    )[0]  # 取出第一個輸出
    
                # 如果是最後一個塊並且提供了額外殘差,則將其新增到隱藏狀態
                if i == len(blocks) - 1 and additional_residuals is not None:
                    hidden_states = hidden_states + additional_residuals  # 加上額外殘差
    
                output_states = output_states + (hidden_states,)  # 將當前隱藏狀態新增到輸出狀態元組中
    
            if self.downsamplers is not None:  # 如果存在下采樣器
                for downsampler in self.downsamplers:  # 遍歷每個下采樣器
                    hidden_states = downsampler(hidden_states)  # 處理當前隱藏狀態
    
                output_states = output_states + (hidden_states,)  # 將當前隱藏狀態新增到輸出狀態元組中
    
            return hidden_states, output_states  # 返回最終的隱藏狀態和輸出狀態元組
# 從 diffusers.models.unets.unet_2d_blocks 中複製,替換 UpBlock2D 為 UpBlockFlat,ResnetBlock2D 為 ResnetBlockFlat,Upsample2D 為 LinearMultiDim
class UpBlockFlat(nn.Module):
    # 初始化函式,定義輸入輸出通道及其他引數
    def __init__(
        self,
        in_channels: int,  # 輸入通道數
        prev_output_channel: int,  # 前一層輸出通道數
        out_channels: int,  # 當前層輸出通道數
        temb_channels: int,  # 時間嵌入通道數
        resolution_idx: Optional[int] = None,  # 解析度索引
        dropout: float = 0.0,  # dropout 機率
        num_layers: int = 1,  # 層數
        resnet_eps: float = 1e-6,  # ResNet 中的 epsilon 值
        resnet_time_scale_shift: str = "default",  # 時間尺度偏移設定
        resnet_act_fn: str = "swish",  # 啟用函式型別
        resnet_groups: int = 32,  # 分組數
        resnet_pre_norm: bool = True,  # 是否進行預歸一化
        output_scale_factor: float = 1.0,  # 輸出縮放因子
        add_upsample: bool = True,  # 是否新增上取樣
    ):
        # 呼叫父類建構函式
        super().__init__()
        # 初始化一個空列表儲存 ResNet 塊
        resnets = []

        # 遍歷層數,構建每一層的 ResNet 塊
        for i in range(num_layers):
            # 根據層數決定殘差跳躍通道數
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 根據當前層數決定輸入通道數
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 將 ResNet 塊新增到列表中
            resnets.append(
                ResnetBlockFlat(
                    in_channels=resnet_in_channels + res_skip_channels,  # 輸入通道數
                    out_channels=out_channels,  # 輸出通道數
                    temb_channels=temb_channels,  # 時間嵌入通道數
                    eps=resnet_eps,  # epsilon 值
                    groups=resnet_groups,  # 分組數
                    dropout=dropout,  # dropout 機率
                    time_embedding_norm=resnet_time_scale_shift,  # 時間嵌入歸一化
                    non_linearity=resnet_act_fn,  # 啟用函式
                    output_scale_factor=output_scale_factor,  # 輸出縮放因子
                    pre_norm=resnet_pre_norm,  # 預歸一化
                )
            )

        # 將 ResNet 塊列表轉換為模組列表
        self.resnets = nn.ModuleList(resnets)

        # 如果需要新增上取樣層,則建立上取樣模組
        if add_upsample:
            self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            # 否則設定為 None
            self.upsamplers = None

        # 初始化梯度檢查點標誌
        self.gradient_checkpointing = False
        # 設定解析度索引
        self.resolution_idx = resolution_idx

    # 前向傳播函式
    def forward(
        self,
        hidden_states: torch.Tensor,  # 隱藏狀態張量
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 殘差隱藏狀態元組
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
        upsample_size: Optional[int] = None,  # 可選的上取樣大小
        *args,  # 可變引數
        **kwargs,  # 可變關鍵字引數
    ) -> torch.Tensor:  # 定義一個返回 torch.Tensor 型別的函式
        # 如果引數列表 args 長度大於 0 或 kwargs 中的 scale 引數不為 None
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 設定棄用訊息,提醒使用者 scale 引數已棄用且將被忽略
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫 deprecate 函式,記錄 scale 引數的棄用
            deprecate("scale", "1.0.0", deprecation_message)

        # 檢查 FreeU 是否啟用,取決於 s1, s2, b1 和 b2 的值
        is_freeu_enabled = (
            getattr(self, "s1", None)  # 獲取 self 中的 s1 屬性
            and getattr(self, "s2", None)  # 獲取 self 中的 s2 屬性
            and getattr(self, "b1", None)  # 獲取 self 中的 b1 屬性
            and getattr(self, "b2", None)  # 獲取 self 中的 b2 屬性
        )

        # 遍歷 self.resnets 中的每個 ResNet 模型
        for resnet in self.resnets:
            # 彈出 res 隱藏狀態的最後一個元素
            res_hidden_states = res_hidden_states_tuple[-1]  
            # 移除 res 隱藏狀態元組的最後一個元素
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]  

            # FreeU: 僅在前兩個階段進行操作
            if is_freeu_enabled:
                # 應用 FreeU 操作,返回更新後的 hidden_states 和 res_hidden_states
                hidden_states, res_hidden_states = apply_freeu(
                    self.resolution_idx,  # 當前解析度索引
                    hidden_states,  # 當前隱藏狀態
                    res_hidden_states,  # 之前的隱藏狀態
                    s1=self.s1,  # s1 引數
                    s2=self.s2,  # s2 引數
                    b1=self.b1,  # b1 引數
                    b2=self.b2,  # b2 引數
                )

            # 將當前的 hidden_states 和 res_hidden_states 在維度 1 上拼接
            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)  

            # 如果處於訓練模式並且開啟了梯度檢查點
            if self.training and self.gradient_checkpointing:
                # 定義一個建立自定義前向函式的函式
                def create_custom_forward(module):
                    # 定義自定義前向函式,接收輸入並呼叫模組
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                # 如果 PyTorch 版本大於等於 1.11.0
                if is_torch_version(">=", "1.11.0"):
                    # 使用梯度檢查點來計算 hidden_states
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),  # 使用自定義前向函式
                        hidden_states,  # 當前隱藏狀態
                        temb,  # 傳入的額外輸入
                        use_reentrant=False  # 禁用重入檢查
                    )
                else:
                    # 對於早期版本,使用梯度檢查點計算 hidden_states
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),  # 使用自定義前向函式
                        hidden_states,  # 當前隱藏狀態
                        temb  # 傳入的額外輸入
                    )
            else:
                # 在非訓練模式下直接呼叫 resnet 處理 hidden_states
                hidden_states = resnet(hidden_states, temb)  

        # 如果存在上取樣器
        if self.upsamplers is not None:
            # 遍歷所有上取樣器
            for upsampler in self.upsamplers:
                # 使用上取樣器對 hidden_states 進行處理,指定上取樣尺寸
                hidden_states = upsampler(hidden_states, upsample_size)  

        # 返回處理後的 hidden_states
        return hidden_states  
# 從 diffusers.models.unets.unet_2d_blocks 中複製的程式碼,修改了類名和一些元件
class CrossAttnUpBlockFlat(nn.Module):
    # 初始化方法,定義類的基本屬性和引數
    def __init__(
        # 輸入通道數
        in_channels: int,
        # 輸出通道數
        out_channels: int,
        # 上一層輸出的通道數
        prev_output_channel: int,
        # 額外的時間嵌入通道數
        temb_channels: int,
        # 可選的解析度索引
        resolution_idx: Optional[int] = None,
        # dropout 機率
        dropout: float = 0.0,
        # 層數
        num_layers: int = 1,
        # 每個塊的變換器層數,可以是單個整數或元組
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
        # ResNet 的 epsilon 值
        resnet_eps: float = 1e-6,
        # ResNet 時間尺度偏移的型別
        resnet_time_scale_shift: str = "default",
        # ResNet 啟用函式的型別
        resnet_act_fn: str = "swish",
        # ResNet 的組數
        resnet_groups: int = 32,
        # 是否在 ResNet 中使用預歸一化
        resnet_pre_norm: bool = True,
        # 注意力頭的數量
        num_attention_heads: int = 1,
        # 交叉注意力的維度
        cross_attention_dim: int = 1280,
        # 輸出縮放因子
        output_scale_factor: float = 1.0,
        # 是否新增上取樣步驟
        add_upsample: bool = True,
        # 是否使用雙交叉注意力
        dual_cross_attention: bool = False,
        # 是否使用線性投影
        use_linear_projection: bool = False,
        # 是否僅使用交叉注意力
        only_cross_attention: bool = False,
        # 是否上溯注意力
        upcast_attention: bool = False,
        # 注意力型別
        attention_type: str = "default",
    # 定義建構函式的結束部分
        ):
            # 呼叫父類的建構函式
            super().__init__()
            # 初始化一個空列表用於儲存殘差網路塊
            resnets = []
            # 初始化一個空列表用於儲存注意力模型
            attentions = []
    
            # 設定是否使用交叉注意力標誌為真
            self.has_cross_attention = True
            # 設定注意力頭的數量
            self.num_attention_heads = num_attention_heads
    
            # 如果 transformer_layers_per_block 是整數,則將其轉換為相同長度的列表
            if isinstance(transformer_layers_per_block, int):
                transformer_layers_per_block = [transformer_layers_per_block] * num_layers
    
            # 遍歷每一層以構建殘差網路和注意力模型
            for i in range(num_layers):
                # 設定殘差跳過通道數,最後一層使用輸入通道,否則使用輸出通道
                res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
                # 設定殘差網路輸入通道數,第一層使用前一層輸出通道,否則使用當前輸出通道
                resnet_in_channels = prev_output_channel if i == 0 else out_channels
    
                # 新增一個殘差網路塊到 resnets 列表中
                resnets.append(
                    ResnetBlockFlat(
                        # 設定殘差網路輸入通道數
                        in_channels=resnet_in_channels + res_skip_channels,
                        # 設定殘差網路輸出通道數
                        out_channels=out_channels,
                        # 設定時間嵌入通道數
                        temb_channels=temb_channels,
                        # 設定殘差網路的 epsilon 值
                        eps=resnet_eps,
                        # 設定殘差網路的組數
                        groups=resnet_groups,
                        # 設定丟棄率
                        dropout=dropout,
                        # 設定時間嵌入的歸一化方法
                        time_embedding_norm=resnet_time_scale_shift,
                        # 設定非線性啟用函式
                        non_linearity=resnet_act_fn,
                        # 設定輸出縮放因子
                        output_scale_factor=output_scale_factor,
                        # 設定是否進行預歸一化
                        pre_norm=resnet_pre_norm,
                    )
                )
                # 如果不使用雙重交叉注意力
                if not dual_cross_attention:
                    # 新增一個普通的 Transformer2DModel 到 attentions 列表中
                    attentions.append(
                        Transformer2DModel(
                            # 設定注意力頭的數量
                            num_attention_heads,
                            # 設定每個注意力頭的輸出通道數
                            out_channels // num_attention_heads,
                            # 設定輸入通道數
                            in_channels=out_channels,
                            # 設定層數
                            num_layers=transformer_layers_per_block[i],
                            # 設定交叉注意力維度
                            cross_attention_dim=cross_attention_dim,
                            # 設定歸一化組數
                            norm_num_groups=resnet_groups,
                            # 設定是否使用線性投影
                            use_linear_projection=use_linear_projection,
                            # 設定是否僅使用交叉注意力
                            only_cross_attention=only_cross_attention,
                            # 設定是否上溯注意力
                            upcast_attention=upcast_attention,
                            # 設定注意力型別
                            attention_type=attention_type,
                        )
                    )
                else:
                    # 新增一個雙重 Transformer2DModel 到 attentions 列表中
                    attentions.append(
                        DualTransformer2DModel(
                            # 設定注意力頭的數量
                            num_attention_heads,
                            # 設定每個注意力頭的輸出通道數
                            out_channels // num_attention_heads,
                            # 設定輸入通道數
                            in_channels=out_channels,
                            # 設定層數為 1
                            num_layers=1,
                            # 設定交叉注意力維度
                            cross_attention_dim=cross_attention_dim,
                            # 設定歸一化組數
                            norm_num_groups=resnet_groups,
                        )
                    )
            # 將注意力模型列表轉換為 nn.ModuleList
            self.attentions = nn.ModuleList(attentions)
            # 將殘差網路塊列表轉換為 nn.ModuleList
            self.resnets = nn.ModuleList(resnets)
    
            # 如果需要新增上取樣層
            if add_upsample:
                # 將上取樣器新增到 nn.ModuleList 中
                self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)])
            else:
                # 否則將上取樣器設定為 None
                self.upsamplers = None
    
            # 設定梯度檢查點標誌為假
            self.gradient_checkpointing = False
            # 設定解析度索引
            self.resolution_idx = resolution_idx
    # 定義前向傳播函式,接收多個輸入引數
        def forward(
            self,
            # 隱藏狀態,型別為 PyTorch 的張量
            hidden_states: torch.Tensor,
            # 包含殘差隱藏狀態的元組,元素型別為 PyTorch 張量
            res_hidden_states_tuple: Tuple[torch.Tensor, ...],
            # 可選的時間嵌入,型別為 PyTorch 的張量
            temb: Optional[torch.Tensor] = None,
            # 可選的編碼器隱藏狀態,型別為 PyTorch 的張量
            encoder_hidden_states: Optional[torch.Tensor] = None,
            # 可選的交叉注意力引數,型別為字典,包含任意鍵值對
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 可選的上取樣大小,型別為整數
            upsample_size: Optional[int] = None,
            # 可選的注意力掩碼,型別為 PyTorch 的張量
            attention_mask: Optional[torch.Tensor] = None,
            # 可選的編碼器注意力掩碼,型別為 PyTorch 的張量
            encoder_attention_mask: Optional[torch.Tensor] = None,
# 從 diffusers.models.unets.unet_2d_blocks 中複製的 UNetMidBlock2D 程式碼,替換了 UNetMidBlock2D 為 UNetMidBlockFlat,ResnetBlock2D 為 ResnetBlockFlat
class UNetMidBlockFlat(nn.Module):
    """
    2D UNet 中間塊 [`UNetMidBlockFlat`],包含多個殘差塊和可選的注意力塊。

    引數:
        in_channels (`int`): 輸入通道的數量。
        temb_channels (`int`): 時間嵌入通道的數量。
        dropout (`float`, *可選*, 預設值為 0.0): dropout 比率。
        num_layers (`int`, *可選*, 預設值為 1): 殘差塊的數量。
        resnet_eps (`float`, *可選*, 預設值為 1e-6): resnet 塊的 epsilon 值。
        resnet_time_scale_shift (`str`, *可選*, 預設值為 `default`):
            應用於時間嵌入的歸一化型別。這可以幫助提高模型在長範圍時間依賴任務上的效能。
        resnet_act_fn (`str`, *可選*, 預設值為 `swish`): resnet 塊的啟用函式。
        resnet_groups (`int`, *可選*, 預設值為 32):
            resnet 塊的分組歸一化層使用的組數。
        attn_groups (`Optional[int]`, *可選*, 預設值為 None): 注意力塊的組數。
        resnet_pre_norm (`bool`, *可選*, 預設值為 `True`):
            是否在 resnet 塊中使用預歸一化。
        add_attention (`bool`, *可選*, 預設值為 `True`): 是否新增註意力塊。
        attention_head_dim (`int`, *可選*, 預設值為 1):
            單個注意力頭的維度。注意力頭的數量基於此值和輸入通道的數量確定。
        output_scale_factor (`float`, *可選*, 預設值為 1.0): 輸出縮放因子。

    返回:
        `torch.Tensor`: 最後一個殘差塊的輸出,是一個形狀為 `(batch_size, in_channels,
        height, width)` 的張量。

    """

    def __init__(
        self,
        in_channels: int,
        temb_channels: int,
        dropout: float = 0.0,
        num_layers: int = 1,
        resnet_eps: float = 1e-6,
        resnet_time_scale_shift: str = "default",  # 預設,空間
        resnet_act_fn: str = "swish",
        resnet_groups: int = 32,
        attn_groups: Optional[int] = None,
        resnet_pre_norm: bool = True,
        add_attention: bool = True,
        attention_head_dim: int = 1,
        output_scale_factor: float = 1.0,
    ):
        # 初始化 UNetMidBlockFlat 類,設定各引數的預設值
        super().__init__()  # 呼叫父類的初始化方法
        self.in_channels = in_channels  # 儲存輸入通道數
        self.temb_channels = temb_channels  # 儲存時間嵌入通道數
        self.dropout = dropout  # 儲存 dropout 比率
        self.num_layers = num_layers  # 儲存殘差塊的數量
        self.resnet_eps = resnet_eps  # 儲存 resnet 塊的 epsilon 值
        self.resnet_time_scale_shift = resnet_time_scale_shift  # 儲存時間縮放偏移型別
        self.resnet_act_fn = resnet_act_fn  # 儲存啟用函式型別
        self.resnet_groups = resnet_groups  # 儲存分組數
        self.attn_groups = attn_groups  # 儲存注意力組數
        self.resnet_pre_norm = resnet_pre_norm  # 儲存是否使用預歸一化
        self.add_attention = add_attention  # 儲存是否新增註意力塊
        self.attention_head_dim = attention_head_dim  # 儲存注意力頭的維度
        self.output_scale_factor = output_scale_factor  # 儲存輸出縮放因子

    def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
        # 定義前向傳播方法,接受隱藏狀態和可選的時間嵌入
        hidden_states = self.resnets[0](hidden_states, temb)  # 透過第一個殘差塊處理隱藏狀態
        for attn, resnet in zip(self.attentions, self.resnets[1:]):  # 遍歷後續的注意力塊和殘差塊
            if attn is not None:  # 如果注意力塊存在
                hidden_states = attn(hidden_states, temb=temb)  # 透過注意力塊處理隱藏狀態
            hidden_states = resnet(hidden_states, temb)  # 透過殘差塊處理隱藏狀態

        return hidden_states  # 返回處理後的隱藏狀態
# 從 diffusers.models.unets.unet_2d_blocks 中複製,替換 UNetMidBlock2DCrossAttn 為 UNetMidBlockFlatCrossAttn,ResnetBlock2D 為 ResnetBlockFlat
class UNetMidBlockFlatCrossAttn(nn.Module):
    # 初始化方法,定義模型引數
    def __init__(
        self,
        # 輸入通道數
        in_channels: int,
        # 時間嵌入通道數
        temb_channels: int,
        # 輸出通道數,預設為 None
        out_channels: Optional[int] = None,
        # Dropout 機率,預設為 0.0
        dropout: float = 0.0,
        # 層數,預設為 1
        num_layers: int = 1,
        # 每個塊的 Transformer 層數,預設為 1
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
        # ResNet 的 epsilon 值,預設為 1e-6
        resnet_eps: float = 1e-6,
        # ResNet 的時間尺度偏移,預設為 "default"
        resnet_time_scale_shift: str = "default",
        # ResNet 的啟用函式型別,預設為 "swish"
        resnet_act_fn: str = "swish",
        # ResNet 的分組數,預設為 32
        resnet_groups: int = 32,
        # 輸出的 ResNet 分組數,預設為 None
        resnet_groups_out: Optional[int] = None,
        # 是否使用預歸一化,預設為 True
        resnet_pre_norm: bool = True,
        # 注意力頭數,預設為 1
        num_attention_heads: int = 1,
        # 輸出縮放因子,預設為 1.0
        output_scale_factor: float = 1.0,
        # 交叉注意力維度,預設為 1280
        cross_attention_dim: int = 1280,
        # 是否使用雙交叉注意力,預設為 False
        dual_cross_attention: bool = False,
        # 是否使用線性投影,預設為 False
        use_linear_projection: bool = False,
        # 是否上升注意力計算精度,預設為 False
        upcast_attention: bool = False,
        # 注意力型別,預設為 "default"
        attention_type: str = "default",
    # 前向傳播方法,定義模型的前向計算邏輯
    def forward(
        self,
        # 隱藏狀態張量
        hidden_states: torch.Tensor,
        # 可選的時間嵌入張量,預設為 None
        temb: Optional[torch.Tensor] = None,
        # 可選的編碼器隱藏狀態張量,預設為 None
        encoder_hidden_states: Optional[torch.Tensor] = None,
        # 可選的注意力掩碼,預設為 None
        attention_mask: Optional[torch.Tensor] = None,
        # 可選的交叉注意力引數字典,預設為 None
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 可選的編碼器注意力掩碼,預設為 None
        encoder_attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:  # 定義函式的返回型別為 torch.Tensor
        if cross_attention_kwargs is not None:  # 檢查 cross_attention_kwargs 是否為 None
            if cross_attention_kwargs.get("scale", None) is not None:  # 檢查 scale 是否在 cross_attention_kwargs 中
                logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")  # 發出警告,提示 scale 引數已過時

        hidden_states = self.resnets[0](hidden_states, temb)  # 使用第一個殘差網路處理隱藏狀態和時間嵌入
        for attn, resnet in zip(self.attentions, self.resnets[1:]):  # 遍歷注意力層和後續的殘差網路
            if self.training and self.gradient_checkpointing:  # 檢查是否在訓練模式且開啟了梯度檢查點

                def create_custom_forward(module, return_dict=None):  # 定義一個函式以建立自定義前向傳播
                    def custom_forward(*inputs):  # 定義實際的前向傳播函式
                        if return_dict is not None:  # 檢查是否需要返回字典形式的輸出
                            return module(*inputs, return_dict=return_dict)  # 呼叫模組並返回字典
                        else:  # 如果不需要字典形式的輸出
                            return module(*inputs)  # 直接呼叫模組並返回結果

                    return custom_forward  # 返回自定義前向傳播函式

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}  # 根據 PyTorch 版本設定檢查點引數
                hidden_states = attn(  # 使用注意力層處理隱藏狀態
                    hidden_states,  # 輸入隱藏狀態
                    encoder_hidden_states=encoder_hidden_states,  # 輸入編碼器的隱藏狀態
                    cross_attention_kwargs=cross_attention_kwargs,  # 傳遞交叉注意力引數
                    attention_mask=attention_mask,  # 傳遞注意力掩碼
                    encoder_attention_mask=encoder_attention_mask,  # 傳遞編碼器注意力掩碼
                    return_dict=False,  # 不返回字典
                )[0]  # 取出輸出的第一個元素
                hidden_states = torch.utils.checkpoint.checkpoint(  # 使用檢查點儲存記憶體
                    create_custom_forward(resnet),  # 建立自定義前向傳播
                    hidden_states,  # 輸入隱藏狀態
                    temb,  # 輸入時間嵌入
                    **ckpt_kwargs,  # 解包檢查點引數
                )
            else:  # 如果不在訓練模式或不使用梯度檢查點
                hidden_states = attn(  # 使用注意力層處理隱藏狀態
                    hidden_states,  # 輸入隱藏狀態
                    encoder_hidden_states=encoder_hidden_states,  # 輸入編碼器的隱藏狀態
                    cross_attention_kwargs=cross_attention_kwargs,  # 傳遞交叉注意力引數
                    attention_mask=attention_mask,  # 傳遞注意力掩碼
                    encoder_attention_mask=encoder_attention_mask,  # 傳遞編碼器注意力掩碼
                    return_dict=False,  # 不返回字典
                )[0]  # 取出輸出的第一個元素
                hidden_states = resnet(hidden_states, temb)  # 使用殘差網路處理隱藏狀態和時間嵌入

        return hidden_states  # 返回處理後的隱藏狀態
# 從 diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn 複製,替換 UNetMidBlock2DSimpleCrossAttn 為 UNetMidBlockFlatSimpleCrossAttn,ResnetBlock2D 為 ResnetBlockFlat
class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
    # 初始化方法,設定各層的輸入輸出引數
    def __init__(
        # 輸入通道數
        in_channels: int,
        # 條件嵌入通道數
        temb_channels: int,
        # Dropout 機率
        dropout: float = 0.0,
        # 網路層數
        num_layers: int = 1,
        # ResNet 的 epsilon 值
        resnet_eps: float = 1e-6,
        # ResNet 的時間縮放偏移方式
        resnet_time_scale_shift: str = "default",
        # ResNet 啟用函式型別
        resnet_act_fn: str = "swish",
        # ResNet 中組的數量
        resnet_groups: int = 32,
        # 是否使用 ResNet 前歸一化
        resnet_pre_norm: bool = True,
        # 注意力頭的維度
        attention_head_dim: int = 1,
        # 輸出縮放因子
        output_scale_factor: float = 1.0,
        # 交叉注意力的維度
        cross_attention_dim: int = 1280,
        # 是否跳過時間啟用
        skip_time_act: bool = False,
        # 是否僅使用交叉注意力
        only_cross_attention: bool = False,
        # 交叉注意力的歸一化方式
        cross_attention_norm: Optional[str] = None,
    ):
        # 呼叫父類的初始化方法
        super().__init__()

        # 設定是否使用交叉注意力機制
        self.has_cross_attention = True

        # 設定注意力頭的維度
        self.attention_head_dim = attention_head_dim
        # 確定 ResNet 的組數,若未提供則使用預設值
        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)

        # 計算頭的數量
        self.num_heads = in_channels // self.attention_head_dim

        # 確保至少有一個 ResNet 塊
        resnets = [
            # 建立一個 ResNet 塊
            ResnetBlockFlat(
                in_channels=in_channels,
                out_channels=in_channels,
                temb_channels=temb_channels,
                eps=resnet_eps,
                groups=resnet_groups,
                dropout=dropout,
                time_embedding_norm=resnet_time_scale_shift,
                non_linearity=resnet_act_fn,
                output_scale_factor=output_scale_factor,
                pre_norm=resnet_pre_norm,
                skip_time_act=skip_time_act,
            )
        ]
        # 初始化注意力列表
        attentions = []

        # 根據層數建立對應的注意力機制
        for _ in range(num_layers):
            # 根據是否支援縮放點積注意力選擇處理器
            processor = (
                AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
            )

            # 新增註意力機制到列表
            attentions.append(
                Attention(
                    query_dim=in_channels,
                    cross_attention_dim=in_channels,
                    heads=self.num_heads,
                    dim_head=self.attention_head_dim,
                    added_kv_proj_dim=cross_attention_dim,
                    norm_num_groups=resnet_groups,
                    bias=True,
                    upcast_softmax=True,
                    only_cross_attention=only_cross_attention,
                    cross_attention_norm=cross_attention_norm,
                    processor=processor,
                )
            )
            # 新增 ResNet 塊到列表
            resnets.append(
                ResnetBlockFlat(
                    in_channels=in_channels,
                    out_channels=in_channels,
                    temb_channels=temb_channels,
                    eps=resnet_eps,
                    groups=resnet_groups,
                    dropout=dropout,
                    time_embedding_norm=resnet_time_scale_shift,
                    non_linearity=resnet_act_fn,
                    output_scale_factor=output_scale_factor,
                    pre_norm=resnet_pre_norm,
                    skip_time_act=skip_time_act,
                )
            )

        # 將注意力層存入模組列表
        self.attentions = nn.ModuleList(attentions)
        # 將 ResNet 塊存入模組列表
        self.resnets = nn.ModuleList(resnets)

    def forward(
        # 定義前向傳播的方法
        self,
        hidden_states: torch.Tensor,
        # 可選的時間嵌入張量
        temb: Optional[torch.Tensor] = None,
        # 可選的編碼器隱藏狀態張量
        encoder_hidden_states: Optional[torch.Tensor] = None,
        # 可選的注意力掩碼張量
        attention_mask: Optional[torch.Tensor] = None,
        # 可選的交叉注意力引數字典
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 可選的編碼器注意力掩碼張量
        encoder_attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # 如果傳入的 cross_attention_kwargs 為 None,則初始化為空字典
        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
        # 檢查 cross_attention_kwargs 中是否有 'scale',如果有則發出警告,說明該引數已棄用
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

        # 如果 attention_mask 為 None
        if attention_mask is None:
            # 如果 encoder_hidden_states 被定義:表示我們在進行交叉注意力,因此應該使用交叉注意力掩碼
            mask = None if encoder_hidden_states is None else encoder_attention_mask
        else:
            # 當 attention_mask 被定義時:我們不檢查 encoder_attention_mask
            # 這是為了與 UnCLIP 相容,UnCLIP 使用 'attention_mask' 引數作為交叉注意力掩碼
            # TODO: UnCLIP 應透過 encoder_attention_mask 引數而不是 attention_mask 引數來表達交叉注意力掩碼
            #       然後我們可以簡化整個 if/else 塊為:
            #         mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
            mask = attention_mask

        # 使用第一個殘差網路處理隱藏狀態和時間嵌入
        hidden_states = self.resnets[0](hidden_states, temb)
        # 遍歷所有注意力層和對應的殘差網路
        for attn, resnet in zip(self.attentions, self.resnets[1:]):
            # 使用注意力層處理隱藏狀態
            hidden_states = attn(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,  # 傳遞編碼器隱藏狀態
                attention_mask=mask,  # 傳遞掩碼
                **cross_attention_kwargs,  # 傳遞交叉注意力引數
            )

            # 使用殘差網路處理隱藏狀態
            hidden_states = resnet(hidden_states, temb)

        # 返回最終的隱藏狀態
        return hidden_states

相關文章