|
from ..models.model_manager import ModelManager |
|
import torch |
|
|
|
|
|
|
|
def tokenize_long_prompt(tokenizer, prompt, max_length=None): |
|
|
|
length = tokenizer.model_max_length if max_length is None else max_length |
|
|
|
|
|
tokenizer.model_max_length = 99999999 |
|
|
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
|
|
|
|
|
max_length = (input_ids.shape[1] + length - 1) // length * length |
|
|
|
|
|
tokenizer.model_max_length = length |
|
|
|
|
|
input_ids = tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
padding="max_length", |
|
max_length=max_length, |
|
truncation=True |
|
).input_ids |
|
|
|
|
|
num_sentence = input_ids.shape[1] // length |
|
input_ids = input_ids.reshape((num_sentence, length)) |
|
|
|
return input_ids |
|
|
|
|
|
|
|
class BasePrompter: |
|
def __init__(self): |
|
self.refiners = [] |
|
self.extenders = [] |
|
|
|
|
|
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): |
|
for refiner_class in refiner_classes: |
|
refiner = refiner_class.from_model_manager(model_manager) |
|
self.refiners.append(refiner) |
|
|
|
def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]): |
|
for extender_class in extender_classes: |
|
extender = extender_class.from_model_manager(model_manager) |
|
self.extenders.append(extender) |
|
|
|
|
|
@torch.no_grad() |
|
def process_prompt(self, prompt, positive=True): |
|
if isinstance(prompt, list): |
|
prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt] |
|
else: |
|
for refiner in self.refiners: |
|
prompt = refiner(prompt, positive=positive) |
|
return prompt |
|
|
|
@torch.no_grad() |
|
def extend_prompt(self, prompt:str, positive=True): |
|
extended_prompt = dict(prompt=prompt) |
|
for extender in self.extenders: |
|
extended_prompt = extender(extended_prompt) |
|
return extended_prompt |