增加刷新页面时重载Prompt 数据,关联Prompt和搜索条件,当选择个人时,只会搜到自己的记录
This commit is contained in:
@ -153,8 +153,8 @@ class ChatBot(ChatBotFrame):
|
|||||||
prompt_list, devs_document = get_conf('prompt_list', 'devs_document')
|
prompt_list, devs_document = get_conf('prompt_list', 'devs_document')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
# self.cpopyBtn = gr.Button("复制回答", variant="secondary").style(size="sm")
|
# self.cpopyBtn = gr.Button("复制回答", variant="secondary").style(size="sm")
|
||||||
self.resetBtn = gr.Button("重置Chatbot", variant="secondary", elem_id='empty_btn').style(size="sm")
|
self.resetBtn = gr.Button("新建对话", variant="secondary", elem_id='empty_btn').style(size="sm")
|
||||||
self.stopBtn = gr.Button("停止", variant="stop").style(size="sm")
|
self.stopBtn = gr.Button("中止对话", variant="stop").style(size="sm")
|
||||||
with gr.Tab('Function'):
|
with gr.Tab('Function'):
|
||||||
with gr.Accordion("基础功能区", open=True) as self.area_basic_fn:
|
with gr.Accordion("基础功能区", open=True) as self.area_basic_fn:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -395,6 +395,7 @@ class ChatBot(ChatBotFrame):
|
|||||||
self.signals_public()
|
self.signals_public()
|
||||||
self.signals_prompt_edit()
|
self.signals_prompt_edit()
|
||||||
# self.signals_auto_input()
|
# self.signals_auto_input()
|
||||||
|
self.demo.load(fn=func_box.refresh_load_data, postprocess=False, inputs=[self.chatbot, self.history, self.pro_fp_state], outputs=[self.pro_func_prompt, self.pro_fp_state, self.chatbot, self.history, ])
|
||||||
|
|
||||||
# Start
|
# Start
|
||||||
self.auto_opentab_delay()
|
self.auto_opentab_delay()
|
||||||
|
|||||||
33
func_box.py
33
func_box.py
@ -236,7 +236,7 @@ def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request):
|
|||||||
return prompt.update(samples=data, visible=True), prompt
|
return prompt.update(samples=data, visible=True), prompt
|
||||||
|
|
||||||
|
|
||||||
def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''):
|
def diff_list(txt='', percent=0.70, switch: list = None, lst: dict = None, sp=15, hosts=''):
|
||||||
"""
|
"""
|
||||||
按照搜索结果统计相似度的文本,两组文本相似度>70%的将统计在一起,取最长的作为key
|
按照搜索结果统计相似度的文本,两组文本相似度>70%的将统计在一起,取最长的作为key
|
||||||
Args:
|
Args:
|
||||||
@ -250,8 +250,16 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
|
|||||||
返回一个列表
|
返回一个列表
|
||||||
"""
|
"""
|
||||||
count_dict = {}
|
count_dict = {}
|
||||||
|
is_all = toolbox.get_conf('prompt_list')[0]['key'][1]
|
||||||
if not lst:
|
if not lst:
|
||||||
lst = SqliteHandle('ai_common').get_prompt_value(txt)
|
lst = {}
|
||||||
|
tabs = SqliteHandle().get_tables()
|
||||||
|
if is_all in switch:
|
||||||
|
lst.update(SqliteHandle(f"ai_common_{hosts}").get_prompt_value(txt))
|
||||||
|
else:
|
||||||
|
for tab in tabs:
|
||||||
|
if tab.startswith('ai_common'):
|
||||||
|
lst.update(SqliteHandle(f"{tab}").get_prompt_value(txt))
|
||||||
lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value(txt))
|
lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value(txt))
|
||||||
# diff 数据,根据precent系数归类数据
|
# diff 数据,根据precent系数归类数据
|
||||||
str_ = time.time()
|
str_ = time.time()
|
||||||
@ -269,7 +277,7 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
|
|||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
if not found: count_dict[i] = 1
|
if not found: count_dict[i] = 1
|
||||||
with ThreadPoolExecutor(1000) as executor:
|
with ThreadPoolExecutor(100) as executor:
|
||||||
executor.map(tf_factor_calcul, lst)
|
executor.map(tf_factor_calcul, lst)
|
||||||
print('计算耗时', time.time()-str_)
|
print('计算耗时', time.time()-str_)
|
||||||
sorted_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)
|
sorted_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)
|
||||||
@ -481,7 +489,7 @@ def thread_write_chat(chatbot, history):
|
|||||||
if private_key in chat_title:
|
if private_key in chat_title:
|
||||||
SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({i_say: gpt_result})
|
SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({i_say: gpt_result})
|
||||||
else:
|
else:
|
||||||
SqliteHandle(f'ai_common').inset_prompt({i_say: gpt_result})
|
SqliteHandle(f'ai_common_{chat_title[-2]}').inset_prompt({i_say: gpt_result})
|
||||||
|
|
||||||
|
|
||||||
base_path = os.path.dirname(__file__)
|
base_path = os.path.dirname(__file__)
|
||||||
@ -527,6 +535,23 @@ def spinner_chatbot_loading(chatbot):
|
|||||||
return loading_msg
|
return loading_msg
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_load_data(chat, history, prompt):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
chat: 聊天组件
|
||||||
|
history: 对话记录
|
||||||
|
prompt: prompt dataset组件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预期是每次刷新页面,加载最新
|
||||||
|
"""
|
||||||
|
is_all = toolbox.get_conf('prompt_list')[0]['key'][0]
|
||||||
|
data = prompt_retrieval(is_all=[is_all])
|
||||||
|
prompt.samples = data
|
||||||
|
return prompt.update(samples=data, visible=True), prompt, chat, history
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def txt_converter_json(input_string):
|
def txt_converter_json(input_string):
|
||||||
try:
|
try:
|
||||||
if input_string.startswith("{") and input_string.endswith("}"):
|
if input_string.startswith("{") and input_string.endswith("}"):
|
||||||
|
|||||||
Reference in New Issue
Block a user