[MNT] Fix infer scripts
This commit is contained in:
113
infer.py
113
infer.py
@ -12,43 +12,53 @@ if torch.cuda.is_available():
|
|||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
||||||
|
|
||||||
def main(
|
class Infer():
|
||||||
load_8bit: bool = False,
|
def __init__(
|
||||||
base_model: str = "",
|
self,
|
||||||
lora_weights: str = "",
|
load_8bit: bool = False,
|
||||||
infer_data_path: str = "",
|
base_model: str = "",
|
||||||
prompt_template: str = "", # The prompt template to use, will default to alpaca.
|
lora_weights: str = "",
|
||||||
):
|
prompt_template: str = "", # The prompt template to use, will default to alpaca.
|
||||||
prompter = Prompter(prompt_template)
|
):
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
prompter = Prompter(prompt_template)
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||||
base_model,
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
load_in_8bit=load_8bit,
|
base_model,
|
||||||
torch_dtype=torch.float16,
|
load_in_8bit=load_8bit,
|
||||||
device_map="auto",
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
print(f"Using lora {lora_weights}")
|
|
||||||
model = PeftModel.from_pretrained(
|
|
||||||
model,
|
|
||||||
lora_weights,
|
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
except:
|
|
||||||
print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
|
|
||||||
# unwind broken decapoda-research config
|
|
||||||
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
|
|
||||||
model.config.bos_token_id = 1
|
|
||||||
model.config.eos_token_id = 2
|
|
||||||
if not load_8bit:
|
|
||||||
model.half() # seems to fix bugs for some users.
|
|
||||||
|
|
||||||
model.eval()
|
try:
|
||||||
|
print(f"Using lora {lora_weights}")
|
||||||
|
model = PeftModel.from_pretrained(
|
||||||
|
model,
|
||||||
|
lora_weights,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
|
||||||
|
|
||||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
# unwind broken decapoda-research config
|
||||||
model = torch.compile(model)
|
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
|
||||||
|
model.config.bos_token_id = 1
|
||||||
|
model.config.eos_token_id = 2
|
||||||
|
if not load_8bit:
|
||||||
|
model.half() # seems to fix bugs for some users.
|
||||||
|
|
||||||
def evaluate(
|
model.eval()
|
||||||
|
|
||||||
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
|
model = torch.compile(model)
|
||||||
|
|
||||||
|
self.base_model = base_model
|
||||||
|
self.lora_weights = lora_weights
|
||||||
|
self.model = model
|
||||||
|
self.prompter = prompter
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def generate_output(
|
||||||
|
self,
|
||||||
instruction,
|
instruction,
|
||||||
input=None,
|
input=None,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
@ -58,8 +68,8 @@ def main(
|
|||||||
max_new_tokens=256,
|
max_new_tokens=256,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
prompt = prompter.generate_prompt(instruction, input)
|
prompt = self.prompter.generate_prompt(instruction, input)
|
||||||
inputs = tokenizer(prompt, return_tensors="pt")
|
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||||
input_ids = inputs["input_ids"].to(device)
|
input_ids = inputs["input_ids"].to(device)
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@ -70,7 +80,7 @@ def main(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generation_output = model.generate(
|
generation_output = self.model.generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
@ -78,32 +88,47 @@ def main(
|
|||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
)
|
)
|
||||||
s = generation_output.sequences[0]
|
s = generation_output.sequences[0]
|
||||||
output = tokenizer.decode(s)
|
output = self.tokenizer.decode(s)
|
||||||
return prompter.get_response(output)
|
return self.prompter.get_response(output)
|
||||||
|
|
||||||
def infer_from_file():
|
def infer_from_file(self, infer_data_path):
|
||||||
with open(infer_data_path) as f:
|
with open(infer_data_path) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
instruction = data["instruction"]
|
instruction = data["instruction"]
|
||||||
output = data["output"]
|
output = data["output"]
|
||||||
print('=' * 100)
|
print('=' * 100)
|
||||||
print(f"Base Model: {base_model} Lora Weights: {lora_weights}")
|
print(f"Base Model: {self.base_model} Lora Weights: {self.lora_weights}")
|
||||||
print("Instruction:\n", instruction)
|
print("Instruction:\n", instruction)
|
||||||
model_output = evaluate(instruction)
|
model_output = self.generate_output(instruction)
|
||||||
print("Model Output:\n", model_output)
|
print("Model Output:\n", model_output)
|
||||||
print("Ground Truth:\n", output)
|
print("Ground Truth:\n", output)
|
||||||
print('=' * 100)
|
print('=' * 100)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
load_8bit: bool = False,
|
||||||
|
base_model: str = "",
|
||||||
|
lora_weights: str = "",
|
||||||
|
prompt_template: str = "", # The prompt template to use, will default to alpaca.
|
||||||
|
infer_data_path: str = "",
|
||||||
|
):
|
||||||
|
infer = Infer(
|
||||||
|
load_8bit=load_8bit,
|
||||||
|
base_model=base_model,
|
||||||
|
lora_weights=lora_weights,
|
||||||
|
prompt_template=prompt_template
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
infer_from_file()
|
infer.infer_from_file(infer_data_path)
|
||||||
except:
|
except Exception as e:
|
||||||
print("Read infer_data_path Failed! Now Interactive Mode: ")
|
print(e, "Read infer_data_path Failed! Now Interactive Mode: ")
|
||||||
while True:
|
while True:
|
||||||
print('=' * 100)
|
print('=' * 100)
|
||||||
instruction = input("请输入您的问题: ")
|
instruction = input("请输入您的问题: ")
|
||||||
print("LaWGPT:")
|
print("LaWGPT:")
|
||||||
print(evaluate(instruction))
|
print(infer.generate_output(instruction))
|
||||||
print('=' * 100)
|
print('=' * 100)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,5 +3,5 @@ python infer.py \
|
|||||||
--load_8bit True \
|
--load_8bit True \
|
||||||
--base_model 'minlik/chinese-llama-7b-merged' \
|
--base_model 'minlik/chinese-llama-7b-merged' \
|
||||||
--lora_weights 'entity303/lawgpt-lora-7b' \
|
--lora_weights 'entity303/lawgpt-lora-7b' \
|
||||||
--infer_data_path './resources/example_infer_data.json' \
|
--prompt_template 'law_template' \
|
||||||
--prompt_template 'law_template'
|
--infer_data_path './resources/example_infer_data.json'
|
||||||
Reference in New Issue
Block a user