|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
def dice(x, y): |
|
intersect = np.sum(np.sum(np.sum(x * y))) |
|
y_sum = np.sum(np.sum(np.sum(y))) |
|
if y_sum == 0: |
|
return 0.0 |
|
x_sum = np.sum(np.sum(np.sum(x))) |
|
return 2 * intersect / (x_sum + y_sum) |
|
|
|
|
|
class AverageMeter(object): |
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = np.where(self.count > 0, self.sum / self.count, self.sum) |
|
|
|
|
|
def distributed_all_gather( |
|
tensor_list, valid_batch_size=None, out_numpy=False, world_size=None, no_barrier=False, is_valid=None |
|
): |
|
if world_size is None: |
|
world_size = torch.distributed.get_world_size() |
|
if valid_batch_size is not None: |
|
valid_batch_size = min(valid_batch_size, world_size) |
|
elif is_valid is not None: |
|
is_valid = torch.tensor(bool(is_valid), dtype=torch.bool, device=tensor_list[0].device) |
|
if not no_barrier: |
|
torch.distributed.barrier() |
|
tensor_list_out = [] |
|
with torch.no_grad(): |
|
if is_valid is not None: |
|
is_valid_list = [torch.zeros_like(is_valid) for _ in range(world_size)] |
|
torch.distributed.all_gather(is_valid_list, is_valid) |
|
is_valid = [x.item() for x in is_valid_list] |
|
for tensor in tensor_list: |
|
gather_list = [torch.zeros_like(tensor) for _ in range(world_size)] |
|
torch.distributed.all_gather(gather_list, tensor) |
|
if valid_batch_size is not None: |
|
gather_list = gather_list[:valid_batch_size] |
|
elif is_valid is not None: |
|
gather_list = [g for g, v in zip(gather_list, is_valid_list) if v] |
|
if out_numpy: |
|
gather_list = [t.cpu().numpy() for t in gather_list] |
|
tensor_list_out.append(gather_list) |
|
return tensor_list_out |
|
|