|
""" |
|
Gradient-based Subword Tokenization(GBST) Layer implementation. |
|
|
|
based on lucidrains/charformer-pytorch implementation, |
|
which distributed under MIT License. |
|
|
|
original code location: |
|
https://github.com/lucidrains/charformer-pytorch/charformer_pytorch.py |
|
|
|
copyright (c) 2023~, ETRI LIRS. Jong-hun Shin. |
|
""" |
|
import math |
|
import functools |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from typing import Optional |
|
|
|
from torch import einsum, nn, Tensor |
|
from transformers.utils import logging |
|
from einops.layers.torch import Rearrange |
|
from einops import rearrange, repeat |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
_BLOCKS = ( |
|
(1, 0), (2, 0), (3, 0), (4, 0), |
|
(6, 0), (9, 0), |
|
|
|
) |
|
|
|
@torch.jit.script |
|
def pad_to_multiple(in_tensor:Tensor, multiple:int, seq_dim:int, |
|
dim:int, value:Optional[float]): |
|
seqlen = in_tensor.shape[seq_dim] |
|
padded_len = math.ceil(seqlen / multiple) * multiple |
|
if seqlen == padded_len: |
|
return in_tensor |
|
pad_offset = (0,) * (-1 - dim) * 2 |
|
if len(pad_offset) == 0: |
|
return F.pad(in_tensor, (0, padded_len - seqlen), value=value) |
|
|
|
d1, d2 = pad_offset |
|
return F.pad(in_tensor, (d1, d2, 0, padded_len - seqlen), value=value) |
|
|
|
|
|
|
|
|
|
class Depthwise1dConv(nn.Module): |
|
def __init__(self, in_dim, out_dim, krnl_size, use_bn=False): |
|
super().__init__() |
|
self.use_bn = use_bn |
|
self.convol = nn.Conv1d(in_dim, out_dim, krnl_size, groups=in_dim) |
|
|
|
if self.use_bn: |
|
self.bn = nn.BatchNorm1d(out_dim, eps=1e-05,) |
|
self.proj = nn.Conv1d(out_dim, out_dim, 1) |
|
|
|
@torch.cuda.amp.autocast(enabled=False, dtype=torch.float32) |
|
def forward(self, in_tensor): |
|
in_tensor = self.convol(in_tensor) |
|
if self.use_bn: |
|
in_tensor = self.bn(in_tensor) |
|
return self.proj(in_tensor) |
|
|
|
def _init_weights(self, factor:float=0.05): |
|
logger.debug(f"1dConv-Weight initialize called, before: {self.convol.weight.data}") |
|
self.convol.weight.data.normal_(mean=0.0, std=factor * 1.0) |
|
self.proj.weight.data.normal_(mean=0.0, std=factor * 1.0) |
|
logger.debug(f"1dConv-Weight initialize called, after: {self.convol.weight.data}") |
|
|
|
|
|
class Padding(nn.Module): |
|
def __init__(self, padding, value=0): |
|
super().__init__() |
|
self.padding = padding |
|
self.value = value |
|
|
|
def forward(self, in_tensor): |
|
return F.pad(in_tensor, self.padding, value=self.value) |
|
|
|
|
|
class GBSWT(nn.Module): |
|
""" Gradient-based Sub-Word Tokenizer implementation. """ |
|
def __init__(self, embed_tokens, |
|
max_block_size=None, |
|
blocks=_BLOCKS, |
|
downsample_factor=1, |
|
score_consensus_attn=True, |
|
use_bn=False,): |
|
super().__init__() |
|
num_tokens, dim = embed_tokens.weight.shape |
|
|
|
assert (max_block_size is not None) ^ (blocks is not None), \ |
|
'max_block_size or blocks must be given.' |
|
if blocks is None: |
|
self.blocks = tuple(map(lambda elem: (elem, 0), range(1, max_block_size+1))) |
|
else: |
|
if not isinstance(blocks, tuple): |
|
raise ValueError('blocks must be assigned as a tuple') |
|
self.blocks = tuple(map(lambda elem: elem if isinstance(elem, tuple) else (elem, 0), blocks)) |
|
if not all([(offset < block_size) for block_size, offset in self.blocks]): |
|
raise ValueError('Offset must be smaller than given block size.') |
|
max_block_size = max(list(map(lambda x: x[0], self.blocks))) |
|
|
|
assert downsample_factor <= max_block_size, \ |
|
'downsample factor must be less than the max_block_size.' |
|
|
|
self.downsample_factor = downsample_factor |
|
self.score_consensus_attn = score_consensus_attn |
|
self.use_bn = use_bn |
|
logger.debug(f"GBSWT Subword Block Combinations: {self.blocks}") |
|
logger.debug(f"GBSWT Downsampling factor: {self.downsample_factor}, use BatchNorm: {self.use_bn}") |
|
|
|
def lcm(*num): |
|
return int(functools.reduce(lambda x, y: int((x * y) / math.gcd(x, y)), num, 1)) |
|
|
|
self.block_pad_multiple = lcm(*[block_size for block_size, _ in self.blocks]) |
|
|
|
|
|
|
|
self.embeds = embed_tokens |
|
self.positional_convol = nn.Sequential( |
|
Padding((0, 0, 0, max_block_size-1)), |
|
Rearrange('b s d -> b d s'), |
|
Depthwise1dConv(dim, dim, krnl_size=max_block_size, use_bn=self.use_bn,), |
|
Rearrange('b d s -> b s d')) |
|
self.cand_scoring = nn.Sequential( |
|
nn.Linear(dim, 1), |
|
Rearrange('... () -> ...')) |
|
|
|
def _init_weights(self, factor:float=0.05): |
|
self.positional_convol[2]._init_weights(factor) |
|
|
|
self.cand_scoring[0].weight.data.normal_(mean=0.0, std=factor * 1.0) |
|
|
|
|
|
def get_blocks(self): |
|
""" return GBST candidate blocking list. """ |
|
return self.blocks |
|
|
|
@torch.cuda.amp.autocast() |
|
def forward(self, in_tensor, attention_mask=None): |
|
b, s = in_tensor.shape |
|
|
|
mask = attention_mask |
|
|
|
block_multi, ds_factor = self.block_pad_multiple, self.downsample_factor |
|
|
|
in_tensor = self.embeds(in_tensor) |
|
in_tensor = self.positional_convol(in_tensor) |
|
in_tensor = pad_to_multiple(in_tensor, block_multi, |
|
seq_dim=1, dim=-2, value=0.0) |
|
if mask is not None: |
|
mask = pad_to_multiple(mask, block_multi, |
|
seq_dim=1, dim=-1, value=False) |
|
|
|
def _masked_mean(in_tensor:Tensor, mask:Tensor, dim:int=-1): |
|
len_diff = len(in_tensor.shape) - len(mask.shape) |
|
mask = torch.unsqueeze(mask, dim=-len_diff) |
|
in_tensor.masked_fill_(~(mask.bool()), 0.) |
|
|
|
total_elems = mask.sum(dim=dim) |
|
mean = in_tensor.sum(dim=dim) / total_elems.clamp(min=1.) |
|
mean.masked_fill_((total_elems == 0), 0.) |
|
return mean.float() |
|
|
|
block_reprs, block_masks = [], [] |
|
|
|
|
|
for block_size, offset in self.blocks: |
|
block_in = in_tensor.clone() |
|
if mask is not None: |
|
block_mask = mask.clone() |
|
need_padding = offset > 0 |
|
|
|
if need_padding: |
|
loff, roff = (block_size - offset), offset |
|
|
|
block_in = F.pad(block_in, (0, 0, loff, roff), value=0.0) |
|
if mask is not None: |
|
block_mask = F.pad(block_mask, (0, 0, loff, roff), value=False) |
|
|
|
blks = rearrange(block_in, 'b (s m) d -> b s m d', m=block_size) |
|
if mask is not None: |
|
mask_blks = rearrange(block_mask, 'b (s m) -> b s m', m=block_size) |
|
blk_repr = _masked_mean(blks, mask_blks, dim=-2) |
|
else: |
|
blk_repr = blks.mean(dim=-2) |
|
|
|
blk_repr = repeat(blk_repr, 'b s d -> b (s m) d', m=block_size) |
|
|
|
if need_padding: |
|
blk_repr = blk_repr[:, loff:-roff] |
|
|
|
block_reprs.append(blk_repr) |
|
|
|
if mask is not None: |
|
mask_blks = torch.any(mask_blks, dim=-1) |
|
mask_blks = repeat(mask_blks, 'b s -> b (s m)', m=block_size) |
|
if need_padding: |
|
mask_blks = mask_blks[:, loff:-roff] |
|
block_masks.append(mask_blks) |
|
|
|
|
|
block_reprs = torch.stack(block_reprs, dim=2,) |
|
scores = self.cand_scoring(block_reprs) |
|
|
|
if mask is not None: |
|
block_masks = torch.stack(block_masks, dim=2) |
|
max_neg_val = -torch.finfo(scores.dtype).max |
|
scores = scores.masked_fill(~block_masks, max_neg_val) |
|
|
|
scores = scores.softmax(dim=2) |
|
|
|
|
|
if self.score_consensus_attn: |
|
score_sim = einsum('b i d, b j d -> b i j', scores, scores) |
|
|
|
if mask is not None: |
|
cross_mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j') |
|
max_neg_val = -torch.finfo(score_sim.dtype).max |
|
score_sim = score_sim.masked_fill((~(cross_mask.bool())), max_neg_val) |
|
|
|
score_attn = score_sim.softmax(dim=-1) |
|
scores = einsum('b i j, b j m -> b i m', score_attn, scores) |
|
|
|
scores = rearrange(scores, 'b n m -> b n m ()') |
|
in_tensor = (block_reprs * scores).sum(dim=2) |
|
|
|
@torch.jit.script |
|
def _reshape_input_tensor(in_tensor:Tensor, s:int, d:int): |
|
|
|
m = int(math.ceil(s / d) * d) |
|
|
|
return in_tensor[:, :m] |
|
|
|
in_tensor = _reshape_input_tensor(in_tensor, s, ds_factor) |
|
if mask is not None: |
|
mask = _reshape_input_tensor(mask, s, ds_factor) |
|
|
|
|
|
in_tensor = rearrange(in_tensor, 'b (n m) d -> b n m d', m=ds_factor) |
|
if mask is not None: |
|
mask = rearrange(mask, 'b (n m) -> b n m', m=ds_factor) |
|
in_tensor = _masked_mean(in_tensor, mask, dim=2) |
|
mask = torch.any(mask, dim=-1) |
|
else: |
|
in_tensor = in_tensor.mean(dim=-2) |
|
|
|
|
|
return in_tensor, mask |
|
|