将文件存储换成sqlite3 | 对话可以更多花样

This commit is contained in:
w_xiaolizu
2023-05-17 21:48:19 +08:00
parent c5d4f60137
commit 1e28a4feea
4 changed files with 120 additions and 56 deletions

View File

@ -1,15 +0,0 @@
#! .\venv\
# encoding: utf-8
# @Time : 2023/4/20
# @Author : Spike
# @Descr :
def chat_with_ai(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt):
history = []
pass

View File

@ -7,6 +7,8 @@ import hashlib
import json
import os.path
import subprocess
import threading
import time
import psutil
import re
import tempfile
@ -22,6 +24,8 @@ from scipy.linalg import norm
import pyperclip
import random
import gradio as gr
import toolbox
from prompt_generator import SqliteHandle
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
官网https://docs.python.org/3/library/contextlib.html"""
@ -72,6 +76,16 @@ class Shell(object):
self.__temp += i
yield self.__temp
def timeStatistics(func):
def statistics(*args, **kwargs):
startTiem = time.time()
obj = func(*args, **kwargs)
endTiem = time.time()
ums = startTiem - endTiem
print('func:{} > Time-consuming: {}'.format(func, ums))
return obj
return statistics
def context_with(*parms):
"""
一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。
@ -211,8 +225,8 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
import difflib
count_dict = {}
if not lst:
lst = YamlHandle().load()
lst.update(YamlHandle(os.path.join(prompt_path, f"ai_private_{hosts}.yaml")).load())
lst = SqliteHandle('ai_common').get_prompt_value()
lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value())
# diff 数据根据precent系数归类数据
for i in lst:
found = False
@ -253,7 +267,7 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
def search_list(txt, sp=15):
lst = YamlHandle().load()
lst = SqliteHandle('ai_common').get_prompt_value()
dateset_list = []
for key in lst:
index = key.find(txt)
@ -268,37 +282,31 @@ def search_list(txt, sp=15):
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
hosts = ipaddr.client.host
user_file = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
if file.name.endswith('json'):
upload_data = check_json_format(file.name)
if upload_data != {}:
YamlHandle(user_file).dump_dict(upload_data)
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
else:
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
return prompt.samples, prompt, []
elif file.name.endswith('yaml'):
upload_data = YamlHandle(file.name).load()
if upload_data != {} and type(upload_data) is dict:
YamlHandle(user_file).dump_dict(upload_data)
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
else:
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
return prompt.samples, prompt, []
else:
upload_data = {}
if upload_data != {}:
SqliteHandle(f'prompt_{hosts}').inset_prompt(upload_data)
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
else:
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
return prompt.samples, prompt, []
def prompt_retrieval(is_all, hosts='', search=False):
count_dict = {}
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
if '所有人' in is_all:
for root, dirs, files in os.walk(prompt_path):
for f in files:
if f.startswith('prompt') and f.endswith('yaml'):
data = YamlHandle(file=os.path.join(root, f)).load()
if data: count_dict.update(data)
for tab in SqliteHandle('ai_common').get_tables():
if tab.startswith('prompt'):
data = SqliteHandle(tab).get_prompt_value()
if data: count_dict.update(data)
elif '个人' in is_all:
data = YamlHandle(file=user_path).load()
data = SqliteHandle(f'prompt_{hosts}').get_prompt_value()
if data: count_dict.update(data)
retrieval = []
if count_dict != {}:
@ -320,8 +328,8 @@ def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipa
def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request):
if txt and name:
yaml_obj = YamlHandle(os.path.join(prompt_path, f'prompt_{ipaddr.client.host}.yaml'))
yaml_obj.update(name, txt)
yaml_obj = SqliteHandle(f'prompt_{ipaddr.client.host}')
yaml_obj.inset_prompt({name: txt})
result = prompt_retrieval(is_all=checkbox, hosts=ipaddr.client.host)
prompt.samples = result
return "", "", ['个人'], prompt.update(samples=result, samples_per_page=10, visible=True), prompt
@ -351,6 +359,16 @@ def show_prompt_result(index, data: gr.Dataset, chatbot):
chatbot.append((click[1], click[2]))
return chatbot
def thread_write_chat(chatbot):
private_key = toolbox.get_conf('private_key')[0]
chat_title = chatbot[0][0].split()
if private_key in chat_title:
SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({chatbot[-1][0]: chatbot[-1][1]})
else:
SqliteHandle(f'ai_common').inset_prompt({chatbot[-1][0]: chatbot[-1][1]})
base_path = os.path.dirname(__file__)
prompt_path = os.path.join(base_path, 'prompt_users')
@ -360,6 +378,7 @@ class YamlHandle:
if not os.path.exists(file):
Shell(f'touch {file}').read()
self.file = file
self._load = self.load()
def load(self) -> dict:
@ -368,7 +387,7 @@ class YamlHandle:
return data
def update(self, key, value):
date = self.load()
date = self._load
if not date:
date = {}
date[key] = value
@ -377,7 +396,7 @@ class YamlHandle:
return date
def dump_dict(self, new_dict):
date = self.load()
date = self._load
if not date:
date = {}
date.update(new_dict)
@ -417,6 +436,7 @@ class FileHandle:
if __name__ == '__main__':
print(YamlHandle().load())
for i in YamlHandle().load():
print(i)

View File

@ -3,10 +3,73 @@
# @Time : 2023/4/19
# @Author : Spike
# @Descr :
# 默认的prompt
import os.path
import sqlite3
import threading
import functools
import func_box
# 连接到数据库
base_path = os.path.dirname(__file__)
prompt_path = os.path.join(base_path, 'prompt_users')
def connect_db_close(cls_method):
@functools.wraps(cls_method)
def wrapper(cls=None, *args, **kwargs):
cls._connect_db()
result = cls_method(cls, *args, **kwargs)
cls._close_db()
return result
return wrapper
class SqliteHandle:
def __init__(self, table='ai_common'):
self.__connect = sqlite3.connect(os.path.join(prompt_path, 'ai_prompt.db'))
self.__cursor = self.__connect.cursor()
self.__table = table
if self.__table not in self.get_tables():
self.create_tab()
def new_connect_db(self):
"""多线程操作时每个线程新建独立的connect"""
self.__connect = sqlite3.connect(os.path.join(prompt_path, 'ai_prompt.db'))
self.__cursor = self.__connect.cursor()
def new_close_db(self):
self.__cursor.close()
self.__connect.close()
def create_tab(self):
self.__cursor.execute(f"CREATE TABLE `{self.__table}` ('id' INTEGER PRIMARY KEY AUTOINCREMENT, 'prompt' TEXT, 'result' TEXT)")
def get_tables(self):
all_tab = []
result = self.__cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
for tab in result:
all_tab.append(tab[0])
return all_tab
def get_prompt_value(self):
temp_all = {}
result = self.__cursor.execute(f"SELECT prompt, result FROM `{self.__table}`").fetchall()
for row in result:
temp_all[row[0]] = row[1]
return temp_all
def inset_prompt(self, prompt: dict):
for key in prompt:
self.__cursor.execute(f"INSERT INTO `{self.__table}` (prompt, result) VALUES (?, ?);", (str(key), str(prompt[key])))
self.__connect.commit()
def delete_prompt(self):
self.__cursor.execute(f"DELETE from `{self.__table}` where id BETWEEN 1 AND 21")
self.__connect.commit()
sqlite_handle = SqliteHandle
if __name__ == '__main__':
# print(sqlite_handle('ai_common').inset_prompt(test))
# sqlite_handle('ai_common').delete_prompt()
print(sqlite_handle('ai_common').get_prompt_value())

View File

@ -13,6 +13,7 @@ import shutil
import os
import time
import glob
from concurrent.futures import ThreadPoolExecutor
############################### 插件输入输出接驳区 #######################################
"""
@ -91,20 +92,14 @@ def ArgsGeneralWrapper(f):
return decorated
pool = ThreadPoolExecutor(200)
def update_ui(chatbot, history, msg='正常', txt='', *args): # 刷新界面
"""
刷新用户界面
"""
private_key = get_conf('private_key')[0]
chat_title = chatbot[0][0].split()
if private_key in chat_title:
private_path = os.path.join(func_box.prompt_path, f"ai_private_{chat_title[-2]}.yaml")
func_box.YamlHandle(private_path).update(key=chatbot[-1][0], value=chatbot[-1][1])
else:
func_box.YamlHandle().update(key=chatbot[-1][0], value=chatbot[-1][1])
assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时可用clear将其清空然后用for+append循环重新赋值。"
yield chatbot.get_cookies(), chatbot, history, msg, txt
pool.submit(func_box.thread_write_chat, chatbot)
def trimmed_format_exc():
import os, traceback
@ -254,8 +249,8 @@ def text_divide_paragraph(text):
else:
# wtf input
lines = text.split("\n")
for i, line in enumerate(lines):
lines[i] = lines[i].replace(" ", " ")
# for i, line in enumerate(lines):
# lines[i] = lines[i].replace(" ", " ")
text = "</br>".join(lines)
return text
@ -373,6 +368,7 @@ def format_io(self, y):
gpt_reply = close_up_code_segment_during_stream(gpt_reply) # 当代码输出半截的时候,试着补上后个```
y[-1] = (
None if i_ask is None else markdown.markdown(i_ask, extensions=['fenced_code', 'tables']),
#None if i_ask is None else markdown_convertion(i_ask),
None if gpt_reply is None else markdown_convertion(gpt_reply)
)
return y