File size: 8,503 Bytes
0e37bb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import bisect
import csv
import io
import time
import numpy as np
import torch
from torch.utils.data import _utils
from torch.utils.data.dataloader import ExceptionWrapper, _DatasetKind, _MultiProcessingDataLoaderIter
from src.utils.monitoring import ResourceMonitoringThread
class ConcatIndices:
"""Helper to map indices of concatenated/mixed datasets to the sample index for the corresponding dataset."""
cumulative_sizes: np.ndarray
def __init__(self, sizes):
self.cumulative_sizes = np.cumsum(sizes)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
# Returns a pair (dataset_idx, sample_idx)
if idx < 0 or idx >= len(self):
raise ValueError(f"index must be between 0 and the total size ({len(self)})")
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
return dataset_idx, idx
return dataset_idx, idx - self.cumulative_sizes[dataset_idx - 1]
class CSVLogger(object):
"""An append-to CSV abstraction. File I/O requires a flush."""
def __init__(self, fname, header):
"""Write header to internal buffers."""
self.fname = fname
self.buffer = io.StringIO()
self.writer = csv.writer(self.buffer, quoting=csv.QUOTE_NONNUMERIC)
self.writer.writerow(header)
self.initialized = False
def writerow(self, row) -> None:
"""Write row to internal buffers."""
self.writer.writerow(row)
def flush(self) -> None:
"""Flush buffer to file."""
# Overwrite old file
mode = "a+" if self.initialized else "w"
with open(self.fname, mode, newline="") as f:
f.write(self.buffer.getvalue())
self.buffer = io.StringIO()
self.writer = csv.writer(self.buffer, quoting=csv.QUOTE_NONNUMERIC)
self.initialized = True
class MonitoredDataset(torch.utils.data.Dataset):
"""Implement resource monitoring on a per-worker basis.
The sampling occurs every monitor_interval seconds and writes the log
every log_interval seconds to a file specified by log_filename, which
maps a worker id to a file using the '%w' placeholder.
Warning: Do not call this dataset before it is consumed in the DataLoader.
"""
def __init__(
self, dataset: torch.utils.data.Dataset, log_filename: str, log_interval: float, monitor_interval: float
):
self.dataset = dataset
self.log_filename = str(log_filename)
self.log_interval = log_interval
self.monitor_interval = monitor_interval
self._csv_log = None
self._monitoring_thread = None
self._last_log_time = None
# Patch getitems dynamically
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
def __getitems__(self, index):
self.maybe_start_resource_monitoring()
return self.dataset.__getitems__(index)
self.__getitems__ = __getitems__
def __del__(self):
self.stop_resource_monitoring()
def __getitem__(self, index):
self.maybe_start_resource_monitoring()
return self.dataset.__getitem__(index)
def __len__(self):
return len(self.dataset)
def _elapsed_log_time(self):
if self._last_log_time is None:
return float("inf")
else:
return time.perf_counter() - self._last_log_time
def _update_log_time(self):
self._last_log_time = time.perf_counter()
def maybe_start_resource_monitoring(self):
if self._monitoring_thread is None:
def callback_fn(resource_sample):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id
if self._csv_log is None:
header = [f.name for f in resource_sample.fields()]
log_filename = self.log_filename.replace("%w", str(worker_id))
self._csv_log = CSVLogger(log_filename, header)
row_values = resource_sample.as_tuple()
self._csv_log.writerow(row_values)
if self._elapsed_log_time() > self.log_interval:
self._csv_log.flush()
self._update_log_time()
self._monitoring_thread = ResourceMonitoringThread(
None, self.monitor_interval, stats_callback_fn=callback_fn
)
self._monitoring_thread.start()
def stop_resource_monitoring(self):
if self._monitoring_thread:
self._monitoring_thread.stop()
class NondeterministicDataLoader(torch.utils.data.DataLoader):
"""Override torch dataloader to return out of order."""
def __init__(self, *args, **kwargs):
"""Pass through constructor."""
super().__init__(*args, **kwargs)
def _get_iterator(self):
if self.num_workers:
self.check_worker_number_rationality()
return _SloppyMultiProcessingDataLoaderIter(self)
else:
return super()._get_iterator()
class _SloppyMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
def __init__(self, *args, **kwargs):
"""Pass through constructor."""
super().__init__(*args, **kwargs)
def _next_data(self):
"""Adds out of order returns."""
while True:
# If the worker responsible for `self._rcvd_idx` has already ended
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
# we try to advance `self._rcvd_idx` to find the next valid index.
#
# This part needs to run in the loop because both the `self._get_data()`
# call and `_IterableDatasetStopIteration` check below can mark
# extra worker(s) as dead.
while self._rcvd_idx < self._send_idx:
info = self._task_info[self._rcvd_idx]
if info is None:
# Found a reordered tombstone
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
self._try_put_index()
else:
worker_id = info[0]
# has data or is still active
if len(info) == 2 or self._workers_status[worker_id]:
break
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
if not self._persistent_workers:
self._shutdown_workers()
raise StopIteration
# Now `self._rcvd_idx` is the batch index we want to fetch
# Check if the next sample has already been generated
if len(self._task_info[self._rcvd_idx]) == 2:
data = self._task_info.pop(self._rcvd_idx)[1]
return self._process_data(data)
assert not self._shutdown and self._tasks_outstanding > 0
idx, data = self._get_data()
self._tasks_outstanding -= 1
if self._dataset_kind == _DatasetKind.Iterable:
# Check for _IterableDatasetStopIteration
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
if self._persistent_workers:
self._workers_status[data.worker_id] = False
else:
self._mark_worker_as_unavailable(data.worker_id)
self._try_put_index()
continue
if idx != self._rcvd_idx:
# Tombstone to recieve later
self._task_info[idx] = None
if isinstance(data, ExceptionWrapper):
data.reraise()
return data
else:
del self._task_info[idx]
return self._process_data(data)
def get_worker_info():
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
num_workers = 1
worker_id = 0
else:
num_workers = worker_info.num_workers
worker_id = worker_info.id
return num_workers, worker_id
|