Release code repo

This commit is contained in:
songpx
2023-05-13 17:26:44 +08:00
parent 95b621693c
commit 7652c0071d
21 changed files with 1036 additions and 3896 deletions

4
.gitignore vendored
View File

@ -1,3 +1 @@
config.yaml
src/
data/
data/

View File

@ -2,7 +2,7 @@
<p align="center">
<a href="./assets/logo/lawgpt2.jpeg">
<img src="./assets/logo/lawgpt2.jpeg" width="70%" >
<img src="./assets/logo/lawgpt2.jpeg" width="80%" >
</a>
</p>
@ -10,24 +10,25 @@
<a href=""><img src="https://img.shields.io/badge/version-alpha1.0-blue"></a>
<a href=""><img src="https://img.shields.io/github/last-commit/pengxiao-song/lawgpt"></a>
<a href="https://www.lamda.nju.edu.cn/"><img src="https://img.shields.io/badge/support-NJU--LAMDA-9cf.svg"></a>
</p>
LaWGPT 是一系列中文法律知识增强的开源大语言模型。
LaWGPT 是一系列基于中文法律知识的开源大语言模型。
该系列模型在 Chinese-LLaMA 的基础上扩充法律领域词表,并使用大规模中文法律文书、中文法典进行预训练,增强了模型在法律领域的基础语义理解能力。在此基础上,构造多种法律领域对话问答数据集进行指令精调,提升了模型对法律内容的理解和执行能力。
该系列模型在通用中文基座模型(如 Chinese-LLaMA、ChatGLM 等)的基础上扩充法律领域专有词表、**大规模中文法律语料预训练**,增强了模型在法律领域的基础语义理解能力。在此基础上,**构造法律领域对话问答数据集、中国司法考试数据集进行指令精调**,提升了模型对法律内容的理解和执行能力。
详细内容请参考技术报告。
---
本项目持续开展,后续会相继开源法律领域对话问答数据集及 LaWGPT-13B 的模型
本项目持续开展,法律领域专有数据集及系列模型后续相继开源,敬请关注
## 更新
- 💦 2023/04/25:公开发布 LawGPT-7B alpha1.0(内测版)供初步测试使用
- 基于 Chinese-LLaMA 使用 50w 中文裁判文书数据二次预训练
- 🔥🔥🔥 2023/05/13:公开发布 legal-base-7blawgpt-7b-beta1.0
- legal-base-7b基座 Chinese-LLaMA-7B ,基于 50w 中文裁判文书数据二次预训练
- lawgpt-7b-beta1.0:基于 legal-base-7b构造 30w 高质量法律问答数据集指令精调
- 🔥🔥🔥 2023/04/12内部测试 lawgpt-7b-alpha
- lawgpt-7b-alpha基座 Chinese-LLaMA-7B ,构造 30w 法律问答数据集指令精调
## 快速开始
@ -35,7 +36,7 @@ LaWGPT 是一系列中文法律知识增强的开源大语言模型。
```bash
git clone git@github.com:pengxiao-song/LaWGPT.git
cd LawGPT
cd LaWGPT
conda env create -f environment.yml
conda activate lawgpt
```
@ -48,6 +49,10 @@ conda activate lawgpt
## 数据构建
本项目汇总了互联网上的中文法律数据源
根据 [Stanford_alpaca](https://github.com/tatsu-lab/stanford_alpaca#data-generation-process) 和 [self-instruct](https://github.com/yizhongw/self-instruct) 方式数据生成
## 模型训练
中文法律基座模型 LawGPT 的训练过程分为三个阶段:
@ -98,7 +103,7 @@ conda activate lawgpt
## 致谢
本项目基于部分开源项目及公开数据集展开,在此对相关项目和研究开发人员表示诚挚的感谢:
本项目基于如下开源项目展开,在此对相关项目和研究开发人员表示诚挚的感谢:
- Chinese-LLaMA-Alpaca: https://github.com/ymcui/Chinese-LLaMA-Alpaca
- LLaMA: https://github.com/facebookresearch/llama
@ -106,8 +111,6 @@ conda activate lawgpt
- alpaca-lora: https://github.com/tloen/alpaca-lora
- ChatGLM-6B: https://github.com/THUDM/ChatGLM-6B
## 引用
如果您觉得我们的工作对您有所帮助,请考虑引用如下内容
如果您觉得我们的工作对您有所帮助,请考虑引用如下内容

View File

@ -1,6 +0,0 @@
## 指令微调数据
**JEC-QA中国法考问答数据集**
- https://jecqa.thunlp.org/
- https://github.com/thunlp/jec-qa

View File

@ -1,476 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>answer</th>\n",
" <th>option_list</th>\n",
" <th>statement</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>[B]</td>\n",
" <td>{'A': '我国商务部在确定进口橡胶制品是否存在补贴时必须证明出国(地区)政府直接向出口商...</td>\n",
" <td>中国商务部决定对原产于马来西亚等八国的橡胶制品展开反补贴调查。根据我国《反补贴条例》以及相关...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>[D]</td>\n",
" <td>{'A': '该法典体现了“个人最大限度的自由,法律最小限度的干涉”这一立法精神', 'B'...</td>\n",
" <td>1804年的《法国民法典》是世界近代法制史上的第一部民法典是大陆法系的核心和基础。下列关于...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>[D]</td>\n",
" <td>{'A': '“偶语诗书”', 'B': '“以古非今”', 'C': '“非所宜言”', ...</td>\n",
" <td>据史书载,以下均为秦朝刑事罪名。下列哪一选项最不具有秦朝法律文化的专制特色?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>[A, B]</td>\n",
" <td>{'A': '船舶抵押权的设定', 'B': '同国籍船舶在公海发生碰撞的损害赔偿', 'C...</td>\n",
" <td>根据《中华人民共和国海商法》,在海事关系的法律适用中,旗国法适用于下列哪些情形?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>[A, B, C]</td>\n",
" <td>{'A': '“君权神授”观念是近代宪法发展的思想条件之一', 'B': '美国宪法是世界上...</td>\n",
" <td>下列有关宪法发展史的论述不正确的有</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13292</th>\n",
" <td>[D]</td>\n",
" <td>{'A': '如依中国法律和甲国法律均构成犯罪,即可准予引渡', 'B': '中国应按照收到...</td>\n",
" <td>中国人高某在甲国探亲期间加入甲国国籍,回中国后健康不佳,也未申请退出中国国籍。后甲国因高某在...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13293</th>\n",
" <td>[B, C]</td>\n",
" <td>{'A': '欣荣公司的请求权已经超过诉讼时效', 'B': '乙的请求权没有超过诉讼时效'...</td>\n",
" <td>欣荣公司于2006年8月1日领取营业执照时股东甲尚有50万元的出资未缴纳。按照出资协议最晚...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13294</th>\n",
" <td>[B]</td>\n",
" <td>{'A': '报同级检察院批准', 'B': '报同级检察院备案', 'C': '报上一级公...</td>\n",
" <td>张某因涉嫌放火罪被批准逮捕。公安机关在侦查过程中,发现张某另有抢劫罪的重大嫌疑,决定依照刑事...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13295</th>\n",
" <td>[A, B, D]</td>\n",
" <td>{'A': '被告人的辩护人申请审判员张某回避', 'B': '被告人收到起诉书后下落不明'...</td>\n",
" <td>在法庭审判过程中,下列哪些情形不可以延期审理?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13296</th>\n",
" <td>[C]</td>\n",
" <td>{'A': '甲鸳鸯运输公司的船舶发生原油泄漏,导致某海域大面积污染,海鲜捕捞量大幅度减产,...</td>\n",
" <td>以下事实中,在甲乙之间产生民事法律关系的是:</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>21072 rows × 3 columns</p>\n",
"</div>"
],
"text/plain": [
" answer option_list \\\n",
"0 [B] {'A': '我国商务部在确定进口橡胶制品是否存在补贴时必须证明出国(地区)政府直接向出口商... \n",
"1 [D] {'A': '该法典体现了“个人最大限度的自由,法律最小限度的干涉”这一立法精神', 'B'... \n",
"2 [D] {'A': '“偶语诗书”', 'B': '“以古非今”', 'C': '“非所宜言”', ... \n",
"3 [A, B] {'A': '船舶抵押权的设定', 'B': '同国籍船舶在公海发生碰撞的损害赔偿', 'C... \n",
"4 [A, B, C] {'A': '“君权神授”观念是近代宪法发展的思想条件之一', 'B': '美国宪法是世界上... \n",
"... ... ... \n",
"13292 [D] {'A': '如依中国法律和甲国法律均构成犯罪,即可准予引渡', 'B': '中国应按照收到... \n",
"13293 [B, C] {'A': '欣荣公司的请求权已经超过诉讼时效', 'B': '乙的请求权没有超过诉讼时效'... \n",
"13294 [B] {'A': '报同级检察院批准', 'B': '报同级检察院备案', 'C': '报上一级公... \n",
"13295 [A, B, D] {'A': '被告人的辩护人申请审判员张某回避', 'B': '被告人收到起诉书后下落不明'... \n",
"13296 [C] {'A': '甲鸳鸯运输公司的船舶发生原油泄漏,导致某海域大面积污染,海鲜捕捞量大幅度减产,... \n",
"\n",
" statement \n",
"0 中国商务部决定对原产于马来西亚等八国的橡胶制品展开反补贴调查。根据我国《反补贴条例》以及相关... \n",
"1 1804年的《法国民法典》是世界近代法制史上的第一部民法典是大陆法系的核心和基础。下列关于... \n",
"2 据史书载,以下均为秦朝刑事罪名。下列哪一选项最不具有秦朝法律文化的专制特色? \n",
"3 根据《中华人民共和国海商法》,在海事关系的法律适用中,旗国法适用于下列哪些情形? \n",
"4 下列有关宪法发展史的论述不正确的有 \n",
"... ... \n",
"13292 中国人高某在甲国探亲期间加入甲国国籍,回中国后健康不佳,也未申请退出中国国籍。后甲国因高某在... \n",
"13293 欣荣公司于2006年8月1日领取营业执照时股东甲尚有50万元的出资未缴纳。按照出资协议最晚... \n",
"13294 张某因涉嫌放火罪被批准逮捕。公安机关在侦查过程中,发现张某另有抢劫罪的重大嫌疑,决定依照刑事... \n",
"13295 在法庭审判过程中,下列哪些情形不可以延期审理? \n",
"13296 以下事实中,在甲乙之间产生民事法律关系的是: \n",
"\n",
"[21072 rows x 3 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train0_df = pd.read_json('./0_train.json', lines=True)\n",
"train1_df = pd.read_json('./1_train.json', lines=True)\n",
"\n",
"train_df = pd.concat([train0_df, train1_df], axis=0)\n",
"train_df = train_df.loc[:, ['answer', 'option_list', 'statement']]\n",
"train_df"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'instruction': '中国商务部决定对原产于马来西亚等八国的橡胶制品展开反补贴调查。根据我国《反补贴条例》以及相关法律法规,下列关于此次反补贴调查的哪项判断是正确的? (A) 我国商务部在确定进口橡胶制品是否存在补贴时必须证明出国(地区)政府直接向出口商提供了现金形式的财政资助 (B) 在反补贴调查期间,该八国政府或橡胶制品的出口经营者,可以向中国商务部作出承诺,取消、限制补贴或改变价格 (C) 如果我国商务部终局裁定决定对该八国进口橡胶制品征收反补贴税该反补贴税的征收期限不得超过10年 (D) 如果中国橡胶制品进口商对商务部征收反补贴税的终局裁定不服,必须首先向商务部请求行政复审,对行政复审决定还不服,才能向中国有管辖权的法院起诉.',\n",
" 'input': '',\n",
" 'output': '(B) 在反补贴调查期间,该八国政府或橡胶制品的出口经营者,可以向中国商务部作出承诺,取消、限制补贴或改变价格'},\n",
" {'instruction': '1804年的《法国民法典》是世界近代法制史上的第一部民法典是大陆法系的核心和基础。下列关于《法国民法典》的哪一项表述不正确? (A) 该法典体现了“个人最大限度的自由,法律最小限度的干涉”这一立法精神 (B) 该法典具有鲜明的革命性和时代性 (C) 该法典的影响后来传播到美洲、非洲和亚洲广大地区 (D) 该法典首次全面规定了法人制度.',\n",
" 'input': '',\n",
" 'output': '(D) 该法典首次全面规定了法人制度'},\n",
" {'instruction': '据史书载,以下均为秦朝刑事罪名。下列哪一选项最不具有秦朝法律文化的专制特色? (A) “偶语诗书” (B) “以古非今” (C) “非所宜言” (D) “失刑”.',\n",
" 'input': '',\n",
" 'output': '(D) “失刑”'},\n",
" {'instruction': '根据《中华人民共和国海商法》,在海事关系的法律适用中,旗国法适用于下列哪些情形? (A) 船舶抵押权的设定 (B) 同国籍船舶在公海发生碰撞的损害赔偿 (C) 共同海损理算 (D) 海事赔偿责任限制.',\n",
" 'input': '',\n",
" 'output': '(A) 船舶抵押权的设定(B) 同国籍船舶在公海发生碰撞的损害赔偿'},\n",
" {'instruction': '下列有关宪法发展史的论述不正确的有 (A) “君权神授”观念是近代宪法发展的思想条件之一 (B) 美国宪法是世界上最早的宪法 (C) 1918年的《苏联宪法》是第一部社会主义性质的宪法 (D) 《人权宣言》不是法国的第一部宪法.',\n",
" 'input': '',\n",
" 'output': '(A) “君权神授”观念是近代宪法发展的思想条件之一(B) 美国宪法是世界上最早的宪法(C) 1918年的《苏联宪法》是第一部社会主义性质的宪法'},\n",
" {'instruction': '下列按照特别程序审理的案件中,必须由审判员组成合议庭审理的是哪一项? (A) 宣告失踪案件 (B) 认定公民无民事行为能力案件 (C) 选民资格案件 (D) 认定财产无主案件.',\n",
" 'input': '',\n",
" 'output': '(C) 选民资格案件'},\n",
" {'instruction': '关于行政赔偿诉讼,下列哪些选项是正确的? (A) 当事人在提起行政诉讼的同时一并提出行政赔偿请求,法院应分别立案 (B) 除特殊情形外,法院单独受理的一审行政赔偿案件的审理期限为三个月 (C) 如复议决定加重损害,赔偿请求人只对复议机关提出行政赔偿诉讼的,复议机关为被告 (D) 提起行政诉讼时一并提出行政赔偿请求的,可以在提起诉讼后至法院一审判决前提出.',\n",
" 'input': '',\n",
" 'output': '(A) 当事人在提起行政诉讼的同时一并提出行政赔偿请求,法院应分别立案(B) 除特殊情形外,法院单独受理的一审行政赔偿案件的审理期限为三个月(C) 如复议决定加重损害,赔偿请求人只对复议机关提出行政赔偿诉讼的,复议机关为被告'},\n",
" {'instruction': '根据法律规定,下列关于土地使用权出让的表述正确的是: (A) 土地使用权出让的对象只能是国有土地的使用权 (B) 出让商业、旅游、娱乐和豪华住宅用地的使用权的,有条件的,必须采取招标、拍卖的方式 (C) 土地使用权出让的,应该签订书面的出让合同,但是出让合同只对土地使用者具有约束力 (D) 土地使用者以出让的方式取得土地使用权后,应当向县级以上地方人民政府申请登记,经核实后,由同级人民政府土地管理部门颁发土地使用权证.',\n",
" 'input': '',\n",
" 'output': '(A) 土地使用权出让的对象只能是国有土地的使用权(B) 出让商业、旅游、娱乐和豪华住宅用地的使用权的,有条件的,必须采取招标、拍卖的方式'},\n",
" {'instruction': '依我国法律规定,在我国法院受理的涉外离婚案件审理过程中,认定婚姻是否有效应当以下列哪一项为准据法? (A) 婚姻缔结地法 (B) 当事人本国法 (C) 当事人住所地法 (D) 法院地.',\n",
" 'input': '',\n",
" 'output': '(A) 婚姻缔结地法'},\n",
" {'instruction': '法院在审理刑事案件过程中,下列表述不正确的有: (A) 法院经过审理,如果认为定罪证据不足,应当坚持无罪推定的原则,依法宣告被告人无罪 (B) 法院经过审理,如果认为量刑证据存疑,应当在量刑时作出有利于被告人的处理 (C) 为了保障被告人的合法权益,侦查人员在规定的办案场所外讯问取得的供述应当一律依法排除 (D) 庭前会议时,对于控辩双方有异议的证据,应当重点调查,没有异议的,可以简化调查.',\n",
" 'input': '',\n",
" 'output': '(A) 法院经过审理,如果认为定罪证据不足,应当坚持无罪推定的原则,依法宣告被告人无罪(C) 为了保障被告人的合法权益,侦查人员在规定的办案场所外讯问取得的供述应当一律依法排除(D) 庭前会议时,对于控辩双方有异议的证据,应当重点调查,没有异议的,可以简化调查'}]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_instances = []\n",
"for index, row in train_df.iterrows():\n",
" instance_dict = {}\n",
" answers, options, statement = row['answer'], row['option_list'], row['statement']\n",
" opt_a, opt_b, opt_c, opt_d = options.values()\n",
" instruction = f\"{statement} (A) {opt_a} (B) {opt_b} (C) {opt_c} (D) {opt_d}.\"\n",
" output = ''\n",
" for answer in answers:\n",
" if answer in ['A', 'B', 'C', 'D']:\n",
" output += f\"({answer}) {options[answer]}\"\n",
" output = f\"{output}\"\n",
" test_instances.append({'instruction': instruction, 'input': '', 'output': output})\n",
"test_instances[:10]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"with open('qa_thunlp.json', 'w') as f:\n",
" json.dump(test_instances, f, ensure_ascii=False)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>option_list</th>\n",
" <th>statement</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>11781</td>\n",
" <td>{'A': '犯罪的预备阶段', 'B': '犯罪的实行阶段', 'C': '犯罪行为尚未实...</td>\n",
" <td>犯罪中止可以发生在:</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>13516</td>\n",
" <td>{'A': '行为人知道或者应当知道标明密级的事项,而为境外窃取、刺探、收买、非法提供的',...</td>\n",
" <td>下列哪些行为不属于为境外窃取、刺探、收买、非法提供国家秘密罪?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>14849</td>\n",
" <td>{'A': '单务法律行为', 'B': '双务法律行为', 'C': '实践法律行为', ...</td>\n",
" <td>下列关于赠与合同的说法正确的是:</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1218</td>\n",
" <td>{'A': '人权是基本权利的来源,基本权利是人权宪法化的具体表现', 'B': '基本权利...</td>\n",
" <td>公民基本权利也称宪法权利。关于公民基本权利,下列哪些选项是正确的?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>12483</td>\n",
" <td>{'A': '政府采购机构签订的采购合同行为', 'B': '公安局作出的行政调解行为', ...</td>\n",
" <td>公民、法人或者其他组织对下列哪些行为可以提起行政复议?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3298</th>\n",
" <td>314495</td>\n",
" <td>{'A': '对可能判处5年有期徒刑以上刑罚的一般应当组成合议庭进行审判', 'B': '...</td>\n",
" <td>甲盗窃财物,数额较大,某区人民检察院向该区人民法院提起公诉,同时因为案件事实清楚,证据充分,...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3299</th>\n",
" <td>38200</td>\n",
" <td>{'A': '在行政许可的听证程序中,行政机关可以收取合理费用', 'B': '行政机关对行...</td>\n",
" <td>关于收费问题,下列说法正确的是?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3300</th>\n",
" <td>313314</td>\n",
" <td>{'A': '刘某只能自行收集证据', 'B': '刘某必须经过人民法院准许,方可查阅、摘抄...</td>\n",
" <td>在一起强奸案中,律师刘某作为被害人的诉讼代理人,下列说法正确的是:</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3301</th>\n",
" <td>311215</td>\n",
" <td>{'A': '甲、乙共同购买一台电视,乙要将其共有的份额转让,甲享有优先购买权', 'B':...</td>\n",
" <td>以下关系中,一方享有优先购买权的为哪些选项?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3302</th>\n",
" <td>31863</td>\n",
" <td>{'A': '郭某是村民小组长在管理本村行政事务时将村集体财产30万元均为己有', 'B...</td>\n",
" <td>下列行为构成贪污罪的有:</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5289 rows × 3 columns</p>\n",
"</div>"
],
"text/plain": [
" id option_list \\\n",
"0 11781 {'A': '犯罪的预备阶段', 'B': '犯罪的实行阶段', 'C': '犯罪行为尚未实... \n",
"1 13516 {'A': '行为人知道或者应当知道标明密级的事项,而为境外窃取、刺探、收买、非法提供的',... \n",
"2 14849 {'A': '单务法律行为', 'B': '双务法律行为', 'C': '实践法律行为', ... \n",
"3 1218 {'A': '人权是基本权利的来源,基本权利是人权宪法化的具体表现', 'B': '基本权利... \n",
"4 12483 {'A': '政府采购机构签订的采购合同行为', 'B': '公安局作出的行政调解行为', ... \n",
"... ... ... \n",
"3298 314495 {'A': '对可能判处5年有期徒刑以上刑罚的一般应当组成合议庭进行审判', 'B': '... \n",
"3299 38200 {'A': '在行政许可的听证程序中,行政机关可以收取合理费用', 'B': '行政机关对行... \n",
"3300 313314 {'A': '刘某只能自行收集证据', 'B': '刘某必须经过人民法院准许,方可查阅、摘抄... \n",
"3301 311215 {'A': '甲、乙共同购买一台电视,乙要将其共有的份额转让,甲享有优先购买权', 'B':... \n",
"3302 31863 {'A': '郭某是村民小组长在管理本村行政事务时将村集体财产30万元均为己有', 'B... \n",
"\n",
" statement \n",
"0 犯罪中止可以发生在: \n",
"1 下列哪些行为不属于为境外窃取、刺探、收买、非法提供国家秘密罪? \n",
"2 下列关于赠与合同的说法正确的是: \n",
"3 公民基本权利也称宪法权利。关于公民基本权利,下列哪些选项是正确的? \n",
"4 公民、法人或者其他组织对下列哪些行为可以提起行政复议? \n",
"... ... \n",
"3298 甲盗窃财物,数额较大,某区人民检察院向该区人民法院提起公诉,同时因为案件事实清楚,证据充分,... \n",
"3299 关于收费问题,下列说法正确的是? \n",
"3300 在一起强奸案中,律师刘某作为被害人的诉讼代理人,下列说法正确的是: \n",
"3301 以下关系中,一方享有优先购买权的为哪些选项? \n",
"3302 下列行为构成贪污罪的有: \n",
"\n",
"[5289 rows x 3 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test0_df = pd.read_json('./0_test.json', lines=True)\n",
"test1_df = pd.read_json('./1_test.json', lines=True)\n",
"\n",
"test_df = pd.concat([test0_df, test1_df], axis=0)\n",
"test_df = test_df.loc[:, ['id', 'option_list', 'statement']]\n",
"test_df"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'instruction': '犯罪中止可以发生在: (A) 犯罪的预备阶段 (B) 犯罪的实行阶段 (C) 犯罪行为尚未实行完毕的情况下 (D) 犯罪行为已经实行完毕的情况下.',\n",
" 'input': '',\n",
" 'output': ''},\n",
" {'instruction': '下列哪些行为不属于为境外窃取、刺探、收买、非法提供国家秘密罪? (A) 行为人知道或者应当知道标明密级的事项,而为境外窃取、刺探、收买、非法提供的 (B) 行为人知道或者应当知道没有标明密级的事项关系国家安全和利益,而为境外窃取、刺探、收买、非法提供的 (C) 通过互联网将国家秘密或者情报非法发送给境外的机构、组织、个人的 (D) 将国家秘密通过互联网予以发布,情节严重的.',\n",
" 'input': '',\n",
" 'output': ''},\n",
" {'instruction': '下列关于赠与合同的说法正确的是: (A) 单务法律行为 (B) 双务法律行为 (C) 实践法律行为 (D) 要式法律行为.',\n",
" 'input': '',\n",
" 'output': ''},\n",
" {'instruction': '公民基本权利也称宪法权利。关于公民基本权利,下列哪些选项是正确的? (A) 人权是基本权利的来源,基本权利是人权宪法化的具体表现 (B) 基本权利的主体主要是公民,在我国法人也可以作为基本权利的主体 (C) 我国公民在行使自由和权利的时候,不得损害国家的、社会的、集体的利益和其他公民的合法的自由和利益 (D) 权利和义务的平等性是我国公民基本权利和义务的重要特点.',\n",
" 'input': '',\n",
" 'output': ''},\n",
" {'instruction': '公民、法人或者其他组织对下列哪些行为可以提起行政复议? (A) 政府采购机构签订的采购合同行为 (B) 公安局作出的行政调解行为 (C) 大学拒绝发放学位证书的行为 (D) 注册会计师协会对注册会计师执业证照不予年检的行为.',\n",
" 'input': '',\n",
" 'output': ''},\n",
" {'instruction': '以下关于环境保护法的民事责任的说法,哪些是正确的? (A) 环境侵权责任应当适用民事侵权的过错原则 (B) 赔偿责任和赔偿金额的纠纷,可以根据当事人的请求,由环境保护行政主管部门处理 (C) 由于不可抗拒的自然灾害,免予承担责任 (D) 因环境污染损害赔偿提起诉讼的时效期间为四年,从当事人知道或者应当知道受到污染损害时起计算.',\n",
" 'input': '',\n",
" 'output': ''},\n",
" {'instruction': '我国《商标法》第三十九条规定:“转让注册商标的,转让人和受让人应当签订转让协议,并共同向商标局提出申请。受让人应当保证使用该注册商标的商品质量。”从表述上看,该法律条文省略了: (A) 假定条件 (B) 处理 (C) 行为模式 (D) 法律后果.',\n",
" 'input': '',\n",
" 'output': ''},\n",
" {'instruction': '根据我国《行政复议法》的规定,公民、法人或其他组织对民事纠纷的仲裁、调解或者处理不服的,不能申请行政复议。但下列哪种情况除外 (A) 继承权纠纷 (B) 干涉婚姻自主权的 (C) 房屋买卖行为 (D) 行政机关对土地、矿产、森林等资源所有权或使用权归属的处理决定.',\n",
" 'input': '',\n",
" 'output': ''},\n",
" {'instruction': '下列有关中国的贸易救济措施的说法正确的是: (A) 反倾销税的征收期限和价格承诺的履行期限不超过6年但是经复审确定终止征收反倾销有可能导致损害的继续或者再度发生的可以适当延长反倾销税的征收期限 (B) 根据反补贴条例,出口国政府或出口经营者,都可以做出承诺,分别承诺取消、限制补贴或其他有关措施,承诺修改价格 (C) 保障措施的实施期限不超过4年符合法律规定的条件的保障措施的实施期限可以适当延长但最长不超过8年。 (D) 适用保障措施要求的产业损害程度重于反倾销或反补贴要求的损害程度,即实质损害而不是严重损害.',\n",
" 'input': '',\n",
" 'output': ''},\n",
" {'instruction': '以下说法不正确的是? (A) 1919年《魏玛宪法》是第一部现代宪法 (B) 1830年法国宪法是钦定宪法 (C) 英国是世界上制定宪法最多的国家 (D) 1787年美国宪法是世界首部成文宪法.',\n",
" 'input': '',\n",
" 'output': ''}]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_instances = []\n",
"for index, row in test_df.iterrows():\n",
" instance_dict = {}\n",
" options, statement = row['option_list'], row['statement']\n",
" opt_a, opt_b, opt_c, opt_d = options.values()\n",
" instruction = f\"{statement} (A) {opt_a} (B) {opt_b} (C) {opt_c} (D) {opt_d}.\"\n",
" output = ''\n",
" # print({'Instruction': instruction, 'Output': output})\n",
" test_instances.append({'instruction': instruction, 'input': '', 'output': output})\n",
"test_instances[:10]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"with open('qa_thunlp_test.json', 'w') as f:\n",
" json.dump(test_instances, f, ensure_ascii=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "legal",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because it is too large Load Diff

78
scripts/clear_law.py Normal file
View File

@ -0,0 +1,78 @@
import re
import json
class read_lawfile:
def __init__(self, chapter_moder=r"第[零一二三四五六七八九十百千万]+章 .+\b", entry_mode=r"第[零一二三四五六七八九十百千万]+条\b"):
# 识别章和节
self.chapter_mode = chapter_moder
self.entry_mode = entry_mode
def read_file(self, file_path):
# 读取文件
self.law = {}
f = open(file_path, encoding='utf-8')
content = f.read()
content = content.replace("\n\n", "\n")
content = content.replace("##", "")
# print(content)
chapter_p = re.search(self.chapter_mode, content)
while chapter_p is not None:
c_start = chapter_p.start()
c_end = chapter_p.end()
key = content[c_start:c_end]
content = content[c_end:]
chapter_p = re.search(self.chapter_mode, content)
if chapter_p is not None:
end = chapter_p.start()
c_content = content[:end]
self.law[key] = self.read_entrys(c_content)
# print(content[c_start:c_end])
else:
self.law[key] = self.read_entrys(content)
f.close()
return self.law
def read_entrys(self, content):
entrys = {}
entry_p = re.search(self.entry_mode, content)
while entry_p is not None:
e_start = entry_p.start()
e_end = entry_p.end()
key = content[e_start:e_end]
content = content[e_end+1:]
entry_p = re.search(self.entry_mode, content)
if entry_p is not None:
end = entry_p.start()
e_content = content[:end]
entrys[key] = e_content
else:
entrys[key] = content
return entrys
# entry_p = re.search(entry_mode, content)
# while entry_p is not None:
# start = entry_p.start()
# end = entry_p.end()
# # print(content[start:end])
# content = content[end:]
# law[content[start:end]] = read_entrys(content)
# chapter_p = re.search(chapter_mode, content)
def show(self):
for key in self.law:
print(key, '\n')
for item in self.law[key]:
print(item, ' ', self.law[key][item])
if __name__ == '__main__':
file_path = "D:/11496/Documents/project/Laws-master/经济法/价格法(1997-12-29).md"
r = read_lawfile()
dict = r.read_file(file_path)
r.show()
print(dict)
with open('./a.json', 'w') as f:
# json.dumps(dict, f, ensure_ascii=False)
json.dump(dict, f, ensure_ascii=False)

View File

@ -0,0 +1,50 @@
import argparse
import openai
import yaml
import random
def return_random_prompt():
system_prompt = "你需要针对法条内容尽可能联想多样化的场景生成问答数据。我们将用于人工评估 ChatGPT 模型对指令的完成情况。要求:\n"
# generate random tasks
system_prompt += "1. 结合真实问题,表述多样化。\n"
# other requirements
system_prompt += "2. 如果遇到无法处理的指令(只靠文本无法回答),给出无法处理的回复。\n"
system_prompt += "3. 除非特别要求,请使用中文,指令可以是命令句、疑问句、或其他合适的类型。\n"
system_prompt += "4. <Reference>:违反本法规定,对妇女实施性骚扰的,由公安机关给予批评教育或者出具告诫书,并由所在单位依法给予处分。\n学校、用人单位违反本法规定,未采取必要措施预防和制止性骚扰,造成妇女权益受到侵害或者社会影响恶劣的,由上级机关或者主管部门责令改正;拒不改正或者情节严重的,依法对直接负责的主管人员和其他直接责任人员给予处分。\n"
system_prompt += "5. <input>是结合法条内容联想到的真实场景下的问题。要求该场景下存在违法者和受害人\n"
system_prompt += "6. <output>是结合法条内容对该问题的适当且真实的回应,不能只回复答应或拒绝请求。尽可能地指明违法行为可能遭受的惩罚,并向受害者提出维权建议。\n\n"
system_prompt += "请给出满足条件的10条JSON格式数据\n"
return system_prompt
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--cfg_path', default='../config.yaml', type=str)
parser.add_argument('--save_path', default='./output.json', type=str)
args = parser.parse_args()
with open(args.cfg_path, 'r') as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
openai.api_key = cfg['API_KEY']
openai.api_base = cfg['API_BASE_URL']
output_file = open(args.save_path, 'w')
# number of data to generate (each prompt contains 20 JSON-formatted data)
# TODO: 改成流式的,不然会中途断掉
MAX_EPOCHS = 1
for k in range(MAX_EPOCHS):
response = openai.ChatCompletion.create(
# here we use `gpt-3.5-turbo` model, while Stanford-Alpaca uses `text-davinci-003`
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": return_random_prompt()},
]
)
output_file.write(response["choices"][0]["message"]["content"] + '\n')
output_file.close()

