|
"""Callback to measure training and testing time of a PyTorch Lightning module.""" |
|
|
|
|
|
|
|
|
|
import logging |
|
import time |
|
|
|
import torch |
|
from lightning.pytorch import Callback, LightningModule, Trainer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class TimerCallback(Callback): |
|
"""Callback that measures the training and testing time of a PyTorch Lightning module. |
|
|
|
Examples: |
|
>>> from anomalib.callbacks import TimerCallback |
|
>>> from anomalib.engine import Engine |
|
... |
|
>>> callbacks = [TimerCallback()] |
|
>>> engine = Engine(callbacks=callbacks) |
|
""" |
|
|
|
def __init__(self) -> None: |
|
self.start: float |
|
self.num_images: int = 0 |
|
|
|
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: |
|
"""Call when fit begins. |
|
|
|
Sets the start time to the time training started. |
|
|
|
Args: |
|
trainer (Trainer): PyTorch Lightning trainer. |
|
pl_module (LightningModule): Current training module. |
|
|
|
Returns: |
|
None |
|
""" |
|
del trainer, pl_module |
|
|
|
self.start = time.time() |
|
|
|
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: |
|
"""Call when fit ends. |
|
|
|
Prints the time taken for training. |
|
|
|
Args: |
|
trainer (Trainer): PyTorch Lightning trainer. |
|
pl_module (LightningModule): Current training module. |
|
|
|
Returns: |
|
None |
|
""" |
|
del trainer, pl_module |
|
logger.info("Training took %5.2f seconds", (time.time() - self.start)) |
|
|
|
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: |
|
"""Call when the test begins. |
|
|
|
Sets the start time to the time testing started. |
|
Goes over all the test dataloaders and adds the number of images in each. |
|
|
|
Args: |
|
trainer (Trainer): PyTorch Lightning trainer. |
|
pl_module (LightningModule): Current training module. |
|
|
|
Returns: |
|
None |
|
""" |
|
del pl_module |
|
|
|
self.start = time.time() |
|
self.num_images = 0 |
|
|
|
if trainer.test_dataloaders is not None: |
|
if isinstance(trainer.test_dataloaders, torch.utils.data.dataloader.DataLoader): |
|
self.num_images += len(trainer.test_dataloaders.dataset) |
|
else: |
|
for dataloader in trainer.test_dataloaders: |
|
self.num_images += len(dataloader.dataset) |
|
|
|
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: |
|
"""Call when the test ends. |
|
|
|
Prints the time taken for testing and the throughput in frames per second. |
|
|
|
Args: |
|
trainer (Trainer): PyTorch Lightning trainer. |
|
pl_module (LightningModule): Current training module. |
|
|
|
Returns: |
|
None |
|
""" |
|
del pl_module |
|
|
|
testing_time = time.time() - self.start |
|
output = f"Testing took {testing_time} seconds\nThroughput " |
|
if trainer.test_dataloaders is not None: |
|
if isinstance(trainer.test_dataloaders, torch.utils.data.dataloader.DataLoader): |
|
test_data_loader = trainer.test_dataloaders |
|
else: |
|
test_data_loader = trainer.test_dataloaders[0] |
|
output += f"(batch_size={test_data_loader.batch_size})" |
|
output += f" : {self.num_images/testing_time} FPS" |
|
logger.info(output) |
|
|