diff --git a/.gitignore b/.gitignore
index 41da0ad..0353852 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,19 @@
-./data
+__pycache__/
+*.npy
+*.npz
+*.pyc
+*.pyd
+*.so
+*.ipynb
+.ipynb_checkpoints
+models/base_models/*
+!models/base_models/.gitkeep
+models/lora_weights/*
+!models/lora_weights/.gitkeep
+outputs/*
+!outputs/.gitkeep
+data/*
+!data/.gitkeep
+wandb/
+flagged/
+.DS_Store
diff --git a/README.md b/README.md
index 1896590..e8f366a 100644
--- a/README.md
+++ b/README.md
@@ -24,6 +24,7 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。
本项目持续开展,法律领域数据集及系列模型后续相继开源,敬请关注。
## 更新
+- 🛠️ 2023/05/22:项目主分支结构调整,详见[项目结构](https://github.com/pengxiao-song/LaWGPT#项目结构)
- 🪴 2023/05/15:发布 [中文法律数据源汇总(Awesome Chinese Legal Resources)](https://github.com/pengxiao-song/awesome-chinese-legal-resources) 和 [法律领域词表](https://github.com/pengxiao-song/LaWGPT/blob/main/resources/legal_vocab.txt)
@@ -44,13 +45,25 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。
1. 准备代码,创建环境
```bash
+ # 下载代码
git clone git@github.com:pengxiao-song/LaWGPT.git
cd LaWGPT
+
+ # 创建环境
+ conda create -n lawgpt python=3.10 -y
conda activate lawgpt
pip install -r requirements.txt
+
+ # 启动可视化脚本(自动下载预训练模型约15GB)
+ bash ./scripts/webui.sh
```
-2. 合并模型权重(可选)
+2. 访问 http://127.0.0.1:7860 :
+
+
+
+
+3. 合并模型权重(可选)
**如果您想使用 LaWGPT-7B-alpha 模型,可跳过改步,直接进入步骤3.**
@@ -61,44 +74,28 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。
本项目给出[合并方式](https://github.com/pengxiao-song/LaWGPT/wiki/%E6%A8%A1%E5%9E%8B%E5%90%88%E5%B9%B6),请各位获取原版权重后自行重构模型。
-3. 启动示例
-
- 启动本地服务:
-
- ```bash
- conda activate lawgpt
- cd LaWGPT
- sh src/scripts/generate.sh
- ```
-
- 接入服务:
-
-
-
-
-
-
## 项目结构
-```bash
+```bash
LaWGPT
-├── assets # 项目静态资源
-├── data # 语料及精调数据
-├── tools # 数据清洗等工具
+├── assets # 静态资源
+├── resources # 项目资源
+├── models # 基座模型及 lora 权重
+│ ├── base_models
+│ └── lora_weights
+├── outputs # 指令微调的输出权重
+├── data # 实验数据
+├── scripts # 脚本目录
+│ ├── finetune.sh # 指令微调脚本
+│ └── webui.sh # 启动服务脚本
+├── templates # prompt 模板
+├── tools # 工具包
+├── utils
+├── train_clm.py # 二次训练
+├── finetune.py # 指令微调
+├── webui.py # 启动服务
├── README.md
-├── requirements.txt
-└── src # 源码
- ├── finetune.py
- ├── generate.py
- ├── models # 基座模型及 Lora 权重
- │ ├── base_models
- │ └── lora_weights
- ├── outputs
- ├── scripts # 脚本文件
- │ ├── finetune.sh # 指令微调
- │ └── generate.sh # 服务创建
- ├── templates
- └── utils
+└── requirements.txt
```
@@ -119,13 +116,13 @@ LawGPT 系列模型的训练过程分为两个阶段:
### 二次训练流程
-1. 参考 `src/data/example_instruction_train.json` 构造二次训练数据集
-2. 运行 `src/scripts/train_lora.sh`
+1. 参考 `resources/example_instruction_train.json` 构造二次训练数据集
+2. 运行 `scripts/train_clm.sh`
### 指令精调步骤
-1. 参考 `src/data/example_instruction_tune.json` 构造指令微调数据集
-2. 运行 `src/scripts/finetune.sh`
+1. 参考 `resources/example_instruction_tune.json` 构造指令微调数据集
+2. 运行 `scripts/finetune.sh`
### 计算资源
@@ -222,4 +219,4 @@ LawGPT 系列模型的训练过程分为两个阶段:
## 引用
-如果您觉得我们的工作对您有所帮助,请考虑引用该项目
\ No newline at end of file
+如果您觉得我们的工作对您有所帮助,请考虑引用该项目
diff --git a/src/models/base_models/.gitkeep b/data/.gitkeep
similarity index 100%
rename from src/models/base_models/.gitkeep
rename to data/.gitkeep
diff --git a/src/finetune.py b/finetune.py
similarity index 85%
rename from src/finetune.py
rename to finetune.py
index ff7c0b3..4059fc7 100644
--- a/src/finetune.py
+++ b/finetune.py
@@ -7,6 +7,12 @@ import torch
import transformers
from datasets import load_dataset
+"""
+Unused imports:
+import torch.nn as nn
+import bitsandbytes as bnb
+"""
+
from peft import (
LoraConfig,
get_peft_model,
@@ -15,45 +21,41 @@ from peft import (
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer
+
from utils.prompter import Prompter
def train(
# model/data params
- base_model: str = "./models/base_models/your_base_model_dir",
- data_path: str = "./data/your_data.json",
- output_dir: str = "./outputs/your_version_dir",
-
+ base_model: str = "", # the only required argument
+ data_path: str = "yahma/alpaca-cleaned",
+ output_dir: str = "./lora-alpaca",
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 4,
- num_epochs: int = 10,
+ num_epochs: int = 3,
learning_rate: float = 3e-4,
- cutoff_len: int = 512,
+ cutoff_len: int = 256,
val_set_size: int = 2000,
-
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
- lora_target_modules: List[str] = ["q_proj", "v_proj",],
-
+ lora_target_modules: List[str] = [
+ "q_proj",
+ "v_proj",
+ ],
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = True,
group_by_length: bool = False, # faster, but produces an odd training loss curve
-
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
-
- # either training checkpoint or final adapter
- resume_from_checkpoint: str = None,
-
- # The prompt template to use, will default to alpaca.
- prompt_template_name: str = "alpaca",
+ resume_from_checkpoint: str = None, # either training checkpoint or final adapter
+ prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
@@ -81,11 +83,13 @@ def train(
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
)
+ assert (
+ base_model
+ ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
- # Configure device and distributed training
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
@@ -95,8 +99,8 @@ def train(
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
- "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
-
+ "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
+ )
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
@@ -113,21 +117,13 @@ def train(
)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
- tokenizer.bos_token_id = 1
- tokenizer.eos_token_id = 2
- bos = tokenizer.bos_token_id
- eos = tokenizer.eos_token_id
- pad = tokenizer.pad_token_id
-
- print("pre-trained model's BOS EOS and PAD token id:",
- bos, eos, pad, " => It should be 1,2,none")
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
- def tokenize(prompt, add_eos_token=True):
+ def tokenize(prompt):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
@@ -212,13 +208,18 @@ def train(
else:
print(f"Checkpoint {checkpoint_name} not found")
- # Be more transparent about the % of trainable params.
- model.print_trainable_parameters()
+ model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
- train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
- train_data = (train_val["train"].shuffle().map(generate_and_tokenize_prompt))
- val_data = (train_val["test"].shuffle().map(generate_and_tokenize_prompt))
+ train_val = data["train"].train_test_split(
+ test_size=val_set_size, shuffle=True, seed=42
+ )
+ train_data = (
+ train_val["train"].shuffle().map(generate_and_tokenize_prompt)
+ )
+ val_data = (
+ train_val["test"].shuffle().map(generate_and_tokenize_prompt)
+ )
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
@@ -235,7 +236,7 @@ def train(
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
- warmup_steps=100,
+ warmup_ratio=0.1,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
@@ -243,10 +244,10 @@ def train(
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
- eval_steps=100 if val_set_size > 0 else None,
- save_steps=100,
+ eval_steps=50 if val_set_size > 0 else None,
+ save_steps=50,
output_dir=output_dir,
- save_total_limit=3,
+ save_total_limit=5,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
@@ -273,7 +274,9 @@ def train(
model.save_pretrained(output_dir)
- print("\n If there's a warning about missing keys above, please disregard :)")
+ print(
+ "\n If there's a warning about missing keys above, please disregard :)"
+ )
if __name__ == "__main__":
diff --git a/src/models/lora_weights/.gitkeep b/models/base_models/.gitkeep
similarity index 100%
rename from src/models/lora_weights/.gitkeep
rename to models/base_models/.gitkeep
diff --git a/src/outputs/.gitkeep b/models/lora_weights/.gitkeep
similarity index 100%
rename from src/outputs/.gitkeep
rename to models/lora_weights/.gitkeep
diff --git a/src/utils/__init__.py b/outputs/.gitkeep
similarity index 100%
rename from src/utils/__init__.py
rename to outputs/.gitkeep
diff --git a/src/data/example_instruction_train.json b/resources/example_instruction_train.json
similarity index 100%
rename from src/data/example_instruction_train.json
rename to resources/example_instruction_train.json
diff --git a/resources/example_instruction_tune.json b/resources/example_instruction_tune.json
new file mode 100644
index 0000000..496d4fd
--- /dev/null
+++ b/resources/example_instruction_tune.json
@@ -0,0 +1,12 @@
+[
+ {
+ "instruction": "酒驾撞人要判多久?",
+ "input": "",
+ "output": "《刑法》第一百三十三条规定:违反交通运输管理法规,因而发生重大事故,致人重伤、死广或者使公私财产遭受重大损失的,处三年以下有期徒刑或者拘役;交通运输肇事后逃逸,致人重伤的,处三年以上七年以下有期徒刑。交通肇事后逃逸致人死亡的,处七年以上有期徒刑。"
+ },
+ {
+ "instruction": "偷狗被派出所抓到,会怎么处理?",
+ "input": "",
+ "output": "取决于偷盗的狗的价值,可能按盗窃罪的罪名处理。如果价值达到犯罪标准,就构成盗窃罪,要承担刑事责任;如果不到犯罪标准,就是治安处罚、罚款或者拘留治安处罚这会涉嫌构成盗窃。如果不到一千元,则不会构成犯罪。如果超过一千元,则可能会是构成犯罪的。"
+ }
+]
\ No newline at end of file
diff --git a/scripts/finetune.sh b/scripts/finetune.sh
new file mode 100644
index 0000000..e8b5614
--- /dev/null
+++ b/scripts/finetune.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+export WANDB_MODE=disabled # 禁用wandb
+
+# 使用chinese-alpaca-plus-7b-merged模型在law_data.json数据集上finetune
+experiment_name="chinese-alpaca-plus-7b-law-e1"
+
+# 单卡或者模型并行
+python finetune.py \
+ --base_model "minlik/chinese-alpaca-plus-7b-merged" \
+ --data_path "./data/finetune_law_data.json" \
+ --output_dir "./outputs/"${experiment_name} \
+ --batch_size 64 \
+ --micro_batch_size 8 \
+ --num_epochs 20 \
+ --learning_rate 3e-4 \
+ --cutoff_len 256 \
+ --val_set_size 0 \
+ --lora_r 8 \
+ --lora_alpha 16 \
+ --lora_dropout 0.05 \
+ --lora_target_modules "[q_proj,v_proj]" \
+ --train_on_inputs True \
+ --add_eos_token True \
+ --group_by_length False \
+ --wandb_project \
+ --wandb_run_name \
+ --wandb_watch \
+ --wandb_log_model \
+ --resume_from_checkpoint "./outputs/"${experiment_name} \
+ --prompt_template_name "alpaca" \
+
+
+# 多卡数据并行
+# WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1234 finetune.py \
+# --base_model "minlik/chinese-alpaca-plus-7b-merged" \
+# --data_path "./data/finetune_law_data.json" \
+# --output_dir "./outputs/"${experiment_name} \
+# --batch_size 64 \
+# --micro_batch_size 8 \
+# --num_epochs 20 \
+# --learning_rate 3e-4 \
+# --cutoff_len 256 \
+# --val_set_size 0 \
+# --lora_r 8 \
+# --lora_alpha 16 \
+# --lora_dropout 0.05 \
+# --lora_target_modules "[q_proj,v_proj]" \
+# --train_on_inputs True \
+# --add_eos_token True \
+# --group_by_length False \
+# --wandb_project \
+# --wandb_run_name \
+# --wandb_watch \
+# --wandb_log_model \
+# --resume_from_checkpoint "./outputs/"${experiment_name} \
+# --prompt_template_name "alpaca" \
\ No newline at end of file
diff --git a/src/scripts/train.sh b/scripts/train_clm.sh
similarity index 64%
rename from src/scripts/train.sh
rename to scripts/train_clm.sh
index 56532a2..cb45cb6 100644
--- a/src/scripts/train.sh
+++ b/scripts/train_clm.sh
@@ -1,9 +1,9 @@
#!/bin/bash
-WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1235 train_lora.py \
- --base_model '../models/base_models/chinese_llama_7b' \
- --data_path '' \
- --output_dir '../models/lora_weights' \
+WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1235 train_clm.py \
+ --base_model './models/base_models/chinese_llama_7b' \
+ --data_path './data/train_clm_data.json' \
+ --output_dir './outputs/train-clm' \
--batch_size 128 \
--micro_batch_size 8 \
--num_epochs 1 \
@@ -17,4 +17,4 @@ WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --
--train_on_inputs True \
--add_eos_token True \
--group_by_length True \
- --resume_from_checkpoint '../models/lora_weights'
+ --resume_from_checkpoint './outputs/train-clm'
\ No newline at end of file
diff --git a/scripts/webui.sh b/scripts/webui.sh
new file mode 100644
index 0000000..097cd3e
--- /dev/null
+++ b/scripts/webui.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+
+# 使用huggingface上已经训练好的模型
+python webui.py \
+ --load_8bit True \
+ --base_model 'minlik/chinese-llama-7b-merged' \
+ --lora_weights 'entity303/lawgpt-lora-7b' \
+ --prompt_template "law_template" \
+ --server_name "0.0.0.0" \
+ --share_gradio Ture \
+
+
+# 使用自己finetune的lora, 把自己的模型放到对应目录即可
+# python webui.py \
+# --load_8bit True \
+# --base_model 'minlik/chinese-alpaca-plus-7b-merged' \
+# --lora_weights './outputs/chinese-alpaca-plus-7b-law-e1' \
+# --prompt_template "alpaca" \
+# --server_name "0.0.0.0" \
+# --share_gradio Ture \
\ No newline at end of file
diff --git a/src/data/example_instruction_tune.json b/src/data/example_instruction_tune.json
deleted file mode 100644
index f6db36f..0000000
--- a/src/data/example_instruction_tune.json
+++ /dev/null
@@ -1,10 +0,0 @@
-[
- {
- "instruction": "酒驾撞人要判多久?",
- "input": "",
- "output": "《刑法》第一百三十三条规定:违反交通运输管理法规,因而发生重大事故,致人重伤、死广或者使公私财产遭受重大损失的,处三年以下有期徒刑或者拘役;交通运输肇事后逃逸,致人重伤的,处三年以上七年以下有期徒刑。交通肇事后逃逸致人死亡的,处七年以上有期徒刑。"
- },
- {"instruction":"偷狗被派出所抓到,会怎么处理?",
- "input":"",
- "output":"取决于偷盗的狗的价值,可能按盗窃罪的罪名处理。如果价值达到犯罪标准,就构成盗窃罪,要承担刑事责任;如果不到犯罪标准,就是治安处罚、罚款或者拘留治安处罚这会涉嫌构成盗窃。如果不到一千元,则不会构成犯罪。如果超过一千元,则可能会是构成犯罪的。"}
-]
\ No newline at end of file
diff --git a/src/scripts/finetune.sh b/src/scripts/finetune.sh
deleted file mode 100644
index 14bf7a1..0000000
--- a/src/scripts/finetune.sh
+++ /dev/null
@@ -1,17 +0,0 @@
-#!/bin/bash
-
-WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1234 finetune.py \
- --base_model 'minlik/chinese-llama-7b-merged' \
- --data_path '' \
- --output_dir './outputs/LawGPT' \
- --prompt_template_name 'law_template' \
- --micro_batch_size 16 \
- --batch_size 128 \
- --num_epochs 3 \
- --val_set_size 10000 \
- --lora_target_modules='[q_proj,k_proj,v_proj,o_proj]' \
- --lora_r 16 \
- --lora_alpha 32 \
- --learning_rate 3e-4 \
- --cutoff_len 512 \
- --resume_from_checkpoint './outputs/LawGPT' \
\ No newline at end of file
diff --git a/src/scripts/generate.sh b/src/scripts/generate.sh
deleted file mode 100644
index 283007e..0000000
--- a/src/scripts/generate.sh
+++ /dev/null
@@ -1,7 +0,0 @@
-
-CUDA_VISIBLE_DEVICES=1 python generate.py \
- --load_8bit \
- --base_model 'minlik/chinese-llama-7b-merged' \
- --lora_weights 'entity303/lawgpt-lora-7b' \
- --prompt_template 'law_template' \
- --share_gradio
diff --git a/templates/alpaca.json b/templates/alpaca.json
new file mode 100644
index 0000000..e486439
--- /dev/null
+++ b/templates/alpaca.json
@@ -0,0 +1,6 @@
+{
+ "description": "Template used by Alpaca-LoRA.",
+ "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
+ "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
+ "response_split": "### Response:"
+}
diff --git a/src/templates/law_template.json b/templates/law_template.json
similarity index 100%
rename from src/templates/law_template.json
rename to templates/law_template.json
diff --git a/src/train_lora.py b/train_clm.py
similarity index 100%
rename from src/train_lora.py
rename to train_clm.py
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/utils/callbacks.py b/utils/callbacks.py
similarity index 100%
rename from src/utils/callbacks.py
rename to utils/callbacks.py
diff --git a/src/utils/evaluate.py b/utils/evaluate.py
similarity index 100%
rename from src/utils/evaluate.py
rename to utils/evaluate.py
diff --git a/src/utils/merge.py b/utils/merge.py
similarity index 100%
rename from src/utils/merge.py
rename to utils/merge.py
diff --git a/src/utils/prompter.py b/utils/prompter.py
similarity index 100%
rename from src/utils/prompter.py
rename to utils/prompter.py
diff --git a/src/generate.py b/webui.py
similarity index 77%
rename from src/generate.py
rename to webui.py
index bbe5513..453043a 100644
--- a/src/generate.py
+++ b/webui.py
@@ -6,7 +6,7 @@ import gradio as gr
import torch
import transformers
from peft import PeftModel
-from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
+from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoModel, AutoTokenizer, AutoModelForCausalLM
from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter
@@ -19,14 +19,14 @@ else:
try:
if torch.backends.mps.is_available():
device = "mps"
-except: # noqa: E722
+except:
pass
def main(
load_8bit: bool = False,
base_model: str = "",
- lora_weights: str = "tloen/alpaca-lora-7b",
+ lora_weights: str = "",
prompt_template: str = "", # The prompt template to use, will default to alpaca.
server_name: str = "0.0.0.0", # Allows to listen on all interfaces by providing '0.
share_gradio: bool = False,
@@ -45,33 +45,41 @@ def main(
torch_dtype=torch.float16,
device_map="auto",
)
- model = PeftModel.from_pretrained(
- model,
- lora_weights,
- torch_dtype=torch.float16,
- )
+ try:
+ model = PeftModel.from_pretrained(
+ model,
+ lora_weights,
+ torch_dtype=torch.float16,
+ )
+ except:
+ print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device},
torch_dtype=torch.float16,
)
- model = PeftModel.from_pretrained(
- model,
- lora_weights,
- device_map={"": device},
- torch_dtype=torch.float16,
- )
+ try:
+ model = PeftModel.from_pretrained(
+ model,
+ lora_weights,
+ device_map={"": device},
+ torch_dtype=torch.float16,
+ )
+ except:
+ print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
else:
model = LlamaForCausalLM.from_pretrained(
- base_model,
- device_map={"": device}, low_cpu_mem_usage=True
- )
- model = PeftModel.from_pretrained(
- model,
- lora_weights,
- device_map={"": device},
+ base_model, device_map={"": device}, low_cpu_mem_usage=True
)
+ try:
+ model = PeftModel.from_pretrained(
+ model,
+ lora_weights,
+ device_map={"": device},
+ )
+ except:
+ print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
@@ -87,15 +95,16 @@ def main(
def evaluate(
instruction,
- input=None,
+ # input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
- num_beams=1,
- max_new_tokens=256,
- stream_output=True,
+ num_beams=4,
+ max_new_tokens=128,
+ stream_output=False,
**kwargs,
):
+ input=None
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
@@ -144,6 +153,7 @@ def main(
break
yield prompter.get_response(decoded_output)
+ print(decoded_output)
return # early return for stream_output
# Without streaming
@@ -157,6 +167,7 @@ def main(
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
+ print(output)
yield prompter.get_response(output)
gr.Interface(
@@ -165,11 +176,11 @@ def main(
gr.components.Textbox(
lines=2,
label="Instruction",
- placeholder="Tell me about alpacas.",
+ placeholder="此处输入法律相关问题",
),
- gr.components.Textbox(lines=2, label="Input", placeholder="none"),
+ # gr.components.Textbox(lines=2, label="Input", placeholder="none"),
gr.components.Slider(
- minimum=0, maximum=1, value=1.0, label="Temperature"
+ minimum=0, maximum=1, value=0.1, label="Temperature"
),
gr.components.Slider(
minimum=0, maximum=1, value=0.75, label="Top p"
@@ -178,23 +189,22 @@ def main(
minimum=0, maximum=100, step=1, value=40, label="Top k"
),
gr.components.Slider(
- minimum=1, maximum=4, step=1, value=4, label="Beams"
+ minimum=1, maximum=4, step=1, value=1, label="Beams"
),
gr.components.Slider(
minimum=1, maximum=2000, step=1, value=256, label="Max tokens"
),
- gr.components.Checkbox(label="Stream output", value=True),
+ gr.components.Checkbox(label="Stream output", value=True),
],
outputs=[
gr.inputs.Textbox(
- lines=5,
+ lines=8,
label="Output",
)
],
title="🦙🌲 LaWGPT",
- description="", # noqa: E501
+ description="",
).queue().launch(server_name="0.0.0.0", share=share_gradio)
- # Old testing code follows.
if __name__ == "__main__":