View File

@ -6,25 +6,17 @@ import random
def return_random_prompt():
system_prompt = "你需要尽可能给出多样化的任务指令和对应的回答。我们将用于人工评估ChatGPT模型对指令的完成情况。要求:\n"
# generate random topics
system_prompt += "1. 主题多样化,涵盖法律诉讼的各个领域,例如:刑法、民法、行政法等。\n"
system_prompt = "你需要针对输入尽可能给出多样化的任务指令和对应的回答。我们将用于人工评估ChatGPT模型对指令的完成情况。要求:\n"
# generate random tasks
task_list = ["开放式生成", "分类", "问答", "编辑", "摘要",
"写作", "翻译", "分析", "常识推理", "写信", "抽取", "推荐"]
system_prompt += "2. 表述多样化,结合真实问题;指令类型多样化,例如:" + \
"".join(random.sample(task_list, 10)) + "等。\n"
task_list = ["开放式生成", "分类", "问答", "编辑", "摘要", "写作", "分析", "抽取"]
system_prompt += "1. 表述多样化,结合真实问题;指令类型多样化,例如:" + "".join(random.sample(task_list, 7)) + "等。\n"
# other requirements
system_prompt += "3. 如果遇到无法处理的指令(只靠文本无法回答),给出无法处理的回复。\n"
system_prompt += "4. 除非特别要求,请使用中文,指令可以是命令句、疑问句、或其他合适的类型。\n"
system_prompt += "5. 为指令生成一个适当且涉及真实情况的<input>,不应该只包含简单的占位符。<input>应提供实质性的内容,具有挑战性。字数不超过" + \
str(random.randint(80, 120)) + "字。\n"
system_prompt += "6. <output>应该是对指令的适当且真实的回应,不能只回复答应或拒绝请求。如果需要额外信息才能回复时,请努力预测用户意图并尝试回复。<output>的内容应少于" + \
str(random.randint(128, 512)) + "字。\n\n"
system_prompt += "2. 如果遇到无法处理的指令(只靠文本无法回答),给出无法处理的回复。\n"
system_prompt += "3. 除非特别要求,请使用中文,指令可以是命令句、疑问句、或其他合适的类型。\n"
system_prompt += "4. <input>是:'第十三条 一切危害国家主权、领土完整和安全,分裂国家、颠覆人民民主专政的政权和推翻社会主义制度,破坏社会秩序和经济秩序,侵犯国有财产或者劳动群众集体所有的财产,侵犯公民私人所有的财产,侵犯公民的人身权利、民主权利和其他权利,以及其他危害社会的行为,依照法律应当受刑罚处罚的,都是犯罪,但是情节显著轻微危害不大的,不认为是犯罪。'"
system_prompt += "5. <output>应该是对指令的适当且真实的回应,不能只回复答应或拒绝请求。如果需要额外信息才能回复时,请努力预测用户意图并尝试回复。<output>的内容应少于" + str(random.randint(128, 512)) + "字。\n\n"
system_prompt += "请给出满足条件的20条JSON格式数据\n"
return system_prompt

