diff --git a/__main__.py b/__main__.py index 3f57400..5a66094 100644 --- a/__main__.py +++ b/__main__.py @@ -153,8 +153,8 @@ class ChatBot(ChatBotFrame): prompt_list, devs_document = get_conf('prompt_list', 'devs_document') with gr.Row(): # self.cpopyBtn = gr.Button("复制回答", variant="secondary").style(size="sm") - self.resetBtn = gr.Button("重置Chatbot", variant="secondary", elem_id='empty_btn').style(size="sm") - self.stopBtn = gr.Button("停止", variant="stop").style(size="sm") + self.resetBtn = gr.Button("新建对话", variant="secondary", elem_id='empty_btn').style(size="sm") + self.stopBtn = gr.Button("中止对话", variant="stop").style(size="sm") with gr.Tab('Function'): with gr.Accordion("基础功能区", open=True) as self.area_basic_fn: with gr.Row(): @@ -395,6 +395,7 @@ class ChatBot(ChatBotFrame): self.signals_public() self.signals_prompt_edit() # self.signals_auto_input() + self.demo.load(fn=func_box.refresh_load_data, postprocess=False, inputs=[self.chatbot, self.history, self.pro_fp_state], outputs=[self.pro_func_prompt, self.pro_fp_state, self.chatbot, self.history, ]) # Start self.auto_opentab_delay() diff --git a/func_box.py b/func_box.py index e9b6d9f..f2e60d1 100644 --- a/func_box.py +++ b/func_box.py @@ -236,7 +236,7 @@ def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request): return prompt.update(samples=data, visible=True), prompt -def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''): +def diff_list(txt='', percent=0.70, switch: list = None, lst: dict = None, sp=15, hosts=''): """ 按照搜索结果统计相似度的文本,两组文本相似度>70%的将统计在一起,取最长的作为key Args: @@ -250,8 +250,16 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15 返回一个列表 """ count_dict = {} + is_all = toolbox.get_conf('prompt_list')[0]['key'][1] if not lst: - lst = SqliteHandle('ai_common').get_prompt_value(txt) + lst = {} + tabs = SqliteHandle().get_tables() + if is_all in switch: + lst.update(SqliteHandle(f"ai_common_{hosts}").get_prompt_value(txt)) + else: + for tab in tabs: + if tab.startswith('ai_common'): + lst.update(SqliteHandle(f"{tab}").get_prompt_value(txt)) lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value(txt)) # diff 数据,根据precent系数归类数据 str_ = time.time() @@ -269,7 +277,7 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15 found = True break if not found: count_dict[i] = 1 - with ThreadPoolExecutor(1000) as executor: + with ThreadPoolExecutor(100) as executor: executor.map(tf_factor_calcul, lst) print('计算耗时', time.time()-str_) sorted_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True) @@ -481,7 +489,7 @@ def thread_write_chat(chatbot, history): if private_key in chat_title: SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({i_say: gpt_result}) else: - SqliteHandle(f'ai_common').inset_prompt({i_say: gpt_result}) + SqliteHandle(f'ai_common_{chat_title[-2]}').inset_prompt({i_say: gpt_result}) base_path = os.path.dirname(__file__) @@ -527,6 +535,23 @@ def spinner_chatbot_loading(chatbot): return loading_msg +def refresh_load_data(chat, history, prompt): + """ + Args: + chat: 聊天组件 + history: 对话记录 + prompt: prompt dataset组件 + + Returns: + 预期是每次刷新页面,加载最新 + """ + is_all = toolbox.get_conf('prompt_list')[0]['key'][0] + data = prompt_retrieval(is_all=[is_all]) + prompt.samples = data + return prompt.update(samples=data, visible=True), prompt, chat, history + + + def txt_converter_json(input_string): try: if input_string.startswith("{") and input_string.endswith("}"):