Initialize repo
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
config.yaml
|
||||
src/
|
||||
data/
|
||||
112
README.md
112
README.md
@ -1,2 +1,110 @@
|
||||
# LawGPT
|
||||
Repo for LawGPT, Llama-7B tuned with Chinese Legal knowledge. 基于中文法律知识的LLaMA微调模型
|
||||
# LaWGPT:基于中文法律知识的大语言模型
|
||||
|
||||
<!--  -->
|
||||

|
||||
|
||||
<p align="center">
|
||||
<a href=""><img src="https://img.shields.io/badge/version-alpha1.0-blue"></a>
|
||||
<a href=""><img src="https://img.shields.io/github/last-commit/pengxiao-song/lawgpt"></a>
|
||||
<a href="https://www.lamda.nju.edu.cn/"><img src="https://img.shields.io/badge/support-NJU--LAMDA-9cf.svg"></a>
|
||||
|
||||
</p>
|
||||
|
||||
LaWGPT 是一系列中文法律知识增强的开源大语言模型。
|
||||
|
||||
该系列模型在 Chinese-LLaMA 的基础上扩充了法律领域词表,并使用大规模中文法律文书、中文法典进行预训练,增强了模型在法律领域的基础语义理解能力。在此基础上,构造多种法律领域对话问答数据集进行指令精调,提升了模型对法律内容的理解和执行能力。
|
||||
|
||||
详细内容请参考技术报告。
|
||||
|
||||
---
|
||||
|
||||
本项目持续开展,后续会相继开源法律领域对话问答数据集及 LaWGPT-13B 的模型。
|
||||
|
||||
|
||||
## 更新
|
||||
|
||||
- 💦 2023/04/25:公开发布 LawGPT-7B alpha1.0(内测版)供初步测试使用
|
||||
- 基于 Chinese-LLaMA 使用 50w 中文裁判文书数据二次预训练
|
||||
|
||||
## 快速开始
|
||||
|
||||
**1. 准备代码,创建环境**
|
||||
|
||||
```bash
|
||||
git clone git@github.com:pengxiao-song/LaWGPT.git
|
||||
cd LawGPT
|
||||
conda env create -f environment.yml
|
||||
conda activate lawgpt
|
||||
```
|
||||
**2. 下载模型权重**
|
||||
|
||||
**3. 启动示例**
|
||||
|
||||
## 项目结构
|
||||
|
||||
|
||||
## 数据构建
|
||||
|
||||
## 模型训练
|
||||
|
||||
中文法律基座模型 LawGPT 的训练过程分为三个阶段:
|
||||
|
||||
1. 第一阶段:扩充法律领域词表,在大规模法律文书及法典数据上预训练 Chinese-LLaMA
|
||||
2. 第二阶段:构造法律领域对话问答数据集,在预训练模型基础上指令精调
|
||||
|
||||
### 计算资源
|
||||
|
||||
8 张 Tesla V100-SXM2-32GB
|
||||
|
||||
### 训练细节
|
||||
|
||||
|
||||
## 模型评估
|
||||
|
||||
评估工作正有序开展,敬请期待。
|
||||
|
||||
## 局限性
|
||||
|
||||
由于计算资源、数据规模等因素限制,当前阶段 LawGPT 存在诸多局限性:
|
||||
|
||||
1. 数据资源有限、模型容量较小,导致其相对较弱的模型记忆和语言能力。因此,在面对事实性知识任务时,可能会生成不正确的结果。
|
||||
2. 该系列模型只进行了初步的人类意图对齐。因此,可能产生不可预测的有害内容以及不符合人类偏好和价值观的内容。
|
||||
3. 自我认知能力存在问题,中文理解能力有待增强。
|
||||
|
||||
请诸君在使用前了解上述问题,以免造成误解和不必要的麻烦。
|
||||
|
||||
## 协作者
|
||||
|
||||
本项目由[南京大学机器学习与数据挖掘研究所(LAMDA)](https://www.lamda.nju.edu.cn/CH.MainPage.ashx)支持。
|
||||
|
||||
如下各位合作开展(按字母序排列):[金苡萱](https://www.lamda.nju.edu.cn/jinyx/)、[宋鹏霄](https://www.lamda.nju.edu.cn/songpx/)、[杨骁文](https://github.com/njuyxw),由[郭兰哲](https://www.lamda.nju.edu.cn/guolz/)老师、[李宇峰](https://cs.nju.edu.cn/liyf/index.htm)老师指导。
|
||||
|
||||
## 免责声明
|
||||
|
||||
请各位严格遵守如下约定:
|
||||
|
||||
1. 本项目任何资源**仅供学术研究使用,严禁任何商业用途**。
|
||||
2. 模型输出受多种不确定性因素影响,本项目当前无法保证其准确性,**严禁用于真实法律场景**。
|
||||
3. 本项目不承担任何法律责任,亦不对因使用相关资源和输出结果而可能产生的任何损失承担责任。
|
||||
|
||||
## 问题反馈
|
||||
|
||||
如有问题,请于 GitHub Issue 中提交。请礼貌讨论,构建和谐交流环境。
|
||||
|
||||
> **协作者科研之余全力推进项目进展,由于人力有限难以实时反馈,给诸君带来不便,敬请谅解!**
|
||||
|
||||
## 致谢
|
||||
|
||||
本项目基于部分开源项目及公开数据集展开,在此对相关项目和研究开发人员表示诚挚的感谢:
|
||||
|
||||
- Chinese-LLaMA-Alpaca: https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
||||
- LLaMA: https://github.com/facebookresearch/llama
|
||||
- Alpaca: https://github.com/tatsu-lab/stanford_alpaca
|
||||
- alpaca-lora: https://github.com/tloen/alpaca-lora
|
||||
- ChatGLM-6B: https://github.com/THUDM/ChatGLM-6B
|
||||
|
||||
|
||||
|
||||
## 引用
|
||||
|
||||
如果您觉得我们的工作对您有所帮助,请考虑引用如下内容
|
||||
BIN
assets/logo/lamda.png
Normal file
BIN
assets/logo/lamda.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 37 KiB |
BIN
assets/logo/lawgpt1.png
Normal file
BIN
assets/logo/lawgpt1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 121 KiB |
BIN
assets/logo/lawgpt2.jpeg
Normal file
BIN
assets/logo/lawgpt2.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 537 KiB |
59
scripts/generate_instructions.py
Normal file
59
scripts/generate_instructions.py
Normal file
@ -0,0 +1,59 @@
|
||||
import argparse
|
||||
import openai
|
||||
import yaml
|
||||
import sys
|
||||
import random
|
||||
|
||||
|
||||
def return_random_prompt():
|
||||
system_prompt = "你需要尽可能给出多样化的任务指令和对应的回答。我们将用于人工评估ChatGPT模型对指令的完成情况。要求:\n"
|
||||
|
||||
# generate random topics
|
||||
system_prompt += "1. 主题多样化,涵盖法律诉讼的各个领域,例如:刑法、民法、行政法等。\n"
|
||||
|
||||
# generate random tasks
|
||||
task_list = ["开放式生成", "分类", "问答", "编辑", "摘要",
|
||||
"写作", "翻译", "分析", "常识推理", "写信", "抽取", "推荐"]
|
||||
system_prompt += "2. 表述多样化,结合真实问题;指令类型多样化,例如:" + \
|
||||
"、".join(random.sample(task_list, 10)) + "等。\n"
|
||||
|
||||
# other requirements
|
||||
system_prompt += "3. 如果遇到无法处理的指令(只靠文本无法回答),给出无法处理的回复。\n"
|
||||
system_prompt += "4. 除非特别要求,请使用中文,指令可以是命令句、疑问句、或其他合适的类型。\n"
|
||||
system_prompt += "5. 为指令生成一个适当且涉及真实情况的<input>,不应该只包含简单的占位符。<input>应提供实质性的内容,具有挑战性。字数不超过" + \
|
||||
str(random.randint(80, 120)) + "字。\n"
|
||||
system_prompt += "6. <output>应该是对指令的适当且真实的回应,不能只回复答应或拒绝请求。如果需要额外信息才能回复时,请努力预测用户意图并尝试回复。<output>的内容应少于" + \
|
||||
str(random.randint(128, 512)) + "字。\n\n"
|
||||
|
||||
system_prompt += "请给出满足条件的20条JSON格式数据:\n"
|
||||
|
||||
return system_prompt
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--cfg_path', default='../config.yaml', type=str)
|
||||
parser.add_argument('--save_path', default='./output.json', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.cfg_path, 'r') as f:
|
||||
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
openai.api_key = cfg['API_KEY']
|
||||
openai.api_base = cfg['API_BASE_URL']
|
||||
|
||||
output_file = open(args.save_path, 'w')
|
||||
|
||||
# number of data to generate (each prompt contains 20 JSON-formatted data)
|
||||
# TODO: 改成流式的,不然会中途断掉
|
||||
MAX_EPOCHS = 1
|
||||
for k in range(MAX_EPOCHS):
|
||||
response = openai.ChatCompletion.create(
|
||||
# here we use `gpt-3.5-turbo` model, while Stanford-Alpaca uses `text-davinci-003`
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "user", "content": return_random_prompt()},
|
||||
]
|
||||
)
|
||||
output_file.write(response["choices"][0]["message"]["content"] + '\n')
|
||||
output_file.close()
|
||||
63
scripts/merge_vocabulary.py
Normal file
63
scripts/merge_vocabulary.py
Normal file
@ -0,0 +1,63 @@
|
||||
from transformers import LlamaTokenizer
|
||||
from sentencepiece import sentencepiece_model_pb2 as model
|
||||
import sentencepiece as sp
|
||||
import argparse
|
||||
import os
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--load_path', default='../src/models/base_model/chinese_llama_7b/tokenizer_chinese.model', type=str)
|
||||
parser.add_argument('--save_dir', default='../src/models/base_model/save_chinese', type=str)
|
||||
parser.add_argument('--voc_path', default='../data/vocabulary/legal_vocab_processed.txt', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
LOAD_PATH = args.load_path
|
||||
SAVE_DIR = args.save_dir
|
||||
VOC_PATH = args.voc_path
|
||||
|
||||
# Load pre-trained llama tokenizer and sentencepiece model
|
||||
llama_spm = model.ModelProto()
|
||||
llama_spm.ParseFromString(open(LOAD_PATH, "rb").read())
|
||||
|
||||
# show size of llama's vocabulary
|
||||
llama_spm_tokens_set = set(p.piece for p in llama_spm.pieces)
|
||||
print(f"Size of initial llama's vocabulary: {len(llama_spm_tokens_set)}")
|
||||
|
||||
# Load custom vocabulary
|
||||
new_tokens = open(VOC_PATH, "r").read().split("\n")
|
||||
for token in new_tokens:
|
||||
if token not in llama_spm_tokens_set:
|
||||
new_token = model.ModelProto().SentencePiece()
|
||||
new_token.piece = token
|
||||
new_token.score = 0
|
||||
llama_spm.pieces.append(new_token)
|
||||
print(f"Size of merged llama's vocabulary: {len(llama_spm.pieces)}")
|
||||
|
||||
# save
|
||||
os.makedirs(SAVE_DIR, exist_ok=True)
|
||||
SAVE_MODEL_PATH = os.path.join(SAVE_DIR, 'tokenizer.model')
|
||||
SAVE_VOCAB_PATH = os.path.join(SAVE_DIR, 'tokenizer.vocab')
|
||||
with open(SAVE_MODEL_PATH, 'wb') as f:
|
||||
f.write(llama_spm.SerializeToString())
|
||||
with open(SAVE_VOCAB_PATH, 'w') as f:
|
||||
f.writelines([f'{token.piece} {token.score}\n' for token in llama_spm.pieces])
|
||||
tokenizer = LlamaTokenizer(SAVE_MODEL_PATH)
|
||||
tokenizer.save_pretrained(SAVE_DIR)
|
||||
print(f'New llama tokenizer and spm has been saved to {SAVE_DIR}')
|
||||
|
||||
# test
|
||||
llama_tokenizer_old = LlamaTokenizer.from_pretrained(LOAD_PATH)
|
||||
llama_tokenizer_new = LlamaTokenizer.from_pretrained(SAVE_DIR)
|
||||
text = '''登记错误赔偿责任登记等手续登记等手续生效登记机构和登记办法登记机构赔偿后登记机构应当提供登记收费问题'''
|
||||
|
||||
print(f'Size of old vocabulary: {llama_tokenizer_old.vocab_size}')
|
||||
print(f'Size of new vocabulary: {llama_tokenizer_new.vocab_size}')
|
||||
print('All special tokens and ids in new llama:')
|
||||
print(llama_tokenizer_new.all_special_tokens)
|
||||
print(llama_tokenizer_new.all_special_ids)
|
||||
print(llama_tokenizer_new.special_tokens_map)
|
||||
|
||||
print(f'Text:\n{text}')
|
||||
print(f'Tokenized by LLaMA tokenizer:\n {llama_tokenizer_old.tokenize(text)}')
|
||||
print(f'Tokenized by NEW LLaMA tokenizer:\n {llama_tokenizer_new.tokenize(text)}')
|
||||
Reference in New Issue
Block a user