diffusers-原始碼解析-四-

绝不原创的飞龙發表於2024-10-22

diffusers 原始碼解析(四)

.\diffusers\models\attention_flax.py

# 版權宣告,表明該程式碼的版權歸 HuggingFace 團隊所有
# 根據 Apache 2.0 許可證授權使用該檔案,未遵守許可證不得使用
# 許可證獲取連結
# 指出該軟體是以“現狀”分發,不附帶任何明示或暗示的保證
# 具體的許可權和限制請參見許可證

# 匯入 functools 模組,用於函數語言程式設計工具
import functools
# 匯入 math 模組,提供數學相關的功能
import math

# 匯入 flax.linen 模組,作為神經網路構建的工具
import flax.linen as nn
# 匯入 jax 庫,用於加速計算
import jax
# 匯入 jax.numpy 模組,提供類似於 NumPy 的陣列功能
import jax.numpy as jnp


def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
    """多頭點積注意力,查詢數目有限的實現。"""
    # 獲取 key 的維度資訊,包括 key 的數量、頭數和特徵維度
    num_kv, num_heads, k_features = key.shape[-3:]
    # 獲取 value 的特徵維度
    v_features = value.shape[-1]
    # 確保 key_chunk_size 不超過 num_kv
    key_chunk_size = min(key_chunk_size, num_kv)
    # 對查詢進行縮放,防止數值溢位
    query = query / jnp.sqrt(k_features)

    @functools.partial(jax.checkpoint, prevent_cse=False)
    def summarize_chunk(query, key, value):
        # 計算查詢和鍵之間的注意力權重
        attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)

        # 獲取每個查詢的最大得分,用於數值穩定性
        max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
        # 計算最大得分的梯度不更新
        max_score = jax.lax.stop_gradient(max_score)
        # 計算經過 softmax 的注意力權重
        exp_weights = jnp.exp(attn_weights - max_score)

        # 計算加權後的值
        exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
        # 獲取每個查詢的最大得分
        max_score = jnp.einsum("...qhk->...qh", max_score)

        return (exp_values, exp_weights.sum(axis=-1), max_score)

    def chunk_scanner(chunk_idx):
        # 動態切片獲取鍵的部分資料
        key_chunk = jax.lax.dynamic_slice(
            operand=key,
            start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0],  # [...,k,h,d]
            slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features],  # [...,k,h,d]
        )

        # 動態切片獲取值的部分資料
        value_chunk = jax.lax.dynamic_slice(
            operand=value,
            start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0],  # [...,v,h,d]
            slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features],  # [...,v,h,d]
        )

        return summarize_chunk(query, key_chunk, value_chunk)

    # 對每個鍵塊進行注意力計算
    chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))

    # 計算全域性最大得分
    global_max = jnp.max(chunk_max, axis=0, keepdims=True)
    # 計算每個塊與全域性最大得分的差異
    max_diffs = jnp.exp(chunk_max - global_max)

    # 更新值和權重以便於歸一化
    chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
    chunk_weights *= max_diffs

    # 計算所有塊的總值和總權重
    all_values = chunk_values.sum(axis=0)
    all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)

    # 返回歸一化後的總值
    return all_values / all_weights


def jax_memory_efficient_attention(
    query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
):
    r"""
    # Flax 實現的記憶體高效多頭點積注意力機制,相關文獻連結
    Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
    # 相關 GitHub 專案連結
    https://github.com/AminRezaei0x443/memory-efficient-attention

    # 引數說明:
    # query: 輸入的查詢張量,形狀為 (batch..., query_length, head, query_key_depth_per_head)
    Args:
        query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
        # key: 輸入的鍵張量,形狀為 (batch..., key_value_length, head, query_key_depth_per_head)
        key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
        # value: 輸入的值張量,形狀為 (batch..., key_value_length, head, value_depth_per_head)
        value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
        # precision: 計算時的數值精度,預設值為 jax.lax.Precision.HIGHEST
        precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
            numerical precision for computation
        # query_chunk_size: 將查詢陣列劃分的塊大小,必須能整除 query_length
        query_chunk_size (`int`, *optional*, defaults to 1024):
            chunk size to divide query array value must divide query_length equally without remainder
        # key_chunk_size: 將鍵和值陣列劃分的塊大小,必須能整除 key_value_length
        key_chunk_size (`int`, *optional*, defaults to 4096):
            chunk size to divide key and value array value must divide key_value_length equally without remainder

    # 返回值為形狀為 (batch..., query_length, head, value_depth_per_head) 的陣列
    Returns:
        (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
    """
    # 獲取查詢張量的最後三個維度的大小
    num_q, num_heads, q_features = query.shape[-3:]

    # 定義一個函式,用於掃描處理每個查詢塊
    def chunk_scanner(chunk_idx, _):
        # 從查詢陣列中切片出當前塊
        query_chunk = jax.lax.dynamic_slice(
            # 操作的物件是查詢張量
            operand=query,
            # 起始索引,保持前面的維度不變,從 chunk_idx 開始切片
            start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0],  # [...,q,h,d]
            # 切片的大小,前面的維度不變,後面根據塊大小取最小值
            slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features],  # [...,q,h,d]
        )

        return (
            # 返回未使用的下一個塊索引
            chunk_idx + query_chunk_size,  # unused ignore it
            # 呼叫注意力函式處理當前查詢塊
            _query_chunk_attention(
                query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
            ),
        )

    # 使用 jax.lax.scan 進行塊的掃描處理
    _, res = jax.lax.scan(
        f=chunk_scanner,  # 處理函式
        init=0,  # 初始化塊索引為 0
        xs=None,  # 不需要額外的輸入資料
        # 根據查詢塊大小計算要處理的塊數
        length=math.ceil(num_q / query_chunk_size),  # start counter  # stop counter
    )

    # 將所有塊的結果在第 -3 維度拼接在一起
    return jnp.concatenate(res, axis=-3)  # fuse the chunked result back
# 定義一個 Flax 的多頭注意力模組,遵循文獻中的描述
class FlaxAttention(nn.Module):
    r"""
    Flax多頭注意力模組,詳見: https://arxiv.org/abs/1706.03762

    引數:
        query_dim (:obj:`int`):
            輸入隱藏狀態的維度
        heads (:obj:`int`, *optional*, defaults to 8):
            注意力頭的數量
        dim_head (:obj:`int`, *optional*, defaults to 64):
            每個頭內隱藏狀態的維度
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            dropout比率
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
            啟用記憶體高效注意力 https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):
            是否將頭維度拆分為自注意力計算的新軸。通常情況下,啟用該標誌可以加快Stable Diffusion 2.x和Stable Diffusion XL的計算速度。
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            引數的 `dtype`

    """

    # 定義輸入引數的型別和預設值
    query_dim: int
    heads: int = 8
    dim_head: int = 64
    dropout: float = 0.0
    use_memory_efficient_attention: bool = False
    split_head_dim: bool = False
    dtype: jnp.dtype = jnp.float32

    # 設定模組的初始化函式
    def setup(self):
        # 計算內部維度為每個頭的維度與頭的數量的乘積
        inner_dim = self.dim_head * self.heads
        # 計算縮放因子
        self.scale = self.dim_head**-0.5

        # 建立權重矩陣,使用舊的命名 {to_q, to_k, to_v, to_out}
        self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
        # 建立鍵的權重矩陣
        self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
        # 建立值的權重矩陣
        self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")

        # 建立輸出的權重矩陣
        self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
        # 建立dropout層
        self.dropout_layer = nn.Dropout(rate=self.dropout)

    # 將張量的頭部維度重塑為批次維度
    def reshape_heads_to_batch_dim(self, tensor):
        # 解構張量的形狀
        batch_size, seq_len, dim = tensor.shape
        head_size = self.heads
        # 重塑張量形狀以分離頭維度
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
        # 轉置張量的維度
        tensor = jnp.transpose(tensor, (0, 2, 1, 3))
        # 進一步重塑為批次與頭維度合併
        tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
        return tensor

    # 將張量的批次維度重塑為頭部維度
    def reshape_batch_dim_to_heads(self, tensor):
        # 解構張量的形狀
        batch_size, seq_len, dim = tensor.shape
        head_size = self.heads
        # 重塑張量形狀以合併批次與頭維度
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
        # 轉置張量的維度
        tensor = jnp.transpose(tensor, (0, 2, 1, 3))
        # 進一步重塑為合併批次與頭維度
        tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
        return tensor

# 定義一個 Flax 基礎變換器塊層,使用 GLU 啟用函式,詳見:
class FlaxBasicTransformerBlock(nn.Module):
    r"""
    Flax 變換器塊層,使用 `GLU` (門控線性單元) 啟用函式,詳見:
    https://arxiv.org/abs/1706.03762
    # 引數說明部分
    Parameters:
        dim (:obj:`int`):  # 內部隱藏狀態的維度
            Inner hidden states dimension
        n_heads (:obj:`int`):  # 注意力頭的數量
            Number of heads
        d_head (:obj:`int`):  # 每個頭內部隱藏狀態的維度
            Hidden states dimension inside each head
        dropout (:obj:`float`, *optional*, defaults to 0.0):  # 隨機失活率
            Dropout rate
        only_cross_attention (`bool`, defaults to `False`):  # 是否僅應用交叉注意力
            Whether to only apply cross attention.
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):  # 引數資料型別
            Parameters `dtype`
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):  # 啟用記憶體高效注意力
            enable memory efficient attention https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):  # 是否將頭維度拆分為新軸
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
    """

    dim: int  # 內部隱藏狀態維度的型別宣告
    n_heads: int  # 注意力頭數量的型別宣告
    d_head: int  # 每個頭的隱藏狀態維度的型別宣告
    dropout: float = 0.0  # 隨機失活率的預設值
    only_cross_attention: bool = False  # 預設不只應用交叉注意力
    dtype: jnp.dtype = jnp.float32  # 預設資料型別為 jnp.float32
    use_memory_efficient_attention: bool = False  # 預設不啟用記憶體高效注意力
    split_head_dim: bool = False  # 預設不拆分頭維度

    def setup(self):
        # 設定自注意力(如果 only_cross_attention 為 True,則為交叉注意力)
        self.attn1 = FlaxAttention(
            self.dim,  # 傳入的內部隱藏狀態維度
            self.n_heads,  # 傳入的注意力頭數量
            self.d_head,  # 傳入的每個頭的隱藏狀態維度
            self.dropout,  # 傳入的隨機失活率
            self.use_memory_efficient_attention,  # 是否使用記憶體高效注意力
            self.split_head_dim,  # 是否拆分頭維度
            dtype=self.dtype,  # 傳入的資料型別
        )
        # 設定交叉注意力
        self.attn2 = FlaxAttention(
            self.dim,  # 傳入的內部隱藏狀態維度
            self.n_heads,  # 傳入的注意力頭數量
            self.d_head,  # 傳入的每個頭的隱藏狀態維度
            self.dropout,  # 傳入的隨機失活率
            self.use_memory_efficient_attention,  # 是否使用記憶體高效注意力
            self.split_head_dim,  # 是否拆分頭維度
            dtype=self.dtype,  # 傳入的資料型別
        )
        # 設定前饋網路
        self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)  # 前饋網路初始化
        # 設定第一個歸一化層
        self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)  # 歸一化層初始化
        # 設定第二個歸一化層
        self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)  # 歸一化層初始化
        # 設定第三個歸一化層
        self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)  # 歸一化層初始化
        # 設定丟棄層
        self.dropout_layer = nn.Dropout(rate=self.dropout)  # 丟棄層初始化
    # 定義可呼叫物件,接收隱藏狀態、上下文和確定性標誌
    def __call__(self, hidden_states, context, deterministic=True):
        # 儲存輸入的隱藏狀態以供後續殘差連線使用
        residual = hidden_states
        # 如果僅執行交叉注意力,進行相關的處理
        if self.only_cross_attention:
            hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
        else:
            # 否則執行自注意力處理
            hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
        # 將自注意力的輸出與輸入的殘差相加
        hidden_states = hidden_states + residual
    
        # 交叉注意力處理
        residual = hidden_states
        # 處理交叉注意力
        hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
        # 將交叉注意力的輸出與輸入的殘差相加
        hidden_states = hidden_states + residual
    
        # 前饋網路處理
        residual = hidden_states
        # 應用前饋網路
        hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
        # 將前饋網路的輸出與輸入的殘差相加
        hidden_states = hidden_states + residual
    
        # 返回經過 dropout 處理的最終隱藏狀態
        return self.dropout_layer(hidden_states, deterministic=deterministic)
# 定義一個二維的 Flax Transformer 模型,繼承自 nn.Module
class FlaxTransformer2DModel(nn.Module):
    r"""
    A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
    https://arxiv.org/pdf/1506.02025.pdf

    文件字串,描述該類的功能和引數。

    Parameters:
        in_channels (:obj:`int`):
            Input number of channels
        n_heads (:obj:`int`):
            Number of heads
        d_head (:obj:`int`):
            Hidden states dimension inside each head
        depth (:obj:`int`, *optional*, defaults to 1):
            Number of transformers block
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        use_linear_projection (`bool`, defaults to `False`): tbd
        only_cross_attention (`bool`, defaults to `False`): tbd
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
            enable memory efficient attention https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
    """
    
    # 定義輸入通道數
    in_channels: int
    # 定義頭的數量
    n_heads: int
    # 定義每個頭的隱藏狀態維度
    d_head: int
    # 定義 Transformer 塊的數量,預設為 1
    depth: int = 1
    # 定義 Dropout 率,預設為 0.0
    dropout: float = 0.0
    # 定義是否使用線性投影,預設為 False
    use_linear_projection: bool = False
    # 定義是否僅使用交叉注意力,預設為 False
    only_cross_attention: bool = False
    # 定義引數的資料型別,預設為 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 定義是否使用記憶體高效注意力,預設為 False
    use_memory_efficient_attention: bool = False
    # 定義是否將頭維度拆分為新的軸,預設為 False
    split_head_dim: bool = False

    # 設定模型的元件
    def setup(self):
        # 使用 Group Normalization 規範化層,分組數為 32,epsilon 為 1e-5
        self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)

        # 計算內部維度為頭的數量乘以每個頭的維度
        inner_dim = self.n_heads * self.d_head
        # 根據是否使用線性投影選擇輸入層
        if self.use_linear_projection:
            # 建立一個線性投影層,輸出維度為 inner_dim,資料型別為 self.dtype
            self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
        else:
            # 建立一個卷積層,輸出維度為 inner_dim,卷積核大小為 (1, 1),步幅為 (1, 1),填充方式為 "VALID",資料型別為 self.dtype
            self.proj_in = nn.Conv(
                inner_dim,
                kernel_size=(1, 1),
                strides=(1, 1),
                padding="VALID",
                dtype=self.dtype,
            )

        # 建立一系列 Transformer 塊,數量為 depth
        self.transformer_blocks = [
            FlaxBasicTransformerBlock(
                inner_dim,
                self.n_heads,
                self.d_head,
                dropout=self.dropout,
                only_cross_attention=self.only_cross_attention,
                dtype=self.dtype,
                use_memory_efficient_attention=self.use_memory_efficient_attention,
                split_head_dim=self.split_head_dim,
            )
            for _ in range(self.depth)  # 迴圈生成每個 Transformer 塊
        ]

        # 根據是否使用線性投影選擇輸出層
        if self.use_linear_projection:
            # 建立一個線性投影層,輸出維度為 inner_dim,資料型別為 self.dtype
            self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
        else:
            # 建立一個卷積層,輸出維度為 inner_dim,卷積核大小為 (1, 1),步幅為 (1, 1),填充方式為 "VALID",資料型別為 self.dtype
            self.proj_out = nn.Conv(
                inner_dim,
                kernel_size=(1, 1),
                strides=(1, 1),
                padding="VALID",
                dtype=self.dtype,
            )

        # 建立一個 Dropout 層,Dropout 率為 self.dropout
        self.dropout_layer = nn.Dropout(rate=self.dropout)
    # 定義可呼叫物件的方法,接收隱藏狀態、上下文和確定性標誌
    def __call__(self, hidden_states, context, deterministic=True):
        # 解構隱藏狀態的形狀,獲取批次大小、高度、寬度和通道數
        batch, height, width, channels = hidden_states.shape
        # 儲存原始隱藏狀態以用於殘差連線
        residual = hidden_states
        # 對隱藏狀態進行歸一化處理
        hidden_states = self.norm(hidden_states)
        # 如果使用線性投影,則重塑隱藏狀態
        if self.use_linear_projection:
            # 將隱藏狀態重塑為(batch, height * width, channels)的形狀
            hidden_states = hidden_states.reshape(batch, height * width, channels)
            # 應用輸入投影
            hidden_states = self.proj_in(hidden_states)
        else:
            # 直接應用輸入投影
            hidden_states = self.proj_in(hidden_states)
            # 將隱藏狀態重塑為(batch, height * width, channels)的形狀
            hidden_states = hidden_states.reshape(batch, height * width, channels)
    
        # 遍歷每個變換塊,更新隱藏狀態
        for transformer_block in self.transformer_blocks:
            # 透過變換塊處理隱藏狀態和上下文
            hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
    
        # 如果使用線性投影,則先應用輸出投影
        if self.use_linear_projection:
            hidden_states = self.proj_out(hidden_states)
            # 將隱藏狀態重塑回原來的形狀
            hidden_states = hidden_states.reshape(batch, height, width, channels)
        else:
            # 先重塑隱藏狀態
            hidden_states = hidden_states.reshape(batch, height, width, channels)
            # 再應用輸出投影
            hidden_states = self.proj_out(hidden_states)
    
        # 將隱藏狀態與原始狀態相加,實現殘差連線
        hidden_states = hidden_states + residual
        # 返回經過dropout層處理後的隱藏狀態
        return self.dropout_layer(hidden_states, deterministic=deterministic)
