| | import argparse
|
| | from torch.utils.data import DataLoader
|
| | import lightning as L
|
| | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
|
| |
|
| | from datasets import dataset_dict, RandomConcatSampler
|
| | from model import PL_RelPose
|
| | from utils import seed_torch
|
| | from configs.default import get_cfg_defaults
|
| |
|
| |
|
| | def main(args):
|
| | config = get_cfg_defaults()
|
| | config.merge_from_file(args.config)
|
| |
|
| | task = config.DATASET.TASK
|
| | dataset = config.DATASET.DATA_SOURCE
|
| |
|
| | batch_size = config.TRAINER.BATCH_SIZE
|
| | num_workers = config.TRAINER.NUM_WORKERS
|
| | pin_memory = config.TRAINER.PIN_MEMORY
|
| | n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
|
| | lr = config.TRAINER.LEARNING_RATE
|
| | epochs = config.TRAINER.EPOCHS
|
| | pct_start = config.TRAINER.PCT_START
|
| |
|
| | num_keypoints = config.MODEL.NUM_KEYPOINTS
|
| | n_layers = config.MODEL.N_LAYERS
|
| | num_heads = config.MODEL.NUM_HEADS
|
| | features = config.MODEL.FEATURES
|
| |
|
| | seed = config.RANDOM_SEED
|
| | seed_torch(seed)
|
| |
|
| | build_fn = dataset_dict[task][dataset]
|
| | trainset = build_fn('train', config)
|
| | validset = build_fn('val', config)
|
| |
|
| | if dataset == 'scannet' or dataset == 'megadepth' or dataset == 'linemod' or dataset == 'ho3d' or dataset == 'mapfree':
|
| | sampler = RandomConcatSampler(
|
| | trainset,
|
| | n_samples_per_subset=n_samples_per_subset,
|
| | subset_replacement=True,
|
| | shuffle=True,
|
| | seed=seed
|
| | )
|
| | trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, sampler=sampler)
|
| | else:
|
| | trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)
|
| |
|
| | validloader = DataLoader(validset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
|
| |
|
| | if args.weights is None:
|
| | pl_relpose = PL_RelPose(
|
| | task=task,
|
| | lr=lr,
|
| | epochs=epochs,
|
| | pct_start=pct_start,
|
| | n_layers=n_layers,
|
| | num_heads=num_heads,
|
| | num_keypoints=num_keypoints,
|
| | features=features,
|
| | )
|
| | else:
|
| | pl_relpose = PL_RelPose.load_from_checkpoint(
|
| | checkpoint_path=args.weights,
|
| | task=task,
|
| | lr=lr,
|
| | epochs=epochs,
|
| | pct_start=pct_start,
|
| | n_layers=n_layers,
|
| | num_heads=num_heads,
|
| | num_keypoints=num_keypoints,
|
| | )
|
| |
|
| | lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
| | latest_checkpoint_callback = ModelCheckpoint()
|
| | best_checkpoint_callback = ModelCheckpoint(monitor='valid/auc@20', mode='max')
|
| | trainer = L.Trainer(
|
| | devices=[0],
|
| |
|
| |
|
| | max_epochs=epochs,
|
| | callbacks=[lr_monitor, latest_checkpoint_callback, best_checkpoint_callback],
|
| | precision="bf16-mixed",
|
| |
|
| | )
|
| |
|
| | trainer.fit(pl_relpose, trainloader, validloader, ckpt_path=args.resume)
|
| |
|
| |
|
| | def get_parser():
|
| | parser = argparse.ArgumentParser()
|
| | parser.add_argument('config', type=str, help='.yaml configure file path')
|
| | parser.add_argument('--resume', type=str, default=None)
|
| | parser.add_argument('--weights', type=str, default=None)
|
| |
|
| | return parser
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | parser = get_parser()
|
| | args = parser.parse_args()
|
| | main(args)
|
| |
|