diffusers-原始碼解析-二十三-

绝不原创的飞龙發表於2024-10-22

diffusers 原始碼解析(二十三)

.\diffusers\pipelines\controlnet\pipeline_controlnet_sd_xl_img2img.py

# 版權所有 2024 HuggingFace 團隊。保留所有權利。
#
# 根據 Apache 許可證第 2.0 版(“許可證”)許可;
# 除非遵守許可證,否則您不得使用此檔案。
# 您可以在以下網址獲得許可證副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面協議另有規定,軟體
# 在“按原樣”基礎上分發,不提供任何形式的保證或條件,
# 無論是明示或暗示的。
# 請參閱許可證以瞭解管理許可權的具體語言和
# 限制條款。


import inspect  # 匯入 inspect 模組,用於獲取物件的摘要資訊
from typing import Any, Callable, Dict, List, Optional, Tuple, Union  # 匯入型別註解模組

import numpy as np  # 匯入 numpy,用於陣列和矩陣計算
import PIL.Image  # 匯入 PIL.Image,用於處理影像
import torch  # 匯入 PyTorch,用於深度學習
import torch.nn.functional as F  # 匯入 PyTorch 的函式式 API
from transformers import (  # 從 transformers 匯入模型和處理器
    CLIPImageProcessor,  # 匯入 CLIP 影像處理器
    CLIPTextModel,  # 匯入 CLIP 文字模型
    CLIPTextModelWithProjection,  # 匯入帶投影的 CLIP 文字模型
    CLIPTokenizer,  # 匯入 CLIP 分詞器
    CLIPVisionModelWithProjection,  # 匯入帶投影的 CLIP 視覺模型
)

from diffusers.utils.import_utils import is_invisible_watermark_available  # 匯入檢查是否可用的隱形水印功能