# 定義一個 Flax 的前饋神經網路模組,繼承自 nn.Module
class FlaxFeedForward(nn.Module):
    r"""
    Flax 模組封裝了兩個線性層,中間由一個非線性啟用函式分隔。它是 PyTorch 的
    [`FeedForward`] 類的對應物,具有以下簡化:
    - 啟用函式目前硬編碼為門控線性單元,來自:
    https://arxiv.org/abs/2002.05202
    - `dim_out` 等於 `dim`。
    - 隱藏維度的數量硬編碼為 `dim * 4` 在 [`FlaxGELU`] 中。

    引數:
        dim (:obj:`int`):
            內部隱藏狀態的維度
        dropout (:obj:`float`, *可選*, 預設為 0.0):
            丟棄率
        dtype (:obj:`jnp.dtype`, *可選*, 預設為 jnp.float32):
            引數的資料型別
    """

    # 定義類屬性 dim、dropout 和 dtype,分別表示維度、丟棄率和資料型別
    dim: int
    dropout: float = 0.0
    dtype: jnp.dtype = jnp.float32

    # 設定方法,初始化網路層
    def setup(self):
        # 第二個線性層暫時稱為 net_2,以匹配順序層的索引
        self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)  # 初始化 FlaxGEGLU 網路
        self.net_2 = nn.Dense(self.dim, dtype=self.dtype)  # 初始化線性層

    # 定義前向傳播方法
    def __call__(self, hidden_states, deterministic=True):
        hidden_states = self.net_0(hidden_states, deterministic=deterministic)  # 透過 net_0 處理隱藏狀態
        hidden_states = self.net_2(hidden_states)  # 透過 net_2 處理隱藏狀態
        return hidden_states  # 返回處理後的隱藏狀態


# 定義 Flax 的 GEGLU 啟用層,繼承自 nn.Module
class FlaxGEGLU(nn.Module):
    r"""
    Flax 實現的線性層後跟門控線性單元啟用函式變體,來自
    https://arxiv.org/abs/2002.05202。

    引數:
        dim (:obj:`int`):
            輸入隱藏狀態的維度
        dropout (:obj:`float`, *可選*, 預設為 0.0):
            丟棄率
        dtype (:obj:`jnp.dtype`, *可選*, 預設為 jnp.float32):
            引數的資料型別
    """

    # 定義類屬性 dim、dropout 和 dtype
    dim: int
    dropout: float = 0.0
    dtype: jnp.dtype = jnp.float32

    # 設定方法,初始化網路層
    def setup(self):
        inner_dim = self.dim * 4  # 計算內部維度
        self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)  # 初始化線性層
        self.dropout_layer = nn.Dropout(rate=self.dropout)  # 初始化丟棄層

    # 定義前向傳播方法
    def __call__(self, hidden_states, deterministic=True):
        hidden_states = self.proj(hidden_states)  # 透過線性層處理隱藏狀態
        hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)  # 將輸出分為兩個部分
        return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)  # 返回帶丟棄的啟用輸出

.\diffusers\models\attention_processor.py

# 版權宣告,標明該檔案的版權歸 HuggingFace 團隊所有
# 該檔案根據 Apache 2.0 許可證進行許可
# 在遵守許可證的情況下,您可以使用該檔案
# 許可證的副本可以在以下網址獲取
# http://www.apache.org/licenses/LICENSE-2.0
# 除非法律要求或書面同意,否則軟體按 "現狀" 提供,不附帶任何明示或暗示的擔保
# 請參閱許可證以瞭解有關許可權和限制的具體資訊

import inspect  # 匯入 inspect 模組,用於獲取物件的資訊
import math  # 匯入 math 模組,提供數學函式
from typing import Callable, List, Optional, Tuple, Union  # 匯入型別提示相關的型別

import torch  # 匯入 PyTorch 庫
import torch.nn.functional as F  # 匯入 PyTorch 中的神經網路功能模組,並重新命名為 F
from torch import nn  # 從 PyTorch 匯入 nn 模組,提供神經網路的構建塊

from ..image_processor import IPAdapterMaskProcessor  # 從上層模組匯入 IPAdapterMaskProcessor
from ..utils import deprecate, logging  # 從上層模組匯入棄用和日誌記錄功能
from ..utils.import_utils import is_torch_npu_available, is_xformers_available  # 匯入檢查 PyTorch NPU 和 xformers 可用性的工具
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph  # 匯入與 PyTorch 版本和圖形相關的工具

logger = logging.get_logger(__name__)  # 獲取當前模組的日誌記錄器例項,便於記錄日誌資訊

if is_torch_npu_available():  # 檢查是否可以使用 PyTorch NPU
    import torch_npu  # 如果可用,則匯入 torch_npu 模組

if is_xformers_available():  # 檢查是否可以使用 xformers 庫
    import xformers  # 如果可用,匯入 xformers 模組
    import xformers.ops  # 匯入 xformers 中的操作模組
else:  # 如果 xformers 不可用
    xformers = None  # 將 xformers 設為 None

