import torch import torch.optim as optim import data as Data import models as Model import torch.nn as nn import argparse import logging import core.logger as Logger import os import numpy as np from misc.metric_tools import ConfuseMatrixMeter from models.loss import * from collections import OrderedDict import core.metrics as Metrics from misc.torchutils import get_scheduler, save_network if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='./config/whu/whu.json', help='JSON configuration file for training') parser.add_argument('--phase', type=str, default='train', choices=['train', 'test'], help='Choose between training or testing') parser.add_argument('--gpu_ids', type=str, default=None, help='Specify GPU device') parser.add_argument('-log_eval', action='store_true', help='Whether to log evaluation') args = parser.parse_args() opt = Logger.parse(args) opt = Logger.dict_to_nonedict(opt) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True Logger.setup_logger(logger_name=None, root=opt['path_cd']['log'], phase='train', level=logging.INFO, screen=True) Logger.setup_logger(logger_name='test', root=opt['path_cd']['log'], phase='test', level=logging.INFO) logger = logging.getLogger('base') logger.info(Logger.dict2str(opt)) for phase, dataset_opt in opt['datasets'].items(): if phase == 'train' and args.phase != 'test': print("Creating [train] change-detection dataloader") train_set = Data.create_cd_dataset(dataset_opt=dataset_opt, phase=phase) train_loader = Data.create_cd_dataloader(train_set, dataset_opt, phase) opt['len_train_dataloader'] = len(train_loader) elif phase == 'val' and args.phase != 'test': print("Creating [val] change-detection dataloader") val_set = Data.create_cd_dataset(dataset_opt=dataset_opt, phase=phase) val_loader = Data.create_cd_dataloader(val_set, dataset_opt, phase) opt['len_val_dataloader'] = len(val_loader) elif phase == 'test': print("Creating [test] change-detection dataloader") test_set = Data.create_cd_dataset(dataset_opt=dataset_opt, phase=phase) test_loader = Data.create_cd_dataloader(test_set, dataset_opt, phase) opt['len_test_dataloader'] = len(test_loader) logger.info('Dataset initialization completed') cd_model = Model.create_CD_model(opt) if opt['model']['loss'] == 'ce_dice': loss_fun = ce_dice elif opt['model']['loss'] == 'ce': loss_fun = cross_entropy elif opt['model']['loss'] == 'dice': loss_fun = dice elif opt['model']['loss'] == 'ce2_dice1': loss_fun = ce2_dice1 elif opt['model']['loss'] == 'ce1_dice2': loss_fun = ce1_dice2 if opt['train']["optimizer"]["type"] == 'adam': optimer = optim.Adam(cd_model.parameters(), lr=opt['train']["optimizer"]["lr"]) elif opt['train']["optimizer"]["type"] == 'adamw': optimer = optim.AdamW(cd_model.parameters(), lr=opt['train']["optimizer"]["lr"]) elif opt['train']["optimizer"]["type"] == 'sgd': optimer = optim.SGD(cd_model.parameters(), lr=opt['train']["optimizer"]["lr"], momentum=0.9, weight_decay=5e-4) device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') cd_model.to(device) if len(opt['gpu_ids']) > 0: cd_model = nn.DataParallel(cd_model) metric = ConfuseMatrixMeter(n_class=2) log_dict = OrderedDict() if opt['phase'] == 'train': best_mF1 = 0.0 for current_epoch in range(0, opt['train']['n_epoch']): print("......Training Started......") metric.clear() cd_model.train() train_result_path = '{}/train/{}'.format(opt['path_cd']['result'], current_epoch) os.makedirs(train_result_path, exist_ok=True) message = 'Current learning rate: %0.7f\n \n' % optimer.param_groups[0]['lr'] logger.info(message) for current_step, train_data in enumerate(train_loader): train_im1 = train_data['A'].to(device) train_im2 = train_data['B'].to(device) pred_img = cd_model(train_im1, train_im2) gt = train_data['L'].to(device).long() train_loss = loss_fun(pred_img, gt) optimer.zero_grad() train_loss.backward() optimer.step() log_dict['loss'] = train_loss.item() G_pred = pred_img.detach() G_pred = torch.argmax(G_pred, dim=1) current_score = metric.update_cm(pr=G_pred.cpu().numpy(), gt=gt.detach().cpu().numpy()) log_dict['running_acc'] = current_score.item() if current_step % opt['train']['train_print_iter'] == 0: logs = log_dict message = '[Training Change Detection]. Epoch: [%d/%d]. Iteration: [%d/%d], Loss: %.5f, Current mF1: %.5f\n' % \ (current_epoch, opt['train']['n_epoch'] - 1, current_step, len(train_loader), logs['loss'], logs['running_acc']) logger.info(message) scores = metric.get_scores() epoch_acc = scores['mf1'] log_dict['epoch_acc'] = epoch_acc.item() for k, v in scores.items(): log_dict[k] = v logs = log_dict message = '[Training Change Detection (Epoch Summary)]: Epoch: [%d/%d]. Current mF1=%.5f \n' % \ (current_epoch, opt['train']['n_epoch'] - 1, logs['epoch_acc']) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) message += '\n' logger.info(message) metric.clear() cd_model.eval() with torch.no_grad(): if current_epoch % opt['train']['val_freq'] == 0: val_result_path = '{}/val/{}'.format(opt['path_cd']['result'], current_epoch) os.makedirs(val_result_path, exist_ok=True) for current_step, val_data in enumerate(val_loader): val_img1 = val_data['A'].to(device) val_img2 = val_data['B'].to(device) pred_img = cd_model(val_img1, val_img2) gt = val_data['L'].to(device).long() val_loss = loss_fun(pred_img, gt) log_dict['loss'] = val_loss.item() G_pred = pred_img.detach() G_pred = torch.argmax(G_pred, dim=1) current_score = metric.update_cm(pr=G_pred.cpu().numpy(), gt=gt.detach().cpu().numpy()) log_dict['running_acc'] = current_score.item() if current_step % opt['train']['val_print_iter'] == 0: logs = log_dict message = '[Validation Change Detection]. Epoch: [%d/%d]. Iteration: [%d/%d], Current mF1: %.5f\n' % \ (current_epoch, opt['train']['n_epoch'] - 1, current_step, len(val_loader), logs['running_acc']) logger.info(message) out_dict = OrderedDict() out_dict['pred_cm'] = torch.argmax(pred_img, dim=1, keepdim=False) out_dict['gt_cm'] = gt visuals = out_dict img_mode = "grid" if img_mode == "single": img_A = Metrics.tensor2img(val_data['A'], out_type=np.uint8, min_max=(-1, 1)) # uint8 img_B = Metrics.tensor2img(val_data['B'], out_type=np.uint8, min_max=(-1, 1)) # uint8 gt_cm = Metrics.tensor2img(visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1), out_type=np.uint8, min_max=(0, 1)) # uint8 pred_cm = Metrics.tensor2img(visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1), out_type=np.uint8, min_max=(0, 1)) # uint8 Metrics.save_img( img_A, '{}/img_A_e{}_b{}.png'.format(val_result_path, current_epoch, current_step)) Metrics.save_img( img_B, '{}/img_B_e{}_b{}.png'.format(val_result_path, current_epoch, current_step)) Metrics.save_img( pred_cm, '{}/img_pred_e{}_b{}.png'.format(val_result_path, current_epoch, current_step)) Metrics.save_img( gt_cm, '{}/img_gt_e{}_b{}.png'.format(val_result_path, current_epoch, current_step)) else: visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0 visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0 grid_img = torch.cat((val_data['A'].to(device), val_data['B'].to(device), visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1), visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1)), dim=0) grid_img = Metrics.tensor2img(grid_img) # uint8 Metrics.save_img( grid_img,'{}/img_A_B_pred_gt_e{}_b{}.png'.format(val_result_path, current_epoch, current_step)) scores = metric.get_scores() epoch_acc = scores['mf1'] log_dict['epoch_acc'] = epoch_acc.item() for k, v in scores.items(): log_dict[k] = v logs = log_dict message = '[Validation Change Detection (Epoch Summary)]: Epoch: [%d/%d]. Epoch mF1=%.5f \n' % \ (current_epoch, opt['train']['n_epoch'], logs['epoch_acc']) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) message += '\n' logger.info(message) if logs['epoch_acc'] > best_mF1: is_best_model = True best_mF1 = logs['epoch_acc'] logger.info('[Validation CD Phase] Best model updated, saving current best model and training state.') save_network(opt, current_epoch, cd_model, optimer, is_best_model) else: is_best_model = False logger.info('[Validation CD Phase] Saving current change detection model and training state.') logger.info('--- Proceed to next epoch ----\n \n') metric.clear() get_scheduler(optimizer=optimer, args=opt['train']).step() logger.info('Training finished.') else: logger.info('Begin Model Evaluation (testing).') test_result_path = '{}/test/'.format(opt['path_cd']['result']) os.makedirs(test_result_path, exist_ok=True) logger_test = logging.getLogger('test') load_path = opt["path_cd"]["resume_state"] print(load_path) if load_path is not None: logger.info('Loading pretrained model for CD model [{:s}] ...'.format(load_path)) gen_path = '{}_gen.pth'.format(load_path) opt_path = '{}_opt.pth'.format(load_path) cd_model = Model.create_CD_model(opt) cd_model.load_state_dict(torch.load(gen_path), strict=True) cd_model.to(device) metric.clear() cd_model.eval() with torch.no_grad(): for current_step, test_data in enumerate(test_loader): test_img1 = test_data['A'].to(device) test_img2 = test_data['B'].to(device) pred_img = cd_model(test_img1, test_img2) gt = test_data['L'].to(device).long() G_pred = pred_img.detach() G_pred = torch.argmax(G_pred, dim=1) current_score = metric.update_cm(pr=G_pred.cpu().numpy(), gt=gt.detach().cpu().numpy()) log_dict['running_acc'] = current_score.item() logs = log_dict message = '[Testing Change Detection]. Iteration: [%d/%d], running mF1: %.5f\n' % \ (current_step, len(test_loader), logs['running_acc']) logger_test.info(message) out_dict = OrderedDict() out_dict['pred_cm'] = torch.argmax(pred_img, dim=1, keepdim=False) out_dict['gt_cm'] = gt visuals = out_dict img_mode = 'single' if img_mode == 'single': visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0 visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0 img_A = Metrics.tensor2img(test_data['A'], out_type=np.uint8, min_max=(-1, 1)) # uint8 img_B = Metrics.tensor2img(test_data['B'], out_type=np.uint8, min_max=(-1, 1)) # uint8 gt_cm = Metrics.tensor2img(visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1), out_type=np.uint8, min_max=(0, 1)) # uint8 pred_cm = Metrics.tensor2img(visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1), out_type=np.uint8, min_max=(0, 1)) # uint8 Metrics.save_img( img_A, '{}/img_A_{}.png'.format(test_result_path, current_step)) Metrics.save_img( img_B, '{}/img_B_{}.png'.format(test_result_path, current_step)) Metrics.save_img( pred_cm, '{}/img_pred_cm{}.png'.format(test_result_path, current_step)) Metrics.save_img( gt_cm, '{}/img_gt_cm{}.png'.format(test_result_path, current_step)) else: visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0 visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0 grid_img = torch.cat((test_data['A'], test_data['B'], visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1), visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1)), dim=0) grid_img = Metrics.tensor2img(grid_img) # uint8 Metrics.save_img( grid_img, '{}/img_A_B_pred_gt_{}.png'.format(test_result_path, current_step)) scores = metric.get_scores() epoch_acc = scores['mf1'] log_dict['epoch_acc'] = epoch_acc.item() for k, v in scores.items(): log_dict[k] = v logs = log_dict message = '[Test Change Detection Summary]: Test mF1=%.5f \n' % \ (logs['epoch_acc']) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) message += '\n' logger_test.info(message) logger.info('End of testing...')