from ...callbacks import MultiPipelineCallbacks, PipelineCallback  # 匯入多管道回撥和管道回撥類
from ...image_processor import PipelineImageInput, VaeImageProcessor  # 匯入影像處理相關類
from ...loaders import (  # 匯入載入器相關類
    FromSingleFileMixin,  # 從單檔案載入的混合類
    IPAdapterMixin,  # 影像處理介面卡混合類
    StableDiffusionXLLoraLoaderMixin,  # StableDiffusionXL Lora 載入混合類
    TextualInversionLoaderMixin,  # 文字反轉載入混合類
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel  # 匯入不同模型
from ...models.attention_processor import (  # 匯入注意力處理器
    AttnProcessor2_0,  # 注意力處理器版本 2.0
    XFormersAttnProcessor,  # XFormers 注意力處理器
)
from ...models.lora import adjust_lora_scale_text_encoder  # 匯入調整 Lora 標度文字編碼器的函式
from ...schedulers import KarrasDiffusionSchedulers  # 匯入 Karras 擴散排程器
from ...utils import (  # 匯入常用工具
    USE_PEFT_BACKEND,  # 指示是否使用 PEFT 後端的常量
    deprecate,  # 匯入棄用裝飾器
    logging,  # 匯入日誌記錄模組
    replace_example_docstring,  # 匯入替換示例文件字串的工具
    scale_lora_layers,  # 匯入縮放 Lora 層的工具
    unscale_lora_layers,  # 匯入反縮放 Lora 層的工具
)
from ...utils.torch_utils import is_compiled_module, randn_tensor  # 匯入與 PyTorch 相關的工具
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin  # 匯入擴散管道和穩定擴散混合類
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput  # 匯入穩定擴散 XL 管道輸出類


if is_invisible_watermark_available():  # 如果隱形水印功能可用
    from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker  # 匯入穩定擴散 XL 水印類

from .multicontrolnet import MultiControlNetModel  # 匯入多控制網模型


logger = logging.get_logger(__name__)  # 獲取當前模組的日誌記錄器,禁止 pylint 檢查


EXAMPLE_DOC_STRING = """  # 示例文件字串的空模板
"""


# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 複製的函式
def retrieve_latents(  # 定義函式以檢索潛在變數
    encoder_output: torch.Tensor,  # 輸入為編碼器輸出的張量
    generator: Optional[torch.Generator] = None,  # 可選的隨機數生成器
    sample_mode: str = "sample"  # 取樣模式,預設為“sample”
):
    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":  # 如果編碼器輸出有潛在分佈並且模式為取樣
        return encoder_output.latent_dist.sample(generator)  # 從潛在分佈中取樣並返回
    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":  # 如果編碼器輸出有潛在分佈並且模式為“argmax”
        return encoder_output.latent_dist.mode()  # 返回潛在分佈的眾數
    elif hasattr(encoder_output, "latents"):  # 如果編碼器輸出有潛在變數
        return encoder_output.latents  # 直接返回潛在變數
    else:  # 如果以上條件都不滿足
        raise AttributeError("Could not access latents of provided encoder_output")  # 丟擲屬性錯誤,說明無法訪問潛在變數


class StableDiffusionXLControlNetImg2ImgPipeline(  # 定義 StableDiffusionXL 控制網路影像到影像的管道類
    DiffusionPipeline,  # 繼承自擴散管道
    # 繼承穩定擴散模型的混合類
        StableDiffusionMixin,
        # 繼承文字反轉載入器的混合類
        TextualInversionLoaderMixin,
        # 繼承穩定擴散 XL Lora 載入器的混合類
        StableDiffusionXLLoraLoaderMixin,
        # 繼承單檔案載入器的混合類
        FromSingleFileMixin,
        # 繼承 IP 介面卡的混合類
        IPAdapterMixin,
# 文件字串,描述使用 ControlNet 指導的影像生成管道
    r"""
    Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    The pipeline also inherits the following loading methods:
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

    """

    # 定義模型在 CPU 上解除安裝的順序
    model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
    # 定義可選元件的列表,用於管道的初始化
    _optional_components = [
        "tokenizer",  # 詞彙表,用於文字編碼
        "tokenizer_2",  # 第二個詞彙表,用於文字編碼
        "text_encoder",  # 文字編碼器,用於生成文字嵌入
        "text_encoder_2",  # 第二個文字編碼器,可能有不同的功能
        "feature_extractor",  # 特徵提取器,用於影像特徵的提取
        "image_encoder",  # 影像編碼器,將影像轉換為嵌入
    ]
    # 定義回撥張量輸入的列表,用於處理管道中的輸入
    _callback_tensor_inputs = [
        "latents",  # 潛在變數,用於生成模型的輸入
        "prompt_embeds",  # 正向提示的嵌入表示
        "negative_prompt_embeds",  # 負向提示的嵌入表示
        "add_text_embeds",  # 額外文字嵌入,用於補充輸入
        "add_time_ids",  # 附加的時間識別符號,用於時間相關的處理
        "negative_pooled_prompt_embeds",  # 負向池化提示的嵌入表示
        "add_neg_time_ids",  # 附加的負向時間識別符號
    ]

    # 建構函式,初始化管道所需的元件
    def __init__(
        self,  # 建構函式的第一個引數,指向類的例項
        vae: AutoencoderKL,  # 變分自編碼器,用於影像的重建
        text_encoder: CLIPTextModel,  # 文字編碼器,使用 CLIP 模型
        text_encoder_2: CLIPTextModelWithProjection,  # 第二個文字編碼器,帶投影功能的 CLIP 模型
        tokenizer: CLIPTokenizer,  # 第一個 CLIP 詞彙表
        tokenizer_2: CLIPTokenizer,  # 第二個 CLIP 詞彙表
        unet: UNet2DConditionModel,  # U-Net 模型,用於生成影像
        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],  # 控制網路模型,用於引導生成
        scheduler: KarrasDiffusionSchedulers,  # 排程器,控制擴散過程
        requires_aesthetics_score: bool = False,  # 是否需要美學評分,預設為 False
        force_zeros_for_empty_prompt: bool = True,  # 對於空提示強制使用零值,預設為 True
        add_watermarker: Optional[bool] = None,  # 是否新增水印,預設為 None
        feature_extractor: CLIPImageProcessor = None,  # 特徵提取器,預設為 None
        image_encoder: CLIPVisionModelWithProjection = None,  # 影像編碼器,預設為 None
    ):
        # 呼叫父類的建構函式進行初始化
        super().__init__()

        # 檢查 controlnet 是否為列表或元組,如果是則將其封裝為 MultiControlNetModel 物件
        if isinstance(controlnet, (list, tuple)):
            controlnet = MultiControlNetModel(controlnet)

        # 註冊多個模組,包括 VAE、文字編碼器、tokenizer、UNet 等
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            tokenizer=tokenizer,
            tokenizer_2=tokenizer_2,
            unet=unet,
            controlnet=controlnet,
            scheduler=scheduler,
            feature_extractor=feature_extractor,
            image_encoder=image_encoder,
        )
        # 計算 VAE 的縮放因子,通常用於影像尺寸調整
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        # 建立 VAE 影像處理器,設定縮放因子並開啟 RGB 轉換
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
        # 建立控制影像處理器,設定縮放因子,開啟 RGB 轉換,但不進行標準化
        self.control_image_processor = VaeImageProcessor(
            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
        )
        # 根據輸入引數或預設值確定是否新增水印
        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()

        # 如果需要水印,則初始化水印物件
        if add_watermarker:
            self.watermark = StableDiffusionXLWatermarker()
        else:
            # 否則將水印設定為 None
            self.watermark = None

        # 註冊配置,強制空提示使用零值
        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
        # 註冊配置,標記是否需要美學評分
        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)

    # 從 StableDiffusionXLPipeline 複製的 encode_prompt 方法
    def encode_prompt(
        self,
        # 定義 prompt 字串及其相關引數
        prompt: str,
        prompt_2: Optional[str] = None,
        device: Optional[torch.device] = None,
        num_images_per_prompt: int = 1,
        do_classifier_free_guidance: bool = True,
        negative_prompt: Optional[str] = None,
        negative_prompt_2: Optional[str] = None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        pooled_prompt_embeds: Optional[torch.Tensor] = None,
        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
        lora_scale: Optional[float] = None,
        clip_skip: Optional[int] = None,
    # 從 StableDiffusionPipeline 複製的 encode_image 方法
    # 定義一個方法來編碼影像,引數包括影像、裝置、每個提示的影像數量和可選的隱藏狀態輸出
        def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
            # 獲取影像編碼器引數的資料型別
            dtype = next(self.image_encoder.parameters()).dtype
    
            # 檢查輸入的影像是否為張量型別
            if not isinstance(image, torch.Tensor):
                # 如果不是,將其轉換為張量,並提取畫素值
                image = self.feature_extractor(image, return_tensors="pt").pixel_values
    
            # 將影像移動到指定裝置並轉換為相應的資料型別
            image = image.to(device=device, dtype=dtype)
            # 檢查是否需要輸出隱藏狀態
            if output_hidden_states:
                # 獲取影像編碼器的隱藏狀態,選擇倒數第二個隱藏層
                image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
                # 將隱藏狀態按每個提示的影像數量重複
                image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
                # 獲取無條件影像編碼的隱藏狀態,使用全零張量作為輸入
                uncond_image_enc_hidden_states = self.image_encoder(
                    torch.zeros_like(image), output_hidden_states=True
                ).hidden_states[-2]
                # 將無條件隱藏狀態按每個提示的影像數量重複
                uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
                    num_images_per_prompt, dim=0
                )
                # 返回影像編碼的隱藏狀態和無條件影像編碼的隱藏狀態
                return image_enc_hidden_states, uncond_image_enc_hidden_states
            else:
                # 獲取影像編碼的嵌入表示
                image_embeds = self.image_encoder(image).image_embeds
                # 將嵌入表示按每個提示的影像數量重複
                image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
                # 建立與影像嵌入同樣形狀的全零張量作為無條件嵌入
                uncond_image_embeds = torch.zeros_like(image_embeds)
    
                # 返回影像嵌入和無條件影像嵌入
                return image_embeds, uncond_image_embeds
    
        # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds 複製的方法
        def prepare_ip_adapter_image_embeds(
            # 定義方法的引數,包括 IP 介面卡影像、影像嵌入、裝置、每個提示的影像數量和分類器自由引導的標誌
            self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
    ):
        # 初始化一個空列表,用於儲存影像嵌入
        image_embeds = []
        # 如果啟用了無分類器自由引導,則初始化負影像嵌入列表
        if do_classifier_free_guidance:
            negative_image_embeds = []
        # 如果輸入介面卡影像嵌入為 None
        if ip_adapter_image_embeds is None:
            # 檢查輸入介面卡影像是否為列表型別,如果不是,則轉換為列表
            if not isinstance(ip_adapter_image, list):
                ip_adapter_image = [ip_adapter_image]

            # 檢查輸入介面卡影像的長度是否與 IP 介面卡數量相等
            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
                # 如果不相等,丟擲值錯誤
                raise ValueError(
                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
                )

            # 遍歷輸入介面卡影像和相應的影像投影層
            for single_ip_adapter_image, image_proj_layer in zip(
                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
            ):
                # 確定是否輸出隱藏狀態,依據影像投影層的型別
                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
                # 編碼單個影像,獲取嵌入和負嵌入
                single_image_embeds, single_negative_image_embeds = self.encode_image(
                    single_ip_adapter_image, device, 1, output_hidden_state
                )

                # 將影像嵌入新增到列表中,增加一個維度
                image_embeds.append(single_image_embeds[None, :])
                # 如果啟用了無分類器自由引導,則將負影像嵌入新增到列表中
                if do_classifier_free_guidance:
                    negative_image_embeds.append(single_negative_image_embeds[None, :])
        else:
            # 如果輸入介面卡影像嵌入已存在
            for single_image_embeds in ip_adapter_image_embeds:
                # 如果啟用了無分類器自由引導,將嵌入分成負嵌入和正嵌入
                if do_classifier_free_guidance:
                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
                    # 新增負影像嵌入到列表中
                    negative_image_embeds.append(single_negative_image_embeds)
                # 新增正影像嵌入到列表中
                image_embeds.append(single_image_embeds)

        # 初始化一個空列表,用於儲存處理後的輸入介面卡影像嵌入
        ip_adapter_image_embeds = []
        # 遍歷影像嵌入,執行重複操作以匹配每個提示的影像數量
        for i, single_image_embeds in enumerate(image_embeds):
            # 將單個影像嵌入沿著維度 0 重複指定次數
            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
            # 如果啟用了無分類器自由引導,處理負嵌入
            if do_classifier_free_guidance:
                single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
                # 將負嵌入與正嵌入合併
                single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)

            # 將嵌入移動到指定的裝置
            single_image_embeds = single_image_embeds.to(device=device)
            # 將處理後的嵌入新增到列表中
            ip_adapter_image_embeds.append(single_image_embeds)

        # 返回處理後的輸入介面卡影像嵌入列表
        return ip_adapter_image_embeds

    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 複製的內容
    # 準備額外的引數用於排程器步驟,因為並非所有排程器具有相同的引數簽名
        def prepare_extra_step_kwargs(self, generator, eta):
            # eta (η) 僅在 DDIMScheduler 中使用,其他排程器會忽略此引數
            # eta 對應於 DDIM 論文中的 η: https://arxiv.org/abs/2010.02502
            # 該值應在 [0, 1] 之間
    
            # 檢查排程器的步驟方法是否接受 eta 引數
            accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 建立一個空字典用於存放額外引數
            extra_step_kwargs = {}
            # 如果排程器接受 eta,則將其新增到額外引數中
            if accepts_eta:
                extra_step_kwargs["eta"] = eta
    
            # 檢查排程器的步驟方法是否接受 generator 引數
            accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 如果排程器接受 generator,則將其新增到額外引數中
            if accepts_generator:
                extra_step_kwargs["generator"] = generator
            # 返回包含額外引數的字典
            return extra_step_kwargs
    
        # 檢查輸入引數的有效性
        def check_inputs(
            self,
            prompt,
            prompt_2,
            image,
            strength,
            num_inference_steps,
            callback_steps,
            negative_prompt=None,
            negative_prompt_2=None,
            prompt_embeds=None,
            negative_prompt_embeds=None,
            pooled_prompt_embeds=None,
            negative_pooled_prompt_embeds=None,
            ip_adapter_image=None,
            ip_adapter_image_embeds=None,
            controlnet_conditioning_scale=1.0,
            control_guidance_start=0.0,
            control_guidance_end=1.0,
            callback_on_step_end_tensor_inputs=None,
        # 從 diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image 複製的引數
    # 檢查輸入影像的型別和形狀,確保與提示的批次大小一致
    def check_image(self, image, prompt, prompt_embeds):
        # 判斷輸入是否為 PIL 影像
        image_is_pil = isinstance(image, PIL.Image.Image)
        # 判斷輸入是否為 PyTorch 張量
        image_is_tensor = isinstance(image, torch.Tensor)
        # 判斷輸入是否為 NumPy 陣列
        image_is_np = isinstance(image, np.ndarray)
        # 判斷輸入是否為 PIL 影像列表
        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
        # 判斷輸入是否為 PyTorch 張量列表
        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
        # 判斷輸入是否為 NumPy 陣列列表
        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
    
        # 如果輸入不符合任何型別,丟擲型別錯誤
        if (
            not image_is_pil
            and not image_is_tensor
            and not image_is_np
            and not image_is_pil_list
            and not image_is_tensor_list
            and not image_is_np_list
        ):
            raise TypeError(
                f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
            )
    
        # 如果輸入為 PIL 影像,設定批次大小為 1
        if image_is_pil:
            image_batch_size = 1
        else:
            # 否則,根據輸入的長度確定批次大小
            image_batch_size = len(image)
    
        # 如果提示不為 None 且為字串,設定提示批次大小為 1
        if prompt is not None and isinstance(prompt, str):
            prompt_batch_size = 1
        # 如果提示為列表,根據列表長度設定批次大小
        elif prompt is not None and isinstance(prompt, list):
            prompt_batch_size = len(prompt)
        # 如果提示嵌入不為 None,使用其第一維的大小作為批次大小
        elif prompt_embeds is not None:
            prompt_batch_size = prompt_embeds.shape[0]
    
        # 如果影像批次大小不為 1,且與提示批次大小不一致,丟擲值錯誤
        if image_batch_size != 1 and image_batch_size != prompt_batch_size:
            raise ValueError(
                f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
            )
    
        # 從 diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl 匯入的 prepare_image 方法
        def prepare_control_image(
            self,
            image,
            width,
            height,
            batch_size,
            num_images_per_prompt,
            device,
            dtype,
            do_classifier_free_guidance=False,
            guess_mode=False,
        ):
            # 預處理輸入影像並轉換為指定的資料型別
            image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
            # 獲取影像批次大小
            image_batch_size = image.shape[0]
    
            # 如果影像批次大小為 1,重複次數設定為 batch_size
            if image_batch_size == 1:
                repeat_by = batch_size
            else:
                # 如果影像批次大小與提示批次大小相同,設定重複次數為每個提示的影像數量
                repeat_by = num_images_per_prompt
    
            # 重複影像以匹配所需的批次大小
            image = image.repeat_interleave(repeat_by, dim=0)
    
            # 將影像轉移到指定裝置和資料型別
            image = image.to(device=device, dtype=dtype)
    
            # 如果啟用分類器自由引導並且不在猜測模式下,複製影像以增加維度
            if do_classifier_free_guidance and not guess_mode:
                image = torch.cat([image] * 2)
    
            # 返回處理後的影像
            return image
    
        # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img 匯入的 get_timesteps 方法
    # 獲取時間步的函式,接收推理步驟數、強度和裝置引數
        def get_timesteps(self, num_inference_steps, strength, device):
            # 計算原始時間步,使用 init_timestep,確保不超過推理步驟數
            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
    
            # 計算開始時間步,確保不小於零
            t_start = max(num_inference_steps - init_timestep, 0)
            # 從排程器獲取時間步,擷取從 t_start 開始的所有時間步
            timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
            # 如果排程器具有設定開始索引的方法,則呼叫該方法
            if hasattr(self.scheduler, "set_begin_index"):
                self.scheduler.set_begin_index(t_start * self.scheduler.order)
    
            # 返回時間步和剩餘的推理步驟數
            return timesteps, num_inference_steps - t_start
    
        # 從 StableDiffusionXLImg2ImgPipeline 複製的準備潛在變數的函式
        def prepare_latents(
            self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
        # 從 StableDiffusionXLImg2ImgPipeline 複製的獲取附加時間 ID 的函式
        def _get_add_time_ids(
            self,
            original_size,
            crops_coords_top_left,
            target_size,
            aesthetic_score,
            negative_aesthetic_score,
            negative_original_size,
            negative_crops_coords_top_left,
            negative_target_size,
            dtype,
            text_encoder_projection_dim=None,
    ):
        # 檢查配置是否需要美學評分
        if self.config.requires_aesthetics_score:
            # 建立包含原始大小、裁剪座標及美學評分的列表
            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
            # 建立包含負樣本原始大小、裁剪座標及負美學評分的列表
            add_neg_time_ids = list(
                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
            )
        else:
            # 建立包含原始大小、裁剪座標和目標大小的列表
            add_time_ids = list(original_size + crops_coords_top_left + target_size)
            # 建立包含負樣本原始大小、裁剪座標及負目標大小的列表
            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)

        # 計算透過新增時間嵌入維度和文字編碼器投影維度得到的透過嵌入維度
        passed_add_embed_dim = (
            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
        )
        # 獲取模型期望的新增嵌入維度
        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features

        # 檢查期望的嵌入維度是否大於傳遞的嵌入維度,並符合特定條件
        if (
            expected_add_embed_dim > passed_add_embed_dim
            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
        ):
            # 丟擲值錯誤,說明建立的嵌入維度不符合預期
            raise ValueError(
                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
            )
        # 檢查期望的嵌入維度是否小於傳遞的嵌入維度,並符合特定條件
        elif (
            expected_add_embed_dim < passed_add_embed_dim
            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
        ):
            # 丟擲值錯誤,說明建立的嵌入維度不符合預期
            raise ValueError(
                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
            )
        # 檢查期望的嵌入維度是否與傳遞的嵌入維度不相等
        elif expected_add_embed_dim != passed_add_embed_dim:
            # 丟擲值錯誤,說明模型配置不正確
            raise ValueError(
                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
            )

        # 將新增的時間 ID 轉換為張量,並指定資料型別
        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
        # 將新增的負時間 ID 轉換為張量,並指定資料型別
        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)

        # 返回新增的時間 ID 和新增的負時間 ID
        return add_time_ids, add_neg_time_ids

    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae 複製而來
    # 定義一個方法,用於將 VAE 模型的引數型別提升
    def upcast_vae(self):
        # 獲取當前 VAE 模型的資料型別
        dtype = self.vae.dtype
        # 將 VAE 模型轉換為 float32 資料型別
        self.vae.to(dtype=torch.float32)
        # 檢查 VAE 解碼器中第一個注意力處理器的型別,以確定是否使用了特定版本的處理器
        use_torch_2_0_or_xformers = isinstance(
            self.vae.decoder.mid_block.attentions[0].processor,
            (
                AttnProcessor2_0,
                XFormersAttnProcessor,
            ),
        )
        # 如果使用了 xformers 或 torch_2_0,注意力塊不需要為 float32 型別,從而節省大量記憶體
        if use_torch_2_0_or_xformers:
            # 將後量化卷積層轉換為原始資料型別
            self.vae.post_quant_conv.to(dtype)
            # 將解碼器輸入卷積層轉換為原始資料型別
            self.vae.decoder.conv_in.to(dtype)
            # 將解碼器中間塊轉換為原始資料型別
            self.vae.decoder.mid_block.to(dtype)

    # 定義一個屬性,返回當前的引導縮放比例
    @property
    def guidance_scale(self):
        # 返回內部儲存的引導縮放比例
        return self._guidance_scale

    # 定義一個屬性,返回當前的剪輯跳過值
    @property
    def clip_skip(self):
        # 返回內部儲存的剪輯跳過值
        return self._clip_skip

    # 定義一個屬性,用於判斷是否進行無分類器引導,依據是引導縮放比例是否大於 1
    # 此屬性的定義參考了 Imagen 論文中的方程 (2)
    # 當 `guidance_scale = 1` 時,相當於不進行無分類器引導
    @property
    def do_classifier_free_guidance(self):
        # 如果引導縮放比例大於 1,返回 True,否則返回 False
        return self._guidance_scale > 1

    # 定義一個屬性,返回當前的交叉注意力引數
    @property
    def cross_attention_kwargs(self):
        # 返回內部儲存的交叉注意力引數
        return self._cross_attention_kwargs

    # 定義一個屬性,返回當前的時間步數
    @property
    def num_timesteps(self):
        # 返回內部儲存的時間步數
        return self._num_timesteps

    # 裝飾器,表示在執行下面的方法時不計算梯度
    @torch.no_grad()
    # 裝飾器,用於替換示例文件字串
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 定義一個可呼叫的類方法,接受多個引數用於處理影像生成
    def __call__(
        # 主提示字串或字串列表,預設為 None
        self,
        prompt: Union[str, List[str]] = None,
        # 第二個提示字串或字串列表,預設為 None
        prompt_2: Optional[Union[str, List[str]]] = None,
        # 輸入影像,用於影像生成的基礎,預設為 None
        image: PipelineImageInput = None,
        # 控制影像,用於影響生成的影像,預設為 None
        control_image: PipelineImageInput = None,
        # 輸出影像的高度,預設為 None
        height: Optional[int] = None,
        # 輸出影像的寬度,預設為 None
        width: Optional[int] = None,
        # 影像生成的強度,預設為 0.8
        strength: float = 0.8,
        # 進行推理的步數,預設為 50
        num_inference_steps: int = 50,
        # 引導尺度,控制影像生成的引導程度,預設為 5.0
        guidance_scale: float = 5.0,
        # 負面提示字串或字串列表,預設為 None
        negative_prompt: Optional[Union[str, List[str]]] = None,
        # 第二個負面提示字串或字串列表,預設為 None
        negative_prompt_2: Optional[Union[str, List[str]]] = None,
        # 每個提示生成的影像數量,預設為 1
        num_images_per_prompt: Optional[int] = 1,
        # 取樣的 eta 值,預設為 0.0
        eta: float = 0.0,
        # 隨機數生成器,可選,預設為 None
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        # 潛在變數,預設為 None
        latents: Optional[torch.Tensor] = None,
        # 提示的嵌入向量,預設為 None
        prompt_embeds: Optional[torch.Tensor] = None,
        # 負面提示的嵌入向量,預設為 None
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        # 聚合的提示嵌入向量,預設為 None
        pooled_prompt_embeds: Optional[torch.Tensor] = None,
        # 負面聚合提示嵌入向量,預設為 None
        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
        # 輸入介面卡影像,預設為 None
        ip_adapter_image: Optional[PipelineImageInput] = None,
        # 輸入介面卡影像的嵌入向量,預設為 None
        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
        # 輸出型別,預設為 "pil"
        output_type: Optional[str] = "pil",
        # 是否返回字典,預設為 True
        return_dict: bool = True,
        # 交叉注意力引數,預設為 None
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 控制網路的條件縮放,預設為 0.8
        controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
        # 猜測模式,預設為 False
        guess_mode: bool = False,
        # 控制引導的開始位置,預設為 0.0
        control_guidance_start: Union[float, List[float]] = 0.0,
        # 控制引導的結束位置,預設為 1.0
        control_guidance_end: Union[float, List[float]] = 1.0,
        # 原始影像的尺寸,預設為 None
        original_size: Tuple[int, int] = None,
        # 裁剪座標的左上角,預設為 (0, 0)
        crops_coords_top_left: Tuple[int, int] = (0, 0),
        # 目標尺寸,預設為 None
        target_size: Tuple[int, int] = None,
        # 負面原始影像的尺寸,預設為 None
        negative_original_size: Optional[Tuple[int, int]] = None,
        # 負面裁剪座標的左上角,預設為 (0, 0)
        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
        # 負目標尺寸,預設為 None
        negative_target_size: Optional[Tuple[int, int]] = None,
        # 審美分數,預設為 6.0
        aesthetic_score: float = 6.0,
        # 負面審美分數,預設為 2.5
        negative_aesthetic_score: float = 2.5,
        # 跳過的剪輯層數,預設為 None
        clip_skip: Optional[int] = None,
        # 步驟結束時的回撥函式,可選,預設為 None
        callback_on_step_end: Optional[
            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
        ] = None,
        # 結束步驟時的張量輸入回撥,預設為 ["latents"]
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        # 其他額外引數,預設為空
        **kwargs,

.\diffusers\pipelines\controlnet\pipeline_flax_controlnet.py

# 版權所有 2024 HuggingFace 團隊。保留所有權利。
#
# 根據 Apache 許可證,版本 2.0(“許可證”)授權;
# 除非遵守許可證,否則不得使用此檔案。
# 可以在以下網址獲取許可證副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律要求或書面協議另有規定,
# 否則根據許可證分發的軟體是“按原樣”提供的,
# 不提供任何形式的擔保或條件,無論是明示或暗示。
# 有關許可證下的特定語言的許可權和限制,請參見許可證。

import warnings  # 匯入警告模組,用於處理警告資訊
from functools import partial  # 從 functools 匯入 partial,用於部分函式應用
from typing import Dict, List, Optional, Union  # 匯入型別提示,方便函式引數和返回值的型別註釋

import jax  # 匯入 JAX,用於高效能數值計算
import jax.numpy as jnp  # 匯入 JAX 的 NumPy 介面,提供陣列操作功能
import numpy as np  # 匯入 NumPy,提供數值計算功能
from flax.core.frozen_dict import FrozenDict  # 從 flax 匯入 FrozenDict,用於不可變字典
from flax.jax_utils import unreplicate  # 從 flax 匯入 unreplicate,用於在 JAX 中處理裝置資料
from flax.training.common_utils import shard  # 從 flax 匯入 shard,用於資料並行
from PIL import Image  # 從 PIL 匯入 Image,用於影像處理
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel  # 匯入 CLIP 相關模組,處理影像和文字

from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel  # 匯入模型定義
from ...schedulers import (  # 匯入排程器,用於訓練過程中的控制
    FlaxDDIMScheduler,
    FlaxDPMSolverMultistepScheduler,
    FlaxLMSDiscreteScheduler,
    FlaxPNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring  # 匯入工具函式和常量
from ..pipeline_flax_utils import FlaxDiffusionPipeline  # 匯入擴散管道
from ..stable_diffusion import FlaxStableDiffusionPipelineOutput  # 匯入穩定擴散管道輸出
from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker  # 匯入安全檢查器

logger = logging.get_logger(__name__)  # 獲取當前模組的日誌記錄器,方便除錯和資訊輸出

# 設定為 True 以使用 Python 迴圈而不是 jax.fori_loop,以便於除錯
DEBUG = False  # 除錯模式標誌,預設為關閉狀態

EXAMPLE_DOC_STRING = """  # 示例文件字串,可能用於文件生成或示例展示


```  # 示例結束標誌
    Examples:
        ```py
        >>> import jax  # 匯入 JAX 庫,用於高效能數值計算
        >>> import numpy as np  # 匯入 NumPy 庫,支援陣列操作
        >>> import jax.numpy as jnp  # 匯入 JAX 的 NumPy,支援自動微分和GPU加速
        >>> from flax.jax_utils import replicate  # 從 Flax 匯入 replicate 函式,用於引數複製
        >>> from flax.training.common_utils import shard  # 從 Flax 匯入 shard 函式,用於資料分片
        >>> from diffusers.utils import load_image, make_image_grid  # 從 diffusers 匯入影像載入和網格生成工具
        >>> from PIL import Image  # 匯入 PIL 庫,用於影像處理
        >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel  # 匯入用於穩定擴散模型和控制網的類

        >>> def create_key(seed=0):  # 定義函式建立隨機數生成器的金鑰
        ...     return jax.random.PRNGKey(seed)  # 返回一個以 seed 為種子的 PRNG 金鑰

        >>> rng = create_key(0)  # 建立隨機數生成器的金鑰,種子為 0

        >>> # get canny image  # 獲取 Canny 邊緣檢測影像
        >>> canny_image = load_image(  # 使用 load_image 函式載入影像
        ...     "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"  # 指定影像的 URL
        ... )

        >>> prompts = "best quality, extremely detailed"  # 定義用於生成影像的正向提示
        >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality"  # 定義生成影像時要避免的負向提示

        >>> # load control net and stable diffusion v1-5  # 載入控制網路和穩定擴散模型 v1-5
        >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(  # 從預訓練模型載入控制網路及其引數
        ...     "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32  # 指定模型名稱、來源及資料型別
        ... )
        >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(  # 從預訓練模型載入穩定擴散管道及其引數
        ...     "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32  # 指定模型名稱、控制網、版本和資料型別
        ... )
        >>> params["controlnet"] = controlnet_params  # 將控制網引數存入管道引數中

        >>> num_samples = jax.device_count()  # 獲取當前裝置的數量,設定樣本數量
        >>> rng = jax.random.split(rng, jax.device_count())  # 將隨機數生成器的金鑰根據裝置數量進行分割

        >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)  # 準備正向提示的輸入,針對每個樣本複製
        >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)  # 準備負向提示的輸入,針對每個樣本複製
        >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)  # 準備處理後的影像輸入,針對每個樣本複製

        >>> p_params = replicate(params)  # 複製引數以便在多個裝置上使用
        >>> prompt_ids = shard(prompt_ids)  # 將正向提示的輸入資料進行分片
        >>> negative_prompt_ids = shard(negative_prompt_ids)  # 將負向提示的輸入資料進行分片
        >>> processed_image = shard(processed_image)  # 將處理後的影像輸入資料進行分片

        >>> output = pipe(  # 呼叫管道生成輸出
        ...     prompt_ids=prompt_ids,  # 傳入正向提示 ID
        ...     image=processed_image,  # 傳入處理後的影像
        ...     params=p_params,  # 傳入複製的引數
        ...     prng_seed=rng,  # 傳入隨機數生成器的金鑰
        ...     num_inference_steps=50,  # 設定推理的步驟數
        ...     neg_prompt_ids=negative_prompt_ids,  # 傳入負向提示 ID
        ...     jit=True,  # 啟用 JIT 編譯
        ... ).images  # 獲取生成的影像

        >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))  # 將輸出影像轉換為 PIL 格式
        >>> output_images = make_image_grid(output_images, num_samples // 4, 4)  # 將影像生成網格格式,指定每行顯示的影像數量
        >>> output_images.save("generated_image.png")  # 儲存生成的影像為 PNG 檔案
        ``` 
# 定義一個類,基於 Flax 實現 Stable Diffusion 的控制網文字到影像生成管道
class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
    r"""
    基於 Flax 的管道,用於使用 Stable Diffusion 和 ControlNet 指導進行文字到影像生成。

    此模型繼承自 [`FlaxDiffusionPipeline`]。有關所有管道實現的通用方法(下載、儲存、在特定裝置上執行等),請檢視超類文件。

    引數:
        vae ([`FlaxAutoencoderKL`]):
            用於將影像編碼和解碼為潛在表示的變分自編碼器(VAE)模型。
        text_encoder ([`~transformers.FlaxCLIPTextModel`]):
            凍結的文字編碼器([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。
        tokenizer ([`~transformers.CLIPTokenizer`]):
            用於對文字進行分詞的 `CLIPTokenizer`。
        unet ([`FlaxUNet2DConditionModel`]):
            一個 `FlaxUNet2DConditionModel`,用於去噪編碼後的影像潛在表示。
        controlnet ([`FlaxControlNetModel`]):
            在去噪過程中為 `unet` 提供額外的條件資訊。
        scheduler ([`SchedulerMixin`]):
            用於與 `unet` 結合使用的排程器,以去噪編碼的影像潛在表示。可以是
            [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`] 或
            [`FlaxDPMSolverMultistepScheduler`] 中的一個。
        safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
            分類模組,評估生成的影像是否可能被視為冒犯或有害。
            有關模型潛在危害的更多細節,請參閱 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5)。
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            一個 `CLIPImageProcessor`,用於提取生成影像的特徵;用於 `safety_checker` 的輸入。
    """

    # 初始化方法,定義所需引數及其型別
    def __init__(
        # 變分自編碼器(VAE)模型,用於影像編碼和解碼
        vae: FlaxAutoencoderKL,
        # 凍結的文字編碼器模型
        text_encoder: FlaxCLIPTextModel,
        # 文字分詞器
        tokenizer: CLIPTokenizer,
        # 去噪模型
        unet: FlaxUNet2DConditionModel,
        # 控制網模型
        controlnet: FlaxControlNetModel,
        # 影像去噪的排程器
        scheduler: Union[
            FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
        ],
        # 安全檢查模組
        safety_checker: FlaxStableDiffusionSafetyChecker,
        # 特徵提取器
        feature_extractor: CLIPImageProcessor,
        # 資料型別,預設為 32 位浮點數
        dtype: jnp.dtype = jnp.float32,
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 設定資料型別屬性
        self.dtype = dtype

        # 檢查安全檢查器是否為 None
        if safety_checker is None:
            # 記錄警告,告知使用者已禁用安全檢查器
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

        # 註冊各個模組,方便後續使用
        self.register_modules(
            vae=vae,  # 變分自編碼器
            text_encoder=text_encoder,  # 文字編碼器
            tokenizer=tokenizer,  # 分詞器
            unet=unet,  # UNet 模型
            controlnet=controlnet,  # 控制網路
            scheduler=scheduler,  # 排程器
            safety_checker=safety_checker,  # 安全檢查器
            feature_extractor=feature_extractor,  # 特徵提取器
        )
        # 計算 VAE 的縮放因子
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

    def prepare_text_inputs(self, prompt: Union[str, List[str]]):
        # 檢查 prompt 型別是否為字串或列表
        if not isinstance(prompt, (str, list)):
            # 如果型別不符,丟擲值錯誤
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        # 使用分詞器處理輸入文字
        text_input = self.tokenizer(
            prompt,  # 輸入的提示文字
            padding="max_length",  # 填充到最大長度
            max_length=self.tokenizer.model_max_length,  # 設定最大長度為分詞器的最大模型長度
            truncation=True,  # 如果超過最大長度,則截斷
            return_tensors="np",  # 返回 NumPy 格式的張量
        )

        # 返回處理後的輸入 ID
        return text_input.input_ids

    def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]):
        # 檢查影像型別是否為 PIL.Image 或列表
        if not isinstance(image, (Image.Image, list)):
            # 如果型別不符,丟擲值錯誤
            raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")

        # 如果輸入是單個影像,將其轉換為列表
        if isinstance(image, Image.Image):
            image = [image]

        # 對所有影像進行預處理,併合併為一個陣列
        processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])

        # 返回處理後的影像陣列
        return processed_images

    def _get_has_nsfw_concepts(self, features, params):
        # 使用安全檢查器檢查是否存在不適當內容概念
        has_nsfw_concepts = self.safety_checker(features, params)
        # 返回檢查結果
        return has_nsfw_concepts
    # 定義一個安全檢查的私有方法,接收影像、模型引數和是否使用 JIT 編譯的標誌
    def _run_safety_checker(self, images, safety_model_params, jit=False):
        # 當 jit 為 True 時,safety_model_params 應該已經被複制
        # 將輸入的影像陣列轉換為 PIL 影像格式
        pil_images = [Image.fromarray(image) for image in images]
        # 使用特徵提取器處理 PIL 影像,返回其畫素值
        features = self.feature_extractor(pil_images, return_tensors="np").pixel_values

        # 如果啟用 JIT 編譯
        if jit:
            # 對特徵進行分片處理
            features = shard(features)
            # 檢查特徵中是否存在 NSFW 概念
            has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
            # 取消特徵的分片
            has_nsfw_concepts = unshard(has_nsfw_concepts)
            # 取消模型引數的複製
            safety_model_params = unreplicate(safety_model_params)
        else:
            # 否則,直接獲取 NSFW 概念的檢查結果
            has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)

        # 初始化一個標誌,指示影像是否已經被複制
        images_was_copied = False
        # 遍歷每個 NSFW 概念的檢查結果
        for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
            # 如果檢測到 NSFW 概念
            if has_nsfw_concept:
                # 如果還沒有複製影像
                if not images_was_copied:
                    # 標記為已複製,並進行影像複製
                    images_was_copied = True
                    images = images.copy()

                # 將對應的影像替換為全黑影像
                images[idx] = np.zeros(images[idx].shape, dtype=np.uint8)  # black image

            # 如果存在任何 NSFW 概念
            if any(has_nsfw_concepts):
                # 發出警告,提示可能檢測到不適宜內容
                warnings.warn(
                    "Potential NSFW content was detected in one or more images. A black image will be returned"
                    " instead. Try again with a different prompt and/or seed."
                )

        # 返回處理後的影像和 NSFW 概念的檢查結果
        return images, has_nsfw_concepts

    # 定義一個生成影像的私有方法,接收多個引數以控制生成過程
    def _generate(
        self,
        prompt_ids: jnp.ndarray,  # 輸入的提示 ID 陣列
        image: jnp.ndarray,  # 輸入的影像資料
        params: Union[Dict, FrozenDict],  # 模型引數,可能是字典或不可變字典
        prng_seed: jax.Array,  # 隨機種子,用於隨機數生成
        num_inference_steps: int,  # 推理步驟的數量
        guidance_scale: float,  # 指導比例,用於控制生成質量
        latents: Optional[jnp.ndarray] = None,  # 潛在變數,預設值為 None
        neg_prompt_ids: Optional[jnp.ndarray] = None,  # 負提示 ID,預設值為 None
        controlnet_conditioning_scale: float = 1.0,  # 控制網路的條件縮放比例
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 定義可呼叫的方法,接收多個引數以控制生成過程
    def __call__(
        self,
        prompt_ids: jnp.ndarray,  # 輸入的提示 ID 陣列
        image: jnp.ndarray,  # 輸入的影像資料
        params: Union[Dict, FrozenDict],  # 模型引數,可能是字典或不可變字典
        prng_seed: jax.Array,  # 隨機種子,用於隨機數生成
        num_inference_steps: int = 50,  # 預設推理步驟的數量為 50
        guidance_scale: Union[float, jnp.ndarray] = 7.5,  # 預設指導比例為 7.5
        latents: jnp.ndarray = None,  # 潛在變數,預設值為 None
        neg_prompt_ids: jnp.ndarray = None,  # 負提示 ID,預設值為 None
        controlnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0,  # 預設控制網路的條件縮放比例為 1.0
        return_dict: bool = True,  # 預設返回字典格式
        jit: bool = False,  # 預設不啟用 JIT 編譯
# 靜態引數為 pipe 和 num_inference_steps,任何更改都會觸發重新編譯。
# 非靜態引數是(分片)輸入張量,這些張量在它們的第一維上被對映(因此為 `0`)。
@partial(
    jax.pmap,  # 使用 JAX 的 pmap 並行對映功能
    in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0),  # 指定輸入張量的軸
    static_broadcasted_argnums=(0, 5),  # 指定靜態廣播引數的索引
)
def _p_generate(  # 定義生成函式
    pipe,  # 生成管道物件
    prompt_ids,  # 提示 ID
    image,  # 輸入影像
    params,  # 生成引數
    prng_seed,  # 隨機數生成種子
    num_inference_steps,  # 推理步驟數
    guidance_scale,  # 指導尺度
    latents,  # 潛在變數
    neg_prompt_ids,  # 負提示 ID
    controlnet_conditioning_scale,  # 控制網條件尺度
):
    return pipe._generate(  # 呼叫生成管道的生成方法
        prompt_ids,  # 提示 ID
        image,  # 輸入影像
        params,  # 生成引數
        prng_seed,  # 隨機數生成種子
        num_inference_steps,  # 推理步驟數
        guidance_scale,  # 指導尺度
        latents,  # 潛在變數
        neg_prompt_ids,  # 負提示 ID
        controlnet_conditioning_scale,  # 控制網條件尺度
    )


@partial(jax.pmap, static_broadcasted_argnums=(0,))  # 使用 JAX 的 pmap,並指定靜態廣播引數
def _p_get_has_nsfw_concepts(pipe, features, params):  # 定義檢查是否有 NSFW 概念的函式
    return pipe._get_has_nsfw_concepts(features, params)  # 呼叫管道的相關方法


def unshard(x: jnp.ndarray):  # 定義反分片函式,接受一個張量
    # einops.rearrange(x, 'd b ... -> (d b) ...')  # 註釋掉的排列操作
    num_devices, batch_size = x.shape[:2]  # 獲取裝置數量和批次大小
    rest = x.shape[2:]  # 獲取其餘維度
    return x.reshape(num_devices * batch_size, *rest)  # 重新調整形狀以合併裝置和批次維度


def preprocess(image, dtype):  # 定義影像預處理函式
    image = image.convert("RGB")  # 將影像轉換為 RGB 模式
    w, h = image.size  # 獲取影像的寬和高
    w, h = (x - x % 64 for x in (w, h))  # 將寬高調整為64的整數倍
    image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])  # 調整影像大小,使用 Lanczos 插值法
    image = jnp.array(image).astype(dtype) / 255.0  # 轉換為 NumPy 陣列並歸一化到 [0, 1]
    image = image[None].transpose(0, 3, 1, 2)  # 新增新維度並調整通道順序
    return image  # 返回處理後的影像

.\diffusers\pipelines\controlnet\__init__.py

# 匯入型別檢查工具
from typing import TYPE_CHECKING

# 從 utils 模組匯入必要的工具和常量
from ...utils import (
    DIFFUSERS_SLOW_IMPORT,  # 匯入慢匯入標誌
    OptionalDependencyNotAvailable,  # 匯入可選依賴不可用異常
    _LazyModule,  # 匯入延遲模組工具
    get_objects_from_module,  # 匯入從模組獲取物件的函式
    is_flax_available,  # 匯入檢查 Flax 可用性的函式
    is_torch_available,  # 匯入檢查 Torch 可用性的函式
    is_transformers_available,  # 匯入檢查 Transformers 可用性的函式
)

# 初始化一個空字典用於儲存虛擬物件
_dummy_objects = {}
# 初始化一個空字典用於儲存匯入結構
_import_structure = {}

try:
    # 檢查 Transformers 和 Torch 是否可用
    if not (is_transformers_available() and is_torch_available()):
        # 如果不可用,丟擲異常
        raise OptionalDependencyNotAvailable()
# 捕獲可選依賴不可用的異常
except OptionalDependencyNotAvailable:
    # 從 utils 匯入虛擬物件
    from ...utils import dummy_torch_and_transformers_objects  # noqa F403

    # 更新虛擬物件字典
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
    # 如果依賴可用,更新匯入結構字典
    _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
    _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
    _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
    _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
    _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
    _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
    _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
    _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]

try:
    # 檢查 Transformers 和 Flax 是否可用
    if not (is_transformers_available() and is_flax_available()):
        # 如果不可用,丟擲異常
        raise OptionalDependencyNotAvailable()
# 捕獲可選依賴不可用的異常
except OptionalDependencyNotAvailable:
    # 從 utils 匯入虛擬 Flax 和 Transformers 物件
    from ...utils import dummy_flax_and_transformers_objects  # noqa F403

    # 更新虛擬物件字典
    _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
    # 如果依賴可用,更新匯入結構字典
    _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]

# 如果型別檢查或慢匯入標誌被設定
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    try:
        # 檢查 Transformers 和 Torch 是否可用
        if not (is_transformers_available() and is_torch_available()):
            # 如果不可用,丟擲異常
            raise OptionalDependencyNotAvailable()

    # 捕獲可選依賴不可用的異常
    except OptionalDependencyNotAvailable:
        # 匯入虛擬的 Torch 和 Transformers 物件
        from ...utils.dummy_torch_and_transformers_objects import *
    else:
        # 如果依賴可用,匯入相應模組
        from .multicontrolnet import MultiControlNetModel
        from .pipeline_controlnet import StableDiffusionControlNetPipeline
        from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
        from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
        from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
        from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
        from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
        from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline

    try:
        # 檢查 Transformers 和 Flax 是否可用
        if not (is_transformers_available() and is_flax_available()):
            # 如果不可用,丟擲異常
            raise OptionalDependencyNotAvailable()
    # 捕獲可選依賴項不可用的異常
        except OptionalDependencyNotAvailable:
            # 從 dummy 模組匯入所有內容,忽略 F403 警告
            from ...utils.dummy_flax_and_transformers_objects import *  # noqa F403
        else:
            # 從 pipeline_flax_controlnet 模組匯入 FlaxStableDiffusionControlNetPipeline
            from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
# 如果之前的條件不滿足,執行以下程式碼
else:
    # 匯入 sys 模組,用於訪問和操作 Python 直譯器的執行時環境
    import sys

    # 將當前模組替換為一個延遲載入的模組物件
    sys.modules[__name__] = _LazyModule(
        __name__,  # 模組名稱
        globals()["__file__"],  # 當前檔案的路徑
        _import_structure,  # 模組的匯入結構
        module_spec=__spec__,  # 模組的規格物件
    )
    # 遍歷假物件字典,將每個物件屬性設定到當前模組
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)  # 設定模組屬性

.\diffusers\pipelines\controlnet_hunyuandit\pipeline_hunyuandit_controlnet.py

# 版權宣告,指明檔案的版權歸 HunyuanDiT 和 HuggingFace 團隊所有
# 本檔案在 Apache 2.0 許可證下授權使用
# 除非遵循許可證,否則不能使用此檔案
# 許可證的副本可以在以下網址獲取
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律規定或書面協議另有約定,否則軟體在"按現狀"基礎上提供,不附帶任何明示或暗示的保證
# 檢視許可證以瞭解特定語言的許可權和限制

# 匯入用於獲取函式資訊的 inspect 模組
import inspect
# 匯入型別提示所需的型別
from typing import Callable, Dict, List, Optional, Tuple, Union

# 匯入 numpy 庫
import numpy as np
# 匯入 PyTorch 庫
import torch
# 從 transformers 庫匯入相關模型和分詞器
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel

# 從 diffusers 庫匯入 StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput

# 匯入多管道回撥類
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
# 匯入影像處理類
from ...image_processor import PipelineImageInput, VaeImageProcessor
# 匯入自動編碼器和模型
from ...models import AutoencoderKL, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel
# 匯入 2D 旋轉位置嵌入函式
from ...models.embeddings import get_2d_rotary_pos_embed
# 匯入穩定擴散安全檢查器
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
# 匯入擴散排程器
from ...schedulers import DDPMScheduler
# 匯入實用工具函式
from ...utils import (
    is_torch_xla_available,  # 檢查是否可用 XLA
    logging,  # 匯入日誌記錄模組
    replace_example_docstring,  # 替換示例文件字串的工具
)
# 匯入 PyTorch 相關的隨機張量函式
from ...utils.torch_utils import randn_tensor
# 匯入擴散管道工具類
from ..pipeline_utils import DiffusionPipeline

# 檢查是否可用 XLA,並根據結果匯入相應模組
if is_torch_xla_available():
    import torch_xla.core.xla_model as xm  # 匯入 XLA 核心模型

    XLA_AVAILABLE = True  # 設定 XLA 可用標誌為 True
else:
    XLA_AVAILABLE = False  # 設定 XLA 可用標誌為 False

# 建立一個日誌記錄器例項,記錄當前模組的日誌
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 示例文件字串,用於說明使用方法
EXAMPLE_DOC_STRING = """
# 示例程式碼展示如何使用 HunyuanDiT 進行影像生成
    Examples:
        ```py
        # 從 diffusers 庫匯入所需的模型和管道
        from diffusers import HunyuanDiT2DControlNetModel, HunyuanDiTControlNetPipeline
        # 匯入 PyTorch 庫
        import torch

        # 從預訓練模型載入 HunyuanDiT2DControlNetModel,並指定資料型別為 float16
        controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
            "Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny", torch_dtype=torch.float16
        )

        # 從預訓練模型載入 HunyuanDiTControlNetPipeline,傳入 controlnet 和資料型別
        pipe = HunyuanDiTControlNetPipeline.from_pretrained(
            "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
        )
        # 將管道移動到 CUDA 裝置以加速處理
        pipe.to("cuda")

        # 從 diffusers.utils 匯入載入影像的工具
        from diffusers.utils import load_image

        # 從指定 URL 載入條件影像
        cond_image = load_image(
            "https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny/resolve/main/canny.jpg?download=true"
        )

        ## HunyuanDiT 支援英語和中文提示,因此也可以使用英文提示
        # 定義影像生成的提示內容,描述夜晚的場景
        prompt = "在夜晚的酒店門前,一座古老的中國風格的獅子雕像矗立著,它的眼睛閃爍著光芒,彷彿在守護著這座建築。背景是夜晚的酒店前,構圖方式是特寫,平視,居中構圖。這張照片呈現了真實攝影風格,蘊含了中國雕塑文化,同時展現了神秘氛圍"
        # prompt="At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere."
        # 使用提示、影像尺寸、條件影像和推理步驟生成影像,並獲取生成的第一張影像
        image = pipe(
            prompt,
            height=1024,
            width=1024,
            control_image=cond_image,
            num_inference_steps=50,
        ).images[0]
        ```  

"""

文件字串,通常用於描述模組或類的功能

"""

定義一個標準寬高比的 NumPy 陣列

STANDARD_RATIO = np.array(
[
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
]
)

定義一個標準尺寸的列表,每個比例對應不同的寬高組合

STANDARD_SHAPE = [
[(1024, 1024), (1280, 1280)], # 1:1
[(1024, 768), (1152, 864), (1280, 960)], # 4:3
[(768, 1024), (864, 1152), (960, 1280)], # 3:4
[(1280, 768)], # 16:9
[(768, 1280)], # 9:16
]

根據標準尺寸計算每個形狀的面積,並將結果儲存在 NumPy 陣列中

STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]

定義一個支援的尺寸列表,包含不同的寬高組合

SUPPORTED_SHAPE = [
(1024, 1024),
(1280, 1280), # 1:1
(1024, 768),
(1152, 864),
(1280, 960), # 4:3
(768, 1024),
(864, 1152),
(960, 1280), # 3:4
(1280, 768), # 16:9
(768, 1280), # 9:16
]

定義一個函式,用於將目標寬高對映到標準形狀

def map_to_standard_shapes(target_width, target_height):
# 計算目標寬高比
target_ratio = target_width / target_height
# 找到與目標寬高比最接近的標準寬高比的索引
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
# 找到與目標面積最接近的標準形狀的索引
closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
# 獲取對應的標準寬和高
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
# 返回標準寬和高
return width, height

定義一個函式,用於計算源影像的縮放裁剪區域以適應目標大小

def get_resize_crop_region_for_grid(src, tgt_size):
# 獲取目標尺寸的高度和寬度
th = tw = tgt_size
# 獲取源影像的高度和寬度
h, w = src

# 計算源影像的寬高比
r = h / w

# 根據寬高比決定縮放方式
# 如果高度大於寬度
if r > 1:
    # 將目標高度作為縮放高度
    resize_height = th
    # 根據高度縮放計算對應的寬度
    resize_width = int(round(th / h * w))
else:
    # 否則,將目標寬度作為縮放寬度
    resize_width = tw
    # 根據寬度縮放計算對應的高度
    resize_height = int(round(tw / w * h))

# 計算裁剪區域的頂部和左邊位置
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))

# 返回裁剪區域的起始和結束座標
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)

從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg 複製的函式

def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
根據 guidance_rescalenoise_cfg 進行重新縮放。基於論文Common Diffusion Noise Schedules and
Sample Steps are Flawed
中的發現。見第3.4節
"""
# 計算噪聲預測文字的標準差
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
# 計算噪聲配置的標準差
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# 重新縮放來自引導的結果(修復過度曝光問題)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# 按照引導縮放因子與原始引導結果進行混合,以避免生成“單調”的影像
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
# 返回重新縮放後的噪聲配置
return noise_cfg

定義 HunyuanDiT 控制網路管道類,繼承自 DiffusionPipeline

class HunyuanDiTControlNetPipeline(DiffusionPipeline):
r"""
使用 HunyuanDiT 進行英語/中文到影像生成的管道。

該模型繼承自 [`DiffusionPipeline`]. 請檢視超類文件以獲取庫為所有管道實現的通用方法
(例如下載或儲存,在特定裝置上執行等)。

HunyuanDiT 使用兩個文字編碼器:[mT5](https://huggingface.co/google/mt5-base) 和 [雙語 CLIP](自行微調)
"""
# 引數說明
Args:
    vae ([`AutoencoderKL`]):  # 變分自編碼器模型,用於將影像編碼和解碼為潛在表示,這裡使用'sdxl-vae-fp16-fix'
        Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use
        `sdxl-vae-fp16-fix`.
    text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):  # 凍結的文字編碼器,使用CLIP模型
        Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 
        HunyuanDiT uses a fine-tuned [bilingual CLIP].
    tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):  # 文字標記化器,可以是BertTokenizer或CLIPTokenizer
        A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
    transformer ([`HunyuanDiT2DModel`]):  # HunyuanDiT模型,由騰訊Hunyuan設計
        The HunyuanDiT model designed by Tencent Hunyuan.
    text_encoder_2 (`T5EncoderModel`):  # mT5嵌入模型,特別是't5-v1_1-xxl'
        The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
    tokenizer_2 (`MT5Tokenizer`):  # mT5嵌入模型的標記化器
        The tokenizer for the mT5 embedder.
    scheduler ([`DDPMScheduler`]):  # 排程器,用於與HunyuanDiT結合,去噪編碼的影像潛在表示
        A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
    controlnet ([`HunyuanDiT2DControlNetModel`] or `List[HunyuanDiT2DControlNetModel]` or [`HunyuanDiT2DControlNetModel`]):  # 提供額外的條件資訊以輔助去噪過程
        Provides additional conditioning to the `unet` during the denoising process. If you set multiple
        ControlNets as a list, the outputs from each ControlNet are added together to create one combined
        additional conditioning.
"""

# 定義模型在CPU上解除安裝的順序
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
# 可選元件列表,可能會在初始化中使用
_optional_components = [
    "safety_checker",  # 安全檢查器
    "feature_extractor",  # 特徵提取器
    "text_encoder_2",  # 第二個文字編碼器
    "tokenizer_2",  # 第二個標記化器
    "text_encoder",  # 第一個文字編碼器
    "tokenizer",  # 第一個標記化器
]
# 從CPU解除安裝中排除的元件
_exclude_from_cpu_offload = ["safety_checker"]  # 不允許解除安裝安全檢查器
# 回撥張量輸入的列表,用於傳遞給模型
_callback_tensor_inputs = [
    "latents",  # 潛在變數
    "prompt_embeds",  # 提示的嵌入表示
    "negative_prompt_embeds",  # 負提示的嵌入表示
    "prompt_embeds_2",  # 第二個提示的嵌入表示
    "negative_prompt_embeds_2",  # 第二個負提示的嵌入表示
]

# 初始化方法定義,接收多個引數以構造模型
def __init__(
    self,
    vae: AutoencoderKL,  # 變分自編碼器模型
    text_encoder: BertModel,  # 文字編碼器
    tokenizer: BertTokenizer,  # 文字標記化器
    transformer: HunyuanDiT2DModel,  # HunyuanDiT模型
    scheduler: DDPMScheduler,  # 排程器
    safety_checker: StableDiffusionSafetyChecker,  # 安全檢查器
    feature_extractor: CLIPImageProcessor,  # 特徵提取器
    controlnet: Union[  # 控制網路,可以是單個或多個模型
        HunyuanDiT2DControlNetModel,
        List[HunyuanDiT2DControlNetModel],
        Tuple[HunyuanDiT2DControlNetModel],
        HunyuanDiT2DMultiControlNetModel,
    ],
    text_encoder_2=T5EncoderModel,  # 第二個文字編碼器,預設使用T5模型
    tokenizer_2=MT5Tokenizer,  # 第二個標記化器,預設使用MT5標記化器
    requires_safety_checker: bool = True,  # 是否需要安全檢查器,預設是True
# 初始化父類
):
    super().__init__()

    # 註冊多個模組,提供必要的元件以供使用
    self.register_modules(
        vae=vae,  # 註冊變分自編碼器
        text_encoder=text_encoder,  # 註冊文字編碼器
        tokenizer=tokenizer,  # 註冊分詞器
        tokenizer_2=tokenizer_2,  # 註冊第二個分詞器
        transformer=transformer,  # 註冊變換器
        scheduler=scheduler,  # 註冊排程器
        safety_checker=safety_checker,  # 註冊安全檢查器
        feature_extractor=feature_extractor,  # 註冊特徵提取器
        text_encoder_2=text_encoder_2,  # 註冊第二個文字編碼器
        controlnet=controlnet,  # 註冊控制網路
    )

    # 檢查安全檢查器是否為 None 並且需要使用安全檢查器
    if safety_checker is None and requires_safety_checker:
        # 記錄警告資訊,提醒使用者禁用安全檢查器的後果
        logger.warning(
            f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
            " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
            " results in services or applications open to the public. Both the diffusers team and Hugging Face"
            " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
            " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
            " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
        )

    # 檢查安全檢查器不為 None 且特徵提取器為 None
    if safety_checker is not None and feature_extractor is None:
        # 丟擲錯誤,提醒使用者必須定義特徵提取器
        raise ValueError(
            "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
            " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
        )

    # 計算 VAE 的縮放因子,如果存在 VAE 配置則使用其通道數量,否則預設為 8
    self.vae_scale_factor = (
        2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
    )
    # 初始化影像處理器,傳入 VAE 縮放因子
    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
    # 註冊到配置中,指明是否需要安全檢查器
    self.register_to_config(requires_safety_checker=requires_safety_checker)
    # 設定預設樣本大小,根據變換器配置或預設為 128
    self.default_sample_size = (
        self.transformer.config.sample_size
        if hasattr(self, "transformer") and self.transformer is not None
        else 128
    )

# 從其他模組複製的方法,用於編碼提示
def encode_prompt(
    self,
    prompt: str,  # 輸入的提示文字
    device: torch.device = None,  # 裝置引數,指定在哪個裝置上處理
    dtype: torch.dtype = None,  # 資料型別引數,指定張量的資料型別
    num_images_per_prompt: int = 1,  # 每個提示生成的影像數量
    do_classifier_free_guidance: bool = True,  # 是否執行無分類器的引導
    negative_prompt: Optional[str] = None,  # 可選的負面提示文字
    prompt_embeds: Optional[torch.Tensor] = None,  # 可選的提示嵌入張量
    negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可選的負面提示嵌入張量
    prompt_attention_mask: Optional[torch.Tensor] = None,  # 可選的提示注意力掩碼
    negative_prompt_attention_mask: Optional[torch.Tensor] = None,  # 可選的負面提示注意力掩碼
    max_sequence_length: Optional[int] = None,  # 可選的最大序列長度
    text_encoder_index: int = 0,  # 文字編碼器索引,預設值為 0
# 從其他模組複製的方法,用於執行安全檢查器
# 定義執行安全檢查器的方法,接收影像、裝置和資料型別作為引數
def run_safety_checker(self, image, device, dtype):
    # 如果安全檢查器未定義,設定無敏感內容標誌為 None
    if self.safety_checker is None:
        has_nsfw_concept = None
    else:
        # 如果輸入影像是張量,則後處理為 PIL 格式
        if torch.is_tensor(image):
            feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
        else:
            # 如果輸入影像不是張量,則將其轉換為 PIL 格式
            feature_extractor_input = self.image_processor.numpy_to_pil(image)
        # 使用特徵提取器處理輸入影像並將其轉換為指定裝置上的張量
        safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
        # 呼叫安全檢查器,返回處理後的影像和無敏感內容標誌
        image, has_nsfw_concept = self.safety_checker(
            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
        )
    # 返回處理後的影像和無敏感內容標誌
    return image, has_nsfw_concept

    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 複製
    def prepare_extra_step_kwargs(self, generator, eta):
        # 為排程器步驟準備額外引數,因為不是所有排程器的引數簽名相同
        # eta(η)僅在 DDIMScheduler 中使用,其他排程器將忽略
        # eta 在 DDIM 論文中的對應關係:https://arxiv.org/abs/2010.02502
        # 其值應在 [0, 1] 之間

        # 檢查排程器的步驟是否接受 eta 引數
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        # 如果接受 eta,則將其新增到額外引數字典中
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # 檢查排程器的步驟是否接受 generator 引數
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 如果接受 generator,則將其新增到額外引數字典中
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        # 返回額外引數字典
        return extra_step_kwargs

    # 從 diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs 複製
    def check_inputs(
        self,
        prompt,
        height,
        width,
        negative_prompt=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
        prompt_attention_mask=None,
        negative_prompt_attention_mask=None,
        prompt_embeds_2=None,
        negative_prompt_embeds_2=None,
        prompt_attention_mask_2=None,
        negative_prompt_attention_mask_2=None,
        callback_on_step_end_tensor_inputs=None,
    # 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 複製
# 準備潛在變數的函式,接收多個引數以配置潛在變數的形狀和生成方式
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        # 根據批大小、通道數、高度和寬度計算潛在變數的形狀
        shape = (
            batch_size,
            num_channels_latents,
            int(height) // self.vae_scale_factor,
            int(width) // self.vae_scale_factor,
        )
        # 檢查生成器列表的長度是否與批大小匹配
        if isinstance(generator, list) and len(generator) != batch_size:
            # 如果不匹配,丟擲一個值錯誤
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        # 如果潛在變數為 None,則生成隨機潛在變數
        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            # 如果已提供潛在變數,將其移動到指定裝置
            latents = latents.to(device)

        # 按排程器要求的標準差縮放初始噪聲
        latents = latents * self.scheduler.init_noise_sigma
        # 返回處理後的潛在變數
        return latents

    # 準備影像的函式,從外部呼叫
    def prepare_image(
        self,
        image,
        width,
        height,
        batch_size,
        num_images_per_prompt,
        device,
        dtype,
        do_classifier_free_guidance=False,
        guess_mode=False,
    ):
        # 檢查影像是否為張量,如果是則不處理
        if isinstance(image, torch.Tensor):
            pass
        else:
            # 否則對影像進行預處理,調整為指定的高度和寬度
            image = self.image_processor.preprocess(image, height=height, width=width)

        # 獲取影像的批大小
        image_batch_size = image.shape[0]

        # 如果影像批大小為1,則重複次數為批大小
        if image_batch_size == 1:
            repeat_by = batch_size
        else:
            # 否則影像批大小與提示批大小相同
            repeat_by = num_images_per_prompt

        # 沿著維度0重複影像
        image = image.repeat_interleave(repeat_by, dim=0)

        # 將影像移動到指定裝置,並轉換為指定資料型別
        image = image.to(device=device, dtype=dtype)

        # 如果啟用了無分類器自由引導,並且未啟用猜測模式,則將影像複製兩次
        if do_classifier_free_guidance and not guess_mode:
            image = torch.cat([image] * 2)

        # 返回處理後的影像
        return image

    # 獲取指導比例的屬性
    @property
    def guidance_scale(self):
        # 返回當前的指導比例
        return self._guidance_scale

    # 獲取指導重標定的屬性
    @property
    def guidance_rescale(self):
        # 返回當前的指導重標定值
        return self._guidance_rescale

    # 此屬性定義了類似於論文中指導權重的定義
    @property
    def do_classifier_free_guidance(self):
        # 如果指導比例大於1,則啟用無分類器自由引導
        return self._guidance_scale > 1

    # 獲取時間步數的屬性
    @property
    def num_timesteps(self):
        # 返回當前的時間步數
        return self._num_timesteps

    # 獲取中斷狀態的屬性
    @property
    def interrupt(self):
        # 返回當前中斷狀態
        return self._interrupt

    # 在不計算梯度的情況下執行,替換示例文件字串
    @torch.no_grad()
    @replace_example_docstring(EXAMPLE_DOC_STRING)
# 定義一個可呼叫的類方法
    def __call__(
        # 提示內容,可以是字串或字串列表
        self,
        prompt: Union[str, List[str]] = None,
        # 輸出影像的高度
        height: Optional[int] = None,
        # 輸出影像的寬度
        width: Optional[int] = None,
        # 推理步驟的數量,預設為 50
        num_inference_steps: Optional[int] = 50,
        # 引導比例,預設為 5.0
        guidance_scale: Optional[float] = 5.0,
        # 控制影像輸入,預設為 None
        control_image: PipelineImageInput = None,
        # 控制網條件比例,可以是單一值或值列表,預設為 1.0
        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
        # 負提示內容,可以是字串或字串列表,預設為 None
        negative_prompt: Optional[Union[str, List[str]]] = None,
        # 每個提示生成的影像數量,預設為 1
        num_images_per_prompt: Optional[int] = 1,
        # 用於生成的隨機性,預設為 0.0
        eta: Optional[float] = 0.0,
        # 隨機數生成器,可以是單個或列表,預設為 None
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        # 潛在變數,預設為 None
        latents: Optional[torch.Tensor] = None,
        # 提示的嵌入,預設為 None
        prompt_embeds: Optional[torch.Tensor] = None,
        # 第二組提示的嵌入,預設為 None
        prompt_embeds_2: Optional[torch.Tensor] = None,
        # 負提示的嵌入,預設為 None
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        # 第二組負提示的嵌入,預設為 None
        negative_prompt_embeds_2: Optional[torch.Tensor] = None,
        # 提示的注意力掩碼,預設為 None
        prompt_attention_mask: Optional[torch.Tensor] = None,
        # 第二組提示的注意力掩碼,預設為 None
        prompt_attention_mask_2: Optional[torch.Tensor] = None,
        # 負提示的注意力掩碼,預設為 None
        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
        # 第二組負提示的注意力掩碼,預設為 None
        negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
        # 輸出型別,預設為 "pil"
        output_type: Optional[str] = "pil",
        # 是否返回字典格式,預設為 True
        return_dict: bool = True,
        # 在步驟結束時的回撥函式
        callback_on_step_end: Optional[
            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
        ] = None,
        # 回撥時的張量輸入列表,預設為 ["latents"]
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        # 引導重標定,預設為 0.0
        guidance_rescale: float = 0.0,
        # 原始影像大小,預設為 (1024, 1024)
        original_size: Optional[Tuple[int, int]] = (1024, 1024),
        # 目標影像大小,預設為 None
        target_size: Optional[Tuple[int, int]] = None,
        # 裁剪座標,預設為 (0, 0)
        crops_coords_top_left: Tuple[int, int] = (0, 0),
        # 是否使用解析度分箱,預設為 True
        use_resolution_binning: bool = True,

# `.\diffusers\pipelines\controlnet_hunyuandit\__init__.py`

```py
# 從 typing 模組匯入 TYPE_CHECKING,用於靜態型別檢查
from typing import TYPE_CHECKING

# 從父模組的 utils 匯入多個工具函式和常量
from ...utils import (
    DIFFUSERS_SLOW_IMPORT,  # 匯入慢匯入的標誌
    OptionalDependencyNotAvailable,  # 匯入可選依賴不可用的異常
    _LazyModule,  # 匯入延遲載入模組的類
    get_objects_from_module,  # 匯入從模組獲取物件的函式
    is_torch_available,  # 匯入檢查 PyTorch 是否可用的函式
    is_transformers_available,  # 匯入檢查 Transformers 是否可用的函式
)

# 建立一個空字典,用於儲存假物件
_dummy_objects = {}
# 建立一個空字典,用於儲存匯入結構
_import_structure = {}

# 嘗試檢查是否可用的依賴
try:
    # 如果 Transformers 和 Torch 不可用,丟擲異常
    if not (is_transformers_available() and is_torch_available()):
        raise OptionalDependencyNotAvailable()
# 捕獲可選依賴不可用的異常
except OptionalDependencyNotAvailable:
    # 從 utils 匯入假物件(dummy objects),避免直接依賴
    from ...utils import dummy_torch_and_transformers_objects  # noqa F403

    # 更新 _dummy_objects 字典,包含假物件
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
# 如果依賴可用,更新匯入結構
else:
    # 將 HunyuanDiTControlNetPipeline 加入匯入結構
    _import_structure["pipeline_hunyuandit_controlnet"] = ["HunyuanDiTControlNetPipeline"]

# 檢查型別是否在檢查模式或是否需要慢匯入
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    # 嘗試檢查是否可用的依賴
    try:
        # 如果 Transformers 和 Torch 不可用,丟擲異常
        if not (is_transformers_available() and is_torch_available()):
            raise OptionalDependencyNotAvailable()

    # 捕獲可選依賴不可用的異常
    except OptionalDependencyNotAvailable:
        # 從 utils 匯入所有假物件,避免直接依賴
        from ...utils.dummy_torch_and_transformers_objects import *
    else:
        # 匯入真實的 HunyuanDiTControlNetPipeline 類
        from .pipeline_hunyuandit_controlnet import HunyuanDiTControlNetPipeline

# 如果不在型別檢查或不需要慢匯入
else:
    # 匯入 sys 模組
    import sys

    # 將當前模組替換為延遲載入模組,傳遞必要的資訊
    sys.modules[__name__] = _LazyModule(
        __name__,  # 模組名稱
        globals()["__file__"],  # 模組檔案路徑
        _import_structure,  # 匯入結構
        module_spec=__spec__,  # 模組規格
    )

    # 遍歷假物件並設定到當前模組
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)

相關文章