[ENH] Restructure the project.

This commit is contained in:
herobrine19
2023-05-22 02:39:04 +08:00
parent f329433cad
commit 46a9cbc21e
22 changed files with 755 additions and 36 deletions

19
.gitignore vendored
View File

@ -1 +1,18 @@
./data
__pycache__/
*.npy
*.npz
*.pyc
*.pyd
*.so
*.ipynb
.ipynb_checkpoints
models/*
!models/.gitkeep
!models/base_models/.gitkeep
!models/lora_weights/.gitkeep
outputs/*
!outputs/.gitkeep
data/*
!data/.gitkeep
wandb/
**/.DS_Store

283
finetune.py Normal file
View File

@ -0,0 +1,283 @@
import os
import sys
from typing import List
import fire
import torch
import transformers
from datasets import load_dataset
"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer
from utils.prompter import Prompter
def train(
# model/data params
base_model: str = "", # the only required argument
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "./lora-alpaca",
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 4,
num_epochs: int = 3,
learning_rate: float = 3e-4,
cutoff_len: int = 256,
val_set_size: int = 2000,
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = [
"q_proj",
"v_proj",
],
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = True,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca-LoRA model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map=device_map,
)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
def tokenize(prompt):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = (
False # So the trainer won't try loading its state
)
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = (
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
)
val_data = (
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_ratio=0.1,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
logging_steps=10,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=50 if val_set_size > 0 else None,
save_steps=50,
output_dir=output_dir,
save_total_limit=5,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)
print(
"\n If there's a warning about missing keys above, please disregard :)"
)
if __name__ == "__main__":
fire.Fire(train)

126
infer.py Normal file
View File

