diff --git a/func_box.py b/func_box.py index a26d641..a4e6260 100644 --- a/func_box.py +++ b/func_box.py @@ -4,6 +4,7 @@ # @Author : Spike # @Descr : import ast +import copy import hashlib import io import json @@ -101,45 +102,6 @@ def timeStatistics(func): return statistics -def context_with(*parms): - """ - 一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。 - Args: - *parms: 参数列表,每个参数都是一个字符串,表示类中的一个属性名。 - Returns: - 一个装饰器函数。 - """ - - def decorator(cls_method): - """ - 装饰器函数,用于将一个类方法转换为一个嵌套多个 with 语句的方法。 - Args: - cls_method: 要装饰的类方法。 - Returns: - 装饰后的类方法。 - """ - - def wrapper(cls='', *args, **kwargs): - """ - 装饰后的方法,用于嵌套多个 with 语句,并调用原始的类方法。 - Args: - cls: 类的实例对象。 - *args: 位置参数。 - **kwargs: 关键字参数。 - Returns: - 原始的类方法返回的结果。 - """ - with_list = [getattr(cls, arg) for arg in parms] - with ExitStack() as stack: - for context in with_list: - stack.enter_context(context) - return cls_method(cls, *args, **kwargs) - - return wrapper - - return decorator - - def copy_temp_file(file): if os.path.exists(file): exdir = tempfile.mkdtemp() @@ -497,15 +459,15 @@ def show_prompt_result(index, data: gr.Dataset, chatbot): return chatbot - +pattern_markdown = re.compile(r'^

|<\/p><\/div>$') +pattern_markdown_p = re.compile(r'^

|<\/div>$') def thread_write_chat(chatbot, history): """ 对话记录写入数据库 """ private_key = toolbox.get_conf('private_key')[0] chat_title = chatbot[0][1].split() - pattern = re.compile(r'^

|<\/p><\/div>$') - i_say = pattern.sub('', chatbot[-1][0]) + i_say = pattern_markdown.sub('', chatbot[-1][0]) gpt_result = history if private_key in chat_title: SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({i_say: gpt_result}) @@ -519,17 +481,16 @@ prompt_path = os.path.join(base_path, 'prompt_users') def reuse_chat(result, chatbot, history, pro_numb): """复用对话记录""" - pattern = re.compile(r'^

|<\/p><\/div>$') if result is None or result == []: pass else: if pro_numb: chatbot.append(result) - history += [pattern.sub('', i) for i in result] + history += [pattern_markdown.sub('', i) for i in result] else: chatbot.append(result[-1]) - history += [pattern.sub('', i) for i in result[-2:]] - i_say = pattern.sub('', chatbot[-1][0]) + history += [pattern_markdown.sub('', i) for i in result[-2:]] + i_say = pattern_markdown.sub('', chatbot[-1][0]) return chatbot, history, i_say, gr.Tabs.update(selected='chatbot'), '' @@ -541,6 +502,19 @@ def num_tokens_from_string(listing: list, encoding_name: str = 'cl100k_base') -> count_tokens += len(encoding.encode(i)) return count_tokens +def spinner_chatbot_loading(chatbot): + loading = [''.join(['.' * random.randint(1, 5)])] + # 将元组转换为列表并修改元素 + loading_msg = copy.deepcopy(chatbot) + temp_list = list(loading_msg[-1]) + if pattern_markdown.match(temp_list[1]): + temp_list[1] = pattern_markdown.sub('', temp_list[1]) + f'{random.choice(loading)}' + else: + temp_list[1] = pattern_markdown_p.sub('', temp_list[1]) + f'{random.choice(loading)}' + # 将列表转换回元组并替换原始元组 + loading_msg[-1] = tuple(temp_list) + return loading_msg + class YamlHandle: def __init__(self, file=os.path.join(prompt_path, 'ai_common.yaml')): @@ -591,5 +565,5 @@ class JsonHandle: if __name__ == '__main__': - num = num_tokens_from_string(['你好', '您好,请问有什么可以帮助您的吗?']) - print(num) \ No newline at end of file + loading = [''.join(['.' * random.randint(1, 5)])] + print(loading) \ No newline at end of file diff --git a/request_llm/bridge_chatgpt.py b/request_llm/bridge_chatgpt.py index aa2099f..bb21e3f 100644 --- a/request_llm/bridge_chatgpt.py +++ b/request_llm/bridge_chatgpt.py @@ -12,6 +12,7 @@ """ import json +import random import time import gradio as gr import logging @@ -137,8 +138,8 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp raw_input = inputs logging.info(f'[raw_input]_{llm_kwargs["ipaddr"]} {raw_input}') chatbot.append((inputs, "")) - yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面 - + loading_msg = func_box.spinner_chatbot_loading(chatbot) + yield from update_ui(chatbot=loading_msg, history=history, msg="等待响应") # 刷新界面 try: headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt, stream) except RuntimeError as e: @@ -162,12 +163,13 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp if retry > MAX_RETRY: raise TimeoutError gpt_replying_buffer = "" - is_head_of_the_stream = True if stream: stream_response = response.iter_lines() while True: try: + loading_msg = func_box.spinner_chatbot_loading(chatbot) + yield from update_ui(chatbot=loading_msg, history=history) chunk = next(stream_response) except StopIteration: # 非OpenAI官方接口的出现这样的报错,OpenAI和API2D不会走这里