增加刷新页面时重载Prompt 数据,关联Prompt和搜索条件,当选择个人时,只会搜到自己的记录

This commit is contained in:
w_xiaolizu
2023-06-18 03:11:23 +08:00
parent 0624bf6e89
commit c7add2dd11
2 changed files with 32 additions and 6 deletions

View File

@ -153,8 +153,8 @@ class ChatBot(ChatBotFrame):
prompt_list, devs_document = get_conf('prompt_list', 'devs_document') prompt_list, devs_document = get_conf('prompt_list', 'devs_document')
with gr.Row(): with gr.Row():
# self.cpopyBtn = gr.Button("复制回答", variant="secondary").style(size="sm") # self.cpopyBtn = gr.Button("复制回答", variant="secondary").style(size="sm")
self.resetBtn = gr.Button("重置Chatbot", variant="secondary", elem_id='empty_btn').style(size="sm") self.resetBtn = gr.Button("新建对话", variant="secondary", elem_id='empty_btn').style(size="sm")
self.stopBtn = gr.Button("停止", variant="stop").style(size="sm") self.stopBtn = gr.Button("中止对话", variant="stop").style(size="sm")
with gr.Tab('Function'): with gr.Tab('Function'):
with gr.Accordion("基础功能区", open=True) as self.area_basic_fn: with gr.Accordion("基础功能区", open=True) as self.area_basic_fn:
with gr.Row(): with gr.Row():
@ -395,6 +395,7 @@ class ChatBot(ChatBotFrame):
self.signals_public() self.signals_public()
self.signals_prompt_edit() self.signals_prompt_edit()
# self.signals_auto_input() # self.signals_auto_input()
self.demo.load(fn=func_box.refresh_load_data, postprocess=False, inputs=[self.chatbot, self.history, self.pro_fp_state], outputs=[self.pro_func_prompt, self.pro_fp_state, self.chatbot, self.history, ])
# Start # Start
self.auto_opentab_delay() self.auto_opentab_delay()

View File

@ -236,7 +236,7 @@ def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request):
return prompt.update(samples=data, visible=True), prompt return prompt.update(samples=data, visible=True), prompt
def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''): def diff_list(txt='', percent=0.70, switch: list = None, lst: dict = None, sp=15, hosts=''):
""" """
按照搜索结果统计相似度的文本,两组文本相似度>70%的将统计在一起取最长的作为key 按照搜索结果统计相似度的文本,两组文本相似度>70%的将统计在一起取最长的作为key
Args: Args:
@ -250,8 +250,16 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
返回一个列表 返回一个列表
""" """
count_dict = {} count_dict = {}
is_all = toolbox.get_conf('prompt_list')[0]['key'][1]
if not lst: if not lst:
lst = SqliteHandle('ai_common').get_prompt_value(txt) lst = {}
tabs = SqliteHandle().get_tables()
if is_all in switch:
lst.update(SqliteHandle(f"ai_common_{hosts}").get_prompt_value(txt))
else:
for tab in tabs:
if tab.startswith('ai_common'):
lst.update(SqliteHandle(f"{tab}").get_prompt_value(txt))
lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value(txt)) lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value(txt))
# diff 数据根据precent系数归类数据 # diff 数据根据precent系数归类数据
str_ = time.time() str_ = time.time()
@ -269,7 +277,7 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
found = True found = True
break break
if not found: count_dict[i] = 1 if not found: count_dict[i] = 1
with ThreadPoolExecutor(1000) as executor: with ThreadPoolExecutor(100) as executor:
executor.map(tf_factor_calcul, lst) executor.map(tf_factor_calcul, lst)
print('计算耗时', time.time()-str_) print('计算耗时', time.time()-str_)
sorted_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True) sorted_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)
@ -481,7 +489,7 @@ def thread_write_chat(chatbot, history):
if private_key in chat_title: if private_key in chat_title:
SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({i_say: gpt_result}) SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({i_say: gpt_result})
else: else:
SqliteHandle(f'ai_common').inset_prompt({i_say: gpt_result}) SqliteHandle(f'ai_common_{chat_title[-2]}').inset_prompt({i_say: gpt_result})
base_path = os.path.dirname(__file__) base_path = os.path.dirname(__file__)
@ -527,6 +535,23 @@ def spinner_chatbot_loading(chatbot):
return loading_msg return loading_msg
def refresh_load_data(chat, history, prompt):
"""
Args:
chat: 聊天组件
history: 对话记录
prompt: prompt dataset组件
Returns:
预期是每次刷新页面,加载最新
"""
is_all = toolbox.get_conf('prompt_list')[0]['key'][0]
data = prompt_retrieval(is_all=[is_all])
prompt.samples = data
return prompt.update(samples=data, visible=True), prompt, chat, history
def txt_converter_json(input_string): def txt_converter_json(input_string):
try: try:
if input_string.startswith("{") and input_string.endswith("}"): if input_string.startswith("{") and input_string.endswith("}"):