File size: 2,795 Bytes
8cf4db8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import logging
import core.logger as Logger
import data as Data

#Create chaneg detection dataset
import logging
import torch.utils.data

def create_cd_dataloader(dataset, dataset_opt, phase):
    if phase == 'train' or 'val' or 'test':
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=dataset_opt['batch_size'],
            shuffle=dataset_opt['use_shuffle'],
            num_workers=dataset_opt['num_workers'],
            pin_memory=True)
    else:
        raise NotImplementedError(
            'Dataloader [{:s}] is not found'.format(phase)
        )

def create_cd_dataset(dataset_opt, phase):
    from data.CDDataset import CDDataset
    print(dataset_opt["datasetroot"])
    dataset = CDDataset(root_dir=dataset_opt["datasetroot"],
                        resolution=dataset_opt["resolution"],
                        split=phase,
                        data_len=dataset_opt["data_len"]
                        )
    logger = logging.getLogger('base')
    logger.info('Dataset [{:s} - {:s} - {:s}] is created'.format(dataset.__class__.__name__,
                                                                 dataset_opt['name'],
                                                                 phase))
    return dataset

def create_scd_dataset(dataset_opt, phase):
    from data.CDDataset import SCDDataset
    print(dataset_opt["datasetroot"])
    dataset = SCDDataset(root_dir=dataset_opt["datasetroot"],
                        resolution=dataset_opt["resolution"],
                        split=phase,
                        data_len=dataset_opt["data_len"]
                        )
    logger = logging.getLogger('base')
    logger.info('Dataset [{:s} - {:s} - {:s}] is created'.format(dataset.__class__.__name__,
                                                                 dataset_opt['name'],
                                                                 phase))
    return dataset


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, default='../config/levir.json')
    parser.add_argument('-p', '--phase', type=str, choices=['train', 'test'], default='train')
    parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)

    args = parser.parse_args()
    opt = Logger.parse(args)
    opt = Logger.dict_to_nonedict(opt)
    print(opt)

    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train' and args.phase != 'test':
            print("Creating [train] change-detection dataloader.")
            train_set  = Data.create_cd_dataset(dataset_opt, phase)
            train_loader = Data.create_cd_dataloader(train_set, dataset_opt, phase)