深度學習武器庫-timm-非常好用的pytorch CV模型庫 - 常用模型操作

零度的python武器库發表於2024-08-11

簡要介紹

timm庫,全稱pytorch-image-models,是最前沿的PyTorch影像模型、預訓練權重和實用指令碼的開源集合庫,其中的模型可用於訓練、推理和驗證。

github原始碼連結
https://github.com/huggingface/pytorch-image-models

文件教程
文件:https://huggingface.co/docs/hub/timm
上手教程:https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055


優點

1、方便使用。在python環境中安裝timm庫,即可用幾行程式碼建立網路模型,並可選擇匯入在imagenet等資料集上得到的預訓練權重;無需再去扒每個模型的原始碼,這對於跑模型對比實驗是非常方便的,可以節省大量的時間;

2、靈活性高。匯入模型的原始做法是,直接用.pth等權重檔案匯入,但這通常受到儲存模型方法的限制,可能出現權重鍵名稱不匹配、網路中間張量操作丟失(只能匯入模型權重 卻無法匯入網路層之間的中間張量操作)等問題;而timm中的模型是對於整個模型的封裝,在建立模型後,可以對相應網路層進行改動調整,非常靈活;

3、模型收錄全面且前沿。在CV這個模型日新月異的領域,timm及時更新收錄了最新的模型,比如,2024.8.8 新增了ECCV2024上的新模型RDNet(工作連結:https://github.com/naver-ai/rdnet)。
image

缺點

目前使用中感到一點不方便的是,使用函式請求下載這個庫中的權重檔案時,連結地址大部分是huggingface上的,而huggingface得用US的節點才能有較好的網速 。。。 我的PC可以訪問,但伺服器訪問不了。。。 但這個缺點可以透過一些操作進行規避。

常用模型操作

  • 檢視目前收錄模型
    使用程式碼:
    timm.list_models('*')
    執行效果:
    image
    可以看到目前收錄共有946個模型,檢視目前已收錄的模型,從這個列表中確定要匯入的目標模型名稱

    還可以透過正規表示式匹配目標模型名稱,並透過指定pretrained=True篩選有預訓練權重的模型
    如下匹配預訓練resnet:timm.list_models('resnet*', pretrained=True)
    image

  • 建立模型
    使用程式碼(以resnet50為例):
    timm.create_model('resnet50', pretrained=True, in_chans=3, num_classes=6)
    這裡的主要引數有四個:
    第一個是模型名稱model_name
    第二個是是否預訓練pretrained,
    第三個是輸入影像的通道數in_chans
    第四個是分類類別數num_classes,指最後輸出FC層的維度。

    注意:在建立模型這步中包含了從網路下載模型權重的操作,
    此時就會出現我在“缺點”部分講到的問題:因為網路無法連線huggingface網站,而導致權重下載請求失敗的情況 (多在伺服器端出現)。
    image

    下面是解決方法
    先在能夠連線huggingface網站的PC上,手動下載權重配置檔案,使用程式碼如下:

    backbone_name = 'resnet50'
    
    pretrained_cfg = timm.create_model(backbone_name).default_cfg
    print(pretrained_cfg)
    

    執行後輸出配置資訊:

    {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth', 'hf_hub_id': 'timm/resnet50.a1_in1k', 'architecture': 'resnet50', 'tag': 'a1_in1k', 'custom_load': False, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 0.95, 'test_crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'num_classes': 1000, 'pool_size': (7, 7), 'first_conv': 'conv1', 'classifier': 'fc', 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476'}
    

    其中,url對應了模型的下載請求地址,直接將這個url複製貼上到瀏覽器中,手動下載權重檔案。
    獲得權重檔案後,再使用timm.create_model方法,透過將pretrained_cfg_overlay引數指定為權重檔案,來建立模型,這樣就是本地建立了:

    backbone_name = 'resnet50'
    ckpt_path = './ckpt/resnet50_a1_0-14fe96d1.pth'
    
    model = timm.create_model(backbone_name,
                                       pretrained=True,
                                       pretrained_cfg_overlay=dict(file=ckpt_path))
    
  • 手動調整模型
    在cv中,最常見的操作是將某個網路的主幹層,用於特徵提取
    timm中有專門的方法可以實現這個目的:
    feature_ouput = model.forward_features(image)
    feature_ouput即為網路在最後的head層之前輸出的特徵向量。
    但這個操作還無法完美解決問題,因為通常認為提取特徵就是排除網路的最後一層,但有的網路最後一層中不僅包括全連線(FC)層,在FC層之前還包含池化層。這就需要更靈活的操作,來調整、構建我們想要的網路。
    下面是我的程式碼(以手動新增池化層為例):

    feature_extract_model = nn.Sequential(*list(model.children())[:-1],
                                               nn.AdaptiveAvgPool2d(1))
    

    另外,也可以自己定製最後的head層(但個人感覺這個用途不多),例如:

    model.fc = nn.Sequential(
        nn.BatchNorm1d(num_in_features),
        nn.Linear(in_features=num_in_features, out_features=512, bias=False),
        nn.ReLU(),
        nn.BatchNorm1d(512),
        nn.Dropout(0.4),
        nn.Linear(in_features=512, out_features=10, bias=False))
    

本期總結

timm是一個非常全面且便捷的CV影像模型庫,能夠大大提升我們跑實驗的效率。我們同樣也能運用其中的部分模組類,用到自己的編碼中,也可以在它的原始碼中學習模型的程式碼寫法。本期筆記只是介紹了本人近期跑對比實驗,使用後感覺最常用的一些方法和操作,timm還有很多功能和用法需要去探索,比如還有資料增強、資料集和最佳化器等等功能。

相關文章