Release code repo

This commit is contained in:
songpx
2023-05-13 17:26:44 +08:00
parent 95b621693c
commit 7652c0071d
21 changed files with 1036 additions and 3896 deletions

280
src/finetune.py Normal file
View File

@ -0,0 +1,280 @@
import os
import sys
from typing import List
import fire
import torch
import transformers
from datasets import load_dataset
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer
from utils.prompter import Prompter
def train(
# model/data params
base_model: str = "./models/base_models/your_base_model_dir",
data_path: str = "./data/your_data.json",
output_dir: str = "./outputs/your_version_dir",
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 4,
num_epochs: int = 10,
learning_rate: float = 3e-4,
cutoff_len: int = 512,
val_set_size: int = 2000,
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = ["q_proj", "v_proj",],
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = True,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
# either training checkpoint or final adapter
resume_from_checkpoint: str = None,
# The prompt template to use, will default to alpaca.
prompt_template_name: str = "alpaca",
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca-LoRA model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
)
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
# Configure device and distributed training
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map=device_map,
)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
bos = tokenizer.bos_token_id
eos = tokenizer.eos_token_id
pad = tokenizer.pad_token_id
print("pre-trained model's BOS EOS and PAD token id:",
bos, eos, pad, " => It should be 1,2,none")
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
def tokenize(prompt, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = (
False # So the trainer won't try loading its state
)
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
# Be more transparent about the % of trainable params.
model.print_trainable_parameters()
if val_set_size > 0:
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = (train_val["train"].shuffle().map(generate_and_tokenize_prompt))
val_data = (train_val["test"].shuffle().map(generate_and_tokenize_prompt))
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
logging_steps=10,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=100 if val_set_size > 0 else None,
save_steps=100,
output_dir=output_dir,
save_total_limit=3,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)
print("\n If there's a warning about missing keys above, please disregard :)")
if __name__ == "__main__":
fire.Fire(train)

201
src/generate.py Normal file
View File

@ -0,0 +1,201 @@
import os
import sys
import fire
import gradio as gr
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except: # noqa: E722
pass
def main(
load_8bit: bool = False,
base_model: str = "",
lora_weights: str = "tloen/alpaca-lora-7b",
prompt_template: str = "", # The prompt template to use, will default to alpaca.
server_name: str = "0.0.0.0", # Allows to listen on all interfaces by providing '0.
share_gradio: bool = False,
):
base_model = base_model or os.environ.get("BASE_MODEL", "")
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device},
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(
model,
lora_weights,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device}, low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(
model,
lora_weights,
device_map={"": device},
)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
def evaluate(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=1,
max_new_tokens=256,
stream_output=True,
**kwargs,
):
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
generate_params = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"max_new_tokens": max_new_tokens,
}
if stream_output:
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator,
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
def generate_with_callback(callback=None, **kwargs):
kwargs.setdefault(
"stopping_criteria", transformers.StoppingCriteriaList()
)
kwargs["stopping_criteria"].append(
Stream(callback_func=callback)
)
with torch.no_grad():
model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(
generate_with_callback, kwargs, callback=None
)
with generate_with_streaming(**generate_params) as generator:
for output in generator:
# new_tokens = len(output) - len(input_ids[0])
decoded_output = tokenizer.decode(output)
if output[-1] in [tokenizer.eos_token_id]:
break
yield prompter.get_response(decoded_output)
return # early return for stream_output
# Without streaming
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
yield prompter.get_response(output)
gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2,
label="Instruction",
placeholder="Tell me about alpacas.",
),
gr.components.Textbox(lines=2, label="Input", placeholder="none"),
gr.components.Slider(
minimum=0, maximum=1, value=1.0, label="Temperature"
),
gr.components.Slider(
minimum=0, maximum=1, value=0.75, label="Top p"
),
gr.components.Slider(
minimum=0, maximum=100, step=1, value=40, label="Top k"
),
gr.components.Slider(
minimum=1, maximum=4, step=1, value=4, label="Beams"
),
gr.components.Slider(
minimum=1, maximum=2000, step=1, value=256, label="Max tokens"
),
gr.components.Checkbox(label="Stream output", value=True),
],
outputs=[
gr.inputs.Textbox(
lines=5,
label="Output",
)
],
title="🦙🌲 LaWGPT",
description="", # noqa: E501
).queue().launch(server_name="0.0.0.0", share=share_gradio)
# Old testing code follows.
if __name__ == "__main__":
fire.Fire(main)

