|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import PreTrainedModel
|
|
from .configuration_gpt import GPTConfig
|
|
|
|
|
|
class GPT(nn.Module):
|
|
"""
|
|
The GPT language model:
|
|
- Embeddings (token + positional)
|
|
- Stack of Transformer blocks
|
|
- Final LayerNorm + Linear head for output logits
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
block_size: int = 1024,
|
|
vocab_size: int = 50304,
|
|
n_layer: int = 12,
|
|
n_head: int = 12,
|
|
n_embd: int = 768,
|
|
):
|
|
super().__init__()
|
|
|
|
|
|
self.block_size = block_size
|
|
self.vocab_size = vocab_size
|
|
self.n_layer = n_layer
|
|
self.n_head = n_head
|
|
self.n_embd = n_embd
|
|
|
|
|
|
self.transformer = nn.ModuleDict(
|
|
dict(
|
|
wte=nn.Embedding(self.vocab_size, self.n_embd),
|
|
wpe=nn.Embedding(self.block_size, self.n_embd),
|
|
h=nn.ModuleList(
|
|
[self.Block(self.n_embd, self.n_head) for _ in range(self.n_layer)]
|
|
),
|
|
ln_f=nn.LayerNorm(self.n_embd),
|
|
)
|
|
)
|
|
|
|
|
|
self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=False)
|
|
|
|
|
|
self.transformer.wte.weight = self.lm_head.weight
|
|
|
|
def forward(self, x):
|
|
B, T = x.shape
|
|
assert T <= self.block_size, "Cannot forward sequence longer than block size"
|
|
|
|
|
|
tok_emb = self.transformer.wte(x)
|
|
pos_emb = self.transformer.wpe(torch.arange(T, device=x.device))
|
|
x = tok_emb + pos_emb.unsqueeze(0)
|
|
|
|
|
|
for block in self.transformer.h:
|
|
x = block(x)
|
|
|
|
x = self.transformer.ln_f(x)
|
|
logits = self.lm_head(x)
|
|
return logits
|
|
|
|
class CausalSelfAttention(nn.Module):
|
|
"""
|
|
Multi-head self-attention with causal masking.
|
|
"""
|
|
|
|
def __init__(self, n_embd, n_head):
|
|
super().__init__()
|
|
assert (
|
|
n_embd % n_head == 0
|
|
), "Embedding dimension must be divisible by number of heads"
|
|
self.n_head = n_head
|
|
self.n_embd = n_embd
|
|
|
|
|
|
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
|
self.c_proj = nn.Linear(n_embd, n_embd)
|
|
|
|
def forward(self, x):
|
|
B, T, C = x.size()
|
|
qkv = self.c_attn(x)
|
|
q, k, v = qkv.split(self.n_embd, dim=2)
|
|
|
|
|
|
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
|
|
|
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
|
y = self.c_proj(y)
|
|
return y
|
|
|
|
class MLP(nn.Module):
|
|
"""
|
|
Feed-forward network block used in Transformer architectures.
|
|
"""
|
|
|
|
def __init__(self, n_embd):
|
|
super().__init__()
|
|
self.c_fc = nn.Linear(n_embd, 4 * n_embd)
|
|
self.gelu = nn.GELU(approximate="tanh")
|
|
self.c_proj = nn.Linear(4 * n_embd, n_embd)
|
|
|
|
def forward(self, x):
|
|
return self.c_proj(self.gelu(self.c_fc(x)))
|
|
|
|
class Block(nn.Module):
|
|
"""
|
|
A single Transformer block.
|
|
"""
|
|
|
|
def __init__(self, n_embd, n_head):
|
|
super().__init__()
|
|
self.ln_1 = nn.LayerNorm(n_embd)
|
|
self.attn = GPT.CausalSelfAttention(n_embd, n_head)
|
|
self.ln_2 = nn.LayerNorm(n_embd)
|
|
self.mlp = GPT.MLP(n_embd)
|
|
|
|
def forward(self, x):
|
|
x = x + self.attn(self.ln_1(x))
|
|
x = x + self.mlp(self.ln_2(x))
|
|
return x
|
|
|
|
|
|
class GPTModelForTextGeneration(PreTrainedModel):
|
|
"""
|
|
A wrapper class for GPT-based text generation.
|
|
This integrates a Transformer model within the Hugging Face `PreTrainedModel` framework.
|
|
"""
|
|
|
|
config_class = GPTConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
|
|
self.model = GPT(
|
|
block_size=config.block_size,
|
|
vocab_size=config.vocab_size,
|
|
n_layer=config.n_layer,
|
|
n_head=config.n_head,
|
|
n_embd=config.n_embd,
|
|
)
|
|
|
|
def forward(self, input_ids: torch.Tensor):
|
|
|
|
assert isinstance(input_ids, torch.Tensor), "input_ids must be a PyTorch tensor"
|
|
|
|
tokens = input_ids.clone()
|
|
tokens = tokens.unsqueeze(0) if tokens.dim() == 1 else tokens
|
|
|
|
assert (
|
|
tokens.ndim == 2 and tokens.shape[0] == 1
|
|
), "input_ids must have 2 dimensions: (1, sequence_length)"
|
|
|
|
|
|
assert torch.all(
|
|
(tokens >= 0) & (tokens <= self.model.vocab_size)
|
|
), "input_ids contain invalid token values"
|
|
|
|
|
|
logits = self.model.forward(tokens)
|
|
|
|
return {"logits": logits}
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
max_length: int = 50,
|
|
do_sample: bool = True,
|
|
top_k: int = 50,
|
|
top_p: float = 0.95,
|
|
temperature: float = 0.9,
|
|
device: str = "cpu",
|
|
):
|
|
"""
|
|
Generates text using autoregressive sampling with top-k, top-p, and temperature.
|
|
"""
|
|
|
|
|
|
if device.startswith("cuda"):
|
|
assert torch.cuda.is_available(), "CUDA is not available, please use 'cpu'"
|
|
if device != "cuda":
|
|
try:
|
|
device_index = int(device.split(":")[1])
|
|
assert (
|
|
0 <= device_index < torch.cuda.device_count()
|
|
), f"Invalid CUDA device index: {device_index}"
|
|
except (IndexError, ValueError):
|
|
raise ValueError(
|
|
"Invalid device format. Use 'cpu', 'cuda', or 'cuda:N' where N is an integer."
|
|
)
|
|
elif device != "cpu":
|
|
raise ValueError("Invalid device. Use 'cpu', 'cuda', or 'cuda:N'.")
|
|
|
|
|
|
input_ids = input_ids.to(device)
|
|
self.model.to(device)
|
|
|
|
|
|
assert isinstance(input_ids, torch.Tensor), "input_ids must be a PyTorch tensor"
|
|
tokens = input_ids.clone()
|
|
tokens = tokens.unsqueeze(0) if tokens.dim() == 1 else tokens
|
|
|
|
assert (
|
|
tokens.ndim == 2 and tokens.shape[0] == 1
|
|
), "input_ids must have 2 dimensions: (1, sequence_length)"
|
|
|
|
|
|
assert torch.all(
|
|
(tokens >= 0) & (tokens < self.model.vocab_size)
|
|
), "input_ids contain invalid token values"
|
|
|
|
|
|
assert (
|
|
isinstance(max_length, int) and max_length >= 1
|
|
), "max_length must be a positive integer"
|
|
assert (
|
|
max_length <= self.model.block_size
|
|
), f"max_length must be in range [1, {self.model.block_size}]"
|
|
|
|
|
|
assert isinstance(top_k, int) and top_k >= 1, "top_k must be a positive integer"
|
|
|
|
|
|
assert (
|
|
isinstance(top_p, (int, float)) and 0.0 <= top_p <= 1.0
|
|
), "top_p must be in range [0, 1]"
|
|
|
|
|
|
assert (
|
|
isinstance(temperature, (int, float)) and 0.0 <= temperature <= 1.0
|
|
), "temperature must be in range [0, 1]"
|
|
|
|
|
|
tokens = tokens.to(device)
|
|
|
|
|
|
while tokens.size(1) < max_length:
|
|
logits = self.forward(tokens)["logits"][:, -1, :]
|
|
logits = logits / max(0.01, temperature)
|
|
|
|
if do_sample:
|
|
top_k = min(top_k, logits.size(-1))
|
|
|
|
|
|
indices_to_remove = (
|
|
logits < torch.topk(logits, top_k, dim=1)[0][..., -1, None]
|
|
)
|
|
logits[indices_to_remove] = float("-inf")
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
|
cumulative_probs = torch.cumsum(
|
|
F.softmax(sorted_logits, dim=-1), dim=-1
|
|
)
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
|
..., :-1
|
|
].clone()
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
|
|
|
|
sorted_logits[sorted_indices_to_remove] = float("-inf")
|
|
|
|
logits = torch.gather(sorted_logits, 1, sorted_indices.argsort(-1))
|
|
|
|
|
|
next_tokens = torch.multinomial(F.softmax(logits, -1), 1)
|
|
else:
|
|
next_tokens = torch.argmax(logits, dim=-1, keepdim=True)
|
|
|
|
tokens = torch.cat((tokens, next_tokens), dim=1)
|
|
|
|
return tokens.flatten()
|
|
|