File size: 38,083 Bytes
9d199d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 |
"""
Neural network modules for the HiFi-GAN: Generative Adversarial Networks for
Efficient and High Fidelity Speech Synthesis
For more details: https://arxiv.org/pdf/2010.05646.pdf, https://arxiv.org/abs/2406.10735
Authors
* Jarod Duret 2021
* Yingzhi WANG 2022
"""
# Adapted from https://github.com/jik876/hifi-gan/ and https://github.com/coqui-ai/TTS/
# MIT License
# Copyright (c) 2020 Jungil Kong
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import json
import logging
import math
import os
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
LRELU_SLOPE = 0.1
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
"""This function computes the number of elements to add for zero-padding.
Arguments
---------
L_in : int
stride: int
kernel_size : int
dilation : int
Returns
-------
padding : int
The size of the padding to be added
"""
if stride > 1:
padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
else:
L_out = (
math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
)
padding = [
math.floor((L_in - L_out) / 2),
math.floor((L_in - L_out) / 2),
]
return padding
def get_padding_elem_transposed(
L_out: int,
L_in: int,
stride: int,
kernel_size: int,
dilation: int,
output_padding: int,
):
"""This function computes the required padding size for transposed convolution
Arguments
---------
L_out : int
L_in : int
stride: int
kernel_size : int
dilation : int
output_padding : int
Returns
-------
padding : int
The size of the padding to be applied
"""
padding = -0.5 * (
L_out
- (L_in - 1) * stride
- dilation * (kernel_size - 1)
- output_padding
- 1
)
return int(padding)
class Conv1d(nn.Module):
"""This function implements 1d convolution.
Arguments
---------
out_channels : int
It is the number of output channels.
kernel_size : int
Kernel size of the convolutional filters.
input_shape : tuple
The shape of the input. Alternatively use ``in_channels``.
in_channels : int
The number of input channels. Alternatively use ``input_shape``.
stride : int
Stride factor of the convolutional filters. When the stride factor > 1,
a decimation in time is performed.
dilation : int
Dilation factor of the convolutional filters.
padding : str
(same, valid, causal). If "valid", no padding is performed.
If "same" and stride is 1, output shape is the same as the input shape.
"causal" results in causal (dilated) convolutions.
groups : int
Number of blocked connections from input channels to output channels.
bias : bool
Whether to add a bias term to convolution operation.
padding_mode : str
This flag specifies the type of padding. See torch.nn documentation
for more information.
skip_transpose : bool
If False, uses batch x time x channel convention of speechbrain.
If True, uses batch x channel x time convention.
weight_norm : bool
If True, use weight normalization,
to be removed with self.remove_weight_norm() at inference
conv_init : str
Weight initialization for the convolution network
default_padding: str or int
This sets the default padding mode that will be used by the pytorch Conv1d backend.
Example
-------
>>> inp_tensor = torch.rand([10, 40, 16])
>>> cnn_1d = Conv1d(
... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5
... )
>>> out_tensor = cnn_1d(inp_tensor)
>>> out_tensor.shape
torch.Size([10, 40, 8])
"""
def __init__(
self,
out_channels,
kernel_size,
input_shape=None,
in_channels=None,
stride=1,
dilation=1,
padding="same",
groups=1,
bias=True,
padding_mode="reflect",
skip_transpose=False,
weight_norm=False,
conv_init=None,
default_padding=0,
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.padding_mode = padding_mode
self.unsqueeze = False
self.skip_transpose = skip_transpose
if input_shape is None and in_channels is None:
raise ValueError("Must provide one of input_shape or in_channels")
if in_channels is None:
in_channels = self._check_input_shape(input_shape)
self.in_channels = in_channels
self.conv = nn.Conv1d(
in_channels,
out_channels,
self.kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=default_padding,
groups=groups,
bias=bias,
)
if conv_init == "kaiming":
nn.init.kaiming_normal_(self.conv.weight)
elif conv_init == "zero":
nn.init.zeros_(self.conv.weight)
elif conv_init == "normal":
nn.init.normal_(self.conv.weight, std=1e-6)
if weight_norm:
self.conv = nn.utils.weight_norm(self.conv)
def forward(self, x):
"""Returns the output of the convolution.
Arguments
---------
x : torch.Tensor (batch, time, channel)
input to convolve. 2d or 4d tensors are expected.
Returns
-------
wx : torch.Tensor
The convolved outputs.
"""
if not self.skip_transpose:
x = x.transpose(1, -1)
if self.unsqueeze:
x = x.unsqueeze(1)
if self.padding == "same":
x = self._manage_padding(
x, self.kernel_size, self.dilation, self.stride
)
elif self.padding == "causal":
num_pad = (self.kernel_size - 1) * self.dilation
x = F.pad(x, (num_pad, 0))
elif self.padding == "valid":
pass
else:
raise ValueError(
"Padding must be 'same', 'valid' or 'causal'. Got "
+ self.padding
)
wx = self.conv(x)
if self.unsqueeze:
wx = wx.squeeze(1)
if not self.skip_transpose:
wx = wx.transpose(1, -1)
return wx
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
"""This function performs zero-padding on the time axis
such that their lengths is unchanged after the convolution.
Arguments
---------
x : torch.Tensor
Input tensor.
kernel_size : int
Size of kernel.
dilation : int
Dilation used.
stride : int
Stride.
Returns
-------
x : torch.Tensor
The padded outputs.
"""
# Detecting input shape
L_in = self.in_channels
# Time padding
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
# Applying padding
x = F.pad(x, padding, mode=self.padding_mode)
return x
def _check_input_shape(self, shape):
"""Checks the input shape and returns the number of input channels."""
if len(shape) == 2:
self.unsqueeze = True
in_channels = 1
elif self.skip_transpose:
in_channels = shape[1]
elif len(shape) == 3:
in_channels = shape[2]
else:
raise ValueError(
"conv1d expects 2d, 3d inputs. Got " + str(len(shape))
)
# Kernel size must be odd
if not self.padding == "valid" and self.kernel_size % 2 == 0:
raise ValueError(
"The field kernel size must be an odd number. Got %s."
% (self.kernel_size)
)
return in_channels
def remove_weight_norm(self):
"""Removes weight normalization at inference if used during training."""
self.conv = nn.utils.remove_weight_norm(self.conv)
class Conv2d(nn.Module):
"""This function implements 2d convolution.
Arguments
---------
out_channels : int
It is the number of output channels.
kernel_size : tuple
Kernel size of the 2d convolutional filters over time and frequency
axis.
input_shape : tuple
The shape of the input. Alternatively use ``in_channels``.
in_channels : int
The number of input channels. Alternatively use ``input_shape``.
stride: int
Stride factor of the 2d convolutional filters over time and frequency
axis.
dilation : int
Dilation factor of the 2d convolutional filters over time and
frequency axis.
padding : str
(same, valid, causal).
If "valid", no padding is performed.
If "same" and stride is 1, output shape is same as input shape.
If "causal" then proper padding is inserted to simulate causal convolution on the first spatial dimension.
(spatial dim 1 is dim 3 for both skip_transpose=False and skip_transpose=True)
groups : int
This option specifies the convolutional groups. See torch.nn
documentation for more information.
bias : bool
If True, the additive bias b is adopted.
padding_mode : str
This flag specifies the type of padding. See torch.nn documentation
for more information.
max_norm : float
kernel max-norm.
swap : bool
If True, the convolution is done with the format (B, C, W, H).
If False, the convolution is dine with (B, H, W, C).
Active only if skip_transpose is False.
skip_transpose : bool
If False, uses batch x spatial.dim2 x spatial.dim1 x channel convention of speechbrain.
If True, uses batch x channel x spatial.dim1 x spatial.dim2 convention.
weight_norm : bool
If True, use weight normalization,
to be removed with self.remove_weight_norm() at inference
conv_init : str
Weight initialization for the convolution network
Example
-------
>>> inp_tensor = torch.rand([10, 40, 16, 8])
>>> cnn_2d = Conv2d(
... input_shape=inp_tensor.shape, out_channels=5, kernel_size=(7, 3)
... )
>>> out_tensor = cnn_2d(inp_tensor)
>>> out_tensor.shape
torch.Size([10, 40, 16, 5])
"""
def __init__(
self,
out_channels,
kernel_size,
input_shape=None,
in_channels=None,
stride=(1, 1),
dilation=(1, 1),
padding="same",
groups=1,
bias=True,
padding_mode="reflect",
max_norm=None,
swap=False,
skip_transpose=False,
weight_norm=False,
conv_init=None,
):
super().__init__()
# handle the case if some parameter is int
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation)
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.padding_mode = padding_mode
self.unsqueeze = False
self.max_norm = max_norm
self.swap = swap
self.skip_transpose = skip_transpose
if input_shape is None and in_channels is None:
raise ValueError("Must provide one of input_shape or in_channels")
if in_channels is None:
in_channels = self._check_input(input_shape)
self.in_channels = in_channels
# Weights are initialized following pytorch approach
self.conv = nn.Conv2d(
self.in_channels,
out_channels,
self.kernel_size,
stride=self.stride,
padding=0,
dilation=self.dilation,
groups=groups,
bias=bias,
)
if conv_init == "kaiming":
nn.init.kaiming_normal_(self.conv.weight)
elif conv_init == "zero":
nn.init.zeros_(self.conv.weight)
if weight_norm:
self.conv = nn.utils.weight_norm(self.conv)
def forward(self, x):
"""Returns the output of the convolution.
Arguments
---------
x : torch.Tensor (batch, time, channel)
input to convolve. 2d or 4d tensors are expected.
Returns
-------
x : torch.Tensor
The output of the convolution.
"""
if not self.skip_transpose:
x = x.transpose(1, -1)
if self.swap:
x = x.transpose(-1, -2)
if self.unsqueeze:
x = x.unsqueeze(1)
if self.padding == "same":
x = self._manage_padding(
x, self.kernel_size, self.dilation, self.stride
)
elif self.padding == "causal":
num_pad = (self.kernel_size[0] - 1) * self.dilation[1]
x = F.pad(x, (0, 0, num_pad, 0))
elif self.padding == "valid":
pass
else:
raise ValueError(
"Padding must be 'same','valid' or 'causal'. Got "
+ self.padding
)
if self.max_norm is not None:
self.conv.weight.data = torch.renorm(
self.conv.weight.data, p=2, dim=0, maxnorm=self.max_norm
)
wx = self.conv(x)
if self.unsqueeze:
wx = wx.squeeze(1)
if not self.skip_transpose:
wx = wx.transpose(1, -1)
if self.swap:
wx = wx.transpose(1, 2)
return wx
def _manage_padding(
self,
x,
kernel_size: Tuple[int, int],
dilation: Tuple[int, int],
stride: Tuple[int, int],
):
"""This function performs zero-padding on the time and frequency axes
such that their lengths is unchanged after the convolution.
Arguments
---------
x : torch.Tensor
Input to be padded
kernel_size : int
Size of the kernel for computing padding
dilation : int
Dilation rate for computing padding
stride: int
Stride for computing padding
Returns
-------
x : torch.Tensor
The padded outputs.
"""
# Detecting input shape
L_in = self.in_channels
# Time padding
padding_time = get_padding_elem(
L_in, stride[-1], kernel_size[-1], dilation[-1]
)
padding_freq = get_padding_elem(
L_in, stride[-2], kernel_size[-2], dilation[-2]
)
padding = padding_time + padding_freq
# Applying padding
x = nn.functional.pad(x, padding, mode=self.padding_mode)
return x
def _check_input(self, shape):
"""Checks the input shape and returns the number of input channels."""
if len(shape) == 3:
self.unsqueeze = True
in_channels = 1
elif len(shape) == 4:
in_channels = shape[3]
else:
raise ValueError("Expected 3d or 4d inputs. Got " + len(shape))
# Kernel size must be odd
if not self.padding == "valid" and (
self.kernel_size[0] % 2 == 0 or self.kernel_size[1] % 2 == 0
):
raise ValueError(
"The field kernel size must be an odd number. Got %s."
% (self.kernel_size)
)
return in_channels
def remove_weight_norm(self):
"""Removes weight normalization at inference if used during training."""
self.conv = nn.utils.remove_weight_norm(self.conv)
class ConvTranspose1d(nn.Module):
"""This class implements 1d transposed convolution with speechbrain.
Transpose convolution is normally used to perform upsampling.
Arguments
---------
out_channels : int
It is the number of output channels.
kernel_size : int
Kernel size of the convolutional filters.
input_shape : tuple
The shape of the input. Alternatively use ``in_channels``.
in_channels : int
The number of input channels. Alternatively use ``input_shape``.
stride : int
Stride factor of the convolutional filters. When the stride factor > 1,
upsampling in time is performed.
dilation : int
Dilation factor of the convolutional filters.
padding : str or int
To have in output the target dimension, we suggest tuning the kernel
size and the padding properly. We also support the following function
to have some control over the padding and the corresponding output
dimensionality.
if "valid", no padding is applied
if "same", padding amount is inferred so that the output size is closest
to possible to input size. Note that for some kernel_size / stride combinations
it is not possible to obtain the exact same size, but we return the closest
possible size.
if "factor", padding amount is inferred so that the output size is closest
to inputsize*stride. Note that for some kernel_size / stride combinations
it is not possible to obtain the exact size, but we return the closest
possible size.
if an integer value is entered, a custom padding is used.
output_padding : int,
Additional size added to one side of the output shape
groups: int
Number of blocked connections from input channels to output channels.
Default: 1
bias: bool
If True, adds a learnable bias to the output
skip_transpose : bool
If False, uses batch x time x channel convention of speechbrain.
If True, uses batch x channel x time convention.
weight_norm : bool
If True, use weight normalization,
to be removed with self.remove_weight_norm() at inference
Example
-------
>>> from speechbrain.nnet.CNN import Conv1d, ConvTranspose1d
>>> inp_tensor = torch.rand([10, 12, 40]) #[batch, time, fea]
>>> convtranspose_1d = ConvTranspose1d(
... input_shape=inp_tensor.shape, out_channels=8, kernel_size=3, stride=2
... )
>>> out_tensor = convtranspose_1d(inp_tensor)
>>> out_tensor.shape
torch.Size([10, 25, 8])
>>> # Combination of Conv1d and ConvTranspose1d
>>> from speechbrain.nnet.CNN import Conv1d, ConvTranspose1d
>>> signal = torch.tensor([1,100])
>>> signal = torch.rand([1,100]) #[batch, time]
>>> conv1d = Conv1d(input_shape=signal.shape, out_channels=1, kernel_size=3, stride=2)
>>> conv_out = conv1d(signal)
>>> conv_t = ConvTranspose1d(input_shape=conv_out.shape, out_channels=1, kernel_size=3, stride=2, padding=1)
>>> signal_rec = conv_t(conv_out, output_size=[100])
>>> signal_rec.shape
torch.Size([1, 100])
>>> signal = torch.rand([1,115]) #[batch, time]
>>> conv_t = ConvTranspose1d(input_shape=signal.shape, out_channels=1, kernel_size=3, stride=2, padding='same')
>>> signal_rec = conv_t(signal)
>>> signal_rec.shape
torch.Size([1, 115])
>>> signal = torch.rand([1,115]) #[batch, time]
>>> conv_t = ConvTranspose1d(input_shape=signal.shape, out_channels=1, kernel_size=7, stride=2, padding='valid')
>>> signal_rec = conv_t(signal)
>>> signal_rec.shape
torch.Size([1, 235])
>>> signal = torch.rand([1,115]) #[batch, time]
>>> conv_t = ConvTranspose1d(input_shape=signal.shape, out_channels=1, kernel_size=7, stride=2, padding='factor')
>>> signal_rec = conv_t(signal)
>>> signal_rec.shape
torch.Size([1, 231])
>>> signal = torch.rand([1,115]) #[batch, time]
>>> conv_t = ConvTranspose1d(input_shape=signal.shape, out_channels=1, kernel_size=3, stride=2, padding=10)
>>> signal_rec = conv_t(signal)
>>> signal_rec.shape
torch.Size([1, 211])
"""
def __init__(
self,
out_channels,
kernel_size,
input_shape=None,
in_channels=None,
stride=1,
dilation=1,
padding=0,
output_padding=0,
groups=1,
bias=True,
skip_transpose=False,
weight_norm=False,
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.unsqueeze = False
self.skip_transpose = skip_transpose
if input_shape is None and in_channels is None:
raise ValueError("Must provide one of input_shape or in_channels")
if in_channels is None:
in_channels = self._check_input_shape(input_shape)
if self.padding == "same":
L_in = input_shape[-1] if skip_transpose else input_shape[1]
padding_value = get_padding_elem_transposed(
L_in,
L_in,
stride=stride,
kernel_size=kernel_size,
dilation=dilation,
output_padding=output_padding,
)
elif self.padding == "factor":
L_in = input_shape[-1] if skip_transpose else input_shape[1]
padding_value = get_padding_elem_transposed(
L_in * stride,
L_in,
stride=stride,
kernel_size=kernel_size,
dilation=dilation,
output_padding=output_padding,
)
elif self.padding == "valid":
padding_value = 0
elif type(self.padding) is int:
padding_value = padding
else:
raise ValueError("Not supported padding type")
self.conv = nn.ConvTranspose1d(
in_channels,
out_channels,
self.kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=padding_value,
groups=groups,
bias=bias,
)
if weight_norm:
self.conv = nn.utils.weight_norm(self.conv)
def forward(self, x, output_size=None):
"""Returns the output of the convolution.
Arguments
---------
x : torch.Tensor (batch, time, channel)
input to convolve. 2d or 4d tensors are expected.
output_size : int
The size of the output
Returns
-------
x : torch.Tensor
The convolved output
"""
if not self.skip_transpose:
x = x.transpose(1, -1)
if self.unsqueeze:
x = x.unsqueeze(1)
wx = self.conv(x, output_size=output_size)
if self.unsqueeze:
wx = wx.squeeze(1)
if not self.skip_transpose:
wx = wx.transpose(1, -1)
return wx
def _check_input_shape(self, shape):
"""Checks the input shape and returns the number of input channels."""
if len(shape) == 2:
self.unsqueeze = True
in_channels = 1
elif self.skip_transpose:
in_channels = shape[1]
elif len(shape) == 3:
in_channels = shape[2]
else:
raise ValueError(
"conv1d expects 2d, 3d inputs. Got " + str(len(shape))
)
return in_channels
def remove_weight_norm(self):
"""Removes weight normalization at inference if used during training."""
self.conv = nn.utils.remove_weight_norm(self.conv)
class ResBlock1(torch.nn.Module):
"""
Residual Block Type 1, which has 3 convolutional layers in each convolution block.
Arguments
---------
channels : int
number of hidden channels for the convolutional layers.
kernel_size : int
size of the convolution filter in each layer.
dilation : list
list of dilation value for each conv layer in a block.
"""
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList(
[
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation[0],
padding="same",
skip_transpose=True,
weight_norm=True,
),
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation[1],
padding="same",
skip_transpose=True,
weight_norm=True,
),
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation[2],
padding="same",
skip_transpose=True,
weight_norm=True,
),
]
)
self.convs2 = nn.ModuleList(
[
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
stride=1,
dilation=1,
padding="same",
skip_transpose=True,
weight_norm=True,
),
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
stride=1,
dilation=1,
padding="same",
skip_transpose=True,
weight_norm=True,
),
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
stride=1,
dilation=1,
padding="same",
skip_transpose=True,
weight_norm=True,
),
]
)
def forward(self, x):
"""Returns the output of ResBlock1
Arguments
---------
x : torch.Tensor (batch, channel, time)
input tensor.
Returns
-------
The ResBlock outputs
"""
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
"""This functions removes weight normalization during inference."""
for layer in self.convs1:
layer.remove_weight_norm()
for layer in self.convs2:
layer.remove_weight_norm()
class ResBlock2(torch.nn.Module):
"""
Residual Block Type 2, which has 2 convolutional layers in each convolution block.
Arguments
---------
channels : int
number of hidden channels for the convolutional layers.
kernel_size : int
size of the convolution filter in each layer.
dilation : list
list of dilation value for each conv layer in a block.
"""
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super().__init__()
self.convs = nn.ModuleList(
[
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation[0],
padding="same",
skip_transpose=True,
weight_norm=True,
),
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation[1],
padding="same",
skip_transpose=True,
weight_norm=True,
),
]
)
def forward(self, x):
"""Returns the output of ResBlock1
Arguments
---------
x : torch.Tensor (batch, channel, time)
input tensor.
Returns
-------
The ResBlock outputs
"""
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
"""This functions removes weight normalization during inference."""
for layer in self.convs:
layer.remove_weight_norm()
class HiFiGANArabicGenerator(torch.nn.Module):
"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
Arguments
---------
in_channels : int
number of input tensor channels.
out_channels : int
number of output tensor channels.
resblock_type : str
type of the `ResBlock`. '1' or '2'.
resblock_dilation_sizes : List[List[int]]
list of dilation values in each layer of a `ResBlock`.
resblock_kernel_sizes : List[int]
list of kernel sizes for each `ResBlock`.
upsample_kernel_sizes : List[int]
list of kernel sizes for each transposed convolution.
upsample_initial_channel : int
number of channels for the first upsampling layer. This is divided by 2
for each consecutive upsampling layer.
upsample_factors : List[int]
upsampling factors (stride) for each upsampling layer.
inference_padding : int
constant padding applied to the input at inference time. Defaults to 5.
cond_channels : int
If provided, adds a conv layer to the beginning of the forward.
conv_post_bias : bool
Whether to add a bias term to the final conv.
Example
-------
>>> inp_tensor = torch.rand([4, 80, 33])
>>> hifigan_generator= HifiganGenerator(
... in_channels = 80,
... out_channels = 1,
... resblock_type = "1",
... resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
... resblock_kernel_sizes = [3, 7, 11],
... upsample_kernel_sizes = [16, 16, 4, 4],
... upsample_initial_channel = 512,
... upsample_factors = [8, 8, 2, 2],
... )
>>> out_tensor = hifigan_generator(inp_tensor)
>>> out_tensor.shape
torch.Size([4, 1, 8448])
"""
def __init__(
self,
in_channels,
out_channels,
resblock_type,
resblock_dilation_sizes,
resblock_kernel_sizes,
upsample_kernel_sizes,
upsample_initial_channel,
upsample_factors,
inference_padding=5,
cond_channels=0,
conv_post_bias=True,
):
super().__init__()
self.inference_padding = inference_padding
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_factors)
# initial upsampling layers
self.conv_pre = Conv1d(
in_channels=in_channels,
out_channels=upsample_initial_channel,
kernel_size=7,
stride=1,
padding="same",
skip_transpose=True,
weight_norm=True,
)
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
# upsampling layers
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(
zip(upsample_factors, upsample_kernel_sizes)
):
self.ups.append(
ConvTranspose1d(
in_channels=upsample_initial_channel // (2**i),
out_channels=upsample_initial_channel // (2 ** (i + 1)),
kernel_size=k,
stride=u,
padding=(k - u) // 2,
skip_transpose=True,
weight_norm=True,
)
)
# MRF blocks
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d))
# post convolution layer
self.conv_post = Conv1d(
in_channels=ch,
out_channels=1,
kernel_size=7,
stride=1,
padding="same",
skip_transpose=True,
bias=conv_post_bias,
weight_norm=True,
)
if cond_channels > 0:
self.cond_layer = Conv1d(
in_channels=cond_channels,
out_channels=upsample_initial_channel,
kernel_size=1,
)
def forward(self, x, g=None):
"""
Arguments
---------
x : torch.Tensor (batch, channel, time)
feature input tensor.
g : torch.Tensor (batch, 1, time)
global conditioning input tensor.
Returns
-------
The generator outputs
"""
o = self.conv_pre(x)
if hasattr(self, "cond_layer"):
o = o + self.cond_layer(g)
for i in range(self.num_upsamples):
o = F.leaky_relu(o, LRELU_SLOPE)
o = self.ups[i](o)
z_sum = None
for j in range(self.num_kernels):
if z_sum is None:
z_sum = self.resblocks[i * self.num_kernels + j](o)
else:
z_sum += self.resblocks[i * self.num_kernels + j](o)
o = z_sum / self.num_kernels
o = F.leaky_relu(o)
o = self.conv_post(o)
o = torch.tanh(o)
return o
def remove_weight_norm(self):
"""This functions removes weight normalization during inference."""
for layer in self.ups:
layer.remove_weight_norm()
for layer in self.resblocks:
layer.remove_weight_norm()
self.conv_pre.remove_weight_norm()
self.conv_post.remove_weight_norm()
@torch.no_grad()
def inference(self, c, padding=True):
"""The inference function performs a padding and runs the forward method.
Arguments
---------
c : torch.Tensor (batch, channel, time)
feature input tensor.
padding : bool
Whether to pad tensor before forward.
Returns
-------
The generator outputs
"""
if padding:
c = torch.nn.functional.pad(
c, (self.inference_padding, self.inference_padding), "replicate"
)
return self.forward(c)
@classmethod
def from_pretrained(cls, checkpoint_path, config_path=None, device='cpu'):
if config_path is None:
config_path = os.path.join(os.path.dirname(__file__), "config.json")
with open(config_path, "r") as file:
config = json.load(file)
model = cls(**config)
ckpt = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(ckpt)
return model.eval().to(device)
if __name__ == '__main__':
gen = HifiganGenerator.from_pretrained("generator.ckpt", "config.json")
x = torch.rand(1, 80, 122)
mel = gen(x) |