20
.gitignore
vendored
20
.gitignore
vendored
@ -1 +1,19 @@
|
||||
./data
|
||||
__pycache__/
|
||||
*.npy
|
||||
*.npz
|
||||
*.pyc
|
||||
*.pyd
|
||||
*.so
|
||||
*.ipynb
|
||||
.ipynb_checkpoints
|
||||
models/base_models/*
|
||||
!models/base_models/.gitkeep
|
||||
models/lora_weights/*
|
||||
!models/lora_weights/.gitkeep
|
||||
outputs/*
|
||||
!outputs/.gitkeep
|
||||
data/*
|
||||
!data/.gitkeep
|
||||
wandb/
|
||||
flagged/
|
||||
.DS_Store
|
||||
|
||||
77
README.md
77
README.md
@ -24,6 +24,7 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。
|
||||
本项目持续开展,法律领域数据集及系列模型后续相继开源,敬请关注。
|
||||
|
||||
## 更新
|
||||
- 🛠️ 2023/05/22:项目主分支结构调整,详见[项目结构](https://github.com/pengxiao-song/LaWGPT#项目结构)
|
||||
|
||||
- 🪴 2023/05/15:发布 [中文法律数据源汇总(Awesome Chinese Legal Resources)](https://github.com/pengxiao-song/awesome-chinese-legal-resources) 和 [法律领域词表](https://github.com/pengxiao-song/LaWGPT/blob/main/resources/legal_vocab.txt)
|
||||
|
||||
@ -44,13 +45,25 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。
|
||||
1. 准备代码,创建环境
|
||||
|
||||
```bash
|
||||
# 下载代码
|
||||
git clone git@github.com:pengxiao-song/LaWGPT.git
|
||||
cd LaWGPT
|
||||
|
||||
# 创建环境
|
||||
conda create -n lawgpt python=3.10 -y
|
||||
conda activate lawgpt
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 启动可视化脚本(自动下载预训练模型约15GB)
|
||||
bash ./scripts/webui.sh
|
||||
```
|
||||
|
||||
2. 合并模型权重(可选)
|
||||
2. 访问 http://127.0.0.1:7860 :
|
||||
<p align="center">
|
||||
<img src="./assets/demo/demo.png" width="80%" >
|
||||
</p>
|
||||
|
||||
3. 合并模型权重(可选)
|
||||
|
||||
**如果您想使用 LaWGPT-7B-alpha 模型,可跳过改步,直接进入步骤3.**
|
||||
|
||||
@ -61,44 +74,28 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。
|
||||
本项目给出[合并方式](https://github.com/pengxiao-song/LaWGPT/wiki/%E6%A8%A1%E5%9E%8B%E5%90%88%E5%B9%B6),请各位获取原版权重后自行重构模型。
|
||||
|
||||
|
||||
3. 启动示例
|
||||
|
||||
启动本地服务:
|
||||
|
||||
```bash
|
||||
conda activate lawgpt
|
||||
cd LaWGPT
|
||||
sh src/scripts/generate.sh
|
||||
```
|
||||
|
||||
接入服务:
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/demo/demo.png" width="80%" >
|
||||
</p>
|
||||
|
||||
|
||||
## 项目结构
|
||||
|
||||
```bash
|
||||
```bash
|
||||
LaWGPT
|
||||
├── assets # 项目静态资源
|
||||
├── data # 语料及精调数据
|
||||
├── tools # 数据清洗等工具
|
||||
├── assets # 静态资源
|
||||
├── resources # 项目资源
|
||||
├── models # 基座模型及 lora 权重
|
||||
│ ├── base_models
|
||||
│ └── lora_weights
|
||||
├── outputs # 指令微调的输出权重
|
||||
├── data # 实验数据
|
||||
├── scripts # 脚本目录
|
||||
│ ├── finetune.sh # 指令微调脚本
|
||||
│ └── webui.sh # 启动服务脚本
|
||||
├── templates # prompt 模板
|
||||
├── tools # 工具包
|
||||
├── utils
|
||||
├── train_clm.py # 二次训练
|
||||
├── finetune.py # 指令微调
|
||||
├── webui.py # 启动服务
|
||||
├── README.md
|
||||
├── requirements.txt
|
||||
└── src # 源码
|
||||
├── finetune.py
|
||||
├── generate.py
|
||||
├── models # 基座模型及 Lora 权重
|
||||
│ ├── base_models
|
||||
│ └── lora_weights
|
||||
├── outputs
|
||||
├── scripts # 脚本文件
|
||||
│ ├── finetune.sh # 指令微调
|
||||
│ └── generate.sh # 服务创建
|
||||
├── templates
|
||||
└── utils
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
|
||||
@ -119,13 +116,13 @@ LawGPT 系列模型的训练过程分为两个阶段:
|
||||
|
||||
### 二次训练流程
|
||||
|
||||
1. 参考 `src/data/example_instruction_train.json` 构造二次训练数据集
|
||||
2. 运行 `src/scripts/train_lora.sh`
|
||||
1. 参考 `resources/example_instruction_train.json` 构造二次训练数据集
|
||||
2. 运行 `scripts/train_clm.sh`
|
||||
|
||||
### 指令精调步骤
|
||||
|
||||
1. 参考 `src/data/example_instruction_tune.json` 构造指令微调数据集
|
||||
2. 运行 `src/scripts/finetune.sh`
|
||||
1. 参考 `resources/example_instruction_tune.json` 构造指令微调数据集
|
||||
2. 运行 `scripts/finetune.sh`
|
||||
|
||||
### 计算资源
|
||||
|
||||
@ -222,4 +219,4 @@ LawGPT 系列模型的训练过程分为两个阶段:
|
||||
|
||||
## 引用
|
||||
|
||||
如果您觉得我们的工作对您有所帮助,请考虑引用该项目
|
||||
如果您觉得我们的工作对您有所帮助,请考虑引用该项目
|
||||
|
||||
@ -7,6 +7,12 @@ import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
|
||||
"""
|
||||
Unused imports:
|
||||
import torch.nn as nn
|
||||
import bitsandbytes as bnb
|
||||
"""
|
||||
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
@ -15,45 +21,41 @@ from peft import (
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
from utils.prompter import Prompter
|
||||
|
||||
|
||||
def train(
|
||||
# model/data params
|
||||
base_model: str = "./models/base_models/your_base_model_dir",
|
||||
data_path: str = "./data/your_data.json",
|
||||
output_dir: str = "./outputs/your_version_dir",
|
||||
|
||||
base_model: str = "", # the only required argument
|
||||
data_path: str = "yahma/alpaca-cleaned",
|
||||
output_dir: str = "./lora-alpaca",
|
||||
# training hyperparams
|
||||
batch_size: int = 128,
|
||||
micro_batch_size: int = 4,
|
||||
num_epochs: int = 10,
|
||||
num_epochs: int = 3,
|
||||
learning_rate: float = 3e-4,
|
||||
cutoff_len: int = 512,
|
||||
cutoff_len: int = 256,
|
||||
val_set_size: int = 2000,
|
||||
|
||||
# lora hyperparams
|
||||
lora_r: int = 8,
|
||||
lora_alpha: int = 16,
|
||||
lora_dropout: float = 0.05,
|
||||
lora_target_modules: List[str] = ["q_proj", "v_proj",],
|
||||
|
||||
lora_target_modules: List[str] = [
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
],
|
||||
# llm hyperparams
|
||||
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
||||
add_eos_token: bool = True,
|
||||
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
||||
|
||||
# wandb params
|
||||
wandb_project: str = "",
|
||||
wandb_run_name: str = "",
|
||||
wandb_watch: str = "", # options: false | gradients | all
|
||||
wandb_log_model: str = "", # options: false | true
|
||||
|
||||
# either training checkpoint or final adapter
|
||||
resume_from_checkpoint: str = None,
|
||||
|
||||
# The prompt template to use, will default to alpaca.
|
||||
prompt_template_name: str = "alpaca",
|
||||
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
|
||||
):
|
||||
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
||||
print(
|
||||
@ -81,11 +83,13 @@ def train(
|
||||
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
|
||||
f"prompt template: {prompt_template_name}\n"
|
||||
)
|
||||
assert (
|
||||
base_model
|
||||
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
|
||||
gradient_accumulation_steps = batch_size // micro_batch_size
|
||||
|
||||
prompter = Prompter(prompt_template_name)
|
||||
|
||||
# Configure device and distributed training
|
||||
device_map = "auto"
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
ddp = world_size != 1
|
||||
@ -95,8 +99,8 @@ def train(
|
||||
|
||||
# Check if parameter passed or if set within environ
|
||||
use_wandb = len(wandb_project) > 0 or (
|
||||
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
|
||||
|
||||
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
|
||||
)
|
||||
# Only overwrite environ if wandb param passed
|
||||
if len(wandb_project) > 0:
|
||||
os.environ["WANDB_PROJECT"] = wandb_project
|
||||
@ -113,21 +117,13 @@ def train(
|
||||
)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||
tokenizer.bos_token_id = 1
|
||||
tokenizer.eos_token_id = 2
|
||||
bos = tokenizer.bos_token_id
|
||||
eos = tokenizer.eos_token_id
|
||||
pad = tokenizer.pad_token_id
|
||||
|
||||
print("pre-trained model's BOS EOS and PAD token id:",
|
||||
bos, eos, pad, " => It should be 1,2,none")
|
||||
|
||||
tokenizer.pad_token_id = (
|
||||
0 # unk. we want this to be different from the eos token
|
||||
)
|
||||
tokenizer.padding_side = "left" # Allow batched inference
|
||||
|
||||
def tokenize(prompt, add_eos_token=True):
|
||||
def tokenize(prompt):
|
||||
# there's probably a way to do this with the tokenizer settings
|
||||
# but again, gotta move fast
|
||||
result = tokenizer(
|
||||
@ -212,13 +208,18 @@ def train(
|
||||
else:
|
||||
print(f"Checkpoint {checkpoint_name} not found")
|
||||
|
||||
# Be more transparent about the % of trainable params.
|
||||
model.print_trainable_parameters()
|
||||
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
||||
|
||||
if val_set_size > 0:
|
||||
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
|
||||
train_data = (train_val["train"].shuffle().map(generate_and_tokenize_prompt))
|
||||
val_data = (train_val["test"].shuffle().map(generate_and_tokenize_prompt))
|
||||
train_val = data["train"].train_test_split(
|
||||
test_size=val_set_size, shuffle=True, seed=42
|
||||
)
|
||||
train_data = (
|
||||
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
)
|
||||
val_data = (
|
||||
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
||||
)
|
||||
else:
|
||||
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
val_data = None
|
||||
@ -235,7 +236,7 @@ def train(
|
||||
args=transformers.TrainingArguments(
|
||||
per_device_train_batch_size=micro_batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
warmup_steps=100,
|
||||
warmup_ratio=0.1,
|
||||
num_train_epochs=num_epochs,
|
||||
learning_rate=learning_rate,
|
||||
fp16=True,
|
||||
@ -243,10 +244,10 @@ def train(
|
||||
optim="adamw_torch",
|
||||
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
||||
save_strategy="steps",
|
||||
eval_steps=100 if val_set_size > 0 else None,
|
||||
save_steps=100,
|
||||
eval_steps=50 if val_set_size > 0 else None,
|
||||
save_steps=50,
|
||||
output_dir=output_dir,
|
||||
save_total_limit=3,
|
||||
save_total_limit=5,
|
||||
load_best_model_at_end=True if val_set_size > 0 else False,
|
||||
ddp_find_unused_parameters=False if ddp else None,
|
||||
group_by_length=group_by_length,
|
||||
@ -273,7 +274,9 @@ def train(
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
print("\n If there's a warning about missing keys above, please disregard :)")
|
||||
print(
|
||||
"\n If there's a warning about missing keys above, please disregard :)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
12
resources/example_instruction_tune.json
Normal file
12
resources/example_instruction_tune.json
Normal file
@ -0,0 +1,12 @@
|
||||
[
|
||||
{
|
||||
"instruction": "酒驾撞人要判多久?",
|
||||
"input": "",
|
||||
"output": "《刑法》第一百三十三条规定:违反交通运输管理法规,因而发生重大事故,致人重伤、死广或者使公私财产遭受重大损失的,处三年以下有期徒刑或者拘役;交通运输肇事后逃逸,致人重伤的,处三年以上七年以下有期徒刑。交通肇事后逃逸致人死亡的,处七年以上有期徒刑。"
|
||||
},
|
||||
{
|
||||
"instruction": "偷狗被派出所抓到,会怎么处理?",
|
||||
"input": "",
|
||||
"output": "取决于偷盗的狗的价值,可能按盗窃罪的罪名处理。如果价值达到犯罪标准,就构成盗窃罪,要承担刑事责任;如果不到犯罪标准,就是治安处罚、罚款或者拘留治安处罚这会涉嫌构成盗窃。如果不到一千元,则不会构成犯罪。如果超过一千元,则可能会是构成犯罪的。"
|
||||
}
|
||||
]
|
||||
56
scripts/finetune.sh
Normal file
56
scripts/finetune.sh
Normal file
@ -0,0 +1,56 @@
|
||||
#!/bin/bash
|
||||
export WANDB_MODE=disabled # 禁用wandb
|
||||
|
||||
# 使用chinese-alpaca-plus-7b-merged模型在law_data.json数据集上finetune
|
||||
experiment_name="chinese-alpaca-plus-7b-law-e1"
|
||||
|
||||
# 单卡或者模型并行
|
||||
python finetune.py \
|
||||
--base_model "minlik/chinese-alpaca-plus-7b-merged" \
|
||||
--data_path "./data/finetune_law_data.json" \
|
||||
--output_dir "./outputs/"${experiment_name} \
|
||||
--batch_size 64 \
|
||||
--micro_batch_size 8 \
|
||||
--num_epochs 20 \
|
||||
--learning_rate 3e-4 \
|
||||
--cutoff_len 256 \
|
||||
--val_set_size 0 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 16 \
|
||||
--lora_dropout 0.05 \
|
||||
--lora_target_modules "[q_proj,v_proj]" \
|
||||
--train_on_inputs True \
|
||||
--add_eos_token True \
|
||||
--group_by_length False \
|
||||
--wandb_project \
|
||||
--wandb_run_name \
|
||||
--wandb_watch \
|
||||
--wandb_log_model \
|
||||
--resume_from_checkpoint "./outputs/"${experiment_name} \
|
||||
--prompt_template_name "alpaca" \
|
||||
|
||||
|
||||
# 多卡数据并行
|
||||
# WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1234 finetune.py \
|
||||
# --base_model "minlik/chinese-alpaca-plus-7b-merged" \
|
||||
# --data_path "./data/finetune_law_data.json" \
|
||||
# --output_dir "./outputs/"${experiment_name} \
|
||||
# --batch_size 64 \
|
||||
# --micro_batch_size 8 \
|
||||
# --num_epochs 20 \
|
||||
# --learning_rate 3e-4 \
|
||||
# --cutoff_len 256 \
|
||||
# --val_set_size 0 \
|
||||
# --lora_r 8 \
|
||||
# --lora_alpha 16 \
|
||||
# --lora_dropout 0.05 \
|
||||
# --lora_target_modules "[q_proj,v_proj]" \
|
||||
# --train_on_inputs True \
|
||||
# --add_eos_token True \
|
||||
# --group_by_length False \
|
||||
# --wandb_project \
|
||||
# --wandb_run_name \
|
||||
# --wandb_watch \
|
||||
# --wandb_log_model \
|
||||
# --resume_from_checkpoint "./outputs/"${experiment_name} \
|
||||
# --prompt_template_name "alpaca" \
|
||||
@ -1,9 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1235 train_lora.py \
|
||||
--base_model '../models/base_models/chinese_llama_7b' \
|
||||
--data_path '' \
|
||||
--output_dir '../models/lora_weights' \
|
||||
WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1235 train_clm.py \
|
||||
--base_model './models/base_models/chinese_llama_7b' \
|
||||
--data_path './data/train_clm_data.json' \
|
||||
--output_dir './outputs/train-clm' \
|
||||
--batch_size 128 \
|
||||
--micro_batch_size 8 \
|
||||
--num_epochs 1 \
|
||||
@ -17,4 +17,4 @@ WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --
|
||||
--train_on_inputs True \
|
||||
--add_eos_token True \
|
||||
--group_by_length True \
|
||||
--resume_from_checkpoint '../models/lora_weights'
|
||||
--resume_from_checkpoint './outputs/train-clm'
|
||||
21
scripts/webui.sh
Normal file
21
scripts/webui.sh
Normal file
@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
# 使用huggingface上已经训练好的模型
|
||||
python webui.py \
|
||||
--load_8bit True \
|
||||
--base_model 'minlik/chinese-llama-7b-merged' \
|
||||
--lora_weights 'entity303/lawgpt-lora-7b' \
|
||||
--prompt_template "law_template" \
|
||||
--server_name "0.0.0.0" \
|
||||
--share_gradio Ture \
|
||||
|
||||
|
||||
# 使用自己finetune的lora, 把自己的模型放到对应目录即可
|
||||
# python webui.py \
|
||||
# --load_8bit True \
|
||||
# --base_model 'minlik/chinese-alpaca-plus-7b-merged' \
|
||||
# --lora_weights './outputs/chinese-alpaca-plus-7b-law-e1' \
|
||||
# --prompt_template "alpaca" \
|
||||
# --server_name "0.0.0.0" \
|
||||
# --share_gradio Ture \
|
||||
@ -1,10 +0,0 @@
|
||||
[
|
||||
{
|
||||
"instruction": "酒驾撞人要判多久?",
|
||||
"input": "",
|
||||
"output": "《刑法》第一百三十三条规定:违反交通运输管理法规,因而发生重大事故,致人重伤、死广或者使公私财产遭受重大损失的,处三年以下有期徒刑或者拘役;交通运输肇事后逃逸,致人重伤的,处三年以上七年以下有期徒刑。交通肇事后逃逸致人死亡的,处七年以上有期徒刑。"
|
||||
},
|
||||
{"instruction":"偷狗被派出所抓到,会怎么处理?",
|
||||
"input":"",
|
||||
"output":"取决于偷盗的狗的价值,可能按盗窃罪的罪名处理。如果价值达到犯罪标准,就构成盗窃罪,要承担刑事责任;如果不到犯罪标准,就是治安处罚、罚款或者拘留治安处罚这会涉嫌构成盗窃。如果不到一千元,则不会构成犯罪。如果超过一千元,则可能会是构成犯罪的。"}
|
||||
]
|
||||
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1234 finetune.py \
|
||||
--base_model 'minlik/chinese-llama-7b-merged' \
|
||||
--data_path '' \
|
||||
--output_dir './outputs/LawGPT' \
|
||||
--prompt_template_name 'law_template' \
|
||||
--micro_batch_size 16 \
|
||||
--batch_size 128 \
|
||||
--num_epochs 3 \
|
||||
--val_set_size 10000 \
|
||||
--lora_target_modules='[q_proj,k_proj,v_proj,o_proj]' \
|
||||
--lora_r 16 \
|
||||
--lora_alpha 32 \
|
||||
--learning_rate 3e-4 \
|
||||
--cutoff_len 512 \
|
||||
--resume_from_checkpoint './outputs/LawGPT' \
|
||||
@ -1,7 +0,0 @@
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python generate.py \
|
||||
--load_8bit \
|
||||
--base_model 'minlik/chinese-llama-7b-merged' \
|
||||
--lora_weights 'entity303/lawgpt-lora-7b' \
|
||||
--prompt_template 'law_template' \
|
||||
--share_gradio
|
||||
6
templates/alpaca.json
Normal file
6
templates/alpaca.json
Normal file
@ -0,0 +1,6 @@
|
||||
{
|
||||
"description": "Template used by Alpaca-LoRA.",
|
||||
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
|
||||
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
|
||||
"response_split": "### Response:"
|
||||
}
|
||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
@ -6,7 +6,7 @@ import gradio as gr
|
||||
import torch
|
||||
import transformers
|
||||
from peft import PeftModel
|
||||
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoModel, AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from utils.callbacks import Iteratorize, Stream
|
||||
from utils.prompter import Prompter
|
||||
@ -19,14 +19,14 @@ else:
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
except: # noqa: E722
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def main(
|
||||
load_8bit: bool = False,
|
||||
base_model: str = "",
|
||||
lora_weights: str = "tloen/alpaca-lora-7b",
|
||||
lora_weights: str = "",
|
||||
prompt_template: str = "", # The prompt template to use, will default to alpaca.
|
||||
server_name: str = "0.0.0.0", # Allows to listen on all interfaces by providing '0.
|
||||
share_gradio: bool = False,
|
||||
@ -45,33 +45,41 @@ def main(
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
lora_weights,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
try:
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
lora_weights,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
except:
|
||||
print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
|
||||
elif device == "mps":
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
device_map={"": device},
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
lora_weights,
|
||||
device_map={"": device},
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
try:
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
lora_weights,
|
||||
device_map={"": device},
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
except:
|
||||
print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
|
||||
else:
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
device_map={"": device}, low_cpu_mem_usage=True
|
||||
)
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
lora_weights,
|
||||
device_map={"": device},
|
||||
base_model, device_map={"": device}, low_cpu_mem_usage=True
|
||||
)
|
||||
try:
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
lora_weights,
|
||||
device_map={"": device},
|
||||
)
|
||||
except:
|
||||
print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
|
||||
|
||||
# unwind broken decapoda-research config
|
||||
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
|
||||
@ -87,15 +95,16 @@ def main(
|
||||
|
||||
def evaluate(
|
||||
instruction,
|
||||
input=None,
|
||||
# input=None,
|
||||
temperature=0.1,
|
||||
top_p=0.75,
|
||||
top_k=40,
|
||||
num_beams=1,
|
||||
max_new_tokens=256,
|
||||
stream_output=True,
|
||||
num_beams=4,
|
||||
max_new_tokens=128,
|
||||
stream_output=False,
|
||||
**kwargs,
|
||||
):
|
||||
input=None
|
||||
prompt = prompter.generate_prompt(instruction, input)
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"].to(device)
|
||||
@ -144,6 +153,7 @@ def main(
|
||||
break
|
||||
|
||||
yield prompter.get_response(decoded_output)
|
||||
print(decoded_output)
|
||||
return # early return for stream_output
|
||||
|
||||
# Without streaming
|
||||
@ -157,6 +167,7 @@ def main(
|
||||
)
|
||||
s = generation_output.sequences[0]
|
||||
output = tokenizer.decode(s)
|
||||
print(output)
|
||||
yield prompter.get_response(output)
|
||||
|
||||
gr.Interface(
|
||||
@ -165,11 +176,11 @@ def main(
|
||||
gr.components.Textbox(
|
||||
lines=2,
|
||||
label="Instruction",
|
||||
placeholder="Tell me about alpacas.",
|
||||
placeholder="此处输入法律相关问题",
|
||||
),
|
||||
gr.components.Textbox(lines=2, label="Input", placeholder="none"),
|
||||
# gr.components.Textbox(lines=2, label="Input", placeholder="none"),
|
||||
gr.components.Slider(
|
||||
minimum=0, maximum=1, value=1.0, label="Temperature"
|
||||
minimum=0, maximum=1, value=0.1, label="Temperature"
|
||||
),
|
||||
gr.components.Slider(
|
||||
minimum=0, maximum=1, value=0.75, label="Top p"
|
||||
@ -178,23 +189,22 @@ def main(
|
||||
minimum=0, maximum=100, step=1, value=40, label="Top k"
|
||||
),
|
||||
gr.components.Slider(
|
||||
minimum=1, maximum=4, step=1, value=4, label="Beams"
|
||||
minimum=1, maximum=4, step=1, value=1, label="Beams"
|
||||
),
|
||||
gr.components.Slider(
|
||||
minimum=1, maximum=2000, step=1, value=256, label="Max tokens"
|
||||
),
|
||||
gr.components.Checkbox(label="Stream output", value=True),
|
||||
gr.components.Checkbox(label="Stream output", value=True),
|
||||
],
|
||||
outputs=[
|
||||
gr.inputs.Textbox(
|
||||
lines=5,
|
||||
lines=8,
|
||||
label="Output",
|
||||
)
|
||||
],
|
||||
title="🦙🌲 LaWGPT",
|
||||
description="", # noqa: E501
|
||||
description="",
|
||||
).queue().launch(server_name="0.0.0.0", share=share_gradio)
|
||||
# Old testing code follows.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Reference in New Issue
Block a user