View File

View File

0
src/outputs/.gitkeep Normal file
View File

17
src/scripts/finetune.sh Normal file
View File

@ -0,0 +1,17 @@
#!/bin/bash
WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1234 finetune.py \
--base_model 'minlik/chinese-llama-7b-merged' \
--data_path '' \
--output_dir './outputs/LawGPT' \
--prompt_template_name 'law_template' \
--micro_batch_size 16 \
--batch_size 128 \
--num_epochs 3 \
--val_set_size 10000 \
--lora_target_modules='[q_proj,k_proj,v_proj,o_proj]' \
--lora_r 16 \
--lora_alpha 32 \
--learning_rate 3e-4 \
--cutoff_len 512 \
--resume_from_checkpoint './outputs/LawGPT' \

7
src/scripts/generate.sh Normal file
View File

@ -0,0 +1,7 @@
CUDA_VISIBLE_DEVICES=1 python generate.py \
--load_8bit \
--base_model 'minlik/chinese-llama-7b-merged' \
--lora_weights 'entity303/lawgpt-lora-7b' \
--prompt_template 'law_template' \
--share_gradio

View File

@ -0,0 +1,6 @@
{
"description": "Template used by Law Instruction Tuning",
"prompt_input": "下面是一个问题,运用法学知识来正确回答提问.\n### 问题:\n{instruction}\n### 回答:\n",
"prompt_no_input": "下面是一个问题,运用法学知识来正确回答提问.\n### 问题:\n{instruction}\n### 回答:\n",
"response_split": "### 回答:"
}

0
src/utils/__init__.py Normal file
View File

75
src/utils/callbacks.py Normal file
View File

@ -0,0 +1,75 @@
"""
Helpers to support streaming generate output.
Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py
"""
import gc
import traceback
from queue import Queue
from threading import Thread
import torch
import transformers
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc = func
self.c_callback = callback
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
if self.stop_now:
raise ValueError
self.q.put(val)
def gentask():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
self.thread = Thread(target=gentask)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True

196
src/utils/evaluate.py Normal file
View File

@ -0,0 +1,196 @@
import math
import os
import sys
import fire
from tqdm import tqdm
import pandas as pd
import torch
import transformers
from peft import PeftModel
import datasets
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter
device = "cuda"
def main(
load_8bit: bool = True,
base_model: str = "decapoda-research/llama-7b-hf",
lora_weights: str = "./lora-alpaca",
data_path: str = "./data",
output_path: str = "./output",
eval_rate: float = 0.1,
batch_size: int = 32,
# The prompt template to use, will default to alpaca.
prompt_template: str = "alpaca",
):
base_model = base_model or os.environ.get("BASE_MODEL", "")
assert (base_model), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
def evaluate_one(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=2,
max_new_tokens=128,
**kwargs,
):
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
# Without streaming
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)
return prompter.get_response(output)
def evaluate_all():
# data = datasets.load_dataset("json", data_files=data_path)
# data = data["train"]
# df = data.to_pandas()
df = pd.read_json(data_path, orient='records')
print(df.info())
# 计算准确率
correct = 0
total = 0
total_step = len(df)
pbar = tqdm(total=total_step, unit='batch')
error = []
for i in range(total_step):
instruction = df['instruction'].iloc[i]
input = df['input'].iloc[i]
label = df['output'].iloc[i]
pred = evaluate_one(instruction=instruction, input=input)
if pred == label:
correct += 1
else:
error.append((label, pred))
total += 1
acc = correct / total
# 更新进度条
# Update the progress bar
pbar.set_description(
f"Testing: Sample [{total}/{total_step}] Acc: {acc :.4f}")
pbar.update(1)
for e in error:
print(e)
def evaluate_by_batch(
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=1,
max_new_tokens=32
):
df = pd.read_json(data_path, orient='records')
# df = df.sample(frac=eval_rate).reset_index(drop=True)
df['prompt'] = df.apply(lambda x: prompter.generate_prompt(
x['instruction'], x['input']), axis=1)
tokenizer.padding_side = "left" # Allow batched inference
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams
)
outputs = []
total = 0
total_step = math.ceil(len(df) / batch_size)
pbar = tqdm(total=total_step, unit='batch')
# 计算准确率
with torch.no_grad():
for i in range(total_step):
batch = df.iloc[i*batch_size:(i+1)*batch_size]
inputs = tokenizer(batch['prompt'].tolist(), return_tensors="pt", padding=True)[
'input_ids'].to(device)
generation_outputs = model.generate(
input_ids=inputs,
generation_config=generation_config,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.pad_token_id
)
for g in generation_outputs:
decoded_item = tokenizer.decode(
g, skip_special_tokens=True)
try:
output = prompter.get_response(decoded_item)
except:
output = decoded_item
outputs.append(output)
total += 1
# 更新进度条
pbar.set_description(f"Testing: Sample [{total}/{len(df)}] ")
pbar.update(1)
df['pred'] = outputs
df['pred'].to_csv(output_path, index=False)
evaluate_by_batch()
if __name__ == "__main__":
# fire.Fire(main)
import yaml
dataset_param = sys.argv[1]
with open("./configs/evaluate_params.yaml", "r") as stream:
# try:
params = yaml.safe_load(stream)
print('=' * 80)
print(params[dataset_param])
print('=' * 80)
# fire.Fire(train)
main(**params[dataset_param])

