增加刷新页面时重载Prompt 数据,关联Prompt和搜索条件,当选择个人时,只会搜到自己的记录

This commit is contained in:
w_xiaolizu
2023-06-18 03:11:23 +08:00
parent 0624bf6e89
commit c7add2dd11
2 changed files with 32 additions and 6 deletions

View File

@ -236,7 +236,7 @@ def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request):
return prompt.update(samples=data, visible=True), prompt
def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''):
def diff_list(txt='', percent=0.70, switch: list = None, lst: dict = None, sp=15, hosts=''):
"""
按照搜索结果统计相似度的文本,两组文本相似度>70%的将统计在一起取最长的作为key
Args:
@ -250,8 +250,16 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
返回一个列表
"""
count_dict = {}
is_all = toolbox.get_conf('prompt_list')[0]['key'][1]
if not lst:
lst = SqliteHandle('ai_common').get_prompt_value(txt)
lst = {}
tabs = SqliteHandle().get_tables()
if is_all in switch:
lst.update(SqliteHandle(f"ai_common_{hosts}").get_prompt_value(txt))
else:
for tab in tabs:
if tab.startswith('ai_common'):
lst.update(SqliteHandle(f"{tab}").get_prompt_value(txt))
lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value(txt))
# diff 数据根据precent系数归类数据
str_ = time.time()
@ -269,7 +277,7 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
found = True
break
if not found: count_dict[i] = 1
with ThreadPoolExecutor(1000) as executor:
with ThreadPoolExecutor(100) as executor:
executor.map(tf_factor_calcul, lst)
print('计算耗时', time.time()-str_)
sorted_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)
@ -481,7 +489,7 @@ def thread_write_chat(chatbot, history):
if private_key in chat_title:
SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({i_say: gpt_result})
else:
SqliteHandle(f'ai_common').inset_prompt({i_say: gpt_result})
SqliteHandle(f'ai_common_{chat_title[-2]}').inset_prompt({i_say: gpt_result})
base_path = os.path.dirname(__file__)
@ -527,6 +535,23 @@ def spinner_chatbot_loading(chatbot):
return loading_msg
def refresh_load_data(chat, history, prompt):
"""
Args:
chat: 聊天组件
history: 对话记录
prompt: prompt dataset组件
Returns:
预期是每次刷新页面,加载最新
"""
is_all = toolbox.get_conf('prompt_list')[0]['key'][0]
data = prompt_retrieval(is_all=[is_all])
prompt.samples = data
return prompt.update(samples=data, visible=True), prompt, chat, history
def txt_converter_json(input_string):
try:
if input_string.startswith("{") and input_string.endswith("}"):