280
src/finetune.py Normal file
View File

@ -0,0 +1,280 @@
import os
import sys
from typing import List
import fire
import torch
import transformers
from datasets import load_dataset
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer
from utils.prompter import Prompter
def train(
# model/data params
base_model: str = "./models/base_models/your_base_model_dir",
data_path: str = "./data/your_data.json",
output_dir: str = "./outputs/your_version_dir",
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 4,
num_epochs: int = 10,
learning_rate: float = 3e-4,
cutoff_len: int = 512,
val_set_size: int = 2000,
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = ["q_proj", "v_proj",],
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = True,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
# either training checkpoint or final adapter
resume_from_checkpoint: str = None,
# The prompt template to use, will default to alpaca.
prompt_template_name: str = "alpaca",
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca-LoRA model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
)
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
# Configure device and distributed training
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map=device_map,
)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
bos = tokenizer.bos_token_id
eos = tokenizer.eos_token_id
pad = tokenizer.pad_token_id
print("pre-trained model's BOS EOS and PAD token id:",
bos, eos, pad, " => It should be 1,2,none")
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
def tokenize(prompt, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = (
False # So the trainer won't try loading its state
)
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
# Be more transparent about the % of trainable params.
model.print_trainable_parameters()
if val_set_size > 0:
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = (train_val["train"].shuffle().map(generate_and_tokenize_prompt))
val_data = (train_val["test"].shuffle().map(generate_and_tokenize_prompt))
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
logging_steps=10,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=100 if val_set_size > 0 else None,
save_steps=100,
output_dir=output_dir,
save_total_limit=3,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)
print("\n If there's a warning about missing keys above, please disregard :)")
if __name__ == "__main__":
fire.Fire(train)