@maybe_allow_in_graph  # 裝飾器,可能允許在圖中使用該類
class Attention(nn.Module):  # 定義 Attention 類,繼承自 nn.Module
    r"""  # 文件字串,描述該類是一個交叉注意力層
    A cross attention layer.
    """

    def __init__(  # 初始化方法,定義建構函式
        self,
        query_dim: int,  # 查詢維度,型別為整數
        cross_attention_dim: Optional[int] = None,  # 可選的交叉注意力維度,預設為 None
        heads: int = 8,  # 注意力頭的數量,預設為 8
        kv_heads: Optional[int] = None,  # 可選的鍵值頭數量,預設為 None
        dim_head: int = 64,  # 每個頭的維度,預設為 64
        dropout: float = 0.0,  # dropout 機率,預設為 0.0
        bias: bool = False,  # 是否使用偏置,預設為 False
        upcast_attention: bool = False,  # 是否上升注意力精度,預設為 False
        upcast_softmax: bool = False,  # 是否上升 softmax 精度,預設為 False
        cross_attention_norm: Optional[str] = None,  # 可選的交叉注意力歸一化方式,預設為 None
        cross_attention_norm_num_groups: int = 32,  # 交叉注意力歸一化的組數量,預設為 32
        qk_norm: Optional[str] = None,  # 可選的查詢鍵歸一化方式,預設為 None
        added_kv_proj_dim: Optional[int] = None,  # 可選的新增鍵值投影維度,預設為 None
        added_proj_bias: Optional[bool] = True,  # 是否為新增的投影使用偏置,預設為 True
        norm_num_groups: Optional[int] = None,  # 可選的歸一化組數量,預設為 None
        spatial_norm_dim: Optional[int] = None,  # 可選的空間歸一化維度,預設為 None
        out_bias: bool = True,  # 是否使用輸出偏置,預設為 True
        scale_qk: bool = True,  # 是否縮放查詢和鍵,預設為 True
        only_cross_attention: bool = False,  # 是否僅使用交叉注意力,預設為 False
        eps: float = 1e-5,  # 為數值穩定性引入的微小常數,預設為 1e-5
        rescale_output_factor: float = 1.0,  # 輸出重標定因子,預設為 1.0
        residual_connection: bool = False,  # 是否使用殘差連線,預設為 False
        _from_deprecated_attn_block: bool = False,  # 可選引數,指示是否來自棄用的注意力塊,預設為 False
        processor: Optional["AttnProcessor"] = None,  # 可選的處理器,預設為 None
        out_dim: int = None,  # 輸出維度,預設為 None
        context_pre_only=None,  # 上下文前處理,預設為 None
        pre_only=False,  # 是否僅進行前處理,預設為 False
    # 設定是否使用來自 `torch_npu` 的 npu flash attention
    def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
        r"""
        設定是否使用來自 `torch_npu` 的 npu flash attention。
        """
        # 如果選擇使用 npu flash attention
        if use_npu_flash_attention:
            # 建立 NPU 注意力處理器例項
            processor = AttnProcessorNPU()
        else:
            # 設定注意力處理器
            # 預設情況下使用 AttnProcessor2_0,當使用 torch 2.x 時,
            # 它利用 torch.nn.functional.scaled_dot_product_attention 進行本地 Flash/記憶體高效注意力
            # 僅在其具有預設 `scale` 引數時適用。TODO: 在遷移到 torch 2.1 時移除 scale_qk 檢查
            processor = (
                AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
            )
        # 設定當前的處理器
        self.set_processor(processor)
    
    # 設定是否使用記憶體高效的 xformers 注意力
    def set_use_memory_efficient_attention_xformers(
        self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
    ):
        pass  # 此處可能缺少實現
    
    # 設定注意力計算的切片大小
    def set_attention_slice(self, slice_size: int) -> None:
        r"""
        設定注意力計算的切片大小。
    
        引數:
            slice_size (`int`):
                用於注意力計算的切片大小。
        """
        # 如果切片大小不為 None 且大於可切片頭維度
        if slice_size is not None and slice_size > self.sliceable_head_dim:
            # 丟擲值錯誤,切片大小必須小於或等於可切片頭維度
            raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
    
        # 如果切片大小不為 None 且新增的 kv 投影維度不為 None
        if slice_size is not None and self.added_kv_proj_dim is not None:
            # 建立帶切片大小的 KV 處理器例項
            processor = SlicedAttnAddedKVProcessor(slice_size)
        # 如果切片大小不為 None
        elif slice_size is not None:
            # 建立帶切片大小的注意力處理器例項
            processor = SlicedAttnProcessor(slice_size)
        # 如果新增的 kv 投影維度不為 None
        elif self.added_kv_proj_dim is not None:
            # 建立 KV 注意力處理器例項
            processor = AttnAddedKVProcessor()
        else:
            # 設定注意力處理器
            # 預設情況下使用 AttnProcessor2_0,當使用 torch 2.x 時,
            # 它利用 torch.nn.functional.scaled_dot_product_attention 進行本地 Flash/記憶體高效注意力
            # 僅在其具有預設 `scale` 引數時適用。TODO: 在遷移到 torch 2.1 時移除 scale_qk 檢查
            processor = (
                AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
            )
    
        # 設定當前的處理器
        self.set_processor(processor)
    # 設定要使用的注意力處理器
    def set_processor(self, processor: "AttnProcessor") -> None:
        r"""
        設定要使用的注意力處理器。
    
        引數:
            processor (`AttnProcessor`):
                要使用的注意力處理器。
        """
        # 如果當前處理器在 `self._modules` 中,且傳入的 `processor` 不在其中,則需要從 `self._modules` 中移除當前處理器
        if (
            hasattr(self, "processor")  # 檢查當前物件是否有處理器屬性
            and isinstance(self.processor, torch.nn.Module)  # 確保當前處理器是一個 PyTorch 模組
            and not isinstance(processor, torch.nn.Module)  # 檢查傳入的處理器不是 PyTorch 模組
        ):
            # 記錄日誌,指出將移除已訓練權重的處理器
            logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
            # 從模組中移除當前處理器
            self._modules.pop("processor")
    
        # 設定當前物件的處理器為傳入的處理器
        self.processor = processor
    
    # 獲取正在使用的注意力處理器
    def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
        r"""
        獲取正在使用的注意力處理器。
    
        引數:
            return_deprecated_lora (`bool`, *可選*, 預設為 `False`):
                設定為 `True` 以返回過時的 LoRA 注意力處理器。
    
        返回:
            "AttentionProcessor": 正在使用的注意力處理器。
        """
        # 如果不需要返回過時的 LoRA 處理器,則返回當前處理器
        if not return_deprecated_lora:
            return self.processor
    
    # 前向傳播方法,處理輸入的隱藏狀態
    def forward(
        self,
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 可選的編碼器隱藏狀態張量
        attention_mask: Optional[torch.Tensor] = None,  # 可選的注意力掩碼張量
        **cross_attention_kwargs,  # 可變引數,用於交叉注意力
    ) -> torch.Tensor:
        r"""  # 文件字串,描述此方法的功能和引數
        The forward method of the `Attention` class.

        Args:  # 引數說明
            hidden_states (`torch.Tensor`):  # 查詢的隱藏狀態,型別為張量
                The hidden states of the query.
            encoder_hidden_states (`torch.Tensor`, *optional*):  # 編碼器的隱藏狀態,可選引數
                The hidden states of the encoder.
            attention_mask (`torch.Tensor`, *optional*):  # 注意力掩碼,可選引數
                The attention mask to use. If `None`, no mask is applied.
            **cross_attention_kwargs:  # 額外的關鍵字引數,傳遞給交叉注意力
                Additional keyword arguments to pass along to the cross attention.

        Returns:  # 返回值說明
            `torch.Tensor`: The output of the attention layer.  # 返回注意力層的輸出
        """
        # `Attention` 類可以呼叫不同的注意力處理器/函式
        # 這裡我們簡單地將所有張量傳遞給所選的處理器類
        # 對於此處定義的標準處理器,`**cross_attention_kwargs` 是空的

        attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())  # 獲取處理器呼叫方法的引數名集合
        quiet_attn_parameters = {"ip_adapter_masks"}  # 定義不需要警告的引數集合
        unused_kwargs = [  # 篩選出未被使用的關鍵字引數
            k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
        ]
        if len(unused_kwargs) > 0:  # 如果存在未使用的關鍵字引數
            logger.warning(  # 記錄警告日誌
                f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
            )
        cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}  # 過濾出有效的關鍵字引數

        return self.processor(  # 呼叫處理器並返回結果
            self,
            hidden_states,  # 傳遞隱藏狀態
            encoder_hidden_states=encoder_hidden_states,  # 傳遞編碼器的隱藏狀態
            attention_mask=attention_mask,  # 傳遞注意力掩碼
            **cross_attention_kwargs,  # 解包有效的額外關鍵字引數
        )

    def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:  # 定義方法,輸入張量並返回處理後的張量
        r"""  # 文件字串,描述此方法的功能和引數
        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`  # 將張量從 `[batch_size, seq_len, dim]` 重新形狀為 `[batch_size // heads, seq_len, dim * heads]`,`heads` 為初始化時的頭數量
        is the number of heads initialized while constructing the `Attention` class.

        Args:  # 引數說明
            tensor (`torch.Tensor`): The tensor to reshape.  # 要重新形狀的張量

        Returns:  # 返回值說明
            `torch.Tensor`: The reshaped tensor.  # 返回重新形狀後的張量
        """
        head_size = self.heads  # 獲取頭的數量
        batch_size, seq_len, dim = tensor.shape  # 解包輸入張量的形狀
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)  # 重新調整張量的形狀
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)  # 調整維度順序並重新形狀
        return tensor  # 返回處理後的張量
    # 將輸入張量從形狀 `[batch_size, seq_len, dim]` 轉換為 `[batch_size, seq_len, heads, dim // heads]`
    def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
            r"""
            將張量從 `[batch_size, seq_len, dim]` 重塑為 `[batch_size, seq_len, heads, dim // heads]`,其中 `heads` 是
            在構造 `Attention` 類時初始化的頭數。
    
            引數:
                tensor (`torch.Tensor`): 要重塑的張量。
                out_dim (`int`, *可選*, 預設值為 `3`): 張量的輸出維度。如果為 `3`,則張量被
                    重塑為 `[batch_size * heads, seq_len, dim // heads]`。
    
            返回:
                `torch.Tensor`: 重塑後的張量。
            """
            # 獲取頭的數量
            head_size = self.heads
            # 檢查輸入張量的維度,如果是三維則提取形狀資訊
            if tensor.ndim == 3:
                batch_size, seq_len, dim = tensor.shape
                extra_dim = 1
            else:
                # 如果不是三維,提取四維形狀資訊
                batch_size, extra_dim, seq_len, dim = tensor.shape
            # 重塑張量為 `[batch_size, seq_len * extra_dim, head_size, dim // head_size]`
            tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
            # 調整張量維度順序為 `[batch_size, heads, seq_len * extra_dim, dim // heads]`
            tensor = tensor.permute(0, 2, 1, 3)
    
            # 如果輸出維度為 3,進一步重塑張量為 `[batch_size * heads, seq_len * extra_dim, dim // heads]`
            if out_dim == 3:
                tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
    
            # 返回重塑後的張量
            return tensor
    
    # 計算注意力得分的函式
    def get_attention_scores(
            self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
        ) -> torch.Tensor:
            r"""
            計算注意力得分。
    
            引數:
                query (`torch.Tensor`): 查詢張量。
                key (`torch.Tensor`): 鍵張量。
                attention_mask (`torch.Tensor`, *可選*): 使用的注意力掩碼。如果為 `None`,則不應用掩碼。
    
            返回:
                `torch.Tensor`: 注意力機率/得分。
            """
            # 獲取查詢張量的資料型別
            dtype = query.dtype
            # 如果需要上升型別,將查詢和鍵張量轉換為浮點型
            if self.upcast_attention:
                query = query.float()
                key = key.float()
    
            # 如果沒有提供注意力掩碼,建立空的輸入張量
            if attention_mask is None:
                baddbmm_input = torch.empty(
                    query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
                )
                # 設定 beta 為 0
                beta = 0
            else:
                # 如果有注意力掩碼,將其用作輸入
                baddbmm_input = attention_mask
                # 設定 beta 為 1
                beta = 1
    
            # 計算注意力得分
            attention_scores = torch.baddbmm(
                baddbmm_input,
                query,
                key.transpose(-1, -2),
                beta=beta,
                alpha=self.scale,
            )
            # 刪除臨時的輸入張量
            del baddbmm_input
    
            # 如果需要上升型別,將注意力得分轉換為浮點型
            if self.upcast_softmax:
                attention_scores = attention_scores.float()
    
            # 計算注意力機率
            attention_probs = attention_scores.softmax(dim=-1)
            # 刪除注意力得分張量
            del attention_scores
    
            # 將注意力機率轉換回原始資料型別
            attention_probs = attention_probs.to(dtype)
    
            # 返回注意力機率
            return attention_probs
    
    # 準備注意力掩碼的函式
    def prepare_attention_mask(
            self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
    ) -> torch.Tensor:  # 定義一個函式的返回型別為 torch.Tensor
        r"""  # 開始文件字串,描述函式的作用和引數
        Prepare the attention mask for the attention computation.  # 準備注意力計算的注意力掩碼
        Args:  # 引數說明
            attention_mask (`torch.Tensor`):  # 輸入引數,注意力掩碼,型別為 torch.Tensor
                The attention mask to prepare.  # 待準備的注意力掩碼
            target_length (`int`):  # 輸入引數,目標長度,型別為 int
                The target length of the attention mask. This is the length of the attention mask after padding.  # 注意力掩碼的目標長度,經過填充後的長度
            batch_size (`int`):  # 輸入引數,批處理大小,型別為 int
                The batch size, which is used to repeat the attention mask.  # 批處理大小,用於重複注意力掩碼
            out_dim (`int`, *optional*, defaults to `3`):  # 可選引數,輸出維度,型別為 int,預設為 3
                The output dimension of the attention mask. Can be either `3` or `4`.  # 注意力掩碼的輸出維度,可以是 3 或 4
        Returns:  # 返回說明
            `torch.Tensor`: The prepared attention mask.  # 返回準備好的注意力掩碼,型別為 torch.Tensor
        """  # 結束文件字串
        head_size = self.heads  # 獲取頭部大小,來自類的屬性 heads
        if attention_mask is None:  # 檢查注意力掩碼是否為 None
            return attention_mask  # 如果是 None,直接返回

        current_length: int = attention_mask.shape[-1]  # 獲取當前注意力掩碼的長度
        if current_length != target_length:  # 檢查當前長度是否與目標長度不匹配
            if attention_mask.device.type == "mps":  # 如果裝置型別是 "mps"
                # HACK: MPS: Does not support padding by greater than dimension of input tensor.  # HACK: MPS 不支援填充超過輸入張量的維度
                # Instead, we can manually construct the padding tensor.  # 所以我們手動構建填充張量
                padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)  # 定義填充張量的形狀
                padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)  # 建立全零填充張量
                attention_mask = torch.cat([attention_mask, padding], dim=2)  # 在最後一個維度上拼接填充張量
            else:  # 如果不是 "mps" 裝置
                # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:  # TODO: 對於如 stable-diffusion 的管道,填充交叉注意力掩碼
                #       we want to instead pad by (0, remaining_length), where remaining_length is:  # 我們希望用 (0, remaining_length) 填充,其中 remaining_length 是
                #       remaining_length: int = target_length - current_length  # remaining_length 的計算
                # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding  # TODO: 重新啟用相關測試
                attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)  # 用零填充注意力掩碼到目標長度

        if out_dim == 3:  # 如果輸出維度是 3
            if attention_mask.shape[0] < batch_size * head_size:  # 檢查注意力掩碼的第一維是否小於批處理大小乘以頭部大小
                attention_mask = attention_mask.repeat_interleave(head_size, dim=0)  # 在第一維上重複注意力掩碼
        elif out_dim == 4:  # 如果輸出維度是 4
            attention_mask = attention_mask.unsqueeze(1)  # 在第一維增加一個維度
            attention_mask = attention_mask.repeat_interleave(head_size, dim=1)  # 在第二維上重複注意力掩碼

        return attention_mask  # 返回準備好的注意力掩碼
    # 定義一個函式用於規範化編碼器的隱藏狀態,接受一個張量作為輸入並返回一個張量
    def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
        r"""
        規範化編碼器隱藏狀態。構造 `Attention` 類時需要指定 `self.norm_cross`。

        引數:
            encoder_hidden_states (`torch.Tensor`): 編碼器的隱藏狀態。

        返回:
            `torch.Tensor`: 規範化後的編碼器隱藏狀態。
        """
        # 確保在呼叫此方法之前已定義 `self.norm_cross`
        assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"

        # 檢查 `self.norm_cross` 是否為 LayerNorm 型別
        if isinstance(self.norm_cross, nn.LayerNorm):
            # 對編碼器隱藏狀態進行層歸一化
            encoder_hidden_states = self.norm_cross(encoder_hidden_states)
        # 檢查 `self.norm_cross` 是否為 GroupNorm 型別
        elif isinstance(self.norm_cross, nn.GroupNorm):
            # GroupNorm 沿通道維度進行歸一化,並期望輸入形狀為 (N, C, *)。
            # 此時我們希望沿隱藏維度進行歸一化,因此需要調整形狀
            # (batch_size, sequence_length, hidden_size) ->
            # (batch_size, hidden_size, sequence_length)
            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)  # 轉置張量以調整維度順序
            encoder_hidden_states = self.norm_cross(encoder_hidden_states)  # 對轉置後的張量進行歸一化
            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)  # 再次轉置回原始順序
        else:
            # 如果 `self.norm_cross` 既不是 LayerNorm 也不是 GroupNorm,則觸發斷言失敗
            assert False

        # 返回規範化後的編碼器隱藏狀態
        return encoder_hidden_states

    # 該裝飾器在計算圖中禁止梯度計算,以節省記憶體和加快推理速度
    @torch.no_grad()
    # 定義一個融合投影的方法,預設引數 fuse 為 True
    def fuse_projections(self, fuse=True):
        # 獲取 to_q 權重的裝置資訊
        device = self.to_q.weight.data.device
        # 獲取 to_q 權重的資料型別
        dtype = self.to_q.weight.data.dtype

        # 如果不是交叉注意力
        if not self.is_cross_attention:
            # 獲取權重矩陣的拼接
            concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
            # 輸入特徵數為拼接後權重的列數
            in_features = concatenated_weights.shape[1]
            # 輸出特徵數為拼接後權重的行數
            out_features = concatenated_weights.shape[0]

            # 建立一個新的線性投影層並複製權重
            self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
            # 複製拼接後的權重到新的層
            self.to_qkv.weight.copy_(concatenated_weights)
            # 如果使用偏置
            if self.use_bias:
                # 拼接 q、k、v 的偏置
                concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
                # 複製拼接後的偏置到新的層
                self.to_qkv.bias.copy_(concatenated_bias)

        # 如果是交叉注意力
        else:
            # 獲取 k 和 v 權重的拼接
            concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
            # 輸入特徵數為拼接後權重的列數
            in_features = concatenated_weights.shape[1]
            # 輸出特徵數為拼接後權重的行數
            out_features = concatenated_weights.shape[0]

            # 建立一個新的線性投影層並複製權重
            self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
            # 複製拼接後的權重到新的層
            self.to_kv.weight.copy_(concatenated_weights)
            # 如果使用偏置
            if self.use_bias:
                # 拼接 k 和 v 的偏置
                concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
                # 複製拼接後的偏置到新的層
                self.to_kv.bias.copy_(concatenated_bias)

        # 處理 SD3 和其他新增的投影
        if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
            # 獲取額外投影的權重拼接
            concatenated_weights = torch.cat(
                [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
            )
            # 輸入特徵數為拼接後權重的列數
            in_features = concatenated_weights.shape[1]
            # 輸出特徵數為拼接後權重的行數
            out_features = concatenated_weights.shape[0]

            # 建立一個新的線性投影層並複製權重
            self.to_added_qkv = nn.Linear(
                in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
            )
            # 複製拼接後的權重到新的層
            self.to_added_qkv.weight.copy_(concatenated_weights)
            # 如果使用偏置
            if self.added_proj_bias:
                # 拼接額外投影的偏置
                concatenated_bias = torch.cat(
                    [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
                )
                # 複製拼接後的偏置到新的層
                self.to_added_qkv.bias.copy_(concatenated_bias)

        # 將融合狀態儲存到屬性中
        self.fused_projections = fuse
# 定義一個處理器類,用於執行與注意力相關的計算
class AttnProcessor:
    r"""
    預設處理器,用於執行與注意力相關的計算。
    """

    # 實現可呼叫方法,處理注意力計算
    def __call__(
        self,
        attn: Attention,  # 注意力物件
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 編碼器隱藏狀態(可選)
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩碼(可選)
        temb: Optional[torch.Tensor] = None,  # 額外的時間嵌入(可選)
        *args,  # 額外的位置引數
        **kwargs,  # 額外的關鍵字引數
    ) -> torch.Tensor:  # 返回處理後的張量
        # 檢查是否有額外引數或已棄用的 scale 引數
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 構建棄用警告訊息
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫棄用處理函式
            deprecate("scale", "1.0.0", deprecation_message)

        # 初始化殘差為隱藏狀態
        residual = hidden_states

        # 如果空間歸一化存在,則應用於隱藏狀態
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        # 獲取輸入張量的維度
        input_ndim = hidden_states.ndim

        # 如果輸入是四維的,則調整形狀
        if input_ndim == 4:
            # 解包隱藏狀態的形狀
            batch_size, channel, height, width = hidden_states.shape
            # 重新調整形狀為(batch_size, channel, height*width)並轉置
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        # 根據編碼器隱藏狀態的存在與否,獲取批次大小和序列長度
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        # 準備注意力掩碼
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        # 如果組歸一化存在,則應用於隱藏狀態
        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        # 將隱藏狀態轉換為查詢向量
        query = attn.to_q(hidden_states)

        # 如果沒有編碼器隱藏狀態,使用隱藏狀態作為編碼器隱藏狀態
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        # 如果需要規範化編碼器隱藏狀態,則應用規範化
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        # 從編碼器隱藏狀態中獲取鍵和值
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        # 將查詢、鍵和值轉換為批次維度
        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        # 計算注意力分數
        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        # 透過注意力分數加權求值
        hidden_states = torch.bmm(attention_probs, value)
        # 將隱藏狀態轉換回頭維度
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # 線性投影
        hidden_states = attn.to_out[0](hidden_states)
        # 應用 dropout
        hidden_states = attn.to_out[1](hidden_states)

        # 如果輸入是四維的,調整回原始形狀
        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        # 如果存在殘差連線,則將殘差加回隱藏狀態
        if attn.residual_connection:
            hidden_states = hidden_states + residual

        # 將隱藏狀態歸一化到輸出因子
        hidden_states = hidden_states / attn.rescale_output_factor

        # 返回最終的隱藏狀態
        return hidden_states


# 定義一個處理器類,用於實現自定義擴散方法的注意力
class CustomDiffusionAttnProcessor(nn.Module):
    r"""
    實現自定義擴散方法的注意力處理器。
    # 定義引數說明
    Args:
        train_kv (`bool`, defaults to `True`):  # 是否重新訓練對應於文字特徵的鍵值矩陣
            Whether to newly train the key and value matrices corresponding to the text features.
        train_q_out (`bool`, defaults to `True`):  # 是否重新訓練對應於潛在影像特徵的查詢矩陣
            Whether to newly train query matrices corresponding to the latent image features.
        hidden_size (`int`, *optional*, defaults to `None`):  # 注意力層的隱藏大小
            The hidden size of the attention layer.
        cross_attention_dim (`int`, *optional*, defaults to `None`):  # 編碼器隱藏狀態中的通道數量
            The number of channels in the `encoder_hidden_states`.
        out_bias (`bool`, defaults to `True`):  # 是否在 `train_q_out` 中包含偏置引數
            Whether to include the bias parameter in `train_q_out`.
        dropout (`float`, *optional*, defaults to 0.0):  # 使用的 dropout 機率
            The dropout probability to use.
    """

    # 初始化方法
    def __init__(
        self,  # 初始化方法的第一個引數,表示物件本身
        train_kv: bool = True,  # 設定鍵值矩陣訓練的預設值為 True
        train_q_out: bool = True,  # 設定查詢矩陣訓練的預設值為 True
        hidden_size: Optional[int] = None,  # 隱藏層大小,預設為 None
        cross_attention_dim: Optional[int] = None,  # 跨注意力維度,預設為 None
        out_bias: bool = True,  # 輸出偏置引數的預設值為 True
        dropout: float = 0.0,  # 預設的 dropout 機率為 0.0
    ):
        super().__init__()  # 呼叫父類的初始化方法
        self.train_kv = train_kv  # 儲存鍵值訓練標誌
        self.train_q_out = train_q_out  # 儲存查詢輸出訓練標誌

        self.hidden_size = hidden_size  # 儲存隱藏層大小
        self.cross_attention_dim = cross_attention_dim  # 儲存跨注意力維度

        # `_custom_diffusion` id 方便序列化和載入
        if self.train_kv:  # 如果需要訓練鍵值
            self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)  # 建立鍵的線性層
            self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)  # 建立值的線性層
        if self.train_q_out:  # 如果需要訓練查詢輸出
            self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)  # 建立查詢的線性層
            self.to_out_custom_diffusion = nn.ModuleList([])  # 初始化輸出層的模組列表
            self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))  # 新增線性輸出層
            self.to_out_custom_diffusion.append(nn.Dropout(dropout))  # 新增 dropout 層

    # 可呼叫方法
    def __call__(  # 定義物件被呼叫時的行為
        self,  # 第一個引數,表示物件本身
        attn: Attention,  # 注意力物件
        hidden_states: torch.Tensor,  # 隱藏狀態張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 編碼器隱藏狀態,預設為 None
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩碼,預設為 None
    # 返回型別為 torch.Tensor
    ) -> torch.Tensor:
        # 獲取隱藏狀態的批次大小和序列長度
        batch_size, sequence_length, _ = hidden_states.shape
        # 準備注意力掩碼以適應當前批次和序列長度
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
        # 如果需要訓練查詢輸出,則使用自定義擴散進行轉換
        if self.train_q_out:
            query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
        else:
            # 否則使用標準的查詢轉換
            query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
    
        # 檢查編碼器隱藏狀態是否為 None
        if encoder_hidden_states is None:
            # 如果是,則不進行交叉注意力
            crossattn = False
            encoder_hidden_states = hidden_states
        else:
            # 否則,啟用交叉注意力
            crossattn = True
            # 如果需要歸一化編碼器隱藏狀態,則進行歸一化
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
    
        # 如果需要訓練鍵值對
        if self.train_kv:
            # 使用自定義擴散獲取鍵和值
            key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
            value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
            # 將鍵和值轉換為查詢的權重資料型別
            key = key.to(attn.to_q.weight.dtype)
            value = value.to(attn.to_q.weight.dtype)
        else:
            # 否則使用標準的鍵和值轉換
            key = attn.to_k(encoder_hidden_states)
            value = attn.to_v(encoder_hidden_states)
    
        # 如果進行交叉注意力
        if crossattn:
            # 建立與鍵相同形狀的張量以進行detach操作
            detach = torch.ones_like(key)
            detach[:, :1, :] = detach[:, :1, :] * 0.0
            # 應用detach邏輯以阻止梯度流動
            key = detach * key + (1 - detach) * key.detach()
            value = detach * value + (1 - detach) * value.detach()
    
        # 將查詢、鍵和值轉換為批次維度
        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)
    
        # 計算注意力分數
        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        # 使用注意力分數和值進行批次矩陣乘法
        hidden_states = torch.bmm(attention_probs, value)
        # 將隱藏狀態轉換回頭維度
        hidden_states = attn.batch_to_head_dim(hidden_states)
    
        # 如果需要訓練查詢輸出
        if self.train_q_out:
            # 線性投影
            hidden_states = self.to_out_custom_diffusion[0](hidden_states)
            # 應用dropout
            hidden_states = self.to_out_custom_diffusion[1](hidden_states)
        else:
            # 否則使用標準的線性投影
            hidden_states = attn.to_out[0](hidden_states)
            # 應用dropout
            hidden_states = attn.to_out[1](hidden_states)
    
        # 返回最終的隱藏狀態
        return hidden_states
