|
|
import sys |
|
|
from model.trainer import Trainer |
|
|
|
|
|
sys.path.insert(0, '.') |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.backends.cudnn as cudnn |
|
|
from torch.nn.parallel import gather |
|
|
import torch.optim.lr_scheduler |
|
|
|
|
|
import dataset.dataset as myDataLoader |
|
|
import dataset.Transforms as myTransforms |
|
|
from model.metric_tool import ConfuseMatrixMeter |
|
|
from model.utils import BCEDiceLoss, init_seed |
|
|
from PIL import Image |
|
|
import os |
|
|
import time |
|
|
import numpy as np |
|
|
from argparse import ArgumentParser |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def validate(args, val_loader, model, save_masks=False): |
|
|
model.eval() |
|
|
|
|
|
|
|
|
for m in model.modules(): |
|
|
if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)): |
|
|
m.track_running_stats = True |
|
|
m.eval() |
|
|
|
|
|
salEvalVal = ConfuseMatrixMeter(n_class=2) |
|
|
epoch_loss = [] |
|
|
|
|
|
if save_masks: |
|
|
mask_dir = f"{args.savedir}/pred_masks" |
|
|
os.makedirs(mask_dir, exist_ok=True) |
|
|
print(f"Saving prediction masks to: {mask_dir}") |
|
|
|
|
|
pbar = tqdm(enumerate(val_loader), total=len(val_loader), desc="Validating") |
|
|
|
|
|
for batch_idx, batched_inputs in pbar: |
|
|
img, target = batched_inputs |
|
|
|
|
|
batch_file_names = val_loader.sampler.data_source.file_list[ |
|
|
batch_idx * args.batch_size : (batch_idx + 1) * args.batch_size |
|
|
] |
|
|
|
|
|
pre_img = img[:, 0:3] |
|
|
post_img = img[:, 3:6] |
|
|
|
|
|
if args.onGPU: |
|
|
pre_img = pre_img.cuda() |
|
|
post_img = post_img.cuda() |
|
|
target = target.cuda() |
|
|
|
|
|
target = target.float() |
|
|
output = model(pre_img, post_img) |
|
|
loss = BCEDiceLoss(output, target) |
|
|
pred = (output > 0.5).long() |
|
|
|
|
|
if save_masks: |
|
|
pred_np = pred.cpu().numpy().astype(np.uint8) |
|
|
|
|
|
print(f"\nDebug - Batch {batch_idx}: {len(batch_file_names)} files, Mask shape: {pred_np.shape}") |
|
|
|
|
|
try: |
|
|
for i in range(pred_np.shape[0]): |
|
|
if i >= len(batch_file_names): |
|
|
print(f"Warning: Missing filename for mask {i}, using default") |
|
|
base_name = f"batch_{batch_idx}_mask_{i}" |
|
|
else: |
|
|
base_name = os.path.splitext(os.path.basename(batch_file_names[i]))[0] |
|
|
|
|
|
single_mask = pred_np[i, 0] |
|
|
|
|
|
if single_mask.ndim != 2: |
|
|
raise ValueError(f"Invalid mask shape: {single_mask.shape}") |
|
|
|
|
|
mask_path = f"{mask_dir}/{base_name}_pred.png" |
|
|
Image.fromarray(single_mask * 255).save(mask_path) |
|
|
print(f"Saved: {mask_path}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nError saving batch {batch_idx}: {str(e)}") |
|
|
print(f"Current mask shape: {single_mask.shape if 'single_mask' in locals() else 'N/A'}") |
|
|
print(f"Current file: {base_name if 'base_name' in locals() else 'N/A'}") |
|
|
|
|
|
if args.onGPU and torch.cuda.device_count() > 1: |
|
|
pred = gather(pred, 0, dim=0) |
|
|
|
|
|
f1 = salEvalVal.update_cm(pr=pred.cpu().numpy(), gt=target.cpu().numpy()) |
|
|
epoch_loss.append(loss.item()) |
|
|
|
|
|
pbar.set_postfix({'Loss': f"{loss.item():.4f}", 'F1': f"{f1:.4f}"}) |
|
|
|
|
|
average_loss = sum(epoch_loss) / len(epoch_loss) |
|
|
scores = salEvalVal.get_scores() |
|
|
return average_loss, scores |
|
|
|
|
|
def ValidateSegmentation(args): |
|
|
"""完整的验证流程主函数""" |
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
|
|
torch.backends.cudnn.benchmark = True |
|
|
init_seed(args.seed) |
|
|
|
|
|
|
|
|
args.savedir = os.path.join(args.savedir, |
|
|
f"{args.file_root}_iter_{args.max_steps}_lr_{args.lr}") |
|
|
os.makedirs(args.savedir, exist_ok=True) |
|
|
|
|
|
|
|
|
dataset_mapping = { |
|
|
'LEVIR': './levir_cd_256', |
|
|
'WHU': './whu_cd_256', |
|
|
'CLCD': './clcd_256', |
|
|
'SYSU': './sysu_256', |
|
|
'OSCD': './oscd_256' |
|
|
} |
|
|
args.file_root = dataset_mapping.get(args.file_root, args.file_root) |
|
|
|
|
|
|
|
|
model = Trainer(args.model_type).float() |
|
|
if args.onGPU: |
|
|
model = model.cuda() |
|
|
|
|
|
|
|
|
mean = [0.406, 0.456, 0.485, 0.406, 0.456, 0.485] |
|
|
std = [0.225, 0.224, 0.229, 0.225, 0.224, 0.229] |
|
|
|
|
|
valDataset = myTransforms.Compose([ |
|
|
myTransforms.Normalize(mean=mean, std=std), |
|
|
myTransforms.Scale(args.inWidth, args.inHeight), |
|
|
myTransforms.ToTensor() |
|
|
]) |
|
|
|
|
|
|
|
|
test_data = myDataLoader.Dataset(file_root=args.file_root, mode="test", transform=valDataset) |
|
|
testLoader = torch.utils.data.DataLoader( |
|
|
test_data, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=args.num_workers, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
|
|
|
logFileLoc = os.path.join(args.savedir, args.logFile) |
|
|
logger = open(logFileLoc, 'a' if os.path.exists(logFileLoc) else 'w') |
|
|
if not os.path.exists(logFileLoc): |
|
|
logger.write("\n%s\t%s\t%s\t%s\t%s\t%s\t%s" % |
|
|
('Epoch', 'Kappa', 'IoU', 'F1', 'Recall', 'Precision', 'OA')) |
|
|
logger.flush() |
|
|
|
|
|
|
|
|
model_file_name = os.path.join(args.savedir, 'best_model.pth') |
|
|
if not os.path.exists(model_file_name): |
|
|
raise FileNotFoundError(f"Model file not found: {model_file_name}") |
|
|
|
|
|
state_dict = torch.load(model_file_name) |
|
|
model.load_state_dict(state_dict) |
|
|
print(f"Loaded model from {model_file_name}") |
|
|
|
|
|
|
|
|
loss_test, score_test = validate(args, testLoader, model, save_masks=args.save_masks) |
|
|
|
|
|
|
|
|
print("\nTest Results:") |
|
|
print(f"Loss: {loss_test:.4f}") |
|
|
print(f"Kappa: {score_test['Kappa']:.4f}") |
|
|
print(f"IoU: {score_test['IoU']:.4f}") |
|
|
print(f"F1: {score_test['F1']:.4f}") |
|
|
print(f"Recall: {score_test['recall']:.4f}") |
|
|
print(f"Precision: {score_test['precision']:.4f}") |
|
|
print(f"OA: {score_test['OA']:.4f}") |
|
|
|
|
|
|
|
|
logger.write("\n%s\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" % |
|
|
('Test', score_test['Kappa'], score_test['IoU'], score_test['F1'], |
|
|
score_test['recall'], score_test['precision'], score_test['OA'])) |
|
|
logger.close() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = ArgumentParser() |
|
|
parser.add_argument('--file_root', default="LEVIR", |
|
|
help='Data directory | LEVIR | WHU | CLCD | SYSU | OSCD') |
|
|
parser.add_argument('--inWidth', type=int, default=256, help='Width of input image') |
|
|
parser.add_argument('--inHeight', type=int, default=256, help='Height of input image') |
|
|
parser.add_argument('--max_steps', type=int, default=80000, |
|
|
help='Max. number of iterations (for path naming)') |
|
|
parser.add_argument('--num_workers', type=int, default=4, |
|
|
help='Number of data loading workers') |
|
|
parser.add_argument('--model_type', type=str, default='small', |
|
|
help='Model type | tiny | small') |
|
|
parser.add_argument('--batch_size', type=int, default=16, |
|
|
help='Batch size for validation') |
|
|
parser.add_argument('--lr', type=float, default=2e-4, |
|
|
help='Learning rate (for path naming)') |
|
|
parser.add_argument('--seed', type=int, default=16, |
|
|
help='Random seed for reproducibility') |
|
|
parser.add_argument('--savedir', default='./results', |
|
|
help='Base directory to save results') |
|
|
parser.add_argument('--logFile', default='testLog.txt', |
|
|
help='File to save validation logs') |
|
|
parser.add_argument('--onGPU', default=True, |
|
|
type=lambda x: (str(x).lower() == 'true'), |
|
|
help='Run on GPU if True') |
|
|
parser.add_argument('--gpu_id', type=int, default=0, |
|
|
help='GPU device id') |
|
|
parser.add_argument('--save_masks', action='store_true', |
|
|
help='Save predicted masks to disk') |
|
|
|
|
|
args = parser.parse_args() |
|
|
print('Validation with args:') |
|
|
print(args) |
|
|
|
|
|
ValidateSegmentation(args) |
|
|
|
|
|
|