Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ipynb 파일로 train 실행하기 #5

Open
jsh0551 opened this issue Apr 28, 2022 · 0 comments
Open

ipynb 파일로 train 실행하기 #5

jsh0551 opened this issue Apr 28, 2022 · 0 comments

Comments

@jsh0551
Copy link
Contributor

jsh0551 commented Apr 28, 2022

블럭 별로 셀에 넣어서 실행하시면 됩니다.
차후에 augmentation 확인하는 코드도 추가할 예정입니다.
오류 생기거나 원하는 추가기능 있으면 말씀해주세요!

from mmcv import Config
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor
from mmseg.datasets import (build_dataloader, build_dataset)
import mmcv
import os
from mmseg.apis import set_random_seed

데이터 및 모델 경로 설정

data_root = '/opt/ml/input/mmseg'
model_path = '/opt/ml/mmsegmentation/_MyModel/_models_'
model_name = 'deeplabv3_r50-d8_512x1024_40k_cityscapes.py'
cfg = Config.fromfile(os.path.join(model_path,model_name))

실험 환경 설정

cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.num_classes = 11
cfg.model.auxiliary_head.num_classes = 11

# data root
cfg.data_root = data_root

# batch size
cfg.data.samples_per_gpu = 16
cfg.data.workers_per_gpu=8

cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = 'images'
cfg.data.train.ann_dir = 'annotations'
cfg.data.train.split = 'splits/train.txt'

cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = 'images'
cfg.data.val.ann_dir = 'annotations'
cfg.data.val.split = 'splits/valid.txt'

# where checkpoints saved
cfg.work_dir = './work_dirs/deeplabv3_r50-d8'

 # max epoch
cfg.runner.max_epochs = 80

# train 정보 출력 간격
cfg.log_config.interval = 50
cfg.optimizer_config.grad_clip = dict(max_norm=35, norm_type=2)

# max_keep_ckpts : pht 파일 최대 저장 갯수
# interval : pth 저장 간격
cfg.checkpoint_config = dict(max_keep_ckpts=10, interval=2) 

# Set seed to facitate reproducing the result
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

데이터셋, 모델 선언

datasets = build_dataset(cfg.data.train)
model = build_segmentor(cfg.model)
model.init_weights()

학습
train_segmentor(model, datasets, cfg, distributed=False, validate=True,meta=dict())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant