GET程式碼學習

csjywu1發表於2024-10-08

這段程式碼展示瞭如何將資料從 DataFrame 形式轉化為 “blocks” 的過程。具體來說,它透過處理 DataFrame 的每一行(每一行代表一個原子),將它們按殘基(residue)分組,每個組被稱為一個 "block"。以下是對程式碼的詳細分析,解釋如何將資料轉化為 "block"。

主要步驟:

  1. 迴圈遍歷每個原子(DataFrame 的每一行)

    • df.itertuples() 用於按行遍歷 DataFrame,每一行都代表一個原子。
    • 從每一行中提取原子所屬的殘基 ID、原子名稱、元素型別及其空間座標等資訊。
  2. 按殘基(residue)分組

    • res_id 是一個字串,它將殘基 ID 和插入程式碼(insertion code)拼接起來形成獨特的識別符號。如果沒有插入程式碼,只使用殘基 ID。
    • 如果 res_id 發生變化(即遇到新的殘基),意味著當前殘基的原子分組已經結束,形成了一個完整的 "block"。
    • 新的 "block" 使用 Block 類來表示,包含殘基的符號(res_symbol)和對應的原子列表(units)。
  3. 跳過氫原子

    • 如果元素為氫(H),則跳過這一行(continue),因為氫原子在許多情況下被忽略。
  4. 建立 Atom 例項

    • 每個原子被建立為一個 Atom 物件,包含原子的名稱、三維座標(x, y, z),以及元素型別。
    • 這些 Atom 物件會被新增到當前殘基的 "block" 中。
  5. 生成 Block 物件

    • 當遍歷到新的殘基時,前一個殘基的 "block" 完成,建立並儲存到 blocks 列表中。
    • 遍歷結束後,最後一個殘基的 "block" 被加入到 blocks 中。

程式碼分析:

1. df_to_blocks 函式:

這是將 DataFrame 中的原子資料分塊(block)的核心函式。

def df_to_blocks(df, key_residue='residue', key_insertion_code='insertion_code', key_resname='resname',
                 key_atom_name='atom_name', key_element='element', key_x='x', key_y='y', key_z='z') -> List[Block]:
    last_res_id, last_res_symbol = None, None  # 儲存上一個殘基的ID和符號
    blocks, units = [], []  # blocks存放每個block,units存放每個block中的原子
    for row in df.itertuples():  # 遍歷每一行(原子)
        residue = getattr(row, key_residue)  # 提取殘基ID
        if key_insertion_code is None:
            res_id = str(residue)
        else:
            insert_code = getattr(row, key_insertion_code)  # 提取插入程式碼
            res_id = f'{residue}{insert_code}'.rstrip()  # 拼接殘基ID和插入程式碼
        
        # 如果殘基ID發生變化,意味著當前block結束
        if res_id != last_res_id:
            block = Block(last_res_symbol, units)  # 建立一個新的Block物件
            blocks.append(block)  # 將Block新增到blocks列表
            units = []  # 清空units以開始新的Block
            last_res_id = res_id  # 更新為新的殘基ID
            last_res_symbol = VOCAB.abrv_to_symbol(getattr(row, key_resname))  # 透過殘基名稱獲取符號
        
        # 處理當前原子
        atom = getattr(row, key_atom_name)  # 獲取原子名稱
        element = getattr(row, key_element)  # 獲取元素型別
        if element == 'H':  # 跳過氫原子
            continue
        units.append(Atom(atom, [getattr(row, axis) for axis in [key_x, key_y, key_z]], element))  # 新增原子
    blocks = blocks[1:]  # 忽略第一個空block
    blocks.append(Block(last_res_symbol, units))  # 新增最後一個Block
    return blocks

詳細步驟解釋:

  1. 初始化

    • last_res_idlast_res_symbol 用於追蹤前一個殘基的資訊,blocks 用於存放最終的 Block 列表,units 用於暫存當前 Block 的原子。
  2. 按行遍歷 DataFrame

    • df.itertuples() 會將 DataFrame 中每一行作為一個元組返回,透過 getattr() 可以訪問每個屬性(例如 residueatom_name 等)。
  3. 檢查殘基 ID 變化

    • res_id 發生變化時,說明已經處理完當前殘基,需要建立一個新的 Block。
    • 使用 Block(last_res_symbol, units) 建立一個 Block 物件,units 則是該 Block 中的所有原子。
  4. 處理原子資訊

    • 每一行代表一個原子,建立 Atom 物件時,會提取該原子的名稱、三維座標和元素資訊。
  5. 返回結果

    • 最後返回 blocks 列表,包含所有的 Block 物件。