# 定義一個帶有額外可學習的鍵和值矩陣的注意力處理器類
class AttnAddedKVProcessor:
    r"""
    處理器,用於執行與文字編碼器相關的注意力計算
    """

    # 定義呼叫方法,以實現注意力計算
    def __call__(
        self,
        attn: Attention,  # 注意力物件
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 編碼器的隱藏狀態(可選)
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩碼(可選)
        *args,  # 其他位置引數
        **kwargs,  # 其他關鍵字引數
    ) -> torch.Tensor:  # 返回型別為張量
        # 檢查是否傳遞了多餘的引數或已棄用的 scale 引數
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 發出棄用警告
            deprecate("scale", "1.0.0", deprecation_message)

        # 將隱藏狀態賦值給殘差
        residual = hidden_states

        # 重塑隱藏狀態的形狀,並轉置維度
        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
        # 獲取批大小和序列長度
        batch_size, sequence_length, _ = hidden_states.shape

        # 準備注意力掩碼
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        # 如果沒有編碼器隱藏狀態,則使用輸入的隱藏狀態
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        # 如果需要進行歸一化處理
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        # 對隱藏狀態進行分組歸一化處理
        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        # 將隱藏狀態轉換為查詢
        query = attn.to_q(hidden_states)
        # 將查詢從頭維度轉換為批維度
        query = attn.head_to_batch_dim(query)

        # 將編碼器隱藏狀態投影為鍵和值
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
        # 將投影結果轉換為批維度
        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)

        # 如果不是僅進行交叉注意力
        if not attn.only_cross_attention:
            # 將隱藏狀態轉換為鍵和值
            key = attn.to_k(hidden_states)
            value = attn.to_v(hidden_states)
            # 轉換為批維度
            key = attn.head_to_batch_dim(key)
            value = attn.head_to_batch_dim(value)
            # 將編碼器鍵和值與當前鍵和值拼接
            key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
            value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
        else:
            # 僅使用編碼器的鍵和值
            key = encoder_hidden_states_key_proj
            value = encoder_hidden_states_value_proj

        # 獲取注意力機率
        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        # 計算隱藏狀態的新值
        hidden_states = torch.bmm(attention_probs, value)
        # 將隱藏狀態轉換回頭維度
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # 線性投影
        hidden_states = attn.to_out[0](hidden_states)
        # 應用 dropout
        hidden_states = attn.to_out[1](hidden_states)

        # 重塑隱藏狀態,並將殘差加回
        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
        hidden_states = hidden_states + residual

        # 返回處理後的隱藏狀態
        return hidden_states


# 定義另一個注意力處理器類
class AttnAddedKVProcessor2_0:
    r"""
    # 處理縮放點積注意力的處理器(如果使用 PyTorch 2.0,預設啟用),
    # 其中為文字編碼器新增了額外的可學習的鍵和值矩陣。
        """
    
        # 初始化方法
        def __init__(self):
            # 檢查 F 中是否有 "scaled_dot_product_attention" 屬性
            if not hasattr(F, "scaled_dot_product_attention"):
                # 如果沒有,丟擲 ImportError,提示使用者需要升級到 PyTorch 2.0
                raise ImportError(
                    "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
                )
    
        # 定義呼叫方法
        def __call__(
            self,
            attn: Attention,  # 輸入的注意力機制物件
            hidden_states: torch.Tensor,  # 隱藏狀態張量
            encoder_hidden_states: Optional[torch.Tensor] = None,  # 可選的編碼器隱藏狀態張量
            attention_mask: Optional[torch.Tensor] = None,  # 可選的注意力掩碼張量
            *args,  # 額外的位置引數
            **kwargs,  # 額外的關鍵字引數
    ) -> torch.Tensor:  # 指定函式返回型別為 torch.Tensor
        # 檢查引數是否存在或 scale 引數是否被提供
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 設定棄用訊息,告知 scale 引數將被忽略
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 呼叫 deprecate 函式發出棄用警告
            deprecate("scale", "1.0.0", deprecation_message)

        # 將輸入的 hidden_states 賦值給 residual
        residual = hidden_states

        # 調整 hidden_states 的形狀並進行轉置
        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
        # 獲取 batch_size 和 sequence_length
        batch_size, sequence_length, _ = hidden_states.shape

        # 準備注意力掩碼
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)

        # 如果沒有提供 encoder_hidden_states,則使用 hidden_states
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        # 如果需要歸一化交叉隱藏狀態
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        # 對 hidden_states 進行分組歸一化
        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        # 計算查詢向量
        query = attn.to_q(hidden_states)
        # 將查詢向量轉換為批次維度
        query = attn.head_to_batch_dim(query, out_dim=4)

        # 生成 encoder_hidden_states 的鍵和值的投影
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
        # 將鍵和值轉換為批次維度
        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)

        # 如果不是隻進行交叉注意力
        if not attn.only_cross_attention:
            # 計算當前 hidden_states 的鍵和值
            key = attn.to_k(hidden_states)
            value = attn.to_v(hidden_states)
            # 轉換為批次維度
            key = attn.head_to_batch_dim(key, out_dim=4)
            value = attn.head_to_batch_dim(value, out_dim=4)
            # 將鍵和值與 encoder 的鍵和值連線
            key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
            value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
        else:
            # 如果只進行交叉注意力,使用 encoder 的鍵和值
            key = encoder_hidden_states_key_proj
            value = encoder_hidden_states_value_proj

        # 計算縮放點積注意力的輸出,形狀為 (batch, num_heads, seq_len, head_dim)
        # TODO: 在遷移到 Torch 2.1 時新增對 attn.scale 的支援
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )
        # 轉置並重塑 hidden_states
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])

        # 進行線性投影
        hidden_states = attn.to_out[0](hidden_states)
        # 進行 dropout
        hidden_states = attn.to_out[1](hidden_states)

        # 轉置並重塑回 residual 的形狀
        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
        # 將 residual 加到 hidden_states 上
        hidden_states = hidden_states + residual

        # 返回最終的 hidden_states
        return hidden_states
# 定義一個名為 JointAttnProcessor2_0 的類,用於處理自注意力投影
class JointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    # 初始化方法
    def __init__(self):
        # 檢查 F 是否有 scaled_dot_product_attention 屬性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果沒有,丟擲匯入錯誤,提示需要升級 PyTorch 到 2.0
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定義呼叫方法,接受多個引數
    def __call__(
        self,
        attn: Attention,  # 自注意力物件
        hidden_states: torch.FloatTensor,  # 當前隱藏狀態的張量
        encoder_hidden_states: torch.FloatTensor = None,  # 編碼器的隱藏狀態,預設為 None
        attention_mask: Optional[torch.FloatTensor] = None,  # 可選的注意力掩碼,預設為 None
        *args,  # 額外的位置引數
        **kwargs,  # 額外的關鍵字引數
    # 返回一個浮點張量
    ) -> torch.FloatTensor:
        # 儲存輸入的隱藏狀態,以便後續使用
        residual = hidden_states
    
        # 獲取隱藏狀態的維度
        input_ndim = hidden_states.ndim
        # 如果隱藏狀態是四維的
        if input_ndim == 4:
            # 解包隱藏狀態的形狀為批大小、通道、高度和寬度
            batch_size, channel, height, width = hidden_states.shape
            # 將隱藏狀態重塑為三維,並進行轉置
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
        # 獲取編碼器隱藏狀態的維度
        context_input_ndim = encoder_hidden_states.ndim
        # 如果編碼器隱藏狀態是四維的
        if context_input_ndim == 4:
            # 解包編碼器隱藏狀態的形狀為批大小、通道、高度和寬度
            batch_size, channel, height, width = encoder_hidden_states.shape
            # 將編碼器隱藏狀態重塑為三維,並進行轉置
            encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
    
        # 獲取編碼器隱藏狀態的批大小
        batch_size = encoder_hidden_states.shape[0]
    
        # 計算 `sample` 投影
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)
    
        # 計算 `context` 投影
        encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
    
        # 合併注意力查詢、鍵和值
        query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
        key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
        value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
    
        # 獲取鍵的最後一維大小
        inner_dim = key.shape[-1]
        # 計算每個頭的維度
        head_dim = inner_dim // attn.heads
        # 重塑查詢、鍵和值以適應多個頭
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
    
        # 計算縮放點積注意力
        hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
        # 轉置並重塑隱藏狀態
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        # 轉換為查詢的型別
        hidden_states = hidden_states.to(query.dtype)
    
        # 拆分注意力輸出
        hidden_states, encoder_hidden_states = (
            hidden_states[:, : residual.shape[1]],  # 獲取原隱藏狀態的部分
            hidden_states[:, residual.shape[1] :],  # 獲取編碼器隱藏狀態的部分
        )
    
        # 進行線性投影
        hidden_states = attn.to_out[0](hidden_states)
        # 進行 dropout
        hidden_states = attn.to_out[1](hidden_states)
        # 如果上下文不是僅限於編碼器
        if not attn.context_pre_only:
            # 對編碼器隱藏狀態進行額外處理
            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
    
        # 如果輸入是四維的,進行轉置和重塑
        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
        # 如果上下文輸入是四維的,進行轉置和重塑
        if context_input_ndim == 4:
            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
    
        # 返回處理後的隱藏狀態和編碼器隱藏狀態
        return hidden_states, encoder_hidden_states
