gpt2-quantzed-gguf / main1.py
kyrylokumar's picture
Added extra files
35e23cc verified
raw
history blame
3.4 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import transformers
import torch
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import warnings
warnings.filterwarnings("ignore")
device = "cuda"
model =AutoModelForCausalLM.from_pretrained("gpt2", ).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
from datasets import load_dataset
test = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
# print(len(test))
encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt")
import time
import gc
def run_experiment(model):
print(f'Memory usage of model alone = {model.get_memory_footprint()/10**6}')
max_length = model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)
nlls = []
start_time = time.time()
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
# loss is calculated using CrossEntropyLoss which averages over valid labels
neg_log_likelihood = outputs.loss
if begin_loc == 0:
print(f'Memory usage at forward pass = {torch.cuda.memory_allocated(0)/10**6}')
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).mean())
print(f'Loss = {ppl.item()}')
print(f'Time taken: {- start_time + time.time()}')
from quant import perform_quantization
model_type = 0
if model_type == 0:
## Normal
print('Normal model')
run_experiment(model)
print()
## Full model quant (including lm_head)
if model_type == 0:
print('Full model quant')
perform_quantization(model)
torch.save(model, 'q1-full-quant.pt')
# print(model)
run_experiment(model)
print()
# Without lm_head
if model_type == 0:
print('Full model without lm_head')
model =AutoModelForCausalLM.from_pretrained("gpt2", ).to(device)
perform_quantization(model, regex=r"transformer\.h\.\d+\.[a-zA-Z]+")
# print(model)
run_experiment(model)
print()
# Only lm_head
if model_type == 0:
print('Only LM head')
model =AutoModelForCausalLM.from_pretrained("gpt2", ).to(device)
perform_quantization(model, regex=r"[\w.]*lm_head[\w.]*")
# print(gc.collect())
# print(model)
run_experiment(model)
print()
# Last 4 layers
if model_type == 0:
print('Last 4 attention layers')
model =AutoModelForCausalLM.from_pretrained("gpt2", ).to(device)
perform_quantization(model, regex=r"transformer\.h\.(8|9|10|11)\.[a-zA-Z]+")
# print(gc.collect())
# print(model)
run_experiment(model)
print()
# Only q,k,v
if model_type == 0:
print('Only q,k,v')
model =AutoModelForCausalLM.from_pretrained("gpt2", ).to(device)
perform_quantization(model, regex=r"[\w.]*attn[\w.]*")
# print(model)
run_experiment(model)
print()