總結:

  • 程式碼中的 df_to_blocks 函式從 DataFrame 中提取原子,並將它們按殘基分組,建立了一個個 Block。每個 Block 包含一個殘基的資訊(符號和原子列表)。
  • 如果 DataFrame 中的資料已經按照這種結構組織好,那麼這段程式碼可以將資料高效地轉換為模型可以處理的塊狀資料結構。

這個分組和轉換過程的核心是識別殘基 ID 的變化,然後將每個殘基的原子聚合成一個 Block。

按殘基分組是什麼意思。也就是說一個block裡面會有多個相同的殘基麼?

“按殘基分組”是指將同一個殘基(residue)中所有的原子歸類到一個 block 中。在蛋白質結構中,殘基通常指氨基酸,而每個氨基酸殘基包含若干原子,例如碳、氫、氧、氮等。

一個 block 代表的是一個殘基及其包含的所有原子。因此,一個 block 裡面不會有多個相同的殘基,而是每個殘基單獨作為一個 block,包含該殘基的所有原子。每個殘基只對應一個 block,多個相同的殘基(例如多個“甘氨酸”殘基)會分別出現在不同的 block 中。

詳細解釋:

  • 在蛋白質結構中,殘基通常指的是氨基酸。每個氨基酸殘基包含多個原子(例如,氮原子、碳原子、氧原子等),這些原子組成了氨基酸的骨架和側鏈。
  • 當你“按殘基分組”時,程式碼會遍歷每個原子,並根據原子所屬的殘基(透過殘基 ID res_id)將同一殘基的所有原子放在一起。
  • 一個 block 就是代表了一個殘基的原子叢集。舉例來說,一個“甘氨酸”殘基,它包含 C、N、O 等原子,這些原子會被歸為同一個 block。

舉例:

假設我們有一個蛋白質片段,包含兩個氨基酸殘基:甘氨酸(Gly)和丙氨酸(Ala)。每個氨基酸殘基包含若干原子:

  1. 甘氨酸殘基(Gly)

    • 原子:N、C、O、H
    • 這些原子被分組到同一個 block 中,表示為 Block 1。
  2. 丙氨酸殘基(Ala)

    • 原子:N、C、O、H、Cβ(側鏈上的原子)
    • 這些原子被分組到另一個 block 中,表示為 Block 2。

程式碼的具體過程:

  • 當程式碼遍歷 DataFrame 時,它會檢查每一行原子所屬的殘基 ID(res_id)。
  • 當它遇到一個新的殘基 ID 時,意味著前一個殘基的所有原子已經處理完畢,生成一個 block。
  • 一個 block 包含當前殘基的所有原子。然後,開始處理下一個殘基,建立新的 block。

總結:

  • 一個 block 只包含一個殘基的所有原子
  • 不同的殘基會分別生成不同的 block
  • 多個相同的殘基(例如多個“甘氨酸”殘基)會出現在不同的 block 中,而不是合併在一起。

希望這能幫助你理解按殘基分組和 block 之間的關係!如果還有其他問題,請告訴我。

if isinstance(items, list):
            data = process_map(self._preprocess, items, max_workers=n_cpu, chunksize=10)  所以一個items包含了兩個蛋白質,將蛋白質1的所有殘基轉換為block1,而將蛋白質2的所有殘基轉換為了block2,是這樣麼?

是的,理解是正確的。在這段程式碼中,items 代表的是輸入資料,其中包含了兩個蛋白質的原子資訊。當 items 是一個列表時,process_map(self._preprocess, items) 會將列表中的每一項傳遞給 self._preprocess 函式進行處理。

具體流程:

  1. items 包含了兩個蛋白質的結構資料:

    • items['atoms_interface1']:包含了蛋白質 1 的所有原子資訊。
    • items['atoms_interface2']:包含了蛋白質 2 的所有原子資訊。
  2. _preprocess 函式中:

    • blocks1 = df_to_blocks(item['atoms_interface1']) 會將 蛋白質 1 的所有殘基轉換為 blocks1,即一個包含每個殘基的“block”的列表。
    • blocks2 = df_to_blocks(item['atoms_interface2']) 會將 蛋白質 2 的所有殘基轉換為 blocks2,同樣也是一個包含每個殘基的“block”的列表。
  3. 最後,blocks1blocks2 分別表示兩個蛋白質的殘基結構,這兩個列表會被傳遞到 blocks_to_data(blocks1, blocks2) 中進行進一步處理,可能是用於對接或相互作用分析。

