gigapath部署以及微調全過程

liujunxi發表於2024-11-09

0.0

什麼是gigapath

gigapath是一個由微軟開發的數字病理學全玻片基礎模型,用於從高解析度影像(如病理切片影像)中提取和處理資訊的深度學習模型架構。

圖中分為abc三個部分

a

首先輸入一張高畫質的病理影像,我們將它拆分成256*256的影像切片,從而可以逐塊處理。

每個影像塊會被輸入到一個基於Vision Transformer(ViT)的編碼器中,提取影像塊級別的特徵,得到影像塊的嵌入表示。

影像塊級別的CLS token(分類標記)用於表徵整個影像塊的全域性資訊

Slide-Level Encoder (LongNet):隨後將這些影像塊的嵌入表示傳遞到一個基於 LongNet(長序列處理網路)的 Slide-Level Encoder 中,該編碼器採用 Dilated Attention(膨脹注意力機制)來捕捉不同影像塊之間的長距離依賴關係,生成整個病理切片影像的嵌入表示。

b

Vision Transformer (Teacher Model):在教師模型中,影像塊被分為多個“全域性切片”(Global crops),用於生成準確的嵌入表示。

Vision Transformer (Student Model):學生模型接收區域性切片(Local crops)和帶掩碼的全域性切片。學生模型和教師模型透過對比學習(Contrastive Loss)進行對齊,確保學生模型在帶掩碼的輸入下也能生成類似的特徵。

c

LongNet-based Decoder:該部分顯示了輸入嵌入和目標嵌入的匹配過程,透過重構損失來指導解碼器學習生成目標嵌入。

Reconstruction Loss(重構損失):透過計算生成嵌入與目標嵌入之間的重構損失,最佳化模型的生成效果。

最後得到嵌入表示的向量,成功將高維、複雜的資料(例如影像、文字、音訊等)轉換為低維向量,可以作為其他任務(如分類、聚類、檢索等)的輸入特徵。這種特徵是由模型學到的,因此具有通用性,能夠適應多種任務需求。

0.1

模型部署

首先從/home/data/hf/Gigapath,將這裡面的檔案cp到本地

開啟終端

命令

cp -a /home/data/hf/Gigapath .

然後根據Prov-GigaPath/Prov-GigaPath ·擁抱臉裡面的教程

利用裡面的environment.yaml建立我們的gigapath虛擬環境

 cd Gigapath/prov-gigapath_github

conda env create -f environment.yaml
conda activate gigapath
pip install -e .

接著我們把cuda設定進環境變數裡面

export CUDA_HOME=/usr/local/cuda/
export PATH=/usr/local/cuda/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
我們新建一個python檔案,命名為test_gigapath.py, 放在Gigapath/prov-gigapath_github目錄下
test_gigapath.py程式碼如下
import os
import torch
import timm
import numpy as np
import gigapath.slide_encoder as slide_encoder
from gigapath.pipeline import run_inference_with_tile_encoder, run_inference_with_slide_encoder

print("................")

# 解壓後輸入圖片所在目錄
slide_dir = "/public/liujx/Gigapath/PANDA/PANDA_sample_tiles/054b6888604d963455bfff551518ece5"
image_paths = [os.path.join(slide_dir, img) for img in os.listdir(slide_dir) if img.endswith('.png')]
print(f"Found {len(image_paths)} image tiles")

# 載入 tile_encoder 模型
model_arch = "vit_giant_patch14_dinov2"
model_path = "/home/data/hf/Gigapath/prov-gigapath_hf/pytorch_model.bin"
tile_encoder = timm.create_model(model_arch,
                                 pretrained=True,
                                 img_size=224,
                                 in_chans=3,
                                 pretrained_cfg_overlay=dict(file=model_path))
  
# 列印引數數量
print("tile_encoder param #", sum(p.numel() for p in tile_encoder.parameters()))

# 載入 slide_encoder 模型
slide_encoder_model = slide_encoder.create_model(
    pretrained="/public/liujx/Gigapath/prov-gigapath_hf/slide_encoder.pth",
    model_arch="gigapath_slide_enc12l768d",
    in_chans=1536,
)
print("slide_encoder param #", sum(p.numel() for p in slide_encoder_model.parameters()))

# 執行 tile_encoder 推理
tile_encoder_outputs = run_inference_with_tile_encoder(image_paths, tile_encoder)

# 列印 tile_encoder 輸出形狀
for k in tile_encoder_outputs.keys():
    print(f"tile_encoder_outs[{k}].shape: {tile_encoder_outputs[k].shape}")

# 執行 slide_encoder 推理
slide_embeds = run_inference_with_slide_encoder(
    slide_encoder_model=slide_encoder_model,
    **tile_encoder_outputs
)
print(slide_embeds.keys())

# 儲存 slide_embeds 輸出到 PyTorch .pt 檔案
save_dir = "/public/liujx/Gigapath/prov-gigapath_github"
os.makedirs(save_dir, exist_ok=True)

# 儲存 slide_embeds 為 .pt 檔案格式
slide_embeds_path = os.path.join(save_dir, "slide_embeds.pt")
torch.save(slide_embeds, slide_embeds_path)
print(f"slide_embeds saved to {slide_embeds_path}")