# 定義一個類,PAGJointAttnProcessor2_0,用於處理自注意力投影
class PAGJointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    # 初始化方法
    def __init__(self):
        # 檢查是否存在名為"scaled_dot_product_attention"的屬性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果不存在,則丟擲匯入錯誤,提示需要升級PyTorch到2.0
            raise ImportError(
                "PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    # 可呼叫方法,接受注意力物件和隱藏狀態
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        # 其他可選引數
        attention_mask: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
# 定義另一個類,PAGCFGJointAttnProcessor2_0,類似於PAGJointAttnProcessor2_0
class PAGCFGJointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    # 初始化方法
    def __init__(self):
        # 檢查是否存在名為"scaled_dot_product_attention"的屬性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果不存在,則丟擲匯入錯誤,提示需要升級PyTorch到2.0
            raise ImportError(
                "PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    # 可呼叫方法,接受注意力物件和隱藏狀態
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        # 其他可選引數
        attention_mask: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
# 定義第三個類,FusedJointAttnProcessor2_0,處理自注意力投影
class FusedJointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    # 初始化方法
    def __init__(self):
        # 檢查是否存在名為"scaled_dot_product_attention"的屬性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果不存在,則丟擲匯入錯誤,提示需要升級PyTorch到2.0
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 可呼叫方法,接受注意力物件和隱藏狀態
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        # 其他可選引數
        attention_mask: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:
        # 將隱藏狀態賦值給殘差變數
        residual = hidden_states

        # 獲取隱藏狀態的維度
        input_ndim = hidden_states.ndim
        # 如果隱藏狀態是四維的,進行維度變換
        if input_ndim == 4:
            # 解包隱藏狀態的形狀
            batch_size, channel, height, width = hidden_states.shape
            # 將隱藏狀態變形為(batch_size, channel, height * width)並轉置
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
        # 獲取編碼器隱藏狀態的維度
        context_input_ndim = encoder_hidden_states.ndim
        # 如果編碼器隱藏狀態是四維的,進行維度變換
        if context_input_ndim == 4:
            # 解包編碼器隱藏狀態的形狀
            batch_size, channel, height, width = encoder_hidden_states.shape
            # 將編碼器隱藏狀態變形為(batch_size, channel, height * width)並轉置
            encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        # 獲取編碼器隱藏狀態的批次大小
        batch_size = encoder_hidden_states.shape[0]

        # `sample` 進行投影
        qkv = attn.to_qkv(hidden_states)
        # 計算每個分量的大小
        split_size = qkv.shape[-1] // 3
        # 將qkv拆分為query、key和value
        query, key, value = torch.split(qkv, split_size, dim=-1)

        # `context` 進行投影
        encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
        # 計算編碼器qkv的分量大小
        split_size = encoder_qkv.shape[-1] // 3
        # 將編碼器qkv拆分為查詢、鍵和值的投影
        (
            encoder_hidden_states_query_proj,
            encoder_hidden_states_key_proj,
            encoder_hidden_states_value_proj,
        ) = torch.split(encoder_qkv, split_size, dim=-1)

        # 進行注意力計算
        # 將query、key、value進行連線
        query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
        key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
        value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

        # 獲取key的最後一維大小
        inner_dim = key.shape[-1]
        # 計算每個頭的維度
        head_dim = inner_dim // attn.heads
        # 調整query的形狀以適應多頭注意力
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # 調整key的形狀以適應多頭注意力
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # 調整value的形狀以適應多頭注意力
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # 進行縮放點積注意力計算
        hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
        # 調整hidden_states的形狀
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        # 將hidden_states轉換為與query相同的資料型別
        hidden_states = hidden_states.to(query.dtype)

        # 拆分注意力輸出
        hidden_states, encoder_hidden_states = (
            # 保留殘差形狀的部分
            hidden_states[:, : residual.shape[1]],
            # 剩餘的部分
            hidden_states[:, residual.shape[1] :],
        )

        # 線性投影
        hidden_states = attn.to_out[0](hidden_states)
        # 進行dropout
        hidden_states = attn.to_out[1](hidden_states)
        # 如果不是隻使用上下文,進行編碼器輸出的投影
        if not attn.context_pre_only:
            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

        # 如果輸入是四維的,調整hidden_states的形狀
        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
        # 如果上下文輸入是四維的,調整encoder_hidden_states的形狀
        if context_input_ndim == 4:
            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        # 返回hidden_states和encoder_hidden_states
        return hidden_states, encoder_hidden_states
# 定義一個用於處理 Aura Flow 的注意力處理器類
class AuraFlowAttnProcessor2_0:
    """Attention processor used typically in processing Aura Flow."""

    # 初始化方法
    def __init__(self):
        # 檢查 F 是否具有 scaled_dot_product_attention 屬性,並確保 PyTorch 版本符合要求
        if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
            # 如果不滿足條件,丟擲匯入錯誤,提示使用者升級 PyTorch
            raise ImportError(
                "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
            )

    # 可呼叫方法,用於處理輸入的注意力和隱藏狀態
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        *args,
        **kwargs,
# 定義一個用於處理 Aura Flow 的融合投影注意力處理器類
class FusedAuraFlowAttnProcessor2_0:
    """Attention processor used typically in processing Aura Flow with fused projections."""

    # 初始化方法
    def __init__(self):
        # 檢查 F 是否具有 scaled_dot_product_attention 屬性,並確保 PyTorch 版本符合要求
        if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
            # 如果不滿足條件,丟擲匯入錯誤,提示使用者升級 PyTorch
            raise ImportError(
                "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
            )

    # 可呼叫方法,用於處理輸入的注意力和隱藏狀態
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        *args,
        **kwargs,
# YiYi 待辦事項:重構與 rope 相關的函式/類
def apply_rope(xq, xk, freqs_cis):
    # 將 xq 轉換為浮點型,並重新調整形狀以便處理
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    # 將 xk 轉換為浮點型,並重新調整形狀以便處理
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    # 計算 xq 的輸出,結合頻率複數
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    # 計算 xk 的輸出,結合頻率複數
    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
    # 返回撥整形狀後的 xq_out 和 xk_out,並確保與原始型別匹配
    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)


# 定義一個實現縮放點積注意力的處理器類
class FluxSingleAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    # 初始化方法
    def __init__(self):
        # 檢查 F 是否具有 scaled_dot_product_attention 屬性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果不滿足條件,丟擲匯入錯誤,提示使用者升級 PyTorch
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 可呼叫方法,用於處理輸入的注意力和隱藏狀態
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    # 定義函式的返回型別為 torch.Tensor
    ) -> torch.Tensor:
        # 獲取 hidden_states 的維度數量
        input_ndim = hidden_states.ndim
    
        # 如果輸入的維度為 4
        if input_ndim == 4:
            # 解包 hidden_states 的形狀為 batch_size, channel, height, width
            batch_size, channel, height, width = hidden_states.shape
            # 將 hidden_states 檢視調整為 (batch_size, channel, height * width) 並轉置
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
    
        # 如果 encoder_hidden_states 為 None,則獲取 hidden_states 的形狀
        # 否則獲取 encoder_hidden_states 的形狀
        batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
    
        # 將 hidden_states 轉換為查詢向量
        query = attn.to_q(hidden_states)
        # 如果 encoder_hidden_states 為 None,將其設定為 hidden_states
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
    
        # 將 encoder_hidden_states 轉換為鍵向量
        key = attn.to_k(encoder_hidden_states)
        # 將 encoder_hidden_states 轉換為值向量
        value = attn.to_v(encoder_hidden_states)
    
        # 獲取鍵的最後一個維度的大小
        inner_dim = key.shape[-1]
        # 計算每個頭的維度
        head_dim = inner_dim // attn.heads
    
        # 將查詢向量調整檢視為 (batch_size, -1, attn.heads, head_dim) 並轉置
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
    
        # 將鍵向量調整檢視為 (batch_size, -1, attn.heads, head_dim) 並轉置
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # 將值向量調整檢視為 (batch_size, -1, attn.heads, head_dim) 並轉置
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
    
        # 如果存在規範化查詢的層,則對查詢進行規範化
        if attn.norm_q is not None:
            query = attn.norm_q(query)
        # 如果存在規範化鍵的層,則對鍵進行規範化
        if attn.norm_k is not None:
            key = attn.norm_k(key)
    
        # 如果需要應用 RoPE
        if image_rotary_emb is not None:
            # 應用旋轉嵌入到查詢和鍵上
            query, key = apply_rope(query, key, image_rotary_emb)
    
        # 計算縮放點積注意力,輸出形狀為 (batch, num_heads, seq_len, head_dim)
        hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
    
        # 轉置並調整 hidden_states 的形狀為 (batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        # 將 hidden_states 轉換為與查詢相同的資料型別
        hidden_states = hidden_states.to(query.dtype)
    
        # 如果輸入維度為 4,將 hidden_states 轉置並調整形狀回原始維度
        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
    
        # 返回處理後的 hidden_states
        return hidden_states
# 定義一個名為 FluxAttnProcessor2_0 的類,通常用於處理 SD3 類自注意力投影
class FluxAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    # 初始化方法
    def __init__(self):
        # 檢查 F 是否有 scaled_dot_product_attention 屬性,如果沒有則丟擲 ImportError
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定義呼叫方法,使類例項可被呼叫
    def __call__(
        self,
        attn: Attention,  # 接收 Attention 物件
        hidden_states: torch.FloatTensor,  # 接收隱藏狀態張量
        encoder_hidden_states: torch.FloatTensor = None,  # 可選的編碼器隱藏狀態張量
        attention_mask: Optional[torch.FloatTensor] = None,  # 可選的注意力掩碼張量
        image_rotary_emb: Optional[torch.Tensor] = None,  # 可選的影像旋轉嵌入張量
    ):
        # 此處將實現自注意力的具體處理邏輯