結論:

是的,_preprocess 函式的作用就是將蛋白質 1 的殘基轉換為 blocks1,將蛋白質 2 的殘基轉換為 blocks2,每個 block 都代表一個殘基及其包含的所有原子資訊。

這裡的blocks1和blocks2是指什麼

在這段程式碼中,blocks1blocks2 是從 item['atoms_interface1']item['atoms_interface2'] 中提取出來的資料。它們代表了兩個不同的分子介面(interface1interface2)中的殘基和原子的分組資訊,經過 df_to_blocks() 函式處理後被表示為“block”的形式。

blocks1blocks2 具體指什麼?

  • blocks1:是從 item['atoms_interface1'] 中提取出來的“block”,表示分子介面 1(interface1)中按殘基分組的原子列表。這個分子介面包含了某個蛋白質或分子片段的一部分(通常是參與對接或相互作用的表面區域)。

  • blocks2:是從 item['atoms_interface2'] 中提取出來的“block”,表示分子介面 2(interface2)中按殘基分組的原子列表。這個介面通常是與 blocks1 進行相互作用的另一部分(例如另一個蛋白質或分子)。

每個 block 是由一個殘基的所有原子組成的物件,可能包含殘基的符號和它的原子座標、原子型別等資訊。

df_to_blocks() 函式的作用:

  • df_to_blocks() 函式將 DataFrame 中的原子資訊按殘基分組,生成多個 Block 物件,每個 Block 包含一個殘基的所有原子資訊。
  • blocks1blocks2 代表兩個不同分子介面中的所有殘基,每個介面都被分成若干 Block

示例場景:

假設 item['atoms_interface1']item['atoms_interface2'] 分別包含兩個不同蛋白質(或分子)的介面原子資訊,那麼:

  • blocks1 會包含 interface1 中所有殘基及其原子的分組;
  • blocks2 會包含 interface2 中所有殘基及其原子的分組。

這些分子介面資訊可能被用於計算兩個蛋白質之間的對接相互作用,或者分析它們的幾何結構。

blocks_to_data(blocks1, blocks2) 的作用:

  • 這個函式(雖然程式碼中沒有定義)看起來是將 blocks1blocks2 結合在一起,可能用於計算兩個介面的相互作用或距離矩陣,或者其他結構相關的指標。它會將兩個介面的“block”資料轉化為模型需要的輸入形式。

data['label'] = item['affinity']['neglog_aff']

  • 最後一步將 item['affinity']['neglog_aff'] 賦值給 data['label'],表示這個樣本的目標值(label),通常是一個關於親和力的數值(例如結合自由能的負對數),用作監督學習中的標籤。

總結:

  • blocks1blocks2 是分別從 item['atoms_interface1']item['atoms_interface2'] 中提取出來的“block”物件,代表了按殘基分組的原子資訊。
  • 它們代表了兩個分子介面的結構,這些結構可能用於進一步的相互作用分析或模型輸入。
blocks_to_data的函式

這段程式碼實現了 blocks_to_data 函式,它的作用是將多個 Block 物件的列表轉化為統一的資料格式,通常用於神經網路模型的輸入。具體來說,它將 Block 中的各類資料(如原子型別、原子座標、殘基型別等)整合成模型需要的輸入格式。以下是對每個步驟的詳細分析:

函式功能:

  • 輸入:函式接受若干個 blocks_list 引數,每個引數是一個 Block 物件的列表(即多個殘基的集合)。
  • 輸出:函式返回一個字典,字典中包含了處理後的所有 Block 資料,組織成統一格式,方便後續處理。