然後我們開啟存放sample_tiles.zip的資料夾,在GIgapath/PANDA目錄下,將這個zip包解壓
命令如下
unzip sample_tiles.zip
複製這個路徑,修改為以上程式碼中解壓後輸入圖片所在目錄
然後將29行的pretrained改成
/public/(你的使用者名稱)/Gigapath/prov-gigapath_hf/slide_encoder.pth
50行save_dir改成
/public/(你的使用者名稱)/Gigapath/prov-gigapath_github
然後新建一個檔案,命名為run_gigapath.pbs
程式碼如下
#!/bin/bash
#PBS -N test
#PBS -o test_$PBS_JOBID.log
#PBS -e test_$PBS_JOBID.err
#PBS -l nodes=1:ppn=12
#PBS -q gpu
cd $PBS_O_WORKDIR


module add gcc/11.2.0
source /home/data/software/python/3.12.7/gigapath/bin/activate

echo "test"
python3 test_gigapath.py

接著由於原有的專案程式碼有bug,我們需要進行修改

vim /public/liujx/Gigapath/prov-gigapath_github/gigapath/torchscale/model/../../torchscale/architecture/config.py

輸入i修改,在這個檔案的第一行加入

import numpy as np

然後按Esc,再輸入 :wq 儲存

儲存後回到剛剛你專案的地址,cd到你專案下

輸入

qsub run_gigapath.pbs
就會輸出一個作業號,然後等待大約3分鐘訓練,就會輸出一個名為
slide_embeds.pt
的檔案以及對應的err 和 log檔案
如果err檔案輸出如下
/home/data/software/python/3.12.7/gigapath/lib/python3.12/site-packages/timm/models/registry.py:4: FutureWarning: Importing from timm.models.registry is deprecated, please import via timm.models
  warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
/home/data/software/python/3.12.7/gigapath/lib/python3.12/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
  warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
/public/liujx/Gigapath/prov-gigapath_github/gigapath/slide_encoder.py:236: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(local_path, map_location="cpu")["model"]
/public/liujx/Gigapath/prov-gigapath_github/gigapath/pipeline.py:102: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(dtype=torch.float16):

Running inference with tile encoder:   0%|          | 0/7 [00:00<?, ?it/s]
Running inference with tile encoder:  14%|█▍        | 1/7 [00:17<01:47, 17.90s/it]
Running inference with tile encoder:  29%|██▊       | 2/7 [00:35<01:29, 17.83s/it]
Running inference with tile encoder:  43%|████▎     | 3/7 [00:56<01:16, 19.09s/it]
Running inference with tile encoder:  57%|█████▋    | 4/7 [01:16<00:58, 19.58s/it]
Running inference with tile encoder:  71%|███████▏  | 5/7 [01:33<00:37, 18.56s/it]
Running inference with tile encoder:  86%|████████▌ | 6/7 [01:47<00:17, 17.14s/it]
Running inference with tile encoder: 100%|██████████| 7/7 [01:54<00:00, 13.70s/it]
Running inference with tile encoder: 100%|██████████| 7/7 [01:54<00:00, 16.34s/it]
/public/liujx/Gigapath/prov-gigapath_github/gigapath/pipeline.py:130: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(dtype=torch.float16):

並且log檔案如下

test
................
Found 810 image tiles
tile_encoder param # 1134769664
dilated_ratio:  [1, 2, 4, 8, 16]
segment_length:  [np.int64(1024), np.int64(5792), np.int64(32768), np.int64(185363), np.int64(1048576)]
Number of trainable LongNet parameters:  85148160
Global Pooling: False
[92m Successfully Loaded Pretrained GigaPath model from /public/liujx/Gigapath/prov-gigapath_hf/slide_encoder.pth [00m
slide_encoder param # 86330880
tile_encoder_outs[tile_embeds].shape: torch.Size([810, 1536])
tile_encoder_outs[coords].shape: torch.Size([810, 2])
dict_keys(['layer_0_embed', 'layer_1_embed', 'layer_2_embed', 'layer_3_embed', 'layer_4_embed', 'layer_5_embed', 'layer_6_embed', 'layer_7_embed', 'layer_8_embed', 'layer_9_embed', 'layer_10_embed', 'layer_11_embed', 'layer_12_embed', 'last_layer_embed'])
slide_embeds saved to /public/liujx/Gigapath/prov-gigapath_github/slide_embeds.pt

說明你已經成功載入號預訓練的模型,並且成功執行起來了

這時你的檔案目錄下會多出一個slide_embeds.pt的檔案,這是一個pytorch模式下的張量檔案,我們可以透過一下程式碼開啟並檢視這個檔案

import torch
import pandas as pd

# 載入 .pt 檔案
slide_embeds_path = "/public/liujx/Gigapath/prov-gigapath_github/slide_embeds.pt"
slide_embeds = torch.load(slide_embeds_path)

# 建立一個空的 DataFrame 來儲存結果
data_dict = {}

# 將每個層的嵌入轉換為展平的一維陣列,並新增到字典中
for key, tensor in slide_embeds.items():
    # 如果是二維張量,展平為一維
    flattened_tensor = tensor.cpu().numpy().flatten()
    data_dict[key] = flattened_tensor

# 將字典轉換為 DataFrame
df = pd.DataFrame(data_dict)

# 儲存 DataFrame 為 CSV 檔案
csv_file_path = "/public/liujx/Gigapath/prov-gigapath_github/slide_embeds.csv"
df.to_csv(csv_file_path, index=False)

# 列印 CSV 檔案路徑
print(f"CSV file saved at {csv_file_path}")

將這段程式碼儲存在.pt檔案所在的目錄下,執行,可以將這個.pt檔案轉換為更為直觀,可以透過本地excel直接開啟的csv檔案。

接著我們進行微調,

相關文章