diff --git a/README.md b/README.md index d50857d..fa18de2 100644 --- a/README.md +++ b/README.md @@ -60,9 +60,15 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。 bash scripts/webui.sh # 打开浏览器,访问 http://127.0.0.1:7860/ - # 在Instructions框输入法律问题,点击"Submit"按钮, 等待模型生成答案 + # 在Instruction框输入法律问题,点击"Submit"按钮, 等待模型生成答案 ``` + 如果您想使用自己的数据进行finetune,请查看脚本`scripts/finetune.sh`: + ```bash + bash scripts/finetune.sh + ``` + + 2. 合并模型权重(可选) **如果您想使用 LaWGPT-7B-alpha 模型,可跳过改步,直接进入步骤3.** diff --git a/infer.py b/infer.py deleted file mode 100644 index 078875b..0000000 --- a/infer.py +++ /dev/null @@ -1,126 +0,0 @@ -import sys -import json - -import fire -import gradio as gr -import torch -import transformers -from peft import PeftModel -from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer - -from utils.prompter import Prompter - -if torch.cuda.is_available(): - device = "cuda" - -def load_instruction(instruct_dir): - input_data = [] - with open(instruct_dir, "r") as f: - lines = f.readlines() - for line in lines: - line = line.strip() - d = json.loads(line) - input_data.append(d) - return input_data - - -def main( - load_8bit: bool = True, - base_model: str = "", - # the infer data, if not exists, infer the default instructions in code - instruct_dir: str = "", - lora_weights: str = "", - # The prompt template to use, will default to med_template. - prompt_template: str = "", -): - prompter = Prompter(prompt_template) - tokenizer = LlamaTokenizer.from_pretrained(base_model) - model = LlamaForCausalLM.from_pretrained( - base_model, - load_in_8bit=load_8bit, - torch_dtype=torch.float16, - device_map="auto", - ) - try: - print(f"using lora {lora_weights}") - model = PeftModel.from_pretrained( - model, - lora_weights, - torch_dtype=torch.float16, - ) - except: - print("*"*50, "\n Attention! No Lora Weights \n", "*"*50) - # unwind broken decapoda-research config - model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk - model.config.bos_token_id = 1 - model.config.eos_token_id = 2 - if not load_8bit: - model.half() # seems to fix bugs for some users. - - model.eval() - - if torch.__version__ >= "2" and sys.platform != "win32": - model = torch.compile(model) - - def evaluate( - instruction, - input=None, - temperature=0.1, - top_p=0.75, - top_k=40, - num_beams=1, - max_new_tokens=256, - **kwargs, - ): - prompt = prompter.generate_prompt(instruction, input) - inputs = tokenizer(prompt, return_tensors="pt") - input_ids = inputs["input_ids"].to(device) - generation_config = GenerationConfig( - temperature=temperature, - top_p=top_p, - top_k=top_k, - num_beams=num_beams, - # repetition_penalty=10.0, - **kwargs, - ) - with torch.no_grad(): - generation_output = model.generate( - input_ids=input_ids, - generation_config=generation_config, - return_dict_in_generate=True, - output_scores=True, - max_new_tokens=max_new_tokens, - ) - s = generation_output.sequences[0] - output = tokenizer.decode(s) - return prompter.get_response(output) - - def infer_from_json(instruct_dir): - input_data = load_instruction(instruct_dir) - for d in input_data: - instruction = d["instruction"] - output = d["output"] - print('=' * 100) - print(f"###{base_model}-{lora_weights}###") - model_output = evaluate(instruction) - print("###instruction###") - print(instruction) - print("###golden output###") - print(output) - print("###model output###") - print(model_output) - print('=' * 100) - - if instruct_dir != "": - infer_from_json(instruct_dir) - else: - for instruction in [ - "民间借贷受国家保护的合法利息是多少?", - ]: - print("Instruction:", instruction) - print("Response:", evaluate(instruction)) - print() - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/scripts/generate.sh b/scripts/generate.sh deleted file mode 100644 index 283007e..0000000 --- a/scripts/generate.sh +++ /dev/null @@ -1,7 +0,0 @@ - -CUDA_VISIBLE_DEVICES=1 python generate.py \ - --load_8bit \ - --base_model 'minlik/chinese-llama-7b-merged' \ - --lora_weights 'entity303/lawgpt-lora-7b' \ - --prompt_template 'law_template' \ - --share_gradio diff --git a/scripts/infer.sh b/scripts/infer.sh deleted file mode 100644 index 6bd3211..0000000 --- a/scripts/infer.sh +++ /dev/null @@ -1,16 +0,0 @@ - -# LawGPT -python infer.py \ - --base_model 'minlik/chinese-alpaca-plus-7b-merged' \ - --lora_weights './outputs/chinese-alpaca-plus-7b-law-e1' \ - --instruct_dir './data/infer_law_data.json' \ - --prompt_template 'alpaca' - - -# Chinese-Alpaca-plus-7B -python infer.py \ - --base_model 'minlik/chinese-alpaca-plus-7b-merged' \ - --lora_weights '' \ - --instruct_dir './data/infer_law_data.json' \ - --prompt_template 'alpaca' - diff --git a/scripts/train.sh b/scripts/train.sh deleted file mode 100644 index 56532a2..0000000 --- a/scripts/train.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1235 train_lora.py \ - --base_model '../models/base_models/chinese_llama_7b' \ - --data_path '' \ - --output_dir '../models/lora_weights' \ - --batch_size 128 \ - --micro_batch_size 8 \ - --num_epochs 1 \ - --learning_rate 0.0003 \ - --cutoff_len 1024 \ - --val_set_size 0 \ - --lora_r 16 \ - --lora_alpha 32 \ - --lora_dropout 0.05 \ - --lora_target_modules '[q_proj, v_proj, k_proj, o_proj]' \ - --train_on_inputs True \ - --add_eos_token True \ - --group_by_length True \ - --resume_from_checkpoint '../models/lora_weights' diff --git a/scripts/webui.sh b/scripts/webui.sh index a36078f..097cd3e 100644 --- a/scripts/webui.sh +++ b/scripts/webui.sh @@ -4,7 +4,7 @@ # 使用huggingface上已经训练好的模型 python webui.py \ --load_8bit True \ - --base_model 'minlik/chinese-alpaca-plus-7b-merged' \ + --base_model 'minlik/chinese-llama-7b-merged' \ --lora_weights 'entity303/lawgpt-lora-7b' \ --prompt_template "law_template" \ --server_name "0.0.0.0" \ @@ -16,6 +16,6 @@ python webui.py \ # --load_8bit True \ # --base_model 'minlik/chinese-alpaca-plus-7b-merged' \ # --lora_weights './outputs/chinese-alpaca-plus-7b-law-e1' \ -# --prompt_template "law_template" \ +# --prompt_template "alpaca" \ # --server_name "0.0.0.0" \ # --share_gradio Ture \ \ No newline at end of file