[ENH] Update GUI for matching legal grounds

This commit is contained in:
WNJXYK
2023-05-27 17:04:11 +08:00
parent 39eed4febe
commit dc849bd282

View File

@ -10,6 +10,7 @@ from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, Aut
from utils.callbacks import Iteratorize, Stream from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter from utils.prompter import Prompter
from utils.knowledge import Knowledge
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = "cuda"
@ -37,6 +38,7 @@ def main(
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
prompter = Prompter(prompt_template) prompter = Prompter(prompt_template)
knowledge = Knowledge()
tokenizer = LlamaTokenizer.from_pretrained(base_model) tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda": if device == "cuda":
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
@ -106,6 +108,7 @@ def main(
): ):
input=None input=None
prompt = prompter.generate_prompt(instruction, input) prompt = prompter.generate_prompt(instruction, input)
legals = knowledge.query_prompt(instruction)
inputs = tokenizer(prompt, return_tensors="pt") inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device) input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig( generation_config = GenerationConfig(
@ -152,7 +155,7 @@ def main(
if output[-1] in [tokenizer.eos_token_id]: if output[-1] in [tokenizer.eos_token_id]:
break break
yield prompter.get_response(decoded_output) yield prompter.get_response(decoded_output), knowledge.get_response(legals)
print(decoded_output) print(decoded_output)
return # early return for stream_output return # early return for stream_output
@ -168,7 +171,7 @@ def main(
s = generation_output.sequences[0] s = generation_output.sequences[0]
output = tokenizer.decode(s) output = tokenizer.decode(s)
print(output) print(output)
yield prompter.get_response(output) yield prompter.get_response(output), knowledge.get_response(legals)
gr.Interface( gr.Interface(
fn=evaluate, fn=evaluate,
@ -200,6 +203,9 @@ def main(
gr.inputs.Textbox( gr.inputs.Textbox(
lines=8, lines=8,
label="Output", label="Output",
),
gr.inputs.Textbox(
label="Legal Ground",
) )
], ],
title="🦙🌲 LaWGPT", title="🦙🌲 LaWGPT",