diff --git a/infer.py b/infer.py index 3d90202..79dbcdf 100644 --- a/infer.py +++ b/infer.py @@ -12,43 +12,53 @@ if torch.cuda.is_available(): device = "cuda" -def main( - load_8bit: bool = False, - base_model: str = "", - lora_weights: str = "", - infer_data_path: str = "", - prompt_template: str = "", # The prompt template to use, will default to alpaca. -): - prompter = Prompter(prompt_template) - tokenizer = LlamaTokenizer.from_pretrained(base_model) - model = LlamaForCausalLM.from_pretrained( - base_model, - load_in_8bit=load_8bit, - torch_dtype=torch.float16, - device_map="auto", - ) - try: - print(f"Using lora {lora_weights}") - model = PeftModel.from_pretrained( - model, - lora_weights, +class Infer(): + def __init__( + self, + load_8bit: bool = False, + base_model: str = "", + lora_weights: str = "", + prompt_template: str = "", # The prompt template to use, will default to alpaca. + ): + prompter = Prompter(prompt_template) + tokenizer = LlamaTokenizer.from_pretrained(base_model) + model = LlamaForCausalLM.from_pretrained( + base_model, + load_in_8bit=load_8bit, torch_dtype=torch.float16, + device_map="auto", ) - except: - print("*"*50, "\n Attention! No Lora Weights \n", "*"*50) - # unwind broken decapoda-research config - model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk - model.config.bos_token_id = 1 - model.config.eos_token_id = 2 - if not load_8bit: - model.half() # seems to fix bugs for some users. + + try: + print(f"Using lora {lora_weights}") + model = PeftModel.from_pretrained( + model, + lora_weights, + torch_dtype=torch.float16, + ) + except: + print("*"*50, "\n Attention! No Lora Weights \n", "*"*50) + + # unwind broken decapoda-research config + model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk + model.config.bos_token_id = 1 + model.config.eos_token_id = 2 + if not load_8bit: + model.half() # seems to fix bugs for some users. - model.eval() + model.eval() - if torch.__version__ >= "2" and sys.platform != "win32": - model = torch.compile(model) + if torch.__version__ >= "2" and sys.platform != "win32": + model = torch.compile(model) + + self.base_model = base_model + self.lora_weights = lora_weights + self.model = model + self.prompter = prompter + self.tokenizer = tokenizer - def evaluate( + def generate_output( + self, instruction, input=None, temperature=0.1, @@ -58,8 +68,8 @@ def main( max_new_tokens=256, **kwargs, ): - prompt = prompter.generate_prompt(instruction, input) - inputs = tokenizer(prompt, return_tensors="pt") + prompt = self.prompter.generate_prompt(instruction, input) + inputs = self.tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( temperature=temperature, @@ -70,7 +80,7 @@ def main( **kwargs, ) with torch.no_grad(): - generation_output = model.generate( + generation_output = self.model.generate( input_ids=input_ids, generation_config=generation_config, return_dict_in_generate=True, @@ -78,34 +88,49 @@ def main( max_new_tokens=max_new_tokens, ) s = generation_output.sequences[0] - output = tokenizer.decode(s) - return prompter.get_response(output) + output = self.tokenizer.decode(s) + return self.prompter.get_response(output) - def infer_from_file(): + def infer_from_file(self, infer_data_path): with open(infer_data_path) as f: for line in f: data = json.loads(line) instruction = data["instruction"] output = data["output"] print('=' * 100) - print(f"Base Model: {base_model} Lora Weights: {lora_weights}") + print(f"Base Model: {self.base_model} Lora Weights: {self.lora_weights}") print("Instruction:\n", instruction) - model_output = evaluate(instruction) + model_output = self.generate_output(instruction) print("Model Output:\n", model_output) print("Ground Truth:\n", output) print('=' * 100) + +def main( + load_8bit: bool = False, + base_model: str = "", + lora_weights: str = "", + prompt_template: str = "", # The prompt template to use, will default to alpaca. + infer_data_path: str = "", +): + infer = Infer( + load_8bit=load_8bit, + base_model=base_model, + lora_weights=lora_weights, + prompt_template=prompt_template + ) + try: - infer_from_file() - except: - print("Read infer_data_path Failed! Now Interactive Mode: ") + infer.infer_from_file(infer_data_path) + except Exception as e: + print(e, "Read infer_data_path Failed! Now Interactive Mode: ") while True: print('=' * 100) instruction = input("请输入您的问题: ") print("LaWGPT:") - print(evaluate(instruction)) + print(infer.generate_output(instruction)) print('=' * 100) if __name__ == "__main__": - fire.Fire(main) + fire.Fire(main) \ No newline at end of file diff --git a/scripts/infer.sh b/scripts/infer.sh index 11f5cd4..4386487 100644 --- a/scripts/infer.sh +++ b/scripts/infer.sh @@ -3,5 +3,5 @@ python infer.py \ --load_8bit True \ --base_model 'minlik/chinese-llama-7b-merged' \ --lora_weights 'entity303/lawgpt-lora-7b' \ - --infer_data_path './resources/example_infer_data.json' \ - --prompt_template 'law_template' + --prompt_template 'law_template' \ + --infer_data_path './resources/example_infer_data.json' \ No newline at end of file