Maksym-Lysyi's picture
initial commit
e3641b1
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import os
import os.path as osp
import time
import warnings
import click
import yaml
from glob import glob
import torch
import torch.distributed as dist
from vit_utils.util import init_random_seed, set_random_seed
from vit_utils.dist_util import get_dist_info, init_dist
from vit_utils.logging import get_root_logger
import configs.ViTPose_small_coco_256x192 as s_cfg
import configs.ViTPose_base_coco_256x192 as b_cfg
import configs.ViTPose_large_coco_256x192 as l_cfg
import configs.ViTPose_huge_coco_256x192 as h_cfg
from vit_models.model import ViTPose
from datasets.COCO import COCODataset
from vit_utils.train_valid_fn import train_model
CUR_PATH = osp.dirname(__file__)
@click.command()
@click.option('--config-path', type=click.Path(exists=True), default='config.yaml', required=True, help='train config file path')
@click.option('--model-name', type=str, default='b', required=True, help='[b: ViT-B, l: ViT-L, h: ViT-H]')
def main(config_path, model_name):
cfg = {'b':b_cfg,
's':s_cfg,
'l':l_cfg,
'h':h_cfg}.get(model_name.lower())
# Load config.yaml
with open(config_path, 'r') as f:
cfg_yaml = yaml.load(f, Loader=yaml.SafeLoader)
for k, v in cfg_yaml.items():
if hasattr(cfg, k):
raise ValueError(f"Already exists {k} in config")
else:
cfg.__setattr__(k, v)
# set cudnn_benchmark
if cfg.cudnn_benchmark:
torch.backends.cudnn.benchmark = True
# Set work directory (session-level)
if not hasattr(cfg, 'work_dir'):
cfg.__setattr__('work_dir', f"{CUR_PATH}/runs/train")
if not osp.exists(cfg.work_dir):
os.makedirs(cfg.work_dir)
session_list = sorted(glob(f"{cfg.work_dir}/*"))
if len(session_list) == 0:
session = 1
else:
session = int(os.path.basename(session_list[-1])) + 1
session_dir = osp.join(cfg.work_dir, str(session).zfill(3))
os.makedirs(session_dir)
cfg.__setattr__('work_dir', session_dir)
if cfg.autoscale_lr:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
# init distributed env first, since logger depends on the dist info.
if cfg.launcher == 'none':
distributed = False
if len(cfg.gpu_ids) > 1:
warnings.warn(
f"We treat {cfg['gpu_ids']} as gpu-ids, and reset to "
f"{cfg['gpu_ids'][0:1]} as gpu-ids to avoid potential error in "
"non-distribute training time.")
cfg.gpu_ids = cfg.gpu_ids[0:1]
else:
distributed = True
init_dist(cfg.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(session_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log some basic info
logger.info(f'Distributed training: {distributed}')
# set random seeds
seed = init_random_seed(cfg.seed)
logger.info(f"Set random seed to {seed}, "
f"deterministic: {cfg.deterministic}")
set_random_seed(seed, deterministic=cfg.deterministic)
meta['seed'] = seed
# Set model
model = ViTPose(cfg.model)
if cfg.resume_from:
# Load ckpt partially
ckpt_state = torch.load(cfg.resume_from)['state_dict']
ckpt_state.pop('keypoint_head.final_layer.bias')
ckpt_state.pop('keypoint_head.final_layer.weight')
model.load_state_dict(ckpt_state, strict=False)
# freeze the backbone, leave the head to be finetuned
model.backbone.frozen_stages = model.backbone.depth - 1
model.backbone.freeze_ffn = True
model.backbone.freeze_attn = True
model.backbone._freeze_stages()
# Set dataset
datasets_train = COCODataset(
root_path=cfg.data_root,
data_version="feet_train",
is_train=True,
use_gt_bboxes=True,
image_width=192,
image_height=256,
scale=True,
scale_factor=0.35,
flip_prob=0.5,
rotate_prob=0.5,
rotation_factor=45.,
half_body_prob=0.3,
use_different_joints_weight=True,
heatmap_sigma=3,
soft_nms=False
)
datasets_valid = COCODataset(
root_path=cfg.data_root,
data_version="feet_val",
is_train=False,
use_gt_bboxes=True,
image_width=192,
image_height=256,
scale=False,
scale_factor=0.35,
flip_prob=0.5,
rotate_prob=0.5,
rotation_factor=45.,
half_body_prob=0.3,
use_different_joints_weight=True,
heatmap_sigma=3,
soft_nms=False
)
train_model(
model=model,
datasets_train=datasets_train,
datasets_valid=datasets_valid,
cfg=cfg,
distributed=distributed,
validate=cfg.validate,
timestamp=timestamp,
meta=meta
)
if __name__ == '__main__':
main()