主要步驟:

  1. 初始化變數

    • B: 儲存每個 block(殘基)的型別,來自 VOCAB.symbol_to_idx
    • A: 儲存原子的型別,來自 VOCAB.get_atom_global_idx()block.to_data()
    • X: 儲存原子的座標資訊。
    • atom_positions: 儲存原子位置的索引,用於進一步計算原子的幾何資訊。
    • block_lengths: 記錄每個 block(殘基)包含的原子數量。
    • segment_ids: 記錄這些 block 屬於哪個分子或片段(即哪一個 blocks_list)。
  2. 處理每個 blocks_list

    • 外層迴圈 for i, blocks in enumerate(blocks_list) 遍歷輸入的多個 blocks_list(每個介面或分子),並逐一處理其中的 blocks(代表每個殘基)。
  3. 處理每個 block

    • 對於每個 blocks_list,首先建立一個 "global node"(全域性節點),即代表整個分子的虛擬中心點,用於儲存整個分子的全域性資訊。
    • 每個 Block 物件呼叫 block.to_data() 方法,該方法返回 b(殘基型別)、a(原子型別)、x(原子座標)、positions(原子位置索引)和 block_len(該殘基包含的原子數)。
  4. 計算全域性節點的中心點

    • cur_X[0] = np.mean(cur_X[1:], axis=0):該操作將當前 Block 的所有原子座標(去掉全域性節點本身)取平均值,並將結果儲存為全域性節點的座標。也就是說,"全域性節點" 的座標是所有原子座標的質心。
  5. 生成段 ID

    • cur_segment_ids = [i for _ in cur_B]:建立一個 segment_ids 列表,用於標記這些 block 來自於哪個 blocks_list(可以理解為它們屬於哪個分子或片段)。
  6. 將當前 block 的資訊存入全域性列表

    • B.extend(cur_B)A.extend(cur_A) 等操作將處理好的 block 資料新增到全域性變數中,最終將所有 block 的資訊合併在一起。
  7. 返回資料

    • 最終返回的資料是一個字典 data,其中包含了處理好的所有 block 和原子資訊。這些資訊包括:
      • X: 原子的座標陣列。
      • B: 殘基型別的索引。
      • A: 原子型別的索引。
      • atom_positions: 原子的幾何位置資訊。
      • block_lengths: 每個殘基包含的原子數量。
      • segment_ids: 每個殘基所屬的分子片段。

詳細欄位解釋:

  1. X:

    • 形狀為 [Natom, 2, 3],表示每個原子的三維座標。X[0] 是全域性節點的座標(殘基的中心),其餘的是殘基中每個原子的座標。
  2. B:

    • B 表示每個殘基的型別。它是一個列表,每個元素對應一個殘基,表示殘基的種類(由 VOCAB.symbol_to_idx() 轉換)。
  3. A:

    • A 是所有原子的型別列表,每個原子透過 VOCAB.get_atom_global_idx()block.to_data() 得到其型別。
  4. atom_positions:

    • atom_positions 記錄每個原子的位置資訊,它可能是用於更復雜的幾何計算或進一步分析。
  5. block_lengths:

    • block_lengths 記錄每個殘基包含的原子數量。每個殘基有不同數量的原子,這個列表提供了每個殘基的長度資訊。
  6. segment_ids:

    • segment_ids 用於標記這些 block 屬於哪個片段(如不同的分子或不同的介面),使得不同 blocks_list 可以區分開。

總結:

blocks_to_data 函式的核心目的是將 Block 物件的列表轉化為一個統一的資料格式,包括殘基型別、原子型別、座標等資訊。這些資料會被用於後續的神經網路模型或其他分析任務。

形狀為 [Natom, 2, 3] 這裡的維度2是什麼意思

X 的形狀 [Natom, 2, 3] 中,維度 2 對應的是每個原子的兩種座標表示。具體解釋如下:

  1. Natom: 表示所有原子的數量。這個維度的大小取決於輸入資料中總共包含的原子數(包括所有 block 中的原子)。

  2. 2: 這個維度表示兩種不同的座標表示方式:

    • 第一種表示X[:, 0, :])通常是指全域性節點的座標,即代表整個殘基或整個分子片段的質心座標(block 中所有原子座標的平均值)。
    • 第二種表示X[:, 1, :])是原子的實際座標,即每個原子相對於殘基的三維空間位置。
  3. 3: 表示每個原子的三維空間座標(即 x, y, z 三個軸的座標)。這個維度的大小始終為 3,表示三維座標系中的位置。

詳細解釋:

  • X[:, 0, :]: 對應的是每個 block 的全域性座標,即質心(或殘基的中心位置)。這種座標表示是透過 np.mean(cur_X[1:], axis=0) 計算得出的,表示該 block 中所有原子的平均座標。

  • X[:, 1, :]: 對應的是原子的實際座標,表示該原子在三維空間中的 x, y, z 位置。

舉個例子:

假設有 2 個原子,X 的形狀為 [2, 2, 3]