# 定義一個名為 CogVideoXAttnProcessor2_0 的類,專用於 CogVideoX 模型的縮放點積注意力處理
class CogVideoXAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
    query and key vectors, but does not include spatial normalization.
    """

    # 初始化方法
    def __init__(self):
        # 檢查 F 是否有 scaled_dot_product_attention 屬性,如果沒有則丟擲 ImportError
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定義呼叫方法,使類例項可被呼叫
    def __call__(
        self,
        attn: Attention,  # 接收 Attention 物件
        hidden_states: torch.Tensor,  # 接收隱藏狀態張量
        encoder_hidden_states: torch.Tensor,  # 接收編碼器隱藏狀態張量
        attention_mask: Optional[torch.Tensor] = None,  # 可選的注意力掩碼張量
        image_rotary_emb: Optional[torch.Tensor] = None,  # 可選的影像旋轉嵌入張量
    ):
        # 此處將實現自注意力的具體處理邏輯
    ) -> torch.Tensor:  # 函式返回一個張量,表示隱藏狀態
        text_seq_length = encoder_hidden_states.size(1)  # 獲取編碼器隱藏狀態的序列長度

        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)  # 在維度1上連線編碼器隱藏狀態和當前隱藏狀態

        batch_size, sequence_length, _ = (  # 解包 batch_size 和 sequence_length
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape  # 根據編碼器隱藏狀態的存在性決定形狀
        )

        if attention_mask is not None:  # 如果存在注意力掩碼
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)  # 準備注意力掩碼
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])  # 調整注意力掩碼的形狀以適應頭數

        query = attn.to_q(hidden_states)  # 將隱藏狀態轉換為查詢向量
        key = attn.to_k(hidden_states)  # 將隱藏狀態轉換為鍵向量
        value = attn.to_v(hidden_states)  # 將隱藏狀態轉換為值向量

        inner_dim = key.shape[-1]  # 獲取鍵向量的最後一個維度大小
        head_dim = inner_dim // attn.heads  # 計算每個頭的維度

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 調整查詢向量形狀並轉置以適應多頭注意力
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 調整鍵向量形狀並轉置以適應多頭注意力
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 調整值向量形狀並轉置以適應多頭注意力

        if attn.norm_q is not None:  # 如果查詢歸一化層存在
            query = attn.norm_q(query)  # 對查詢向量進行歸一化
        if attn.norm_k is not None:  # 如果鍵歸一化層存在
            key = attn.norm_k(key)  # 對鍵向量進行歸一化

        # Apply RoPE if needed  # 如果需要應用旋轉位置編碼
        if image_rotary_emb is not None:  # 如果影像旋轉嵌入存在
            from .embeddings import apply_rotary_emb  # 匯入應用旋轉嵌入的函式

            query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)  # 應用旋轉嵌入到查詢向量的後半部分
            if not attn.is_cross_attention:  # 如果不是交叉注意力
                key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)  # 應用旋轉嵌入到鍵向量的後半部分

        hidden_states = F.scaled_dot_product_attention(  # 計算縮放點積注意力
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False  # 輸入查詢、鍵和值,以及注意力掩碼
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)  # 轉置和重塑隱藏狀態以合併頭維度

        # linear proj  # 線性投影
        hidden_states = attn.to_out[0](hidden_states)  # 對隱藏狀態應用輸出線性變換
        # dropout  # 進行dropout操作
        hidden_states = attn.to_out[1](hidden_states)  # 對隱藏狀態應用dropout

        encoder_hidden_states, hidden_states = hidden_states.split(  # 將隱藏狀態分割為編碼器和當前隱藏狀態
            [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1  # 根據文字序列長度和剩餘部分進行分割
        )
        return hidden_states, encoder_hidden_states  # 返回當前隱藏狀態和編碼器隱藏狀態
# 定義一個用於實現 CogVideoX 模型的縮放點積注意力的處理器類
class FusedCogVideoXAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
    query and key vectors, but does not include spatial normalization.
    """

    # 初始化方法
    def __init__(self):
        # 檢查 F 是否具有 scaled_dot_product_attention 屬性,如果沒有則丟擲匯入錯誤
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定義可呼叫方法,處理注意力計算
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # 獲取編碼器隱藏狀態的序列長度
        text_seq_length = encoder_hidden_states.size(1)

        # 將編碼器和當前隱藏狀態按維度 1 連線
        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

        # 獲取批次大小和序列長度,依據編碼器隱藏狀態是否為 None
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        # 如果提供了注意力掩碼,則準備掩碼
        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # 將掩碼調整為適當的形狀
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        # 將隱藏狀態轉換為查詢、鍵、值
        qkv = attn.to_qkv(hidden_states)
        # 計算每個部分的大小
        split_size = qkv.shape[-1] // 3
        # 分割成查詢、鍵和值
        query, key, value = torch.split(qkv, split_size, dim=-1)

        # 獲取鍵的內部維度
        inner_dim = key.shape[-1]
        # 計算每個頭的維度
        head_dim = inner_dim // attn.heads

        # 調整查詢、鍵和值的形狀以適應多頭注意力
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # 如果存在查詢的歸一化,則應用歸一化
        if attn.norm_q is not None:
            query = attn.norm_q(query)
        # 如果存在鍵的歸一化,則應用歸一化
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        # 如果需要應用 RoPE
        if image_rotary_emb is not None:
            from .embeddings import apply_rotary_emb

            # 對查詢的特定部分應用旋轉嵌入
            query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
            # 如果不是交叉注意力,則對鍵的特定部分應用旋轉嵌入
            if not attn.is_cross_attention:
                key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)

        # 計算縮放點積注意力
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        # 調整隱藏狀態的形狀以便輸出
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

        # 線性投影
        hidden_states = attn.to_out[0](hidden_states)
        # 應用 dropout
        hidden_states = attn.to_out[1](hidden_states)

        # 將隱藏狀態拆分為編碼器隱藏狀態和當前隱藏狀態
        encoder_hidden_states, hidden_states = hidden_states.split(
            [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
        )
        # 返回當前隱藏狀態和編碼器隱藏狀態
        return hidden_states, encoder_hidden_states


# 定義用於實現記憶體高效注意力的處理器類
class XFormersAttnAddedKVProcessor:
    r"""
    Processor for implementing memory efficient attention using xFormers.
    # 文件字串,說明可選引數 attention_op 的作用
        Args:
            attention_op (`Callable`, *optional*, defaults to `None`):
                使用的基本注意力運算子,推薦設定為 `None` 讓 xFormers 選擇最佳運算子
        """
    
        # 建構函式,初始化注意力運算子
        def __init__(self, attention_op: Optional[Callable] = None):
            # 將傳入的注意力運算子賦值給例項變數
            self.attention_op = attention_op
    
        # 可呼叫方法,用於執行注意力計算
        def __call__(
            self,
            attn: Attention,  # 注意力物件
            hidden_states: torch.Tensor,  # 隱藏狀態張量
            encoder_hidden_states: Optional[torch.Tensor] = None,  # 編碼器隱藏狀態,預設為 None
            attention_mask: Optional[torch.Tensor] = None,  # 注意力掩碼,預設為 None
        ) -> torch.Tensor:
            # 將當前隱藏狀態儲存為殘差以便後續使用
            residual = hidden_states
            # 調整隱藏狀態的形狀並轉置
            hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
            # 獲取批次大小和序列長度
            batch_size, sequence_length, _ = hidden_states.shape
    
            # 準備注意力掩碼
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
    
            # 如果沒有編碼器隱藏狀態,則將其設定為當前的隱藏狀態
            if encoder_hidden_states is None:
                encoder_hidden_states = hidden_states
            # 如果需要,則對編碼器隱藏狀態進行歸一化處理
            elif attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
    
            # 對隱藏狀態進行分組歸一化處理
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
    
            # 生成查詢向量
            query = attn.to_q(hidden_states)
            # 將查詢向量從頭部維度轉換為批次維度
            query = attn.head_to_batch_dim(query)
    
            # 對編碼器隱藏狀態進行鍵和值的投影
            encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
            encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
            # 將編碼器隱藏狀態的鍵和值轉換為批次維度
            encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
            encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
    
            # 如果不是僅使用交叉注意力
            if not attn.only_cross_attention:
                # 生成當前隱藏狀態的鍵和值
                key = attn.to_k(hidden_states)
                value = attn.to_v(hidden_states)
                # 轉換鍵和值到批次維度
                key = attn.head_to_batch_dim(key)
                value = attn.head_to_batch_dim(value)
                # 將編碼器的鍵和值與當前的鍵和值連線起來
                key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
                value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
            else:
                # 如果僅使用交叉注意力,則直接使用編碼器的鍵和值
                key = encoder_hidden_states_key_proj
                value = encoder_hidden_states_value_proj
    
            # 計算高效的注意力
            hidden_states = xformers.ops.memory_efficient_attention(
                query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
            )
            # 將結果轉換為查詢的 dtype
            hidden_states = hidden_states.to(query.dtype)
            # 將隱藏狀態從批次維度轉換回頭部維度
            hidden_states = attn.batch_to_head_dim(hidden_states)
    
            # 線性變換
            hidden_states = attn.to_out[0](hidden_states)
            # 應用 dropout
            hidden_states = attn.to_out[1](hidden_states)
    
            # 調整隱藏狀態的形狀以匹配殘差
            hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
            # 將當前隱藏狀態與殘差相加
            hidden_states = hidden_states + residual
    
            # 返回最終的隱藏狀態
            return hidden_states
# 定義一個用於實現基於 xFormers 的記憶體高效注意力的處理器類
class XFormersAttnProcessor:
    r"""
    處理器,用於實現基於 xFormers 的記憶體高效注意力。

    引數:
        attention_op (`Callable`, *可選*, 預設為 `None`):
            基礎
            [運算子](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase),
            用作注意力運算子。建議將其設定為 `None`,並讓 xFormers 選擇最佳運算子。
    """

    # 初始化方法,接受一個可選的注意力運算子
    def __init__(self, attention_op: Optional[Callable] = None):
        # 將傳入的注意力運算子賦值給例項變數
        self.attention_op = attention_op

    # 定義可呼叫方法,用於執行注意力計算
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        *args,
        **kwargs,
# 定義一個用於實現 flash attention 的處理器類,使用 torch_npu
class AttnProcessorNPU:
    r"""
    處理器,用於使用 torch_npu 實現 flash attention。torch_npu 僅支援 fp16 和 bf16 資料型別。如果
    使用 fp32,將使用 F.scaled_dot_product_attention 進行計算,但在 NPU 上加速效果不明顯。

    """

    # 初始化方法
    def __init__(self):
        # 檢查是否可用 torch_npu,如果不可用則丟擲異常
        if not is_torch_npu_available():
            raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")

    # 定義可呼叫方法,用於執行注意力計算
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        *args,
        **kwargs,
# 定義一個用於實現 scaled dot-product attention 的處理器類,預設在 PyTorch 2.0 中啟用
class AttnProcessor2_0:
    r"""
    處理器,用於實現 scaled dot-product attention(如果您使用的是 PyTorch 2.0,預設啟用)。
    """

    # 初始化方法
    def __init__(self):
        # 檢查 F 中是否有 scaled_dot_product_attention 屬性,如果沒有則丟擲異常
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定義可呼叫方法,用於執行注意力計算
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        *args,
        **kwargs,
# 定義一個用於實現 scaled dot-product attention 的處理器類,適用於穩定音訊模型
class StableAudioAttnProcessor2_0:
    r"""
    處理器,用於實現 scaled dot-product attention(如果您使用的是 PyTorch 2.0,預設啟用)。此處理器用於
    穩定音訊模型。它在查詢和鍵向量上應用旋轉嵌入,並允許 MHA、GQA 或 MQA。
    """

    # 初始化方法
    def __init__(self):
        # 檢查 F 中是否有 scaled_dot_product_attention 屬性,如果沒有則丟擲異常
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError(
                "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    # 定義方法,用於應用部分旋轉嵌入
    def apply_partial_rotary_emb(
        self,
        x: torch.Tensor,
        freqs_cis: Tuple[torch.Tensor],
    # 定義返回型別為 torch.Tensor 的函式
    ) -> torch.Tensor:
        # 從當前模組匯入 apply_rotary_emb 函式
        from .embeddings import apply_rotary_emb
    
        # 獲取頻率餘弦的最後一個維度大小,用於旋轉
        rot_dim = freqs_cis[0].shape[-1]
        # 將輸入張量 x 劃分為需要旋轉和不需要旋轉的部分
        x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
    
        # 應用旋轉嵌入到需要旋轉的部分
        x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
    
        # 將旋轉後的部分與未旋轉的部分在最後一個維度上連線
        out = torch.cat((x_rotated, x_unrotated), dim=-1)
        # 返回連線後的輸出張量
        return out
    
    # 定義可呼叫方法,接收注意力和隱藏狀態
    def __call__(
        self,
        # 輸入的注意力物件
        attn: Attention,
        # 隱藏狀態的張量
        hidden_states: torch.Tensor,
        # 可選的編碼器隱藏狀態張量
        encoder_hidden_states: Optional[torch.Tensor] = None,
        # 可選的注意力掩碼張量
        attention_mask: Optional[torch.Tensor] = None,
        # 可選的旋轉嵌入張量
        rotary_emb: Optional[torch.Tensor] = None,
# 定義 HunyuanAttnProcessor2_0 類,處理縮放的點積注意力
class HunyuanAttnProcessor2_0:
    r"""
    處理器用於實現縮放的點積注意力(如果使用 PyTorch 2.0,預設啟用)。這是
    HunyuanDiT 模型中使用的。它在查詢和鍵向量上應用歸一化層和旋轉嵌入。
    """

    # 初始化方法
    def __init__(self):
        # 檢查 F 中是否有 scaled_dot_product_attention 屬性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果沒有,則丟擲匯入錯誤,提示需要升級 PyTorch 到 2.0
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定義呼叫方法
    def __call__(
        self,
        attn: Attention,  # 注意力機制例項
        hidden_states: torch.Tensor,  # 當前隱藏狀態的張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 編碼器隱藏狀態的可選張量
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩碼的可選張量
        temb: Optional[torch.Tensor] = None,  # 時間嵌入的可選張量
        image_rotary_emb: Optional[torch.Tensor] = None,  # 影像旋轉嵌入的可選張量
class FusedHunyuanAttnProcessor2_0:
    r"""
    處理器用於實現縮放的點積注意力(如果使用 PyTorch 2.0,預設啟用),帶有融合的
    投影層。這是 HunyuanDiT 模型中使用的。它在查詢和鍵向量上應用歸一化層和旋轉嵌入。
    """

    # 初始化方法
    def __init__(self):
        # 檢查 F 中是否有 scaled_dot_product_attention 屬性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果沒有,則丟擲匯入錯誤,提示需要升級 PyTorch 到 2.0
            raise ImportError(
                "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    # 定義呼叫方法
    def __call__(
        self,
        attn: Attention,  # 注意力機制例項
        hidden_states: torch.Tensor,  # 當前隱藏狀態的張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 編碼器隱藏狀態的可選張量
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩碼的可選張量
        temb: Optional[torch.Tensor] = None,  # 時間嵌入的可選張量
        image_rotary_emb: Optional[torch.Tensor] = None,  # 影像旋轉嵌入的可選張量
class PAGHunyuanAttnProcessor2_0:
    r"""
    處理器用於實現縮放的點積注意力(如果使用 PyTorch 2.0,預設啟用)。這是
    HunyuanDiT 模型中使用的。它在查詢和鍵向量上應用歸一化層和旋轉嵌入。該處理器
    變體採用了 [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377)。
    """

    # 初始化方法
    def __init__(self):
        # 檢查 F 中是否有 scaled_dot_product_attention 屬性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果沒有,則丟擲匯入錯誤,提示需要升級 PyTorch 到 2.0
            raise ImportError(
                "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    # 定義呼叫方法
    def __call__(
        self,
        attn: Attention,  # 注意力機制例項
        hidden_states: torch.Tensor,  # 當前隱藏狀態的張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 編碼器隱藏狀態的可選張量
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩碼的可選張量
        temb: Optional[torch.Tensor] = None,  # 時間嵌入的可選張量
        image_rotary_emb: Optional[torch.Tensor] = None,  # 影像旋轉嵌入的可選張量
class PAGCFGHunyuanAttnProcessor2_0:
    r"""
    處理器用於實現縮放的點積注意力(如果使用 PyTorch 2.0,預設啟用)。這是
    HunyuanDiT 模型中使用的。它在查詢和鍵向量上應用歸一化層和旋轉嵌入。該處理器
    變體採用了 [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377)。
    """
    # 初始化方法,用於建立類的例項
        def __init__(self):
            # 檢查模組 F 是否具有屬性 "scaled_dot_product_attention"
            if not hasattr(F, "scaled_dot_product_attention"):
                # 如果沒有該屬性,則丟擲 ImportError,提示使用者升級 PyTorch
                raise ImportError(
                    "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
                )
    
    # 可呼叫方法,允許類的例項像函式一樣被呼叫
        def __call__(
            self,
            attn: Attention,  # 注意力機制物件
            hidden_states: torch.Tensor,  # 當前隱藏狀態的張量
            encoder_hidden_states: Optional[torch.Tensor] = None,  # 編碼器的隱藏狀態,可選引數
            attention_mask: Optional[torch.Tensor] = None,  # 注意力掩碼,可選引數
            temb: Optional[torch.Tensor] = None,  # 時間嵌入,可選引數
            image_rotary_emb: Optional[torch.Tensor] = None,  # 影像旋轉嵌入,可選引數
# 定義一個用於實現縮放點積注意力的處理器類
class LuminaAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
    used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
    """

    # 初始化方法
    def __init__(self):
        # 檢查 PyTorch 是否具有縮放點積注意力功能
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果沒有,丟擲匯入錯誤,提示使用者升級 PyTorch 到 2.0
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定義呼叫方法,使類例項可呼叫
    def __call__(
        self,
        # 接收注意力物件
        attn: Attention,
        # 接收隱藏狀態張量
        hidden_states: torch.Tensor,
        # 接收編碼器隱藏狀態張量
        encoder_hidden_states: torch.Tensor,
        # 可選的注意力掩碼張量
        attention_mask: Optional[torch.Tensor] = None,
        # 可選的查詢旋轉嵌入張量
        query_rotary_emb: Optional[torch.Tensor] = None,
        # 可選的鍵旋轉嵌入張量
        key_rotary_emb: Optional[torch.Tensor] = None,
        # 可選的基本序列長度
        base_sequence_length: Optional[int] = None,
    ) -> torch.Tensor:  # 函式返回一個張量,表示處理後的隱藏狀態
        from .embeddings import apply_rotary_emb  # 從當前包匯入應用旋轉嵌入的函式

        input_ndim = hidden_states.ndim  # 獲取隱藏狀態的維度數

        if input_ndim == 4:  # 如果隱藏狀態是四維張量
            batch_size, channel, height, width = hidden_states.shape  # 解包出批次大小、通道、高度和寬度
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)  # 重塑並轉置隱藏狀態

        batch_size, sequence_length, _ = hidden_states.shape  # 解包出批次大小和序列長度

        # Get Query-Key-Value Pair  # 獲取查詢、鍵、值對
        query = attn.to_q(hidden_states)  # 將隱藏狀態轉換為查詢張量
        key = attn.to_k(encoder_hidden_states)  # 將編碼器的隱藏狀態轉換為鍵張量
        value = attn.to_v(encoder_hidden_states)  # 將編碼器的隱藏狀態轉換為值張量

        query_dim = query.shape[-1]  # 獲取查詢的最後一個維度(特徵維度)
        inner_dim = key.shape[-1]  # 獲取鍵的最後一個維度
        head_dim = query_dim // attn.heads  # 計算每個頭的維度
        dtype = query.dtype  # 獲取查詢張量的資料型別

        # Get key-value heads  # 獲取鍵值頭的數量
        kv_heads = inner_dim // head_dim  # 計算每個頭的鍵值數量

        # Apply Query-Key Norm if needed  # 如果需要,應用查詢-鍵歸一化
        if attn.norm_q is not None:  # 如果定義了查詢的歸一化
            query = attn.norm_q(query)  # 對查詢進行歸一化
        if attn.norm_k is not None:  # 如果定義了鍵的歸一化
            key = attn.norm_k(key)  # 對鍵進行歸一化

        query = query.view(batch_size, -1, attn.heads, head_dim)  # 重塑查詢張量以適應頭的維度

        key = key.view(batch_size, -1, kv_heads, head_dim)  # 重塑鍵張量以適應頭的維度
        value = value.view(batch_size, -1, kv_heads, head_dim)  # 重塑值張量以適應頭的維度

        # Apply RoPE if needed  # 如果需要,應用旋轉位置嵌入
        if query_rotary_emb is not None:  # 如果定義了查詢的旋轉嵌入
            query = apply_rotary_emb(query, query_rotary_emb, use_real=False)  # 應用旋轉嵌入到查詢
        if key_rotary_emb is not None:  # 如果定義了鍵的旋轉嵌入
            key = apply_rotary_emb(key, key_rotary_emb, use_real=False)  # 應用旋轉嵌入到鍵

        query, key = query.to(dtype), key.to(dtype)  # 將查詢和鍵轉換為相同的資料型別

        # Apply proportional attention if true  # 如果為真,應用比例注意力
        if key_rotary_emb is None:  # 如果沒有鍵的旋轉嵌入
            softmax_scale = None  # 設定縮放因子為 None
        else:  # 如果有鍵的旋轉嵌入
            if base_sequence_length is not None:  # 如果定義了基礎序列長度
                softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale  # 計算縮放因子
            else:  # 如果沒有定義基礎序列長度
                softmax_scale = attn.scale  # 使用注意力的縮放因子

        # perform Grouped-query Attention (GQA)  # 執行分組查詢注意力
        n_rep = attn.heads // kv_heads  # 計算每個鍵值頭的重複數量
        if n_rep >= 1:  # 如果重複數量大於等於 1
            key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)  # 擴充套件並重復鍵
            value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)  # 擴充套件並重復值

        # scaled_dot_product_attention expects attention_mask shape to be  # 縮放點積注意力期望的注意力掩碼形狀
        # (batch, heads, source_length, target_length)
        attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)  # 將注意力掩碼轉換為布林值並調整形狀
        attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)  # 擴充套件注意力掩碼以匹配頭的數量

        query = query.transpose(1, 2)  # 轉置查詢張量
        key = key.transpose(1, 2)  # 轉置鍵張量
        value = value.transpose(1, 2)  # 轉置值張量

        # the output of sdp = (batch, num_heads, seq_len, head_dim)  # 縮放點積注意力的輸出形狀
        # TODO: add support for attn.scale when we move to Torch 2.1  # TODO: 在遷移到 Torch 2.1 時支援 attn.scale
        hidden_states = F.scaled_dot_product_attention(  # 計算縮放點積注意力
            query, key, value, attn_mask=attention_mask, scale=softmax_scale  # 輸入查詢、鍵、值及注意力掩碼和縮放因子
        )
        hidden_states = hidden_states.transpose(1, 2).to(dtype)  # 轉置輸出並轉換為相應的資料型別

        return hidden_states  # 返回處理後的隱藏狀態
# 定義一個用於實現縮放點積注意力的處理器類,預設啟用(如果使用 PyTorch 2.0)
class FusedAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
    fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
    For cross-attention modules, key and value projection matrices are fused.

    <Tip warning={true}>

    This API is currently 🧪 experimental in nature and can change in future.

    </Tip>
    """

    # 初始化方法
    def __init__(self):
        # 檢查 F 庫是否具有縮放點積注意力功能
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果沒有,丟擲匯入錯誤,提示使用者升級 PyTorch 版本
            raise ImportError(
                "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
            )

    # 呼叫方法,處理注意力計算
    def __call__(
        self,
        attn: Attention,  # 注意力模組
        hidden_states: torch.Tensor,  # 隱藏狀態張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 編碼器的隱藏狀態,可選
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩碼,可選
        temb: Optional[torch.Tensor] = None,  # 時間嵌入,可選
        *args,  # 可變位置引數
        **kwargs,  # 可變關鍵字引數
    ):
        pass  # 此處省略具體實現

# 定義一個用於實現記憶體高效注意力的處理器類,使用 xFormers 方法
class CustomDiffusionXFormersAttnProcessor(nn.Module):
    r"""
    Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.

    Args:
    train_kv (`bool`, defaults to `True`):
        Whether to newly train the key and value matrices corresponding to the text features.
    train_q_out (`bool`, defaults to `True`):
        Whether to newly train query matrices corresponding to the latent image features.
    hidden_size (`int`, *optional*, defaults to `None`):
        The hidden size of the attention layer.
    cross_attention_dim (`int`, *optional*, defaults to `None`):
        The number of channels in the `encoder_hidden_states`.
    out_bias (`bool`, defaults to `True`):
        Whether to include the bias parameter in `train_q_out`.
    dropout (`float`, *optional*, defaults to 0.0):
        The dropout probability to use.
    attention_op (`Callable`, *optional*, defaults to `None`):
        The base
        [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
        as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
    """

    # 初始化方法,設定各種引數
    def __init__(
        self,
        train_kv: bool = True,  # 是否訓練與文字特徵對應的鍵值矩陣
        train_q_out: bool = False,  # 是否訓練與潛在影像特徵對應的查詢矩陣
        hidden_size: Optional[int] = None,  # 注意力層的隱藏大小
        cross_attention_dim: Optional[int] = None,  # 編碼器隱藏狀態的通道數
        out_bias: bool = True,  # 是否在 train_q_out 中包含偏置引數
        dropout: float = 0.0,  # 使用的丟棄機率
        attention_op: Optional[Callable] = None,  # 要使用的基礎注意力操作
    ):
        pass  # 此處省略具體實現
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 儲存訓練鍵值對的標誌
        self.train_kv = train_kv
        # 儲存訓練查詢輸出的標誌
        self.train_q_out = train_q_out

        # 儲存隱藏層大小
        self.hidden_size = hidden_size
        # 儲存交叉注意力維度
        self.cross_attention_dim = cross_attention_dim
        # 儲存注意力操作型別
        self.attention_op = attention_op

        # `_custom_diffusion` id 用於簡化序列化和載入
        if self.train_kv:
            # 建立線性層,將交叉注意力維度或隱藏層大小對映到隱藏層大小,且不使用偏置
            self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
            # 建立線性層,將交叉注意力維度或隱藏層大小對映到隱藏層大小,且不使用偏置
            self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        if self.train_q_out:
            # 建立線性層,將隱藏層大小對映到隱藏層大小,且不使用偏置
            self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
            # 建立一個空的模組列表以儲存輸出相關的層
            self.to_out_custom_diffusion = nn.ModuleList([])
            # 將線性層新增到模組列表中,用於輸出對映
            self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
            # 將 Dropout 層新增到模組列表中,用於正則化
            self.to_out_custom_diffusion.append(nn.Dropout(dropout))

    def __call__(
        # 定義呼叫方法,接收注意力物件和隱藏狀態張量
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        # 可選引數:編碼器的隱藏狀態張量
        encoder_hidden_states: Optional[torch.Tensor] = None,
        # 可選引數:注意力掩碼張量
        attention_mask: Optional[torch.Tensor] = None,
    # 定義函式的返回型別為 torch.Tensor
        ) -> torch.Tensor:
            # 獲取批次大小和序列長度,根據 encoder_hidden_states 是否為 None 決定來源
            batch_size, sequence_length, _ = (
                hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
            )
    
            # 準備注意力掩碼
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
    
            # 判斷是否在訓練階段並應用不同的查詢生成方式
            if self.train_q_out:
                query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
            else:
                query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
    
            # 判斷是否存在編碼器隱藏狀態,並設定 crossattn 標誌
            if encoder_hidden_states is None:
                crossattn = False
                encoder_hidden_states = hidden_states
            else:
                crossattn = True
                # 如果需要對編碼器隱藏狀態進行歸一化處理
                if attn.norm_cross:
                    encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
    
            # 判斷是否在訓練階段並應用不同的鍵值生成方式
            if self.train_kv:
                key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
                value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
                key = key.to(attn.to_q.weight.dtype)
                value = value.to(attn.to_q.weight.dtype)
            else:
                key = attn.to_k(encoder_hidden_states)
                value = attn.to_v(encoder_hidden_states)
    
            # 如果使用交叉注意力,進行鍵值的分離和處理
            if crossattn:
                detach = torch.ones_like(key)
                detach[:, :1, :] = detach[:, :1, :] * 0.0
                key = detach * key + (1 - detach) * key.detach()
                value = detach * value + (1 - detach) * value.detach()
    
            # 將查詢、鍵、值轉換為批處理維度並保持連續性
            query = attn.head_to_batch_dim(query).contiguous()
            key = attn.head_to_batch_dim(key).contiguous()
            value = attn.head_to_batch_dim(value).contiguous()
    
            # 使用記憶體高效的注意力計算隱藏狀態
            hidden_states = xformers.ops.memory_efficient_attention(
                query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
            )
            # 將隱藏狀態轉換為查詢的型別
            hidden_states = hidden_states.to(query.dtype)
            # 將隱藏狀態轉換回頭部維度
            hidden_states = attn.batch_to_head_dim(hidden_states)
    
            # 根據訓練標誌決定輸出的處理方式
            if self.train_q_out:
                # 線性變換
                hidden_states = self.to_out_custom_diffusion[0](hidden_states)
                # 進行 dropout 操作
                hidden_states = self.to_out_custom_diffusion[1](hidden_states)
            else:
                # 線性變換
                hidden_states = attn.to_out[0](hidden_states)
                # 進行 dropout 操作
                hidden_states = attn.to_out[1](hidden_states)
    
            # 返回處理後的隱藏狀態
            return hidden_states
# 自定義擴散注意力處理器類,繼承自 PyTorch 的 nn.Module
class CustomDiffusionAttnProcessor2_0(nn.Module):
    r"""
    用於實現自定義擴散方法的注意力處理器,使用 PyTorch 2.0 的記憶體高效縮放
    點積注意力。

    引數:
        train_kv (`bool`, 預設值為 `True`):
            是否新訓練與文字特徵對應的鍵和值矩陣。
        train_q_out (`bool`, 預設值為 `True`):
            是否新訓練與潛在影像特徵對應的查詢矩陣。
        hidden_size (`int`, *可選*, 預設值為 `None`):
            注意力層的隱藏大小。
        cross_attention_dim (`int`, *可選*, 預設值為 `None`):
            `encoder_hidden_states` 中的通道數。
        out_bias (`bool`, 預設值為 `True`):
            是否在 `train_q_out` 中包含偏置引數。
        dropout (`float`, *可選*, 預設值為 0.0):
            使用的 dropout 機率。
    """

    # 初始化方法,設定類的屬性
    def __init__(
        self,
        train_kv: bool = True,
        train_q_out: bool = True,
        hidden_size: Optional[int] = None,
        cross_attention_dim: Optional[int] = None,
        out_bias: bool = True,
        dropout: float = 0.0,
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 設定是否訓練鍵值矩陣的標誌
        self.train_kv = train_kv
        # 設定是否訓練查詢輸出矩陣的標誌
        self.train_q_out = train_q_out

        # 設定隱藏層的大小
        self.hidden_size = hidden_size
        # 設定交叉注意力的維度
        self.cross_attention_dim = cross_attention_dim

        # 如果需要訓練鍵值矩陣,則建立對應的線性層
        if self.train_kv:
            # 建立從交叉注意力維度到隱藏層的線性變換,且不使用偏置
            self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
            # 建立從交叉注意力維度到隱藏層的線性變換,且不使用偏置
            self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        # 如果需要訓練查詢輸出,則建立對應的線性層
        if self.train_q_out:
            # 建立從隱藏層到隱藏層的線性變換,且不使用偏置
            self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
            # 建立一個空的模組列表,用於儲存輸出層
            self.to_out_custom_diffusion = nn.ModuleList([])
            # 將線性層新增到輸出模組列表中
            self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
            # 新增 dropout 層到輸出模組列表中
            self.to_out_custom_diffusion.append(nn.Dropout(dropout))

    # 定義類的呼叫方法,處理輸入的注意力和隱藏狀態
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:  # 定義返回型別為 torch.Tensor
        batch_size, sequence_length, _ = hidden_states.shape  # 解包 hidden_states 的形狀,獲取批大小和序列長度
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)  # 準備注意力掩碼
        if self.train_q_out:  # 檢查是否在訓練查詢輸出
            query = self.to_q_custom_diffusion(hidden_states)  # 使用自定義擴散方法生成查詢向量
        else:  # 否則
            query = attn.to_q(hidden_states)  # 使用標準方法生成查詢向量

        if encoder_hidden_states is None:  # 檢查編碼器隱藏狀態是否為空
            crossattn = False  # 設定交叉注意力標誌為假
            encoder_hidden_states = hidden_states  # 將編碼器隱藏狀態設定為隱藏狀態
        else:  # 如果編碼器隱藏狀態不為空
            crossattn = True  # 設定交叉注意力標誌為真
            if attn.norm_cross:  # 如果需要歸一化
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)  # 歸一化編碼器隱藏狀態

        if self.train_kv:  # 檢查是否在訓練鍵值對
            key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))  # 生成鍵向量
            value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))  # 生成值向量
            key = key.to(attn.to_q.weight.dtype)  # 將鍵向量轉換為查詢權重的資料型別
            value = value.to(attn.to_q.weight.dtype)  # 將值向量轉換為查詢權重的資料型別
        else:  # 否則
            key = attn.to_k(encoder_hidden_states)  # 使用標準方法生成鍵向量
            value = attn.to_v(encoder_hidden_states)  # 使用標準方法生成值向量

        if crossattn:  # 如果進行交叉注意力
            detach = torch.ones_like(key)  # 建立與鍵相同形狀的全1張量
            detach[:, :1, :] = detach[:, :1, :] * 0.0  # 將第一時間步的值設定為0
            key = detach * key + (1 - detach) * key.detach()  # 根據 detach 張量計算鍵的最終值
            value = detach * value + (1 - detach) * value.detach()  # 根據 detach 張量計算值的最終值

        inner_dim = hidden_states.shape[-1]  # 獲取隱藏狀態的最後一維大小

        head_dim = inner_dim // attn.heads  # 計算每個頭的維度
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 重新調整查詢的形狀並轉置
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 重新調整鍵的形狀並轉置
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 重新調整值的形狀並轉置

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(  # 計算縮放點積注意力
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False  # 輸入查詢、鍵和值以及注意力掩碼
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)  # 轉置並重塑隱藏狀態
        hidden_states = hidden_states.to(query.dtype)  # 將隱藏狀態轉換為查詢的型別

        if self.train_q_out:  # 如果在訓練查詢輸出
            # linear proj
            hidden_states = self.to_out_custom_diffusion[0](hidden_states)  # 線性變換
            # dropout
            hidden_states = self.to_out_custom_diffusion[1](hidden_states)  # 應用 dropout
        else:  # 否則
            # linear proj
            hidden_states = attn.to_out[0](hidden_states)  # 線性變換
            # dropout
            hidden_states = attn.to_out[1](hidden_states)  # 應用 dropout

        return hidden_states  # 返回最終的隱藏狀態
# 定義一個用於實現切片注意力的處理器類
class SlicedAttnProcessor:
    r"""
    處理器用於實現切片注意力。

    引數:
        slice_size (`int`, *可選*):
            計算注意力的步驟數量。使用的切片數量為 `attention_head_dim // slice_size`,並且
            `attention_head_dim` 必須是 `slice_size` 的整數倍。
    """

    # 初始化方法,接受切片大小作為引數
    def __init__(self, slice_size: int):
        # 將傳入的切片大小儲存為例項變數
        self.slice_size = slice_size

    # 定義可呼叫方法,以便例項可以像函式一樣被呼叫
    def __call__(
        self,
        attn: Attention,  # 輸入的注意力物件
        hidden_states: torch.Tensor,  # 當前隱藏狀態的張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 編碼器的隱藏狀態,可選引數
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩碼,可選引數
    ) -> torch.Tensor:
        # 儲存輸入的隱藏狀態,用於殘差連線
        residual = hidden_states

        # 獲取隱藏狀態的維度數量
        input_ndim = hidden_states.ndim

        # 如果輸入維度是4,調整隱藏狀態的形狀
        if input_ndim == 4:
            # 解包隱藏狀態的形狀為批次大小、通道、高度和寬度
            batch_size, channel, height, width = hidden_states.shape
            # 將隱藏狀態展平為(batch_size, channel, height * width)並轉置
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        # 確定序列長度和批次大小,根據是否有編碼器隱藏狀態決定
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        # 準備注意力掩碼
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        # 如果有分組歸一化,應用於隱藏狀態
        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        # 將隱藏狀態轉換為查詢向量
        query = attn.to_q(hidden_states)
        # 獲取查詢向量的最後一維大小
        dim = query.shape[-1]
        # 將查詢向量轉換為批次維度格式
        query = attn.head_to_batch_dim(query)

        # 如果沒有編碼器隱藏狀態,使用當前隱藏狀態
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        # 如果需要,歸一化編碼器隱藏狀態
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        # 將編碼器隱藏狀態轉換為鍵和值
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        # 將鍵和值轉換為批次維度格式
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        # 獲取查詢向量的批次大小和令牌數量
        batch_size_attention, query_tokens, _ = query.shape
        # 初始化隱藏狀態張量為零
        hidden_states = torch.zeros(
            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
        )

        # 按切片處理查詢、鍵和值
        for i in range((batch_size_attention - 1) // self.slice_size + 1):
            # 計算當前切片的起始和結束索引
            start_idx = i * self.slice_size
            end_idx = (i + 1) * self.slice_size

            # 獲取當前切片的查詢、鍵和注意力掩碼
            query_slice = query[start_idx:end_idx]
            key_slice = key[start_idx:end_idx]
            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None

            # 計算當前切片的注意力分數
            attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)

            # 將注意力分數與值相乘,獲取注意力結果
            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])

            # 將注意力結果儲存到隱藏狀態中
            hidden_states[start_idx:end_idx] = attn_slice

        # 將隱藏狀態轉換回頭維度格式
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # 對隱藏狀態進行線性變換
        hidden_states = attn.to_out[0](hidden_states)
        # 應用 dropout
        hidden_states = attn.to_out[1](hidden_states)

        # 如果輸入維度是4,調整隱藏狀態的形狀
        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        # 如果需要殘差連線,將殘差加到隱藏狀態中
        if attn.residual_connection:
            hidden_states = hidden_states + residual

        # 根據縮放因子調整輸出
        hidden_states = hidden_states / attn.rescale_output_factor

        # 返回最終的隱藏狀態
        return hidden_states
# 定義一個處理器類,用於實現切片注意力,並額外學習鍵和值矩陣
class SlicedAttnAddedKVProcessor:
    r"""
    處理器,用於實現帶有額外可學習的鍵和值矩陣的切片注意力,用於文字編碼器。

    引數:
        slice_size (`int`, *可選*):
            計算注意力的步數。使用 `attention_head_dim // slice_size` 的切片數量,
            並且 `attention_head_dim` 必須是 `slice_size` 的倍數。
    """

    # 初始化方法,接收切片大小作為引數
    def __init__(self, slice_size):
        # 將傳入的切片大小賦值給例項變數
        self.slice_size = slice_size

    # 定義呼叫方法,使類的例項可以像函式一樣被呼叫
    def __call__(
        self,
        attn: "Attention",  # 接收一個注意力物件
        hidden_states: torch.Tensor,  # 輸入的隱藏狀態張量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 可選的編碼器隱藏狀態張量
        attention_mask: Optional[torch.Tensor] = None,  # 可選的注意力掩碼張量
        temb: Optional[torch.Tensor] = None,  # 可選的時間嵌入張量
    # 返回型別為 torch.Tensor
        ) -> torch.Tensor:
            # 儲存輸入的隱藏狀態作為殘差
            residual = hidden_states
    
            # 如果空間歸一化存在,則應用於隱藏狀態和時間嵌入
            if attn.spatial_norm is not None:
                hidden_states = attn.spatial_norm(hidden_states, temb)
    
            # 將隱藏狀態重塑為三維張量並轉置維度
            hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
    
            # 獲取批次大小和序列長度
            batch_size, sequence_length, _ = hidden_states.shape
    
            # 準備注意力掩碼
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
    
            # 如果沒有編碼器隱藏狀態,則將其設定為當前隱藏狀態
            if encoder_hidden_states is None:
                encoder_hidden_states = hidden_states
            # 如果需要歸一化編碼器隱藏狀態
            elif attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
    
            # 對隱藏狀態應用組歸一化並轉置維度
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
    
            # 生成查詢向量
            query = attn.to_q(hidden_states)
            # 獲取查詢向量的最後一維大小
            dim = query.shape[-1]
            # 將查詢向量的維度轉換為批次維度
            query = attn.head_to_batch_dim(query)
    
            # 生成編碼器隱藏狀態的鍵和值的投影
            encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
            encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
    
            # 將編碼器隱藏狀態的鍵和值轉換為批次維度
            encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
            encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
    
            # 如果不只使用交叉注意力
            if not attn.only_cross_attention:
                # 生成當前隱藏狀態的鍵和值
                key = attn.to_k(hidden_states)
                value = attn.to_v(hidden_states)
                # 將鍵和值轉換為批次維度
                key = attn.head_to_batch_dim(key)
                value = attn.head_to_batch_dim(value)
                # 將編碼器鍵與當前鍵拼接
                key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
                # 將編碼器值與當前值拼接
                value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
            else:
                # 直接使用編碼器的鍵和值
                key = encoder_hidden_states_key_proj
                value = encoder_hidden_states_value_proj
    
            # 獲取批次大小、查詢令牌數量和最後一維大小
            batch_size_attention, query_tokens, _ = query.shape
            # 初始化隱藏狀態為零張量
            hidden_states = torch.zeros(
                (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
            )
    
            # 按切片大小進行迭代處理
            for i in range((batch_size_attention - 1) // self.slice_size + 1):
                start_idx = i * self.slice_size  # 切片起始索引
                end_idx = (i + 1) * self.slice_size  # 切片結束索引
    
                # 獲取當前查詢切片、鍵切片和注意力掩碼切片
                query_slice = query[start_idx:end_idx]
                key_slice = key[start_idx:end_idx]
                attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
    
                # 獲取當前切片的注意力分數
                attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
    
                # 將注意力分數與當前值進行批次矩陣乘法
                attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
    
                # 將結果儲存到隱藏狀態
                hidden_states[start_idx:end_idx] = attn_slice
    
            # 將隱藏狀態的維度轉換回頭部維度
            hidden_states = attn.batch_to_head_dim(hidden_states)
    
            # 線性投影
            hidden_states = attn.to_out[0](hidden_states)
            # 應用丟棄層
            hidden_states = attn.to_out[1](hidden_states)
    
            # 轉置最後兩維並重塑為殘差形狀
            hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
            # 將殘差新增到當前隱藏狀態
            hidden_states = hidden_states + residual
    
            # 返回最終的隱藏狀態
            return hidden_states
# 定義一個空間歸一化類,繼承自 nn.Module
class SpatialNorm(nn.Module):
    """
    空間條件歸一化,定義在 https://arxiv.org/abs/2209.09002 中。

    引數:
        f_channels (`int`):
            輸入到組歸一化層的通道數,以及空間歸一化層的輸出通道數。
        zq_channels (`int`):
            量化向量的通道數,如論文中所述。
    """

    # 初始化方法,接收通道數作為引數
    def __init__(
        self,
        f_channels: int,
        zq_channels: int,
    ):
        # 呼叫父類的初始化方法
        super().__init__()
        # 建立組歸一化層,指定通道數、組數和其他引數
        self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
        # 建立卷積層,輸入通道為 zq_channels,輸出通道為 f_channels
        self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
        # 建立另一個卷積層,功能相同但用於偏置項
        self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)

    # 前向傳播方法,定義輸入和輸出
    def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
        # 獲取輸入張量 f 的空間尺寸
        f_size = f.shape[-2:]
        # 對 zq 進行上取樣,使其尺寸與 f 相同
        zq = F.interpolate(zq, size=f_size, mode="nearest")
        # 對輸入 f 應用歸一化層
        norm_f = self.norm_layer(f)
        # 計算新的輸出張量,透過歸一化後的 f 和卷積結果結合
        new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
        # 返回新的張量
        return new_f


# 定義一個 IPAdapter 注意力處理器類,繼承自 nn.Module
class IPAdapterAttnProcessor(nn.Module):
    r"""
    多個 IP-Adapter 的注意力處理器。

    引數:
        hidden_size (`int`):
            注意力層的隱藏尺寸。
        cross_attention_dim (`int`):
            `encoder_hidden_states` 中的通道數。
        num_tokens (`int`, `Tuple[int]` 或 `List[int]`, 預設為 `(4,)`):
            影像特徵的上下文長度。
        scale (`float` 或 List[`float`], 預設為 1.0):
            影像提示的權重縮放。
    """

    # 初始化方法,接收多個引數
    def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
        # 呼叫父類的初始化方法
        super().__init__()

        # 儲存隱藏尺寸
        self.hidden_size = hidden_size
        # 儲存交叉注意力維度
        self.cross_attention_dim = cross_attention_dim

        # 確保 num_tokens 為元組或列表
        if not isinstance(num_tokens, (tuple, list)):
            num_tokens = [num_tokens]
        # 儲存 num_tokens
        self.num_tokens = num_tokens

        # 確保 scale 為列表
        if not isinstance(scale, list):
            scale = [scale] * len(num_tokens)
        # 驗證 scale 和 num_tokens 長度相同
        if len(scale) != len(num_tokens):
            raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
        # 儲存縮放因子
        self.scale = scale

        # 建立用於鍵的線性變換列表
        self.to_k_ip = nn.ModuleList(
            [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
        )
        # 建立用於值的線性變換列表
        self.to_v_ip = nn.ModuleList(
            [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
        )

    # 定義呼叫方法,處理輸入的注意力資訊
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        scale: float = 1.0,
        ip_adapter_masks: Optional[torch.Tensor] = None,
class IPAdapterAttnProcessor2_0(torch.nn.Module):
    r"""
    PyTorch 2.0 的 IP-Adapter 注意力處理器。
    # 定義引數說明文件,列出類建構函式的引數及其型別和預設值
        Args:
            hidden_size (`int`):
                注意力層的隱藏層大小
            cross_attention_dim (`int`):
                編碼器隱藏狀態的通道數
            num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
                影像特徵的上下文長度
            scale (`float` or `List[float]`, defaults to 1.0):
                影像提示的權重比例
        """
    
        # 初始化類的建構函式,設定類屬性
        def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
            # 呼叫父類的建構函式
            super().__init__()
    
            # 檢查 PyTorch 是否支援縮放點積注意力
            if not hasattr(F, "scaled_dot_product_attention"):
                # 如果不支援,丟擲匯入錯誤
                raise ImportError(
                    f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
                )
    
            # 設定隱藏層大小屬性
            self.hidden_size = hidden_size
            # 設定交叉注意力維度屬性
            self.cross_attention_dim = cross_attention_dim
    
            # 如果 num_tokens 不是元組或列表,則將其轉換為列表
            if not isinstance(num_tokens, (tuple, list)):
                num_tokens = [num_tokens]
            # 設定 num_tokens 屬性
            self.num_tokens = num_tokens
    
            # 如果 scale 不是列表,則建立與 num_tokens 長度相同的列表
            if not isinstance(scale, list):
                scale = [scale] * len(num_tokens)
            # 檢查 scale 的長度是否與 num_tokens 相同
            if len(scale) != len(num_tokens):
                # 如果不同,丟擲值錯誤
                raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
            # 設定 scale 屬性
            self.scale = scale
    
            # 建立一個包含多個線性層的模組列表,用於輸入到 K 的對映
            self.to_k_ip = nn.ModuleList(
                [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
            )
            # 建立一個包含多個線性層的模組列表,用於輸入到 V 的對映
            self.to_v_ip = nn.ModuleList(
                [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
            )
    
        # 定義類的呼叫方法,用於執行注意力計算
        def __call__(
            self,
            attn: Attention,
            hidden_states: torch.Tensor,
            encoder_hidden_states: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            temb: Optional[torch.Tensor] = None,
            scale: float = 1.0,
            ip_adapter_masks: Optional[torch.Tensor] = None,
# 定義用於實現 PAG 的處理器類,使用縮放點積注意力(預設啟用 PyTorch 2.0)
class PAGIdentitySelfAttnProcessor2_0:
    r"""
    處理器用於實現 PAG,使用縮放點積注意力(預設在 PyTorch 2.0 中啟用)。
    PAG 參考: https://arxiv.org/abs/2403.17377
    """

    # 初始化函式
    def __init__(self):
        # 檢查 F 中是否有縮放點積注意力功能
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果沒有,則丟擲匯入錯誤,提示需要升級 PyTorch
            raise ImportError(
                "PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    # 可呼叫方法,定義了注意力處理的輸入引數
    def __call__(
        self,
        attn: Attention,  # 輸入的注意力物件
        hidden_states: torch.FloatTensor,  # 當前的隱藏狀態
        encoder_hidden_states: Optional[torch.FloatTensor] = None,  # 編碼器的隱藏狀態(可選)
        attention_mask: Optional[torch.FloatTensor] = None,  # 注意力掩碼(可選)
        temb: Optional[torch.FloatTensor] = None,  # 額外的時間嵌入(可選)
class PAGCFGIdentitySelfAttnProcessor2_0:
    r"""
    處理器用於實現 PAG,使用縮放點積注意力(預設啟用 PyTorch 2.0)。
    PAG 參考: https://arxiv.org/abs/2403.17377
    """

    # 初始化函式
    def __init__(self):
        # 檢查 F 中是否有縮放點積注意力功能
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果沒有,則丟擲匯入錯誤,提示需要升級 PyTorch
            raise ImportError(
                "PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    # 可呼叫方法,定義了注意力處理的輸入引數
    def __call__(
        self,
        attn: Attention,  # 輸入的注意力物件
        hidden_states: torch.FloatTensor,  # 當前的隱藏狀態
        encoder_hidden_states: Optional[torch.FloatTensor] = None,  # 編碼器的隱藏狀態(可選)
        attention_mask: Optional[torch.FloatTensor] = None,  # 注意力掩碼(可選)
        temb: Optional[torch.FloatTensor] = None,  # 額外的時間嵌入(可選)
class LoRAAttnProcessor:
    # 初始化函式
    def __init__(self):
        # 該類的建構函式,目前沒有初始化操作
        pass


class LoRAAttnProcessor2_0:
    # 初始化函式
    def __init__(self):
        # 該類的建構函式,目前沒有初始化操作
        pass


class LoRAXFormersAttnProcessor:
    # 初始化函式
    def __init__(self):
        # 該類的建構函式,目前沒有初始化操作
        pass


class LoRAAttnAddedKVProcessor:
    # 初始化函式
    def __init__(self):
        # 該類的建構函式,目前沒有初始化操作
        pass


# 定義一個包含新增鍵值注意力處理器的元組
ADDED_KV_ATTENTION_PROCESSORS = (
    AttnAddedKVProcessor,  # 新增鍵值注意力處理器
    SlicedAttnAddedKVProcessor,  # 切片新增鍵值注意力處理器
    AttnAddedKVProcessor2_0,  # 新增鍵值注意力處理器版本2.0
    XFormersAttnAddedKVProcessor,  # XFormers 新增鍵值注意力處理器
)

# 定義一個包含交叉注意力處理器的元組
CROSS_ATTENTION_PROCESSORS = (
    AttnProcessor,  # 注意力處理器
    AttnProcessor2_0,  # 注意力處理器版本2.0
    XFormersAttnProcessor,  # XFormers 注意力處理器
    SlicedAttnProcessor,  # 切片注意力處理器
    IPAdapterAttnProcessor,  # IPAdapter 注意力處理器
    IPAdapterAttnProcessor2_0,  # IPAdapter 注意力處理器版本2.0
)

# 定義一個包含所有注意力處理器的聯合型別
AttentionProcessor = Union[
    AttnProcessor,  # 注意力處理器
    AttnProcessor2_0,  # 注意力處理器版本2.0
    FusedAttnProcessor2_0,  # 融合注意力處理器版本2.0
    XFormersAttnProcessor,  # XFormers 注意力處理器
    SlicedAttnProcessor,  # 切片注意力處理器
    AttnAddedKVProcessor,  # 新增鍵值注意力處理器
    SlicedAttnAddedKVProcessor,  # 切片新增鍵值注意力處理器
    AttnAddedKVProcessor2_0,  # 新增鍵值注意力處理器版本2.0
    XFormersAttnAddedKVProcessor,  # XFormers 新增鍵值注意力處理器
    CustomDiffusionAttnProcessor,  # 自定義擴散注意力處理器
    CustomDiffusionXFormersAttnProcessor,  # 自定義擴散 XFormers 注意力處理器
    CustomDiffusionAttnProcessor2_0,  # 自定義擴散注意力處理器版本2.0
    PAGCFGIdentitySelfAttnProcessor2_0,  # PAGCFG 身份自注意力處理器版本2.0
    PAGIdentitySelfAttnProcessor2_0,  # PAG 身份自注意力處理器版本2.0
    PAGCFGHunyuanAttnProcessor2_0,  # PAGCGHunyuan 注意力處理器版本2.0
    PAGHunyuanAttnProcessor2_0,  # PAG Hunyuan 注意力處理器版本2.0
]

相關文章