diffusers 原始碼解析(二十九)
.\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_model_editing.py
# 版權資訊,宣告版權和許可協議
# Copyright 2024 TIME Authors and The HuggingFace Team. All rights reserved."
# 根據 Apache License 2.0 許可協議進行許可
# 此檔案只能在遵守許可證的情況下使用
# 可透過以下網址獲取許可證副本
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律要求或書面同意,否則軟體按“現狀”分發,不提供任何形式的擔保或條件
# 具體許可條款和限制請參見許可證
# 匯入複製模組,用於物件複製
import copy
# 匯入檢查模組,用於檢查物件的資訊
import inspect
# 匯入型別提示相關的模組
from typing import Any, Callable, Dict, List, Optional, Union
# 匯入 PyTorch 庫
import torch
# 從 transformers 庫中匯入影像處理器、文字模型和標記器
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
# 從相對路徑匯入自定義影像處理器
from ....image_processor import VaeImageProcessor
# 從相對路徑匯入自定義載入器混合類
from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
# 從相對路徑匯入自定義模型類
from ....models import AutoencoderKL, UNet2DConditionModel
# 從相對路徑匯入調整 LoRA 規模的函式
from ....models.lora import adjust_lora_scale_text_encoder
# 從相對路徑匯入排程器類
from ....schedulers import PNDMScheduler
# 從排程器工具匯入排程器混合類
from ....schedulers.scheduling_utils import SchedulerMixin
# 從工具庫中匯入多個功能模組
from ....utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
# 從自定義的 PyTorch 工具庫中匯入隨機張量生成函式
from ....utils.torch_utils import randn_tensor
# 從管道工具匯入擴散管道和穩定擴散混合類
from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin
# 從穩定擴散相關模組匯入管道輸出類
from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
# 從穩定擴散安全檢查器模組匯入安全檢查器類
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
# 建立一個日誌記錄器,用於記錄當前模組的資訊
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 定義一個常量列表,包含不同的影像描述字首
AUGS_CONST = ["A photo of ", "An image of ", "A picture of "]
# 定義一個穩定擴散模型編輯管道類,繼承自多個基類
class StableDiffusionModelEditingPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
):
r"""
文字到影像模型編輯的管道。
該模型繼承自 [`DiffusionPipeline`]。請查閱超類文件以獲取所有管道實現的通用方法
(下載、儲存、在特定裝置上執行等)。
該管道還繼承以下載入方法:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] 用於載入文字反轉嵌入
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] 用於載入 LoRA 權重
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] 用於儲存 LoRA 權重
# 文件字串,描述類或方法的引數
Args:
vae ([`AutoencoderKL`]):
# Variational Auto-Encoder (VAE) 模型,用於將影像編碼和解碼為潛在表示
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder ([`~transformers.CLIPTextModel`]):
# 凍結的文字編碼器,使用 CLIP 的大型視覺變換模型
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer ([`~transformers.CLIPTokenizer`]):
# 用於文字標記化的 CLIPTokenizer
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
# 用於去噪編碼影像潛在的 UNet2DConditionModel
A `UNet2DConditionModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
# 排程器,與 unet 一起用於去噪編碼的影像潛在,可以是 DDIMScheduler、LMSDiscreteScheduler 或 PNDMScheduler
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
# 分類模組,用於估計生成的影像是否可能被視為冒犯或有害
Classification module that estimates whether generated images could be considered offensive or harmful.
# 參閱模型卡以獲取有關模型潛在危害的更多詳細資訊
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
# 用於從生成的影像中提取特徵的 CLIPImageProcessor;作為輸入傳遞給 safety_checker
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
with_to_k ([`bool`]):
# 是否在編輯文字到影像模型時編輯鍵投影矩陣與值投影矩陣
Whether to edit the key projection matrices along with the value projection matrices.
with_augs ([`list`]):
# 在編輯文字到影像模型時應用的文字增強,設定為 [] 表示不進行增強
Textual augmentations to apply while editing the text-to-image model. Set to `[]` for no augmentations.
"""
# 定義模型的 CPU 解除安裝順序
model_cpu_offload_seq = "text_encoder->unet->vae"
# 定義可選元件列表
_optional_components = ["safety_checker", "feature_extractor"]
# 定義不參與 CPU 解除安裝的元件
_exclude_from_cpu_offload = ["safety_checker"]
# 初始化方法,設定模型和引數
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
with_to_k: bool = True,
with_augs: list = AUGS_CONST,
# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 複製的程式碼
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
**kwargs,
):
# 定義一個棄用訊息,提示使用者 `_encode_prompt()` 已棄用,建議使用 `encode_prompt()`,並說明輸出格式變化
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
# 呼叫 deprecate 函式記錄棄用警告,指定版本和警告資訊,標準警告設定為 False
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
# 呼叫 encode_prompt 方法,將引數傳入以獲取提示嵌入的元組
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt, # 使用者輸入的提示文字
device=device, # 指定執行裝置
num_images_per_prompt=num_images_per_prompt, # 每個提示生成的影像數量
do_classifier_free_guidance=do_classifier_free_guidance, # 是否使用無分類器的引導
negative_prompt=negative_prompt, # 負面提示文字
prompt_embeds=prompt_embeds, # 現有的提示嵌入(如果有的話)
negative_prompt_embeds=negative_prompt_embeds, # 負面提示嵌入(如果有的話)
lora_scale=lora_scale, # Lora 縮放引數
**kwargs, # 額外引數
)
# 連線提示嵌入元組的兩個部分,適配舊版本的相容性
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
# 返回連線後的提示嵌入
return prompt_embeds
# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt 複製
def encode_prompt(
self,
prompt, # 使用者輸入的提示文字
device, # 指定執行裝置
num_images_per_prompt, # 每個提示生成的影像數量
do_classifier_free_guidance, # 是否使用無分類器的引導
negative_prompt=None, # 負面提示文字,預設為 None
prompt_embeds: Optional[torch.Tensor] = None, # 現有的提示嵌入,預設為 None
negative_prompt_embeds: Optional[torch.Tensor] = None, # 負面提示嵌入,預設為 None
lora_scale: Optional[float] = None, # Lora 縮放引數,預設為 None
clip_skip: Optional[int] = None, # 跳過的剪輯層,預設為 None
# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 複製
def run_safety_checker(self, image, device, dtype): # 定義安全檢查函式,接收影像、裝置和資料型別
# 檢查安全檢查器是否存在
if self.safety_checker is None:
has_nsfw_concept = None # 如果沒有安全檢查器,則 NSFW 概念為 None
else:
# 如果影像是張量型別,進行後處理,轉換為 PIL 格式
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
# 如果影像是 numpy 陣列,直接轉換為 PIL 格式
feature_extractor_input = self.image_processor.numpy_to_pil(image)
# 獲取安全檢查器的輸入,轉換為張量並移至指定裝置
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
# 使用安全檢查器檢查影像,返回處理後的影像和 NSFW 概念
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
# 返回處理後的影像和 NSFW 概念
return image, has_nsfw_concept
# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 複製
# 解碼潛在向量的方法
def decode_latents(self, latents):
# 警告資訊,提示此方法已棄用,將在 1.0.0 中移除
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
# 呼叫 deprecate 函式記錄棄用資訊
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
# 根據配置的縮放因子調整潛在向量
latents = 1 / self.vae.config.scaling_factor * latents
# 解碼潛在向量並返回影像資料
image = self.vae.decode(latents, return_dict=False)[0]
# 將影像資料歸一化到 [0, 1] 範圍內
image = (image / 2 + 0.5).clamp(0, 1)
# 始終將影像資料轉換為 float32 型別,以確保相容性並降低開銷
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
# 返回處理後的影像資料
return image
# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 複製
def prepare_extra_step_kwargs(self, generator, eta):
# 為排程器步驟準備額外的關鍵字引數,不同排程器的簽名不同
# eta (η) 僅在 DDIMScheduler 中使用,其他排程器會忽略
# eta 對應於 DDIM 論文中的 η,範圍應在 [0, 1] 之間
# 檢查排程器步驟的引數是否接受 eta
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 初始化額外步驟關鍵字引數字典
extra_step_kwargs = {}
# 如果排程器接受 eta,則將其新增到字典中
if accepts_eta:
extra_step_kwargs["eta"] = eta
# 檢查排程器步驟的引數是否接受 generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 如果排程器接受 generator,則將其新增到字典中
if accepts_generator:
extra_step_kwargs["generator"] = generator
# 返回額外步驟關鍵字引數字典
return extra_step_kwargs
# 檢查輸入引數的方法
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
# 檢查高度和寬度是否能被8整除
if height % 8 != 0 or width % 8 != 0:
# 丟擲異常,給出高度和寬度的資訊
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# 檢查回撥步數是否為正整數
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
# 丟擲異常,給出回撥步數的資訊
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# 檢查回撥輸入是否在預期的輸入列表中
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
# 丟擲異常,列出不在預期列表中的輸入
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
# 檢查是否同時提供了提示和提示嵌入
if prompt is not None and prompt_embeds is not None:
# 丟擲異常,提示只能提供一個
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
# 檢查是否同時未提供提示和提示嵌入
elif prompt is None and prompt_embeds is None:
# 丟擲異常,提示至少要提供一個
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
# 檢查提示型別是否合法
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
# 丟擲異常,提示型別不符合要求
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 檢查是否同時提供了負提示和負提示嵌入
if negative_prompt is not None and negative_prompt_embeds is not None:
# 丟擲異常,提示只能提供一個
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
# 檢查提示嵌入和負提示嵌入的形狀是否一致
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
# 丟擲異常,給出形狀不一致的資訊
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 複製
# 準備潛在變數的函式,接受多個引數以控制形狀和生成方式
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
# 定義潛在變數的形狀,包括批次大小、通道數和調整後的高度寬度
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
# 檢查生成器列表的長度是否與批次大小一致
if isinstance(generator, list) and len(generator) != batch_size:
# 如果不一致,則丟擲值錯誤並提供相關資訊
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# 如果潛在變數為空,則生成新的隨機潛在變數
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# 如果已提供潛在變數,則將其轉移到指定裝置
latents = latents.to(device)
# 根據排程器要求的標準差縮放初始噪聲
latents = latents * self.scheduler.init_noise_sigma
# 返回準備好的潛在變數
return latents
# 裝飾器,指示該函式不需要計算梯度
@torch.no_grad()
def edit_model(
self,
source_prompt: str,
destination_prompt: str,
lamb: float = 0.1,
restart_params: bool = True,
# 裝飾器,指示該函式不需要計算梯度
@torch.no_grad()
def __call__(
self,
# 允許輸入字串或字串列表作為提示
prompt: Union[str, List[str]] = None,
# 可選引數,指定生成影像的高度
height: Optional[int] = None,
# 可選引數,指定生成影像的寬度
width: Optional[int] = None,
# 設定推理步驟的預設數量為50
num_inference_steps: int = 50,
# 設定引導比例的預設值為7.5
guidance_scale: float = 7.5,
# 可選引數,允許輸入負面提示
negative_prompt: Optional[Union[str, List[str]]] = None,
# 可選引數,指定每個提示生成的影像數量,預設值為1
num_images_per_prompt: Optional[int] = 1,
# 設定eta的預設值為0.0
eta: float = 0.0,
# 可選引數,允許輸入生成器或生成器列表
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
# 可選引數,允許輸入潛在變數的張量
latents: Optional[torch.Tensor] = None,
# 可選引數,允許輸入提示嵌入的張量
prompt_embeds: Optional[torch.Tensor] = None,
# 可選引數,允許輸入負面提示嵌入的張量
negative_prompt_embeds: Optional[torch.Tensor] = None,
# 可選引數,指定輸出型別,預設為"pil"
output_type: Optional[str] = "pil",
# 可選引數,控制是否返回字典形式的輸出,預設為True
return_dict: bool = True,
# 可選回撥函式,用於處理生成過程中每一步的資訊
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
# 可選引數,指定回撥的步驟數,預設為1
callback_steps: int = 1,
# 可選引數,允許傳入交叉注意力的關鍵字引數
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可選引數,允許指定跳過的clip層數
clip_skip: Optional[int] = None,
.\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_paradigms.py
# 版權所有 2024 ParaDiGMS 作者和 HuggingFace 團隊。保留所有權利。
#
# 根據 Apache 許可證第 2.0 版(“許可證”)許可;
# 除非遵守許可證,否則您不得使用此檔案。
# 您可以在以下地址獲取許可證副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面協議另有約定,依據許可證分發的軟體
# 是按“原樣”提供的,沒有任何形式的明示或暗示的擔保或條件。
# 有關許可證的特定許可權和限制,請參見許可證。
# 匯入 inspect 模組以進行物件檢查
import inspect
# 從 typing 模組匯入型別提示相關的類
from typing import Any, Callable, Dict, List, Optional, Union
# 匯入 PyTorch 庫以進行深度學習操作
import torch
# 從 transformers 庫匯入 CLIP 影像處理器、文字模型和分詞器
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
# 匯入自定義的影像處理器
from ....image_processor import VaeImageProcessor
# 匯入與載入相關的混合類
from ....loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
# 匯入用於自動編碼器和條件 UNet 的模型
from ....models import AutoencoderKL, UNet2DConditionModel
# 從 lora 模組匯入調整 lora 規模的函式
from ....models.lora import adjust_lora_scale_text_encoder
# 匯入 Karras 擴散排程器
from ....schedulers import KarrasDiffusionSchedulers
# 匯入實用工具模組中的各種功能
from ....utils import (
USE_PEFT_BACKEND, # 用於 PEFT 後端的標誌
deprecate, # 用於標記棄用功能的裝飾器
logging, # 日誌記錄功能
replace_example_docstring, # 替換示例文件字串的功能
scale_lora_layers, # 調整 lora 層規模的功能
unscale_lora_layers, # 反調整 lora 層規模的功能
)
# 從 torch_utils 模組匯入生成隨機張量的功能
from ....utils.torch_utils import randn_tensor
# 匯入擴散管道和穩定擴散混合類
from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin
# 匯入穩定擴散管道輸出類
from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
# 匯入穩定擴散安全檢查器
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
# 建立日誌記錄器,用於當前模組的日誌
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 示例文件字串,展示用法示例
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import DDPMParallelScheduler
>>> from diffusers import StableDiffusionParadigmsPipeline
>>> scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> pipe = StableDiffusionParadigmsPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> ngpu, batch_per_device = torch.cuda.device_count(), 5
>>> pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)])
>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0]
```py
"""
# 定義 StableDiffusionParadigmsPipeline 類,繼承多個混合類以實現功能
class StableDiffusionParadigmsPipeline(
DiffusionPipeline, # 從擴散管道繼承
StableDiffusionMixin, # 從穩定擴散混合類繼承
TextualInversionLoaderMixin, # 從文字反轉載入混合類繼承
StableDiffusionLoraLoaderMixin, # 從穩定擴散 lora 載入混合類繼承
FromSingleFileMixin, # 從單檔案載入混合類繼承
):
r"""
用於文字到影像生成的管道,使用穩定擴散的並行化版本。
此模型繼承自 [`DiffusionPipeline`]。有關通用方法的文件,請檢視超類文件
# 實現所有管道的功能(下載、儲存、在特定裝置上執行等)。
# 管道還繼承以下載入方法:
# - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] 用於載入文字反轉嵌入
# - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] 用於載入 LoRA 權重
# - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] 用於儲存 LoRA 權重
# - [`~loaders.FromSingleFileMixin.from_single_file`] 用於載入 `.ckpt` 檔案
# 引數說明:
# vae ([`AutoencoderKL`]):
# 變分自編碼器(VAE)模型,用於將影像編碼和解碼為潛在表示。
# text_encoder ([`~transformers.CLIPTextModel`]):
# 凍結的文字編碼器([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。
# tokenizer ([`~transformers.CLIPTokenizer`]):
# 一個 `CLIPTokenizer` 用於對文字進行標記化。
# unet ([`UNet2DConditionModel`]):
# 一個 `UNet2DConditionModel` 用於去噪編碼的影像潛在。
# scheduler ([`SchedulerMixin`]):
# 用於與 `unet` 結合使用的排程器,用於去噪編碼的影像潛在。可以是
# [`DDIMScheduler`], [`LMSDiscreteScheduler`] 或 [`PNDMScheduler`] 中的一個。
# safety_checker ([`StableDiffusionSafetyChecker`]):
# 分類模組,估計生成的影像是否可能被認為是冒犯或有害的。
# 請參考 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) 獲取更多關於模型潛在危害的詳細資訊。
# feature_extractor ([`~transformers.CLIPImageProcessor`]):
# 一個 `CLIPImageProcessor` 用於從生成的影像中提取特徵;作為 `safety_checker` 的輸入。
# 定義模型的 CPU 離線載入順序
model_cpu_offload_seq = "text_encoder->unet->vae"
# 定義可選元件列表
_optional_components = ["safety_checker", "feature_extractor"]
# 定義排除在 CPU 離線載入之外的元件
_exclude_from_cpu_offload = ["safety_checker"]
# 初始化方法,接受多個引數
def __init__(
self,
vae: AutoencoderKL, # 變分自編碼器模型
text_encoder: CLIPTextModel, # 文字編碼器
tokenizer: CLIPTokenizer, # 文字標記器
unet: UNet2DConditionModel, # UNet2D 條件模型
scheduler: KarrasDiffusionSchedulers, # 排程器
safety_checker: StableDiffusionSafetyChecker, # 安全檢查模組
feature_extractor: CLIPImageProcessor, # 特徵提取器
requires_safety_checker: bool = True, # 是否需要安全檢查器
):
# 呼叫父類的初始化方法
super().__init__()
# 檢查是否禁用安全檢查器,並且需要安全檢查器
if safety_checker is None and requires_safety_checker:
# 記錄警告資訊,提醒使用者禁用安全檢查器的風險
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
# 檢查是否提供了安全檢查器但未提供特徵提取器
if safety_checker is not None and feature_extractor is None:
# 丟擲錯誤,提示使用者需要定義特徵提取器
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
# 註冊各個模組到當前例項
self.register_modules(
vae=vae, # 變分自編碼器
text_encoder=text_encoder, # 文字編碼器
tokenizer=tokenizer, # 分詞器
unet=unet, # U-Net 模型
scheduler=scheduler, # 排程器
safety_checker=safety_checker, # 安全檢查器
feature_extractor=feature_extractor, # 特徵提取器
)
# 計算 VAE 的縮放因子
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# 初始化影像處理器,使用 VAE 縮放因子
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# 將是否需要安全檢查器的配置註冊到當前例項
self.register_to_config(requires_safety_checker=requires_safety_checker)
# 用於在多個 GPU 上執行多個去噪步驟時,將 unet 包裝為 torch.nn.DataParallel
self.wrapped_unet = self.unet
# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 複製的函式
def _encode_prompt(
self,
prompt, # 輸入的提示文字
device, # 裝置型別(CPU或GPU)
num_images_per_prompt, # 每個提示生成的影像數量
do_classifier_free_guidance, # 是否使用無分類器引導
negative_prompt=None, # 可選的負面提示文字
prompt_embeds: Optional[torch.Tensor] = None, # 可選的提示嵌入
negative_prompt_embeds: Optional[torch.Tensor] = None, # 可選的負面提示嵌入
lora_scale: Optional[float] = None, # 可選的 LoRA 縮放因子
**kwargs, # 其他可選引數
):
# 設定棄用資訊,提醒使用者該方法即將被移除,建議使用新方法
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
# 呼叫 deprecate 函式,傳遞棄用資訊和版本號
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
# 呼叫 encode_prompt 方法,獲取提示的嵌入元組
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt, # 輸入提示文字
device=device, # 計算裝置
num_images_per_prompt=num_images_per_prompt, # 每個提示生成的影像數量
do_classifier_free_guidance=do_classifier_free_guidance, # 是否使用無分類器引導
negative_prompt=negative_prompt, # 負提示文字
prompt_embeds=prompt_embeds, # 提示嵌入
negative_prompt_embeds=negative_prompt_embeds, # 負提示嵌入
lora_scale=lora_scale, # LORA 縮放因子
**kwargs, # 其他額外引數
)
# 將返回的嵌入元組進行拼接,以支援向後相容
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
# 返回拼接後的提示嵌入
return prompt_embeds
# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline 中複製的 encode_prompt 方法
def encode_prompt(
self,
prompt, # 輸入的提示文字
device, # 計算裝置
num_images_per_prompt, # 每個提示生成的影像數量
do_classifier_free_guidance, # 是否使用無分類器引導
negative_prompt=None, # 負提示文字,預設為 None
prompt_embeds: Optional[torch.Tensor] = None, # 可選的提示嵌入
negative_prompt_embeds: Optional[torch.Tensor] = None, # 可選的負提示嵌入
lora_scale: Optional[float] = None, # 可選的 LORA 縮放因子
clip_skip: Optional[int] = None, # 可選的跳過剪輯引數
# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 複製的程式碼
def run_safety_checker(self, image, device, dtype):
# 檢查是否存在安全檢查器
if self.safety_checker is None:
has_nsfw_concept = None # 如果沒有安全檢查器,則設定無 NSFW 概念為 None
else:
# 如果輸入是張量格式,則進行後處理為 PIL 格式
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
# 如果輸入不是張量,轉換為 PIL 格式
feature_extractor_input = self.image_processor.numpy_to_pil(image)
# 使用特徵提取器處理影像,返回張量形式
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
# 呼叫安全檢查器,檢查影像是否包含 NSFW 概念
image, has_nsfw_concept = self.safety_checker(
images=image, # 輸入影像
clip_input=safety_checker_input.pixel_values.to(dtype) # 安全檢查的特徵輸入
)
# 返回處理後的影像及是否存在 NSFW 概念
return image, has_nsfw_concept
# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 複製的程式碼
# 定義一個方法,用於準備額外的引數以供排程器步驟使用
def prepare_extra_step_kwargs(self, generator, eta):
# 為排程器步驟準備額外的關鍵字引數,因為並非所有排程器都有相同的簽名
# eta (η) 僅在 DDIMScheduler 中使用,對於其他排程器將被忽略
# eta 對應於 DDIM 論文中的 η: https://arxiv.org/abs/2010.02502
# 其值應在 [0, 1] 之間
# 檢查排程器的 step 方法是否接受 eta 引數
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 初始化一個字典以儲存額外的步驟引數
extra_step_kwargs = {}
# 如果接受 eta 引數,則將其新增到字典中
if accepts_eta:
extra_step_kwargs["eta"] = eta
# 檢查排程器的 step 方法是否接受 generator 引數
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 如果接受 generator 引數,則將其新增到字典中
if accepts_generator:
extra_step_kwargs["generator"] = generator
# 返回準備好的額外步驟引數字典
return extra_step_kwargs
# 定義一個方法,用於檢查輸入引數的有效性
def check_inputs(
self,
prompt, # 輸入的提示文字
height, # 影像的高度
width, # 影像的寬度
callback_steps, # 回撥步驟的頻率
negative_prompt=None, # 可選的負面提示文字
prompt_embeds=None, # 可選的提示嵌入
negative_prompt_embeds=None, # 可選的負面提示嵌入
callback_on_step_end_tensor_inputs=None, # 可選的在步驟結束時的回撥張量輸入
):
# 檢查高度和寬度是否為 8 的倍數
if height % 8 != 0 or width % 8 != 0:
# 丟擲錯誤,如果高度或寬度不符合要求
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# 檢查回撥步數是否為正整數
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
# 丟擲錯誤,如果回撥步數不符合要求
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# 檢查回撥結束時的張量輸入是否在允許的輸入中
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
# 丟擲錯誤,如果不在允許的輸入中
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
# 檢查是否同時提供了提示和提示嵌入
if prompt is not None and prompt_embeds is not None:
# 丟擲錯誤,不能同時提供提示和提示嵌入
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
# 檢查提示和提示嵌入是否均未定義
elif prompt is None and prompt_embeds is None:
# 丟擲錯誤,必須提供一個提示或提示嵌入
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
# 檢查提示的型別
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
# 丟擲錯誤,如果提示不是字串或列表
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 檢查是否同時提供了負提示和負提示嵌入
if negative_prompt is not None and negative_prompt_embeds is not None:
# 丟擲錯誤,不能同時提供負提示和負提示嵌入
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
# 檢查提示嵌入和負提示嵌入的形狀是否相同
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
# 丟擲錯誤,如果形狀不一致
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 複製的程式碼
# 準備潛在變數,設定其形狀和屬性
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
# 定義潛在變數的形狀,考慮批次大小、通道數和縮放因子
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
# 檢查生成器列表的長度是否與批次大小匹配
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# 如果潛在變數未提供,則生成隨機潛在變數
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# 將提供的潛在變數移動到指定裝置
latents = latents.to(device)
# 根據排程器要求的標準差縮放初始噪聲
latents = latents * self.scheduler.init_noise_sigma
# 返回處理後的潛在變數
return latents
# 計算輸入張量在指定維度上的累積和
def _cumsum(self, input, dim, debug=False):
# 如果除錯模式開啟,則在CPU上執行累積和以確保可重複性
if debug:
# cumsum_cuda_kernel沒有確定性實現,故在CPU上執行
return torch.cumsum(input.cpu().float(), dim=dim).to(input.device)
else:
# 在指定維度上直接計算累積和
return torch.cumsum(input, dim=dim)
# 呼叫方法,接受多種引數以生成輸出
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
parallel: int = 10,
tolerance: float = 0.1,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
debug: bool = False,
clip_skip: int = None,
.\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_pix2pix_zero.py
# 版權宣告,說明版權資訊及持有者
# Copyright 2024 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
#
# 使用 Apache License 2.0 許可協議
# Licensed under the Apache License, Version 2.0 (the "License");
# 該檔案僅在遵循許可協議的情況下使用
# you may not use this file except in compliance with the License.
# 許可協議的獲取連結
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用的法律要求或書面協議,否則軟體按“原樣”分發
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何明示或暗示的保證或條件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 檢視許可協議中關於許可權和限制的詳細資訊
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect # 匯入 inspect 模組,用於獲取物件資訊
from dataclasses import dataclass # 從 dataclasses 模組匯入 dataclass 裝飾器
from typing import Any, Callable, Dict, List, Optional, Union # 匯入型別提示
import numpy as np # 匯入 numpy 模組,常用於數值計算
import PIL.Image # 匯入 PIL.Image 模組,用於影像處理
import torch # 匯入 PyTorch 庫,主要用於深度學習
import torch.nn.functional as F # 匯入 PyTorch 的功能性神經網路模組
from transformers import ( # 從 transformers 模組匯入以下類
BlipForConditionalGeneration, # 匯入用於條件生成的 Blip 模型
BlipProcessor, # 匯入 Blip 處理器
CLIPImageProcessor, # 匯入 CLIP 影像處理器
CLIPTextModel, # 匯入 CLIP 文字模型
CLIPTokenizer, # 匯入 CLIP 分詞器
)
from ....image_processor import PipelineImageInput, VaeImageProcessor # 從自定義模組匯入影像處理相關類
from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin # 匯入穩定擴散和文字反轉載入器混合類
from ....models import AutoencoderKL, UNet2DConditionModel # 匯入自動編碼器和 UNet 模型
from ....models.attention_processor import Attention # 匯入注意力處理器
from ....models.lora import adjust_lora_scale_text_encoder # 匯入調整 Lora 文字編碼器規模的函式
from ....schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler # 匯入多種排程器
from ....schedulers.scheduling_ddim_inverse import DDIMInverseScheduler # 匯入 DDIM 反向排程器
from ....utils import ( # 從自定義工具模組匯入實用函式和常量
PIL_INTERPOLATION, # 匯入 PIL 影像插值方法
USE_PEFT_BACKEND, # 匯入是否使用 PEFT 後端的常量
BaseOutput, # 匯入基礎輸出類
deprecate, # 匯入廢棄標記裝飾器
logging, # 匯入日誌記錄模組
replace_example_docstring, # 匯入替換示例文件字串的函式
scale_lora_layers, # 匯入縮放 Lora 層的函式
unscale_lora_layers, # 匯入反縮放 Lora 層的函式
)
from ....utils.torch_utils import randn_tensor # 從 PyTorch 工具模組匯入生成隨機張量的函式
from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin # 從管道工具模組匯入擴散管道和穩定擴散混合類
from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput # 匯入穩定擴散管道輸出類
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker # 匯入穩定擴散安全檢查器
logger = logging.get_logger(__name__) # 獲取當前模組的日誌記錄器
@dataclass # 將下面的類定義為資料類
class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin): # 定義輸出類,繼承基礎輸出和文字反轉載入器混合類
"""
輸出類用於穩定擴散管道。
引數:
latents (`torch.Tensor`)
反轉的潛在張量
images (`List[PIL.Image.Image]` or `np.ndarray`)
長度為 `batch_size` 的去噪 PIL 影像列表或形狀為 `(batch_size, height, width,
num_channels)` 的 numpy 陣列。PIL 影像或 numpy 陣列呈現擴散管道的去噪影像。
"""
latents: torch.Tensor # 定義潛在張量屬性
images: Union[List[PIL.Image.Image], np.ndarray] # 定義影像屬性,可以是影像列表或 numpy 陣列
EXAMPLE_DOC_STRING = """ # 定義示例文件字串的初始部分
``` # 示例文件字串的開始
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
```py # 示例文件字串的結束
``` # 示例文件字串的結束
# 示例程式碼展示如何使用 Diffusers 庫進行影像生成
Examples:
```py
# 匯入所需的庫
>>> import requests # 用於傳送 HTTP 請求
>>> import torch # 用於處理張量和深度學習模型
# 從 Diffusers 庫匯入必要的類
>>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline
# 定義下載嵌入檔案的函式
>>> def download(embedding_url, local_filepath):
... # 傳送 GET 請求獲取嵌入檔案
... r = requests.get(embedding_url)
... # 以二進位制模式開啟本地檔案並寫入獲取的內容
... with open(local_filepath, "wb") as f:
... f.write(r.content)
# 定義模型檢查點的名稱
>>> model_ckpt = "CompVis/stable-diffusion-v1-4"
# 從預訓練模型載入管道並設定資料型別為 float16
>>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16)
# 根據管道配置建立 DDIM 排程器
>>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
# 將模型移動到 GPU
>>> pipeline.to("cuda")
# 定義文字提示
>>> prompt = "a high resolution painting of a cat in the style of van gough"
# 定義源和目標嵌入檔案的 URL
>>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt"
>>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt"
# 遍歷源和目標嵌入 URL 進行下載
>>> for url in [source_emb_url, target_emb_url]:
... # 呼叫下載函式,將檔案儲存到本地
... download(url, url.split("/")[-1])
# 從本地載入源嵌入
>>> src_embeds = torch.load(source_emb_url.split("/")[-1])
# 從本地載入目標嵌入
>>> target_embeds = torch.load(target_emb_url.split("/")[-1])
# 使用管道生成影像
>>> images = pipeline(
... prompt, # 輸入的文字提示
... source_embeds=src_embeds, # 源嵌入
... target_embeds=target_embeds, # 目標嵌入
... num_inference_steps=50, # 推理步驟數
... cross_attention_guidance_amount=0.15, # 跨注意力引導的強度
... ).images # 生成的影像
# 儲存生成的第一張影像
>>> images[0].save("edited_image_dog.png") # 將影像儲存為 PNG 檔案
"""
# 示例文件字串,提供了使用示例和說明
EXAMPLE_INVERT_DOC_STRING = """
Examples:
```py
>>> import torch # 匯入 PyTorch 庫
>>> from transformers import BlipForConditionalGeneration, BlipProcessor # 從 transformers 匯入模型和處理器
>>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline # 從 diffusers 匯入排程器和管道
>>> import requests # 匯入 requests 庫,用於傳送網路請求
>>> from PIL import Image # 從 PIL 匯入 Image 類,用於處理影像
>>> captioner_id = "Salesforce/blip-image-captioning-base" # 定義影像說明生成模型的 ID
>>> processor = BlipProcessor.from_pretrained(captioner_id) # 從預訓練模型載入處理器
>>> model = BlipForConditionalGeneration.from_pretrained( # 從預訓練模型載入影像說明生成模型
... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True # 指定資料型別和低記憶體使用模式
... )
>>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4" # 定義穩定擴散模型的檢查點 ID
>>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( # 從預訓練模型載入 Pix2Pix 零管道
... sd_model_ckpt, # 指定檢查點
... caption_generator=model, # 指定影像說明生成器
... caption_processor=processor, # 指定影像說明處理器
... torch_dtype=torch.float16, # 指定資料型別
... safety_checker=None, # 關閉安全檢查器
... )
>>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) # 使用排程器配置初始化 DDIM 排程器
>>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) # 使用排程器配置初始化 DDIM 反向排程器
>>> pipeline.enable_model_cpu_offload() # 啟用模型的 CPU 解除安裝
>>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" # 定義要處理的影像 URL
>>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) # 從 URL 載入影像並調整大小
>>> # 生成說明
>>> caption = pipeline.generate_caption(raw_image) # 生成影像的說明
>>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii" # 生成的說明示例
>>> inv_latents = pipeline.invert(caption, image=raw_image).latents # 根據說明和原始影像進行反向處理,獲取潛變數
>>> # 我們需要生成源和目標嵌入
>>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] # 定義源提示列表
>>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] # 定義目標提示列表
>>> source_embeds = pipeline.get_embeds(source_prompts) # 獲取源提示的嵌入表示
>>> target_embeds = pipeline.get_embeds(target_prompts) # 獲取目標提示的嵌入表示
>>> # 潛變數可以用於編輯真實影像
>>> # 在使用穩定擴散 2 或其他使用 v-prediction 的模型時
>>> # 將 `cross_attention_guidance_amount` 設定為 0.01 或更低,以避免輸入潛變數梯度爆炸
>>> image = pipeline( # 使用管道生成新的影像
... caption, # 使用生成的說明
... source_embeds=source_embeds, # 傳遞源嵌入
... target_embeds=target_embeds, # 傳遞目標嵌入
... num_inference_steps=50, # 指定推理步驟數量
... cross_attention_guidance_amount=0.15, # 指定交叉注意力指導量
... generator=generator, # 使用指定的生成器
... latents=inv_latents, # 傳遞潛變數
... negative_prompt=caption, # 使用生成的說明作為負提示
... ).images[0] # 獲取生成的影像
>>> image.save("edited_image.png") # 儲存生成的影像
```py
"""
# 從 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img 匯入的 preprocess 函式
def preprocess(image): # 定義 preprocess 函式,接受影像作為引數
# 設定一個警告資訊,提示使用者 preprocess 方法已被棄用,並將在 diffusers 1.0.0 中移除
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
# 呼叫 deprecate 函式,記錄棄用資訊,設定標準警告為 False
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
# 檢查輸入的 image 是否是一個 Torch 張量
if isinstance(image, torch.Tensor):
# 如果是,直接返回該張量
return image
# 檢查輸入的 image 是否是一個 PIL 影像
elif isinstance(image, PIL.Image.Image):
# 如果是,將其封裝為一個單元素列表
image = [image]
# 檢查列表中的第一個元素是否為 PIL 影像
if isinstance(image[0], PIL.Image.Image):
# 獲取第一個影像的寬度和高度
w, h = image[0].size
# 將寬度和高度調整為 8 的整數倍
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
# 對每個影像進行調整大小,轉換為 numpy 陣列,並在新維度上增加一維
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
# 將所有影像在第 0 維上連線成一個大的陣列
image = np.concatenate(image, axis=0)
# 將資料轉換為 float32 型別並歸一化到 [0, 1] 範圍
image = np.array(image).astype(np.float32) / 255.0
# 調整陣列維度順序為 (batch_size, channels, height, width)
image = image.transpose(0, 3, 1, 2)
# 將畫素值範圍從 [0, 1] 轉換到 [-1, 1]
image = 2.0 * image - 1.0
# 將 numpy 陣列轉換為 Torch 張量
image = torch.from_numpy(image)
# 檢查列表中的第一個元素是否為 Torch 張量
elif isinstance(image[0], torch.Tensor):
# 將多個張量在第 0 維上連線成一個大的張量
image = torch.cat(image, dim=0)
# 返回處理後的影像
return image
# 準備 UNet 模型以執行 Pix2Pix Zero 最佳化
def prepare_unet(unet: UNet2DConditionModel):
# 初始化一個空字典,用於儲存 Pix2Pix Zero 注意力處理器
pix2pix_zero_attn_procs = {}
# 遍歷 UNet 的注意力處理器的鍵
for name in unet.attn_processors.keys():
# 將處理器名稱中的 ".processor" 替換為空
module_name = name.replace(".processor", "")
# 獲取 UNet 中對應的子模組
module = unet.get_submodule(module_name)
# 如果名稱包含 "attn2"
if "attn2" in name:
# 將處理器設定為 Pix2Pix Zero 模式
pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True)
# 允許該模組進行梯度更新
module.requires_grad_(True)
else:
# 設定為非 Pix2Pix Zero 模式
pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False)
# 不允許該模組進行梯度更新
module.requires_grad_(False)
# 設定 UNet 的注意力處理器為修改後的處理器字典
unet.set_attn_processor(pix2pix_zero_attn_procs)
# 返回修改後的 UNet 模型
return unet
class Pix2PixZeroL2Loss:
# 初始化損失類
def __init__(self):
# 設定初始損失值為 0
self.loss = 0.0
# 計算損失的方法
def compute_loss(self, predictions, targets):
# 更新損失值為預測值與目標值之間的均方差
self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0)
class Pix2PixZeroAttnProcessor:
"""注意力處理器類,用於儲存注意力權重。
在 Pix2Pix Zero 中,該過程發生在交叉注意力塊的計算中。"""
# 初始化注意力處理器
def __init__(self, is_pix2pix_zero=False):
# 記錄是否為 Pix2Pix Zero 模式
self.is_pix2pix_zero = is_pix2pix_zero
# 如果是 Pix2Pix Zero 模式,初始化參考交叉注意力對映
if self.is_pix2pix_zero:
self.reference_cross_attn_map = {}
# 定義呼叫方法
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
timestep=None,
loss=None,
):
# 獲取隱藏狀態的批次大小和序列長度
batch_size, sequence_length, _ = hidden_states.shape
# 準備注意力掩碼
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# 將隱藏狀態轉換為查詢向量
query = attn.to_q(hidden_states)
# 如果沒有編碼器隱藏狀態,則使用隱藏狀態本身
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# 如果需要進行交叉規範化
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# 將編碼器隱藏狀態轉換為鍵和值
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# 將查詢、鍵和值轉換為批次維度
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
# 計算注意力分數
attention_probs = attn.get_attention_scores(query, key, attention_mask)
# 如果是 Pix2Pix Zero 模式且時間步不為 None
if self.is_pix2pix_zero and timestep is not None:
# 新的記錄以儲存注意力權重
if loss is None:
self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu()
# 計算損失
elif loss is not None:
# 獲取之前的注意力機率
prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item())
# 計算損失
loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device))
# 將注意力機率與值相乘以獲得新的隱藏狀態
hidden_states = torch.bmm(attention_probs, value)
# 將隱藏狀態轉換回頭維度
hidden_states = attn.batch_to_head_dim(hidden_states)
# 線性變換
hidden_states = attn.to_out[0](hidden_states)
# 應用 dropout
hidden_states = attn.to_out[1](hidden_states)
# 返回新的隱藏狀態
return hidden_states
# 定義一個用於畫素級影像編輯的管道類,基於 Stable Diffusion
class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin):
r"""
使用 Pix2Pix Zero 進行畫素級影像編輯的管道。基於 Stable Diffusion。
該模型繼承自 [`DiffusionPipeline`]。請查閱超類文件以獲取庫為所有管道實現的通用方法(例如下載或儲存、在特定裝置上執行等)。
引數:
vae ([`AutoencoderKL`]):
用於將影像編碼和解碼到潛在表示的變分自編碼器(VAE)模型。
text_encoder ([`CLIPTextModel`]):
凍結的文字編碼器。Stable Diffusion 使用
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) 的文字部分,
特別是 [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) 變體。
tokenizer (`CLIPTokenizer`):
類的分詞器
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)。
unet ([`UNet2DConditionModel`]): 用於去噪編碼影像潛在的條件 U-Net 架構。
scheduler ([`SchedulerMixin`]):
與 `unet` 一起使用以去噪編碼影像潛在的排程器。可以是
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] 或 [`DDPMScheduler`] 的之一。
safety_checker ([`StableDiffusionSafetyChecker`]):
估計生成影像是否可能被視為攻擊性或有害的分類模組。
請參閱 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) 以獲取詳細資訊。
feature_extractor ([`CLIPImageProcessor`]):
從生成影像中提取特徵的模型,以便作為 `safety_checker` 的輸入。
requires_safety_checker (bool):
管道是否需要安全檢查器。如果您公開使用該管道,我們建議將其設定為 True。
"""
# 定義 CPU 解除安裝的模型元件順序
model_cpu_offload_seq = "text_encoder->unet->vae"
# 可選元件列表
_optional_components = [
"safety_checker",
"feature_extractor",
"caption_generator",
"caption_processor",
"inverse_scheduler",
]
# 從 CPU 解除安裝中排除的元件列表
_exclude_from_cpu_offload = ["safety_checker"]
# 初始化方法,定義管道的引數
def __init__(
self,
vae: AutoencoderKL, # 變分自編碼器模型
text_encoder: CLIPTextModel, # 文字編碼器模型
tokenizer: CLIPTokenizer, # 分詞器模型
unet: UNet2DConditionModel, # 條件 U-Net 模型
scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler], # 排程器型別
feature_extractor: CLIPImageProcessor, # 特徵提取器模型
safety_checker: StableDiffusionSafetyChecker, # 安全檢查器模型
inverse_scheduler: DDIMInverseScheduler, # 反向排程器
caption_generator: BlipForConditionalGeneration, # 描述生成器
caption_processor: BlipProcessor, # 描述處理器
requires_safety_checker: bool = True, # 是否需要安全檢查器的標誌
# 定義一個建構函式
):
# 呼叫父類的建構函式
super().__init__()
# 如果沒有提供安全檢查器且需要安全檢查器,發出警告
if safety_checker is None and requires_safety_checker:
logger.warning(
# 輸出關於禁用安全檢查器的警告資訊
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
# 如果提供了安全檢查器但沒有提供特徵提取器,丟擲錯誤
if safety_checker is not None and feature_extractor is None:
raise ValueError(
# 提示使用者必須定義特徵提取器以使用安全檢查器
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
# 註冊模組,設定各個元件
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
caption_processor=caption_processor,
caption_generator=caption_generator,
inverse_scheduler=inverse_scheduler,
)
# 計算 VAE 的縮放因子
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# 建立影像處理器,使用 VAE 縮放因子
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# 將配置項註冊到當前例項
self.register_to_config(requires_safety_checker=requires_safety_checker)
# 從 StableDiffusionPipeline 類複製的編碼提示的方法
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
# 可選引數,表示提示的嵌入
prompt_embeds: Optional[torch.Tensor] = None,
# 可選引數,表示負面提示的嵌入
negative_prompt_embeds: Optional[torch.Tensor] = None,
# 可選引數,表示 LORA 的縮放因子
lora_scale: Optional[float] = None,
# 接收任意額外引數
**kwargs,
# 開始定義一個方法,處理已棄用的編碼提示功能
):
# 定義棄用資訊,說明該方法將被移除,並推薦使用新方法
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
# 呼叫棄用函式,記錄棄用警告
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
# 呼叫新的編碼提示方法,獲取結果元組
prompt_embeds_tuple = self.encode_prompt(
# 傳入提示文字
prompt=prompt,
# 裝置型別(CPU/GPU)
device=device,
# 每個提示生成的影像數量
num_images_per_prompt=num_images_per_prompt,
# 是否進行分類器自由引導
do_classifier_free_guidance=do_classifier_free_guidance,
# 負面提示文字
negative_prompt=negative_prompt,
# 提示嵌入
prompt_embeds=prompt_embeds,
# 負面提示嵌入
negative_prompt_embeds=negative_prompt_embeds,
# Lora縮放因子
lora_scale=lora_scale,
# 其他可選引數
**kwargs,
)
# 將返回的元組中的兩個嵌入連線起來以相容舊版
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
# 返回最終的提示嵌入
return prompt_embeds
# 從指定的管道中複製的 encode_prompt 方法定義
def encode_prompt(
# 提示文字
self,
prompt,
# 裝置型別
device,
# 每個提示生成的影像數量
num_images_per_prompt,
# 是否進行分類器自由引導
do_classifier_free_guidance,
# 負面提示文字(可選)
negative_prompt=None,
# 提示嵌入(可選)
prompt_embeds: Optional[torch.Tensor] = None,
# 負面提示嵌入(可選)
negative_prompt_embeds: Optional[torch.Tensor] = None,
# Lora縮放因子(可選)
lora_scale: Optional[float] = None,
# 跳過的clip層數(可選)
clip_skip: Optional[int] = None,
# 從指定的管道中複製的 run_safety_checker 方法定義
def run_safety_checker(self, image, device, dtype):
# 檢查是否存在安全檢查器
if self.safety_checker is None:
# 如果沒有安全檢查器,標記為無概念
has_nsfw_concept = None
else:
# 如果影像是張量,則進行後處理以轉換為PIL格式
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
# 如果不是張量,則將其轉換為PIL格式
feature_extractor_input = self.image_processor.numpy_to_pil(image)
# 將處理後的影像提取特徵,準備進行安全檢查
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
# 使用安全檢查器檢查影像,返回影像及其概念狀態
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
# 返回檢查後的影像及其概念狀態
return image, has_nsfw_concept
# 從指定的管道中複製的 decode_latents 方法
# 解碼潛在向量並返回生成的影像
def decode_latents(self, latents):
# 警告使用者該方法已過時,將在1.0.0版本中刪除
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
# 呼叫deprecate函式記錄該方法的棄用資訊
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
# 使用配置的縮放因子對潛在向量進行縮放
latents = 1 / self.vae.config.scaling_factor * latents
# 解碼潛在向量,返回生成的影像
image = self.vae.decode(latents, return_dict=False)[0]
# 將影像值從[-1, 1]對映到[0, 1]並限制範圍
image = (image / 2 + 0.5).clamp(0, 1)
# 將影像轉換為float32格式以確保相容性,並將其轉換為numpy陣列
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
# 返回處理後的影像
return image
# 從diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs複製
def prepare_extra_step_kwargs(self, generator, eta):
# 準備額外的引數以供排程器步驟使用,因排程器的引數簽名可能不同
# eta (η) 僅在DDIMScheduler中使用,其他排程器將忽略它。
# eta在DDIM論文中對應於η:https://arxiv.org/abs/2010.02502
# eta的值應在[0, 1]之間
# 檢查排程器步驟是否接受eta引數
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 初始化額外引數字典
extra_step_kwargs = {}
# 如果接受eta,則將其新增到額外引數中
if accepts_eta:
extra_step_kwargs["eta"] = eta
# 檢查排程器步驟是否接受generator引數
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 如果接受generator,則將其新增到額外引數中
if accepts_generator:
extra_step_kwargs["generator"] = generator
# 返回準備好的額外引數
return extra_step_kwargs
def check_inputs(
self,
prompt,
source_embeds,
target_embeds,
callback_steps,
prompt_embeds=None,
):
# 檢查callback_steps是否為正整數
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# 確保source_embeds和target_embeds不能同時未定義
if source_embeds is None and target_embeds is None:
raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.")
# 檢查prompt和prompt_embeds不能同時被定義
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
# 檢查prompt和prompt_embeds不能同時未定義
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
# 確保prompt的型別為str或list
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 從 StableDiffusionPipeline 的 prepare_latents 方法複製的內容
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
# 定義潛在張量的形狀,包括批次大小、通道數、高度和寬度
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
# 檢查生成器是否為列表且其長度與批次大小匹配
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# 如果未提供潛在張量,則生成隨機張量
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# 如果提供了潛在張量,則將其移動到指定裝置
latents = latents.to(device)
# 按排程器所需的標準差縮放初始噪聲
latents = latents * self.scheduler.init_noise_sigma
# 返回處理後的潛在張量
return latents
@torch.no_grad()
def generate_caption(self, images):
"""為給定影像生成標題。"""
# 初始化生成標題的文字
text = "a photography of"
# 儲存當前裝置
prev_device = self.caption_generator.device
# 獲取執行裝置
device = self._execution_device
# 處理輸入影像並轉換為張量
inputs = self.caption_processor(images, text, return_tensors="pt").to(
device=device, dtype=self.caption_generator.dtype
)
# 將標題生成器移動到指定裝置
self.caption_generator.to(device)
# 生成標題輸出
outputs = self.caption_generator.generate(**inputs, max_new_tokens=128)
# 將標題生成器移回先前裝置
self.caption_generator.to(prev_device)
# 解碼輸出以獲取標題
caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0]
# 返回生成的標題
return caption
def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor):
"""構造用於引導影像生成過程的編輯方向。"""
# 返回目標和源嵌入的均值之差,並增加一個維度
return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)
@torch.no_grad()
def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.Tensor:
# 獲取提示的數量
num_prompts = len(prompt)
# 初始化嵌入列表
embeds = []
# 分批處理提示
for i in range(0, num_prompts, batch_size):
prompt_slice = prompt[i : i + batch_size]
# 將提示轉換為輸入 ID,進行填充和截斷
input_ids = self.tokenizer(
prompt_slice,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids
# 將輸入 ID 移動到文字編碼器裝置
input_ids = input_ids.to(self.text_encoder.device)
# 獲取嵌入並追加到列表
embeds.append(self.text_encoder(input_ids)[0])
# 將所有嵌入拼接並計算均值
return torch.cat(embeds, dim=0).mean(0)[None]
# 準備影像的潛在表示,接收影像和其他引數,返回潛在向量
def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):
# 檢查輸入影像的型別是否為有效型別
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
# 丟擲型別錯誤,提示使用者輸入型別不正確
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
# 將影像轉換到指定的裝置和資料型別
image = image.to(device=device, dtype=dtype)
# 如果影像有四個通道,直接將其作為潛在表示
if image.shape[1] == 4:
latents = image
else:
# 檢查生成器列表的長度是否與批次大小匹配
if isinstance(generator, list) and len(generator) != batch_size:
# 丟擲錯誤,提示生成器列表長度與批次大小不匹配
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# 如果生成器是列表,逐個影像編碼並生成潛在表示
if isinstance(generator, list):
latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
# 將潛在表示合併到一個張量中
latents = torch.cat(latents, dim=0)
else:
# 使用單個生成器編碼影像並生成潛在表示
latents = self.vae.encode(image).latent_dist.sample(generator)
# 根據配置的縮放因子調整潛在表示
latents = self.vae.config.scaling_factor * latents
# 檢查潛在表示的批次大小是否與請求的匹配
if batch_size != latents.shape[0]:
# 如果可以整除,則擴充套件潛在表示以匹配批次大小
if batch_size % latents.shape[0] == 0:
# 構建棄用訊息,提示使用者行為即將被移除
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many initial images as text prompts to suppress this warning."
)
# 觸發棄用警告,提醒使用者修改程式碼
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
# 計算每個影像需要複製的次數
additional_latents_per_image = batch_size // latents.shape[0]
# 將潛在表示按需重複以匹配批次大小
latents = torch.cat([latents] * additional_latents_per_image, dim=0)
else:
# 丟擲錯誤,提示無法複製影像以匹配批次大小
raise ValueError(
f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
)
else:
# 將潛在表示封裝為一個張量
latents = torch.cat([latents], dim=0)
# 返回最終的潛在表示
return latents
# 定義一個獲取epsilon的函式,輸入為模型輸出、樣本和時間步
def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int):
# 獲取反向排程器的預測型別配置
pred_type = self.inverse_scheduler.config.prediction_type
# 計算在當前時間步的累積alpha值
alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep]
# 計算beta值為1減去alpha值
beta_prod_t = 1 - alpha_prod_t
# 根據預測型別返回相應的結果
if pred_type == "epsilon":
return model_output
elif pred_type == "sample":
# 根據樣本和模型輸出計算返回值
return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5)
elif pred_type == "v_prediction":
# 根據alpha和beta值結合模型輸出和樣本計算返回值
return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
# 如果預測型別無效,丟擲異常
raise ValueError(
f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
)
# 定義一個自動相關損失計算的函式,輸入為隱藏狀態和可選生成器
def auto_corr_loss(self, hidden_states, generator=None):
# 初始化正則化損失為0
reg_loss = 0.0
# 遍歷隱藏狀態的每一個維度
for i in range(hidden_states.shape[0]):
for j in range(hidden_states.shape[1]):
# 選取當前噪聲
noise = hidden_states[i : i + 1, j : j + 1, :, :]
# 進行迴圈,直到噪聲尺寸小於等於8
while True:
# 隨機生成滾動的位移量
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
# 計算並累加正則化損失
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
# 如果噪聲尺寸小於等於8,跳出迴圈
if noise.shape[2] <= 8:
break
# 對噪聲進行2x2的平均池化
noise = F.avg_pool2d(noise, kernel_size=2)
# 返回計算得到的正則化損失
return reg_loss
# 定義一個計算KL散度的函式,輸入為隱藏狀態
def kl_divergence(self, hidden_states):
# 計算隱藏狀態的均值
mean = hidden_states.mean()
# 計算隱藏狀態的方差
var = hidden_states.var()
# 返回KL散度的計算結果
return var + mean**2 - 1 - torch.log(var + 1e-7)
# 定義呼叫函式,使用@torch.no_grad()裝飾器禁止梯度計算
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
# 輸入引數包括提示、源和目標嵌入、影像的高和寬、推理步驟等
prompt: Optional[Union[str, List[str]]] = None,
source_embeds: torch.Tensor = None,
target_embeds: torch.Tensor = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
cross_attention_guidance_amount: float = 0.1,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
# 使用@torch.no_grad()和裝飾器替換文件字串
@torch.no_grad()
@replace_example_docstring(EXAMPLE_INVERT_DOC_STRING)
# 定義一個名為 invert 的方法,包含多個可選引數
def invert(
# 輸入提示,預設為 None
self,
prompt: Optional[str] = None,
# 輸入影像,預設為 None
image: PipelineImageInput = None,
# 推理步驟的數量,預設為 50
num_inference_steps: int = 50,
# 指導比例,預設為 1
guidance_scale: float = 1,
# 隨機數生成器,可以是單個或多個生成器,預設為 None
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
# 潛在變數,預設為 None
latents: Optional[torch.Tensor] = None,
# 提示嵌入,預設為 None
prompt_embeds: Optional[torch.Tensor] = None,
# 跨注意力引導量,預設為 0.1
cross_attention_guidance_amount: float = 0.1,
# 輸出型別,預設為 "pil"
output_type: Optional[str] = "pil",
# 是否返回字典,預設為 True
return_dict: bool = True,
# 回撥函式,預設為 None
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
# 回撥步驟,預設為 1
callback_steps: Optional[int] = 1,
# 跨注意力引數,預設為 None
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 自動相關的權重,預設為 20.0
lambda_auto_corr: float = 20.0,
# KL 散度的權重,預設為 20.0
lambda_kl: float = 20.0,
# 正則化步驟的數量,預設為 5
num_reg_steps: int = 5,
# 自動相關滾動的數量,預設為 5
num_auto_corr_rolls: int = 5,
.\diffusers\pipelines\deprecated\stable_diffusion_variants\__init__.py
# 從型別檢查模組匯入型別檢查相關功能
from typing import TYPE_CHECKING
# 從 utils 模組匯入所需的工具和常量
from ....utils import (
DIFFUSERS_SLOW_IMPORT, # 匯入用於延遲匯入的常量
OptionalDependencyNotAvailable, # 匯入可選依賴不可用的異常類
_LazyModule, # 匯入延遲模組載入的工具
get_objects_from_module, # 匯入從模組獲取物件的函式
is_torch_available, # 匯入檢查 Torch 庫是否可用的函式
is_transformers_available, # 匯入檢查 Transformers 庫是否可用的函式
)
# 初始化一個空字典用於存放虛擬物件
_dummy_objects = {}
# 初始化一個空字典用於存放模組匯入結構
_import_structure = {}
try:
# 檢查 Transformers 和 Torch 庫是否都可用
if not (is_transformers_available() and is_torch_available()):
# 如果不可用,則丟擲異常
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
# 匯入虛擬物件以避免在依賴不可用時導致錯誤
from ....utils import dummy_torch_and_transformers_objects
# 更新虛擬物件字典,填充從虛擬物件模組獲取的物件
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
# 如果依賴可用,新增相關的管道到匯入結構字典
_import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"]
_import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"]
_import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"]
_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
_import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"]
# 根據型別檢查標誌或慢速匯入標誌進行條件判斷
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
# 再次檢查依賴是否可用
if not (is_transformers_available() and is_torch_available()):
# 如果不可用,則丟擲異常
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
# 匯入虛擬物件以避免在依賴不可用時導致錯誤
from ....utils.dummy_torch_and_transformers_objects import *
else:
# 匯入具體的管道類,確保它們在依賴可用時被載入
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline
from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline
else:
# 如果不是型別檢查或慢速匯入,則進行懶載入處理
import sys
# 使用懶載入模組構造當前模組,指定匯入結構和模組規格
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
# 遍歷虛擬物件字典,將物件屬性設定到當前模組
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
.\diffusers\pipelines\deprecated\stochastic_karras_ve\pipeline_stochastic_karras_ve.py
# 版權宣告,表明此檔案的版權所有者及保留權利
#
# 根據 Apache 許可證第 2.0 版(“許可證”)進行許可;
# 除非遵循許可證,否則您不得使用此檔案。
# 您可以在以下網址獲取許可證的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面同意,軟體在許可證下分發,按“原樣”基礎,
# 不提供任何形式的保證或條件,無論是明示或暗示的。
# 請參見許可證以獲取有關許可權和
# 限制的具體規定。
# 從 typing 模組匯入所需的型別提示
from typing import List, Optional, Tuple, Union
# 匯入 PyTorch 庫
import torch
# 從相對路徑匯入 UNet2DModel 模型
from ....models import UNet2DModel
# 從相對路徑匯入排程器 KarrasVeScheduler
from ....schedulers import KarrasVeScheduler
# 從相對路徑匯入隨機張量生成工具
from ....utils.torch_utils import randn_tensor
# 從相對路徑匯入擴散管道和影像輸出
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
# 定義 KarrasVePipeline 類,繼承自 DiffusionPipeline
class KarrasVePipeline(DiffusionPipeline):
r"""
無條件影像生成的管道。
引數:
unet ([`UNet2DModel`]):
用於去噪編碼影像的 `UNet2DModel`。
scheduler ([`KarrasVeScheduler`]):
用於與 `unet` 結合去噪編碼影像的排程器。
"""
# 為 linting 新增型別提示
unet: UNet2DModel # 定義 unet 型別為 UNet2DModel
scheduler: KarrasVeScheduler # 定義 scheduler 型別為 KarrasVeScheduler
# 初始化函式,接受 UNet2DModel 和 KarrasVeScheduler 作為引數
def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
# 呼叫父類的初始化函式
super().__init__()
# 註冊模組,將 unet 和 scheduler 註冊到當前例項中
self.register_modules(unet=unet, scheduler=scheduler)
# 裝飾器,表明此函式不需要梯度計算
@torch.no_grad()
def __call__(
self,
batch_size: int = 1, # 定義批處理大小,預設為 1
num_inference_steps: int = 50, # 定義推理步驟數量,預設為 50
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, # 可選生成器
output_type: Optional[str] = "pil", # 可選輸出型別,預設為 "pil"
return_dict: bool = True, # 是否返回字典,預設為 True
**kwargs, # 允許額外的關鍵字引數
.\diffusers\pipelines\deprecated\stochastic_karras_ve\__init__.py
# 從 typing 模組匯入 TYPE_CHECKING,用於型別檢查
from typing import TYPE_CHECKING
# 從相對路徑匯入工具模組中的 DIFFUSERS_SLOW_IMPORT 和 _LazyModule
from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
# 定義一個字典,描述要匯入的模組及其對應的類
_import_structure = {"pipeline_stochastic_karras_ve": ["KarrasVePipeline"]}
# 如果正在進行型別檢查或 DIFFUSERS_SLOW_IMPORT 為真
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
# 從 pipeline_stochastic_karras_ve 模組匯入 KarrasVePipeline 類
from .pipeline_stochastic_karras_ve import KarrasVePipeline
# 否則
else:
# 匯入 sys 模組,用於動態修改模組
import sys
# 使用 _LazyModule 建立懶載入模組,並將其賦值給當前模組名稱
sys.modules[__name__] = _LazyModule(
__name__, # 當前模組的名稱
globals()["__file__"], # 當前模組的檔案路徑
_import_structure, # 模組結構字典
module_spec=__spec__, # 當前模組的規格
)
.\diffusers\pipelines\deprecated\versatile_diffusion\modeling_text_unet.py
# 從 typing 模組匯入各種型別註解
from typing import Any, Dict, List, Optional, Tuple, Union
# 匯入 numpy 庫,用於陣列和矩陣操作
import numpy as np
# 匯入 PyTorch 庫,進行深度學習模型的構建和訓練
import torch
# 匯入 PyTorch 的神經網路模組
import torch.nn as nn
# 匯入 PyTorch 的功能性模組,提供常用操作
import torch.nn.functional as F
# 從 diffusers.utils 模組匯入 deprecate 函式,用於處理棄用警告
from diffusers.utils import deprecate
# 匯入配置相關的類和函式
from ....configuration_utils import ConfigMixin, register_to_config
# 匯入模型相關的基類
from ....models import ModelMixin
# 匯入啟用函式獲取工具
from ....models.activations import get_activation
# 匯入注意力處理器相關元件
from ....models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, # 額外來鍵值注意力處理器
CROSS_ATTENTION_PROCESSORS, # 交叉注意力處理器
Attention, # 注意力機制類
AttentionProcessor, # 注意力處理器基類
AttnAddedKVProcessor, # 額外來鍵值注意力處理器類
AttnAddedKVProcessor2_0, # 版本 2.0 的額外來鍵值注意力處理器
AttnProcessor, # 基礎注意力處理器
)
# 匯入嵌入層相關元件
from ....models.embeddings import (
GaussianFourierProjection, # 高斯傅立葉投影類
ImageHintTimeEmbedding, # 影像提示時間嵌入類
ImageProjection, # 影像投影類
ImageTimeEmbedding, # 影像時間嵌入類
TextImageProjection, # 文字影像投影類
TextImageTimeEmbedding, # 文字影像時間嵌入類
TextTimeEmbedding, # 文字時間嵌入類
TimestepEmbedding, # 時間步嵌入類
Timesteps, # 時間步類
)
# 匯入 ResNet 相關元件
from ....models.resnet import ResnetBlockCondNorm2D
# 匯入 2D 雙重變換器模型
from ....models.transformers.dual_transformer_2d import DualTransformer2DModel
# 匯入 2D 變換器模型
from ....models.transformers.transformer_2d import Transformer2DModel
# 匯入 2D 條件 UNet 輸出類
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
# 匯入工具函式和常量
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
# 匯入 PyTorch 相關工具函式
from ....utils.torch_utils import apply_freeu
# 建立日誌記錄器例項
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 定義獲取下采樣塊的函式
def get_down_block(
down_block_type, # 下采樣塊型別
num_layers, # 層數
in_channels, # 輸入通道數
out_channels, # 輸出通道數
temb_channels, # 時間嵌入通道數
add_downsample, # 是否新增下采樣
resnet_eps, # ResNet 中的 epsilon 值
resnet_act_fn, # ResNet 啟用函式
num_attention_heads, # 注意力頭數量
transformer_layers_per_block, # 每個塊中的變換器層數
attention_type, # 注意力型別
attention_head_dim, # 注意力頭維度
resnet_groups=None, # ResNet 組數(可選)
cross_attention_dim=None, # 交叉注意力維度(可選)
downsample_padding=None, # 下采樣填充(可選)
dual_cross_attention=False, # 是否使用雙重交叉注意力
use_linear_projection=False, # 是否使用線性投影
only_cross_attention=False, # 是否僅使用交叉注意力
upcast_attention=False, # 是否上升注意力
resnet_time_scale_shift="default", # ResNet 時間縮放偏移
resnet_skip_time_act=False, # ResNet 是否跳過時間啟用
resnet_out_scale_factor=1.0, # ResNet 輸出縮放因子
cross_attention_norm=None, # 交叉注意力歸一化(可選)
dropout=0.0, # dropout 機率
):
# 如果下采樣塊型別以 "UNetRes" 開頭,則去掉字首
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
# 如果下采樣塊型別為 "DownBlockFlat",則返回相應的塊例項
if down_block_type == "DownBlockFlat":
return DownBlockFlat(
num_layers=num_layers, # 層數
in_channels=in_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
dropout=dropout, # dropout 機率
add_downsample=add_downsample, # 是否新增下采樣
resnet_eps=resnet_eps, # ResNet 中的 epsilon 值
resnet_act_fn=resnet_act_fn, # ResNet 啟用函式
resnet_groups=resnet_groups, # ResNet 組數(可選)
downsample_padding=downsample_padding, # 下采樣填充(可選)
resnet_time_scale_shift=resnet_time_scale_shift, # ResNet 時間縮放偏移
)
# 檢查下采樣塊型別是否為 CrossAttnDownBlockFlat
elif down_block_type == "CrossAttnDownBlockFlat":
# 如果沒有指定 cross_attention_dim,則丟擲值錯誤
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat")
# 建立並返回 CrossAttnDownBlockFlat 例項,傳入所需引數
return CrossAttnDownBlockFlat(
# 設定網路層數
num_layers=num_layers,
# 設定輸入通道數
in_channels=in_channels,
# 設定輸出通道數
out_channels=out_channels,
# 設定時間嵌入通道數
temb_channels=temb_channels,
# 設定 dropout 比率
dropout=dropout,
# 設定是否新增下采樣層
add_downsample=add_downsample,
# 設定 ResNet 中的 epsilon 引數
resnet_eps=resnet_eps,
# 設定 ResNet 啟用函式
resnet_act_fn=resnet_act_fn,
# 設定 ResNet 組的數量
resnet_groups=resnet_groups,
# 設定下采樣的填充引數
downsample_padding=downsample_padding,
# 設定交叉注意力維度
cross_attention_dim=cross_attention_dim,
# 設定注意力頭的數量
num_attention_heads=num_attention_heads,
# 設定是否使用雙交叉注意力
dual_cross_attention=dual_cross_attention,
# 設定是否使用線性投影
use_linear_projection=use_linear_projection,
# 設定是否僅使用交叉注意力
only_cross_attention=only_cross_attention,
# 設定 ResNet 的時間尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 如果下采樣塊型別不被支援,則丟擲值錯誤
raise ValueError(f"{down_block_type} is not supported.")
# 根據給定引數建立上取樣塊的函式
def get_up_block(
# 上取樣塊型別
up_block_type,
# 網路層數
num_layers,
# 輸入通道數
in_channels,
# 輸出通道數
out_channels,
# 上一層輸出通道數
prev_output_channel,
# 條件嵌入通道數
temb_channels,
# 是否新增上取樣
add_upsample,
# ResNet 的 epsilon 值
resnet_eps,
# ResNet 的啟用函式
resnet_act_fn,
# 注意力頭數
num_attention_heads,
# 每個塊的 Transformer 層數
transformer_layers_per_block,
# 解析度索引
resolution_idx,
# 注意力型別
attention_type,
# 注意力頭維度
attention_head_dim,
# ResNet 組數,可選引數
resnet_groups=None,
# 跨注意力維度,可選引數
cross_attention_dim=None,
# 是否使用雙重跨注意力
dual_cross_attention=False,
# 是否使用線性投影
use_linear_projection=False,
# 是否僅使用跨注意力
only_cross_attention=False,
# 是否上溯注意力
upcast_attention=False,
# ResNet 時間尺度偏移,預設為 "default"
resnet_time_scale_shift="default",
# ResNet 是否跳過時間啟用
resnet_skip_time_act=False,
# ResNet 輸出縮放因子
resnet_out_scale_factor=1.0,
# 跨注意力歸一化型別,可選引數
cross_attention_norm=None,
# dropout 機率
dropout=0.0,
):
# 如果上取樣塊型別以 "UNetRes" 開頭,去掉字首
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
# 如果塊型別是 "UpBlockFlat",則返回相應的例項
if up_block_type == "UpBlockFlat":
return UpBlockFlat(
# 傳入各個引數
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 如果塊型別是 "CrossAttnUpBlockFlat"
elif up_block_type == "CrossAttnUpBlockFlat":
# 檢查跨注意力維度是否指定
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat")
# 返回相應的跨注意力上取樣塊例項
return CrossAttnUpBlockFlat(
# 傳入各個引數
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 如果塊型別不支援,丟擲異常
raise ValueError(f"{up_block_type} is not supported.")
# 定義一個 Fourier 嵌入器類,繼承自 nn.Module
class FourierEmbedder(nn.Module):
# 初始化方法,設定頻率和溫度
def __init__(self, num_freqs=64, temperature=100):
# 呼叫父類建構函式
super().__init__()
# 儲存頻率數
self.num_freqs = num_freqs
# 儲存溫度
self.temperature = temperature
# 計算頻率帶
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
# 擴充套件維度以便後續操作
freq_bands = freq_bands[None, None, None]
# 註冊頻率帶為緩衝區,設為非永續性
self.register_buffer("freq_bands", freq_bands, persistent=False)
# 定義呼叫方法,用於處理輸入
def __call__(self, x):
# 將輸入與頻率帶相乘
x = self.freq_bands * x.unsqueeze(-1)
# 返回處理後的結果,包含正弦和餘弦
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
# 定義 GLIGEN 文字邊界框投影類,繼承自 nn.Module
class GLIGENTextBoundingboxProjection(nn.Module):
# 初始化方法,設定物件的基本引數
def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8):
# 呼叫父類的初始化方法
super().__init__()
# 儲存正樣本的長度
self.positive_len = positive_len
# 儲存輸出的維度
self.out_dim = out_dim
# 初始化傅立葉嵌入器,設定頻率數量
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
# 計算位置特徵的維度,包含 sin 和 cos
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
# 如果輸出維度是元組,取第一個元素
if isinstance(out_dim, tuple):
out_dim = out_dim[0]
# 根據特徵型別設定線性層
if feature_type == "text-only":
self.linears = nn.Sequential(
# 第一層線性變換,輸入為正樣本長度加位置維度
nn.Linear(self.positive_len + self.position_dim, 512),
# 啟用函式使用 SiLU
nn.SiLU(),
# 第二層線性變換
nn.Linear(512, 512),
# 啟用函式使用 SiLU
nn.SiLU(),
# 輸出層
nn.Linear(512, out_dim),
)
# 定義一個全為零的引數,用於文字特徵的空值處理
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
# 處理文字和影像的特徵型別
elif feature_type == "text-image":
self.linears_text = nn.Sequential(
# 第一層線性變換
nn.Linear(self.positive_len + self.position_dim, 512),
# 啟用函式使用 SiLU
nn.SiLU(),
# 第二層線性變換
nn.Linear(512, 512),
# 啟用函式使用 SiLU
nn.SiLU(),
# 輸出層
nn.Linear(512, out_dim),
)
self.linears_image = nn.Sequential(
# 第一層線性變換
nn.Linear(self.positive_len + self.position_dim, 512),
# 啟用函式使用 SiLU
nn.SiLU(),
# 第二層線性變換
nn.Linear(512, 512),
# 啟用函式使用 SiLU
nn.SiLU(),
# 輸出層
nn.Linear(512, out_dim),
)
# 定義文字特徵的空值處理引數
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
# 定義影像特徵的空值處理引數
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
# 定義位置特徵的空值處理引數
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
# 前向傳播方法定義
def forward(
self,
boxes,
masks,
positive_embeddings=None,
phrases_masks=None,
image_masks=None,
phrases_embeddings=None,
image_embeddings=None,
):
# 在最後一維增加一個維度,便於後續操作
masks = masks.unsqueeze(-1)
# 透過傅立葉嵌入函式生成 boxes 的嵌入表示
xyxy_embedding = self.fourier_embedder(boxes)
# 獲取空白位置的特徵,並調整形狀為 (1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1)
# 計算加權嵌入,結合 masks 和空白位置特徵
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
# 如果存在正樣本嵌入
if positive_embeddings:
# 獲取正樣本的空白特徵,並調整形狀為 (1, 1, -1)
positive_null = self.null_positive_feature.view(1, 1, -1)
# 計算正樣本嵌入的加權,結合 masks 和空白特徵
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
# 將正樣本嵌入與 xyxy 嵌入連線並透過線性層處理
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
else:
# 在最後一維增加一個維度,便於後續操作
phrases_masks = phrases_masks.unsqueeze(-1)
image_masks = image_masks.unsqueeze(-1)
# 獲取文字和影像的空白特徵,並調整形狀為 (1, 1, -1)
text_null = self.null_text_feature.view(1, 1, -1)
image_null = self.null_image_feature.view(1, 1, -1)
# 計算文字嵌入的加權,結合 phrases_masks 和空白特徵
phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
# 計算影像嵌入的加權,結合 image_masks 和空白特徵
image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
# 將文字嵌入與 xyxy 嵌入連線並透過文字線性層處理
objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
# 將影像嵌入與 xyxy 嵌入連線並透過影像線性層處理
objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
# 將文字和影像的處理結果在維度 1 上連線
objs = torch.cat([objs_text, objs_image], dim=1)
# 返回最終的物件結果
return objs
# 定義一個名為 UNetFlatConditionModel 的類,繼承自 ModelMixin 和 ConfigMixin
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
r"""
一個條件 2D UNet 模型,它接收一個有噪聲的樣本、條件狀態和時間步,並返回一個樣本形狀的輸出。
該模型繼承自 [`ModelMixin`]。請檢視父類文件以瞭解其為所有模型實現的通用方法(例如下載或儲存)。
"""
# 設定該模型支援梯度檢查點
_supports_gradient_checkpointing = True
# 定義不進行拆分的模組名稱列表
_no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"]
# 註冊到配置的裝飾器
@register_to_config
# 初始化方法,設定類的基本引數
def __init__(
# 樣本大小,可選引數
self,
sample_size: Optional[int] = None,
# 輸入通道數,預設為4
in_channels: int = 4,
# 輸出通道數,預設為4
out_channels: int = 4,
# 是否將輸入樣本居中,預設為False
center_input_sample: bool = False,
# 是否將正弦函式翻轉為餘弦函式,預設為True
flip_sin_to_cos: bool = True,
# 頻率偏移量,預設為0
freq_shift: int = 0,
# 向下取樣塊的型別,預設為三個CrossAttnDownBlockFlat和一個DownBlockFlat
down_block_types: Tuple[str] = (
"CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat",
"DownBlockFlat",
),
# 中間塊的型別,預設為UNetMidBlockFlatCrossAttn
mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
# 向上取樣塊的型別,預設為一個UpBlockFlat和三個CrossAttnUpBlockFlat
up_block_types: Tuple[str] = (
"UpBlockFlat",
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
),
# 是否僅使用交叉注意力,預設為False
only_cross_attention: Union[bool, Tuple[bool]] = False,
# 塊輸出通道數,預設為320, 640, 1280, 1280
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
# 每個塊的層數,預設為2
layers_per_block: Union[int, Tuple[int]] = 2,
# 向下取樣時的填充大小,預設為1
downsample_padding: int = 1,
# 中間塊的縮放因子,預設為1
mid_block_scale_factor: float = 1,
# dropout比例,預設為0.0
dropout: float = 0.0,
# 啟用函式型別,預設為silu
act_fn: str = "silu",
# 歸一化的組數,可選引數,預設為32
norm_num_groups: Optional[int] = 32,
# 歸一化的epsilon值,預設為1e-5
norm_eps: float = 1e-5,
# 交叉注意力的維度,預設為1280
cross_attention_dim: Union[int, Tuple[int]] = 1280,
# 每個塊的變換器層數,預設為1
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
# 反向變換器層數的可選配置
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
# 編碼器隱藏維度的可選引數
encoder_hid_dim: Optional[int] = None,
# 編碼器隱藏維度型別的可選引數
encoder_hid_dim_type: Optional[str] = None,
# 注意力頭的維度,預設為8
attention_head_dim: Union[int, Tuple[int]] = 8,
# 注意力頭數量的可選引數
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
# 是否使用雙交叉注意力,預設為False
dual_cross_attention: bool = False,
# 是否使用線性投影,預設為False
use_linear_projection: bool = False,
# 類嵌入型別的可選引數
class_embed_type: Optional[str] = None,
# 附加嵌入型別的可選引數
addition_embed_type: Optional[str] = None,
# 附加時間嵌入維度的可選引數
addition_time_embed_dim: Optional[int] = None,
# 類嵌入數量的可選引數
num_class_embeds: Optional[int] = None,
# 是否向上投射注意力,預設為False
upcast_attention: bool = False,
# ResNet時間縮放偏移的預設值
resnet_time_scale_shift: str = "default",
# ResNet跳過時間啟用的設定,預設為False
resnet_skip_time_act: bool = False,
# ResNet輸出縮放因子,預設為1.0
resnet_out_scale_factor: int = 1.0,
# 時間嵌入型別,預設為positional
time_embedding_type: str = "positional",
# 時間嵌入維度的可選引數
time_embedding_dim: Optional[int] = None,
# 時間嵌入啟用函式的可選引數
time_embedding_act_fn: Optional[str] = None,
# 時間步後啟用的可選引數
timestep_post_act: Optional[str] = None,
# 時間條件投影維度的可選引數
time_cond_proj_dim: Optional[int] = None,
# 輸入卷積核的大小,預設為3
conv_in_kernel: int = 3,
# 輸出卷積核的大小,預設為3
conv_out_kernel: int = 3,
# 投影類嵌入輸入維度的可選引數
projection_class_embeddings_input_dim: Optional[int] = None,
# 注意力型別,預設為default
attention_type: str = "default",
# 類嵌入是否連線,預設為False
class_embeddings_concat: bool = False,
# 中間塊是否僅使用交叉注意力的可選引數
mid_block_only_cross_attention: Optional[bool] = None,
# 交叉注意力的歸一化型別的可選引數
cross_attention_norm: Optional[str] = None,
# 附加嵌入型別的頭數量,預設為64
addition_embed_type_num_heads=64,
# 宣告該方法為屬性
@property
# 定義一個返回注意力處理器字典的方法
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
返回值:
`dict` 的注意力處理器: 一個字典,包含模型中使用的所有注意力處理器,以其權重名稱為索引。
"""
# 初始化一個空字典以遞迴儲存處理器
processors = {}
# 定義一個遞迴函式來新增處理器
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 如果模組有獲取處理器的方法,則新增到字典中
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
# 遍歷模組的子模組,遞迴呼叫函式
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回處理器字典
return processors
# 遍歷當前模組的子模組,並呼叫遞迴函式
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
# 返回所有注意力處理器的字典
return processors
# 定義一個設定注意力處理器的方法
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
設定用於計算注意力的處理器。
引數:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
例項化的處理器類或處理器類的字典,將被設定為所有 `Attention` 層的處理器。
如果 `processor` 是字典,則鍵需要定義對應的交叉注意力處理器的路徑。
在設定可訓練的注意力處理器時,強烈推薦這種做法。
"""
# 計算當前注意力處理器的數量
count = len(self.attn_processors.keys())
# 如果傳入的是字典且數量不匹配,則引發錯誤
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"傳入的是處理器字典,但處理器的數量 {len(processor)} 與注意力層的數量 {count} 不匹配。"
f" 請確保傳入 {count} 個處理器類。"
)
# 定義一個遞迴函式來設定注意力處理器
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 如果模組有設定處理器的方法,則根據傳入的處理器設定
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
# 遍歷模組的子模組,遞迴呼叫函式
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍歷當前模組的子模組,並呼叫遞迴函式
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# 設定預設的注意力處理器
def set_default_attn_processor(self):
"""
禁用自定義注意力處理器,並設定預設的注意力實現。
"""
# 檢查所有注意力處理器是否屬於已新增的 KV 注意力處理器
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 使用 AttnAddedKVProcessor 作為處理器
processor = AttnAddedKVProcessor()
# 檢查所有注意力處理器是否屬於交叉注意力處理器
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 使用 AttnProcessor 作為處理器
processor = AttnProcessor()
else:
# 如果處理器型別不匹配,則引發值錯誤
raise ValueError(
f"當注意力處理器的型別為 {next(iter(self.attn_processors.values()))} 時,無法呼叫 `set_default_attn_processor`"
)
# 設定選定的注意力處理器
self.set_attn_processor(processor)
# 設定梯度檢查點
def _set_gradient_checkpointing(self, module, value=False):
# 如果模組具有 gradient_checkpointing 屬性,則設定其值
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# 啟用 FreeU 機制
def enable_freeu(self, s1, s2, b1, b2):
r"""啟用來自 https://arxiv.org/abs/2309.11497 的 FreeU 機制。
縮放因子的字尾表示應用的階段塊。
請參考 [官方庫](https://github.com/ChenyangSi/FreeU) 以獲取已知在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中表現良好的值組合。
引數:
s1 (`float`):
階段 1 的縮放因子,用於減弱跳過特徵的貢獻。這是為了減輕增強去噪過程中的“過平滑效應”。
s2 (`float`):
階段 2 的縮放因子,用於減弱跳過特徵的貢獻。這是為了減輕增強去噪過程中的“過平滑效應”。
b1 (`float`): 階段 1 的縮放因子,用於增強主幹特徵的貢獻。
b2 (`float`): 階段 2 的縮放因子,用於增強主幹特徵的貢獻。
"""
# 遍歷上取樣塊並設定相應的縮放因子
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1) # 設定階段 1 的縮放因子
setattr(upsample_block, "s2", s2) # 設定階段 2 的縮放因子
setattr(upsample_block, "b1", b1) # 設定階段 1 的主幹特徵縮放因子
setattr(upsample_block, "b2", b2) # 設定階段 2 的主幹特徵縮放因子
# 禁用 FreeU 機制
def disable_freeu(self):
"""禁用 FreeU 機制。"""
freeu_keys = {"s1", "s2", "b1", "b2"} # FreeU 機制的關鍵字集合
# 遍歷上取樣塊並將關鍵字的值設定為 None
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
# 如果上取樣塊具有該屬性或屬性值不為 None,則將其設定為 None
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
# 定義一個用於融合 QKV 投影的函式
def fuse_qkv_projections(self):
# 文件字串,描述該函式的作用及實驗性質
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
# 初始化原始注意力處理器為 None
self.original_attn_processors = None
# 遍歷所有注意力處理器
for _, attn_processor in self.attn_processors.items():
# 檢查處理器的類名是否包含 "Added"
if "Added" in str(attn_processor.__class__.__name__):
# 如果是,丟擲錯誤提示不支援融合
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# 儲存原始的注意力處理器
self.original_attn_processors = self.attn_processors
# 遍歷所有模組
for module in self.modules():
# 檢查模組是否是 Attention 類的例項
if isinstance(module, Attention):
# 融合投影
module.fuse_projections(fuse=True)
# 定義一個用於取消 QKV 投影融合的函式
def unfuse_qkv_projections(self):
# 文件字串,描述該函式的作用及實驗性質
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
# 檢查原始注意力處理器是否不為 None
if self.original_attn_processors is not None:
# 恢復到原始的注意力處理器
self.set_attn_processor(self.original_attn_processors)
# 定義一個用於解除安裝 LoRA 權重的函式
def unload_lora(self):
# 文件字串,描述該函式的作用
"""Unloads LoRA weights."""
# 發出解除安裝的棄用警告
deprecate(
"unload_lora",
"0.28.0",
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
)
# 遍歷所有模組
for module in self.modules():
# 檢查模組是否具有 set_lora_layer 屬性
if hasattr(module, "set_lora_layer"):
# 將 LoRA 層設定為 None
module.set_lora_layer(None)
# 定義前向傳播函式
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
# 定義一個繼承自 nn.Linear 的線性多維層
class LinearMultiDim(nn.Linear):
# 初始化方法,接受輸入特徵、輸出特徵及其他引數
def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs):
# 如果 in_features 是整數,則將其轉換為包含三個維度的列表
in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features)
# 如果未提供 out_features,則將其設定為 in_features
if out_features is None:
out_features = in_features
# 如果 out_features 是整數,則轉換為包含三個維度的列表
out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features)
# 儲存輸入特徵的多維資訊
self.in_features_multidim = in_features
# 儲存輸出特徵的多維資訊
self.out_features_multidim = out_features
# 呼叫父類的初始化方法,計算輸入和輸出特徵的總數量
super().__init__(np.array(in_features).prod(), np.array(out_features).prod())
# 定義前向傳播方法
def forward(self, input_tensor, *args, **kwargs):
# 獲取輸入張量的形狀
shape = input_tensor.shape
# 獲取輸入特徵的維度數量
n_dim = len(self.in_features_multidim)
# 將輸入張量重塑為適合線性層的形狀
input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features)
# 呼叫父類的前向傳播方法,得到輸出張量
output_tensor = super().forward(input_tensor)
# 將輸出張量重塑為目標形狀
output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim)
# 返回輸出張量
return output_tensor
# 定義一個平坦的殘差塊類,繼承自 nn.Module
class ResnetBlockFlat(nn.Module):
# 初始化方法,接受多個引數,包括通道數、丟棄率等
def __init__(
self,
*,
in_channels,
out_channels=None,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
time_embedding_norm="default",
use_in_shortcut=None,
second_dim=4,
**kwargs,
# 初始化方法的結束,接收引數
):
# 呼叫父類的初始化方法
super().__init__()
# 是否進行預歸一化,設定為傳入的值
self.pre_norm = pre_norm
# 將預歸一化設定為 True
self.pre_norm = True
# 如果輸入通道是整數,則構造一個包含三個維度的列表
in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels)
# 計算輸入通道數的乘積
self.in_channels_prod = np.array(in_channels).prod()
# 儲存輸入通道的多維資訊
self.channels_multidim = in_channels
# 如果輸出通道不為 None
if out_channels is not None:
# 如果輸出通道是整數,構造一個包含三個維度的列表
out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels)
# 計算輸出通道數的乘積
out_channels_prod = np.array(out_channels).prod()
# 儲存輸出通道的多維資訊
self.out_channels_multidim = out_channels
else:
# 如果輸出通道為 None,則輸出通道乘積等於輸入通道乘積
out_channels_prod = self.in_channels_prod
# 輸出通道的多維資訊與輸入通道相同
self.out_channels_multidim = self.channels_multidim
# 儲存時間嵌入的歸一化狀態
self.time_embedding_norm = time_embedding_norm
# 如果輸出組數為 None,使用傳入的組數
if groups_out is None:
groups_out = groups
# 建立第一個歸一化層,使用組歸一化
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True)
# 建立第一個卷積層,使用輸入通道和輸出通道乘積
self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0)
# 如果時間嵌入通道不為 None
if temb_channels is not None:
# 建立時間嵌入投影層
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod)
else:
# 如果時間嵌入通道為 None,則不進行投影
self.time_emb_proj = None
# 建立第二個歸一化層,使用輸出組數和輸出通道乘積
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True)
# 建立丟棄層,使用傳入的丟棄率
self.dropout = torch.nn.Dropout(dropout)
# 建立第二個卷積層,使用輸出通道乘積
self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0)
# 設定非線性啟用函式為 SiLU
self.nonlinearity = nn.SiLU()
# 檢查是否使用輸入短路,如果短路使用引數為 None,則根據通道數判斷
self.use_in_shortcut = (
self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut
)
# 初始化快捷連線卷積為 None
self.conv_shortcut = None
# 如果使用輸入短路
if self.use_in_shortcut:
# 建立快捷連線卷積層
self.conv_shortcut = torch.nn.Conv2d(
self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0
)
# 定義前向傳播方法,接收輸入張量和時間嵌入
def forward(self, input_tensor, temb):
# 獲取輸入張量的形狀
shape = input_tensor.shape
# 獲取多維通道的維度數
n_dim = len(self.channels_multidim)
# 調整輸入張量形狀,合併通道維度並增加兩個維度
input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1)
# 將張量檢視轉換為指定形狀,保持通道數並增加兩個維度
input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1)
# 初始化隱藏狀態為輸入張量
hidden_states = input_tensor
# 對隱藏狀態進行歸一化處理
hidden_states = self.norm1(hidden_states)
# 應用非線性啟用函式
hidden_states = self.nonlinearity(hidden_states)
# 透過第一個卷積層處理隱藏狀態
hidden_states = self.conv1(hidden_states)
# 如果時間嵌入不為空
if temb is not None:
# 對時間嵌入進行非線性處理並調整形狀
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
# 將時間嵌入與隱藏狀態相加
hidden_states = hidden_states + temb
# 對隱藏狀態進行第二次歸一化處理
hidden_states = self.norm2(hidden_states)
# 再次應用非線性啟用函式
hidden_states = self.nonlinearity(hidden_states)
# 對隱藏狀態應用 dropout 操作
hidden_states = self.dropout(hidden_states)
# 透過第二個卷積層處理隱藏狀態
hidden_states = self.conv2(hidden_states)
# 如果存在短路卷積層
if self.conv_shortcut is not None:
# 透過短路卷積層處理輸入張量
input_tensor = self.conv_shortcut(input_tensor)
# 將輸入張量與隱藏狀態相加,生成輸出張量
output_tensor = input_tensor + hidden_states
# 將輸出張量調整為指定形狀,去掉多餘的維度
output_tensor = output_tensor.view(*shape[0:-n_dim], -1)
# 再次調整輸出張量的形狀,匹配輸出通道的多維結構
output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim)
# 返回最終的輸出張量
return output_tensor
# 定義一個名為 DownBlockFlat 的類,繼承自 nn.Module
class DownBlockFlat(nn.Module):
# 初始化方法,接受多個引數用於配置模型
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
dropout: float = 0.0, # dropout 機率
num_layers: int = 1, # ResNet 層數
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值
resnet_time_scale_shift: str = "default", # ResNet 的時間縮放偏移
resnet_act_fn: str = "swish", # ResNet 的啟用函式
resnet_groups: int = 32, # ResNet 的分組數
resnet_pre_norm: bool = True, # 是否在 ResNet 前進行歸一化
output_scale_factor: float = 1.0, # 輸出縮放因子
add_downsample: bool = True, # 是否新增下采樣層
downsample_padding: int = 1, # 下采樣時的填充
):
# 呼叫父類的初始化方法
super().__init__()
# 初始化一個空列表,用於存放 ResNet 層
resnets = []
# 迴圈建立指定數量的 ResNet 層
for i in range(num_layers):
# 第一層使用輸入通道,之後的層使用輸出通道
in_channels = in_channels if i == 0 else out_channels
# 將 ResNet 層新增到列表中
resnets.append(
ResnetBlockFlat(
in_channels=in_channels, # 當前層的輸入通道數
out_channels=out_channels, # 當前層的輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 分組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化方式
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 是否前歸一化
)
)
# 將 ResNet 層列表轉為 nn.ModuleList 以便於管理
self.resnets = nn.ModuleList(resnets)
# 根據引數決定是否新增下采樣層
if add_downsample:
self.downsamplers = nn.ModuleList(
[
LinearMultiDim(
out_channels, # 輸入通道數
use_conv=True, # 使用卷積
out_channels=out_channels, # 輸出通道數
padding=downsample_padding, # 填充
name="op" # 下采樣層名稱
)
]
)
else:
# 如果不新增下采樣層,設定為 None
self.downsamplers = None
# 初始化梯度檢查點為 False
self.gradient_checkpointing = False
# 定義前向傳播方法
def forward(
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None # 輸入的隱藏狀態和可選的時間嵌入
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 初始化輸出狀態為一個空元組
output_states = ()
# 遍歷所有的 ResNet 層
for resnet in self.resnets:
# 如果在訓練模式且開啟了梯度檢查點
if self.training and self.gradient_checkpointing:
# 定義一個建立自定義前向傳播的方法
def create_custom_forward(module):
# 定義自定義前向傳播函式
def custom_forward(*inputs):
return module(*inputs) # 呼叫模組進行前向傳播
return custom_forward
# 檢查 PyTorch 版本,使用不同的呼叫方式
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False # 進行梯度檢查點的前向傳播
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb # 進行梯度檢查點的前向傳播
)
else:
# 正常呼叫 ResNet 層進行前向傳播
hidden_states = resnet(hidden_states, temb)
# 將當前隱藏狀態新增到輸出狀態中
output_states = output_states + (hidden_states,)
# 如果存在下采樣層
if self.downsamplers is not None:
# 遍歷所有下采樣層
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states) # 對隱藏狀態進行下采樣
# 將下采樣後的隱藏狀態新增到輸出狀態中
output_states = output_states + (hidden_states,)
# 返回最終的隱藏狀態和所有輸出狀態
return hidden_states, output_states
# 定義一個名為 CrossAttnDownBlockFlat 的類,繼承自 nn.Module
class CrossAttnDownBlockFlat(nn.Module):
# 初始化方法,定義類的屬性
def __init__(
# 輸入通道數
self,
in_channels: int,
# 輸出通道數
out_channels: int,
# 時間嵌入通道數
temb_channels: int,
# dropout 機率,預設為 0.0
dropout: float = 0.0,
# 層數,預設為 1
num_layers: int = 1,
# 每個塊的變換器層數,可以是整數或整數元組,預設為 1
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 的 epsilon 值,預設為 1e-6
resnet_eps: float = 1e-6,
# ResNet 的時間尺度偏移設定,預設為 "default"
resnet_time_scale_shift: str = "default",
# ResNet 的啟用函式,預設為 "swish"
resnet_act_fn: str = "swish",
# ResNet 的組數,預設為 32
resnet_groups: int = 32,
# 是否使用預歸一化,預設為 True
resnet_pre_norm: bool = True,
# 注意力頭的數量,預設為 1
num_attention_heads: int = 1,
# 交叉注意力的維度,預設為 1280
cross_attention_dim: int = 1280,
# 輸出縮放因子,預設為 1.0
output_scale_factor: float = 1.0,
# 下采樣的填充大小,預設為 1
downsample_padding: int = 1,
# 是否新增下采樣層,預設為 True
add_downsample: bool = True,
# 是否使用雙重交叉注意力,預設為 False
dual_cross_attention: bool = False,
# 是否使用線性投影,預設為 False
use_linear_projection: bool = False,
# 是否只使用交叉注意力,預設為 False
only_cross_attention: bool = False,
# 是否上溯注意力,預設為 False
upcast_attention: bool = False,
# 注意力型別,預設為 "default"
attention_type: str = "default",
):
# 呼叫父類的建構函式以初始化基類
super().__init__()
# 初始化儲存 ResNet 塊的列表
resnets = []
# 初始化儲存注意力模型的列表
attentions = []
# 設定是否使用交叉注意力的標誌
self.has_cross_attention = True
# 設定注意力頭的數量
self.num_attention_heads = num_attention_heads
# 如果 transformer_layers_per_block 是一個整數,則將其轉換為列表形式
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# 為每一層構建 ResNet 塊和注意力模型
for i in range(num_layers):
# 設定當前層的輸入通道數,第一層使用 in_channels,其他層使用 out_channels
in_channels = in_channels if i == 0 else out_channels
# 向 resnets 列表新增一個 ResNet 塊
resnets.append(
ResnetBlockFlat(
# 設定 ResNet 塊的輸入通道數
in_channels=in_channels,
# 設定 ResNet 塊的輸出通道數
out_channels=out_channels,
# 設定時間嵌入通道數
temb_channels=temb_channels,
# 設定 ResNet 塊的 epsilon 值
eps=resnet_eps,
# 設定 ResNet 塊的組數
groups=resnet_groups,
# 設定 dropout 機率
dropout=dropout,
# 設定時間嵌入的歸一化方法
time_embedding_norm=resnet_time_scale_shift,
# 設定啟用函式
non_linearity=resnet_act_fn,
# 設定輸出縮放因子
output_scale_factor=output_scale_factor,
# 設定是否在前面進行歸一化
pre_norm=resnet_pre_norm,
)
)
# 如果不使用雙交叉注意力
if not dual_cross_attention:
# 向 attentions 列表新增一個 Transformer 2D 模型
attentions.append(
Transformer2DModel(
# 設定注意力頭的數量
num_attention_heads,
# 設定每個注意力頭的輸出通道數
out_channels // num_attention_heads,
# 設定輸入通道數
in_channels=out_channels,
# 設定當前層的 Transformer 層數
num_layers=transformer_layers_per_block[i],
# 設定交叉注意力的維度
cross_attention_dim=cross_attention_dim,
# 設定歸一化的組數
norm_num_groups=resnet_groups,
# 設定是否使用線性投影
use_linear_projection=use_linear_projection,
# 設定是否僅使用交叉注意力
only_cross_attention=only_cross_attention,
# 設定是否提高注意力精度
upcast_attention=upcast_attention,
# 設定注意力型別
attention_type=attention_type,
)
)
else:
# 向 attentions 列表新增一個雙 Transformer 2D 模型
attentions.append(
DualTransformer2DModel(
# 設定注意力頭的數量
num_attention_heads,
# 設定每個注意力頭的輸出通道數
out_channels // num_attention_heads,
# 設定輸入通道數
in_channels=out_channels,
# 固定層數為 1
num_layers=1,
# 設定交叉注意力的維度
cross_attention_dim=cross_attention_dim,
# 設定歸一化的組數
norm_num_groups=resnet_groups,
)
)
# 將注意力模型列表轉換為 PyTorch 的 ModuleList
self.attentions = nn.ModuleList(attentions)
# 將 ResNet 塊列表轉換為 PyTorch 的 ModuleList
self.resnets = nn.ModuleList(resnets)
# 如果需要新增下采樣層
if add_downsample:
# 初始化下采樣層為 ModuleList
self.downsamplers = nn.ModuleList(
[
LinearMultiDim(
# 設定輸出通道數
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
# 如果不新增下采樣層,將其設為 None
self.downsamplers = None
# 初始化梯度檢查點標誌為 False
self.gradient_checkpointing = False
# 定義前向傳播函式,接收隱藏狀態和其他可選引數
def forward(
self,
hidden_states: torch.Tensor, # 當前隱藏狀態的張量
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可選的編碼器隱藏狀態張量
attention_mask: Optional[torch.Tensor] = None, # 可選的注意力掩碼
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 可選的交叉注意力引數
encoder_attention_mask: Optional[torch.Tensor] = None, # 可選的編碼器注意力掩碼
additional_residuals: Optional[torch.Tensor] = None, # 可選的額外殘差張量
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: # 返回隱藏狀態和輸出狀態元組
output_states = () # 初始化輸出狀態元組
blocks = list(zip(self.resnets, self.attentions)) # 將殘差網路和注意力模組配對成塊
for i, (resnet, attn) in enumerate(blocks): # 遍歷每個塊及其索引
if self.training and self.gradient_checkpointing: # 檢查是否在訓練且啟用梯度檢查點
def create_custom_forward(module, return_dict=None): # 定義自定義前向傳播函式
def custom_forward(*inputs): # 自定義前向傳播邏輯
if return_dict is not None: # 如果提供了返回字典
return module(*inputs, return_dict=return_dict) # 返回帶字典的結果
else:
return module(*inputs) # 否則返回普通結果
return custom_forward # 返回自定義前向傳播函式
# 設定檢查點引數,如果 PyTorch 版本大於等於 1.11.0,則使用非重入模式
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 透過檢查點機制計算當前塊的隱藏狀態
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 建立自定義前向函式的檢查點
hidden_states, # 輸入當前隱藏狀態
temb, # 輸入時間嵌入
**ckpt_kwargs, # 傳遞檢查點引數
)
# 透過注意力模組處理隱藏狀態並獲取輸出
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states, # 編碼器隱藏狀態
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力引數
attention_mask=attention_mask, # 注意力掩碼
encoder_attention_mask=encoder_attention_mask, # 編碼器注意力掩碼
return_dict=False, # 不返回字典格式
)[0] # 取出第一個輸出
else: # 如果不啟用梯度檢查
# 直接透過殘差網路處理隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 透過注意力模組處理隱藏狀態並獲取輸出
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states, # 編碼器隱藏狀態
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力引數
attention_mask=attention_mask, # 注意力掩碼
encoder_attention_mask=encoder_attention_mask, # 編碼器注意力掩碼
return_dict=False, # 不返回字典格式
)[0] # 取出第一個輸出
# 如果是最後一個塊並且提供了額外殘差,則將其新增到隱藏狀態
if i == len(blocks) - 1 and additional_residuals is not None:
hidden_states = hidden_states + additional_residuals # 加上額外殘差
output_states = output_states + (hidden_states,) # 將當前隱藏狀態新增到輸出狀態元組中
if self.downsamplers is not None: # 如果存在下采樣器
for downsampler in self.downsamplers: # 遍歷每個下采樣器
hidden_states = downsampler(hidden_states) # 處理當前隱藏狀態
output_states = output_states + (hidden_states,) # 將當前隱藏狀態新增到輸出狀態元組中
return hidden_states, output_states # 返回最終的隱藏狀態和輸出狀態元組
# 從 diffusers.models.unets.unet_2d_blocks 中複製,替換 UpBlock2D 為 UpBlockFlat,ResnetBlock2D 為 ResnetBlockFlat,Upsample2D 為 LinearMultiDim
class UpBlockFlat(nn.Module):
# 初始化函式,定義輸入輸出通道及其他引數
def __init__(
self,
in_channels: int, # 輸入通道數
prev_output_channel: int, # 前一層輸出通道數
out_channels: int, # 當前層輸出通道數
temb_channels: int, # 時間嵌入通道數
resolution_idx: Optional[int] = None, # 解析度索引
dropout: float = 0.0, # dropout 機率
num_layers: int = 1, # 層數
resnet_eps: float = 1e-6, # ResNet 中的 epsilon 值
resnet_time_scale_shift: str = "default", # 時間尺度偏移設定
resnet_act_fn: str = "swish", # 啟用函式型別
resnet_groups: int = 32, # 分組數
resnet_pre_norm: bool = True, # 是否進行預歸一化
output_scale_factor: float = 1.0, # 輸出縮放因子
add_upsample: bool = True, # 是否新增上取樣
):
# 呼叫父類建構函式
super().__init__()
# 初始化一個空列表儲存 ResNet 塊
resnets = []
# 遍歷層數,構建每一層的 ResNet 塊
for i in range(num_layers):
# 根據層數決定殘差跳躍通道數
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 根據當前層數決定輸入通道數
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 將 ResNet 塊新增到列表中
resnets.append(
ResnetBlockFlat(
in_channels=resnet_in_channels + res_skip_channels, # 輸入通道數
out_channels=out_channels, # 輸出通道數
temb_channels=temb_channels, # 時間嵌入通道數
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 分組數
dropout=dropout, # dropout 機率
time_embedding_norm=resnet_time_scale_shift, # 時間嵌入歸一化
non_linearity=resnet_act_fn, # 啟用函式
output_scale_factor=output_scale_factor, # 輸出縮放因子
pre_norm=resnet_pre_norm, # 預歸一化
)
)
# 將 ResNet 塊列表轉換為模組列表
self.resnets = nn.ModuleList(resnets)
# 如果需要新增上取樣層,則建立上取樣模組
if add_upsample:
self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 否則設定為 None
self.upsamplers = None
# 初始化梯度檢查點標誌
self.gradient_checkpointing = False
# 設定解析度索引
self.resolution_idx = resolution_idx
# 前向傳播函式
def forward(
self,
hidden_states: torch.Tensor, # 隱藏狀態張量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 殘差隱藏狀態元組
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
upsample_size: Optional[int] = None, # 可選的上取樣大小
*args, # 可變引數
**kwargs, # 可變關鍵字引數
) -> torch.Tensor: # 定義一個返回 torch.Tensor 型別的函式
# 如果引數列表 args 長度大於 0 或 kwargs 中的 scale 引數不為 None
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 設定棄用訊息,提醒使用者 scale 引數已棄用且將被忽略
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫 deprecate 函式,記錄 scale 引數的棄用
deprecate("scale", "1.0.0", deprecation_message)
# 檢查 FreeU 是否啟用,取決於 s1, s2, b1 和 b2 的值
is_freeu_enabled = (
getattr(self, "s1", None) # 獲取 self 中的 s1 屬性
and getattr(self, "s2", None) # 獲取 self 中的 s2 屬性
and getattr(self, "b1", None) # 獲取 self 中的 b1 屬性
and getattr(self, "b2", None) # 獲取 self 中的 b2 屬性
)
# 遍歷 self.resnets 中的每個 ResNet 模型
for resnet in self.resnets:
# 彈出 res 隱藏狀態的最後一個元素
res_hidden_states = res_hidden_states_tuple[-1]
# 移除 res 隱藏狀態元組的最後一個元素
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: 僅在前兩個階段進行操作
if is_freeu_enabled:
# 應用 FreeU 操作,返回更新後的 hidden_states 和 res_hidden_states
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx, # 當前解析度索引
hidden_states, # 當前隱藏狀態
res_hidden_states, # 之前的隱藏狀態
s1=self.s1, # s1 引數
s2=self.s2, # s2 引數
b1=self.b1, # b1 引數
b2=self.b2, # b2 引數
)
# 將當前的 hidden_states 和 res_hidden_states 在維度 1 上拼接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 如果處於訓練模式並且開啟了梯度檢查點
if self.training and self.gradient_checkpointing:
# 定義一個建立自定義前向函式的函式
def create_custom_forward(module):
# 定義自定義前向函式,接收輸入並呼叫模組
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 如果 PyTorch 版本大於等於 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用梯度檢查點來計算 hidden_states
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 使用自定義前向函式
hidden_states, # 當前隱藏狀態
temb, # 傳入的額外輸入
use_reentrant=False # 禁用重入檢查
)
else:
# 對於早期版本,使用梯度檢查點計算 hidden_states
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 使用自定義前向函式
hidden_states, # 當前隱藏狀態
temb # 傳入的額外輸入
)
else:
# 在非訓練模式下直接呼叫 resnet 處理 hidden_states
hidden_states = resnet(hidden_states, temb)
# 如果存在上取樣器
if self.upsamplers is not None:
# 遍歷所有上取樣器
for upsampler in self.upsamplers:
# 使用上取樣器對 hidden_states 進行處理,指定上取樣尺寸
hidden_states = upsampler(hidden_states, upsample_size)
# 返回處理後的 hidden_states
return hidden_states
# 從 diffusers.models.unets.unet_2d_blocks 中複製的程式碼,修改了類名和一些元件
class CrossAttnUpBlockFlat(nn.Module):
# 初始化方法,定義類的基本屬性和引數
def __init__(
# 輸入通道數
in_channels: int,
# 輸出通道數
out_channels: int,
# 上一層輸出的通道數
prev_output_channel: int,
# 額外的時間嵌入通道數
temb_channels: int,
# 可選的解析度索引
resolution_idx: Optional[int] = None,
# dropout 機率
dropout: float = 0.0,
# 層數
num_layers: int = 1,
# 每個塊的變換器層數,可以是單個整數或元組
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 的 epsilon 值
resnet_eps: float = 1e-6,
# ResNet 時間尺度偏移的型別
resnet_time_scale_shift: str = "default",
# ResNet 啟用函式的型別
resnet_act_fn: str = "swish",
# ResNet 的組數
resnet_groups: int = 32,
# 是否在 ResNet 中使用預歸一化
resnet_pre_norm: bool = True,
# 注意力頭的數量
num_attention_heads: int = 1,
# 交叉注意力的維度
cross_attention_dim: int = 1280,
# 輸出縮放因子
output_scale_factor: float = 1.0,
# 是否新增上取樣步驟
add_upsample: bool = True,
# 是否使用雙交叉注意力
dual_cross_attention: bool = False,
# 是否使用線性投影
use_linear_projection: bool = False,
# 是否僅使用交叉注意力
only_cross_attention: bool = False,
# 是否上溯注意力
upcast_attention: bool = False,
# 注意力型別
attention_type: str = "default",
# 定義建構函式的結束部分
):
# 呼叫父類的建構函式
super().__init__()
# 初始化一個空列表用於儲存殘差網路塊
resnets = []
# 初始化一個空列表用於儲存注意力模型
attentions = []
# 設定是否使用交叉注意力標誌為真
self.has_cross_attention = True
# 設定注意力頭的數量
self.num_attention_heads = num_attention_heads
# 如果 transformer_layers_per_block 是整數,則將其轉換為相同長度的列表
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# 遍歷每一層以構建殘差網路和注意力模型
for i in range(num_layers):
# 設定殘差跳過通道數,最後一層使用輸入通道,否則使用輸出通道
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 設定殘差網路輸入通道數,第一層使用前一層輸出通道,否則使用當前輸出通道
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 新增一個殘差網路塊到 resnets 列表中
resnets.append(
ResnetBlockFlat(
# 設定殘差網路輸入通道數
in_channels=resnet_in_channels + res_skip_channels,
# 設定殘差網路輸出通道數
out_channels=out_channels,
# 設定時間嵌入通道數
temb_channels=temb_channels,
# 設定殘差網路的 epsilon 值
eps=resnet_eps,
# 設定殘差網路的組數
groups=resnet_groups,
# 設定丟棄率
dropout=dropout,
# 設定時間嵌入的歸一化方法
time_embedding_norm=resnet_time_scale_shift,
# 設定非線性啟用函式
non_linearity=resnet_act_fn,
# 設定輸出縮放因子
output_scale_factor=output_scale_factor,
# 設定是否進行預歸一化
pre_norm=resnet_pre_norm,
)
)
# 如果不使用雙重交叉注意力
if not dual_cross_attention:
# 新增一個普通的 Transformer2DModel 到 attentions 列表中
attentions.append(
Transformer2DModel(
# 設定注意力頭的數量
num_attention_heads,
# 設定每個注意力頭的輸出通道數
out_channels // num_attention_heads,
# 設定輸入通道數
in_channels=out_channels,
# 設定層數
num_layers=transformer_layers_per_block[i],
# 設定交叉注意力維度
cross_attention_dim=cross_attention_dim,
# 設定歸一化組數
norm_num_groups=resnet_groups,
# 設定是否使用線性投影
use_linear_projection=use_linear_projection,
# 設定是否僅使用交叉注意力
only_cross_attention=only_cross_attention,
# 設定是否上溯注意力
upcast_attention=upcast_attention,
# 設定注意力型別
attention_type=attention_type,
)
)
else:
# 新增一個雙重 Transformer2DModel 到 attentions 列表中
attentions.append(
DualTransformer2DModel(
# 設定注意力頭的數量
num_attention_heads,
# 設定每個注意力頭的輸出通道數
out_channels // num_attention_heads,
# 設定輸入通道數
in_channels=out_channels,
# 設定層數為 1
num_layers=1,
# 設定交叉注意力維度
cross_attention_dim=cross_attention_dim,
# 設定歸一化組數
norm_num_groups=resnet_groups,
)
)
# 將注意力模型列表轉換為 nn.ModuleList
self.attentions = nn.ModuleList(attentions)
# 將殘差網路塊列表轉換為 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 如果需要新增上取樣層
if add_upsample:
# 將上取樣器新增到 nn.ModuleList 中
self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 否則將上取樣器設定為 None
self.upsamplers = None
# 設定梯度檢查點標誌為假
self.gradient_checkpointing = False
# 設定解析度索引
self.resolution_idx = resolution_idx
# 定義前向傳播函式,接收多個輸入引數
def forward(
self,
# 隱藏狀態,型別為 PyTorch 的張量
hidden_states: torch.Tensor,
# 包含殘差隱藏狀態的元組,元素型別為 PyTorch 張量
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可選的時間嵌入,型別為 PyTorch 的張量
temb: Optional[torch.Tensor] = None,
# 可選的編碼器隱藏狀態,型別為 PyTorch 的張量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可選的交叉注意力引數,型別為字典,包含任意鍵值對
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可選的上取樣大小,型別為整數
upsample_size: Optional[int] = None,
# 可選的注意力掩碼,型別為 PyTorch 的張量
attention_mask: Optional[torch.Tensor] = None,
# 可選的編碼器注意力掩碼,型別為 PyTorch 的張量
encoder_attention_mask: Optional[torch.Tensor] = None,
# 從 diffusers.models.unets.unet_2d_blocks 中複製的 UNetMidBlock2D 程式碼,替換了 UNetMidBlock2D 為 UNetMidBlockFlat,ResnetBlock2D 為 ResnetBlockFlat
class UNetMidBlockFlat(nn.Module):
"""
2D UNet 中間塊 [`UNetMidBlockFlat`],包含多個殘差塊和可選的注意力塊。
引數:
in_channels (`int`): 輸入通道的數量。
temb_channels (`int`): 時間嵌入通道的數量。
dropout (`float`, *可選*, 預設值為 0.0): dropout 比率。
num_layers (`int`, *可選*, 預設值為 1): 殘差塊的數量。
resnet_eps (`float`, *可選*, 預設值為 1e-6): resnet 塊的 epsilon 值。
resnet_time_scale_shift (`str`, *可選*, 預設值為 `default`):
應用於時間嵌入的歸一化型別。這可以幫助提高模型在長範圍時間依賴任務上的效能。
resnet_act_fn (`str`, *可選*, 預設值為 `swish`): resnet 塊的啟用函式。
resnet_groups (`int`, *可選*, 預設值為 32):
resnet 塊的分組歸一化層使用的組數。
attn_groups (`Optional[int]`, *可選*, 預設值為 None): 注意力塊的組數。
resnet_pre_norm (`bool`, *可選*, 預設值為 `True`):
是否在 resnet 塊中使用預歸一化。
add_attention (`bool`, *可選*, 預設值為 `True`): 是否新增註意力塊。
attention_head_dim (`int`, *可選*, 預設值為 1):
單個注意力頭的維度。注意力頭的數量基於此值和輸入通道的數量確定。
output_scale_factor (`float`, *可選*, 預設值為 1.0): 輸出縮放因子。
返回:
`torch.Tensor`: 最後一個殘差塊的輸出,是一個形狀為 `(batch_size, in_channels,
height, width)` 的張量。
"""
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # 預設,空間
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
):
# 初始化 UNetMidBlockFlat 類,設定各引數的預設值
super().__init__() # 呼叫父類的初始化方法
self.in_channels = in_channels # 儲存輸入通道數
self.temb_channels = temb_channels # 儲存時間嵌入通道數
self.dropout = dropout # 儲存 dropout 比率
self.num_layers = num_layers # 儲存殘差塊的數量
self.resnet_eps = resnet_eps # 儲存 resnet 塊的 epsilon 值
self.resnet_time_scale_shift = resnet_time_scale_shift # 儲存時間縮放偏移型別
self.resnet_act_fn = resnet_act_fn # 儲存啟用函式型別
self.resnet_groups = resnet_groups # 儲存分組數
self.attn_groups = attn_groups # 儲存注意力組數
self.resnet_pre_norm = resnet_pre_norm # 儲存是否使用預歸一化
self.add_attention = add_attention # 儲存是否新增註意力塊
self.attention_head_dim = attention_head_dim # 儲存注意力頭的維度
self.output_scale_factor = output_scale_factor # 儲存輸出縮放因子
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 定義前向傳播方法,接受隱藏狀態和可選的時間嵌入
hidden_states = self.resnets[0](hidden_states, temb) # 透過第一個殘差塊處理隱藏狀態
for attn, resnet in zip(self.attentions, self.resnets[1:]): # 遍歷後續的注意力塊和殘差塊
if attn is not None: # 如果注意力塊存在
hidden_states = attn(hidden_states, temb=temb) # 透過注意力塊處理隱藏狀態
hidden_states = resnet(hidden_states, temb) # 透過殘差塊處理隱藏狀態
return hidden_states # 返回處理後的隱藏狀態
# 從 diffusers.models.unets.unet_2d_blocks 中複製,替換 UNetMidBlock2DCrossAttn 為 UNetMidBlockFlatCrossAttn,ResnetBlock2D 為 ResnetBlockFlat
class UNetMidBlockFlatCrossAttn(nn.Module):
# 初始化方法,定義模型引數
def __init__(
self,
# 輸入通道數
in_channels: int,
# 時間嵌入通道數
temb_channels: int,
# 輸出通道數,預設為 None
out_channels: Optional[int] = None,
# Dropout 機率,預設為 0.0
dropout: float = 0.0,
# 層數,預設為 1
num_layers: int = 1,
# 每個塊的 Transformer 層數,預設為 1
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 的 epsilon 值,預設為 1e-6
resnet_eps: float = 1e-6,
# ResNet 的時間尺度偏移,預設為 "default"
resnet_time_scale_shift: str = "default",
# ResNet 的啟用函式型別,預設為 "swish"
resnet_act_fn: str = "swish",
# ResNet 的分組數,預設為 32
resnet_groups: int = 32,
# 輸出的 ResNet 分組數,預設為 None
resnet_groups_out: Optional[int] = None,
# 是否使用預歸一化,預設為 True
resnet_pre_norm: bool = True,
# 注意力頭數,預設為 1
num_attention_heads: int = 1,
# 輸出縮放因子,預設為 1.0
output_scale_factor: float = 1.0,
# 交叉注意力維度,預設為 1280
cross_attention_dim: int = 1280,
# 是否使用雙交叉注意力,預設為 False
dual_cross_attention: bool = False,
# 是否使用線性投影,預設為 False
use_linear_projection: bool = False,
# 是否上升注意力計算精度,預設為 False
upcast_attention: bool = False,
# 注意力型別,預設為 "default"
attention_type: str = "default",
# 前向傳播方法,定義模型的前向計算邏輯
def forward(
self,
# 隱藏狀態張量
hidden_states: torch.Tensor,
# 可選的時間嵌入張量,預設為 None
temb: Optional[torch.Tensor] = None,
# 可選的編碼器隱藏狀態張量,預設為 None
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可選的注意力掩碼,預設為 None
attention_mask: Optional[torch.Tensor] = None,
# 可選的交叉注意力引數字典,預設為 None
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可選的編碼器注意力掩碼,預設為 None
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: # 定義函式的返回型別為 torch.Tensor
if cross_attention_kwargs is not None: # 檢查 cross_attention_kwargs 是否為 None
if cross_attention_kwargs.get("scale", None) is not None: # 檢查 scale 是否在 cross_attention_kwargs 中
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") # 發出警告,提示 scale 引數已過時
hidden_states = self.resnets[0](hidden_states, temb) # 使用第一個殘差網路處理隱藏狀態和時間嵌入
for attn, resnet in zip(self.attentions, self.resnets[1:]): # 遍歷注意力層和後續的殘差網路
if self.training and self.gradient_checkpointing: # 檢查是否在訓練模式且開啟了梯度檢查點
def create_custom_forward(module, return_dict=None): # 定義一個函式以建立自定義前向傳播
def custom_forward(*inputs): # 定義實際的前向傳播函式
if return_dict is not None: # 檢查是否需要返回字典形式的輸出
return module(*inputs, return_dict=return_dict) # 呼叫模組並返回字典
else: # 如果不需要字典形式的輸出
return module(*inputs) # 直接呼叫模組並返回結果
return custom_forward # 返回自定義前向傳播函式
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} # 根據 PyTorch 版本設定檢查點引數
hidden_states = attn( # 使用注意力層處理隱藏狀態
hidden_states, # 輸入隱藏狀態
encoder_hidden_states=encoder_hidden_states, # 輸入編碼器的隱藏狀態
cross_attention_kwargs=cross_attention_kwargs, # 傳遞交叉注意力引數
attention_mask=attention_mask, # 傳遞注意力掩碼
encoder_attention_mask=encoder_attention_mask, # 傳遞編碼器注意力掩碼
return_dict=False, # 不返回字典
)[0] # 取出輸出的第一個元素
hidden_states = torch.utils.checkpoint.checkpoint( # 使用檢查點儲存記憶體
create_custom_forward(resnet), # 建立自定義前向傳播
hidden_states, # 輸入隱藏狀態
temb, # 輸入時間嵌入
**ckpt_kwargs, # 解包檢查點引數
)
else: # 如果不在訓練模式或不使用梯度檢查點
hidden_states = attn( # 使用注意力層處理隱藏狀態
hidden_states, # 輸入隱藏狀態
encoder_hidden_states=encoder_hidden_states, # 輸入編碼器的隱藏狀態
cross_attention_kwargs=cross_attention_kwargs, # 傳遞交叉注意力引數
attention_mask=attention_mask, # 傳遞注意力掩碼
encoder_attention_mask=encoder_attention_mask, # 傳遞編碼器注意力掩碼
return_dict=False, # 不返回字典
)[0] # 取出輸出的第一個元素
hidden_states = resnet(hidden_states, temb) # 使用殘差網路處理隱藏狀態和時間嵌入
return hidden_states # 返回處理後的隱藏狀態
# 從 diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn 複製,替換 UNetMidBlock2DSimpleCrossAttn 為 UNetMidBlockFlatSimpleCrossAttn,ResnetBlock2D 為 ResnetBlockFlat
class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
# 初始化方法,設定各層的輸入輸出引數
def __init__(
# 輸入通道數
in_channels: int,
# 條件嵌入通道數
temb_channels: int,
# Dropout 機率
dropout: float = 0.0,
# 網路層數
num_layers: int = 1,
# ResNet 的 epsilon 值
resnet_eps: float = 1e-6,
# ResNet 的時間縮放偏移方式
resnet_time_scale_shift: str = "default",
# ResNet 啟用函式型別
resnet_act_fn: str = "swish",
# ResNet 中組的數量
resnet_groups: int = 32,
# 是否使用 ResNet 前歸一化
resnet_pre_norm: bool = True,
# 注意力頭的維度
attention_head_dim: int = 1,
# 輸出縮放因子
output_scale_factor: float = 1.0,
# 交叉注意力的維度
cross_attention_dim: int = 1280,
# 是否跳過時間啟用
skip_time_act: bool = False,
# 是否僅使用交叉注意力
only_cross_attention: bool = False,
# 交叉注意力的歸一化方式
cross_attention_norm: Optional[str] = None,
):
# 呼叫父類的初始化方法
super().__init__()
# 設定是否使用交叉注意力機制
self.has_cross_attention = True
# 設定注意力頭的維度
self.attention_head_dim = attention_head_dim
# 確定 ResNet 的組數,若未提供則使用預設值
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# 計算頭的數量
self.num_heads = in_channels // self.attention_head_dim
# 確保至少有一個 ResNet 塊
resnets = [
# 建立一個 ResNet 塊
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
]
# 初始化注意力列表
attentions = []
# 根據層數建立對應的注意力機制
for _ in range(num_layers):
# 根據是否支援縮放點積注意力選擇處理器
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
# 新增註意力機制到列表
attentions.append(
Attention(
query_dim=in_channels,
cross_attention_dim=in_channels,
heads=self.num_heads,
dim_head=self.attention_head_dim,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=processor,
)
)
# 新增 ResNet 塊到列表
resnets.append(
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
# 將注意力層存入模組列表
self.attentions = nn.ModuleList(attentions)
# 將 ResNet 塊存入模組列表
self.resnets = nn.ModuleList(resnets)
def forward(
# 定義前向傳播的方法
self,
hidden_states: torch.Tensor,
# 可選的時間嵌入張量
temb: Optional[torch.Tensor] = None,
# 可選的編碼器隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可選的注意力掩碼張量
attention_mask: Optional[torch.Tensor] = None,
# 可選的交叉注意力引數字典
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可選的編碼器注意力掩碼張量
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# 如果傳入的 cross_attention_kwargs 為 None,則初始化為空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 檢查 cross_attention_kwargs 中是否有 'scale',如果有則發出警告,說明該引數已棄用
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 如果 attention_mask 為 None
if attention_mask is None:
# 如果 encoder_hidden_states 被定義:表示我們在進行交叉注意力,因此應該使用交叉注意力掩碼
mask = None if encoder_hidden_states is None else encoder_attention_mask
else:
# 當 attention_mask 被定義時:我們不檢查 encoder_attention_mask
# 這是為了與 UnCLIP 相容,UnCLIP 使用 'attention_mask' 引數作為交叉注意力掩碼
# TODO: UnCLIP 應透過 encoder_attention_mask 引數而不是 attention_mask 引數來表達交叉注意力掩碼
# 然後我們可以簡化整個 if/else 塊為:
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask
# 使用第一個殘差網路處理隱藏狀態和時間嵌入
hidden_states = self.resnets[0](hidden_states, temb)
# 遍歷所有注意力層和對應的殘差網路
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# 使用注意力層處理隱藏狀態
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states, # 傳遞編碼器隱藏狀態
attention_mask=mask, # 傳遞掩碼
**cross_attention_kwargs, # 傳遞交叉注意力引數
)
# 使用殘差網路處理隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 返回最終的隱藏狀態
return hidden_states