diffusers 原始碼解析(四)
.\diffusers\models\attention_flax.py
# 版權宣告,表明該程式碼的版權歸 HuggingFace 團隊所有
# 根據 Apache 2.0 許可證授權使用該檔案,未遵守許可證不得使用
# 許可證獲取連結
# 指出該軟體是以“現狀”分發,不附帶任何明示或暗示的保證
# 具體的許可權和限制請參見許可證
# 匯入 functools 模組,用於函數語言程式設計工具
import functools
# 匯入 math 模組,提供數學相關的功能
import math
# 匯入 flax.linen 模組,作為神經網路構建的工具
import flax.linen as nn
# 匯入 jax 庫,用於加速計算
import jax
# 匯入 jax.numpy 模組,提供類似於 NumPy 的陣列功能
import jax.numpy as jnp
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
"""多頭點積注意力,查詢數目有限的實現。"""
# 獲取 key 的維度資訊,包括 key 的數量、頭數和特徵維度
num_kv, num_heads, k_features = key.shape[-3:]
# 獲取 value 的特徵維度
v_features = value.shape[-1]
# 確保 key_chunk_size 不超過 num_kv
key_chunk_size = min(key_chunk_size, num_kv)
# 對查詢進行縮放,防止數值溢位
query = query / jnp.sqrt(k_features)
@functools.partial(jax.checkpoint, prevent_cse=False)
def summarize_chunk(query, key, value):
# 計算查詢和鍵之間的注意力權重
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
# 獲取每個查詢的最大得分,用於數值穩定性
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
# 計算最大得分的梯度不更新
max_score = jax.lax.stop_gradient(max_score)
# 計算經過 softmax 的注意力權重
exp_weights = jnp.exp(attn_weights - max_score)
# 計算加權後的值
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
# 獲取每個查詢的最大得分
max_score = jnp.einsum("...qhk->...qh", max_score)
return (exp_values, exp_weights.sum(axis=-1), max_score)
def chunk_scanner(chunk_idx):
# 動態切片獲取鍵的部分資料
key_chunk = jax.lax.dynamic_slice(
operand=key,
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
)
# 動態切片獲取值的部分資料
value_chunk = jax.lax.dynamic_slice(
operand=value,
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
)
return summarize_chunk(query, key_chunk, value_chunk)
# 對每個鍵塊進行注意力計算
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
# 計算全域性最大得分
global_max = jnp.max(chunk_max, axis=0, keepdims=True)
# 計算每個塊與全域性最大得分的差異
max_diffs = jnp.exp(chunk_max - global_max)
# 更新值和權重以便於歸一化
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
chunk_weights *= max_diffs
# 計算所有塊的總值和總權重
all_values = chunk_values.sum(axis=0)
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
# 返回歸一化後的總值
return all_values / all_weights
def jax_memory_efficient_attention(
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
):
r"""
# Flax 實現的記憶體高效多頭點積注意力機制,相關文獻連結
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
# 相關 GitHub 專案連結
https://github.com/AminRezaei0x443/memory-efficient-attention
# 引數說明:
# query: 輸入的查詢張量,形狀為 (batch..., query_length, head, query_key_depth_per_head)
Args:
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
# key: 輸入的鍵張量,形狀為 (batch..., key_value_length, head, query_key_depth_per_head)
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
# value: 輸入的值張量,形狀為 (batch..., key_value_length, head, value_depth_per_head)
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
# precision: 計算時的數值精度,預設值為 jax.lax.Precision.HIGHEST
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
numerical precision for computation
# query_chunk_size: 將查詢陣列劃分的塊大小,必須能整除 query_length
query_chunk_size (`int`, *optional*, defaults to 1024):
chunk size to divide query array value must divide query_length equally without remainder
# key_chunk_size: 將鍵和值陣列劃分的塊大小,必須能整除 key_value_length
key_chunk_size (`int`, *optional*, defaults to 4096):
chunk size to divide key and value array value must divide key_value_length equally without remainder
# 返回值為形狀為 (batch..., query_length, head, value_depth_per_head) 的陣列
Returns:
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
"""
# 獲取查詢張量的最後三個維度的大小
num_q, num_heads, q_features = query.shape[-3:]
# 定義一個函式,用於掃描處理每個查詢塊
def chunk_scanner(chunk_idx, _):
# 從查詢陣列中切片出當前塊
query_chunk = jax.lax.dynamic_slice(
# 操作的物件是查詢張量
operand=query,
# 起始索引,保持前面的維度不變,從 chunk_idx 開始切片
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
# 切片的大小,前面的維度不變,後面根據塊大小取最小值
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
)
return (
# 返回未使用的下一個塊索引
chunk_idx + query_chunk_size, # unused ignore it
# 呼叫注意力函式處理當前查詢塊
_query_chunk_attention(
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
),
)
# 使用 jax.lax.scan 進行塊的掃描處理
_, res = jax.lax.scan(
f=chunk_scanner, # 處理函式
init=0, # 初始化塊索引為 0
xs=None, # 不需要額外的輸入資料
# 根據查詢塊大小計算要處理的塊數
length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
)
# 將所有塊的結果在第 -3 維度拼接在一起
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
# 定義一個 Flax 的多頭注意力模組,遵循文獻中的描述
class FlaxAttention(nn.Module):
r"""
Flax多頭注意力模組,詳見: https://arxiv.org/abs/1706.03762
引數:
query_dim (:obj:`int`):
輸入隱藏狀態的維度
heads (:obj:`int`, *optional*, defaults to 8):
注意力頭的數量
dim_head (:obj:`int`, *optional*, defaults to 64):
每個頭內隱藏狀態的維度
dropout (:obj:`float`, *optional*, defaults to 0.0):
dropout比率
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
啟用記憶體高效注意力 https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
是否將頭維度拆分為自注意力計算的新軸。通常情況下,啟用該標誌可以加快Stable Diffusion 2.x和Stable Diffusion XL的計算速度。
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
引數的 `dtype`
"""
# 定義輸入引數的型別和預設值
query_dim: int
heads: int = 8
dim_head: int = 64
dropout: float = 0.0
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32
# 設定模組的初始化函式
def setup(self):
# 計算內部維度為每個頭的維度與頭的數量的乘積
inner_dim = self.dim_head * self.heads
# 計算縮放因子
self.scale = self.dim_head**-0.5
# 建立權重矩陣,使用舊的命名 {to_q, to_k, to_v, to_out}
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
# 建立鍵的權重矩陣
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
# 建立值的權重矩陣
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
# 建立輸出的權重矩陣
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
# 建立dropout層
self.dropout_layer = nn.Dropout(rate=self.dropout)
# 將張量的頭部維度重塑為批次維度
def reshape_heads_to_batch_dim(self, tensor):
# 解構張量的形狀
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
# 重塑張量形狀以分離頭維度
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
# 轉置張量的維度
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
# 進一步重塑為批次與頭維度合併
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
# 將張量的批次維度重塑為頭部維度
def reshape_batch_dim_to_heads(self, tensor):
# 解構張量的形狀
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
# 重塑張量形狀以合併批次與頭維度
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
# 轉置張量的維度
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
# 進一步重塑為合併批次與頭維度
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
# 定義一個 Flax 基礎變換器塊層,使用 GLU 啟用函式,詳見:
class FlaxBasicTransformerBlock(nn.Module):
r"""
Flax 變換器塊層,使用 `GLU` (門控線性單元) 啟用函式,詳見:
https://arxiv.org/abs/1706.03762
# 引數說明部分
Parameters:
dim (:obj:`int`): # 內部隱藏狀態的維度
Inner hidden states dimension
n_heads (:obj:`int`): # 注意力頭的數量
Number of heads
d_head (:obj:`int`): # 每個頭內部隱藏狀態的維度
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0): # 隨機失活率
Dropout rate
only_cross_attention (`bool`, defaults to `False`): # 是否僅應用交叉注意力
Whether to only apply cross attention.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): # 引數資料型別
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): # 啟用記憶體高效注意力
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`): # 是否將頭維度拆分為新軸
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
dim: int # 內部隱藏狀態維度的型別宣告
n_heads: int # 注意力頭數量的型別宣告
d_head: int # 每個頭的隱藏狀態維度的型別宣告
dropout: float = 0.0 # 隨機失活率的預設值
only_cross_attention: bool = False # 預設不只應用交叉注意力
dtype: jnp.dtype = jnp.float32 # 預設資料型別為 jnp.float32
use_memory_efficient_attention: bool = False # 預設不啟用記憶體高效注意力
split_head_dim: bool = False # 預設不拆分頭維度
def setup(self):
# 設定自注意力(如果 only_cross_attention 為 True,則為交叉注意力)
self.attn1 = FlaxAttention(
self.dim, # 傳入的內部隱藏狀態維度
self.n_heads, # 傳入的注意力頭數量
self.d_head, # 傳入的每個頭的隱藏狀態維度
self.dropout, # 傳入的隨機失活率
self.use_memory_efficient_attention, # 是否使用記憶體高效注意力
self.split_head_dim, # 是否拆分頭維度
dtype=self.dtype, # 傳入的資料型別
)
# 設定交叉注意力
self.attn2 = FlaxAttention(
self.dim, # 傳入的內部隱藏狀態維度
self.n_heads, # 傳入的注意力頭數量
self.d_head, # 傳入的每個頭的隱藏狀態維度
self.dropout, # 傳入的隨機失活率
self.use_memory_efficient_attention, # 是否使用記憶體高效注意力
self.split_head_dim, # 是否拆分頭維度
dtype=self.dtype, # 傳入的資料型別
)
# 設定前饋網路
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) # 前饋網路初始化
# 設定第一個歸一化層
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) # 歸一化層初始化
# 設定第二個歸一化層
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) # 歸一化層初始化
# 設定第三個歸一化層
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) # 歸一化層初始化
# 設定丟棄層
self.dropout_layer = nn.Dropout(rate=self.dropout) # 丟棄層初始化
# 定義可呼叫物件,接收隱藏狀態、上下文和確定性標誌
def __call__(self, hidden_states, context, deterministic=True):
# 儲存輸入的隱藏狀態以供後續殘差連線使用
residual = hidden_states
# 如果僅執行交叉注意力,進行相關的處理
if self.only_cross_attention:
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
else:
# 否則執行自注意力處理
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
# 將自注意力的輸出與輸入的殘差相加
hidden_states = hidden_states + residual
# 交叉注意力處理
residual = hidden_states
# 處理交叉注意力
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
# 將交叉注意力的輸出與輸入的殘差相加
hidden_states = hidden_states + residual
# 前饋網路處理
residual = hidden_states
# 應用前饋網路
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
# 將前饋網路的輸出與輸入的殘差相加
hidden_states = hidden_states + residual
# 返回經過 dropout 處理的最終隱藏狀態
return self.dropout_layer(hidden_states, deterministic=deterministic)
# 定義一個二維的 Flax Transformer 模型,繼承自 nn.Module
class FlaxTransformer2DModel(nn.Module):
r"""
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
https://arxiv.org/pdf/1506.02025.pdf
文件字串,描述該類的功能和引數。
Parameters:
in_channels (:obj:`int`):
Input number of channels
n_heads (:obj:`int`):
Number of heads
d_head (:obj:`int`):
Hidden states dimension inside each head
depth (:obj:`int`, *optional*, defaults to 1):
Number of transformers block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
use_linear_projection (`bool`, defaults to `False`): tbd
only_cross_attention (`bool`, defaults to `False`): tbd
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
# 定義輸入通道數
in_channels: int
# 定義頭的數量
n_heads: int
# 定義每個頭的隱藏狀態維度
d_head: int
# 定義 Transformer 塊的數量,預設為 1
depth: int = 1
# 定義 Dropout 率,預設為 0.0
dropout: float = 0.0
# 定義是否使用線性投影,預設為 False
use_linear_projection: bool = False
# 定義是否僅使用交叉注意力,預設為 False
only_cross_attention: bool = False
# 定義引數的資料型別,預設為 jnp.float32
dtype: jnp.dtype = jnp.float32
# 定義是否使用記憶體高效注意力,預設為 False
use_memory_efficient_attention: bool = False
# 定義是否將頭維度拆分為新的軸,預設為 False
split_head_dim: bool = False
# 設定模型的元件
def setup(self):
# 使用 Group Normalization 規範化層,分組數為 32,epsilon 為 1e-5
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
# 計算內部維度為頭的數量乘以每個頭的維度
inner_dim = self.n_heads * self.d_head
# 根據是否使用線性投影選擇輸入層
if self.use_linear_projection:
# 建立一個線性投影層,輸出維度為 inner_dim,資料型別為 self.dtype
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
else:
# 建立一個卷積層,輸出維度為 inner_dim,卷積核大小為 (1, 1),步幅為 (1, 1),填充方式為 "VALID",資料型別為 self.dtype
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
# 建立一系列 Transformer 塊,數量為 depth
self.transformer_blocks = [
FlaxBasicTransformerBlock(
inner_dim,
self.n_heads,
self.d_head,
dropout=self.dropout,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
)
for _ in range(self.depth) # 迴圈生成每個 Transformer 塊
]
# 根據是否使用線性投影選擇輸出層
if self.use_linear_projection:
# 建立一個線性投影層,輸出維度為 inner_dim,資料型別為 self.dtype
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
else:
# 建立一個卷積層,輸出維度為 inner_dim,卷積核大小為 (1, 1),步幅為 (1, 1),填充方式為 "VALID",資料型別為 self.dtype
self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
# 建立一個 Dropout 層,Dropout 率為 self.dropout
self.dropout_layer = nn.Dropout(rate=self.dropout)
# 定義可呼叫物件的方法,接收隱藏狀態、上下文和確定性標誌
def __call__(self, hidden_states, context, deterministic=True):
# 解構隱藏狀態的形狀,獲取批次大小、高度、寬度和通道數
batch, height, width, channels = hidden_states.shape
# 儲存原始隱藏狀態以用於殘差連線
residual = hidden_states
# 對隱藏狀態進行歸一化處理
hidden_states = self.norm(hidden_states)
# 如果使用線性投影,則重塑隱藏狀態
if self.use_linear_projection:
# 將隱藏狀態重塑為(batch, height * width, channels)的形狀
hidden_states = hidden_states.reshape(batch, height * width, channels)
# 應用輸入投影
hidden_states = self.proj_in(hidden_states)
else:
# 直接應用輸入投影
hidden_states = self.proj_in(hidden_states)
# 將隱藏狀態重塑為(batch, height * width, channels)的形狀
hidden_states = hidden_states.reshape(batch, height * width, channels)
# 遍歷每個變換塊,更新隱藏狀態
for transformer_block in self.transformer_blocks:
# 透過變換塊處理隱藏狀態和上下文
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
# 如果使用線性投影,則先應用輸出投影
if self.use_linear_projection:
hidden_states = self.proj_out(hidden_states)
# 將隱藏狀態重塑回原來的形狀
hidden_states = hidden_states.reshape(batch, height, width, channels)
else:
# 先重塑隱藏狀態
hidden_states = hidden_states.reshape(batch, height, width, channels)
# 再應用輸出投影
hidden_states = self.proj_out(hidden_states)
# 將隱藏狀態與原始狀態相加,實現殘差連線
hidden_states = hidden_states + residual
# 返回經過dropout層處理後的隱藏狀態
return self.dropout_layer(hidden_states, deterministic=deterministic)
# 定義一個 Flax 的前饋神經網路模組,繼承自 nn.Module
class FlaxFeedForward(nn.Module):
r"""
Flax 模組封裝了兩個線性層,中間由一個非線性啟用函式分隔。它是 PyTorch 的
[`FeedForward`] 類的對應物,具有以下簡化:
- 啟用函式目前硬編碼為門控線性單元,來自:
https://arxiv.org/abs/2002.05202
- `dim_out` 等於 `dim`。
- 隱藏維度的數量硬編碼為 `dim * 4` 在 [`FlaxGELU`] 中。
引數:
dim (:obj:`int`):
內部隱藏狀態的維度
dropout (:obj:`float`, *可選*, 預設為 0.0):
丟棄率
dtype (:obj:`jnp.dtype`, *可選*, 預設為 jnp.float32):
引數的資料型別
"""
# 定義類屬性 dim、dropout 和 dtype,分別表示維度、丟棄率和資料型別
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
# 設定方法,初始化網路層
def setup(self):
# 第二個線性層暫時稱為 net_2,以匹配順序層的索引
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) # 初始化 FlaxGEGLU 網路
self.net_2 = nn.Dense(self.dim, dtype=self.dtype) # 初始化線性層
# 定義前向傳播方法
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.net_0(hidden_states, deterministic=deterministic) # 透過 net_0 處理隱藏狀態
hidden_states = self.net_2(hidden_states) # 透過 net_2 處理隱藏狀態
return hidden_states # 返回處理後的隱藏狀態
# 定義 Flax 的 GEGLU 啟用層,繼承自 nn.Module
class FlaxGEGLU(nn.Module):
r"""
Flax 實現的線性層後跟門控線性單元啟用函式變體,來自
https://arxiv.org/abs/2002.05202。
引數:
dim (:obj:`int`):
輸入隱藏狀態的維度
dropout (:obj:`float`, *可選*, 預設為 0.0):
丟棄率
dtype (:obj:`jnp.dtype`, *可選*, 預設為 jnp.float32):
引數的資料型別
"""
# 定義類屬性 dim、dropout 和 dtype
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
# 設定方法,初始化網路層
def setup(self):
inner_dim = self.dim * 4 # 計算內部維度
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) # 初始化線性層
self.dropout_layer = nn.Dropout(rate=self.dropout) # 初始化丟棄層
# 定義前向傳播方法
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.proj(hidden_states) # 透過線性層處理隱藏狀態
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) # 將輸出分為兩個部分
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic) # 返回帶丟棄的啟用輸出
.\diffusers\models\attention_processor.py
# 版權宣告,標明該檔案的版權歸 HuggingFace 團隊所有
# 該檔案根據 Apache 2.0 許可證進行許可
# 在遵守許可證的情況下,您可以使用該檔案
# 許可證的副本可以在以下網址獲取
# http://www.apache.org/licenses/LICENSE-2.0
# 除非法律要求或書面同意,否則軟體按 "現狀" 提供,不附帶任何明示或暗示的擔保
# 請參閱許可證以瞭解有關許可權和限制的具體資訊
import inspect # 匯入 inspect 模組,用於獲取物件的資訊
import math # 匯入 math 模組,提供數學函式
from typing import Callable, List, Optional, Tuple, Union # 匯入型別提示相關的型別
import torch # 匯入 PyTorch 庫
import torch.nn.functional as F # 匯入 PyTorch 中的神經網路功能模組,並重新命名為 F
from torch import nn # 從 PyTorch 匯入 nn 模組,提供神經網路的構建塊
from ..image_processor import IPAdapterMaskProcessor # 從上層模組匯入 IPAdapterMaskProcessor
from ..utils import deprecate, logging # 從上層模組匯入棄用和日誌記錄功能
from ..utils.import_utils import is_torch_npu_available, is_xformers_available # 匯入檢查 PyTorch NPU 和 xformers 可用性的工具
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph # 匯入與 PyTorch 版本和圖形相關的工具
logger = logging.get_logger(__name__) # 獲取當前模組的日誌記錄器例項,便於記錄日誌資訊
if is_torch_npu_available(): # 檢查是否可以使用 PyTorch NPU
import torch_npu # 如果可用,則匯入 torch_npu 模組
if is_xformers_available(): # 檢查是否可以使用 xformers 庫
import xformers # 如果可用,匯入 xformers 模組
import xformers.ops # 匯入 xformers 中的操作模組
else: # 如果 xformers 不可用
xformers = None # 將 xformers 設為 None
@maybe_allow_in_graph # 裝飾器,可能允許在圖中使用該類
class Attention(nn.Module): # 定義 Attention 類,繼承自 nn.Module
r""" # 文件字串,描述該類是一個交叉注意力層
A cross attention layer.
"""
def __init__( # 初始化方法,定義建構函式
self,
query_dim: int, # 查詢維度,型別為整數
cross_attention_dim: Optional[int] = None, # 可選的交叉注意力維度,預設為 None
heads: int = 8, # 注意力頭的數量,預設為 8
kv_heads: Optional[int] = None, # 可選的鍵值頭數量,預設為 None
dim_head: int = 64, # 每個頭的維度,預設為 64
dropout: float = 0.0, # dropout 機率,預設為 0.0
bias: bool = False, # 是否使用偏置,預設為 False
upcast_attention: bool = False, # 是否上升注意力精度,預設為 False
upcast_softmax: bool = False, # 是否上升 softmax 精度,預設為 False
cross_attention_norm: Optional[str] = None, # 可選的交叉注意力歸一化方式,預設為 None
cross_attention_norm_num_groups: int = 32, # 交叉注意力歸一化的組數量,預設為 32
qk_norm: Optional[str] = None, # 可選的查詢鍵歸一化方式,預設為 None
added_kv_proj_dim: Optional[int] = None, # 可選的新增鍵值投影維度,預設為 None
added_proj_bias: Optional[bool] = True, # 是否為新增的投影使用偏置,預設為 True
norm_num_groups: Optional[int] = None, # 可選的歸一化組數量,預設為 None
spatial_norm_dim: Optional[int] = None, # 可選的空間歸一化維度,預設為 None
out_bias: bool = True, # 是否使用輸出偏置,預設為 True
scale_qk: bool = True, # 是否縮放查詢和鍵,預設為 True
only_cross_attention: bool = False, # 是否僅使用交叉注意力,預設為 False
eps: float = 1e-5, # 為數值穩定性引入的微小常數,預設為 1e-5
rescale_output_factor: float = 1.0, # 輸出重標定因子,預設為 1.0
residual_connection: bool = False, # 是否使用殘差連線,預設為 False
_from_deprecated_attn_block: bool = False, # 可選引數,指示是否來自棄用的注意力塊,預設為 False
processor: Optional["AttnProcessor"] = None, # 可選的處理器,預設為 None
out_dim: int = None, # 輸出維度,預設為 None
context_pre_only=None, # 上下文前處理,預設為 None
pre_only=False, # 是否僅進行前處理,預設為 False
# 設定是否使用來自 `torch_npu` 的 npu flash attention
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
r"""
設定是否使用來自 `torch_npu` 的 npu flash attention。
"""
# 如果選擇使用 npu flash attention
if use_npu_flash_attention:
# 建立 NPU 注意力處理器例項
processor = AttnProcessorNPU()
else:
# 設定注意力處理器
# 預設情況下使用 AttnProcessor2_0,當使用 torch 2.x 時,
# 它利用 torch.nn.functional.scaled_dot_product_attention 進行本地 Flash/記憶體高效注意力
# 僅在其具有預設 `scale` 引數時適用。TODO: 在遷移到 torch 2.1 時移除 scale_qk 檢查
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
# 設定當前的處理器
self.set_processor(processor)
# 設定是否使用記憶體高效的 xformers 注意力
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
pass # 此處可能缺少實現
# 設定注意力計算的切片大小
def set_attention_slice(self, slice_size: int) -> None:
r"""
設定注意力計算的切片大小。
引數:
slice_size (`int`):
用於注意力計算的切片大小。
"""
# 如果切片大小不為 None 且大於可切片頭維度
if slice_size is not None and slice_size > self.sliceable_head_dim:
# 丟擲值錯誤,切片大小必須小於或等於可切片頭維度
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
# 如果切片大小不為 None 且新增的 kv 投影維度不為 None
if slice_size is not None and self.added_kv_proj_dim is not None:
# 建立帶切片大小的 KV 處理器例項
processor = SlicedAttnAddedKVProcessor(slice_size)
# 如果切片大小不為 None
elif slice_size is not None:
# 建立帶切片大小的注意力處理器例項
processor = SlicedAttnProcessor(slice_size)
# 如果新增的 kv 投影維度不為 None
elif self.added_kv_proj_dim is not None:
# 建立 KV 注意力處理器例項
processor = AttnAddedKVProcessor()
else:
# 設定注意力處理器
# 預設情況下使用 AttnProcessor2_0,當使用 torch 2.x 時,
# 它利用 torch.nn.functional.scaled_dot_product_attention 進行本地 Flash/記憶體高效注意力
# 僅在其具有預設 `scale` 引數時適用。TODO: 在遷移到 torch 2.1 時移除 scale_qk 檢查
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
# 設定當前的處理器
self.set_processor(processor)
# 設定要使用的注意力處理器
def set_processor(self, processor: "AttnProcessor") -> None:
r"""
設定要使用的注意力處理器。
引數:
processor (`AttnProcessor`):
要使用的注意力處理器。
"""
# 如果當前處理器在 `self._modules` 中,且傳入的 `processor` 不在其中,則需要從 `self._modules` 中移除當前處理器
if (
hasattr(self, "processor") # 檢查當前物件是否有處理器屬性
and isinstance(self.processor, torch.nn.Module) # 確保當前處理器是一個 PyTorch 模組
and not isinstance(processor, torch.nn.Module) # 檢查傳入的處理器不是 PyTorch 模組
):
# 記錄日誌,指出將移除已訓練權重的處理器
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
# 從模組中移除當前處理器
self._modules.pop("processor")
# 設定當前物件的處理器為傳入的處理器
self.processor = processor
# 獲取正在使用的注意力處理器
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
r"""
獲取正在使用的注意力處理器。
引數:
return_deprecated_lora (`bool`, *可選*, 預設為 `False`):
設定為 `True` 以返回過時的 LoRA 注意力處理器。
返回:
"AttentionProcessor": 正在使用的注意力處理器。
"""
# 如果不需要返回過時的 LoRA 處理器,則返回當前處理器
if not return_deprecated_lora:
return self.processor
# 前向傳播方法,處理輸入的隱藏狀態
def forward(
self,
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可選的編碼器隱藏狀態張量
attention_mask: Optional[torch.Tensor] = None, # 可選的注意力掩碼張量
**cross_attention_kwargs, # 可變引數,用於交叉注意力
) -> torch.Tensor:
r""" # 文件字串,描述此方法的功能和引數
The forward method of the `Attention` class.
Args: # 引數說明
hidden_states (`torch.Tensor`): # 查詢的隱藏狀態,型別為張量
The hidden states of the query.
encoder_hidden_states (`torch.Tensor`, *optional*): # 編碼器的隱藏狀態,可選引數
The hidden states of the encoder.
attention_mask (`torch.Tensor`, *optional*): # 注意力掩碼,可選引數
The attention mask to use. If `None`, no mask is applied.
**cross_attention_kwargs: # 額外的關鍵字引數,傳遞給交叉注意力
Additional keyword arguments to pass along to the cross attention.
Returns: # 返回值說明
`torch.Tensor`: The output of the attention layer. # 返回注意力層的輸出
"""
# `Attention` 類可以呼叫不同的注意力處理器/函式
# 這裡我們簡單地將所有張量傳遞給所選的處理器類
# 對於此處定義的標準處理器,`**cross_attention_kwargs` 是空的
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) # 獲取處理器呼叫方法的引數名集合
quiet_attn_parameters = {"ip_adapter_masks"} # 定義不需要警告的引數集合
unused_kwargs = [ # 篩選出未被使用的關鍵字引數
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
]
if len(unused_kwargs) > 0: # 如果存在未使用的關鍵字引數
logger.warning( # 記錄警告日誌
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # 過濾出有效的關鍵字引數
return self.processor( # 呼叫處理器並返回結果
self,
hidden_states, # 傳遞隱藏狀態
encoder_hidden_states=encoder_hidden_states, # 傳遞編碼器的隱藏狀態
attention_mask=attention_mask, # 傳遞注意力掩碼
**cross_attention_kwargs, # 解包有效的額外關鍵字引數
)
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: # 定義方法,輸入張量並返回處理後的張量
r""" # 文件字串,描述此方法的功能和引數
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` # 將張量從 `[batch_size, seq_len, dim]` 重新形狀為 `[batch_size // heads, seq_len, dim * heads]`,`heads` 為初始化時的頭數量
is the number of heads initialized while constructing the `Attention` class.
Args: # 引數說明
tensor (`torch.Tensor`): The tensor to reshape. # 要重新形狀的張量
Returns: # 返回值說明
`torch.Tensor`: The reshaped tensor. # 返回重新形狀後的張量
"""
head_size = self.heads # 獲取頭的數量
batch_size, seq_len, dim = tensor.shape # 解包輸入張量的形狀
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) # 重新調整張量的形狀
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) # 調整維度順序並重新形狀
return tensor # 返回處理後的張量
# 將輸入張量從形狀 `[batch_size, seq_len, dim]` 轉換為 `[batch_size, seq_len, heads, dim // heads]`
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
r"""
將張量從 `[batch_size, seq_len, dim]` 重塑為 `[batch_size, seq_len, heads, dim // heads]`,其中 `heads` 是
在構造 `Attention` 類時初始化的頭數。
引數:
tensor (`torch.Tensor`): 要重塑的張量。
out_dim (`int`, *可選*, 預設值為 `3`): 張量的輸出維度。如果為 `3`,則張量被
重塑為 `[batch_size * heads, seq_len, dim // heads]`。
返回:
`torch.Tensor`: 重塑後的張量。
"""
# 獲取頭的數量
head_size = self.heads
# 檢查輸入張量的維度,如果是三維則提取形狀資訊
if tensor.ndim == 3:
batch_size, seq_len, dim = tensor.shape
extra_dim = 1
else:
# 如果不是三維,提取四維形狀資訊
batch_size, extra_dim, seq_len, dim = tensor.shape
# 重塑張量為 `[batch_size, seq_len * extra_dim, head_size, dim // head_size]`
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
# 調整張量維度順序為 `[batch_size, heads, seq_len * extra_dim, dim // heads]`
tensor = tensor.permute(0, 2, 1, 3)
# 如果輸出維度為 3,進一步重塑張量為 `[batch_size * heads, seq_len * extra_dim, dim // heads]`
if out_dim == 3:
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
# 返回重塑後的張量
return tensor
# 計算注意力得分的函式
def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""
計算注意力得分。
引數:
query (`torch.Tensor`): 查詢張量。
key (`torch.Tensor`): 鍵張量。
attention_mask (`torch.Tensor`, *可選*): 使用的注意力掩碼。如果為 `None`,則不應用掩碼。
返回:
`torch.Tensor`: 注意力機率/得分。
"""
# 獲取查詢張量的資料型別
dtype = query.dtype
# 如果需要上升型別,將查詢和鍵張量轉換為浮點型
if self.upcast_attention:
query = query.float()
key = key.float()
# 如果沒有提供注意力掩碼,建立空的輸入張量
if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
# 設定 beta 為 0
beta = 0
else:
# 如果有注意力掩碼,將其用作輸入
baddbmm_input = attention_mask
# 設定 beta 為 1
beta = 1
# 計算注意力得分
attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)
# 刪除臨時的輸入張量
del baddbmm_input
# 如果需要上升型別,將注意力得分轉換為浮點型
if self.upcast_softmax:
attention_scores = attention_scores.float()
# 計算注意力機率
attention_probs = attention_scores.softmax(dim=-1)
# 刪除注意力得分張量
del attention_scores
# 將注意力機率轉換回原始資料型別
attention_probs = attention_probs.to(dtype)
# 返回注意力機率
return attention_probs
# 準備注意力掩碼的函式
def prepare_attention_mask(
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
) -> torch.Tensor: # 定義一個函式的返回型別為 torch.Tensor
r""" # 開始文件字串,描述函式的作用和引數
Prepare the attention mask for the attention computation. # 準備注意力計算的注意力掩碼
Args: # 引數說明
attention_mask (`torch.Tensor`): # 輸入引數,注意力掩碼,型別為 torch.Tensor
The attention mask to prepare. # 待準備的注意力掩碼
target_length (`int`): # 輸入引數,目標長度,型別為 int
The target length of the attention mask. This is the length of the attention mask after padding. # 注意力掩碼的目標長度,經過填充後的長度
batch_size (`int`): # 輸入引數,批處理大小,型別為 int
The batch size, which is used to repeat the attention mask. # 批處理大小,用於重複注意力掩碼
out_dim (`int`, *optional*, defaults to `3`): # 可選引數,輸出維度,型別為 int,預設為 3
The output dimension of the attention mask. Can be either `3` or `4`. # 注意力掩碼的輸出維度,可以是 3 或 4
Returns: # 返回說明
`torch.Tensor`: The prepared attention mask. # 返回準備好的注意力掩碼,型別為 torch.Tensor
""" # 結束文件字串
head_size = self.heads # 獲取頭部大小,來自類的屬性 heads
if attention_mask is None: # 檢查注意力掩碼是否為 None
return attention_mask # 如果是 None,直接返回
current_length: int = attention_mask.shape[-1] # 獲取當前注意力掩碼的長度
if current_length != target_length: # 檢查當前長度是否與目標長度不匹配
if attention_mask.device.type == "mps": # 如果裝置型別是 "mps"
# HACK: MPS: Does not support padding by greater than dimension of input tensor. # HACK: MPS 不支援填充超過輸入張量的維度
# Instead, we can manually construct the padding tensor. # 所以我們手動構建填充張量
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) # 定義填充張量的形狀
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) # 建立全零填充張量
attention_mask = torch.cat([attention_mask, padding], dim=2) # 在最後一個維度上拼接填充張量
else: # 如果不是 "mps" 裝置
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask: # TODO: 對於如 stable-diffusion 的管道,填充交叉注意力掩碼
# we want to instead pad by (0, remaining_length), where remaining_length is: # 我們希望用 (0, remaining_length) 填充,其中 remaining_length 是
# remaining_length: int = target_length - current_length # remaining_length 的計算
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding # TODO: 重新啟用相關測試
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) # 用零填充注意力掩碼到目標長度
if out_dim == 3: # 如果輸出維度是 3
if attention_mask.shape[0] < batch_size * head_size: # 檢查注意力掩碼的第一維是否小於批處理大小乘以頭部大小
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) # 在第一維上重複注意力掩碼
elif out_dim == 4: # 如果輸出維度是 4
attention_mask = attention_mask.unsqueeze(1) # 在第一維增加一個維度
attention_mask = attention_mask.repeat_interleave(head_size, dim=1) # 在第二維上重複注意力掩碼
return attention_mask # 返回準備好的注意力掩碼
# 定義一個函式用於規範化編碼器的隱藏狀態,接受一個張量作為輸入並返回一個張量
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
r"""
規範化編碼器隱藏狀態。構造 `Attention` 類時需要指定 `self.norm_cross`。
引數:
encoder_hidden_states (`torch.Tensor`): 編碼器的隱藏狀態。
返回:
`torch.Tensor`: 規範化後的編碼器隱藏狀態。
"""
# 確保在呼叫此方法之前已定義 `self.norm_cross`
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
# 檢查 `self.norm_cross` 是否為 LayerNorm 型別
if isinstance(self.norm_cross, nn.LayerNorm):
# 對編碼器隱藏狀態進行層歸一化
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
# 檢查 `self.norm_cross` 是否為 GroupNorm 型別
elif isinstance(self.norm_cross, nn.GroupNorm):
# GroupNorm 沿通道維度進行歸一化,並期望輸入形狀為 (N, C, *)。
# 此時我們希望沿隱藏維度進行歸一化,因此需要調整形狀
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2) # 轉置張量以調整維度順序
encoder_hidden_states = self.norm_cross(encoder_hidden_states) # 對轉置後的張量進行歸一化
encoder_hidden_states = encoder_hidden_states.transpose(1, 2) # 再次轉置回原始順序
else:
# 如果 `self.norm_cross` 既不是 LayerNorm 也不是 GroupNorm,則觸發斷言失敗
assert False
# 返回規範化後的編碼器隱藏狀態
return encoder_hidden_states
# 該裝飾器在計算圖中禁止梯度計算,以節省記憶體和加快推理速度
@torch.no_grad()
# 定義一個融合投影的方法,預設引數 fuse 為 True
def fuse_projections(self, fuse=True):
# 獲取 to_q 權重的裝置資訊
device = self.to_q.weight.data.device
# 獲取 to_q 權重的資料型別
dtype = self.to_q.weight.data.dtype
# 如果不是交叉注意力
if not self.is_cross_attention:
# 獲取權重矩陣的拼接
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
# 輸入特徵數為拼接後權重的列數
in_features = concatenated_weights.shape[1]
# 輸出特徵數為拼接後權重的行數
out_features = concatenated_weights.shape[0]
# 建立一個新的線性投影層並複製權重
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
# 複製拼接後的權重到新的層
self.to_qkv.weight.copy_(concatenated_weights)
# 如果使用偏置
if self.use_bias:
# 拼接 q、k、v 的偏置
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
# 複製拼接後的偏置到新的層
self.to_qkv.bias.copy_(concatenated_bias)
# 如果是交叉注意力
else:
# 獲取 k 和 v 權重的拼接
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
# 輸入特徵數為拼接後權重的列數
in_features = concatenated_weights.shape[1]
# 輸出特徵數為拼接後權重的行數
out_features = concatenated_weights.shape[0]
# 建立一個新的線性投影層並複製權重
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
# 複製拼接後的權重到新的層
self.to_kv.weight.copy_(concatenated_weights)
# 如果使用偏置
if self.use_bias:
# 拼接 k 和 v 的偏置
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
# 複製拼接後的偏置到新的層
self.to_kv.bias.copy_(concatenated_bias)
# 處理 SD3 和其他新增的投影
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
# 獲取額外投影的權重拼接
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
# 輸入特徵數為拼接後權重的列數
in_features = concatenated_weights.shape[1]
# 輸出特徵數為拼接後權重的行數
out_features = concatenated_weights.shape[0]
# 建立一個新的線性投影層並複製權重
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
# 複製拼接後的權重到新的層
self.to_added_qkv.weight.copy_(concatenated_weights)
# 如果使用偏置
if self.added_proj_bias:
# 拼接額外投影的偏置
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
# 複製拼接後的偏置到新的層
self.to_added_qkv.bias.copy_(concatenated_bias)
# 將融合狀態儲存到屬性中
self.fused_projections = fuse
# 定義一個處理器類,用於執行與注意力相關的計算
class AttnProcessor:
r"""
預設處理器,用於執行與注意力相關的計算。
"""
# 實現可呼叫方法,處理注意力計算
def __call__(
self,
attn: Attention, # 注意力物件
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 編碼器隱藏狀態(可選)
attention_mask: Optional[torch.Tensor] = None, # 注意力掩碼(可選)
temb: Optional[torch.Tensor] = None, # 額外的時間嵌入(可選)
*args, # 額外的位置引數
**kwargs, # 額外的關鍵字引數
) -> torch.Tensor: # 返回處理後的張量
# 檢查是否有額外引數或已棄用的 scale 引數
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 構建棄用警告訊息
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫棄用處理函式
deprecate("scale", "1.0.0", deprecation_message)
# 初始化殘差為隱藏狀態
residual = hidden_states
# 如果空間歸一化存在,則應用於隱藏狀態
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
# 獲取輸入張量的維度
input_ndim = hidden_states.ndim
# 如果輸入是四維的,則調整形狀
if input_ndim == 4:
# 解包隱藏狀態的形狀
batch_size, channel, height, width = hidden_states.shape
# 重新調整形狀為(batch_size, channel, height*width)並轉置
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# 根據編碼器隱藏狀態的存在與否,獲取批次大小和序列長度
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
# 準備注意力掩碼
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# 如果組歸一化存在,則應用於隱藏狀態
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
# 將隱藏狀態轉換為查詢向量
query = attn.to_q(hidden_states)
# 如果沒有編碼器隱藏狀態,使用隱藏狀態作為編碼器隱藏狀態
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# 如果需要規範化編碼器隱藏狀態,則應用規範化
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# 從編碼器隱藏狀態中獲取鍵和值
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# 將查詢、鍵和值轉換為批次維度
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
# 計算注意力分數
attention_probs = attn.get_attention_scores(query, key, attention_mask)
# 透過注意力分數加權求值
hidden_states = torch.bmm(attention_probs, value)
# 將隱藏狀態轉換回頭維度
hidden_states = attn.batch_to_head_dim(hidden_states)
# 線性投影
hidden_states = attn.to_out[0](hidden_states)
# 應用 dropout
hidden_states = attn.to_out[1](hidden_states)
# 如果輸入是四維的,調整回原始形狀
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 如果存在殘差連線,則將殘差加回隱藏狀態
if attn.residual_connection:
hidden_states = hidden_states + residual
# 將隱藏狀態歸一化到輸出因子
hidden_states = hidden_states / attn.rescale_output_factor
# 返回最終的隱藏狀態
return hidden_states
# 定義一個處理器類,用於實現自定義擴散方法的注意力
class CustomDiffusionAttnProcessor(nn.Module):
r"""
實現自定義擴散方法的注意力處理器。
# 定義引數說明
Args:
train_kv (`bool`, defaults to `True`): # 是否重新訓練對應於文字特徵的鍵值矩陣
Whether to newly train the key and value matrices corresponding to the text features.
train_q_out (`bool`, defaults to `True`): # 是否重新訓練對應於潛在影像特徵的查詢矩陣
Whether to newly train query matrices corresponding to the latent image features.
hidden_size (`int`, *optional*, defaults to `None`): # 注意力層的隱藏大小
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`): # 編碼器隱藏狀態中的通道數量
The number of channels in the `encoder_hidden_states`.
out_bias (`bool`, defaults to `True`): # 是否在 `train_q_out` 中包含偏置引數
Whether to include the bias parameter in `train_q_out`.
dropout (`float`, *optional*, defaults to 0.0): # 使用的 dropout 機率
The dropout probability to use.
"""
# 初始化方法
def __init__(
self, # 初始化方法的第一個引數,表示物件本身
train_kv: bool = True, # 設定鍵值矩陣訓練的預設值為 True
train_q_out: bool = True, # 設定查詢矩陣訓練的預設值為 True
hidden_size: Optional[int] = None, # 隱藏層大小,預設為 None
cross_attention_dim: Optional[int] = None, # 跨注意力維度,預設為 None
out_bias: bool = True, # 輸出偏置引數的預設值為 True
dropout: float = 0.0, # 預設的 dropout 機率為 0.0
):
super().__init__() # 呼叫父類的初始化方法
self.train_kv = train_kv # 儲存鍵值訓練標誌
self.train_q_out = train_q_out # 儲存查詢輸出訓練標誌
self.hidden_size = hidden_size # 儲存隱藏層大小
self.cross_attention_dim = cross_attention_dim # 儲存跨注意力維度
# `_custom_diffusion` id 方便序列化和載入
if self.train_kv: # 如果需要訓練鍵值
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) # 建立鍵的線性層
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) # 建立值的線性層
if self.train_q_out: # 如果需要訓練查詢輸出
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) # 建立查詢的線性層
self.to_out_custom_diffusion = nn.ModuleList([]) # 初始化輸出層的模組列表
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) # 新增線性輸出層
self.to_out_custom_diffusion.append(nn.Dropout(dropout)) # 新增 dropout 層
# 可呼叫方法
def __call__( # 定義物件被呼叫時的行為
self, # 第一個引數,表示物件本身
attn: Attention, # 注意力物件
hidden_states: torch.Tensor, # 隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 編碼器隱藏狀態,預設為 None
attention_mask: Optional[torch.Tensor] = None, # 注意力掩碼,預設為 None
# 返回型別為 torch.Tensor
) -> torch.Tensor:
# 獲取隱藏狀態的批次大小和序列長度
batch_size, sequence_length, _ = hidden_states.shape
# 準備注意力掩碼以適應當前批次和序列長度
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# 如果需要訓練查詢輸出,則使用自定義擴散進行轉換
if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else:
# 否則使用標準的查詢轉換
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
# 檢查編碼器隱藏狀態是否為 None
if encoder_hidden_states is None:
# 如果是,則不進行交叉注意力
crossattn = False
encoder_hidden_states = hidden_states
else:
# 否則,啟用交叉注意力
crossattn = True
# 如果需要歸一化編碼器隱藏狀態,則進行歸一化
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# 如果需要訓練鍵值對
if self.train_kv:
# 使用自定義擴散獲取鍵和值
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
# 將鍵和值轉換為查詢的權重資料型別
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else:
# 否則使用標準的鍵和值轉換
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# 如果進行交叉注意力
if crossattn:
# 建立與鍵相同形狀的張量以進行detach操作
detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :] * 0.0
# 應用detach邏輯以阻止梯度流動
key = detach * key + (1 - detach) * key.detach()
value = detach * value + (1 - detach) * value.detach()
# 將查詢、鍵和值轉換為批次維度
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
# 計算注意力分數
attention_probs = attn.get_attention_scores(query, key, attention_mask)
# 使用注意力分數和值進行批次矩陣乘法
hidden_states = torch.bmm(attention_probs, value)
# 將隱藏狀態轉換回頭維度
hidden_states = attn.batch_to_head_dim(hidden_states)
# 如果需要訓練查詢輸出
if self.train_q_out:
# 線性投影
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
# 應用dropout
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
else:
# 否則使用標準的線性投影
hidden_states = attn.to_out[0](hidden_states)
# 應用dropout
hidden_states = attn.to_out[1](hidden_states)
# 返回最終的隱藏狀態
return hidden_states
# 定義一個帶有額外可學習的鍵和值矩陣的注意力處理器類
class AttnAddedKVProcessor:
r"""
處理器,用於執行與文字編碼器相關的注意力計算
"""
# 定義呼叫方法,以實現注意力計算
def __call__(
self,
attn: Attention, # 注意力物件
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 編碼器的隱藏狀態(可選)
attention_mask: Optional[torch.Tensor] = None, # 注意力掩碼(可選)
*args, # 其他位置引數
**kwargs, # 其他關鍵字引數
) -> torch.Tensor: # 返回型別為張量
# 檢查是否傳遞了多餘的引數或已棄用的 scale 引數
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 發出棄用警告
deprecate("scale", "1.0.0", deprecation_message)
# 將隱藏狀態賦值給殘差
residual = hidden_states
# 重塑隱藏狀態的形狀,並轉置維度
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
# 獲取批大小和序列長度
batch_size, sequence_length, _ = hidden_states.shape
# 準備注意力掩碼
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# 如果沒有編碼器隱藏狀態,則使用輸入的隱藏狀態
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# 如果需要進行歸一化處理
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# 對隱藏狀態進行分組歸一化處理
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
# 將隱藏狀態轉換為查詢
query = attn.to_q(hidden_states)
# 將查詢從頭維度轉換為批維度
query = attn.head_to_batch_dim(query)
# 將編碼器隱藏狀態投影為鍵和值
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# 將投影結果轉換為批維度
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
# 如果不是僅進行交叉注意力
if not attn.only_cross_attention:
# 將隱藏狀態轉換為鍵和值
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# 轉換為批維度
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
# 將編碼器鍵和值與當前鍵和值拼接
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
# 僅使用編碼器的鍵和值
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
# 獲取注意力機率
attention_probs = attn.get_attention_scores(query, key, attention_mask)
# 計算隱藏狀態的新值
hidden_states = torch.bmm(attention_probs, value)
# 將隱藏狀態轉換回頭維度
hidden_states = attn.batch_to_head_dim(hidden_states)
# 線性投影
hidden_states = attn.to_out[0](hidden_states)
# 應用 dropout
hidden_states = attn.to_out[1](hidden_states)
# 重塑隱藏狀態,並將殘差加回
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
# 返回處理後的隱藏狀態
return hidden_states
# 定義另一個注意力處理器類
class AttnAddedKVProcessor2_0:
r"""
# 處理縮放點積注意力的處理器(如果使用 PyTorch 2.0,預設啟用),
# 其中為文字編碼器新增了額外的可學習的鍵和值矩陣。
"""
# 初始化方法
def __init__(self):
# 檢查 F 中是否有 "scaled_dot_product_attention" 屬性
if not hasattr(F, "scaled_dot_product_attention"):
# 如果沒有,丟擲 ImportError,提示使用者需要升級到 PyTorch 2.0
raise ImportError(
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
# 定義呼叫方法
def __call__(
self,
attn: Attention, # 輸入的注意力機制物件
hidden_states: torch.Tensor, # 隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可選的編碼器隱藏狀態張量
attention_mask: Optional[torch.Tensor] = None, # 可選的注意力掩碼張量
*args, # 額外的位置引數
**kwargs, # 額外的關鍵字引數
) -> torch.Tensor: # 指定函式返回型別為 torch.Tensor
# 檢查引數是否存在或 scale 引數是否被提供
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 設定棄用訊息,告知 scale 引數將被忽略
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 呼叫 deprecate 函式發出棄用警告
deprecate("scale", "1.0.0", deprecation_message)
# 將輸入的 hidden_states 賦值給 residual
residual = hidden_states
# 調整 hidden_states 的形狀並進行轉置
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
# 獲取 batch_size 和 sequence_length
batch_size, sequence_length, _ = hidden_states.shape
# 準備注意力掩碼
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
# 如果沒有提供 encoder_hidden_states,則使用 hidden_states
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# 如果需要歸一化交叉隱藏狀態
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# 對 hidden_states 進行分組歸一化
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
# 計算查詢向量
query = attn.to_q(hidden_states)
# 將查詢向量轉換為批次維度
query = attn.head_to_batch_dim(query, out_dim=4)
# 生成 encoder_hidden_states 的鍵和值的投影
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# 將鍵和值轉換為批次維度
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
# 如果不是隻進行交叉注意力
if not attn.only_cross_attention:
# 計算當前 hidden_states 的鍵和值
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# 轉換為批次維度
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
# 將鍵和值與 encoder 的鍵和值連線
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
else:
# 如果只進行交叉注意力,使用 encoder 的鍵和值
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
# 計算縮放點積注意力的輸出,形狀為 (batch, num_heads, seq_len, head_dim)
# TODO: 在遷移到 Torch 2.1 時新增對 attn.scale 的支援
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# 轉置並重塑 hidden_states
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# 進行線性投影
hidden_states = attn.to_out[0](hidden_states)
# 進行 dropout
hidden_states = attn.to_out[1](hidden_states)
# 轉置並重塑回 residual 的形狀
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
# 將 residual 加到 hidden_states 上
hidden_states = hidden_states + residual
# 返回最終的 hidden_states
return hidden_states
# 定義一個名為 JointAttnProcessor2_0 的類,用於處理自注意力投影
class JointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
# 初始化方法
def __init__(self):
# 檢查 F 是否有 scaled_dot_product_attention 屬性
if not hasattr(F, "scaled_dot_product_attention"):
# 如果沒有,丟擲匯入錯誤,提示需要升級 PyTorch 到 2.0
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
# 定義呼叫方法,接受多個引數
def __call__(
self,
attn: Attention, # 自注意力物件
hidden_states: torch.FloatTensor, # 當前隱藏狀態的張量
encoder_hidden_states: torch.FloatTensor = None, # 編碼器的隱藏狀態,預設為 None
attention_mask: Optional[torch.FloatTensor] = None, # 可選的注意力掩碼,預設為 None
*args, # 額外的位置引數
**kwargs, # 額外的關鍵字引數
# 返回一個浮點張量
) -> torch.FloatTensor:
# 儲存輸入的隱藏狀態,以便後續使用
residual = hidden_states
# 獲取隱藏狀態的維度
input_ndim = hidden_states.ndim
# 如果隱藏狀態是四維的
if input_ndim == 4:
# 解包隱藏狀態的形狀為批大小、通道、高度和寬度
batch_size, channel, height, width = hidden_states.shape
# 將隱藏狀態重塑為三維,並進行轉置
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# 獲取編碼器隱藏狀態的維度
context_input_ndim = encoder_hidden_states.ndim
# 如果編碼器隱藏狀態是四維的
if context_input_ndim == 4:
# 解包編碼器隱藏狀態的形狀為批大小、通道、高度和寬度
batch_size, channel, height, width = encoder_hidden_states.shape
# 將編碼器隱藏狀態重塑為三維,並進行轉置
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# 獲取編碼器隱藏狀態的批大小
batch_size = encoder_hidden_states.shape[0]
# 計算 `sample` 投影
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# 計算 `context` 投影
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# 合併注意力查詢、鍵和值
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
# 獲取鍵的最後一維大小
inner_dim = key.shape[-1]
# 計算每個頭的維度
head_dim = inner_dim // attn.heads
# 重塑查詢、鍵和值以適應多個頭
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# 計算縮放點積注意力
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
# 轉置並重塑隱藏狀態
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# 轉換為查詢的型別
hidden_states = hidden_states.to(query.dtype)
# 拆分注意力輸出
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]], # 獲取原隱藏狀態的部分
hidden_states[:, residual.shape[1] :], # 獲取編碼器隱藏狀態的部分
)
# 進行線性投影
hidden_states = attn.to_out[0](hidden_states)
# 進行 dropout
hidden_states = attn.to_out[1](hidden_states)
# 如果上下文不是僅限於編碼器
if not attn.context_pre_only:
# 對編碼器隱藏狀態進行額外處理
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# 如果輸入是四維的,進行轉置和重塑
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 如果上下文輸入是四維的,進行轉置和重塑
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 返回處理後的隱藏狀態和編碼器隱藏狀態
return hidden_states, encoder_hidden_states
# 定義一個類,PAGJointAttnProcessor2_0,用於處理自注意力投影
class PAGJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
# 初始化方法
def __init__(self):
# 檢查是否存在名為"scaled_dot_product_attention"的屬性
if not hasattr(F, "scaled_dot_product_attention"):
# 如果不存在,則丟擲匯入錯誤,提示需要升級PyTorch到2.0
raise ImportError(
"PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
# 可呼叫方法,接受注意力物件和隱藏狀態
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
# 其他可選引數
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
# 定義另一個類,PAGCFGJointAttnProcessor2_0,類似於PAGJointAttnProcessor2_0
class PAGCFGJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
# 初始化方法
def __init__(self):
# 檢查是否存在名為"scaled_dot_product_attention"的屬性
if not hasattr(F, "scaled_dot_product_attention"):
# 如果不存在,則丟擲匯入錯誤,提示需要升級PyTorch到2.0
raise ImportError(
"PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
# 可呼叫方法,接受注意力物件和隱藏狀態
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
# 其他可選引數
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
# 定義第三個類,FusedJointAttnProcessor2_0,處理自注意力投影
class FusedJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
# 初始化方法
def __init__(self):
# 檢查是否存在名為"scaled_dot_product_attention"的屬性
if not hasattr(F, "scaled_dot_product_attention"):
# 如果不存在,則丟擲匯入錯誤,提示需要升級PyTorch到2.0
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
# 可呼叫方法,接受注意力物件和隱藏狀態
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
# 其他可選引數
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
# 將隱藏狀態賦值給殘差變數
residual = hidden_states
# 獲取隱藏狀態的維度
input_ndim = hidden_states.ndim
# 如果隱藏狀態是四維的,進行維度變換
if input_ndim == 4:
# 解包隱藏狀態的形狀
batch_size, channel, height, width = hidden_states.shape
# 將隱藏狀態變形為(batch_size, channel, height * width)並轉置
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# 獲取編碼器隱藏狀態的維度
context_input_ndim = encoder_hidden_states.ndim
# 如果編碼器隱藏狀態是四維的,進行維度變換
if context_input_ndim == 4:
# 解包編碼器隱藏狀態的形狀
batch_size, channel, height, width = encoder_hidden_states.shape
# 將編碼器隱藏狀態變形為(batch_size, channel, height * width)並轉置
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# 獲取編碼器隱藏狀態的批次大小
batch_size = encoder_hidden_states.shape[0]
# `sample` 進行投影
qkv = attn.to_qkv(hidden_states)
# 計算每個分量的大小
split_size = qkv.shape[-1] // 3
# 將qkv拆分為query、key和value
query, key, value = torch.split(qkv, split_size, dim=-1)
# `context` 進行投影
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
# 計算編碼器qkv的分量大小
split_size = encoder_qkv.shape[-1] // 3
# 將編碼器qkv拆分為查詢、鍵和值的投影
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
# 進行注意力計算
# 將query、key、value進行連線
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
# 獲取key的最後一維大小
inner_dim = key.shape[-1]
# 計算每個頭的維度
head_dim = inner_dim // attn.heads
# 調整query的形狀以適應多頭注意力
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# 調整key的形狀以適應多頭注意力
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# 調整value的形狀以適應多頭注意力
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# 進行縮放點積注意力計算
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
# 調整hidden_states的形狀
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# 將hidden_states轉換為與query相同的資料型別
hidden_states = hidden_states.to(query.dtype)
# 拆分注意力輸出
hidden_states, encoder_hidden_states = (
# 保留殘差形狀的部分
hidden_states[:, : residual.shape[1]],
# 剩餘的部分
hidden_states[:, residual.shape[1] :],
)
# 線性投影
hidden_states = attn.to_out[0](hidden_states)
# 進行dropout
hidden_states = attn.to_out[1](hidden_states)
# 如果不是隻使用上下文,進行編碼器輸出的投影
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# 如果輸入是四維的,調整hidden_states的形狀
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 如果上下文輸入是四維的,調整encoder_hidden_states的形狀
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 返回hidden_states和encoder_hidden_states
return hidden_states, encoder_hidden_states
# 定義一個用於處理 Aura Flow 的注意力處理器類
class AuraFlowAttnProcessor2_0:
"""Attention processor used typically in processing Aura Flow."""
# 初始化方法
def __init__(self):
# 檢查 F 是否具有 scaled_dot_product_attention 屬性,並確保 PyTorch 版本符合要求
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
# 如果不滿足條件,丟擲匯入錯誤,提示使用者升級 PyTorch
raise ImportError(
"AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
)
# 可呼叫方法,用於處理輸入的注意力和隱藏狀態
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
*args,
**kwargs,
# 定義一個用於處理 Aura Flow 的融合投影注意力處理器類
class FusedAuraFlowAttnProcessor2_0:
"""Attention processor used typically in processing Aura Flow with fused projections."""
# 初始化方法
def __init__(self):
# 檢查 F 是否具有 scaled_dot_product_attention 屬性,並確保 PyTorch 版本符合要求
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
# 如果不滿足條件,丟擲匯入錯誤,提示使用者升級 PyTorch
raise ImportError(
"FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
)
# 可呼叫方法,用於處理輸入的注意力和隱藏狀態
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
*args,
**kwargs,
# YiYi 待辦事項:重構與 rope 相關的函式/類
def apply_rope(xq, xk, freqs_cis):
# 將 xq 轉換為浮點型,並重新調整形狀以便處理
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
# 將 xk 轉換為浮點型,並重新調整形狀以便處理
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
# 計算 xq 的輸出,結合頻率複數
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
# 計算 xk 的輸出,結合頻率複數
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
# 返回撥整形狀後的 xq_out 和 xk_out,並確保與原始型別匹配
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
# 定義一個實現縮放點積注意力的處理器類
class FluxSingleAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
# 初始化方法
def __init__(self):
# 檢查 F 是否具有 scaled_dot_product_attention 屬性
if not hasattr(F, "scaled_dot_product_attention"):
# 如果不滿足條件,丟擲匯入錯誤,提示使用者升級 PyTorch
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
# 可呼叫方法,用於處理輸入的注意力和隱藏狀態
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
# 定義函式的返回型別為 torch.Tensor
) -> torch.Tensor:
# 獲取 hidden_states 的維度數量
input_ndim = hidden_states.ndim
# 如果輸入的維度為 4
if input_ndim == 4:
# 解包 hidden_states 的形狀為 batch_size, channel, height, width
batch_size, channel, height, width = hidden_states.shape
# 將 hidden_states 檢視調整為 (batch_size, channel, height * width) 並轉置
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# 如果 encoder_hidden_states 為 None,則獲取 hidden_states 的形狀
# 否則獲取 encoder_hidden_states 的形狀
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# 將 hidden_states 轉換為查詢向量
query = attn.to_q(hidden_states)
# 如果 encoder_hidden_states 為 None,將其設定為 hidden_states
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# 將 encoder_hidden_states 轉換為鍵向量
key = attn.to_k(encoder_hidden_states)
# 將 encoder_hidden_states 轉換為值向量
value = attn.to_v(encoder_hidden_states)
# 獲取鍵的最後一個維度的大小
inner_dim = key.shape[-1]
# 計算每個頭的維度
head_dim = inner_dim // attn.heads
# 將查詢向量調整檢視為 (batch_size, -1, attn.heads, head_dim) 並轉置
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# 將鍵向量調整檢視為 (batch_size, -1, attn.heads, head_dim) 並轉置
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# 將值向量調整檢視為 (batch_size, -1, attn.heads, head_dim) 並轉置
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# 如果存在規範化查詢的層,則對查詢進行規範化
if attn.norm_q is not None:
query = attn.norm_q(query)
# 如果存在規範化鍵的層,則對鍵進行規範化
if attn.norm_k is not None:
key = attn.norm_k(key)
# 如果需要應用 RoPE
if image_rotary_emb is not None:
# 應用旋轉嵌入到查詢和鍵上
query, key = apply_rope(query, key, image_rotary_emb)
# 計算縮放點積注意力,輸出形狀為 (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
# 轉置並調整 hidden_states 的形狀為 (batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# 將 hidden_states 轉換為與查詢相同的資料型別
hidden_states = hidden_states.to(query.dtype)
# 如果輸入維度為 4,將 hidden_states 轉置並調整形狀回原始維度
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 返回處理後的 hidden_states
return hidden_states
# 定義一個名為 FluxAttnProcessor2_0 的類,通常用於處理 SD3 類自注意力投影
class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
# 初始化方法
def __init__(self):
# 檢查 F 是否有 scaled_dot_product_attention 屬性,如果沒有則丟擲 ImportError
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
# 定義呼叫方法,使類例項可被呼叫
def __call__(
self,
attn: Attention, # 接收 Attention 物件
hidden_states: torch.FloatTensor, # 接收隱藏狀態張量
encoder_hidden_states: torch.FloatTensor = None, # 可選的編碼器隱藏狀態張量
attention_mask: Optional[torch.FloatTensor] = None, # 可選的注意力掩碼張量
image_rotary_emb: Optional[torch.Tensor] = None, # 可選的影像旋轉嵌入張量
):
# 此處將實現自注意力的具體處理邏輯
# 定義一個名為 CogVideoXAttnProcessor2_0 的類,專用於 CogVideoX 模型的縮放點積注意力處理
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
# 初始化方法
def __init__(self):
# 檢查 F 是否有 scaled_dot_product_attention 屬性,如果沒有則丟擲 ImportError
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
# 定義呼叫方法,使類例項可被呼叫
def __call__(
self,
attn: Attention, # 接收 Attention 物件
hidden_states: torch.Tensor, # 接收隱藏狀態張量
encoder_hidden_states: torch.Tensor, # 接收編碼器隱藏狀態張量
attention_mask: Optional[torch.Tensor] = None, # 可選的注意力掩碼張量
image_rotary_emb: Optional[torch.Tensor] = None, # 可選的影像旋轉嵌入張量
):
# 此處將實現自注意力的具體處理邏輯
) -> torch.Tensor: # 函式返回一個張量,表示隱藏狀態
text_seq_length = encoder_hidden_states.size(1) # 獲取編碼器隱藏狀態的序列長度
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) # 在維度1上連線編碼器隱藏狀態和當前隱藏狀態
batch_size, sequence_length, _ = ( # 解包 batch_size 和 sequence_length
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape # 根據編碼器隱藏狀態的存在性決定形狀
)
if attention_mask is not None: # 如果存在注意力掩碼
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # 準備注意力掩碼
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) # 調整注意力掩碼的形狀以適應頭數
query = attn.to_q(hidden_states) # 將隱藏狀態轉換為查詢向量
key = attn.to_k(hidden_states) # 將隱藏狀態轉換為鍵向量
value = attn.to_v(hidden_states) # 將隱藏狀態轉換為值向量
inner_dim = key.shape[-1] # 獲取鍵向量的最後一個維度大小
head_dim = inner_dim // attn.heads # 計算每個頭的維度
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # 調整查詢向量形狀並轉置以適應多頭注意力
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # 調整鍵向量形狀並轉置以適應多頭注意力
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # 調整值向量形狀並轉置以適應多頭注意力
if attn.norm_q is not None: # 如果查詢歸一化層存在
query = attn.norm_q(query) # 對查詢向量進行歸一化
if attn.norm_k is not None: # 如果鍵歸一化層存在
key = attn.norm_k(key) # 對鍵向量進行歸一化
# Apply RoPE if needed # 如果需要應用旋轉位置編碼
if image_rotary_emb is not None: # 如果影像旋轉嵌入存在
from .embeddings import apply_rotary_emb # 匯入應用旋轉嵌入的函式
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) # 應用旋轉嵌入到查詢向量的後半部分
if not attn.is_cross_attention: # 如果不是交叉注意力
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) # 應用旋轉嵌入到鍵向量的後半部分
hidden_states = F.scaled_dot_product_attention( # 計算縮放點積注意力
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False # 輸入查詢、鍵和值,以及注意力掩碼
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) # 轉置和重塑隱藏狀態以合併頭維度
# linear proj # 線性投影
hidden_states = attn.to_out[0](hidden_states) # 對隱藏狀態應用輸出線性變換
# dropout # 進行dropout操作
hidden_states = attn.to_out[1](hidden_states) # 對隱藏狀態應用dropout
encoder_hidden_states, hidden_states = hidden_states.split( # 將隱藏狀態分割為編碼器和當前隱藏狀態
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 # 根據文字序列長度和剩餘部分進行分割
)
return hidden_states, encoder_hidden_states # 返回當前隱藏狀態和編碼器隱藏狀態
# 定義一個用於實現 CogVideoX 模型的縮放點積注意力的處理器類
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
# 初始化方法
def __init__(self):
# 檢查 F 是否具有 scaled_dot_product_attention 屬性,如果沒有則丟擲匯入錯誤
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
# 定義可呼叫方法,處理注意力計算
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# 獲取編碼器隱藏狀態的序列長度
text_seq_length = encoder_hidden_states.size(1)
# 將編碼器和當前隱藏狀態按維度 1 連線
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 獲取批次大小和序列長度,依據編碼器隱藏狀態是否為 None
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
# 如果提供了注意力掩碼,則準備掩碼
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# 將掩碼調整為適當的形狀
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
# 將隱藏狀態轉換為查詢、鍵、值
qkv = attn.to_qkv(hidden_states)
# 計算每個部分的大小
split_size = qkv.shape[-1] // 3
# 分割成查詢、鍵和值
query, key, value = torch.split(qkv, split_size, dim=-1)
# 獲取鍵的內部維度
inner_dim = key.shape[-1]
# 計算每個頭的維度
head_dim = inner_dim // attn.heads
# 調整查詢、鍵和值的形狀以適應多頭注意力
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# 如果存在查詢的歸一化,則應用歸一化
if attn.norm_q is not None:
query = attn.norm_q(query)
# 如果存在鍵的歸一化,則應用歸一化
if attn.norm_k is not None:
key = attn.norm_k(key)
# 如果需要應用 RoPE
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
# 對查詢的特定部分應用旋轉嵌入
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
# 如果不是交叉注意力,則對鍵的特定部分應用旋轉嵌入
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
# 計算縮放點積注意力
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# 調整隱藏狀態的形狀以便輸出
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# 線性投影
hidden_states = attn.to_out[0](hidden_states)
# 應用 dropout
hidden_states = attn.to_out[1](hidden_states)
# 將隱藏狀態拆分為編碼器隱藏狀態和當前隱藏狀態
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
# 返回當前隱藏狀態和編碼器隱藏狀態
return hidden_states, encoder_hidden_states
# 定義用於實現記憶體高效注意力的處理器類
class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
# 文件字串,說明可選引數 attention_op 的作用
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
使用的基本注意力運算子,推薦設定為 `None` 讓 xFormers 選擇最佳運算子
"""
# 建構函式,初始化注意力運算子
def __init__(self, attention_op: Optional[Callable] = None):
# 將傳入的注意力運算子賦值給例項變數
self.attention_op = attention_op
# 可呼叫方法,用於執行注意力計算
def __call__(
self,
attn: Attention, # 注意力物件
hidden_states: torch.Tensor, # 隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 編碼器隱藏狀態,預設為 None
attention_mask: Optional[torch.Tensor] = None, # 注意力掩碼,預設為 None
) -> torch.Tensor:
# 將當前隱藏狀態儲存為殘差以便後續使用
residual = hidden_states
# 調整隱藏狀態的形狀並轉置
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
# 獲取批次大小和序列長度
batch_size, sequence_length, _ = hidden_states.shape
# 準備注意力掩碼
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# 如果沒有編碼器隱藏狀態,則將其設定為當前的隱藏狀態
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# 如果需要,則對編碼器隱藏狀態進行歸一化處理
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# 對隱藏狀態進行分組歸一化處理
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
# 生成查詢向量
query = attn.to_q(hidden_states)
# 將查詢向量從頭部維度轉換為批次維度
query = attn.head_to_batch_dim(query)
# 對編碼器隱藏狀態進行鍵和值的投影
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# 將編碼器隱藏狀態的鍵和值轉換為批次維度
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
# 如果不是僅使用交叉注意力
if not attn.only_cross_attention:
# 生成當前隱藏狀態的鍵和值
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# 轉換鍵和值到批次維度
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
# 將編碼器的鍵和值與當前的鍵和值連線起來
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
# 如果僅使用交叉注意力,則直接使用編碼器的鍵和值
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
# 計算高效的注意力
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
# 將結果轉換為查詢的 dtype
hidden_states = hidden_states.to(query.dtype)
# 將隱藏狀態從批次維度轉換回頭部維度
hidden_states = attn.batch_to_head_dim(hidden_states)
# 線性變換
hidden_states = attn.to_out[0](hidden_states)
# 應用 dropout
hidden_states = attn.to_out[1](hidden_states)
# 調整隱藏狀態的形狀以匹配殘差
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
# 將當前隱藏狀態與殘差相加
hidden_states = hidden_states + residual
# 返回最終的隱藏狀態
return hidden_states
# 定義一個用於實現基於 xFormers 的記憶體高效注意力的處理器類
class XFormersAttnProcessor:
r"""
處理器,用於實現基於 xFormers 的記憶體高效注意力。
引數:
attention_op (`Callable`, *可選*, 預設為 `None`):
基礎
[運算子](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase),
用作注意力運算子。建議將其設定為 `None`,並讓 xFormers 選擇最佳運算子。
"""
# 初始化方法,接受一個可選的注意力運算子
def __init__(self, attention_op: Optional[Callable] = None):
# 將傳入的注意力運算子賦值給例項變數
self.attention_op = attention_op
# 定義可呼叫方法,用於執行注意力計算
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
# 定義一個用於實現 flash attention 的處理器類,使用 torch_npu
class AttnProcessorNPU:
r"""
處理器,用於使用 torch_npu 實現 flash attention。torch_npu 僅支援 fp16 和 bf16 資料型別。如果
使用 fp32,將使用 F.scaled_dot_product_attention 進行計算,但在 NPU 上加速效果不明顯。
"""
# 初始化方法
def __init__(self):
# 檢查是否可用 torch_npu,如果不可用則丟擲異常
if not is_torch_npu_available():
raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
# 定義可呼叫方法,用於執行注意力計算
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
# 定義一個用於實現 scaled dot-product attention 的處理器類,預設在 PyTorch 2.0 中啟用
class AttnProcessor2_0:
r"""
處理器,用於實現 scaled dot-product attention(如果您使用的是 PyTorch 2.0,預設啟用)。
"""
# 初始化方法
def __init__(self):
# 檢查 F 中是否有 scaled_dot_product_attention 屬性,如果沒有則丟擲異常
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
# 定義可呼叫方法,用於執行注意力計算
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
# 定義一個用於實現 scaled dot-product attention 的處理器類,適用於穩定音訊模型
class StableAudioAttnProcessor2_0:
r"""
處理器,用於實現 scaled dot-product attention(如果您使用的是 PyTorch 2.0,預設啟用)。此處理器用於
穩定音訊模型。它在查詢和鍵向量上應用旋轉嵌入,並允許 MHA、GQA 或 MQA。
"""
# 初始化方法
def __init__(self):
# 檢查 F 中是否有 scaled_dot_product_attention 屬性,如果沒有則丟擲異常
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
# 定義方法,用於應用部分旋轉嵌入
def apply_partial_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Tuple[torch.Tensor],
# 定義返回型別為 torch.Tensor 的函式
) -> torch.Tensor:
# 從當前模組匯入 apply_rotary_emb 函式
from .embeddings import apply_rotary_emb
# 獲取頻率餘弦的最後一個維度大小,用於旋轉
rot_dim = freqs_cis[0].shape[-1]
# 將輸入張量 x 劃分為需要旋轉和不需要旋轉的部分
x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
# 應用旋轉嵌入到需要旋轉的部分
x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
# 將旋轉後的部分與未旋轉的部分在最後一個維度上連線
out = torch.cat((x_rotated, x_unrotated), dim=-1)
# 返回連線後的輸出張量
return out
# 定義可呼叫方法,接收注意力和隱藏狀態
def __call__(
self,
# 輸入的注意力物件
attn: Attention,
# 隱藏狀態的張量
hidden_states: torch.Tensor,
# 可選的編碼器隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可選的注意力掩碼張量
attention_mask: Optional[torch.Tensor] = None,
# 可選的旋轉嵌入張量
rotary_emb: Optional[torch.Tensor] = None,
# 定義 HunyuanAttnProcessor2_0 類,處理縮放的點積注意力
class HunyuanAttnProcessor2_0:
r"""
處理器用於實現縮放的點積注意力(如果使用 PyTorch 2.0,預設啟用)。這是
HunyuanDiT 模型中使用的。它在查詢和鍵向量上應用歸一化層和旋轉嵌入。
"""
# 初始化方法
def __init__(self):
# 檢查 F 中是否有 scaled_dot_product_attention 屬性
if not hasattr(F, "scaled_dot_product_attention"):
# 如果沒有,則丟擲匯入錯誤,提示需要升級 PyTorch 到 2.0
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
# 定義呼叫方法
def __call__(
self,
attn: Attention, # 注意力機制例項
hidden_states: torch.Tensor, # 當前隱藏狀態的張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 編碼器隱藏狀態的可選張量
attention_mask: Optional[torch.Tensor] = None, # 注意力掩碼的可選張量
temb: Optional[torch.Tensor] = None, # 時間嵌入的可選張量
image_rotary_emb: Optional[torch.Tensor] = None, # 影像旋轉嵌入的可選張量
class FusedHunyuanAttnProcessor2_0:
r"""
處理器用於實現縮放的點積注意力(如果使用 PyTorch 2.0,預設啟用),帶有融合的
投影層。這是 HunyuanDiT 模型中使用的。它在查詢和鍵向量上應用歸一化層和旋轉嵌入。
"""
# 初始化方法
def __init__(self):
# 檢查 F 中是否有 scaled_dot_product_attention 屬性
if not hasattr(F, "scaled_dot_product_attention"):
# 如果沒有,則丟擲匯入錯誤,提示需要升級 PyTorch 到 2.0
raise ImportError(
"FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
# 定義呼叫方法
def __call__(
self,
attn: Attention, # 注意力機制例項
hidden_states: torch.Tensor, # 當前隱藏狀態的張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 編碼器隱藏狀態的可選張量
attention_mask: Optional[torch.Tensor] = None, # 注意力掩碼的可選張量
temb: Optional[torch.Tensor] = None, # 時間嵌入的可選張量
image_rotary_emb: Optional[torch.Tensor] = None, # 影像旋轉嵌入的可選張量
class PAGHunyuanAttnProcessor2_0:
r"""
處理器用於實現縮放的點積注意力(如果使用 PyTorch 2.0,預設啟用)。這是
HunyuanDiT 模型中使用的。它在查詢和鍵向量上應用歸一化層和旋轉嵌入。該處理器
變體採用了 [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377)。
"""
# 初始化方法
def __init__(self):
# 檢查 F 中是否有 scaled_dot_product_attention 屬性
if not hasattr(F, "scaled_dot_product_attention"):
# 如果沒有,則丟擲匯入錯誤,提示需要升級 PyTorch 到 2.0
raise ImportError(
"PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
# 定義呼叫方法
def __call__(
self,
attn: Attention, # 注意力機制例項
hidden_states: torch.Tensor, # 當前隱藏狀態的張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 編碼器隱藏狀態的可選張量
attention_mask: Optional[torch.Tensor] = None, # 注意力掩碼的可選張量
temb: Optional[torch.Tensor] = None, # 時間嵌入的可選張量
image_rotary_emb: Optional[torch.Tensor] = None, # 影像旋轉嵌入的可選張量
class PAGCFGHunyuanAttnProcessor2_0:
r"""
處理器用於實現縮放的點積注意力(如果使用 PyTorch 2.0,預設啟用)。這是
HunyuanDiT 模型中使用的。它在查詢和鍵向量上應用歸一化層和旋轉嵌入。該處理器
變體採用了 [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377)。
"""
# 初始化方法,用於建立類的例項
def __init__(self):
# 檢查模組 F 是否具有屬性 "scaled_dot_product_attention"
if not hasattr(F, "scaled_dot_product_attention"):
# 如果沒有該屬性,則丟擲 ImportError,提示使用者升級 PyTorch
raise ImportError(
"PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
# 可呼叫方法,允許類的例項像函式一樣被呼叫
def __call__(
self,
attn: Attention, # 注意力機制物件
hidden_states: torch.Tensor, # 當前隱藏狀態的張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 編碼器的隱藏狀態,可選引數
attention_mask: Optional[torch.Tensor] = None, # 注意力掩碼,可選引數
temb: Optional[torch.Tensor] = None, # 時間嵌入,可選引數
image_rotary_emb: Optional[torch.Tensor] = None, # 影像旋轉嵌入,可選引數
# 定義一個用於實現縮放點積注意力的處理器類
class LuminaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
"""
# 初始化方法
def __init__(self):
# 檢查 PyTorch 是否具有縮放點積注意力功能
if not hasattr(F, "scaled_dot_product_attention"):
# 如果沒有,丟擲匯入錯誤,提示使用者升級 PyTorch 到 2.0
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
# 定義呼叫方法,使類例項可呼叫
def __call__(
self,
# 接收注意力物件
attn: Attention,
# 接收隱藏狀態張量
hidden_states: torch.Tensor,
# 接收編碼器隱藏狀態張量
encoder_hidden_states: torch.Tensor,
# 可選的注意力掩碼張量
attention_mask: Optional[torch.Tensor] = None,
# 可選的查詢旋轉嵌入張量
query_rotary_emb: Optional[torch.Tensor] = None,
# 可選的鍵旋轉嵌入張量
key_rotary_emb: Optional[torch.Tensor] = None,
# 可選的基本序列長度
base_sequence_length: Optional[int] = None,
) -> torch.Tensor: # 函式返回一個張量,表示處理後的隱藏狀態
from .embeddings import apply_rotary_emb # 從當前包匯入應用旋轉嵌入的函式
input_ndim = hidden_states.ndim # 獲取隱藏狀態的維度數
if input_ndim == 4: # 如果隱藏狀態是四維張量
batch_size, channel, height, width = hidden_states.shape # 解包出批次大小、通道、高度和寬度
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) # 重塑並轉置隱藏狀態
batch_size, sequence_length, _ = hidden_states.shape # 解包出批次大小和序列長度
# Get Query-Key-Value Pair # 獲取查詢、鍵、值對
query = attn.to_q(hidden_states) # 將隱藏狀態轉換為查詢張量
key = attn.to_k(encoder_hidden_states) # 將編碼器的隱藏狀態轉換為鍵張量
value = attn.to_v(encoder_hidden_states) # 將編碼器的隱藏狀態轉換為值張量
query_dim = query.shape[-1] # 獲取查詢的最後一個維度(特徵維度)
inner_dim = key.shape[-1] # 獲取鍵的最後一個維度
head_dim = query_dim // attn.heads # 計算每個頭的維度
dtype = query.dtype # 獲取查詢張量的資料型別
# Get key-value heads # 獲取鍵值頭的數量
kv_heads = inner_dim // head_dim # 計算每個頭的鍵值數量
# Apply Query-Key Norm if needed # 如果需要,應用查詢-鍵歸一化
if attn.norm_q is not None: # 如果定義了查詢的歸一化
query = attn.norm_q(query) # 對查詢進行歸一化
if attn.norm_k is not None: # 如果定義了鍵的歸一化
key = attn.norm_k(key) # 對鍵進行歸一化
query = query.view(batch_size, -1, attn.heads, head_dim) # 重塑查詢張量以適應頭的維度
key = key.view(batch_size, -1, kv_heads, head_dim) # 重塑鍵張量以適應頭的維度
value = value.view(batch_size, -1, kv_heads, head_dim) # 重塑值張量以適應頭的維度
# Apply RoPE if needed # 如果需要,應用旋轉位置嵌入
if query_rotary_emb is not None: # 如果定義了查詢的旋轉嵌入
query = apply_rotary_emb(query, query_rotary_emb, use_real=False) # 應用旋轉嵌入到查詢
if key_rotary_emb is not None: # 如果定義了鍵的旋轉嵌入
key = apply_rotary_emb(key, key_rotary_emb, use_real=False) # 應用旋轉嵌入到鍵
query, key = query.to(dtype), key.to(dtype) # 將查詢和鍵轉換為相同的資料型別
# Apply proportional attention if true # 如果為真,應用比例注意力
if key_rotary_emb is None: # 如果沒有鍵的旋轉嵌入
softmax_scale = None # 設定縮放因子為 None
else: # 如果有鍵的旋轉嵌入
if base_sequence_length is not None: # 如果定義了基礎序列長度
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale # 計算縮放因子
else: # 如果沒有定義基礎序列長度
softmax_scale = attn.scale # 使用注意力的縮放因子
# perform Grouped-query Attention (GQA) # 執行分組查詢注意力
n_rep = attn.heads // kv_heads # 計算每個鍵值頭的重複數量
if n_rep >= 1: # 如果重複數量大於等於 1
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) # 擴充套件並重復鍵
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) # 擴充套件並重復值
# scaled_dot_product_attention expects attention_mask shape to be # 縮放點積注意力期望的注意力掩碼形狀
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) # 將注意力掩碼轉換為布林值並調整形狀
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) # 擴充套件注意力掩碼以匹配頭的數量
query = query.transpose(1, 2) # 轉置查詢張量
key = key.transpose(1, 2) # 轉置鍵張量
value = value.transpose(1, 2) # 轉置值張量
# the output of sdp = (batch, num_heads, seq_len, head_dim) # 縮放點積注意力的輸出形狀
# TODO: add support for attn.scale when we move to Torch 2.1 # TODO: 在遷移到 Torch 2.1 時支援 attn.scale
hidden_states = F.scaled_dot_product_attention( # 計算縮放點積注意力
query, key, value, attn_mask=attention_mask, scale=softmax_scale # 輸入查詢、鍵、值及注意力掩碼和縮放因子
)
hidden_states = hidden_states.transpose(1, 2).to(dtype) # 轉置輸出並轉換為相應的資料型別
return hidden_states # 返回處理後的隱藏狀態
# 定義一個用於實現縮放點積注意力的處理器類,預設啟用(如果使用 PyTorch 2.0)
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is currently 🧪 experimental in nature and can change in future.
</Tip>
"""
# 初始化方法
def __init__(self):
# 檢查 F 庫是否具有縮放點積注意力功能
if not hasattr(F, "scaled_dot_product_attention"):
# 如果沒有,丟擲匯入錯誤,提示使用者升級 PyTorch 版本
raise ImportError(
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
)
# 呼叫方法,處理注意力計算
def __call__(
self,
attn: Attention, # 注意力模組
hidden_states: torch.Tensor, # 隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 編碼器的隱藏狀態,可選
attention_mask: Optional[torch.Tensor] = None, # 注意力掩碼,可選
temb: Optional[torch.Tensor] = None, # 時間嵌入,可選
*args, # 可變位置引數
**kwargs, # 可變關鍵字引數
):
pass # 此處省略具體實現
# 定義一個用於實現記憶體高效注意力的處理器類,使用 xFormers 方法
class CustomDiffusionXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
Args:
train_kv (`bool`, defaults to `True`):
Whether to newly train the key and value matrices corresponding to the text features.
train_q_out (`bool`, defaults to `True`):
Whether to newly train query matrices corresponding to the latent image features.
hidden_size (`int`, *optional*, defaults to `None`):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`):
The number of channels in the `encoder_hidden_states`.
out_bias (`bool`, defaults to `True`):
Whether to include the bias parameter in `train_q_out`.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
"""
# 初始化方法,設定各種引數
def __init__(
self,
train_kv: bool = True, # 是否訓練與文字特徵對應的鍵值矩陣
train_q_out: bool = False, # 是否訓練與潛在影像特徵對應的查詢矩陣
hidden_size: Optional[int] = None, # 注意力層的隱藏大小
cross_attention_dim: Optional[int] = None, # 編碼器隱藏狀態的通道數
out_bias: bool = True, # 是否在 train_q_out 中包含偏置引數
dropout: float = 0.0, # 使用的丟棄機率
attention_op: Optional[Callable] = None, # 要使用的基礎注意力操作
):
pass # 此處省略具體實現
):
# 呼叫父類的初始化方法
super().__init__()
# 儲存訓練鍵值對的標誌
self.train_kv = train_kv
# 儲存訓練查詢輸出的標誌
self.train_q_out = train_q_out
# 儲存隱藏層大小
self.hidden_size = hidden_size
# 儲存交叉注意力維度
self.cross_attention_dim = cross_attention_dim
# 儲存注意力操作型別
self.attention_op = attention_op
# `_custom_diffusion` id 用於簡化序列化和載入
if self.train_kv:
# 建立線性層,將交叉注意力維度或隱藏層大小對映到隱藏層大小,且不使用偏置
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
# 建立線性層,將交叉注意力維度或隱藏層大小對映到隱藏層大小,且不使用偏置
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
if self.train_q_out:
# 建立線性層,將隱藏層大小對映到隱藏層大小,且不使用偏置
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
# 建立一個空的模組列表以儲存輸出相關的層
self.to_out_custom_diffusion = nn.ModuleList([])
# 將線性層新增到模組列表中,用於輸出對映
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
# 將 Dropout 層新增到模組列表中,用於正則化
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
def __call__(
# 定義呼叫方法,接收注意力物件和隱藏狀態張量
self,
attn: Attention,
hidden_states: torch.Tensor,
# 可選引數:編碼器的隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可選引數:注意力掩碼張量
attention_mask: Optional[torch.Tensor] = None,
# 定義函式的返回型別為 torch.Tensor
) -> torch.Tensor:
# 獲取批次大小和序列長度,根據 encoder_hidden_states 是否為 None 決定來源
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
# 準備注意力掩碼
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# 判斷是否在訓練階段並應用不同的查詢生成方式
if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else:
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
# 判斷是否存在編碼器隱藏狀態,並設定 crossattn 標誌
if encoder_hidden_states is None:
crossattn = False
encoder_hidden_states = hidden_states
else:
crossattn = True
# 如果需要對編碼器隱藏狀態進行歸一化處理
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# 判斷是否在訓練階段並應用不同的鍵值生成方式
if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# 如果使用交叉注意力,進行鍵值的分離和處理
if crossattn:
detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :] * 0.0
key = detach * key + (1 - detach) * key.detach()
value = detach * value + (1 - detach) * value.detach()
# 將查詢、鍵、值轉換為批處理維度並保持連續性
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
# 使用記憶體高效的注意力計算隱藏狀態
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
# 將隱藏狀態轉換為查詢的型別
hidden_states = hidden_states.to(query.dtype)
# 將隱藏狀態轉換回頭部維度
hidden_states = attn.batch_to_head_dim(hidden_states)
# 根據訓練標誌決定輸出的處理方式
if self.train_q_out:
# 線性變換
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
# 進行 dropout 操作
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
else:
# 線性變換
hidden_states = attn.to_out[0](hidden_states)
# 進行 dropout 操作
hidden_states = attn.to_out[1](hidden_states)
# 返回處理後的隱藏狀態
return hidden_states
# 自定義擴散注意力處理器類,繼承自 PyTorch 的 nn.Module
class CustomDiffusionAttnProcessor2_0(nn.Module):
r"""
用於實現自定義擴散方法的注意力處理器,使用 PyTorch 2.0 的記憶體高效縮放
點積注意力。
引數:
train_kv (`bool`, 預設值為 `True`):
是否新訓練與文字特徵對應的鍵和值矩陣。
train_q_out (`bool`, 預設值為 `True`):
是否新訓練與潛在影像特徵對應的查詢矩陣。
hidden_size (`int`, *可選*, 預設值為 `None`):
注意力層的隱藏大小。
cross_attention_dim (`int`, *可選*, 預設值為 `None`):
`encoder_hidden_states` 中的通道數。
out_bias (`bool`, 預設值為 `True`):
是否在 `train_q_out` 中包含偏置引數。
dropout (`float`, *可選*, 預設值為 0.0):
使用的 dropout 機率。
"""
# 初始化方法,設定類的屬性
def __init__(
self,
train_kv: bool = True,
train_q_out: bool = True,
hidden_size: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
dropout: float = 0.0,
):
# 呼叫父類的初始化方法
super().__init__()
# 設定是否訓練鍵值矩陣的標誌
self.train_kv = train_kv
# 設定是否訓練查詢輸出矩陣的標誌
self.train_q_out = train_q_out
# 設定隱藏層的大小
self.hidden_size = hidden_size
# 設定交叉注意力的維度
self.cross_attention_dim = cross_attention_dim
# 如果需要訓練鍵值矩陣,則建立對應的線性層
if self.train_kv:
# 建立從交叉注意力維度到隱藏層的線性變換,且不使用偏置
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
# 建立從交叉注意力維度到隱藏層的線性變換,且不使用偏置
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
# 如果需要訓練查詢輸出,則建立對應的線性層
if self.train_q_out:
# 建立從隱藏層到隱藏層的線性變換,且不使用偏置
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
# 建立一個空的模組列表,用於儲存輸出層
self.to_out_custom_diffusion = nn.ModuleList([])
# 將線性層新增到輸出模組列表中
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
# 新增 dropout 層到輸出模組列表中
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
# 定義類的呼叫方法,處理輸入的注意力和隱藏狀態
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: # 定義返回型別為 torch.Tensor
batch_size, sequence_length, _ = hidden_states.shape # 解包 hidden_states 的形狀,獲取批大小和序列長度
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # 準備注意力掩碼
if self.train_q_out: # 檢查是否在訓練查詢輸出
query = self.to_q_custom_diffusion(hidden_states) # 使用自定義擴散方法生成查詢向量
else: # 否則
query = attn.to_q(hidden_states) # 使用標準方法生成查詢向量
if encoder_hidden_states is None: # 檢查編碼器隱藏狀態是否為空
crossattn = False # 設定交叉注意力標誌為假
encoder_hidden_states = hidden_states # 將編碼器隱藏狀態設定為隱藏狀態
else: # 如果編碼器隱藏狀態不為空
crossattn = True # 設定交叉注意力標誌為真
if attn.norm_cross: # 如果需要歸一化
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) # 歸一化編碼器隱藏狀態
if self.train_kv: # 檢查是否在訓練鍵值對
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) # 生成鍵向量
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) # 生成值向量
key = key.to(attn.to_q.weight.dtype) # 將鍵向量轉換為查詢權重的資料型別
value = value.to(attn.to_q.weight.dtype) # 將值向量轉換為查詢權重的資料型別
else: # 否則
key = attn.to_k(encoder_hidden_states) # 使用標準方法生成鍵向量
value = attn.to_v(encoder_hidden_states) # 使用標準方法生成值向量
if crossattn: # 如果進行交叉注意力
detach = torch.ones_like(key) # 建立與鍵相同形狀的全1張量
detach[:, :1, :] = detach[:, :1, :] * 0.0 # 將第一時間步的值設定為0
key = detach * key + (1 - detach) * key.detach() # 根據 detach 張量計算鍵的最終值
value = detach * value + (1 - detach) * value.detach() # 根據 detach 張量計算值的最終值
inner_dim = hidden_states.shape[-1] # 獲取隱藏狀態的最後一維大小
head_dim = inner_dim // attn.heads # 計算每個頭的維度
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # 重新調整查詢的形狀並轉置
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # 重新調整鍵的形狀並轉置
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # 重新調整值的形狀並轉置
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention( # 計算縮放點積注意力
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False # 輸入查詢、鍵和值以及注意力掩碼
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) # 轉置並重塑隱藏狀態
hidden_states = hidden_states.to(query.dtype) # 將隱藏狀態轉換為查詢的型別
if self.train_q_out: # 如果在訓練查詢輸出
# linear proj
hidden_states = self.to_out_custom_diffusion[0](hidden_states) # 線性變換
# dropout
hidden_states = self.to_out_custom_diffusion[1](hidden_states) # 應用 dropout
else: # 否則
# linear proj
hidden_states = attn.to_out[0](hidden_states) # 線性變換
# dropout
hidden_states = attn.to_out[1](hidden_states) # 應用 dropout
return hidden_states # 返回最終的隱藏狀態
# 定義一個用於實現切片注意力的處理器類
class SlicedAttnProcessor:
r"""
處理器用於實現切片注意力。
引數:
slice_size (`int`, *可選*):
計算注意力的步驟數量。使用的切片數量為 `attention_head_dim // slice_size`,並且
`attention_head_dim` 必須是 `slice_size` 的整數倍。
"""
# 初始化方法,接受切片大小作為引數
def __init__(self, slice_size: int):
# 將傳入的切片大小儲存為例項變數
self.slice_size = slice_size
# 定義可呼叫方法,以便例項可以像函式一樣被呼叫
def __call__(
self,
attn: Attention, # 輸入的注意力物件
hidden_states: torch.Tensor, # 當前隱藏狀態的張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 編碼器的隱藏狀態,可選引數
attention_mask: Optional[torch.Tensor] = None, # 注意力掩碼,可選引數
) -> torch.Tensor:
# 儲存輸入的隱藏狀態,用於殘差連線
residual = hidden_states
# 獲取隱藏狀態的維度數量
input_ndim = hidden_states.ndim
# 如果輸入維度是4,調整隱藏狀態的形狀
if input_ndim == 4:
# 解包隱藏狀態的形狀為批次大小、通道、高度和寬度
batch_size, channel, height, width = hidden_states.shape
# 將隱藏狀態展平為(batch_size, channel, height * width)並轉置
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# 確定序列長度和批次大小,根據是否有編碼器隱藏狀態決定
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
# 準備注意力掩碼
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# 如果有分組歸一化,應用於隱藏狀態
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
# 將隱藏狀態轉換為查詢向量
query = attn.to_q(hidden_states)
# 獲取查詢向量的最後一維大小
dim = query.shape[-1]
# 將查詢向量轉換為批次維度格式
query = attn.head_to_batch_dim(query)
# 如果沒有編碼器隱藏狀態,使用當前隱藏狀態
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# 如果需要,歸一化編碼器隱藏狀態
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# 將編碼器隱藏狀態轉換為鍵和值
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# 將鍵和值轉換為批次維度格式
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
# 獲取查詢向量的批次大小和令牌數量
batch_size_attention, query_tokens, _ = query.shape
# 初始化隱藏狀態張量為零
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
# 按切片處理查詢、鍵和值
for i in range((batch_size_attention - 1) // self.slice_size + 1):
# 計算當前切片的起始和結束索引
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
# 獲取當前切片的查詢、鍵和注意力掩碼
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
# 計算當前切片的注意力分數
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
# 將注意力分數與值相乘,獲取注意力結果
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
# 將注意力結果儲存到隱藏狀態中
hidden_states[start_idx:end_idx] = attn_slice
# 將隱藏狀態轉換回頭維度格式
hidden_states = attn.batch_to_head_dim(hidden_states)
# 對隱藏狀態進行線性變換
hidden_states = attn.to_out[0](hidden_states)
# 應用 dropout
hidden_states = attn.to_out[1](hidden_states)
# 如果輸入維度是4,調整隱藏狀態的形狀
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 如果需要殘差連線,將殘差加到隱藏狀態中
if attn.residual_connection:
hidden_states = hidden_states + residual
# 根據縮放因子調整輸出
hidden_states = hidden_states / attn.rescale_output_factor
# 返回最終的隱藏狀態
return hidden_states
# 定義一個處理器類,用於實現切片注意力,並額外學習鍵和值矩陣
class SlicedAttnAddedKVProcessor:
r"""
處理器,用於實現帶有額外可學習的鍵和值矩陣的切片注意力,用於文字編碼器。
引數:
slice_size (`int`, *可選*):
計算注意力的步數。使用 `attention_head_dim // slice_size` 的切片數量,
並且 `attention_head_dim` 必須是 `slice_size` 的倍數。
"""
# 初始化方法,接收切片大小作為引數
def __init__(self, slice_size):
# 將傳入的切片大小賦值給例項變數
self.slice_size = slice_size
# 定義呼叫方法,使類的例項可以像函式一樣被呼叫
def __call__(
self,
attn: "Attention", # 接收一個注意力物件
hidden_states: torch.Tensor, # 輸入的隱藏狀態張量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可選的編碼器隱藏狀態張量
attention_mask: Optional[torch.Tensor] = None, # 可選的注意力掩碼張量
temb: Optional[torch.Tensor] = None, # 可選的時間嵌入張量
# 返回型別為 torch.Tensor
) -> torch.Tensor:
# 儲存輸入的隱藏狀態作為殘差
residual = hidden_states
# 如果空間歸一化存在,則應用於隱藏狀態和時間嵌入
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
# 將隱藏狀態重塑為三維張量並轉置維度
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
# 獲取批次大小和序列長度
batch_size, sequence_length, _ = hidden_states.shape
# 準備注意力掩碼
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# 如果沒有編碼器隱藏狀態,則將其設定為當前隱藏狀態
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# 如果需要歸一化編碼器隱藏狀態
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# 對隱藏狀態應用組歸一化並轉置維度
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
# 生成查詢向量
query = attn.to_q(hidden_states)
# 獲取查詢向量的最後一維大小
dim = query.shape[-1]
# 將查詢向量的維度轉換為批次維度
query = attn.head_to_batch_dim(query)
# 生成編碼器隱藏狀態的鍵和值的投影
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# 將編碼器隱藏狀態的鍵和值轉換為批次維度
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
# 如果不只使用交叉注意力
if not attn.only_cross_attention:
# 生成當前隱藏狀態的鍵和值
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# 將鍵和值轉換為批次維度
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
# 將編碼器鍵與當前鍵拼接
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
# 將編碼器值與當前值拼接
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
# 直接使用編碼器的鍵和值
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
# 獲取批次大小、查詢令牌數量和最後一維大小
batch_size_attention, query_tokens, _ = query.shape
# 初始化隱藏狀態為零張量
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
# 按切片大小進行迭代處理
for i in range((batch_size_attention - 1) // self.slice_size + 1):
start_idx = i * self.slice_size # 切片起始索引
end_idx = (i + 1) * self.slice_size # 切片結束索引
# 獲取當前查詢切片、鍵切片和注意力掩碼切片
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
# 獲取當前切片的注意力分數
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
# 將注意力分數與當前值進行批次矩陣乘法
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
# 將結果儲存到隱藏狀態
hidden_states[start_idx:end_idx] = attn_slice
# 將隱藏狀態的維度轉換回頭部維度
hidden_states = attn.batch_to_head_dim(hidden_states)
# 線性投影
hidden_states = attn.to_out[0](hidden_states)
# 應用丟棄層
hidden_states = attn.to_out[1](hidden_states)
# 轉置最後兩維並重塑為殘差形狀
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
# 將殘差新增到當前隱藏狀態
hidden_states = hidden_states + residual
# 返回最終的隱藏狀態
return hidden_states
# 定義一個空間歸一化類,繼承自 nn.Module
class SpatialNorm(nn.Module):
"""
空間條件歸一化,定義在 https://arxiv.org/abs/2209.09002 中。
引數:
f_channels (`int`):
輸入到組歸一化層的通道數,以及空間歸一化層的輸出通道數。
zq_channels (`int`):
量化向量的通道數,如論文中所述。
"""
# 初始化方法,接收通道數作為引數
def __init__(
self,
f_channels: int,
zq_channels: int,
):
# 呼叫父類的初始化方法
super().__init__()
# 建立組歸一化層,指定通道數、組數和其他引數
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
# 建立卷積層,輸入通道為 zq_channels,輸出通道為 f_channels
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
# 建立另一個卷積層,功能相同但用於偏置項
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
# 前向傳播方法,定義輸入和輸出
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
# 獲取輸入張量 f 的空間尺寸
f_size = f.shape[-2:]
# 對 zq 進行上取樣,使其尺寸與 f 相同
zq = F.interpolate(zq, size=f_size, mode="nearest")
# 對輸入 f 應用歸一化層
norm_f = self.norm_layer(f)
# 計算新的輸出張量,透過歸一化後的 f 和卷積結果結合
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
# 返回新的張量
return new_f
# 定義一個 IPAdapter 注意力處理器類,繼承自 nn.Module
class IPAdapterAttnProcessor(nn.Module):
r"""
多個 IP-Adapter 的注意力處理器。
引數:
hidden_size (`int`):
注意力層的隱藏尺寸。
cross_attention_dim (`int`):
`encoder_hidden_states` 中的通道數。
num_tokens (`int`, `Tuple[int]` 或 `List[int]`, 預設為 `(4,)`):
影像特徵的上下文長度。
scale (`float` 或 List[`float`], 預設為 1.0):
影像提示的權重縮放。
"""
# 初始化方法,接收多個引數
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
# 呼叫父類的初始化方法
super().__init__()
# 儲存隱藏尺寸
self.hidden_size = hidden_size
# 儲存交叉注意力維度
self.cross_attention_dim = cross_attention_dim
# 確保 num_tokens 為元組或列表
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
# 儲存 num_tokens
self.num_tokens = num_tokens
# 確保 scale 為列表
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
# 驗證 scale 和 num_tokens 長度相同
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
# 儲存縮放因子
self.scale = scale
# 建立用於鍵的線性變換列表
self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
# 建立用於值的線性變換列表
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
# 定義呼叫方法,處理輸入的注意力資訊
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.Tensor] = None,
class IPAdapterAttnProcessor2_0(torch.nn.Module):
r"""
PyTorch 2.0 的 IP-Adapter 注意力處理器。
# 定義引數說明文件,列出類建構函式的引數及其型別和預設值
Args:
hidden_size (`int`):
注意力層的隱藏層大小
cross_attention_dim (`int`):
編碼器隱藏狀態的通道數
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
影像特徵的上下文長度
scale (`float` or `List[float]`, defaults to 1.0):
影像提示的權重比例
"""
# 初始化類的建構函式,設定類屬性
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
# 呼叫父類的建構函式
super().__init__()
# 檢查 PyTorch 是否支援縮放點積注意力
if not hasattr(F, "scaled_dot_product_attention"):
# 如果不支援,丟擲匯入錯誤
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
# 設定隱藏層大小屬性
self.hidden_size = hidden_size
# 設定交叉注意力維度屬性
self.cross_attention_dim = cross_attention_dim
# 如果 num_tokens 不是元組或列表,則將其轉換為列表
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
# 設定 num_tokens 屬性
self.num_tokens = num_tokens
# 如果 scale 不是列表,則建立與 num_tokens 長度相同的列表
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
# 檢查 scale 的長度是否與 num_tokens 相同
if len(scale) != len(num_tokens):
# 如果不同,丟擲值錯誤
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
# 設定 scale 屬性
self.scale = scale
# 建立一個包含多個線性層的模組列表,用於輸入到 K 的對映
self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
# 建立一個包含多個線性層的模組列表,用於輸入到 V 的對映
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
# 定義類的呼叫方法,用於執行注意力計算
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.Tensor] = None,
# 定義用於實現 PAG 的處理器類,使用縮放點積注意力(預設啟用 PyTorch 2.0)
class PAGIdentitySelfAttnProcessor2_0:
r"""
處理器用於實現 PAG,使用縮放點積注意力(預設在 PyTorch 2.0 中啟用)。
PAG 參考: https://arxiv.org/abs/2403.17377
"""
# 初始化函式
def __init__(self):
# 檢查 F 中是否有縮放點積注意力功能
if not hasattr(F, "scaled_dot_product_attention"):
# 如果沒有,則丟擲匯入錯誤,提示需要升級 PyTorch
raise ImportError(
"PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
# 可呼叫方法,定義了注意力處理的輸入引數
def __call__(
self,
attn: Attention, # 輸入的注意力物件
hidden_states: torch.FloatTensor, # 當前的隱藏狀態
encoder_hidden_states: Optional[torch.FloatTensor] = None, # 編碼器的隱藏狀態(可選)
attention_mask: Optional[torch.FloatTensor] = None, # 注意力掩碼(可選)
temb: Optional[torch.FloatTensor] = None, # 額外的時間嵌入(可選)
class PAGCFGIdentitySelfAttnProcessor2_0:
r"""
處理器用於實現 PAG,使用縮放點積注意力(預設啟用 PyTorch 2.0)。
PAG 參考: https://arxiv.org/abs/2403.17377
"""
# 初始化函式
def __init__(self):
# 檢查 F 中是否有縮放點積注意力功能
if not hasattr(F, "scaled_dot_product_attention"):
# 如果沒有,則丟擲匯入錯誤,提示需要升級 PyTorch
raise ImportError(
"PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
# 可呼叫方法,定義了注意力處理的輸入引數
def __call__(
self,
attn: Attention, # 輸入的注意力物件
hidden_states: torch.FloatTensor, # 當前的隱藏狀態
encoder_hidden_states: Optional[torch.FloatTensor] = None, # 編碼器的隱藏狀態(可選)
attention_mask: Optional[torch.FloatTensor] = None, # 注意力掩碼(可選)
temb: Optional[torch.FloatTensor] = None, # 額外的時間嵌入(可選)
class LoRAAttnProcessor:
# 初始化函式
def __init__(self):
# 該類的建構函式,目前沒有初始化操作
pass
class LoRAAttnProcessor2_0:
# 初始化函式
def __init__(self):
# 該類的建構函式,目前沒有初始化操作
pass
class LoRAXFormersAttnProcessor:
# 初始化函式
def __init__(self):
# 該類的建構函式,目前沒有初始化操作
pass
class LoRAAttnAddedKVProcessor:
# 初始化函式
def __init__(self):
# 該類的建構函式,目前沒有初始化操作
pass
# 定義一個包含新增鍵值注意力處理器的元組
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor, # 新增鍵值注意力處理器
SlicedAttnAddedKVProcessor, # 切片新增鍵值注意力處理器
AttnAddedKVProcessor2_0, # 新增鍵值注意力處理器版本2.0
XFormersAttnAddedKVProcessor, # XFormers 新增鍵值注意力處理器
)
# 定義一個包含交叉注意力處理器的元組
CROSS_ATTENTION_PROCESSORS = (
AttnProcessor, # 注意力處理器
AttnProcessor2_0, # 注意力處理器版本2.0
XFormersAttnProcessor, # XFormers 注意力處理器
SlicedAttnProcessor, # 切片注意力處理器
IPAdapterAttnProcessor, # IPAdapter 注意力處理器
IPAdapterAttnProcessor2_0, # IPAdapter 注意力處理器版本2.0
)
# 定義一個包含所有注意力處理器的聯合型別
AttentionProcessor = Union[
AttnProcessor, # 注意力處理器
AttnProcessor2_0, # 注意力處理器版本2.0
FusedAttnProcessor2_0, # 融合注意力處理器版本2.0
XFormersAttnProcessor, # XFormers 注意力處理器
SlicedAttnProcessor, # 切片注意力處理器
AttnAddedKVProcessor, # 新增鍵值注意力處理器
SlicedAttnAddedKVProcessor, # 切片新增鍵值注意力處理器
AttnAddedKVProcessor2_0, # 新增鍵值注意力處理器版本2.0
XFormersAttnAddedKVProcessor, # XFormers 新增鍵值注意力處理器
CustomDiffusionAttnProcessor, # 自定義擴散注意力處理器
CustomDiffusionXFormersAttnProcessor, # 自定義擴散 XFormers 注意力處理器
CustomDiffusionAttnProcessor2_0, # 自定義擴散注意力處理器版本2.0
PAGCFGIdentitySelfAttnProcessor2_0, # PAGCFG 身份自注意力處理器版本2.0
PAGIdentitySelfAttnProcessor2_0, # PAG 身份自注意力處理器版本2.0
PAGCFGHunyuanAttnProcessor2_0, # PAGCGHunyuan 注意力處理器版本2.0
PAGHunyuanAttnProcessor2_0, # PAG Hunyuan 注意力處理器版本2.0
]