[ENH] Update GUI for matching legal grounds

This commit is contained in:
WNJXYK
2023-05-27 17:04:11 +08:00
parent 39eed4febe
commit dc849bd282

View File

@ -10,6 +10,7 @@ from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, Aut
from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter
from utils.knowledge import Knowledge
if torch.cuda.is_available():
device = "cuda"
@ -37,6 +38,7 @@ def main(
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
prompter = Prompter(prompt_template)
knowledge = Knowledge()
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
@ -106,6 +108,7 @@ def main(
):
input=None
prompt = prompter.generate_prompt(instruction, input)
legals = knowledge.query_prompt(instruction)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
@ -152,7 +155,7 @@ def main(
if output[-1] in [tokenizer.eos_token_id]:
break
yield prompter.get_response(decoded_output)
yield prompter.get_response(decoded_output), knowledge.get_response(legals)
print(decoded_output)
return # early return for stream_output
@ -168,7 +171,7 @@ def main(
s = generation_output.sequences[0]
output = tokenizer.decode(s)
print(output)
yield prompter.get_response(output)
yield prompter.get_response(output), knowledge.get_response(legals)
gr.Interface(
fn=evaluate,
@ -200,6 +203,9 @@ def main(
gr.inputs.Textbox(
lines=8,
label="Output",
),
gr.inputs.Textbox(
label="Legal Ground",
)
],
title="🦙🌲 LaWGPT",