201
src/generate.py Normal file
View File

@ -0,0 +1,201 @@
import os
import sys
import fire
import gradio as gr
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except: # noqa: E722
pass
def main(
load_8bit: bool = False,
base_model: str = "",
lora_weights: str = "tloen/alpaca-lora-7b",
prompt_template: str = "", # The prompt template to use, will default to alpaca.
server_name: str = "0.0.0.0", # Allows to listen on all interfaces by providing '0.
share_gradio: bool = False,
):
base_model = base_model or os.environ.get("BASE_MODEL", "")
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device},
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(
model,
lora_weights,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device}, low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(
model,
lora_weights,
device_map={"": device},
)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
def evaluate(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=1,
max_new_tokens=256,
stream_output=True,
**kwargs,
):
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
generate_params = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"max_new_tokens": max_new_tokens,
}
if stream_output:
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator,
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
def generate_with_callback(callback=None, **kwargs):
kwargs.setdefault(
"stopping_criteria", transformers.StoppingCriteriaList()
)
kwargs["stopping_criteria"].append(
Stream(callback_func=callback)
)
with torch.no_grad():
model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(
generate_with_callback, kwargs, callback=None
)
with generate_with_streaming(**generate_params) as generator:
for output in generator:
# new_tokens = len(output) - len(input_ids[0])
decoded_output = tokenizer.decode(output)
if output[-1] in [tokenizer.eos_token_id]:
break
yield prompter.get_response(decoded_output)
return # early return for stream_output
# Without streaming
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
yield prompter.get_response(output)
gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2,
label="Instruction",
placeholder="Tell me about alpacas.",
),
gr.components.Textbox(lines=2, label="Input", placeholder="none"),
gr.components.Slider(
minimum=0, maximum=1, value=1.0, label="Temperature"
),
gr.components.Slider(
minimum=0, maximum=1, value=0.75, label="Top p"
),
gr.components.Slider(
minimum=0, maximum=100, step=1, value=40, label="Top k"
),
gr.components.Slider(
minimum=1, maximum=4, step=1, value=4, label="Beams"
),
gr.components.Slider(
minimum=1, maximum=2000, step=1, value=256, label="Max tokens"
),
gr.components.Checkbox(label="Stream output", value=True),
],
outputs=[
gr.inputs.Textbox(
lines=5,
label="Output",
)
],
title="🦙🌲 LaWGPT",
description="", # noqa: E501
).queue().launch(server_name="0.0.0.0", share=share_gradio)
# Old testing code follows.
if __name__ == "__main__":
fire.Fire(main)

