From dc849bd2824dd46f6a95ca9f5b4f93db6eda45e1 Mon Sep 17 00:00:00 2001 From: WNJXYK Date: Sat, 27 May 2023 17:04:11 +0800 Subject: [PATCH] [ENH] Update GUI for matching legal grounds --- webui.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/webui.py b/webui.py index 453043a..ff8191c 100644 --- a/webui.py +++ b/webui.py @@ -10,6 +10,7 @@ from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, Aut from utils.callbacks import Iteratorize, Stream from utils.prompter import Prompter +from utils.knowledge import Knowledge if torch.cuda.is_available(): device = "cuda" @@ -37,6 +38,7 @@ def main( ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" prompter = Prompter(prompt_template) + knowledge = Knowledge() tokenizer = LlamaTokenizer.from_pretrained(base_model) if device == "cuda": model = LlamaForCausalLM.from_pretrained( @@ -106,6 +108,7 @@ def main( ): input=None prompt = prompter.generate_prompt(instruction, input) + legals = knowledge.query_prompt(instruction) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( @@ -152,7 +155,7 @@ def main( if output[-1] in [tokenizer.eos_token_id]: break - yield prompter.get_response(decoded_output) + yield prompter.get_response(decoded_output), knowledge.get_response(legals) print(decoded_output) return # early return for stream_output @@ -168,7 +171,7 @@ def main( s = generation_output.sequences[0] output = tokenizer.decode(s) print(output) - yield prompter.get_response(output) + yield prompter.get_response(output), knowledge.get_response(legals) gr.Interface( fn=evaluate, @@ -200,6 +203,9 @@ def main( gr.inputs.Textbox( lines=8, label="Output", + ), + gr.inputs.Textbox( + label="Legal Ground", ) ], title="🦙🌲 LaWGPT",