@ -0,0 +1,126 @@
import sys
import json
import fire
import gradio as gr
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from utils.prompter import Prompter
if torch.cuda.is_available():
device = "cuda"
def load_instruction(instruct_dir):
input_data = []
with open(instruct_dir, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
d = json.loads(line)
input_data.append(d)
return input_data
def main(
load_8bit: bool = True,
base_model: str = "",
# the infer data, if not exists, infer the default instructions in code
instruct_dir: str = "",
lora_weights: str = "",
# The prompt template to use, will default to med_template.
prompt_template: str = "",
):
prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
try:
print(f"using lora {lora_weights}")
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
except:
print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
def evaluate(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=1,
max_new_tokens=256,
**kwargs,
):
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
# repetition_penalty=10.0,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
return prompter.get_response(output)
def infer_from_json(instruct_dir):
input_data = load_instruction(instruct_dir)
for d in input_data:
instruction = d["instruction"]
output = d["output"]
print('=' * 100)
print(f"###{base_model}-{lora_weights}###")
model_output = evaluate(instruction)
print("###instruction###")
print(instruction)
print("###golden output###")
print(output)
print("###model output###")
print(model_output)
print('=' * 100)
if instruct_dir != "":
infer_from_json(instruct_dir)
else:
for instruction in [
"民间借贷受国家保护的合法利息是多少?",
]:
print("Instruction:", instruction)
print("Response:", evaluate(instruction))
print()
if __name__ == "__main__":
fire.Fire(main)

56
scripts/finetune.sh Normal file
View File

@ -0,0 +1,56 @@
#!/bin/bash
export WANDB_MODE=disabled # 禁用wandb
# 使用chinese-alpaca-plus-7b-merged模型在law_data.json数据集上finetune
experiment_name="chinese-alpaca-plus-7b-law-e1"
# 单卡或者模型并行
python finetune.py \
--base_model "minlik/chinese-alpaca-plus-7b-merged" \
--data_path "./data/finetune_law_data.json" \
--output_dir "./outputs/"${experiment_name} \
--batch_size 64 \
--micro_batch_size 8 \
--num_epochs 20 \
--learning_rate 3e-4 \
--cutoff_len 256 \
--val_set_size 0 \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--lora_target_modules "[q_proj,v_proj]" \
--train_on_inputs True \
--add_eos_token True \
--group_by_length False \
--wandb_project \
--wandb_run_name \
--wandb_watch \
--wandb_log_model \
--resume_from_checkpoint "./outputs/"${experiment_name} \
--prompt_template_name "alpaca" \
# 多卡数据并行
# WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1234 finetune.py \
# --base_model "minlik/chinese-alpaca-plus-7b-merged" \
# --data_path "./data/finetune_law_data.json" \
# --output_dir "./outputs/"${experiment_name} \
# --batch_size 64 \
# --micro_batch_size 8 \
# --num_epochs 20 \
# --learning_rate 3e-4 \
# --cutoff_len 256 \
# --val_set_size 0 \
# --lora_r 8 \
# --lora_alpha 16 \
# --lora_dropout 0.05 \
# --lora_target_modules "[q_proj,v_proj]" \
# --train_on_inputs True \
# --add_eos_token True \
# --group_by_length False \
# --wandb_project \
# --wandb_run_name \
# --wandb_watch \
# --wandb_log_model \
# --resume_from_checkpoint "./outputs/"${experiment_name} \
# --prompt_template_name "alpaca" \

16
scripts/infer-law.sh Normal file
View File

@ -0,0 +1,16 @@
# LawGPT
python infer.py \
--base_model 'minlik/chinese-alpaca-plus-7b-merged' \
--lora_weights './outputs/chinese-alpaca-plus-7b-law-e1' \
--instruct_dir './data/infer_law_data.json' \
--prompt_template 'alpaca'
# Chinese-Alpaca-plus-7B
python infer.py \
--base_model 'minlik/chinese-alpaca-plus-7b-merged' \
--lora_weights '' \
--instruct_dir './data/infer_law_data.json' \
--prompt_template 'alpaca'

21
scripts/webui.sh Normal file
View File

@ -0,0 +1,21 @@
#!/bin/bash
# 使用huggingface上已经训练好的模型
python webui.py \
--load_8bit True \
--base_model 'minlik/chinese-alpaca-plus-7b-merged' \
--lora_weights 'entity303/lawgpt-lora-7b' \
--prompt_template "law_template" \
--server_name "0.0.0.0" \
--share_gradio Ture \
# 使用自己finetune的lora, 把自己的模型放到对应目录即可
# python webui.py \
# --load_8bit True \
# --base_model 'minlik/chinese-alpaca-plus-7b-merged' \
# --lora_weights './outputs/chinese-alpaca-plus-7b-law-e1' \
# --prompt_template "law_template" \
# --server_name "0.0.0.0" \
# --share_gradio Ture \

View File

@ -1,8 +0,0 @@
[
{
"content": "中华人民共和国最高人民法院 再 审 决 定 书2022最高法刑申136号 原审被告人张某某犯挪用资金罪和伪造、变造国家机关公文罪一案山西省运城市盐湖区人民法院于2012年5月2日以2012运盐刑初字第69号刑事判决认定张克云犯贪污罪判处有期徒刑十二年犯伪造、变造国家机关公文罪判处有期徒刑三年决定执行有期徒刑十三年。宣判后张克云不服提出上诉。山西省运城市中级人民法院于2012年11月12日以2012运中刑二终字第125号刑事裁定驳回上诉维持原判。裁判生效后张克云不服提出申诉。运城市中级人民法院于2013年1月7日以2013运中刑申字第3号驳回申诉通知驳回其申诉。山西省高级人民法院于2017年7月13日以2013晋刑监字第8号再审决定提审本案并于2019年12月24日以2017晋刑再第2号刑事判决认定张克云犯挪用资金罪判处有期徒刑七年六个月与原判伪造、变造国家机关公文罪被判处的有期徒刑三年数罪并罚决定执行有期徒刑十年。张克云仍不服以原审认定事实错误其作为学校董事长、全资投资人有权决定学校相关款项用途学校仍欠其债务个人账户用于学校经费开支没有挪用资金的动机和行为不构成挪用资金罪等为由向本院提出申诉。本院经审查认为原审生效裁判对挪用资金罪定罪量刑的证据不确实、不充分依法应当予以排除。依照《中华人民共和国刑事诉讼法》第二百五十三条第二项、第二百五十四条第二款、第二百五十五条的规定决定如下指令河南省高级人民法院对本案进行再审。二二二年十二月二十九日"
},
{
"content":"中华人民共和国最高人民法院 驳 回 申 诉 通 知 书2022最高法刑申122号 袁某银、袁某财你们因原审被告人袁德银故意伤害一案对江苏省南京市溧水区人民法院2014溧刑初字第268号刑事判决、南京市中级人民法院2015宁刑终字第433号刑事裁定不服以被害人朱宽荣住院期间的136678号报告并未显示其左侧4、5、6、7、8肋骨骨折出院记录及137470号、143006号报告均系伪造江苏省高级人民法院2019苏刑申172号驳回申诉通知书对137470号报告的形成时间认定错误为由向本院提出申诉请求撤销原判依法重新审理本案。本院依法组成合议庭认真审查后认为原审认定原审被告人袁德银因邻里纠纷殴打被害人朱宽荣致其左胸多发肋骨骨折构成轻伤二级其行为构成故意伤害罪并无不当。关于你们提出的原审认定被害人朱宽荣轻伤二级的证据系伪造的申诉理由。首先根据你们提供的136678号报告朱宽荣于2015年2月12日入院时经检查被诊断为左侧多发肋骨骨折该份报告虽未明确骨折的具体位置但与出院记录记载的“左侧多发肋骨骨折4、5、6、7、8”并不矛盾。根据《人体损伤程度鉴定标准》的规定肋骨骨折2处以上即构成轻伤二级故鉴定意见认定朱宽荣的损伤构成轻伤二级并无不当。其次你们主张出院记录、137470号和143006号报告系伪造但并未提供证据证实。即使排除这两份报告也不影响对本案事实的认定。再次我院对申诉进行复查的对象是已经发生法律效力的判决、裁定本案的生效裁判是2014溧刑初字第268号刑事判决和2015宁刑终字第433号刑事裁定江苏省高级人民法院作出的驳回申诉通知书不属于我院复查的范围。综上本院认为你们的申诉不符合《中华人民共和国刑事诉讼法》第二百五十三条规定的应当重新审判的情形予以驳回。望你们尊重人民法院的生效裁判自觉服判息诉。特此通知。二二二年十二月十四日"
}
]

View File

@ -1,10 +0,0 @@
[
{
"instruction": "酒驾撞人要判多久?",
"input": "",
"output": "《刑法》第一百三十三条规定:违反交通运输管理法规,因而发生重大事故,致人重伤、死广或者使公私财产遭受重大损失的,处三年以下有期徒刑或者拘役;交通运输肇事后逃逸,致人重伤的,处三年以上七年以下有期徒刑。交通肇事后逃逸致人死亡的,处七年以上有期徒刑。"
},
{"instruction":"偷狗被派出所抓到,会怎么处理?",
"input":"",
"output":"取决于偷盗的狗的价值,可能按盗窃罪的罪名处理。如果价值达到犯罪标准,就构成盗窃罪,要承担刑事责任;如果不到犯罪标准,就是治安处罚、罚款或者拘留治安处罚这会涉嫌构成盗窃。如果不到一千元,则不会构成犯罪。如果超过一千元,则可能会是构成犯罪的。"}
]

View File

@ -1,17 +0,0 @@
#!/bin/bash
WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1234 finetune.py \
--base_model 'minlik/chinese-llama-7b-merged' \
--data_path '' \
--output_dir './outputs/LawGPT' \
--prompt_template_name 'law_template' \
--micro_batch_size 16 \
--batch_size 128 \
--num_epochs 3 \
--val_set_size 10000 \
--lora_target_modules='[q_proj,k_proj,v_proj,o_proj]' \
--lora_r 16 \
--lora_alpha 32 \
--learning_rate 3e-4 \
--cutoff_len 512 \
--resume_from_checkpoint './outputs/LawGPT' \

6
templates/alpaca.json Normal file
View File

@ -0,0 +1,6 @@
{
"description": "Template used by Alpaca-LoRA.",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}

229
webui.py Normal file
View File

@ -0,0 +1,229 @@
import os
import sys
import fire
import gradio as gr
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoModel, AutoTokenizer, AutoModelForCausalLM
from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except: # noqa: E722
pass
def main(
load_8bit: bool = False,
base_model: str = "",
lora_weights: str = "",
prompt_template: str = "", # The prompt template to use, will default to alpaca.
server_name: str = "0.0.0.0", # Allows to listen on all interfaces by providing '0.
share_gradio: bool = False,
):
base_model = base_model or os.environ.get("BASE_MODEL", "")
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
try:
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
except:
print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device},
torch_dtype=torch.float16,
)
try:
model = PeftModel.from_pretrained(
model,
lora_weights,
device_map={"": device},
torch_dtype=torch.float16,
)
except:
print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
else:
model = LlamaForCausalLM.from_pretrained(
base_model, device_map={"": device}, low_cpu_mem_usage=True
)
try:
model = PeftModel.from_pretrained(
model,
lora_weights,
device_map={"": device},
)
except:
print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
def evaluate(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4,
max_new_tokens=128,
stream_output=False,
**kwargs,
):
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
generate_params = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"max_new_tokens": max_new_tokens,
}
if stream_output:
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator,
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
def generate_with_callback(callback=None, **kwargs):
kwargs.setdefault(
"stopping_criteria", transformers.StoppingCriteriaList()
)
kwargs["stopping_criteria"].append(
Stream(callback_func=callback)
)
with torch.no_grad():
model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(
generate_with_callback, kwargs, callback=None
)
with generate_with_streaming(**generate_params) as generator:
for output in generator:
# new_tokens = len(output) - len(input_ids[0])
decoded_output = tokenizer.decode(output)
if output[-1] in [tokenizer.eos_token_id]:
break
yield prompter.get_response(decoded_output)
print(decoded_output)
return # early return for stream_output
# Without streaming
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
print(output)
yield prompter.get_response(output)
gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2,
label="Instruction",
placeholder="Tell me about alpacas.",
),
gr.components.Textbox(lines=2, label="Input", placeholder="none"),
gr.components.Slider(
minimum=0, maximum=1, value=0.1, label="Temperature"
),
gr.components.Slider(
minimum=0, maximum=1, value=0.75, label="Top p"
),
gr.components.Slider(
minimum=0, maximum=100, step=1, value=40, label="Top k"
),
gr.components.Slider(
minimum=1, maximum=4, step=1, value=1, label="Beams"
),
gr.components.Slider(
minimum=1, maximum=2000, step=1, value=256, label="Max tokens"
),
gr.components.Checkbox(label="Stream output", value=True),
],
outputs=[
gr.inputs.Textbox(
lines=5,
label="Output",
)
],
title="🦙🌲 LLM-LoRA",
description="", # noqa: E501
).queue().launch(server_name="0.0.0.0", share=share_gradio)
# Old testing code follows.
"""
# testing code for readme
for instruction in [
"Tell me about alpacas.",
"Tell me about the president of Mexico in 2019.",
"Tell me about the king of France in 2019.",
"List all Canadian provinces in alphabetical order.",
"Write a Python program that prints the first 10 Fibonacci numbers.",
"Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", # noqa: E501
"Tell me five words that rhyme with 'shock'.",
"Translate the sentence 'I have no mouth but I must scream' into Spanish.",
"Count up from 1 to 500.",
]:
print("Instruction:", instruction)
print("Response:", evaluate(instruction))
print()
"""
if __name__ == "__main__":
fire.Fire(main)