diffusers 原始碼解析(十五)
.\diffusers\models\unets\unet_3d_condition.py
# 版權宣告,宣告此程式碼的版權資訊和所有權
# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
# 版權宣告,宣告此程式碼的版權資訊和所有權
# Copyright 2024 The ModelScope Team.
#
# 許可宣告,宣告本程式碼使用的 Apache 許可證 2.0 版本
# Licensed under the Apache License, Version 2.0 (the "License");
# 使用此檔案前需遵守許可證規定
# you may not use this file except in compliance with the License.
# 可在以下網址獲取許可證副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 免責宣告,說明軟體在許可下按 "原樣" 提供,不附加任何明示或暗示的保證
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 許可證中規定的許可權和限制說明
# See the License for the specific language governing permissions and
# limitations under the License.
# 從 dataclasses 模組匯入 dataclass 裝飾器
from dataclasses import dataclass
# 從 typing 模組匯入所需的型別提示
from typing import Any, Dict, List, Optional, Tuple, Union
# 匯入 PyTorch 庫
import torch
# 匯入 PyTorch 神經網路模組
import torch.nn as nn
# 匯入 PyTorch 的檢查點工具
import torch.utils.checkpoint
# 匯入配置相關的工具類和函式
from ...configuration_utils import ConfigMixin, register_to_config
# 匯入 UNet2D 條件載入器混合類
from ...loaders import UNet2DConditionLoadersMixin
# 匯入基本輸出類和日誌工具
from ...utils import BaseOutput, logging
# 匯入啟用函式獲取工具
from ..activations import get_activation
# 匯入各種注意力處理器相關元件
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, # 匯入新增鍵值對注意力處理器
CROSS_ATTENTION_PROCESSORS, # 匯入交叉注意力處理器
Attention, # 匯入注意力類
AttentionProcessor, # 匯入注意力處理器基類
AttnAddedKVProcessor, # 匯入新增鍵值對的注意力處理器
AttnProcessor, # 匯入普通注意力處理器
FusedAttnProcessor2_0, # 匯入融合注意力處理器
)
# 匯入時間步嵌入和時間步類
from ..embeddings import TimestepEmbedding, Timesteps
# 匯入模型混合類
from ..modeling_utils import ModelMixin
# 匯入時間變換器模型
from ..transformers.transformer_temporal import TransformerTemporalModel
# 匯入 3D UNet 相關的塊
from .unet_3d_blocks import (
CrossAttnDownBlock3D, # 匯入交叉注意力下采樣塊
CrossAttnUpBlock3D, # 匯入交叉注意力上取樣塊
DownBlock3D, # 匯入下采樣塊
UNetMidBlock3DCrossAttn, # 匯入 UNet 中間交叉注意力塊
UpBlock3D, # 匯入上取樣塊
get_down_block, # 匯入獲取下采樣塊的函式
get_up_block, # 匯入獲取上取樣塊的函式
)
# 建立日誌記錄器,使用當前模組的名稱
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 定義 UNet3DConditionOutput 資料類,繼承自 BaseOutput
@dataclass
class UNet3DConditionOutput(BaseOutput):
"""
[`UNet3DConditionModel`] 的輸出類。
引數:
sample (`torch.Tensor` 的形狀為 `(batch_size, num_channels, num_frames, height, width)`):
基於 `encoder_hidden_states` 輸入的隱藏狀態輸出。模型最後一層的輸出。
"""
sample: torch.Tensor # 定義樣本輸出,型別為 PyTorch 張量
# 定義 UNet3DConditionModel 類,繼承自多個混合類
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
條件 3D UNet 模型,接受噪聲樣本、條件狀態和時間步,並返回形狀為樣本的輸出。
此模型繼承自 [`ModelMixin`]。有關其通用方法的文件,請參閱超類文件(如下載或儲存)。
# 引數說明部分
Parameters:
# 輸入/輸出樣本的高度和寬度,型別可以為整數或元組,預設為 None
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
# 輸入樣本的通道數,預設為 4
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
# 輸出的通道數,預設為 4
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
# 使用的下采樣塊型別的元組,預設為指定的四種塊
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`):
The tuple of downsample blocks to use.
# 使用的上取樣塊型別的元組,預設為指定的四種塊
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`):
The tuple of upsample blocks to use.
# 每個塊的輸出通道數的元組,預設為 (320, 640, 1280, 1280)
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
# 每個塊的層數,預設為 2
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
# 下采樣卷積使用的填充,預設為 1
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
# 中間塊使用的縮放因子,預設為 1.0
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
# 使用的啟用函式,預設為 "silu"
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
# 用於歸一化的組數,預設為 32;如果為 None,則跳過歸一化和啟用層
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing.
# 歸一化使用的 epsilon 值,預設為 1e-5
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
# 交叉注意力特徵的維度,預設為 1024
cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features.
# 注意力頭的維度,預設為 64
attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
# 注意力頭的數量,型別為整數,預設為 None
num_attention_heads (`int`, *optional*): The number of attention heads.
# 時間條件投影層的維度,預設為 None
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
"""
# 是否支援梯度檢查點,預設為 False
_supports_gradient_checkpointing = False
# 將此類註冊到配置中
@register_to_config
# 初始化方法,用於建立類的例項
def __init__(
# 樣本大小,預設為 None
self,
sample_size: Optional[int] = None,
# 輸入通道數量,預設為 4
in_channels: int = 4,
# 輸出通道數量,預設為 4
out_channels: int = 4,
# 下采樣塊型別的元組,定義模型的下采樣結構
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
),
# 上取樣塊型別的元組,定義模型的上取樣結構
up_block_types: Tuple[str, ...] = (
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
),
# 每個塊的輸出通道數量,定義模型每個層的通道設定
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
# 每個塊的層數,預設為 2
layers_per_block: int = 2,
# 下采樣時的填充大小,預設為 1
downsample_padding: int = 1,
# 中間塊的縮放因子,預設為 1
mid_block_scale_factor: float = 1,
# 啟用函式型別,預設為 "silu"
act_fn: str = "silu",
# 歸一化組的數量,預設為 32
norm_num_groups: Optional[int] = 32,
# 歸一化的 epsilon 值,預設為 1e-5
norm_eps: float = 1e-5,
# 跨注意力維度,預設為 1024
cross_attention_dim: int = 1024,
# 注意力頭的維度,可以是單一整數或整數元組,預設為 64
attention_head_dim: Union[int, Tuple[int]] = 64,
# 注意力頭的數量,可選引數
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
# 時間條件投影維度,可選引數
time_cond_proj_dim: Optional[int] = None,
@property
# 從 UNet2DConditionModel 複製的屬性,獲取注意力處理器
# 返回所有注意力處理器的字典,以權重名稱為索引
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# 初始化處理器字典
processors = {}
# 遞迴新增處理器的函式
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 如果模組有獲取處理器的方法,新增到處理器字典中
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
# 遍歷子模組,遞迴呼叫該函式
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回處理器字典
return processors
# 遍歷當前類的子模組,呼叫遞迴新增處理器的函式
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
# 返回所有處理器
return processors
# 從 UNet2DConditionModel 複製的設定注意力切片的方法
# 從 UNet2DConditionModel 複製的設定注意力處理器的方法
# 定義一個方法用於設定注意力處理器
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
設定用於計算注意力的處理器。
引數:
processor(`dict` of `AttentionProcessor` 或僅 `AttentionProcessor`):
例項化的處理器類或一個處理器類的字典,將作為所有 `Attention` 層的處理器。
如果 `processor` 是一個字典,鍵需要定義相應的交叉注意力處理器的路徑。
在設定可訓練的注意力處理器時,強烈推薦這樣做。
"""
# 獲取當前注意力處理器的數量
count = len(self.attn_processors.keys())
# 如果傳入的處理器是字典,且數量不等於注意力層數量,丟擲錯誤
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"傳入了一個處理器字典,但處理器的數量 {len(processor)} 與"
f" 注意力層的數量 {count} 不匹配。請確保傳入 {count} 個處理器類。"
)
# 定義一個遞迴函式來設定每個模組的處理器
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 如果模組有設定處理器的方法
if hasattr(module, "set_processor"):
# 如果處理器不是字典,直接設定處理器
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 從字典中獲取相應的處理器並設定
module.set_processor(processor.pop(f"{name}.processor"))
# 遍歷子模組並遞迴呼叫處理器設定
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍歷當前物件的所有子模組,並呼叫遞迴設定函式
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# 定義一個方法來啟用前饋層的分塊處理
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
設定注意力處理器以使用 [前饋分塊](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。
引數:
chunk_size (`int`, *可選*):
前饋層的分塊大小。如果未指定,將對維度為`dim`的每個張量單獨執行前饋層。
dim (`int`, *可選*, 預設為`0`):
應對哪個維度進行前饋計算的分塊。可以選擇 dim=0(批次)或 dim=1(序列長度)。
"""
# 確保 dim 引數為 0 或 1
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# 預設的分塊大小為 1
chunk_size = chunk_size or 1
# 定義一個遞迴函式來設定每個模組的分塊前饋處理
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模組具有設定分塊前饋的屬性,則設定它
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遍歷子模組,遞迴呼叫函式
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 遍歷當前例項的子模組,應用遞迴函式
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# 定義一個方法來禁用前饋層的分塊處理
def disable_forward_chunking(self):
# 定義一個遞迴函式來禁用分塊前饋處理
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模組具有設定分塊前饋的屬性,則設定為 None
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遍歷子模組,遞迴呼叫函式
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 遍歷當前例項的子模組,應用遞迴函式,禁用分塊
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
# 從 diffusers.models.unets.unet_2d_condition 中複製的方法,設定預設注意力處理器
def set_default_attn_processor(self):
"""
禁用自定義注意力處理器並設定預設注意力實現。
"""
# 檢查所有注意力處理器是否為新增的 KV 處理器
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor() # 設定為新增的 KV 處理器
# 檢查所有注意力處理器是否為交叉注意力處理器
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor() # 設定為普通注意力處理器
else:
# 丟擲異常,若注意力處理器型別不符合預期
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
# 設定選定的注意力處理器
self.set_attn_processor(processor)
# 定義一個私有方法來設定模組的梯度檢查點
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
# 檢查模組是否屬於特定型別
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value # 設定梯度檢查點值
# 從 diffusers.models.unets.unet_2d_condition 中複製的方法,啟用自由度
# 啟用 FreeU 機制,引數為兩個縮放因子和兩個增強因子的值
def enable_freeu(self, s1, s2, b1, b2):
r"""從 https://arxiv.org/abs/2309.11497 啟用 FreeU 機制。
縮放因子的字尾表示它們應用的階段塊。
請參考 [官方倉庫](https://github.com/ChenyangSi/FreeU) 以獲取在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中已知效果良好的值組合。
Args:
s1 (`float`):
第1階段的縮放因子,用於減弱跳躍特徵的貢獻,以減輕增強去噪過程中的“過平滑效應”。
s2 (`float`):
第2階段的縮放因子,用於減弱跳躍特徵的貢獻,以減輕增強去噪過程中的“過平滑效應”。
b1 (`float`): 第1階段的縮放因子,用於增強骨幹特徵的貢獻。
b2 (`float`): 第2階段的縮放因子,用於增強骨幹特徵的貢獻。
"""
# 遍歷上取樣塊,給每個塊設定縮放因子和增強因子
for i, upsample_block in enumerate(self.up_blocks):
# 設定第1階段的縮放因子
setattr(upsample_block, "s1", s1)
# 設定第2階段的縮放因子
setattr(upsample_block, "s2", s2)
# 設定第1階段的增強因子
setattr(upsample_block, "b1", b1)
# 設定第2階段的增強因子
setattr(upsample_block, "b2", b2)
# 從 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu 複製
# 禁用 FreeU 機制
def disable_freeu(self):
"""禁用 FreeU 機制。"""
# 定義 FreeU 機制的關鍵屬性
freeu_keys = {"s1", "s2", "b1", "b2"}
# 遍歷上取樣塊
for i, upsample_block in enumerate(self.up_blocks):
# 遍歷 FreeU 關鍵屬性
for k in freeu_keys:
# 如果上取樣塊有該屬性,或者該屬性值不為 None
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
# 將屬性值設定為 None,禁用 FreeU
setattr(upsample_block, k, None)
# 從 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections 複製
# 啟用融合的 QKV 投影
def fuse_qkv_projections(self):
"""
啟用融合的 QKV 投影。對於自注意力模組,所有投影矩陣(即查詢、鍵、值)都被融合。對於交叉注意力模組,鍵和值投影矩陣被融合。
<Tip warning={true}>
此 API 是 🧪 實驗性的。
</Tip>
"""
# 儲存原始的注意力處理器
self.original_attn_processors = None
# 遍歷注意力處理器
for _, attn_processor in self.attn_processors.items():
# 如果注意力處理器的類名中包含“Added”
if "Added" in str(attn_processor.__class__.__name__):
# 丟擲錯誤,表示不支援具有附加 KV 投影的模型
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# 儲存當前的注意力處理器
self.original_attn_processors = self.attn_processors
# 遍歷所有模組
for module in self.modules():
# 如果模組是 Attention 型別
if isinstance(module, Attention):
# 融合投影
module.fuse_projections(fuse=True)
# 設定注意力處理器為融合的注意力處理器
self.set_attn_processor(FusedAttnProcessor2_0())
# 從 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 複製
# 定義一個方法,用於禁用已啟用的融合 QKV 投影
def unfuse_qkv_projections(self):
"""禁用已啟用的融合 QKV 投影。
<Tip warning={true}>
該 API 是 🧪 實驗性的。
</Tip>
"""
# 如果存在原始的注意力處理器,則設定當前的注意力處理器為原始處理器
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
# 定義前向傳播方法,接受多個引數進行計算
def forward(
self,
sample: torch.Tensor, # 輸入樣本,張量格式
timestep: Union[torch.Tensor, float, int], # 當前時間步,可以是張量、浮點數或整數
encoder_hidden_states: torch.Tensor, # 編碼器的隱藏狀態,張量格式
class_labels: Optional[torch.Tensor] = None, # 類別標籤,預設為 None
timestep_cond: Optional[torch.Tensor] = None, # 時間步條件,預設為 None
attention_mask: Optional[torch.Tensor] = None, # 注意力掩碼,預設為 None
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 跨注意力的關鍵字引數,預設為 None
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, # 降級塊的附加殘差,預設為 None
mid_block_additional_residual: Optional[torch.Tensor] = None, # 中間塊的附加殘差,預設為 None
return_dict: bool = True, # 是否返回字典格式的結果,預設為 True
.\diffusers\models\unets\unet_i2vgen_xl.py
# 版權宣告,表明版權歸2024年阿里巴巴DAMO-VILAB和HuggingFace團隊所有
# 提供Apache許可證2.0版本的使用條款
# 說明只能在遵循許可證的情況下使用此檔案
# 可在指定網址獲取許可證副本
#
# 除非適用法律或書面協議另有約定,否則軟體按“原樣”分發
# 不提供任何形式的擔保或條件
# 請參見許可證以獲取與許可權和限制相關的具體資訊
from typing import Any, Dict, Optional, Tuple, Union # 匯入型別提示工具,用於型別註解
import torch # 匯入PyTorch庫
import torch.nn as nn # 匯入PyTorch的神經網路模組
import torch.utils.checkpoint # 匯入PyTorch的檢查點工具
from ...configuration_utils import ConfigMixin, register_to_config # 從配置工具匯入類和函式
from ...loaders import UNet2DConditionLoadersMixin # 匯入2D條件載入器混合類
from ...utils import logging # 匯入日誌工具
from ..activations import get_activation # 匯入啟用函式獲取工具
from ..attention import Attention, FeedForward # 匯入注意力機制和前饋網路
from ..attention_processor import ( # 從注意力處理器模組匯入多個處理器
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps # 匯入時間步嵌入和時間步類
from ..modeling_utils import ModelMixin # 匯入模型混合類
from ..transformers.transformer_temporal import TransformerTemporalModel # 匯入時間變換器模型
from .unet_3d_blocks import ( # 從3D U-Net塊模組匯入多個類
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn,
UpBlock3D,
get_down_block,
get_up_block,
)
from .unet_3d_condition import UNet3DConditionOutput # 匯入3D條件輸出類
logger = logging.get_logger(__name__) # 建立日誌記錄器,用於記錄當前模組的資訊
class I2VGenXLTransformerTemporalEncoder(nn.Module): # 定義一個名為I2VGenXLTransformerTemporalEncoder的類,繼承自nn.Module
def __init__( # 建構函式,用於初始化類的例項
self,
dim: int, # 輸入的特徵維度
num_attention_heads: int, # 注意力頭的數量
attention_head_dim: int, # 每個注意力頭的維度
activation_fn: str = "geglu", # 啟用函式型別,預設使用geglu
upcast_attention: bool = False, # 是否提升注意力計算的精度
ff_inner_dim: Optional[int] = None, # 前饋網路的內部維度,預設為None
dropout: int = 0.0, # dropout機率,預設為0.0
):
super().__init__() # 呼叫父類建構函式
self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5) # 初始化層歸一化層
self.attn1 = Attention( # 初始化注意力層
query_dim=dim, # 查詢維度
heads=num_attention_heads, # 注意力頭數量
dim_head=attention_head_dim, # 每個頭的維度
dropout=dropout, # dropout機率
bias=False, # 不使用偏置
upcast_attention=upcast_attention, # 是否提升注意力計算精度
out_bias=True, # 輸出使用偏置
)
self.ff = FeedForward( # 初始化前饋網路
dim, # 輸入維度
dropout=dropout, # dropout機率
activation_fn=activation_fn, # 啟用函式型別
final_dropout=False, # 最後層不使用dropout
inner_dim=ff_inner_dim, # 內部維度
bias=True, # 使用偏置
)
def forward( # 定義前向傳播方法
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態
# 該方法返回處理後的隱藏狀態張量
) -> torch.Tensor:
# 對隱藏狀態進行歸一化處理
norm_hidden_states = self.norm1(hidden_states)
# 計算注意力輸出,使用歸一化後的隱藏狀態
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
# 將注意力輸出與原始隱藏狀態相加,更新隱藏狀態
hidden_states = attn_output + hidden_states
# 如果隱藏狀態是四維,則去掉第一維
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 透過前饋網路處理隱藏狀態
ff_output = self.ff(hidden_states)
# 將前饋輸出與當前隱藏狀態相加,更新隱藏狀態
hidden_states = ff_output + hidden_states
# 如果隱藏狀態是四維,則去掉第一維
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 返回最終的隱藏狀態
return hidden_states
# 定義 I2VGenXL UNet 類,繼承多個混入類以增加功能
class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
I2VGenXL UNet。一個條件3D UNet模型,接收噪聲樣本、條件狀態和時間步,
返回與樣本形狀相同的輸出。
該模型繼承自 [`ModelMixin`]。有關所有模型實現的通用方法(如下載或儲存),
請檢視超類文件。
引數:
sample_size (`int` 或 `Tuple[int, int]`, *可選*, 預設值為 `None`):
輸入/輸出樣本的高度和寬度。
in_channels (`int`, *可選*, 預設值為 4): 輸入樣本的通道數。
out_channels (`int`, *可選*, 預設值為 4): 輸出樣本的通道數。
down_block_types (`Tuple[str]`, *可選*, 預設值為 `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
使用的下采樣塊的元組。
up_block_types (`Tuple[str]`, *可選*, 預設值為 `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
使用的上取樣塊的元組。
block_out_channels (`Tuple[int]`, *可選*, 預設值為 `(320, 640, 1280, 1280)`):
每個塊的輸出通道元組。
layers_per_block (`int`, *可選*, 預設值為 2): 每個塊的層數。
norm_num_groups (`int`, *可選*, 預設值為 32): 用於歸一化的組數。
如果為 `None`,則跳過後處理中的歸一化和啟用層。
cross_attention_dim (`int`, *可選*, 預設值為 1280): 跨注意力特徵的維度。
attention_head_dim (`int`, *可選*, 預設值為 64): 注意力頭的維度。
num_attention_heads (`int`, *可選*): 注意力頭的數量。
"""
# 設定不支援梯度檢查點的屬性為 False
_supports_gradient_checkpointing = False
@register_to_config
# 初始化方法,接受多種可選引數以設定模型配置
def __init__(
self,
sample_size: Optional[int] = None, # 輸入/輸出樣本大小,預設為 None
in_channels: int = 4, # 輸入樣本的通道數,預設為 4
out_channels: int = 4, # 輸出樣本的通道數,預設為 4
down_block_types: Tuple[str, ...] = ( # 下采樣塊的型別,預設為指定的元組
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
),
up_block_types: Tuple[str, ...] = ( # 上取樣塊的型別,預設為指定的元組
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), # 每個塊的輸出通道,預設為指定的元組
layers_per_block: int = 2, # 每個塊的層數,預設為 2
norm_num_groups: Optional[int] = 32, # 歸一化組數,預設為 32
cross_attention_dim: int = 1024, # 跨注意力特徵的維度,預設為 1024
attention_head_dim: Union[int, Tuple[int]] = 64, # 注意力頭的維度,預設為 64
num_attention_heads: Optional[Union[int, Tuple[int]]] = None, # 注意力頭的數量,預設為 None
@property
# 該屬性從 UNet2DConditionModel 的 attn_processors 複製
# 定義返回注意力處理器的函式,返回型別為字典,鍵為字串,值為 AttentionProcessor 物件
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# 建立一個空字典,用於儲存處理器
processors = {}
# 定義一個遞迴函式,用於新增處理器到字典
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 檢查模組是否有 get_processor 方法
if hasattr(module, "get_processor"):
# 將處理器新增到字典中,鍵為名稱加上 ".processor"
processors[f"{name}.processor"] = module.get_processor()
# 遍歷模組的子模組
for sub_name, child in module.named_children():
# 遞迴呼叫,處理子模組
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回更新後的處理器字典
return processors
# 遍歷當前物件的所有子模組
for name, module in self.named_children():
# 呼叫遞迴函式,將處理器新增到字典中
fn_recursive_add_processors(name, module, processors)
# 返回包含所有處理器的字典
return processors
# 從 diffusers.models.unets.unet_2d_condition 中複製的設定注意力處理器的函式
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
# 獲取當前注意力處理器的數量
count = len(self.attn_processors.keys())
# 如果傳入的是字典且數量不匹配,則引發錯誤
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
# 定義一個遞迴函式,用於設定處理器
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 檢查模組是否有 set_processor 方法
if hasattr(module, "set_processor"):
# 如果傳入的處理器不是字典,直接設定
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 從字典中移除並設定對應的處理器
module.set_processor(processor.pop(f"{name}.processor"))
# 遍歷模組的子模組
for sub_name, child in module.named_children():
# 遞迴呼叫,設定子模組的處理器
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍歷當前物件的所有子模組
for name, module in self.named_children():
# 呼叫遞迴函式,為每個模組設定處理器
fn_recursive_attn_processor(name, module, processor)
# 從 diffusers.models.unets.unet_3d_condition 中複製的啟用前向分塊的函式
# 啟用前饋層的分塊處理
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
設定注意力處理器使用[前饋分塊](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。
引數:
chunk_size (`int`, *可選*):
前饋層的塊大小。如果未指定,將對維度為`dim`的每個張量單獨執行前饋層。
dim (`int`, *可選*, 預設為`0`):
前饋計算應分塊的維度。可以選擇dim=0(批次)或dim=1(序列長度)。
"""
# 檢查維度是否在有效範圍內
if dim not in [0, 1]:
# 丟擲錯誤,確保dim只為0或1
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# 預設塊大小為1
chunk_size = chunk_size or 1
# 定義遞迴函式,用於設定每個模組的前饋分塊
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模組有設定分塊前饋的方法,呼叫該方法
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遞迴遍歷子模組
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 對當前物件的所有子模組應用前饋分塊設定
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# 從diffusers.models.unets.unet_3d_condition.UNet3DConditionModel複製的禁用前饋分塊的方法
def disable_forward_chunking(self):
# 定義遞迴函式,用於禁用模組的前饋分塊
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模組有設定分塊前饋的方法,呼叫該方法
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遞迴遍歷子模組
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 對當前物件的所有子模組應用禁用前饋分塊設定
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
# 從diffusers.models.unets.unet_2d_condition.UNet2DConditionModel複製的設定預設注意力處理器的方法
def set_default_attn_processor(self):
"""
禁用自定義注意力處理器並設定預設的注意力實現。
"""
# 檢查所有注意力處理器是否屬於已新增的KV注意力處理器類
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 如果是,則設定為已新增KV處理器
processor = AttnAddedKVProcessor()
# 檢查所有注意力處理器是否屬於交叉注意力處理器類
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 如果是,則設定為標準注意力處理器
processor = AttnProcessor()
else:
# 丟擲錯誤,說明當前的注意力處理器型別不被支援
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
# 設定當前物件的注意力處理器為選擇的處理器
self.set_attn_processor(processor)
# 從diffusers.models.unets.unet_3d_condition.UNet3DConditionModel複製的設定梯度檢查點的方法
# 設定梯度檢查點,指定模組和布林值
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
# 檢查模組是否為指定的型別之一
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
# 設定模組的梯度檢查點屬性為指定值
module.gradient_checkpointing = value
# 從 UNet2DConditionModel 中複製的啟用 FreeU 方法
def enable_freeu(self, s1, s2, b1, b2):
r"""啟用 FreeU 機制,詳情見 https://arxiv.org/abs/2309.11497.
字尾表示縮放因子應用的階段塊。
請參考 [官方庫](https://github.com/ChenyangSi/FreeU) 以獲取適用於不同管道(如 Stable Diffusion v1, v2 和 Stable Diffusion XL)的有效值組合。
引數:
s1 (`float`):
階段 1 的縮放因子,用於減弱跳過特徵的貢獻,以緩解增強去噪過程中的“過平滑效應”。
s2 (`float`):
階段 2 的縮放因子,用於減弱跳過特徵的貢獻,以緩解增強去噪過程中的“過平滑效應”。
b1 (`float`): 階段 1 的縮放因子,用於放大主幹特徵的貢獻。
b2 (`float`): 階段 2 的縮放因子,用於放大主幹特徵的貢獻。
"""
# 遍歷上取樣塊,索引 i 和塊物件 upsample_block
for i, upsample_block in enumerate(self.up_blocks):
# 設定上取樣塊的屬性 s1 為給定值 s1
setattr(upsample_block, "s1", s1)
# 設定上取樣塊的屬性 s2 為給定值 s2
setattr(upsample_block, "s2", s2)
# 設定上取樣塊的屬性 b1 為給定值 b1
setattr(upsample_block, "b1", b1)
# 設定上取樣塊的屬性 b2 為給定值 b2
setattr(upsample_block, "b2", b2)
# 從 UNet2DConditionModel 中複製的禁用 FreeU 方法
def disable_freeu(self):
"""禁用 FreeU 機制。"""
# 定義 FreeU 相關的屬性鍵
freeu_keys = {"s1", "s2", "b1", "b2"}
# 遍歷上取樣塊,索引 i 和塊物件 upsample_block
for i, upsample_block in enumerate(self.up_blocks):
# 遍歷 FreeU 屬性鍵
for k in freeu_keys:
# 如果上取樣塊具有該屬性或屬性值不為 None
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
# 將上取樣塊的該屬性設定為 None
setattr(upsample_block, k, None)
# 從 UNet2DConditionModel 中複製的融合 QKV 投影方法
# 定義一個方法,用於啟用融合的 QKV 投影
def fuse_qkv_projections(self):
# 提供方法的文件字串,描述其功能和警告資訊
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
# 初始化原始注意力處理器為 None
self.original_attn_processors = None
# 遍歷當前物件的注意力處理器
for _, attn_processor in self.attn_processors.items():
# 檢查處理器類名中是否包含 "Added"
if "Added" in str(attn_processor.__class__.__name__):
# 如果包含,丟擲異常提示不支援
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# 儲存當前的注意力處理器以備後用
self.original_attn_processors = self.attn_processors
# 遍歷當前物件的所有模組
for module in self.modules():
# 檢查模組是否為 Attention 型別
if isinstance(module, Attention):
# 呼叫模組的方法,啟用融合投影
module.fuse_projections(fuse=True)
# 設定注意力處理器為 FusedAttnProcessor2_0 的例項
self.set_attn_processor(FusedAttnProcessor2_0())
# 從 UNet2DConditionModel 複製的方法,用於禁用融合的 QKV 投影
def unfuse_qkv_projections(self):
# 提供方法的文件字串,描述其功能和警告資訊
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
# 檢查原始注意力處理器是否不為 None
if self.original_attn_processors is not None:
# 如果不為 None,恢復原始的注意力處理器
self.set_attn_processor(self.original_attn_processors)
# 定義前向傳播方法,接受多個輸入引數
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
fps: torch.Tensor,
image_latents: torch.Tensor,
image_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
.\diffusers\models\unets\unet_kandinsky3.py
# 版權宣告,指明該檔案屬於 HuggingFace 團隊,所有權利保留
#
# 根據 Apache License 2.0 版(“許可證”)授權;
# 除非遵循許可證,否則不得使用此檔案。
# 可以在以下網址獲取許可證的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律要求或書面同意,否則根據許可證分發的軟體
# 是在“原樣”基礎上分發的,不附帶任何形式的保證或條件。
# 有關特定語言的許可條款和條件,請參見許可證。
from dataclasses import dataclass # 從 dataclasses 模組匯入 dataclass 裝飾器
from typing import Dict, Tuple, Union # 匯入用於型別提示的字典、元組和聯合型別
import torch # 匯入 PyTorch 庫
import torch.utils.checkpoint # 匯入 PyTorch 的檢查點工具
from torch import nn # 從 PyTorch 匯入神經網路模組
from ...configuration_utils import ConfigMixin, register_to_config # 從配置工具匯入混合類和註冊函式
from ...utils import BaseOutput, logging # 從工具模組匯入基礎輸出類和日誌功能
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor # 匯入注意力處理器相關類
from ..embeddings import TimestepEmbedding, Timesteps # 匯入時間步嵌入相關類
from ..modeling_utils import ModelMixin # 匯入模型混合類
logger = logging.get_logger(__name__) # 建立一個記錄器,用於當前模組的日誌記錄
@dataclass # 將該類標記為資料類,以簡化初始化和表示
class Kandinsky3UNetOutput(BaseOutput): # 定義 Kandinsky3UNetOutput 類,繼承自 BaseOutput
sample: torch.Tensor = None # 定義輸出樣本,預設為 None
class Kandinsky3EncoderProj(nn.Module): # 定義 Kandinsky3EncoderProj 類,繼承自 nn.Module
def __init__(self, encoder_hid_dim, cross_attention_dim): # 初始化方法,接收隱藏維度和交叉注意力維度
super().__init__() # 呼叫父類的初始化方法
self.projection_linear = nn.Linear(encoder_hid_dim, cross_attention_dim, bias=False) # 定義線性投影層,不使用偏置
self.projection_norm = nn.LayerNorm(cross_attention_dim) # 定義層歸一化層
def forward(self, x): # 定義前向傳播方法
x = self.projection_linear(x) # 透過線性層處理輸入
x = self.projection_norm(x) # 透過層歸一化處理輸出
return x # 返回處理後的結果
class Kandinsky3UNet(ModelMixin, ConfigMixin): # 定義 Kandinsky3UNet 類,繼承自 ModelMixin 和 ConfigMixin
@register_to_config # 將該方法註冊到配置中
def __init__( # 初始化方法
self,
in_channels: int = 4, # 輸入通道數,預設值為 4
time_embedding_dim: int = 1536, # 時間嵌入維度,預設值為 1536
groups: int = 32, # 組數,預設值為 32
attention_head_dim: int = 64, # 注意力頭維度,預設值為 64
layers_per_block: Union[int, Tuple[int]] = 3, # 每個塊的層數,預設值為 3,可以是整數或元組
block_out_channels: Tuple[int] = (384, 768, 1536, 3072), # 塊輸出通道,預設為指定元組
cross_attention_dim: Union[int, Tuple[int]] = 4096, # 交叉注意力維度,預設值為 4096
encoder_hid_dim: int = 4096, # 編碼器隱藏維度,預設值為 4096
@property # 定義一個屬性
def attn_processors(self) -> Dict[str, AttentionProcessor]: # 返回注意力處理器字典
r""" # 文件字串,描述該方法的功能
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# 設定一個空字典以遞迴儲存處理器
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): # 定義遞迴函式新增處理器
if hasattr(module, "set_processor"): # 檢查模組是否具有 set_processor 屬性
processors[f"{name}.processor"] = module.processor # 將處理器新增到字典中
for sub_name, child in module.named_children(): # 遍歷模組的所有子模組
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) # 遞迴呼叫自身
return processors # 返回更新後的處理器字典
for name, module in self.named_children(): # 遍歷當前類的所有子模組
fn_recursive_add_processors(name, module, processors) # 呼叫遞迴函式
return processors # 返回包含所有處理器的字典
# 定義設定注意力處理器的方法,引數為處理器,可以是 AttentionProcessor 類或其字典形式
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
設定用於計算注意力的處理器。
引數:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
例項化的處理器類或處理器類的字典,將作為 **所有** `Attention` 層的處理器。
如果 `processor` 是一個字典,鍵需要定義相應交叉注意力處理器的路徑。這在設定可訓練注意力處理器時強烈推薦。
"""
# 獲取當前注意力處理器的數量
count = len(self.attn_processors.keys())
# 如果傳入的是字典且其長度與注意力層的數量不匹配,則丟擲錯誤
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"傳入了處理器字典,但處理器的數量 {len(processor)} 與"
f" 注意力層的數量 {count} 不匹配。請確保傳入 {count} 個處理器類。"
)
# 定義遞迴設定注意力處理器的方法
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 如果模組有設定處理器的方法
if hasattr(module, "set_processor"):
# 如果處理器不是字典,則直接設定
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 從字典中獲取對應的處理器並設定
module.set_processor(processor.pop(f"{name}.processor"))
# 遍歷模組的所有子模組
for sub_name, child in module.named_children():
# 遞迴呼叫處理子模組
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍歷當前物件的所有子模組
for name, module in self.named_children():
# 遞迴設定每個子模組的處理器
fn_recursive_attn_processor(name, module, processor)
# 定義設定預設注意力處理器的方法
def set_default_attn_processor(self):
"""
禁用自定義注意力處理器,並設定預設的注意力實現。
"""
# 呼叫設定注意力處理器的方法,使用預設的 AttnProcessor 例項
self.set_attn_processor(AttnProcessor())
# 定義設定梯度檢查點的方法
def _set_gradient_checkpointing(self, module, value=False):
# 如果模組有梯度檢查點的屬性
if hasattr(module, "gradient_checkpointing"):
# 設定該屬性為指定的值
module.gradient_checkpointing = value
# 定義前向傳播函式,接收樣本、時間步以及可選的編碼器隱藏狀態和注意力掩碼
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
# 如果存在編碼器注意力掩碼,則進行調整以適應後續計算
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
# 增加一個維度,以便後續處理
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 檢查時間步是否為張量型別
if not torch.is_tensor(timestep):
# 根據時間步型別確定資料型別
dtype = torch.float32 if isinstance(timestep, float) else torch.int32
# 將時間步轉換為張量並指定裝置
timestep = torch.tensor([timestep], dtype=dtype, device=sample.device)
# 如果時間步為標量,則擴充套件為一維張量
elif len(timestep.shape) == 0:
timestep = timestep[None].to(sample.device)
# 擴充套件時間步到與批次維度相容的形狀
timestep = timestep.expand(sample.shape[0])
# 透過時間投影獲取時間嵌入輸入並轉換為樣本的資料型別
time_embed_input = self.time_proj(timestep).to(sample.dtype)
# 獲取時間嵌入
time_embed = self.time_embedding(time_embed_input)
# 對編碼器隱藏狀態進行線性變換
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
# 如果存在編碼器隱藏狀態,則將時間嵌入與隱藏狀態結合
if encoder_hidden_states is not None:
time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask)
# 初始化隱藏狀態列表
hidden_states = []
# 對輸入樣本進行初步卷積處理
sample = self.conv_in(sample)
# 遍歷下采樣塊
for level, down_sample in enumerate(self.down_blocks):
# 透過下采樣塊處理樣本
sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
# 如果不是最後一個層級,記錄當前樣本狀態
if level != self.num_levels - 1:
hidden_states.append(sample)
# 遍歷上取樣塊
for level, up_sample in enumerate(self.up_blocks):
# 如果不是第一個層級,則拼接當前樣本與之前的隱藏狀態
if level != 0:
sample = torch.cat([sample, hidden_states.pop()], dim=1)
# 透過上取樣塊處理樣本
sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
# 進行輸出卷積規範化
sample = self.conv_norm_out(sample)
# 進行輸出啟用
sample = self.conv_act_out(sample)
# 進行最終輸出卷積
sample = self.conv_out(sample)
# 根據返回標誌返回相應的結果
if not return_dict:
return (sample,)
# 返回結果物件
return Kandinsky3UNetOutput(sample=sample)
# 定義 Kandinsky3UpSampleBlock 類,繼承自 nn.Module
class Kandinsky3UpSampleBlock(nn.Module):
# 初始化方法,設定各引數
def __init__(
self,
in_channels, # 輸入通道數
cat_dim, # 拼接維度
out_channels, # 輸出通道數
time_embed_dim, # 時間嵌入維度
context_dim=None, # 上下文維度,可選
num_blocks=3, # 塊的數量
groups=32, # 分組數
head_dim=64, # 頭維度
expansion_ratio=4, # 擴充套件比例
compression_ratio=2, # 壓縮比例
up_sample=True, # 是否上取樣
self_attention=True, # 是否使用自注意力
):
# 呼叫父類初始化方法
super().__init__()
# 設定上取樣解析度
up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1)
# 設定隱藏通道數
hidden_channels = (
[(in_channels + cat_dim, in_channels)] # 第一層的通道
+ [(in_channels, in_channels)] * (num_blocks - 2) # 中間層的通道
+ [(in_channels, out_channels)] # 最後一層的通道
)
attentions = [] # 用於儲存注意力塊
resnets_in = [] # 用於儲存輸入 ResNet 塊
resnets_out = [] # 用於儲存輸出 ResNet 塊
# 設定自注意力和上下文維度
self.self_attention = self_attention
self.context_dim = context_dim
# 如果使用自注意力,新增註意力塊
if self_attention:
attentions.append(
Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
)
else:
attentions.append(nn.Identity()) # 否則新增身份對映
# 遍歷隱藏通道和上取樣解析度
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
# 新增輸入 ResNet 塊
resnets_in.append(
Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution)
)
# 如果上下文維度不為 None,新增註意力塊
if context_dim is not None:
attentions.append(
Kandinsky3AttentionBlock(
in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
)
)
else:
attentions.append(nn.Identity()) # 否則新增身份對映
# 新增輸出 ResNet 塊
resnets_out.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
)
# 將注意力塊和 ResNet 塊轉換為模組列表
self.attentions = nn.ModuleList(attentions)
self.resnets_in = nn.ModuleList(resnets_in)
self.resnets_out = nn.ModuleList(resnets_out)
# 前向傳播方法
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
# 遍歷注意力塊和 ResNet 塊進行前向計算
for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
x = resnet_in(x, time_embed) # 輸入經過 ResNet 塊
if self.context_dim is not None: # 如果上下文維度存在
x = attention(x, time_embed, context, context_mask, image_mask) # 應用注意力塊
x = resnet_out(x, time_embed) # 輸出經過 ResNet 塊
# 如果使用自注意力,應用首個注意力塊
if self.self_attention:
x = self.attentions[0](x, time_embed, image_mask=image_mask)
return x # 返回處理後的結果
# 定義 Kandinsky3DownSampleBlock 類,繼承自 nn.Module
class Kandinsky3DownSampleBlock(nn.Module):
# 初始化方法,設定各引數
def __init__(
self,
in_channels, # 輸入通道數
out_channels, # 輸出通道數
time_embed_dim, # 時間嵌入維度
context_dim=None, # 上下文維度,可選
num_blocks=3, # 塊的數量
groups=32, # 分組數
head_dim=64, # 頭維度
expansion_ratio=4, # 擴充套件比例
compression_ratio=2, # 壓縮比例
down_sample=True, # 是否下采樣
self_attention=True, # 是否使用自注意力
):
# 呼叫父類的初始化方法
super().__init__()
# 初始化注意力模組列表
attentions = []
# 初始化輸入殘差塊列表
resnets_in = []
# 初始化輸出殘差塊列表
resnets_out = []
# 儲存自注意力標誌
self.self_attention = self_attention
# 儲存上下文維度
self.context_dim = context_dim
# 如果啟用自注意力
if self_attention:
# 新增 Kandinsky3AttentionBlock 到注意力列表
attentions.append(
Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
)
else:
# 否則新增身份層(不改變輸入)
attentions.append(nn.Identity())
# 生成上取樣解析度列表
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]]
# 生成隱藏通道的元組列表
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
# 遍歷隱藏通道和上取樣解析度
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
# 新增輸入殘差塊到列表
resnets_in.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
)
# 如果上下文維度不為 None
if context_dim is not None:
# 新增 Kandinsky3AttentionBlock 到注意力列表
attentions.append(
Kandinsky3AttentionBlock(
out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
)
)
else:
# 否則新增身份層(不改變輸入)
attentions.append(nn.Identity())
# 新增輸出殘差塊到列表
resnets_out.append(
Kandinsky3ResNetBlock(
out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution
)
)
# 將注意力模組列表轉換為 nn.ModuleList 以便管理
self.attentions = nn.ModuleList(attentions)
# 將輸入殘差塊列表轉換為 nn.ModuleList 以便管理
self.resnets_in = nn.ModuleList(resnets_in)
# 將輸出殘差塊列表轉換為 nn.ModuleList 以便管理
self.resnets_out = nn.ModuleList(resnets_out)
# 定義前向傳播方法
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
# 如果啟用自注意力
if self.self_attention:
# 使用第一個注意力模組處理輸入
x = self.attentions[0](x, time_embed, image_mask=image_mask)
# 遍歷剩餘的注意力模組、輸入和輸出殘差塊
for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
# 透過輸入殘差塊處理輸入
x = resnet_in(x, time_embed)
# 如果上下文維度不為 None
if self.context_dim is not None:
# 使用當前注意力模組處理輸入
x = attention(x, time_embed, context, context_mask, image_mask)
# 透過輸出殘差塊處理輸入
x = resnet_out(x, time_embed)
# 返回處理後的輸出
return x
# 定義 Kandinsky3ConditionalGroupNorm 類,繼承自 nn.Module
class Kandinsky3ConditionalGroupNorm(nn.Module):
# 初始化方法,設定分組數、標準化形狀和上下文維度
def __init__(self, groups, normalized_shape, context_dim):
# 呼叫父類建構函式
super().__init__()
# 建立分組歸一化層,不使用仿射變換
self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
# 定義上下文多層感知機,包含 SiLU 啟用和線性層
self.context_mlp = nn.Sequential(nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape))
# 將線性層的權重初始化為零
self.context_mlp[1].weight.data.zero_()
# 將線性層的偏置初始化為零
self.context_mlp[1].bias.data.zero_()
# 前向傳播方法,接收輸入和上下文
def forward(self, x, context):
# 透過上下文多層感知機處理上下文
context = self.context_mlp(context)
# 為了匹配輸入的維度,逐層擴充套件上下文
for _ in range(len(x.shape[2:])):
context = context.unsqueeze(-1)
# 將上下文分割為縮放和偏移量
scale, shift = context.chunk(2, dim=1)
# 應用歸一化並進行縮放和偏移
x = self.norm(x) * (scale + 1.0) + shift
# 返回處理後的輸入
return x
# 定義 Kandinsky3Block 類,繼承自 nn.Module
class Kandinsky3Block(nn.Module):
# 初始化方法,設定輸入通道、輸出通道、時間嵌入維度等引數
def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
# 呼叫父類建構函式
super().__init__()
# 建立條件分組歸一化層
self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
# 定義 SiLU 啟用函式
self.activation = nn.SiLU()
# 如果需要上取樣,使用轉置卷積進行上取樣
if up_resolution is not None and up_resolution:
self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
else:
# 否則使用恆等對映
self.up_sample = nn.Identity()
# 根據卷積核大小確定填充
padding = int(kernel_size > 1)
# 定義卷積投影層
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
# 如果不需要上取樣,定義下采樣卷積層
if up_resolution is not None and not up_resolution:
self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
else:
# 否則使用恆等對映
self.down_sample = nn.Identity()
# 前向傳播方法,接收輸入和時間嵌入
def forward(self, x, time_embed):
# 透過條件分組歸一化處理輸入
x = self.group_norm(x, time_embed)
# 應用啟用函式
x = self.activation(x)
# 進行上取樣
x = self.up_sample(x)
# 透過卷積投影層處理輸入
x = self.projection(x)
# 進行下采樣
x = self.down_sample(x)
# 返回處理後的輸出
return x
# 定義 Kandinsky3ResNetBlock 類,繼承自 nn.Module
class Kandinsky3ResNetBlock(nn.Module):
# 初始化方法,設定輸入通道、輸出通道、時間嵌入維度等引數
def __init__(
self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None]
# 初始化父類
):
super().__init__()
# 定義卷積核的大小
kernel_sizes = [1, 3, 3, 1]
# 計算隱藏通道數
hidden_channel = max(in_channels, out_channels) // compression_ratio
# 構建隱藏通道的元組列表
hidden_channels = (
[(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
)
# 建立包含多個 Kandinsky3Block 的模組列表
self.resnet_blocks = nn.ModuleList(
[
Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
# 將隱藏通道、卷積核大小和上取樣解析度組合在一起
for (in_channel, out_channel), kernel_size, up_resolution in zip(
hidden_channels, kernel_sizes, up_resolutions
)
]
)
# 定義上取樣的快捷連線
self.shortcut_up_sample = (
nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
# 如果存在上取樣解析度,則使用反摺積;否則使用恆等對映
if True in up_resolutions
else nn.Identity()
)
# 定義通道數不同時的投影連線
self.shortcut_projection = (
nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
)
# 定義下采樣的快捷連線
self.shortcut_down_sample = (
nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
# 如果存在下采樣解析度,則使用卷積;否則使用恆等對映
if False in up_resolutions
else nn.Identity()
)
# 前向傳播方法
def forward(self, x, time_embed):
# 初始化輸出為輸入
out = x
# 依次透過每個 ResNet 塊
for resnet_block in self.resnet_blocks:
out = resnet_block(out, time_embed)
# 上取樣輸入
x = self.shortcut_up_sample(x)
# 投影輸入到輸出通道
x = self.shortcut_projection(x)
# 下采樣輸入
x = self.shortcut_down_sample(x)
# 將輸出與處理後的輸入相加
x = x + out
# 返回最終輸出
return x
# 定義 Kandinsky3AttentionPooling 類,繼承自 nn.Module
class Kandinsky3AttentionPooling(nn.Module):
# 初始化方法,接受通道數、上下文維度和頭維度
def __init__(self, num_channels, context_dim, head_dim=64):
# 呼叫父類建構函式
super().__init__()
# 建立注意力機制物件,指定輸入和輸出維度及其他引數
self.attention = Attention(
context_dim,
context_dim,
dim_head=head_dim,
out_dim=num_channels,
out_bias=False,
)
# 前向傳播方法
def forward(self, x, context, context_mask=None):
# 將上下文掩碼轉換為與上下文相同的資料型別
context_mask = context_mask.to(dtype=context.dtype)
# 使用注意力機制計算上下文與其平均值的加權和
context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
# 返回輸入與上下文的和
return x + context.squeeze(1)
# 定義 Kandinsky3AttentionBlock 類,繼承自 nn.Module
class Kandinsky3AttentionBlock(nn.Module):
# 初始化方法,接受多種引數
def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
# 呼叫父類建構函式
super().__init__()
# 建立條件組歸一化物件,用於輸入規範化
self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
# 建立注意力機制物件,指定輸入和輸出維度及其他引數
self.attention = Attention(
num_channels,
context_dim or num_channels,
dim_head=head_dim,
out_dim=num_channels,
out_bias=False,
)
# 計算隱藏通道數,作為擴充套件比和通道數的乘積
hidden_channels = expansion_ratio * num_channels
# 建立條件組歸一化物件,用於輸出規範化
self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
# 定義前饋網路,包含兩個卷積層和啟用函式
self.feed_forward = nn.Sequential(
nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
nn.SiLU(),
nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
)
# 前向傳播方法
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
# 獲取輸入的高度和寬度
height, width = x.shape[-2:]
# 對輸入進行歸一化處理
out = self.in_norm(x, time_embed)
# 將輸出重塑為適合注意力機制的形狀
out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1)
# 如果沒有上下文,則使用當前的輸出作為上下文
context = context if context is not None else out
# 如果存在上下文掩碼,轉換為與上下文相同的資料型別
if context_mask is not None:
context_mask = context_mask.to(dtype=context.dtype)
# 使用注意力機制處理輸出和上下文
out = self.attention(out, context, context_mask)
# 重塑輸出為原始輸入形狀
out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
# 將處理後的輸出與原輸入相加
x = x + out
# 對相加後的結果進行輸出歸一化
out = self.out_norm(x, time_embed)
# 透過前饋網路處理歸一化輸出
out = self.feed_forward(out)
# 將處理後的輸出與相加後的輸入相加
x = x + out
# 返回最終輸出
return x
.\diffusers\models\unets\unet_motion_model.py
# 版權宣告,表明該檔案的所有權及相關使用條款
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根據 Apache License, Version 2.0 (“許可證”) 授權;
# 除非遵守許可證,否則您不得使用此檔案。
# 您可以在以下網址獲取許可證副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非適用法律或書面協議另有規定,否則根據許可證分發的軟體
# 是在“按原樣”基礎上分發的,不提供任何形式的保證或條件,
# 無論是明示還是暗示。
# 有關許可證所管轄的許可權和限制,請參見許可證。
#
# 匯入所需的庫和模組
from dataclasses import dataclass # 匯入資料類裝飾器
from typing import Any, Dict, Optional, Tuple, Union # 匯入型別提示相關的型別
import torch # 匯入 PyTorch 庫
import torch.nn as nn # 匯入 PyTorch 的神經網路模組
import torch.nn.functional as F # 匯入 PyTorch 的功能性神經網路模組
import torch.utils.checkpoint # 匯入 PyTorch 的檢查點功能
# 匯入自定義的配置和載入工具
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, is_torch_version, logging # 匯入常用的工具函式
from ...utils.torch_utils import apply_freeu # 匯入應用 FreeU 的工具函式
from ..attention import BasicTransformerBlock # 匯入基礎變換器模組
from ..attention_processor import ( # 匯入注意力處理器相關的類
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
AttnProcessor2_0,
FusedAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps # 匯入時間步嵌入相關的類
from ..modeling_utils import ModelMixin # 匯入模型混合工具
from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D # 匯入 ResNet 相關的模組
from ..transformers.dual_transformer_2d import DualTransformer2DModel # 匯入雙重變換器模型
from ..transformers.transformer_2d import Transformer2DModel # 匯入 2D 變換器模型
from .unet_2d_blocks import UNetMidBlock2DCrossAttn # 匯入 U-Net 中間塊
from .unet_2d_condition import UNet2DConditionModel # 匯入條件 U-Net 模型
logger = logging.get_logger(__name__) # 獲取當前模組的日誌記錄器,便於除錯和日誌輸出
@dataclass
class UNetMotionOutput(BaseOutput): # 定義 UNetMotionOutput 資料類,繼承自 BaseOutput
"""
[`UNetMotionOutput`] 的輸出。
引數:
sample (`torch.Tensor` 的形狀為 `(batch_size, num_channels, num_frames, height, width)`):
基於 `encoder_hidden_states` 輸入的隱藏狀態輸出。模型最後一層的輸出。
"""
sample: torch.Tensor # 定義 sample 屬性,型別為 torch.Tensor
class AnimateDiffTransformer3D(nn.Module): # 定義 AnimateDiffTransformer3D 類,繼承自 nn.Module
"""
一個用於影片類資料的變換器模型。
# 引數說明部分,描述初始化函式中每個引數的用途
Parameters:
# 多頭注意力機制中頭的數量,預設為16
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
# 每個頭中的通道數,預設為88
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
# 輸入和輸出的通道數,如果輸入是**連續**,則需要指定
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
# Transformer塊的層數,預設為1
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
# dropout機率,預設為0.0
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
# 使用的`encoder_hidden_states`維度數
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
# 配置`TransformerBlock`的注意力是否包含偏置引數
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
# 潛在影像的寬度,如果輸入是**離散**,則需要指定
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
# 該值在訓練期間固定,用於學習位置嵌入的數量
This is fixed during training since it is used to learn a number of position embeddings.
# 前饋中的啟用函式,預設為"geglu"
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
# 配置`TransformerBlock`是否使用可學習的逐元素仿射引數進行歸一化
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
# 配置每個`TransformerBlock`是否包含兩個自注意力層
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
# 應用到序列輸入的位置資訊嵌入的型別
positional_embeddings: (`str`, *optional*):
The type of positional embeddings to apply to the sequence input before passing use.
# 應用位置嵌入的最大序列長度
num_positional_embeddings: (`int`, *optional*):
The maximum length of the sequence over which to apply positional embeddings.
"""
# 初始化方法定義
def __init__(
# 多頭注意力機制中頭的數量,預設為16
self,
num_attention_heads: int = 16,
# 每個頭中的通道數,預設為88
attention_head_dim: int = 88,
# 輸入通道數,可選
in_channels: Optional[int] = None,
# 輸出通道數,可選
out_channels: Optional[int] = None,
# Transformer塊的層數,預設為1
num_layers: int = 1,
# dropout機率,預設為0.0
dropout: float = 0.0,
# 歸一化分組數,預設為32
norm_num_groups: int = 32,
# 使用的`encoder_hidden_states`維度數,可選
cross_attention_dim: Optional[int] = None,
# 注意力是否包含偏置引數,預設為False
attention_bias: bool = False,
# 潛在影像的寬度,可選
sample_size: Optional[int] = None,
# 前饋中的啟用函式,預設為"geglu"
activation_fn: str = "geglu",
# 歸一化是否使用可學習的逐元素仿射引數,預設為True
norm_elementwise_affine: bool = True,
# 每個`TransformerBlock`是否包含兩個自注意力層,預設為True
double_self_attention: bool = True,
# 位置資訊嵌入的型別,可選
positional_embeddings: Optional[str] = None,
# 應用位置嵌入的最大序列長度,可選
num_positional_embeddings: Optional[int] = None,
):
# 呼叫父類的建構函式以初始化父類的屬性
super().__init__()
# 設定注意力頭的數量
self.num_attention_heads = num_attention_heads
# 設定每個注意力頭的維度
self.attention_head_dim = attention_head_dim
# 計算內部維度,等於注意力頭數量與每個注意力頭維度的乘積
inner_dim = num_attention_heads * attention_head_dim
# 設定輸入通道數
self.in_channels = in_channels
# 定義歸一化層,使用組歸一化,允許可學習的偏移
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
# 定義輸入線性變換層,將輸入通道對映到內部維度
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. 定義變換器塊
self.transformer_blocks = nn.ModuleList(
[
# 建立指定數量的基本變換器塊
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
positional_embeddings=positional_embeddings,
num_positional_embeddings=num_positional_embeddings,
)
# 遍歷建立 num_layers 個基本變換器塊
for _ in range(num_layers)
]
)
# 定義輸出線性變換層,將內部維度對映回輸入通道數
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(
# 定義前向傳播方法的輸入引數
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
encoder_hidden_states: Optional[torch.LongTensor] = None, # 編碼器的隱藏狀態,預設為 None
timestep: Optional[torch.LongTensor] = None, # 時間步,預設為 None
class_labels: Optional[torch.LongTensor] = None, # 類標籤,預設為 None
num_frames: int = 1, # 幀數,預設值為 1
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 跨注意力引數,預設為 None
# 該方法用於 [`AnimateDiffTransformer3D`] 的前向傳播
) -> torch.Tensor:
"""
方法引數說明:
hidden_states (`torch.LongTensor`): 輸入的隱狀態,形狀為 `(batch size, num latent pixels)` 或 `(batch size, channel, height, width)`
encoder_hidden_states ( `torch.LongTensor`, *可選*):
交叉注意力層的條件嵌入。如果未提供,交叉注意力將預設使用自注意力。
timestep ( `torch.LongTensor`, *可選*):
用於指示去噪步驟的時間戳。
class_labels ( `torch.LongTensor`, *可選*):
用於指示類別標籤的條件嵌入。
num_frames (`int`, *可選*, 預設為 1):
每個批次處理的幀數,用於重新形狀隱狀態。
cross_attention_kwargs (`dict`, *可選*):
可選的關鍵字字典,傳遞給 `AttentionProcessor`。
返回值:
torch.Tensor:
輸出張量。
"""
# 1. 輸入
# 獲取輸入隱狀態的形狀資訊
batch_frames, channel, height, width = hidden_states.shape
# 計算批次大小
batch_size = batch_frames // num_frames
# 將隱狀態保留用於殘差連線
residual = hidden_states
# 調整隱狀態的形狀以適應批次和幀數
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
# 調整維度順序以便後續處理
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
# 對隱狀態進行規範化
hidden_states = self.norm(hidden_states)
# 再次調整維度順序並重塑為適當的形狀
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
# 輸入層投影
hidden_states = self.proj_in(hidden_states)
# 2. 處理塊
# 遍歷每個變換塊以處理隱狀態
for block in self.transformer_blocks:
hidden_states = block(
hidden_states, # 當前的隱狀態
encoder_hidden_states=encoder_hidden_states, # 可選的編碼器隱狀態
timestep=timestep, # 可選的時間戳
cross_attention_kwargs=cross_attention_kwargs, # 可選的交叉注意力引數
class_labels=class_labels, # 可選的類標籤
)
# 3. 輸出
# 輸出層投影
hidden_states = self.proj_out(hidden_states)
# 調整輸出張量的形狀
hidden_states = (
hidden_states[None, None, :] # 新增維度
.reshape(batch_size, height, width, num_frames, channel) # 重塑為適當形狀
.permute(0, 3, 4, 1, 2) # 調整維度順序
.contiguous() # 確保記憶體連續性
)
# 最終調整輸出的形狀
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
# 將殘差新增到輸出中以形成最終輸出
output = hidden_states + residual
# 返回最終的輸出張量
return output
# 定義一個名為 DownBlockMotion 的類,繼承自 nn.Module
class DownBlockMotion(nn.Module):
# 初始化方法,定義多個引數,包括輸入輸出通道、dropout 率等
def __init__(
self,
in_channels: int, # 輸入通道數量
out_channels: int, # 輸出通道數量
temb_channels: int, # 時間嵌入通道數量
dropout: float = 0.0, # dropout 率,預設為 0
num_layers: int = 1, # 網路層數,預設為 1
resnet_eps: float = 1e-6, # ResNet 的 epsilon 引數
resnet_time_scale_shift: str = "default", # ResNet 時間尺度偏移
resnet_act_fn: str = "swish", # ResNet 啟用函式,預設為 swish
resnet_groups: int = 32, # ResNet 組數,預設為 32
resnet_pre_norm: bool = True, # ResNet 是否使用預歸一化
output_scale_factor: float = 1.0, # 輸出縮放因子
add_downsample: bool = True, # 是否新增下采樣層
downsample_padding: int = 1, # 下采樣時的填充
temporal_num_attention_heads: Union[int, Tuple[int]] = 1, # 時間注意力頭數
temporal_cross_attention_dim: Optional[int] = None, # 時間交叉注意力維度
temporal_max_seq_length: int = 32, # 最大序列長度
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 每個塊的變換器層數
temporal_double_self_attention: bool = True, # 是否雙重自注意力
):
# 前向傳播方法,接收隱藏狀態和時間嵌入等引數
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
num_frames: int = 1, # 幀數,預設為 1
*args, # 接受任意位置引數
**kwargs, # 接受任意關鍵字引數
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: # 返回隱藏狀態和輸出狀態的張量或元組
# 檢查位置引數或關鍵字引數中的 scale 是否被傳遞
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定義棄用資訊,提示使用者 scale 引數將被忽略
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫棄用函式,發出警告
deprecate("scale", "1.0.0", deprecation_message)
# 初始化輸出狀態為一個空元組
output_states = ()
# 將 ResNet 和運動模組進行配對
blocks = zip(self.resnets, self.motion_modules)
# 遍歷每對 ResNet 和運動模組
for resnet, motion_module in blocks:
# 如果處於訓練模式且啟用了梯度檢查點
if self.training and self.gradient_checkpointing:
# 定義一個自定義前向傳播函式
def create_custom_forward(module):
def custom_forward(*inputs): # 自定義前向函式,接受任意輸入
return module(*inputs) # 返回模組的輸出
return custom_forward # 返回自定義前向函式
# 如果 PyTorch 版本大於等於 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用檢查點機制來節省記憶體
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 建立的自定義前向函式
hidden_states, # 輸入的隱藏狀態
temb, # 輸入的時間嵌入
use_reentrant=False, # 不使用重入
)
else:
# 在較早版本中也使用檢查點機制
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 如果不是訓練模式,直接透過 ResNet 處理隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 使用運動模組處理當前的隱藏狀態
hidden_states = motion_module(hidden_states, num_frames=num_frames)
# 將當前隱藏狀態新增到輸出狀態中
output_states = output_states + (hidden_states,)
# 如果下采樣器不為空
if self.downsamplers is not None:
# 遍歷每個下采樣器
for downsampler in self.downsamplers:
# 透過下采樣器處理隱藏狀態
hidden_states = downsampler(hidden_states)
# 將下采樣後的隱藏狀態新增到輸出狀態中
output_states = output_states + (hidden_states,)
# 返回最終的隱藏狀態和輸出狀態
return hidden_states, output_states
# 初始化方法,用於設定網路的引數
def __init__(
# 輸入通道數量
self,
in_channels: int,
# 輸出通道數量
out_channels: int,
# 時間嵌入通道數量
temb_channels: int,
# dropout 機率,預設為 0.0
dropout: float = 0.0,
# 網路層數,預設為 1
num_layers: int = 1,
# 每個塊中的變換器層數,預設為 1
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 中的 epsilon 值,預設為 1e-6
resnet_eps: float = 1e-6,
# ResNet 時間尺度偏移,預設為 "default"
resnet_time_scale_shift: str = "default",
# ResNet 啟用函式,預設為 "swish"
resnet_act_fn: str = "swish",
# ResNet 中的組數,預設為 32
resnet_groups: int = 32,
# 是否在 ResNet 中使用預歸一化,預設為 True
resnet_pre_norm: bool = True,
# 注意力頭的數量,預設為 1
num_attention_heads: int = 1,
# 交叉注意力維度,預設為 1280
cross_attention_dim: int = 1280,
# 輸出縮放因子,預設為 1.0
output_scale_factor: float = 1.0,
# 下采樣填充,預設為 1
downsample_padding: int = 1,
# 是否新增下采樣層,預設為 True
add_downsample: bool = True,
# 是否使用雙交叉注意力,預設為 False
dual_cross_attention: bool = False,
# 是否使用線性投影,預設為 False
use_linear_projection: bool = False,
# 是否僅使用交叉注意力,預設為 False
only_cross_attention: bool = False,
# 是否提升注意力計算精度,預設為 False
upcast_attention: bool = False,
# 注意力型別,預設為 "default"
attention_type: str = "default",
# 時間交叉注意力維度,可選引數
temporal_cross_attention_dim: Optional[int] = None,
# 時間注意力頭數量,預設為 8
temporal_num_attention_heads: int = 8,
# 時間序列的最大長度,預設為 32
temporal_max_seq_length: int = 32,
# 時間變換器塊中的層數,預設為 1
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# 是否使用雙重自注意力,預設為 True
temporal_double_self_attention: bool = True,
# 前向傳播方法,定義如何透過模型傳遞資料
def forward(
# 隱藏狀態張量,輸入到模型中的主要資料
self,
hidden_states: torch.Tensor,
# 可選的時間嵌入張量
temb: Optional[torch.Tensor] = None,
# 可選的編碼器隱藏狀態
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可選的注意力掩碼
attention_mask: Optional[torch.Tensor] = None,
# 每次處理的幀數,預設為 1
num_frames: int = 1,
# 可選的編碼器注意力掩碼
encoder_attention_mask: Optional[torch.Tensor] = None,
# 可選的交叉注意力引數
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可選的額外殘差連線
additional_residuals: Optional[torch.Tensor] = None,
):
# 檢查 cross_attention_kwargs 是否不為空
if cross_attention_kwargs is not None:
# 檢查 scale 引數是否存在,若存在則發出警告
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 初始化輸出狀態為空元組
output_states = ()
# 將自殘差網路、注意力模組和運動模組組合成一個列表
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
# 遍歷組合後的模組及其索引
for i, (resnet, attn, motion_module) in enumerate(blocks):
# 如果處於訓練狀態且啟用了梯度檢查點
if self.training and self.gradient_checkpointing:
# 定義自定義前向傳播函式
def create_custom_forward(module, return_dict=None):
# 定義實際的前向傳播邏輯
def custom_forward(*inputs):
# 根據 return_dict 的值選擇返回方式
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
# 定義檢查點引數字典,根據 PyTorch 版本設定
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 使用檢查點機制計算隱藏狀態
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
# 透過注意力模組處理隱藏狀態
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
# 在非訓練模式下直接透過殘差網路處理隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 透過注意力模組處理隱藏狀態
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# 透過運動模組處理隱藏狀態
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
# 如果是最後一對模組且有額外殘差,則將其應用到隱藏狀態
if i == len(blocks) - 1 and additional_residuals is not None:
hidden_states = hidden_states + additional_residuals
# 將當前隱藏狀態新增到輸出狀態中
output_states = output_states + (hidden_states,)
# 如果存在下采樣模組,則依次應用它們
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
# 將下采樣後的隱藏狀態新增到輸出狀態中
output_states = output_states + (hidden_states,)
# 返回最終的隱藏狀態和輸出狀態
return hidden_states, output_states
# 定義一個繼承自 nn.Module 的類,用於交叉注意力上取樣塊
class CrossAttnUpBlockMotion(nn.Module):
# 初始化方法,設定各層的引數
def __init__(
self,
in_channels: int, # 輸入通道數
out_channels: int, # 輸出通道數
prev_output_channel: int, # 前一層輸出的通道數
temb_channels: int, # 時間嵌入通道數
resolution_idx: Optional[int] = None, # 解析度索引,預設為 None
dropout: float = 0.0, # dropout 機率
num_layers: int = 1, # 層數
transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 每個塊的變換器層數
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值
resnet_time_scale_shift: str = "default", # ResNet 時間縮放偏移
resnet_act_fn: str = "swish", # ResNet 啟用函式
resnet_groups: int = 32, # ResNet 組數
resnet_pre_norm: bool = True, # 是否在前面進行歸一化
num_attention_heads: int = 1, # 注意力頭的數量
cross_attention_dim: int = 1280, # 交叉注意力的維度
output_scale_factor: float = 1.0, # 輸出縮放因子
add_upsample: bool = True, # 是否新增上取樣
dual_cross_attention: bool = False, # 是否使用雙重交叉注意力
use_linear_projection: bool = False, # 是否使用線性投影
only_cross_attention: bool = False, # 是否僅使用交叉注意力
upcast_attention: bool = False, # 是否上浮注意力
attention_type: str = "default", # 注意力型別
temporal_cross_attention_dim: Optional[int] = None, # 時間交叉注意力維度,預設為 None
temporal_num_attention_heads: int = 8, # 時間注意力頭數量
temporal_max_seq_length: int = 32, # 時間序列的最大長度
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 時間塊的變換器層數
# 定義前向傳播方法
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 之前隱藏狀態的元組
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可選的編碼器隱藏狀態
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 交叉注意力的可選引數
upsample_size: Optional[int] = None, # 可選的上取樣大小
attention_mask: Optional[torch.Tensor] = None, # 可選的注意力掩碼
encoder_attention_mask: Optional[torch.Tensor] = None, # 可選的編碼器注意力掩碼
num_frames: int = 1, # 幀數,預設為 1
# 定義一個繼承自 nn.Module 的類,用於上取樣塊
class UpBlockMotion(nn.Module):
# 初始化方法,設定各層的引數
def __init__(
self,
in_channels: int, # 輸入通道數
prev_output_channel: int, # 前一層輸出的通道數
out_channels: int, # 輸出通道數
temb_channels: int, # 時間嵌入通道數
resolution_idx: Optional[int] = None, # 解析度索引,預設為 None
dropout: float = 0.0, # dropout 機率
num_layers: int = 1, # 層數
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值
resnet_time_scale_shift: str = "default", # ResNet 時間縮放偏移
resnet_act_fn: str = "swish", # ResNet 啟用函式
resnet_groups: int = 32, # ResNet 組數
resnet_pre_norm: bool = True, # 是否在前面進行歸一化
output_scale_factor: float = 1.0, # 輸出縮放因子
add_upsample: bool = True, # 是否新增上取樣
temporal_cross_attention_dim: Optional[int] = None, # 時間交叉注意力維度,預設為 None
temporal_num_attention_heads: int = 8, # 時間注意力頭數量
temporal_max_seq_length: int = 32, # 時間序列的最大長度
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 時間塊的變換器層數
):
# 呼叫父類的初始化方法
super().__init__()
# 初始化空列表,用於存放 ResNet 模組
resnets = []
# 初始化空列表,用於存放運動模組
motion_modules = []
# 支援每個時間塊的變換層數量為變數
if isinstance(temporal_transformer_layers_per_block, int):
# 將單個整數轉換為與層數相同的元組
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
# 檢查傳入的層數是否與預期一致
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
# 遍歷每層,構建 ResNet 和運動模組
for i in range(num_layers):
# 設定跳過連線的通道數
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 設定當前層的輸入通道數
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 新增 ResNetBlock2D 模組到 resnets 列表
resnets.append(
ResnetBlock2D(
# 輸入通道數為當前層的輸入和跳過連線的通道數之和
in_channels=resnet_in_channels + res_skip_channels,
# 輸出通道數設定
out_channels=out_channels,
# 時間嵌入通道數
temb_channels=temb_channels,
# 小常數以避免除零
eps=resnet_eps,
# 組歸一化的組數
groups=resnet_groups,
# Dropout 率
dropout=dropout,
# 時間嵌入的歸一化方式
time_embedding_norm=resnet_time_scale_shift,
# 啟用函式設定
non_linearity=resnet_act_fn,
# 輸出尺度因子
output_scale_factor=output_scale_factor,
# 是否使用預歸一化
pre_norm=resnet_pre_norm,
)
)
# 新增 AnimateDiffTransformer3D 模組到 motion_modules 列表
motion_modules.append(
AnimateDiffTransformer3D(
# 注意力頭的數量
num_attention_heads=temporal_num_attention_heads,
# 輸入通道數
in_channels=out_channels,
# 當前層的變換層數量
num_layers=temporal_transformer_layers_per_block[i],
# 組歸一化的組數
norm_num_groups=resnet_groups,
# 跨注意力維度
cross_attention_dim=temporal_cross_attention_dim,
# 是否使用注意力偏置
attention_bias=False,
# 啟用函式型別
activation_fn="geglu",
# 位置資訊嵌入型別
positional_embeddings="sinusoidal",
# 位置資訊嵌入數量
num_positional_embeddings=temporal_max_seq_length,
# 每個注意力頭的維度
attention_head_dim=out_channels // temporal_num_attention_heads,
)
)
# 將 ResNet 模組列表轉換為 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 將運動模組列表轉換為 nn.ModuleList
self.motion_modules = nn.ModuleList(motion_modules)
# 如果需要上取樣,則初始化上取樣模組
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 否則,設定為 None
self.upsamplers = None
# 設定梯度檢查點標誌為 False
self.gradient_checkpointing = False
# 儲存解析度索引
self.resolution_idx = resolution_idx
def forward(
# 前向傳播方法的引數定義
self,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可選的時間嵌入
temb: Optional[torch.Tensor] = None,
# 上取樣大小
upsample_size=None,
# 幀數,預設為 1
num_frames: int = 1,
# 額外的引數
*args,
**kwargs,
# 函式返回型別為 torch.Tensor
) -> torch.Tensor:
# 檢查傳入引數是否存在或 "scale" 引數是否為非 None
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定義棄用提示資訊
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫 deprecate 函式記錄棄用警告
deprecate("scale", "1.0.0", deprecation_message)
# 檢查 FreeU 是否啟用,確保相關屬性均不為 None
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
# 將自定義模組打包成元組,方便遍歷
blocks = zip(self.resnets, self.motion_modules)
# 遍歷每一對 resnet 和 motion_module
for resnet, motion_module in blocks:
# 從隱藏狀態元組中彈出最後一個隱藏狀態
res_hidden_states = res_hidden_states_tuple[-1]
# 更新隱藏狀態元組,移除最後一個元素
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 如果啟用 FreeU,則僅對前兩個階段進行操作
if is_freeu_enabled:
# 應用 FreeU 函式獲取新的隱藏狀態
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
# 將當前隱藏狀態和殘差隱藏狀態在維度 1 上拼接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 如果在訓練模式並且啟用了梯度檢查點
if self.training and self.gradient_checkpointing:
# 定義建立自定義前向傳播函式
def create_custom_forward(module):
# 定義自定義前向傳播的實現
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 如果 torch 版本大於等於 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用檢查點機制儲存記憶體
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
# 否則使用舊版檢查點機制
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 否則直接透過 resnet 計算隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 透過 motion_module 處理隱藏狀態,傳入幀數
hidden_states = motion_module(hidden_states, num_frames=num_frames)
# 如果存在上取樣器,則對每個上取樣器進行處理
if self.upsamplers is not None:
for upsampler in self.upsamplers:
# 透過上取樣器處理隱藏狀態,傳入上取樣大小
hidden_states = upsampler(hidden_states, upsample_size)
# 返回最終處理後的隱藏狀態
return hidden_states
# 定義 UNetMidBlockCrossAttnMotion 類,繼承自 nn.Module
class UNetMidBlockCrossAttnMotion(nn.Module):
# 初始化方法,定義類的引數
def __init__(
self,
in_channels: int, # 輸入通道數
temb_channels: int, # 時間嵌入通道數
dropout: float = 0.0, # Dropout 率
num_layers: int = 1, # 層數
transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 每個塊的變換層數
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值
resnet_time_scale_shift: str = "default", # ResNet 時間尺度偏移
resnet_act_fn: str = "swish", # ResNet 啟用函式型別
resnet_groups: int = 32, # ResNet 組數
resnet_pre_norm: bool = True, # 是否進行前置歸一化
num_attention_heads: int = 1, # 注意力頭數量
output_scale_factor: float = 1.0, # 輸出縮放因子
cross_attention_dim: int = 1280, # 交叉注意力維度
dual_cross_attention: bool = False, # 是否使用雙重交叉注意力
use_linear_projection: bool = False, # 是否使用線性投影
upcast_attention: bool = False, # 是否上升注意力精度
attention_type: str = "default", # 注意力型別
temporal_num_attention_heads: int = 1, # 時間注意力頭數量
temporal_cross_attention_dim: Optional[int] = None, # 時間交叉注意力維度
temporal_max_seq_length: int = 32, # 時間序列最大長度
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 時間塊的變換層數
# 前向傳播方法,定義輸入和輸出
def forward(
self,
hidden_states: torch.Tensor, # 隱藏狀態的輸入張量
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可選的編碼器隱藏狀態
attention_mask: Optional[torch.Tensor] = None, # 可選的注意力掩碼
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 可選的交叉注意力引數
encoder_attention_mask: Optional[torch.Tensor] = None, # 可選的編碼器注意力掩碼
num_frames: int = 1, # 幀數
# 該函式的返回型別為 torch.Tensor
) -> torch.Tensor:
# 檢查交叉注意力引數是否不為 None
if cross_attention_kwargs is not None:
# 如果引數中包含 "scale",發出警告,說明該引數已棄用
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 透過第一個殘差網路處理隱藏狀態
hidden_states = self.resnets[0](hidden_states, temb)
# 將注意力層、殘差網路和運動模組打包在一起
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
# 遍歷每個注意力層、殘差網路和運動模組
for attn, resnet, motion_module in blocks:
# 如果在訓練模式下並且啟用了梯度檢查點
if self.training and self.gradient_checkpointing:
# 建立自定義前向函式
def create_custom_forward(module, return_dict=None):
# 定義自定義前向函式,接受任意輸入
def custom_forward(*inputs):
# 如果返回字典不為 None,使用返回字典呼叫模組
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
# 否則直接呼叫模組
return module(*inputs)
return custom_forward
# 根據 PyTorch 版本設定檢查點引數
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 呼叫注意力模組並獲取輸出的第一個元素
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# 使用梯度檢查點對運動模組進行前向傳播
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
**ckpt_kwargs,
)
# 使用梯度檢查點對殘差網路進行前向傳播
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
# 在非訓練模式下直接呼叫注意力模組
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# 呼叫運動模組,傳入隱藏狀態和幀數
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
# 呼叫殘差網路處理隱藏狀態
hidden_states = resnet(hidden_states, temb)
# 返回處理後的隱藏狀態
return hidden_states
# 定義一個繼承自 nn.Module 的運動模組類
class MotionModules(nn.Module):
# 初始化方法,接收多個引數配置運動模組
def __init__(
self,
in_channels: int, # 輸入通道數
layers_per_block: int = 2, # 每個模組塊的層數,預設是 2
transformer_layers_per_block: Union[int, Tuple[int]] = 8, # 每個塊中的變換層數
num_attention_heads: Union[int, Tuple[int]] = 8, # 注意力頭的數量
attention_bias: bool = False, # 是否使用注意力偏差
cross_attention_dim: Optional[int] = None, # 交叉注意力維度
activation_fn: str = "geglu", # 啟用函式,預設使用 "geglu"
norm_num_groups: int = 32, # 歸一化組的數量
max_seq_length: int = 32, # 最大序列長度
):
# 呼叫父類初始化方法
super().__init__()
# 初始化運動模組列表
self.motion_modules = nn.ModuleList([])
# 如果變換層數是整數,重複為每個模組塊填充
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block
# 檢查變換層數與塊層數是否匹配
elif len(transformer_layers_per_block) != layers_per_block:
raise ValueError(
f"The number of transformer layers per block must match the number of layers per block, "
f"got {layers_per_block} and {len(transformer_layers_per_block)}"
)
# 遍歷每個模組塊
for i in range(layers_per_block):
# 向運動模組列表新增 AnimateDiffTransformer3D 例項
self.motion_modules.append(
AnimateDiffTransformer3D(
in_channels=in_channels, # 輸入通道數
num_layers=transformer_layers_per_block[i], # 當前塊的變換層數
norm_num_groups=norm_num_groups, # 歸一化組的數量
cross_attention_dim=cross_attention_dim, # 交叉注意力維度
activation_fn=activation_fn, # 啟用函式
attention_bias=attention_bias, # 注意力偏差
num_attention_heads=num_attention_heads, # 注意力頭數量
attention_head_dim=in_channels // num_attention_heads, # 每個注意力頭的維度
positional_embeddings="sinusoidal", # 使用正弦波的位置嵌入
num_positional_embeddings=max_seq_length, # 位置嵌入的數量
)
)
# 定義一個運動介面卡類,結合多個混合類
class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin):
@register_to_config
# 初始化方法,配置多個運動介面卡引數
def __init__(
self,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), # 塊輸出通道
motion_layers_per_block: Union[int, Tuple[int]] = 2, # 每個運動塊的層數
motion_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1, # 每個運動塊中的變換層數
motion_mid_block_layers_per_block: int = 1, # 中間塊的層數
motion_transformer_layers_per_mid_block: Union[int, Tuple[int]] = 1, # 中間塊中的變換層數
motion_num_attention_heads: Union[int, Tuple[int]] = 8, # 中間塊的注意力頭數量
motion_norm_num_groups: int = 32, # 中間塊的歸一化組數量
motion_max_seq_length: int = 32, # 中間塊的最大序列長度
use_motion_mid_block: bool = True, # 是否使用中間塊
conv_in_channels: Optional[int] = None, # 輸入通道數
):
pass # 前向傳播方法,尚未實現
# 定義一個修改後的條件 2D UNet 模型
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
r"""
一個修改後的條件 2D UNet 模型,接收嘈雜樣本、條件狀態和時間步,返回形狀輸出。
該模型繼承自 [`ModelMixin`]。檢視超類文件以獲取所有模型的通用方法實現(如下載或儲存)。
"""
# 支援梯度檢查點
_supports_gradient_checkpointing = True
@register_to_config
# 初始化方法,用於建立類的例項
def __init__(
# 可選引數,樣本大小,預設為 None
self,
sample_size: Optional[int] = None,
# 輸入通道數,預設為 4
in_channels: int = 4,
# 輸出通道數,預設為 4
out_channels: int = 4,
# 下采樣塊的型別元組
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockMotion", # 第一個下采樣塊型別
"CrossAttnDownBlockMotion", # 第二個下采樣塊型別
"CrossAttnDownBlockMotion", # 第三個下采樣塊型別
"DownBlockMotion", # 第四個下采樣塊型別
),
# 上取樣塊的型別元組
up_block_types: Tuple[str, ...] = (
"UpBlockMotion", # 第一個上取樣塊型別
"CrossAttnUpBlockMotion", # 第二個上取樣塊型別
"CrossAttnUpBlockMotion", # 第三個上取樣塊型別
"CrossAttnUpBlockMotion", # 第四個上取樣塊型別
),
# 塊的輸出通道數元組
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
# 每個塊的層數,預設為 2
layers_per_block: Union[int, Tuple[int]] = 2,
# 下采樣填充,預設為 1
downsample_padding: int = 1,
# 中間塊的縮放因子,預設為 1
mid_block_scale_factor: float = 1,
# 啟用函式型別,預設為 "silu"
act_fn: str = "silu",
# 歸一化的組數,預設為 32
norm_num_groups: int = 32,
# 歸一化的 epsilon 值,預設為 1e-5
norm_eps: float = 1e-5,
# 交叉注意力的維度,預設為 1280
cross_attention_dim: int = 1280,
# 每個塊的變換器層數,預設為 1
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
# 可選引數,反向變換器層數,預設為 None
reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
# 時間變換器的層數,預設為 1
temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
# 可選引數,反向時間變換器層數,預設為 None
reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
# 每個中間塊的變換器層數,預設為 None
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
# 每個中間塊的時間變換器層數,預設為 1
temporal_transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = 1,
# 是否使用線性投影,預設為 False
use_linear_projection: bool = False,
# 注意力頭的數量,預設為 8
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
# 動作最大序列長度,預設為 32
motion_max_seq_length: int = 32,
# 動作注意力頭的數量,預設為 8
motion_num_attention_heads: Union[int, Tuple[int, ...]] = 8,
# 可選引數,反向動作注意力頭的數量,預設為 None
reverse_motion_num_attention_heads: Optional[Union[int, Tuple[int, ...], Tuple[Tuple[int, ...], ...]]] = None,
# 是否使用動作中間塊,預設為 True
use_motion_mid_block: bool = True,
# 中間塊的層數,預設為 1
mid_block_layers: int = 1,
# 編碼器隱藏層維度,預設為 None
encoder_hid_dim: Optional[int] = None,
# 編碼器隱藏層型別,預設為 None
encoder_hid_dim_type: Optional[str] = None,
# 可選引數,附加嵌入型別,預設為 None
addition_embed_type: Optional[str] = None,
# 可選引數,附加時間嵌入維度,預設為 None
addition_time_embed_dim: Optional[int] = None,
# 可選引數,投影類別嵌入的輸入維度,預設為 None
projection_class_embeddings_input_dim: Optional[int] = None,
# 可選引數,時間條件投影維度,預設為 None
time_cond_proj_dim: Optional[int] = None,
# 類方法,用於從 UNet2DConditionModel 建立物件
@classmethod
def from_unet2d(
cls,
# UNet2DConditionModel 物件
unet: UNet2DConditionModel,
# 可選的運動介面卡,預設為 None
motion_adapter: Optional[MotionAdapter] = None,
# 是否載入權重,預設為 True
load_weights: bool = True,
# 凍結 UNet2DConditionModel 的權重,只保留運動模組可訓練,便於微調
def freeze_unet2d_params(self) -> None:
"""Freeze the weights of just the UNet2DConditionModel, and leave the motion modules
unfrozen for fine tuning.
"""
# 凍結所有引數
for param in self.parameters():
# 將引數的 requires_grad 屬性設定為 False,禁止梯度更新
param.requires_grad = False
# 解凍運動模組
for down_block in self.down_blocks:
# 獲取當前下采樣塊的運動模組
motion_modules = down_block.motion_modules
for param in motion_modules.parameters():
# 將運動模組引數的 requires_grad 屬性設定為 True,允許梯度更新
param.requires_grad = True
for up_block in self.up_blocks:
# 獲取當前上取樣塊的運動模組
motion_modules = up_block.motion_modules
for param in motion_modules.parameters():
# 將運動模組引數的 requires_grad 屬性設定為 True,允許梯度更新
param.requires_grad = True
# 檢查中間塊是否具有運動模組
if hasattr(self.mid_block, "motion_modules"):
# 獲取中間塊的運動模組
motion_modules = self.mid_block.motion_modules
for param in motion_modules.parameters():
# 將運動模組引數的 requires_grad 屬性設定為 True,允許梯度更新
param.requires_grad = True
# 載入運動模組的狀態字典
def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None:
# 遍歷運動介面卡的下采樣塊
for i, down_block in enumerate(motion_adapter.down_blocks):
# 載入下采樣塊的運動模組狀態字典
self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict())
# 遍歷運動介面卡的上取樣塊
for i, up_block in enumerate(motion_adapter.up_blocks):
# 載入上取樣塊的運動模組狀態字典
self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict())
# 支援沒有中間塊的舊運動模組
if hasattr(self.mid_block, "motion_modules"):
# 載入中間塊的運動模組狀態字典
self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict())
# 儲存運動模組的狀態
def save_motion_modules(
self,
save_directory: str,
is_main_process: bool = True,
safe_serialization: bool = True,
variant: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
) -> None:
# 獲取當前模型的狀態字典
state_dict = self.state_dict()
# 提取所有運動模組的狀態
motion_state_dict = {}
for k, v in state_dict.items():
# 篩選出包含 "motion_modules" 的鍵值對
if "motion_modules" in k:
motion_state_dict[k] = v
# 建立運動介面卡例項
adapter = MotionAdapter(
block_out_channels=self.config["block_out_channels"],
motion_layers_per_block=self.config["layers_per_block"],
motion_norm_num_groups=self.config["norm_num_groups"],
motion_num_attention_heads=self.config["motion_num_attention_heads"],
motion_max_seq_length=self.config["motion_max_seq_length"],
use_motion_mid_block=self.config["use_motion_mid_block"],
)
# 載入運動狀態字典
adapter.load_state_dict(motion_state_dict)
# 儲存介面卡的預訓練狀態
adapter.save_pretrained(
save_directory=save_directory,
is_main_process=is_main_process,
safe_serialization=safe_serialization,
variant=variant,
push_to_hub=push_to_hub,
**kwargs,
)
@property
# 從 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 複製的屬性
# 定義一個方法,返回模型中所有注意力處理器的字典
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
返回值:
`dict` 型別的注意力處理器: 包含模型中所有注意力處理器的字典,
按照其權重名稱索引。
"""
# 初始化一個空字典,用於儲存注意力處理器
processors = {}
# 定義一個遞迴函式,用於新增註意力處理器
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 檢查模組是否具有 'get_processor' 方法
if hasattr(module, "get_processor"):
# 將處理器新增到字典中,鍵為處理器名稱
processors[f"{name}.processor"] = module.get_processor()
# 遍歷模組的所有子模組
for sub_name, child in module.named_children():
# 遞迴呼叫,繼續新增子模組的處理器
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回處理器字典
return processors
# 遍歷當前物件的所有子模組
for name, module in self.named_children():
# 呼叫遞迴函式新增所有處理器
fn_recursive_add_processors(name, module, processors)
# 返回最終的處理器字典
return processors
# 從 diffusers.models.unets.unet_2d_condition 中複製的方法,用於設定注意力處理器
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
設定用於計算注意力的注意力處理器。
引數:
processor (`dict` of `AttentionProcessor` 或僅 `AttentionProcessor`):
例項化的處理器類或處理器類的字典,將被設定為**所有** `Attention` 層的處理器。
如果 `processor` 是字典,鍵需要定義相應的交叉注意力處理器的路徑。
當設定可訓練的注意力處理器時,強烈推薦這樣做。
"""
# 獲取當前注意力處理器字典的鍵數量
count = len(self.attn_processors.keys())
# 檢查傳入的處理器字典長度是否與注意力層數量匹配
if isinstance(processor, dict) and len(processor) != count:
# 如果不匹配,丟擲錯誤
raise ValueError(
f"傳入了處理器字典,但處理器數量 {len(processor)} 與"
f" 注意力層數量 {count} 不匹配。請確保傳入 {count} 個處理器類。"
)
# 定義一個遞迴函式,用於設定注意力處理器
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 檢查模組是否具有 'set_processor' 方法
if hasattr(module, "set_processor"):
# 如果處理器不是字典,直接設定處理器
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 從字典中彈出對應的處理器並設定
module.set_processor(processor.pop(f"{name}.processor"))
# 遍歷模組的所有子模組
for sub_name, child in module.named_children():
# 遞迴呼叫,繼續設定子模組的處理器
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍歷當前物件的所有子模組
for name, module in self.named_children():
# 呼叫遞迴函式設定所有處理器
fn_recursive_attn_processor(name, module, processor)
# 定義一個方法以啟用前向分塊處理
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
設定注意力處理器以使用[前饋分塊](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。
引數:
chunk_size (`int`, *可選*):
前饋層的塊大小。如果未指定,將單獨對維度為`dim`的每個張量執行前饋層。
dim (`int`, *可選*, 預設為`0`):
前饋計算應分塊的維度。選擇dim=0(批次)或dim=1(序列長度)。
"""
# 檢查dim引數是否在有效範圍內(0或1)
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# 預設塊大小為1
chunk_size = chunk_size or 1
# 定義遞迴前饋函式以設定模組的分塊前饋處理
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模組有set_chunk_feed_forward屬性,設定塊大小和維度
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遍歷模組的子模組
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 遍歷當前物件的子模組,應用遞迴前饋函式
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# 定義一個方法以禁用前向分塊處理
def disable_forward_chunking(self) -> None:
# 定義遞迴前饋函式以設定模組的分塊前饋處理為None
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模組有set_chunk_feed_forward屬性,設定塊大小和維度為None
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遍歷模組的子模組
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 遍歷當前物件的子模組,應用遞迴前饋函式
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
# 從diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor複製的方法
def set_default_attn_processor(self) -> None:
"""
禁用自定義注意力處理器並設定預設的注意力實現。
"""
# 如果所有注意力處理器都是ADDED_KV_ATTENTION_PROCESSORS型別
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 設定處理器為AttnAddedKVProcessor
processor = AttnAddedKVProcessor()
# 如果所有注意力處理器都是CROSS_ATTENTION_PROCESSORS型別
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 設定處理器為AttnProcessor
processor = AttnProcessor()
else:
# 丟擲錯誤,表示不能在不匹配的注意力處理器型別下呼叫該方法
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
# 設定當前物件的注意力處理器
self.set_attn_processor(processor)
# 定義一個方法以設定模組的梯度檢查點
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
# 檢查模組是否為特定型別,如果是則設定其梯度檢查點屬性
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
module.gradient_checkpointing = value
# 從diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu複製的方法
# 啟用 FreeU 機制,接受四個浮點型縮放因子作為引數
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
# 文件字串,描述該方法的作用及引數含義
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
# 遍歷上取樣塊,併為每個塊設定縮放因子
for i, upsample_block in enumerate(self.up_blocks):
# 為上取樣塊設定階段1的縮放因子
setattr(upsample_block, "s1", s1)
# 為上取樣塊設定階段2的縮放因子
setattr(upsample_block, "s2", s2)
# 為上取樣塊設定階段1的主幹特徵縮放因子
setattr(upsample_block, "b1", b1)
# 為上取樣塊設定階段2的主幹特徵縮放因子
setattr(upsample_block, "b2", b2)
# 禁用 FreeU 機制
def disable_freeu(self) -> None:
# 文件字串,描述該方法的作用
"""Disables the FreeU mechanism."""
# 定義 FreeU 相關的鍵名集合
freeu_keys = {"s1", "s2", "b1", "b2"}
# 遍歷上取樣塊
for i, upsample_block in enumerate(self.up_blocks):
# 遍歷 FreeU 鍵名
for k in freeu_keys:
# 檢查上取樣塊是否具有該屬性或該屬性是否不為 None
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
# 將上取樣塊的該屬性設定為 None
setattr(upsample_block, k, None)
# 啟用融合的 QKV 投影
def fuse_qkv_projections(self):
# 文件字串,描述該方法的作用
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
# 初始化原始注意力處理器為 None
self.original_attn_processors = None
# 遍歷注意力處理器
for _, attn_processor in self.attn_processors.items():
# 檢查注意力處理器類名中是否包含 "Added"
if "Added" in str(attn_processor.__class__.__name__):
# 丟擲異常,說明不支援該操作
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# 儲存原始的注意力處理器
self.original_attn_processors = self.attn_processors
# 遍歷所有模組
for module in self.modules():
# 檢查模組是否為 Attention 型別
if isinstance(module, Attention):
# 融合投影
module.fuse_projections(fuse=True)
# 設定融合後的注意力處理器
self.set_attn_processor(FusedAttnProcessor2_0())
# 解融合 QKV 投影的方法(省略具體實現)
# 定義一個禁用融合 QKV 投影的方法
def unfuse_qkv_projections(self):
"""如果啟用了,禁用融合 QKV 投影。
<Tip warning={true}>
此 API 是 🧪 實驗性。
</Tip>
"""
# 檢查原始注意力處理器是否不為 None
if self.original_attn_processors is not None:
# 設定當前注意力處理器為原始的注意力處理器
self.set_attn_processor(self.original_attn_processors)
# 定義前向傳播方法,接收多個引數
def forward(
self,
# 輸入樣本張量
sample: torch.Tensor,
# 時間步,可以是張量、浮點數或整數
timestep: Union[torch.Tensor, float, int],
# 編碼器隱藏狀態張量
encoder_hidden_states: torch.Tensor,
# 可選的時間步條件張量
timestep_cond: Optional[torch.Tensor] = None,
# 可選的注意力掩碼張量
attention_mask: Optional[torch.Tensor] = None,
# 可選的交叉注意力引數字典
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可選的附加條件引數字典
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
# 可選的下塊附加殘差元組
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
# 可選的中間塊附加殘差張量
mid_block_additional_residual: Optional[torch.Tensor] = None,
# 是否返回字典格式的結果,預設為 True
return_dict: bool = True,