Bussy-boy/model.py

72 lines
2.2 KiB
Python

import torch
import transformers
from transformers import (
GPTNeoXTokenizerFast,
LogitsProcessor,
AutoConfig,
AutoModelForCausalLM,
LogitsProcessorList,
)
from data import config
class StopAfterPlusIsGenerated(LogitsProcessor):
def __init__(self, plus_token_id, eos_token_id):
super().__init__()
self.plus_token_id = plus_token_id
self.eos_token_id = eos_token_id
def __call__(self, input_ids, scores):
forced_eos = torch.full((scores.size(1),), -float("inf")).to(
device=scores.device, dtype=scores.dtype
)
forced_eos[self.eos_token_id] = 0
scores[input_ids[:, -1] == self.plus_token_id] = forced_eos
return scores
class Model:
def __init__(self):
name = f"{config['data_dir']}/mpt-30b-drama"
self.tokenizer = GPTNeoXTokenizerFast.from_pretrained(
name, pad_token="<|endoftext|>"
)
model_config = AutoConfig.from_pretrained(name, trust_remote_code=True)
model_config.attn_config["attn_impl"] = "triton"
model_config.init_device = "cuda:0"
model_config.eos_token_id = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
name,
config=model_config,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
self.logits_processor = LogitsProcessorList(
[StopAfterPlusIsGenerated(559, self.model.config.eos_token_id)]
)
def generate(self, prompt):
with torch.autocast("cuda", dtype=torch.bfloat16):
encoded = self.tokenizer(
prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda:0")
gen_tokens = self.model.generate(
input_ids=encoded.input_ids,
attention_mask=encoded.attention_mask,
pad_token_id=0,
do_sample=True,
temperature=0.90,
use_cache=True,
max_length=8192,
logits_processor=self.logits_processor,
)
return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0][
len(prompt) :
]