diff --git a/README.md b/README.md index 8ece76a..28c1422 100644 --- a/README.md +++ b/README.md @@ -120,7 +120,7 @@ LawGPT 系列模型的训练过程分为两个阶段: ### 二次训练流程 1. 参考 `data/example_instruction_train.json` 构造指令微调数据集 -2. 运行 `src/scripts/finetune.sh` +2. 运行 `src/scripts/train.sh` ### 指令精调步骤 diff --git a/src/scripts/train.sh b/src/scripts/train.sh new file mode 100644 index 0000000..3826276 --- /dev/null +++ b/src/scripts/train.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1235 train.py \ + --base_model '../models/base_models/chinese_llama_7b' \ + --data_path '' \ + --output_dir '../models/lora_weights' \ + --batch_size 128 \ + --micro_batch_size 8 \ + --num_epochs 1 \ + --learning_rate 0.0003 \ + --cutoff_len 1024 \ + --val_set_size 0 \ + --lora_r 16 \ + --lora_alpha 32 \ + --lora_dropout 0.05 \ + --lora_target_modules '[q_proj, v_proj, k_proj, o_proj]' \ + --train_on_inputs True \ + --add_eos_token True \ + --group_by_length True \ + --resume_from_checkpoint '../models/lora_weights' diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..967060c --- /dev/null +++ b/src/train.py @@ -0,0 +1,259 @@ +import os +import sys +from typing import List + +import fire +import torch +import transformers +from datasets import load_dataset + +from peft import ( + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + prepare_model_for_int8_training, + set_peft_model_state_dict, +) +from transformers import LlamaForCausalLM, LlamaTokenizer +from utils.prompter import Prompter + + +def train( + # model/data params + base_model: str = "./models/base_models/your_base_model_dir", + data_path: str = "./data/your_data.json", + output_dir: str = "./outputs/your_version_dir", + + # training hyperparams + batch_size: int = 128, + micro_batch_size: int = 4, + num_epochs: int = 10, + learning_rate: float = 3e-4, + cutoff_len: int = 512, + val_set_size: int = 2000, + + # lora hyperparams + lora_r: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.05, + lora_target_modules: List[str] = ["q_proj", "v_proj",], + + # llm hyperparams + train_on_inputs: bool = True, # if False, masks out inputs in loss + add_eos_token: bool = True, + group_by_length: bool = False, # faster, but produces an odd training loss curve + + # wandb params + wandb_project: str = "", + wandb_run_name: str = "", + wandb_watch: str = "", # options: false | gradients | all + wandb_log_model: str = "", # options: false | true + + # either training checkpoint or final adapter + resume_from_checkpoint: str = None, + + # The prompt template to use, will default to alpaca. + prompt_template_name: str = "alpaca", +): + if int(os.environ.get("LOCAL_RANK", 0)) == 0: + print( + f"Training Alpaca-LoRA model with params:\n" + f"base_model: {base_model}\n" + f"data_path: {data_path}\n" + f"output_dir: {output_dir}\n" + f"batch_size: {batch_size}\n" + f"micro_batch_size: {micro_batch_size}\n" + f"num_epochs: {num_epochs}\n" + f"learning_rate: {learning_rate}\n" + f"cutoff_len: {cutoff_len}\n" + f"val_set_size: {val_set_size}\n" + f"lora_r: {lora_r}\n" + f"lora_alpha: {lora_alpha}\n" + f"lora_dropout: {lora_dropout}\n" + f"lora_target_modules: {lora_target_modules}\n" + f"train_on_inputs: {train_on_inputs}\n" + f"add_eos_token: {add_eos_token}\n" + f"group_by_length: {group_by_length}\n" + f"wandb_project: {wandb_project}\n" + f"wandb_run_name: {wandb_run_name}\n" + f"wandb_watch: {wandb_watch}\n" + f"wandb_log_model: {wandb_log_model}\n" + f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" + f"prompt template: {prompt_template_name}\n" + ) + gradient_accumulation_steps = batch_size // micro_batch_size + + prompter = Prompter(prompt_template_name) + + # Configure device and distributed training + device_map = "auto" + world_size = int(os.environ.get("WORLD_SIZE", 1)) + ddp = world_size != 1 + if ddp: + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} + gradient_accumulation_steps = gradient_accumulation_steps // world_size + + # Check if parameter passed or if set within environ + use_wandb = len(wandb_project) > 0 or ( + "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0) + + # Only overwrite environ if wandb param passed + if len(wandb_project) > 0: + os.environ["WANDB_PROJECT"] = wandb_project + if len(wandb_watch) > 0: + os.environ["WANDB_WATCH"] = wandb_watch + if len(wandb_log_model) > 0: + os.environ["WANDB_LOG_MODEL"] = wandb_log_model + + model = LlamaForCausalLM.from_pretrained( + base_model, + load_in_8bit=True, + torch_dtype=torch.float16, + device_map=device_map, + ) + + tokenizer = LlamaTokenizer.from_pretrained(base_model) + tokenizer.bos_token_id = 1 + tokenizer.eos_token_id = 2 + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + + print("pre-trained model's BOS EOS and PAD token id:", + bos, eos, pad, " => It should be 1,2,none") + + tokenizer.pad_token_id = ( + 0 # unk. we want this to be different from the eos token + ) + tokenizer.padding_side = "left" # Allow batched inference + + def tokenize(prompt, add_eos_token=True): + # there's probably a way to do this with the tokenizer settings + # but again, gotta move fast + result = tokenizer( + prompt, + truncation=True, + max_length=cutoff_len, + padding=False, + return_tensors=None, + ) + if ( + result["input_ids"][-1] != tokenizer.eos_token_id + and len(result["input_ids"]) < cutoff_len + and add_eos_token + ): + result["input_ids"].append(tokenizer.eos_token_id) + result["attention_mask"].append(1) + + result["labels"] = result["input_ids"].copy() + + return result + + def generate_and_tokenize_prompt(data_point): + text = data_point['content'] + tokenized_full_prompt = tokenize(text) + return tokenized_full_prompt + + model = prepare_model_for_int8_training(model) + + config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=lora_target_modules, + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, config) + + if data_path.endswith(".json") or data_path.endswith(".jsonl"): + data = load_dataset("json", data_files=data_path) + else: + data = load_dataset(data_path) + + if resume_from_checkpoint: + # Check the available weights and load them + checkpoint_name = os.path.join( + resume_from_checkpoint, "pytorch_model.bin" + ) # Full checkpoint + if not os.path.exists(checkpoint_name): + checkpoint_name = os.path.join( + resume_from_checkpoint, "adapter_model.bin" + ) # only LoRA model - LoRA config above has to fit + resume_from_checkpoint = ( + False # So the trainer won't try loading its state + ) + # The two files above have a different name depending on how they were saved, but are actually the same. + if os.path.exists(checkpoint_name): + print(f"Restarting from {checkpoint_name}") + adapters_weights = torch.load(checkpoint_name) + set_peft_model_state_dict(model, adapters_weights) + else: + print(f"Checkpoint {checkpoint_name} not found") + + # Be more transparent about the % of trainable params. + model.print_trainable_parameters() + + if val_set_size > 0: + train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42) + train_data = (train_val["train"].shuffle().map(generate_and_tokenize_prompt)) + val_data = (train_val["test"].shuffle().map(generate_and_tokenize_prompt)) + else: + train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) + val_data = None + + if not ddp and torch.cuda.device_count() > 1: + # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available + model.is_parallelizable = True + model.model_parallel = True + + trainer = transformers.Trainer( + model=model, + train_dataset=train_data, + eval_dataset=val_data, + args=transformers.TrainingArguments( + per_device_train_batch_size=micro_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + warmup_steps=100, + num_train_epochs=num_epochs, + learning_rate=learning_rate, + fp16=True, + logging_steps=10, + optim="adamw_torch", + evaluation_strategy="steps" if val_set_size > 0 else "no", + save_strategy="steps", + eval_steps=100 if val_set_size > 0 else None, + save_steps=100, + output_dir=output_dir, + save_total_limit=3, + load_best_model_at_end=True if val_set_size > 0 else False, + ddp_find_unused_parameters=False if ddp else None, + group_by_length=group_by_length, + report_to="wandb" if use_wandb else None, + run_name=wandb_run_name if use_wandb else None, + ), + data_collator=transformers.DataCollatorForSeq2Seq( + tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True + ), + ) + model.config.use_cache = False + + old_state_dict = model.state_dict + model.state_dict = ( + lambda self, *_, **__: get_peft_model_state_dict( + self, old_state_dict() + ) + ).__get__(model, type(model)) + + if torch.__version__ >= "2" and sys.platform != "win32": + model = torch.compile(model) + + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + + model.save_pretrained(output_dir) + + print("\n If there's a warning about missing keys above, please disregard :)") + + +if __name__ == "__main__": + fire.Fire(train)