View File

View File

0
src/outputs/.gitkeep Normal file
View File

17
src/scripts/finetune.sh Normal file
View File

@ -0,0 +1,17 @@
#!/bin/bash
WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1234 finetune.py \
--base_model 'minlik/chinese-llama-7b-merged' \
--data_path '' \
--output_dir './outputs/LawGPT' \
--prompt_template_name 'law_template' \
--micro_batch_size 16 \
--batch_size 128 \
--num_epochs 3 \
--val_set_size 10000 \
--lora_target_modules='[q_proj,k_proj,v_proj,o_proj]' \
--lora_r 16 \
--lora_alpha 32 \
--learning_rate 3e-4 \
--cutoff_len 512 \
--resume_from_checkpoint './outputs/LawGPT' \

7
src/scripts/generate.sh Normal file
View File

@ -0,0 +1,7 @@
CUDA_VISIBLE_DEVICES=1 python generate.py \
--load_8bit \
--base_model 'minlik/chinese-llama-7b-merged' \
--lora_weights 'entity303/lawgpt-lora-7b' \
--prompt_template 'law_template' \
--share_gradio

View File

@ -0,0 +1,6 @@
{
"description": "Template used by Law Instruction Tuning",
"prompt_input": "下面是一个问题,运用法学知识来正确回答提问.\n### 问题:\n{instruction}\n### 回答:\n",
"prompt_no_input": "下面是一个问题,运用法学知识来正确回答提问.\n### 问题:\n{instruction}\n### 回答:\n",
"response_split": "### 回答:"
}

