Files
LaWGPT/infer.py
2023-05-22 02:39:04 +08:00

127 lines
3.6 KiB
Python

import sys
import json
import fire
import gradio as gr
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from utils.prompter import Prompter
if torch.cuda.is_available():
device = "cuda"
def load_instruction(instruct_dir):
input_data = []
with open(instruct_dir, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
d = json.loads(line)
input_data.append(d)
return input_data
def main(
load_8bit: bool = True,
base_model: str = "",
# the infer data, if not exists, infer the default instructions in code
instruct_dir: str = "",
lora_weights: str = "",
# The prompt template to use, will default to med_template.
prompt_template: str = "",
):
prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
try:
print(f"using lora {lora_weights}")
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
except:
print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
def evaluate(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=1,
max_new_tokens=256,
**kwargs,
):
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
# repetition_penalty=10.0,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
return prompter.get_response(output)
def infer_from_json(instruct_dir):
input_data = load_instruction(instruct_dir)
for d in input_data:
instruction = d["instruction"]
output = d["output"]
print('=' * 100)
print(f"###{base_model}-{lora_weights}###")
model_output = evaluate(instruction)
print("###instruction###")
print(instruction)
print("###golden output###")
print(output)
print("###model output###")
print(model_output)
print('=' * 100)
if instruct_dir != "":
infer_from_json(instruct_dir)
else:
for instruction in [
"民间借贷受国家保护的合法利息是多少?",
]:
print("Instruction:", instruction)
print("Response:", evaluate(instruction))
print()
if __name__ == "__main__":
fire.Fire(main)