diff --git a/src/finetune.py b/src/finetune.py deleted file mode 100644 index ff7c0b3..0000000 --- a/src/finetune.py +++ /dev/null @@ -1,280 +0,0 @@ -import os -import sys -from typing import List - -import fire -import torch -import transformers -from datasets import load_dataset - -from peft import ( - LoraConfig, - get_peft_model, - get_peft_model_state_dict, - prepare_model_for_int8_training, - set_peft_model_state_dict, -) -from transformers import LlamaForCausalLM, LlamaTokenizer -from utils.prompter import Prompter - - -def train( - # model/data params - base_model: str = "./models/base_models/your_base_model_dir", - data_path: str = "./data/your_data.json", - output_dir: str = "./outputs/your_version_dir", - - # training hyperparams - batch_size: int = 128, - micro_batch_size: int = 4, - num_epochs: int = 10, - learning_rate: float = 3e-4, - cutoff_len: int = 512, - val_set_size: int = 2000, - - # lora hyperparams - lora_r: int = 8, - lora_alpha: int = 16, - lora_dropout: float = 0.05, - lora_target_modules: List[str] = ["q_proj", "v_proj",], - - # llm hyperparams - train_on_inputs: bool = True, # if False, masks out inputs in loss - add_eos_token: bool = True, - group_by_length: bool = False, # faster, but produces an odd training loss curve - - # wandb params - wandb_project: str = "", - wandb_run_name: str = "", - wandb_watch: str = "", # options: false | gradients | all - wandb_log_model: str = "", # options: false | true - - # either training checkpoint or final adapter - resume_from_checkpoint: str = None, - - # The prompt template to use, will default to alpaca. - prompt_template_name: str = "alpaca", -): - if int(os.environ.get("LOCAL_RANK", 0)) == 0: - print( - f"Training Alpaca-LoRA model with params:\n" - f"base_model: {base_model}\n" - f"data_path: {data_path}\n" - f"output_dir: {output_dir}\n" - f"batch_size: {batch_size}\n" - f"micro_batch_size: {micro_batch_size}\n" - f"num_epochs: {num_epochs}\n" - f"learning_rate: {learning_rate}\n" - f"cutoff_len: {cutoff_len}\n" - f"val_set_size: {val_set_size}\n" - f"lora_r: {lora_r}\n" - f"lora_alpha: {lora_alpha}\n" - f"lora_dropout: {lora_dropout}\n" - f"lora_target_modules: {lora_target_modules}\n" - f"train_on_inputs: {train_on_inputs}\n" - f"add_eos_token: {add_eos_token}\n" - f"group_by_length: {group_by_length}\n" - f"wandb_project: {wandb_project}\n" - f"wandb_run_name: {wandb_run_name}\n" - f"wandb_watch: {wandb_watch}\n" - f"wandb_log_model: {wandb_log_model}\n" - f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" - f"prompt template: {prompt_template_name}\n" - ) - gradient_accumulation_steps = batch_size // micro_batch_size - - prompter = Prompter(prompt_template_name) - - # Configure device and distributed training - device_map = "auto" - world_size = int(os.environ.get("WORLD_SIZE", 1)) - ddp = world_size != 1 - if ddp: - device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} - gradient_accumulation_steps = gradient_accumulation_steps // world_size - - # Check if parameter passed or if set within environ - use_wandb = len(wandb_project) > 0 or ( - "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0) - - # Only overwrite environ if wandb param passed - if len(wandb_project) > 0: - os.environ["WANDB_PROJECT"] = wandb_project - if len(wandb_watch) > 0: - os.environ["WANDB_WATCH"] = wandb_watch - if len(wandb_log_model) > 0: - os.environ["WANDB_LOG_MODEL"] = wandb_log_model - - model = LlamaForCausalLM.from_pretrained( - base_model, - load_in_8bit=True, - torch_dtype=torch.float16, - device_map=device_map, - ) - - tokenizer = LlamaTokenizer.from_pretrained(base_model) - tokenizer.bos_token_id = 1 - tokenizer.eos_token_id = 2 - bos = tokenizer.bos_token_id - eos = tokenizer.eos_token_id - pad = tokenizer.pad_token_id - - print("pre-trained model's BOS EOS and PAD token id:", - bos, eos, pad, " => It should be 1,2,none") - - tokenizer.pad_token_id = ( - 0 # unk. we want this to be different from the eos token - ) - tokenizer.padding_side = "left" # Allow batched inference - - def tokenize(prompt, add_eos_token=True): - # there's probably a way to do this with the tokenizer settings - # but again, gotta move fast - result = tokenizer( - prompt, - truncation=True, - max_length=cutoff_len, - padding=False, - return_tensors=None, - ) - if ( - result["input_ids"][-1] != tokenizer.eos_token_id - and len(result["input_ids"]) < cutoff_len - and add_eos_token - ): - result["input_ids"].append(tokenizer.eos_token_id) - result["attention_mask"].append(1) - - result["labels"] = result["input_ids"].copy() - - return result - - def generate_and_tokenize_prompt(data_point): - full_prompt = prompter.generate_prompt( - data_point["instruction"], - data_point["input"], - data_point["output"], - ) - tokenized_full_prompt = tokenize(full_prompt) - if not train_on_inputs: - user_prompt = prompter.generate_prompt( - data_point["instruction"], data_point["input"] - ) - tokenized_user_prompt = tokenize( - user_prompt, add_eos_token=add_eos_token - ) - user_prompt_len = len(tokenized_user_prompt["input_ids"]) - - if add_eos_token: - user_prompt_len -= 1 - - tokenized_full_prompt["labels"] = [ - -100 - ] * user_prompt_len + tokenized_full_prompt["labels"][ - user_prompt_len: - ] # could be sped up, probably - return tokenized_full_prompt - - model = prepare_model_for_int8_training(model) - - config = LoraConfig( - r=lora_r, - lora_alpha=lora_alpha, - target_modules=lora_target_modules, - lora_dropout=lora_dropout, - bias="none", - task_type="CAUSAL_LM", - ) - model = get_peft_model(model, config) - - if data_path.endswith(".json") or data_path.endswith(".jsonl"): - data = load_dataset("json", data_files=data_path) - else: - data = load_dataset(data_path) - - if resume_from_checkpoint: - # Check the available weights and load them - checkpoint_name = os.path.join( - resume_from_checkpoint, "pytorch_model.bin" - ) # Full checkpoint - if not os.path.exists(checkpoint_name): - checkpoint_name = os.path.join( - resume_from_checkpoint, "adapter_model.bin" - ) # only LoRA model - LoRA config above has to fit - resume_from_checkpoint = ( - False # So the trainer won't try loading its state - ) - # The two files above have a different name depending on how they were saved, but are actually the same. - if os.path.exists(checkpoint_name): - print(f"Restarting from {checkpoint_name}") - adapters_weights = torch.load(checkpoint_name) - set_peft_model_state_dict(model, adapters_weights) - else: - print(f"Checkpoint {checkpoint_name} not found") - - # Be more transparent about the % of trainable params. - model.print_trainable_parameters() - - if val_set_size > 0: - train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42) - train_data = (train_val["train"].shuffle().map(generate_and_tokenize_prompt)) - val_data = (train_val["test"].shuffle().map(generate_and_tokenize_prompt)) - else: - train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) - val_data = None - - if not ddp and torch.cuda.device_count() > 1: - # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available - model.is_parallelizable = True - model.model_parallel = True - - trainer = transformers.Trainer( - model=model, - train_dataset=train_data, - eval_dataset=val_data, - args=transformers.TrainingArguments( - per_device_train_batch_size=micro_batch_size, - gradient_accumulation_steps=gradient_accumulation_steps, - warmup_steps=100, - num_train_epochs=num_epochs, - learning_rate=learning_rate, - fp16=True, - logging_steps=10, - optim="adamw_torch", - evaluation_strategy="steps" if val_set_size > 0 else "no", - save_strategy="steps", - eval_steps=100 if val_set_size > 0 else None, - save_steps=100, - output_dir=output_dir, - save_total_limit=3, - load_best_model_at_end=True if val_set_size > 0 else False, - ddp_find_unused_parameters=False if ddp else None, - group_by_length=group_by_length, - report_to="wandb" if use_wandb else None, - run_name=wandb_run_name if use_wandb else None, - ), - data_collator=transformers.DataCollatorForSeq2Seq( - tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True - ), - ) - model.config.use_cache = False - - old_state_dict = model.state_dict - model.state_dict = ( - lambda self, *_, **__: get_peft_model_state_dict( - self, old_state_dict() - ) - ).__get__(model, type(model)) - - if torch.__version__ >= "2" and sys.platform != "win32": - model = torch.compile(model) - - trainer.train(resume_from_checkpoint=resume_from_checkpoint) - - model.save_pretrained(output_dir) - - print("\n If there's a warning about missing keys above, please disregard :)") - - -if __name__ == "__main__": - fire.Fire(train) diff --git a/src/generate.py b/src/generate.py deleted file mode 100644 index bbe5513..0000000 --- a/src/generate.py +++ /dev/null @@ -1,201 +0,0 @@ -import os -import sys - -import fire -import gradio as gr -import torch -import transformers -from peft import PeftModel -from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer - -from utils.callbacks import Iteratorize, Stream -from utils.prompter import Prompter - -if torch.cuda.is_available(): - device = "cuda" -else: - device = "cpu" - -try: - if torch.backends.mps.is_available(): - device = "mps" -except: # noqa: E722 - pass - - -def main( - load_8bit: bool = False, - base_model: str = "", - lora_weights: str = "tloen/alpaca-lora-7b", - prompt_template: str = "", # The prompt template to use, will default to alpaca. - server_name: str = "0.0.0.0", # Allows to listen on all interfaces by providing '0. - share_gradio: bool = False, -): - base_model = base_model or os.environ.get("BASE_MODEL", "") - assert ( - base_model - ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" - - prompter = Prompter(prompt_template) - tokenizer = LlamaTokenizer.from_pretrained(base_model) - if device == "cuda": - model = LlamaForCausalLM.from_pretrained( - base_model, - load_in_8bit=load_8bit, - torch_dtype=torch.float16, - device_map="auto", - ) - model = PeftModel.from_pretrained( - model, - lora_weights, - torch_dtype=torch.float16, - ) - elif device == "mps": - model = LlamaForCausalLM.from_pretrained( - base_model, - device_map={"": device}, - torch_dtype=torch.float16, - ) - model = PeftModel.from_pretrained( - model, - lora_weights, - device_map={"": device}, - torch_dtype=torch.float16, - ) - else: - model = LlamaForCausalLM.from_pretrained( - base_model, - device_map={"": device}, low_cpu_mem_usage=True - ) - model = PeftModel.from_pretrained( - model, - lora_weights, - device_map={"": device}, - ) - - # unwind broken decapoda-research config - model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk - model.config.bos_token_id = 1 - model.config.eos_token_id = 2 - - if not load_8bit: - model.half() # seems to fix bugs for some users. - - model.eval() - if torch.__version__ >= "2" and sys.platform != "win32": - model = torch.compile(model) - - def evaluate( - instruction, - input=None, - temperature=0.1, - top_p=0.75, - top_k=40, - num_beams=1, - max_new_tokens=256, - stream_output=True, - **kwargs, - ): - prompt = prompter.generate_prompt(instruction, input) - inputs = tokenizer(prompt, return_tensors="pt") - input_ids = inputs["input_ids"].to(device) - generation_config = GenerationConfig( - temperature=temperature, - top_p=top_p, - top_k=top_k, - num_beams=num_beams, - **kwargs, - ) - - generate_params = { - "input_ids": input_ids, - "generation_config": generation_config, - "return_dict_in_generate": True, - "output_scores": True, - "max_new_tokens": max_new_tokens, - } - - if stream_output: - # Stream the reply 1 token at a time. - # This is based on the trick of using 'stopping_criteria' to create an iterator, - # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243. - - def generate_with_callback(callback=None, **kwargs): - kwargs.setdefault( - "stopping_criteria", transformers.StoppingCriteriaList() - ) - kwargs["stopping_criteria"].append( - Stream(callback_func=callback) - ) - with torch.no_grad(): - model.generate(**kwargs) - - def generate_with_streaming(**kwargs): - return Iteratorize( - generate_with_callback, kwargs, callback=None - ) - - with generate_with_streaming(**generate_params) as generator: - for output in generator: - # new_tokens = len(output) - len(input_ids[0]) - decoded_output = tokenizer.decode(output) - - if output[-1] in [tokenizer.eos_token_id]: - break - - yield prompter.get_response(decoded_output) - return # early return for stream_output - - # Without streaming - with torch.no_grad(): - generation_output = model.generate( - input_ids=input_ids, - generation_config=generation_config, - return_dict_in_generate=True, - output_scores=True, - max_new_tokens=max_new_tokens, - ) - s = generation_output.sequences[0] - output = tokenizer.decode(s) - yield prompter.get_response(output) - - gr.Interface( - fn=evaluate, - inputs=[ - gr.components.Textbox( - lines=2, - label="Instruction", - placeholder="Tell me about alpacas.", - ), - gr.components.Textbox(lines=2, label="Input", placeholder="none"), - gr.components.Slider( - minimum=0, maximum=1, value=1.0, label="Temperature" - ), - gr.components.Slider( - minimum=0, maximum=1, value=0.75, label="Top p" - ), - gr.components.Slider( - minimum=0, maximum=100, step=1, value=40, label="Top k" - ), - gr.components.Slider( - minimum=1, maximum=4, step=1, value=4, label="Beams" - ), - gr.components.Slider( - minimum=1, maximum=2000, step=1, value=256, label="Max tokens" - ), - gr.components.Checkbox(label="Stream output", value=True), - ], - outputs=[ - gr.inputs.Textbox( - lines=5, - label="Output", - ) - ], - title="🦙🌲 LaWGPT", - description="", # noqa: E501 - ).queue().launch(server_name="0.0.0.0", share=share_gradio) - # Old testing code follows. - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/src/train_lora.py b/train_clm.py similarity index 100% rename from src/train_lora.py rename to train_clm.py diff --git a/webui.py b/webui.py index 5ae7210..39381d9 100644 --- a/webui.py +++ b/webui.py @@ -19,7 +19,7 @@ else: try: if torch.backends.mps.is_available(): device = "mps" -except: # noqa: E722 +except: pass @@ -201,28 +201,9 @@ def main( label="Output", ) ], - title="🦙🌲 LLM-LoRA", - description="", # noqa: E501 + title="🦙🌲 LaWGPT", + description="", ).queue().launch(server_name="0.0.0.0", share=share_gradio) - # Old testing code follows. - - """ - # testing code for readme - for instruction in [ - "Tell me about alpacas.", - "Tell me about the president of Mexico in 2019.", - "Tell me about the king of France in 2019.", - "List all Canadian provinces in alphabetical order.", - "Write a Python program that prints the first 10 Fibonacci numbers.", - "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", # noqa: E501 - "Tell me five words that rhyme with 'shock'.", - "Translate the sentence 'I have no mouth but I must scream' into Spanish.", - "Count up from 1 to 500.", - ]: - print("Instruction:", instruction) - print("Response:", evaluate(instruction)) - print() - """ if __name__ == "__main__":