0
src/utils/__init__.py Normal file
View File

75
src/utils/callbacks.py Normal file
View File

@ -0,0 +1,75 @@
"""
Helpers to support streaming generate output.
Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py
"""
import gc
import traceback
from queue import Queue
from threading import Thread
import torch
import transformers
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc = func
self.c_callback = callback
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
if self.stop_now:
raise ValueError
self.q.put(val)
def gentask():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
self.thread = Thread(target=gentask)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True

196
src/utils/evaluate.py Normal file
View File

@ -0,0 +1,196 @@
import math
import os
import sys
import fire
from tqdm import tqdm
import pandas as pd
import torch
import transformers
from peft import PeftModel
import datasets
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter
device = "cuda"
def main(
load_8bit: bool = True,
base_model: str = "decapoda-research/llama-7b-hf",
lora_weights: str = "./lora-alpaca",
data_path: str = "./data",
output_path: str = "./output",
eval_rate: float = 0.1,
batch_size: int = 32,
# The prompt template to use, will default to alpaca.
prompt_template: str = "alpaca",
):
base_model = base_model or os.environ.get("BASE_MODEL", "")
assert (base_model), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
def evaluate_one(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=2,
max_new_tokens=128,
**kwargs,
):
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
# Without streaming
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)
return prompter.get_response(output)
def evaluate_all():
# data = datasets.load_dataset("json", data_files=data_path)
# data = data["train"]
# df = data.to_pandas()
df = pd.read_json(data_path, orient='records')
print(df.info())
# 计算准确率
correct = 0
total = 0
total_step = len(df)
pbar = tqdm(total=total_step, unit='batch')
error = []
for i in range(total_step):
instruction = df['instruction'].iloc[i]
input = df['input'].iloc[i]
label = df['output'].iloc[i]
pred = evaluate_one(instruction=instruction, input=input)
if pred == label:
correct += 1
else:
error.append((label, pred))
total += 1
acc = correct / total
# 更新进度条
# Update the progress bar
pbar.set_description(
f"Testing: Sample [{total}/{total_step}] Acc: {acc :.4f}")
pbar.update(1)
for e in error:
print(e)
def evaluate_by_batch(
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=1,
max_new_tokens=32
):
df = pd.read_json(data_path, orient='records')
# df = df.sample(frac=eval_rate).reset_index(drop=True)
df['prompt'] = df.apply(lambda x: prompter.generate_prompt(
x['instruction'], x['input']), axis=1)
tokenizer.padding_side = "left" # Allow batched inference
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams
)
outputs = []
total = 0
total_step = math.ceil(len(df) / batch_size)
pbar = tqdm(total=total_step, unit='batch')
# 计算准确率
with torch.no_grad():
for i in range(total_step):
batch = df.iloc[i*batch_size:(i+1)*batch_size]
inputs = tokenizer(batch['prompt'].tolist(), return_tensors="pt", padding=True)[
'input_ids'].to(device)
generation_outputs = model.generate(
input_ids=inputs,
generation_config=generation_config,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.pad_token_id
)
for g in generation_outputs:
decoded_item = tokenizer.decode(
g, skip_special_tokens=True)
try:
output = prompter.get_response(decoded_item)
except:
output = decoded_item
outputs.append(output)
total += 1
# 更新进度条
pbar.set_description(f"Testing: Sample [{total}/{len(df)}] ")
pbar.update(1)
df['pred'] = outputs
df['pred'].to_csv(output_path, index=False)
evaluate_by_batch()
if __name__ == "__main__":
# fire.Fire(main)
import yaml
dataset_param = sys.argv[1]
with open("./configs/evaluate_params.yaml", "r") as stream:
# try:
params = yaml.safe_load(stream)
print('=' * 80)
print(params[dataset_param])
print('=' * 80)
# fire.Fire(train)
main(**params[dataset_param])

