[MNT] Update webui.py

This commit is contained in:
herobrine19
2023-05-22 17:18:21 +08:00
parent 2fbc20a080
commit 5807e73cfc
2 changed files with 6 additions and 4 deletions

1
.gitignore vendored
View File

@ -15,4 +15,5 @@ outputs/*
data/* data/*
!data/.gitkeep !data/.gitkeep
wandb/ wandb/
flagged/
.DS_Store .DS_Store

View File

@ -95,7 +95,7 @@ def main(
def evaluate( def evaluate(
instruction, instruction,
input=None, # input=None,
temperature=0.1, temperature=0.1,
top_p=0.75, top_p=0.75,
top_k=40, top_k=40,
@ -104,6 +104,7 @@ def main(
stream_output=False, stream_output=False,
**kwargs, **kwargs,
): ):
input=None
prompt = prompter.generate_prompt(instruction, input) prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt") inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device) input_ids = inputs["input_ids"].to(device)
@ -175,9 +176,9 @@ def main(
gr.components.Textbox( gr.components.Textbox(
lines=2, lines=2,
label="Instruction", label="Instruction",
placeholder="Tell me about alpacas.", placeholder="此处输入法律相关问题",
), ),
gr.components.Textbox(lines=2, label="Input", placeholder="none"), # gr.components.Textbox(lines=2, label="Input", placeholder="none"),
gr.components.Slider( gr.components.Slider(
minimum=0, maximum=1, value=0.1, label="Temperature" minimum=0, maximum=1, value=0.1, label="Temperature"
), ),
@ -197,7 +198,7 @@ def main(
], ],
outputs=[ outputs=[
gr.inputs.Textbox( gr.inputs.Textbox(
lines=5, lines=8,
label="Output", label="Output",
) )
], ],