mmsegmentation中構造自己的資料集和資料載入部分,跑現有demo

宇宙•太陽發表於2022-05-05

在mmsegmentation中訓練自己的資料集

先在mmse/dataset下建立一個python檔案,我的名字是my_thermal_dataset.py

在其中填寫下面內容

這裡要注意,在設定suffix的時候,如果你的label檔案和train圖片的字尾不一樣,記得加上,我這裡的label檔案有字尾_label,別忘記了


import mmcv
from mmcv.utils import print_log
import os.path as osp
from ..utils import get_root_logger
from .builder import DATASETS
from .custom import CustomDataset

@DATASETS.register_module()
class mydata(CustomDataset):
    CLASSES = ('background', 'person')
    PALETTE = [[0,0,0], [128,0,0]]
    def __init__(self, **kwargs):
        super(mydata, self).__init__(
            img_suffix='.JPEG',
            seg_map_suffix='_label.png',
            reduce_zero_label=False,
            **kwargs)
        assert osp.exists(self.img_dir)

然後需要在_init_.py檔案中新增你自己的資料集的名稱

# Copyright (c) OpenMMLab. All rights reserved.
from .ade import ADE20KDataset
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .custom import CustomDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
                               RepeatDataset)
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .voc import PascalVOCDataset

from .my_thermal_dataset import my_thermal_dataset

__all__ = [
    'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
    'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
    'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
    'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
    'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
    'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
    'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset','my_thermal_dataset'
]

然後需要更改/mmsegmentation-master/mmseg/core/evaluation下的class_names.py檔案中的class

新增兩個部分

def my_thermal_dataset_classes():
    return ['background', 'person']

def my_thermal_dataset_palette():
    return [[0,0,0], [128,0,0]]

在資料集部分新增告一段落,後需要在config/_base_資料夾下新增相對應的檔案

新增my_thermal_datasets.py

image-20220505104228525

裡面填寫的內容我是在ade20k的配置檔案的基礎上修改的

需要修改對應的路徑

# dataset settings
dataset_type = 'my_thermal_dataset'
data_root = './thermal_data'
img_norm_cfg = dict(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', reduce_zero_label=True),
    dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(2048, 512),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=4,
    workers_per_gpu=4,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='img_dir/train',
        ann_dir='ann_dir/train',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='img_dir/val',
        ann_dir='ann_dir/val',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='img_dir/val',
        ann_dir='img_dir/val',
        pipeline=test_pipeline))

開始訓練模型

GPUS=8 sh tools/slurm_train.sh dev mae_thermal configs/mae/upernet_mae-base_fp16_512×512_160k_thermal.py --work-dir work_dirs/thermal_res/

相關文章