X = [
    [[x_global1, y_global1, z_global1], [x_atom1, y_atom1, z_atom1]],
    [[x_global2, y_global2, z_global2], [x_atom2, y_atom2, z_atom2]]
]
  • 對於第一個原子:

    • X[0, 0, :] 是全域性節點的座標(殘基的中心或質心座標)。
    • X[0, 1, :] 是該原子的具體座標。
  • 對於第二個原子:

    • X[1, 0, :] 是另一個 block 的全域性座標。
    • X[1, 1, :] 是第二個原子的具體座標。

總結:

  • 維度 2 中的第一個位置儲存全域性節點的座標(質心),而第二個位置儲存實際的原子座標。透過這種方式,每個原子都有兩種座標表示,既有它本身的座標,也有全域性座標。

index裡面這些變數是什麼

從你提供的截圖來看,self.indexes 是一個包含字典的列表,每個字典對應一個蛋白質-配體對的結合資料。這些變數儲存的是蛋白質和配體的親和力及其相關的物理化學資料。下面是這些變數的具體含義:

每個字典中的鍵值解釋:

  1. affinity:

    • 包含結合親和力的相關資訊,通常用於評估蛋白質與配體之間的結合強度。
    • Kd:解離常數(dissociation constant),表示蛋白質和配體在平衡時的結合強度。Kd 值越小,表示結合強度越高。
    • dG:自由能變化(Gibbs free energy change),通常是由解離常數計算得出,代表了結合反應的能量變化。通常,dG 值越低,結合越穩定。
    • neglog_aff:親和力的負對數值,可能用於處理 Kd 值較小時的數值穩定性。
  2. id:

    • 這是蛋白質-配體對的唯一識別符號(通常是 PDB ID),每個 PDB ID 對應一種特定的蛋白質結構或蛋白質-配體複合體。

具體示例:

  • affinity

    • 'Kd': 5.59e-07:表示蛋白質-配體的解離常數為 5.59 × 10^-7 M,表明蛋白質和配體結合較緊密。
    • 'dG': -8.53:自由能變化為 -8.53 kcal/mol,結合的穩定性較高。
    • 'neglog_aff': 6.25:可能是親和力的負對數值,通常用於評估結合強度。
  • id

    • 例如,'id': '3bpl':這是蛋白質-配體對的唯一識別符號,在 PDB 資料庫中可能查詢到這個複合物的結構。

總結:

self.indexes 是一個包含多個蛋白質-配體對的資料列表,每個字典包含了:

  • 親和力相關的數值(Kd, dG, neglog_aff)。
  • 蛋白質-配體對的唯一識別符號(id)。
    這些資訊可能用於分析蛋白質與配體的相互作用,以及模型的輸入或標籤(label)。

是的,data['label'] = item['affinity']['neglog_aff'] 這行程式碼的作用是將結合強度(以負對數親和力 neglog_aff 的形式)提取出來,並將其作為 data['label'],用於模型的標籤(label)。

具體解釋:

  • item['affinity']['neglog_aff']:這裡的 neglog_aff 是親和力的負對數(通常是 -log(Kd),其中 Kd 是解離常數),用於衡量蛋白質與配體的結合強度。

    • Kd (解離常數):Kd 值越小,表明蛋白質和配體結合越緊密。為了便於數值計算和處理,Kd 常常以負對數形式表示,因此 neglog_aff 越大,表示結合強度越強。
  • data['label']:將提取的 neglog_aff 值賦值給 data['label'],這個 label 將作為模型的目標值,用於監督學習或預測蛋白質-配體之間的結合強度。

為什麼使用 neglog_aff

使用親和力的負對數(neglog_aff)而不是 Kd 值的原因是:

  1. 數值穩定性:Kd 值的範圍可能非常廣,從 10^-9 到 10^-3 甚至更大。直接使用 Kd 值進行計算可能導致數值不穩定。透過對 Kd 取負對數,可以縮小數值範圍,使其更適合用於機器學習模型。
  2. 線性關係:在某些情況下,-log(Kd) 與結合強度的物理化學過程有更好的線性關係,因此更適合用於迴歸模型。

總結:

這行程式碼是提取蛋白質-配體結合強度的 neglog_aff 值,並將其作為標籤(label),用於模型訓練或評估結合強度的任務。

以batch為單位的資料

