[ENH] Update GUI for matching legal grounds
This commit is contained in:
10
webui.py
10
webui.py
@ -10,6 +10,7 @@ from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, Aut
|
|||||||
|
|
||||||
from utils.callbacks import Iteratorize, Stream
|
from utils.callbacks import Iteratorize, Stream
|
||||||
from utils.prompter import Prompter
|
from utils.prompter import Prompter
|
||||||
|
from utils.knowledge import Knowledge
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
@ -37,6 +38,7 @@ def main(
|
|||||||
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
|
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
|
||||||
|
|
||||||
prompter = Prompter(prompt_template)
|
prompter = Prompter(prompt_template)
|
||||||
|
knowledge = Knowledge()
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||||
if device == "cuda":
|
if device == "cuda":
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
@ -106,6 +108,7 @@ def main(
|
|||||||
):
|
):
|
||||||
input=None
|
input=None
|
||||||
prompt = prompter.generate_prompt(instruction, input)
|
prompt = prompter.generate_prompt(instruction, input)
|
||||||
|
legals = knowledge.query_prompt(instruction)
|
||||||
inputs = tokenizer(prompt, return_tensors="pt")
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
input_ids = inputs["input_ids"].to(device)
|
input_ids = inputs["input_ids"].to(device)
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
@ -152,7 +155,7 @@ def main(
|
|||||||
if output[-1] in [tokenizer.eos_token_id]:
|
if output[-1] in [tokenizer.eos_token_id]:
|
||||||
break
|
break
|
||||||
|
|
||||||
yield prompter.get_response(decoded_output)
|
yield prompter.get_response(decoded_output), knowledge.get_response(legals)
|
||||||
print(decoded_output)
|
print(decoded_output)
|
||||||
return # early return for stream_output
|
return # early return for stream_output
|
||||||
|
|
||||||
@ -168,7 +171,7 @@ def main(
|
|||||||
s = generation_output.sequences[0]
|
s = generation_output.sequences[0]
|
||||||
output = tokenizer.decode(s)
|
output = tokenizer.decode(s)
|
||||||
print(output)
|
print(output)
|
||||||
yield prompter.get_response(output)
|
yield prompter.get_response(output), knowledge.get_response(legals)
|
||||||
|
|
||||||
gr.Interface(
|
gr.Interface(
|
||||||
fn=evaluate,
|
fn=evaluate,
|
||||||
@ -200,6 +203,9 @@ def main(
|
|||||||
gr.inputs.Textbox(
|
gr.inputs.Textbox(
|
||||||
lines=8,
|
lines=8,
|
||||||
label="Output",
|
label="Output",
|
||||||
|
),
|
||||||
|
gr.inputs.Textbox(
|
||||||
|
label="Legal Ground",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
title="🦙🌲 LaWGPT",
|
title="🦙🌲 LaWGPT",
|
||||||
|
|||||||
Reference in New Issue
Block a user