对话时增加等待过渡
This commit is contained in:
70
func_box.py
70
func_box.py
@ -4,6 +4,7 @@
|
|||||||
# @Author : Spike
|
# @Author : Spike
|
||||||
# @Descr :
|
# @Descr :
|
||||||
import ast
|
import ast
|
||||||
|
import copy
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
@ -101,45 +102,6 @@ def timeStatistics(func):
|
|||||||
return statistics
|
return statistics
|
||||||
|
|
||||||
|
|
||||||
def context_with(*parms):
|
|
||||||
"""
|
|
||||||
一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。
|
|
||||||
Args:
|
|
||||||
*parms: 参数列表,每个参数都是一个字符串,表示类中的一个属性名。
|
|
||||||
Returns:
|
|
||||||
一个装饰器函数。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(cls_method):
|
|
||||||
"""
|
|
||||||
装饰器函数,用于将一个类方法转换为一个嵌套多个 with 语句的方法。
|
|
||||||
Args:
|
|
||||||
cls_method: 要装饰的类方法。
|
|
||||||
Returns:
|
|
||||||
装饰后的类方法。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def wrapper(cls='', *args, **kwargs):
|
|
||||||
"""
|
|
||||||
装饰后的方法,用于嵌套多个 with 语句,并调用原始的类方法。
|
|
||||||
Args:
|
|
||||||
cls: 类的实例对象。
|
|
||||||
*args: 位置参数。
|
|
||||||
**kwargs: 关键字参数。
|
|
||||||
Returns:
|
|
||||||
原始的类方法返回的结果。
|
|
||||||
"""
|
|
||||||
with_list = [getattr(cls, arg) for arg in parms]
|
|
||||||
with ExitStack() as stack:
|
|
||||||
for context in with_list:
|
|
||||||
stack.enter_context(context)
|
|
||||||
return cls_method(cls, *args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def copy_temp_file(file):
|
def copy_temp_file(file):
|
||||||
if os.path.exists(file):
|
if os.path.exists(file):
|
||||||
exdir = tempfile.mkdtemp()
|
exdir = tempfile.mkdtemp()
|
||||||
@ -497,15 +459,15 @@ def show_prompt_result(index, data: gr.Dataset, chatbot):
|
|||||||
return chatbot
|
return chatbot
|
||||||
|
|
||||||
|
|
||||||
|
pattern_markdown = re.compile(r'^<div class="markdown-body"><p>|<\/p><\/div>$')
|
||||||
|
pattern_markdown_p = re.compile(r'^<div class="markdown-body">|<\/div>$')
|
||||||
def thread_write_chat(chatbot, history):
|
def thread_write_chat(chatbot, history):
|
||||||
"""
|
"""
|
||||||
对话记录写入数据库
|
对话记录写入数据库
|
||||||
"""
|
"""
|
||||||
private_key = toolbox.get_conf('private_key')[0]
|
private_key = toolbox.get_conf('private_key')[0]
|
||||||
chat_title = chatbot[0][1].split()
|
chat_title = chatbot[0][1].split()
|
||||||
pattern = re.compile(r'^<div class="markdown-body"><p>|<\/p><\/div>$')
|
i_say = pattern_markdown.sub('', chatbot[-1][0])
|
||||||
i_say = pattern.sub('', chatbot[-1][0])
|
|
||||||
gpt_result = history
|
gpt_result = history
|
||||||
if private_key in chat_title:
|
if private_key in chat_title:
|
||||||
SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({i_say: gpt_result})
|
SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({i_say: gpt_result})
|
||||||
@ -519,17 +481,16 @@ prompt_path = os.path.join(base_path, 'prompt_users')
|
|||||||
|
|
||||||
def reuse_chat(result, chatbot, history, pro_numb):
|
def reuse_chat(result, chatbot, history, pro_numb):
|
||||||
"""复用对话记录"""
|
"""复用对话记录"""
|
||||||
pattern = re.compile(r'^<div class="markdown-body"><p>|<\/p><\/div>$')
|
|
||||||
if result is None or result == []:
|
if result is None or result == []:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if pro_numb:
|
if pro_numb:
|
||||||
chatbot.append(result)
|
chatbot.append(result)
|
||||||
history += [pattern.sub('', i) for i in result]
|
history += [pattern_markdown.sub('', i) for i in result]
|
||||||
else:
|
else:
|
||||||
chatbot.append(result[-1])
|
chatbot.append(result[-1])
|
||||||
history += [pattern.sub('', i) for i in result[-2:]]
|
history += [pattern_markdown.sub('', i) for i in result[-2:]]
|
||||||
i_say = pattern.sub('', chatbot[-1][0])
|
i_say = pattern_markdown.sub('', chatbot[-1][0])
|
||||||
return chatbot, history, i_say, gr.Tabs.update(selected='chatbot'), ''
|
return chatbot, history, i_say, gr.Tabs.update(selected='chatbot'), ''
|
||||||
|
|
||||||
|
|
||||||
@ -541,6 +502,19 @@ def num_tokens_from_string(listing: list, encoding_name: str = 'cl100k_base') ->
|
|||||||
count_tokens += len(encoding.encode(i))
|
count_tokens += len(encoding.encode(i))
|
||||||
return count_tokens
|
return count_tokens
|
||||||
|
|
||||||
|
def spinner_chatbot_loading(chatbot):
|
||||||
|
loading = [''.join(['.' * random.randint(1, 5)])]
|
||||||
|
# 将元组转换为列表并修改元素
|
||||||
|
loading_msg = copy.deepcopy(chatbot)
|
||||||
|
temp_list = list(loading_msg[-1])
|
||||||
|
if pattern_markdown.match(temp_list[1]):
|
||||||
|
temp_list[1] = pattern_markdown.sub('', temp_list[1]) + f'{random.choice(loading)}'
|
||||||
|
else:
|
||||||
|
temp_list[1] = pattern_markdown_p.sub('', temp_list[1]) + f'{random.choice(loading)}'
|
||||||
|
# 将列表转换回元组并替换原始元组
|
||||||
|
loading_msg[-1] = tuple(temp_list)
|
||||||
|
return loading_msg
|
||||||
|
|
||||||
class YamlHandle:
|
class YamlHandle:
|
||||||
|
|
||||||
def __init__(self, file=os.path.join(prompt_path, 'ai_common.yaml')):
|
def __init__(self, file=os.path.join(prompt_path, 'ai_common.yaml')):
|
||||||
@ -591,5 +565,5 @@ class JsonHandle:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
num = num_tokens_from_string(['你好', '您好,请问有什么可以帮助您的吗?'])
|
loading = [''.join(['.' * random.randint(1, 5)])]
|
||||||
print(num)
|
print(loading)
|
||||||
@ -12,6 +12,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import random
|
||||||
import time
|
import time
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import logging
|
import logging
|
||||||
@ -137,8 +138,8 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
raw_input = inputs
|
raw_input = inputs
|
||||||
logging.info(f'[raw_input]_{llm_kwargs["ipaddr"]} {raw_input}')
|
logging.info(f'[raw_input]_{llm_kwargs["ipaddr"]} {raw_input}')
|
||||||
chatbot.append((inputs, ""))
|
chatbot.append((inputs, ""))
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
loading_msg = func_box.spinner_chatbot_loading(chatbot)
|
||||||
|
yield from update_ui(chatbot=loading_msg, history=history, msg="等待响应") # 刷新界面
|
||||||
try:
|
try:
|
||||||
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt, stream)
|
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt, stream)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
@ -162,12 +163,13 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
if retry > MAX_RETRY: raise TimeoutError
|
if retry > MAX_RETRY: raise TimeoutError
|
||||||
|
|
||||||
gpt_replying_buffer = ""
|
gpt_replying_buffer = ""
|
||||||
|
|
||||||
is_head_of_the_stream = True
|
is_head_of_the_stream = True
|
||||||
if stream:
|
if stream:
|
||||||
stream_response = response.iter_lines()
|
stream_response = response.iter_lines()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
loading_msg = func_box.spinner_chatbot_loading(chatbot)
|
||||||
|
yield from update_ui(chatbot=loading_msg, history=history)
|
||||||
chunk = next(stream_response)
|
chunk = next(stream_response)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# 非OpenAI官方接口的出现这样的报错,OpenAI和API2D不会走这里
|
# 非OpenAI官方接口的出现这样的报错,OpenAI和API2D不会走这里
|
||||||
|
|||||||
Reference in New Issue
Block a user