根據你提供的截圖內容,下面是對各個變數的解釋以及它們可能對應的資料含義:

  1. A:

    • 內容: tensor([2, 9, 8, ..., 8, 10, 8], device='cuda:0')
    • 解釋: 這是一個表示原子索引的 Tensor,通常對應的是某種型別的原子屬性或索引。它可能是每個原子所對應的型別(例如不同元素的類別)。
  2. B:

    • 內容: tensor([3, 2, 7, ..., 12, 14, 5], device='cuda:0')
    • 解釋: 這個變數通常表示塊(block)或殘基(residue)的索引或型別。它可能是表示不同殘基或塊的分類資訊。
  3. Z:

    • 內容: tensor([[[0.2207, 17.2229, -14.8194], [-11.4810, 22.9620, -30.7930]]], device='cuda:0')
    • 解釋: Z 表示原子的三維座標資訊。這個 Tensor 的形狀是 (8548, 1, 3),表示 8548 個原子的三維座標。這通常用於計算分子之間的距離或作為模型輸入。
  4. atom_positions:

    • 內容: tensor([2, 3, 4, ..., 3, 3, 5], device='cuda:0')
    • 解釋: 這是每個原子的位置資訊,可能是原子在結構中的順序或與三維座標有關的索引。結合 Z 使用時,它可以幫助定位每個原子的具體位置。
  5. block_lengths:

    • 內容: tensor([1, 8, 8, ..., 8, 8, 5], device='cuda:0')
    • 解釋: 這是一個記錄每個塊(block)或殘基(residue)包含多少原子的 Tensor。它可以幫助定義每個塊的大小和結構。
  6. label:

    • 內容: tensor([10.3841, 5.7959, 5.0969, ..., 9.7696], device='cuda:0')
    • 解釋: 這是目標變數的標籤資訊,通常用於表示結合親和力(affinity)或其他分子間屬性。這些標籤是模型訓練時的目標值。
  7. lengths:

    • 內容: tensor([92, 111, 82, ..., 47, 175, 39], device='cuda:0')
    • 解釋: 這是每個樣本(分子或塊)的長度資訊,表示每個分子或殘基塊的長度,用於動態批處理或其他模型輸入處理。
  8. return_loss:

    • 內容: False
    • 解釋: 這是一個布林值,表示是否返回損失值(loss)。它通常用於控制在訓練過程中是否需要計算並返回損失值。
  9. return_noise:

    • 內容: False
    • 解釋: 這是一個布林值,表示是否返回噪聲資訊。它可能用於控制在一些模型中是否引入噪聲以進行資料增強或模擬不同情況。
  10. segment_ids:

    • 內容: tensor([0, 0, 0, ..., 1, 1, 1], device='cuda:0')
    • 解釋: 這是一個表示分段資訊的 Tensor,用於標記資料中不同片段的邊界。它可能用於區分來自不同分子或殘基的部分。
  11. self (AffinityPredictor):

    • 解釋: 這是模型的例項物件(AffinityPredictor),它可能是你正在使用的模型類,專門用於預測分子親和力或其他相關任務。這個模型可能包含了一個塊嵌入層(BlockEmbedding),用於將分子塊嵌入為向量表示。

總結:

這些變數是模型輸入資料的不同部分。大致可以歸類為以下幾類:

  • 幾何和結構資訊:如 Z(三維座標)、atom_positionsblock_lengthslengths
  • 標籤和目標值:如 label(親和力標籤)。
  • 模型的控制和處理標誌:如 return_lossreturn_noise
  • 其他輔助資訊:如 segment_idsAB(分別表示原子和塊的索引或型別)。

這些資訊結合起來構成了用於分子建模和預測的輸入資料,以及用於控制訓練過程的標誌。

邊序號的構建

在這段程式碼中,construct_edges 函式的主要作用是根據輸入的塊(block)、批次(batch)、段(segment)等資訊,利用 edge_constructor 來構建 邊集(edges),這些邊表示分子結構或其他複雜結構中各節點(如原子、塊)之間的連線關係。

