72 lines
2.2 KiB
Python
72 lines
2.2 KiB
Python
import torch
|
|
import transformers
|
|
from transformers import (
|
|
GPTNeoXTokenizerFast,
|
|
LogitsProcessor,
|
|
AutoConfig,
|
|
AutoModelForCausalLM,
|
|
LogitsProcessorList,
|
|
)
|
|
from config import config
|
|
|
|
|
|
class StopAfterPlusIsGenerated(LogitsProcessor):
|
|
def __init__(self, plus_token_id, eos_token_id):
|
|
super().__init__()
|
|
|
|
self.plus_token_id = plus_token_id
|
|
self.eos_token_id = eos_token_id
|
|
|
|
def __call__(self, input_ids, scores):
|
|
forced_eos = torch.full((scores.size(1),), -float("inf")).to(
|
|
device=scores.device, dtype=scores.dtype
|
|
)
|
|
forced_eos[self.eos_token_id] = 0
|
|
|
|
scores[input_ids[:, -1] == self.plus_token_id] = forced_eos
|
|
return scores
|
|
|
|
|
|
class Model:
|
|
def __init__(self):
|
|
name = f"{config['data_dir']}/mpt-30b-drama-ba678"
|
|
self.tokenizer = GPTNeoXTokenizerFast.from_pretrained(
|
|
name, pad_token="<|endoftext|>"
|
|
)
|
|
|
|
model_config = AutoConfig.from_pretrained(name, trust_remote_code=True)
|
|
model_config.attn_config["attn_impl"] = "triton"
|
|
model_config.init_device = "cuda:0"
|
|
model_config.eos_token_id = self.tokenizer.eos_token_id
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
name,
|
|
config=model_config,
|
|
torch_dtype=torch.bfloat16,
|
|
trust_remote_code=True,
|
|
)
|
|
|
|
self.logits_processor = LogitsProcessorList(
|
|
[StopAfterPlusIsGenerated(559, self.model.config.eos_token_id)]
|
|
)
|
|
|
|
def generate(self, prompt):
|
|
with torch.autocast("cuda", dtype=torch.bfloat16):
|
|
encoded = self.tokenizer(
|
|
prompt, return_tensors="pt", padding=True, truncation=True
|
|
).to("cuda:0")
|
|
|
|
gen_tokens = self.model.generate(
|
|
input_ids=encoded.input_ids,
|
|
attention_mask=encoded.attention_mask,
|
|
pad_token_id=0,
|
|
do_sample=True,
|
|
temperature=0.90,
|
|
use_cache=True,
|
|
max_length=8192,
|
|
logits_processor=self.logits_processor,
|
|
)
|
|
return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0][
|
|
len(prompt) :
|
|
]
|