bots/runpod/bussyboy/model.py

75 lines
2.2 KiB
Python

import torch
from transformers import (
AutoModelForCausalLM,
BitsAndBytesConfig,
GPTNeoXTokenizerFast,
LogitsProcessor,
LogitsProcessorList,
)
class StopAfterPlusIsGenerated(LogitsProcessor):
def __init__(self, plus_token_id, eos_token_id):
super().__init__()
self.plus_token_id = plus_token_id
self.eos_token_id = eos_token_id
def __call__(self, input_ids, scores):
forced_eos = torch.full((scores.size(1),), -float("inf")).to(
device=scores.device, dtype=scores.dtype
)
forced_eos[self.eos_token_id] = 0
scores[input_ids[:, -1] == self.plus_token_id] = forced_eos
return scores
class Model:
def __init__(self):
name = "/workspace/models/mpt-30b-drama"
self.tokenizer = GPTNeoXTokenizerFast.from_pretrained(
name, pad_token="<|endoftext|>"
)
# model_config = AutoConfig.from_pretrained(name, trust_remote_code=True)
# model_config.attn_config["attn_impl"] = "triton"
# model_config.init_device = "cuda:0"
# model_config.eos_token_id = self.tokenizer.eos_token_id
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
)
self.model = AutoModelForCausalLM.from_pretrained(
name,
device_map="auto",
quantization_config=quantization_config,
trust_remote_code=True,
)
self.logits_processor = LogitsProcessorList(
[StopAfterPlusIsGenerated(559, self.tokenizer.eos_token_id)]
)
def generate(self, prompt):
encoded = self.tokenizer(
prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
gen_tokens = self.model.generate(
input_ids=encoded.input_ids,
attention_mask=encoded.attention_mask,
pad_token_id=0,
eos_token_id=self.tokenizer.eos_token_id,
do_sample=True,
temperature=0.90,
use_cache=True,
max_new_tokens=512,
logits_processor=self.logits_processor,
)
return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0][
len(prompt) :
]