邊的構建步驟

  1. 不進行切片的邊構建:
    complexity == -1 時,不進行任何切片操作,直接呼叫 edge_constructor 來構建所有邊:

    intra_edges, inter_edges, global_global_edges, global_normal_edges, _ = edge_constructor(B, batch_id, segment_ids, X=X, block_id=block_id)
    return intra_edges, inter_edges, global_global_edges, global_normal_edges
    

    在這裡,edge_constructor 會直接生成以下幾種邊:

    • intra_edges: 在塊(block)內部的邊。
    • inter_edges: 塊與塊之間的邊。
    • global_global_edges: 全域性節點之間的邊。
    • global_normal_edges: 全域性節點與普通節點之間的邊。
  2. 進行切片的邊構建:
    complexity != -1 時,會根據指定的複雜度 complexity 對資料進行切片。每次處理資料的一個小批次(mini-batch),併為每個小批次構建邊,最後再合併這些邊。

    具體過程如下:

    • 初始化部分變數:

      offset, bs_id_start, bs_id_end = 0, 0, 0
      mini_intra_edges, mini_inter_edges, mini_global_global_edges, mini_global_normal_edges = [], [], [], []
      batch_size = batch_id.max() + 1
      unit_batch_id = batch_id[block_id]
      lengths = scatter_sum(torch.ones_like(batch_id), batch_id, dim=0)
      

      這裡 batch_size 是指總的批次數量,unit_batch_id 表示每個原子的批次ID,lengths 計算每個批次中的塊的數量。

    • 逐個批次處理:

      while bs_id_end < batch_size:
          bs_id_start = bs_id_end
          bs_id_end += 1
          while bs_id_end + 1 <= batch_size and \
                (lengths[bs_id_start:bs_id_end + 1] * lengths[bs_id_start:bs_id_end + 1].max()).sum() < complexity:
              bs_id_end += 1
      

      這部分程式碼用於迭代處理每個批次的資料,同時根據 complexity 控制每個小批次的複雜度。lengths[bs_id_start:bs_id_end + 1] 計算當前批次中塊的數量,用於決定是否要繼續向下一個批次擴充套件。

    • 選擇當前批次中的塊和原子:

      block_is_in = (batch_id >= bs_id_start) & (batch_id < bs_id_end)
      unit_is_in = (unit_batch_id >= bs_id_start) & (unit_batch_id < bs_id_end)
      B_mini, batch_id_mini, segment_ids_mini = B[block_is_in], batch_id[block_is_in], segment_ids[block_is_in]
      X_mini, block_id_mini = X[unit_is_in], block_id[unit_is_in]
      

      這裡 block_is_inunit_is_in 用於選擇當前批次中的塊和原子。B_minibatch_id_mini 等變數儲存的是當前小批次中的塊、批次ID、段ID、位置等資訊。

    • 為當前批次構建邊:

      intra_edges, inter_edges, global_global_edges, global_normal_edges, _ = edge_constructor(
          B_mini, batch_id_mini - bs_id_start, segment_ids_mini, X=X_mini, block_id=block_id_mini - offset)
      

      對於當前小批次,透過 edge_constructor 構建邊集,包括塊內邊、塊間邊、全域性邊等。

    • 累加偏移量:

      offset += B_mini.shape[0]
      

      由於在每個小批次之間的索引會重新計算,因此需要透過 offset 來保持原始的全域性索引。

    • 將小批次的邊合併:
      透過檢查 edge_constructor 是否有提供 given_intra_edgesgiven_inter_edges,如果沒有的話,將小批次中的邊新增到合併後的列表中,最終將這些小批次的邊合併成全域性的邊:

      if hasattr(edge_constructor, 'given_intra_edges'):
          intra_edges = edge_constructor.given_intra_edges
      else:
          intra_edges = torch.cat(mini_intra_edges, dim=1)
      if hasattr(edge_constructor, 'given_inter_edges'):
          inter_edges = edge_constructor.given_inter_edges
      else:
          inter_edges = torch.cat(mini_inter_edges, dim=1)
      if global_global_edges is not None:
          global_global_edges = torch.cat(mini_global_global_edges, dim=1)
      if global_normal_edges is not None:
          global_normal_edges = torch.cat(mini_global_normal_edges, dim=1)
      

邊的型別總結

最終返回的邊包括以下幾種型別:

  • intra_edges: 塊內部的邊,即塊中原子之間的邊。
  • inter_edges: 塊之間的邊,即不同塊中的原子之間的邊。
  • global_global_edges: 全域性節點之間的邊,用於捕捉全域性資訊。
  • global_normal_edges: 全域性節點與普通節點之間的邊,用於連線全域性和區域性資訊。

這些邊將被用於進一步的圖神經網路中,用來建模分子結構的幾何和拓撲資訊。

如果你有關於這個過程的具體問題,歡迎進一步提問!

相關文章