51
src/utils/merge.py Normal file
View File

@ -0,0 +1,51 @@
import os
import torch
import transformers
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
BASE_MODEL = os.environ.get("BASE_MODEL", None)
assert (
BASE_MODEL
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
base_model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
load_in_8bit=False,
torch_dtype=torch.float16,
device_map={"": "cpu"},
)
first_weight = base_model.model.layers[0].self_attn.q_proj.weight
first_weight_old = first_weight.clone()
lora_model = PeftModel.from_pretrained(
base_model,
"../outputs/lora-llama-clm-e2",
device_map={"": "cpu"},
torch_dtype=torch.float16,
)
lora_weight = lora_model.base_model.model.model.layers[0].self_attn.q_proj.weight
assert torch.allclose(first_weight_old, first_weight)
# merge weights - new merging method from peft
lora_model = lora_model.merge_and_unload()
lora_model.train(False)
# did we do anything?
assert not torch.allclose(first_weight_old, first_weight)
lora_model_sd = lora_model.state_dict()
deloreanized_sd = {
k.replace("base_model.model.", ""): v
for k, v in lora_model_sd.items()
if "lora" not in k
}
LlamaForCausalLM.save_pretrained(base_model, '../models/LawGPT_step_1', state_dict=deloreanized_sd, max_shard_size="400MB")

51
src/utils/prompter.py Normal file
View File

@ -0,0 +1,51 @@
"""
A dedicated helper to manage templates and prompt building.
"""
import json
import os.path as osp
from typing import Union
class Prompter(object):
__slots__ = ("template", "_verbose")
def __init__(self, template_name: str = "", verbose: bool = False):
self._verbose = verbose
if not template_name:
# Enforce the default here, so the constructor can be called with '' and will not break.
template_name = "alpaca"
file_name = osp.join("templates", f"{template_name}.json")
if not osp.exists(file_name):
raise ValueError(f"Can't read {file_name}")
with open(file_name) as fp:
self.template = json.load(fp)
if self._verbose:
print(
f"Using prompt template {template_name}: {self.template['description']}"
)
def generate_prompt(
self,
instruction: str,
input: Union[None, str] = None,
label: Union[None, str] = None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.template["prompt_input"].format(
instruction=instruction, input=input
)
else:
res = self.template["prompt_no_input"].format(
instruction=instruction
)
if label:
res = f"{res}{label}"
if self._verbose:
print(res)
return res
def get_response(self, output: str) -> str:
return output.split(self.template["response_split"])[1].strip()