将文件存储换成sqlite3 | 对话可以更多花样

This commit is contained in:
w_xiaolizu
2023-05-17 21:48:19 +08:00
parent c5d4f60137
commit 1e28a4feea
4 changed files with 120 additions and 56 deletions

View File

@ -7,6 +7,8 @@ import hashlib
import json
import os.path
import subprocess
import threading
import time
import psutil
import re
import tempfile
@ -22,6 +24,8 @@ from scipy.linalg import norm
import pyperclip
import random
import gradio as gr
import toolbox
from prompt_generator import SqliteHandle
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
官网https://docs.python.org/3/library/contextlib.html"""
@ -72,6 +76,16 @@ class Shell(object):
self.__temp += i
yield self.__temp
def timeStatistics(func):
def statistics(*args, **kwargs):
startTiem = time.time()
obj = func(*args, **kwargs)
endTiem = time.time()
ums = startTiem - endTiem
print('func:{} > Time-consuming: {}'.format(func, ums))
return obj
return statistics
def context_with(*parms):
"""
一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。
@ -211,8 +225,8 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
import difflib
count_dict = {}
if not lst:
lst = YamlHandle().load()
lst.update(YamlHandle(os.path.join(prompt_path, f"ai_private_{hosts}.yaml")).load())
lst = SqliteHandle('ai_common').get_prompt_value()
lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value())
# diff 数据根据precent系数归类数据
for i in lst:
found = False
@ -253,7 +267,7 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
def search_list(txt, sp=15):
lst = YamlHandle().load()
lst = SqliteHandle('ai_common').get_prompt_value()
dateset_list = []
for key in lst:
index = key.find(txt)
@ -268,37 +282,31 @@ def search_list(txt, sp=15):
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
hosts = ipaddr.client.host
user_file = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
if file.name.endswith('json'):
upload_data = check_json_format(file.name)
if upload_data != {}:
YamlHandle(user_file).dump_dict(upload_data)
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
else:
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
return prompt.samples, prompt, []
elif file.name.endswith('yaml'):
upload_data = YamlHandle(file.name).load()
if upload_data != {} and type(upload_data) is dict:
YamlHandle(user_file).dump_dict(upload_data)
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
else:
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
return prompt.samples, prompt, []
else:
upload_data = {}
if upload_data != {}:
SqliteHandle(f'prompt_{hosts}').inset_prompt(upload_data)
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
else:
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
return prompt.samples, prompt, []
def prompt_retrieval(is_all, hosts='', search=False):
count_dict = {}
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
if '所有人' in is_all:
for root, dirs, files in os.walk(prompt_path):
for f in files:
if f.startswith('prompt') and f.endswith('yaml'):
data = YamlHandle(file=os.path.join(root, f)).load()
if data: count_dict.update(data)
for tab in SqliteHandle('ai_common').get_tables():
if tab.startswith('prompt'):
data = SqliteHandle(tab).get_prompt_value()
if data: count_dict.update(data)
elif '个人' in is_all:
data = YamlHandle(file=user_path).load()
data = SqliteHandle(f'prompt_{hosts}').get_prompt_value()
if data: count_dict.update(data)
retrieval = []
if count_dict != {}:
@ -320,8 +328,8 @@ def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipa
def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request):
if txt and name:
yaml_obj = YamlHandle(os.path.join(prompt_path, f'prompt_{ipaddr.client.host}.yaml'))
yaml_obj.update(name, txt)
yaml_obj = SqliteHandle(f'prompt_{ipaddr.client.host}')
yaml_obj.inset_prompt({name: txt})
result = prompt_retrieval(is_all=checkbox, hosts=ipaddr.client.host)
prompt.samples = result
return "", "", ['个人'], prompt.update(samples=result, samples_per_page=10, visible=True), prompt
@ -351,6 +359,16 @@ def show_prompt_result(index, data: gr.Dataset, chatbot):
chatbot.append((click[1], click[2]))
return chatbot
def thread_write_chat(chatbot):
private_key = toolbox.get_conf('private_key')[0]
chat_title = chatbot[0][0].split()
if private_key in chat_title:
SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({chatbot[-1][0]: chatbot[-1][1]})
else:
SqliteHandle(f'ai_common').inset_prompt({chatbot[-1][0]: chatbot[-1][1]})
base_path = os.path.dirname(__file__)
prompt_path = os.path.join(base_path, 'prompt_users')
@ -360,6 +378,7 @@ class YamlHandle:
if not os.path.exists(file):
Shell(f'touch {file}').read()
self.file = file
self._load = self.load()
def load(self) -> dict:
@ -368,7 +387,7 @@ class YamlHandle:
return data
def update(self, key, value):
date = self.load()
date = self._load
if not date:
date = {}
date[key] = value
@ -377,7 +396,7 @@ class YamlHandle:
return date
def dump_dict(self, new_dict):
date = self.load()
date = self._load
if not date:
date = {}
date.update(new_dict)
@ -417,6 +436,7 @@ class FileHandle:
if __name__ == '__main__':
print(YamlHandle().load())
for i in YamlHandle().load():
print(i)