import torch import transformers from transformers import ( GPTNeoXTokenizerFast, LogitsProcessor, AutoConfig, AutoModelForCausalLM, LogitsProcessorList, ) from config import config class StopAfterPlusIsGenerated(LogitsProcessor): def __init__(self, plus_token_id, eos_token_id): super().__init__() self.plus_token_id = plus_token_id self.eos_token_id = eos_token_id def __call__(self, input_ids, scores): forced_eos = torch.full((scores.size(1),), -float("inf")).to( device=scores.device, dtype=scores.dtype ) forced_eos[self.eos_token_id] = 0 scores[input_ids[:, -1] == self.plus_token_id] = forced_eos return scores class Model: def __init__(self): name = f"{config['data_dir']}/mpt-30b-drama-ba678" self.tokenizer = GPTNeoXTokenizerFast.from_pretrained( name, pad_token="<|endoftext|>" ) model_config = AutoConfig.from_pretrained(name, trust_remote_code=True) model_config.attn_config["attn_impl"] = "triton" model_config.init_device = "cuda:0" model_config.eos_token_id = self.tokenizer.eos_token_id self.model = AutoModelForCausalLM.from_pretrained( name, config=model_config, torch_dtype=torch.bfloat16, trust_remote_code=True, ) self.logits_processor = LogitsProcessorList( [StopAfterPlusIsGenerated(559, self.model.config.eos_token_id)] ) def generate(self, prompt): with torch.autocast("cuda", dtype=torch.bfloat16): encoded = self.tokenizer( prompt, return_tensors="pt", padding=True, truncation=True ).to("cuda:0") gen_tokens = self.model.generate( input_ids=encoded.input_ids, attention_mask=encoded.attention_mask, pad_token_id=0, do_sample=True, temperature=0.90, use_cache=True, max_length=8192, logits_processor=self.logits_processor, ) return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0][ len(prompt) : ]