72 lines
2.2 KiB
Python
72 lines
2.2 KiB
Python
|
import torch
|
||
|
import transformers
|
||
|
from transformers import (
|
||
|
GPTNeoXTokenizerFast,
|
||
|
LogitsProcessor,
|
||
|
AutoConfig,
|
||
|
AutoModelForCausalLM,
|
||
|
LogitsProcessorList,
|
||
|
)
|
||
|
from config import config
|
||
|
|
||
|
|
||
|
class StopAfterPlusIsGenerated(LogitsProcessor):
|
||
|
def __init__(self, plus_token_id, eos_token_id):
|
||
|
super().__init__()
|
||
|
|
||
|
self.plus_token_id = plus_token_id
|
||
|
self.eos_token_id = eos_token_id
|
||
|
|
||
|
def __call__(self, input_ids, scores):
|
||
|
forced_eos = torch.full((scores.size(1),), -float("inf")).to(
|
||
|
device=scores.device, dtype=scores.dtype
|
||
|
)
|
||
|
forced_eos[self.eos_token_id] = 0
|
||
|
|
||
|
scores[input_ids[:, -1] == self.plus_token_id] = forced_eos
|
||
|
return scores
|
||
|
|
||
|
|
||
|
class Model:
|
||
|
def __init__(self):
|
||
|
name = f"{config['data_dir']}/mpt-30b-drama-ba678"
|
||
|
self.tokenizer = GPTNeoXTokenizerFast.from_pretrained(
|
||
|
name, pad_token="<|endoftext|>"
|
||
|
)
|
||
|
|
||
|
model_config = AutoConfig.from_pretrained(name, trust_remote_code=True)
|
||
|
model_config.attn_config["attn_impl"] = "triton"
|
||
|
model_config.init_device = "cuda:0"
|
||
|
model_config.eos_token_id = self.tokenizer.eos_token_id
|
||
|
|
||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||
|
name,
|
||
|
config=model_config,
|
||
|
torch_dtype=torch.bfloat16,
|
||
|
trust_remote_code=True,
|
||
|
)
|
||
|
|
||
|
self.logits_processor = LogitsProcessorList(
|
||
|
[StopAfterPlusIsGenerated(559, self.model.config.eos_token_id)]
|
||
|
)
|
||
|
|
||
|
def generate(self, prompt):
|
||
|
with torch.autocast("cuda", dtype=torch.bfloat16):
|
||
|
encoded = self.tokenizer(
|
||
|
prompt, return_tensors="pt", padding=True, truncation=True
|
||
|
).to("cuda:0")
|
||
|
|
||
|
gen_tokens = self.model.generate(
|
||
|
input_ids=encoded.input_ids,
|
||
|
attention_mask=encoded.attention_mask,
|
||
|
pad_token_id=0,
|
||
|
do_sample=True,
|
||
|
temperature=0.90,
|
||
|
use_cache=True,
|
||
|
max_length=8192,
|
||
|
logits_processor=self.logits_processor,
|
||
|
)
|
||
|
return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0][
|
||
|
len(prompt) :
|
||
|
]
|