51
src/utils/merge.py Normal file
View File

@ -0,0 +1,51 @@
import os
import torch
import transformers
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
BASE_MODEL = os.environ.get("BASE_MODEL", None)
assert (
BASE_MODEL
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
base_model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
load_in_8bit=False,
torch_dtype=torch.float16,
device_map={"": "cpu"},
)
first_weight = base_model.model.layers[0].self_attn.q_proj.weight
first_weight_old = first_weight.clone()
lora_model = PeftModel.from_pretrained(
base_model,
"../outputs/lora-llama-clm-e2",
device_map={"": "cpu"},
torch_dtype=torch.float16,
)
lora_weight = lora_model.base_model.model.model.layers[0].self_attn.q_proj.weight
assert torch.allclose(first_weight_old, first_weight)
# merge weights - new merging method from peft
lora_model = lora_model.merge_and_unload()
lora_model.train(False)
# did we do anything?
assert not torch.allclose(first_weight_old, first_weight)
lora_model_sd = lora_model.state_dict()
deloreanized_sd = {
k.replace("base_model.model.", ""): v
for k, v in lora_model_sd.items()
if "lora" not in k
}
LlamaForCausalLM.save_pretrained(base_model, '../models/LawGPT_step_1', state_dict=deloreanized_sd, max_shard_size="400MB")

51
src/utils/prompter.py Normal file
View File

@ -0,0 +1,51 @@
"""
A dedicated helper to manage templates and prompt building.
"""
import json
import os.path as osp
from typing import Union
class Prompter(object):
__slots__ = ("template", "_verbose")
def __init__(self, template_name: str = "", verbose: bool = False):
self._verbose = verbose
if not template_name:
# Enforce the default here, so the constructor can be called with '' and will not break.
template_name = "alpaca"
file_name = osp.join("templates", f"{template_name}.json")
if not osp.exists(file_name):
raise ValueError(f"Can't read {file_name}")
with open(file_name) as fp:
self.template = json.load(fp)
if self._verbose:
print(
f"Using prompt template {template_name}: {self.template['description']}"
)
def generate_prompt(
self,
instruction: str,
input: Union[None, str] = None,
label: Union[None, str] = None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.template["prompt_input"].format(
instruction=instruction, input=input
)
else:
res = self.template["prompt_no_input"].format(
instruction=instruction
)
if label:
res = f"{res}{label}"
if self._verbose:
print(res)
return res
def get_response(self, output: str) -> str:
return output.split(self.template["response_split"])[1].strip()