# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
from logging import getLogger | |
import torch | |
logger = getLogger() | |
def _no_grad_trunc_normal_(tensor, mean, std, a, b): | |
# Cut & paste from PyTorch official master until it's in a few official releases - RW | |
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | |
def norm_cdf(x): | |
# Computes standard normal cumulative distribution function | |
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 | |
with torch.no_grad(): | |
# Values are generated by using a truncated uniform distribution and | |
# then using the inverse CDF for the normal distribution. | |
# Get upper and lower cdf values | |
lower = norm_cdf((a - mean) / std) | |
upper = norm_cdf((b - mean) / std) | |
# Uniformly fill tensor with values from [lower, upper], then translate to | |
# [2*lower-1, 2*upper-1]. | |
tensor.uniform_(2 * lower - 1, 2 * upper - 1) | |
# Use inverse cdf transform for normal distribution to get truncated | |
# standard normal | |
tensor.erfinv_() | |
# Transform to proper mean, std | |
tensor.mul_(std * math.sqrt(2.0)) | |
tensor.add_(mean) | |
# Clamp to ensure it's in the proper range | |
tensor.clamp_(min=a, max=b) | |
return tensor | |
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): | |
# type: (Tensor, float, float, float, float) -> Tensor | |
return _no_grad_trunc_normal_(tensor, mean, std, a, b) | |
def repeat_interleave_batch(x, B, repeat): | |
N = len(x) // B | |
x = torch.cat([torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0) for i in range(N)], dim=0) | |
return x | |