将文件存储换成sqlite3 | 对话可以更多花样
This commit is contained in:
@ -1,15 +0,0 @@
|
||||
#! .\venv\
|
||||
# encoding: utf-8
|
||||
# @Time : 2023/4/20
|
||||
# @Author : Spike
|
||||
# @Descr :
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def chat_with_ai(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt):
|
||||
|
||||
history = []
|
||||
|
||||
pass
|
||||
78
func_box.py
78
func_box.py
@ -7,6 +7,8 @@ import hashlib
|
||||
import json
|
||||
import os.path
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import psutil
|
||||
import re
|
||||
import tempfile
|
||||
@ -22,6 +24,8 @@ from scipy.linalg import norm
|
||||
import pyperclip
|
||||
import random
|
||||
import gradio as gr
|
||||
import toolbox
|
||||
from prompt_generator import SqliteHandle
|
||||
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
|
||||
官网:https://docs.python.org/3/library/contextlib.html"""
|
||||
|
||||
@ -72,6 +76,16 @@ class Shell(object):
|
||||
self.__temp += i
|
||||
yield self.__temp
|
||||
|
||||
def timeStatistics(func):
|
||||
def statistics(*args, **kwargs):
|
||||
startTiem = time.time()
|
||||
obj = func(*args, **kwargs)
|
||||
endTiem = time.time()
|
||||
ums = startTiem - endTiem
|
||||
print('func:{} > Time-consuming: {}'.format(func, ums))
|
||||
return obj
|
||||
return statistics
|
||||
|
||||
def context_with(*parms):
|
||||
"""
|
||||
一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。
|
||||
@ -211,8 +225,8 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
|
||||
import difflib
|
||||
count_dict = {}
|
||||
if not lst:
|
||||
lst = YamlHandle().load()
|
||||
lst.update(YamlHandle(os.path.join(prompt_path, f"ai_private_{hosts}.yaml")).load())
|
||||
lst = SqliteHandle('ai_common').get_prompt_value()
|
||||
lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value())
|
||||
# diff 数据,根据precent系数归类数据
|
||||
for i in lst:
|
||||
found = False
|
||||
@ -253,7 +267,7 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
|
||||
|
||||
|
||||
def search_list(txt, sp=15):
|
||||
lst = YamlHandle().load()
|
||||
lst = SqliteHandle('ai_common').get_prompt_value()
|
||||
dateset_list = []
|
||||
for key in lst:
|
||||
index = key.find(txt)
|
||||
@ -268,37 +282,31 @@ def search_list(txt, sp=15):
|
||||
|
||||
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
|
||||
hosts = ipaddr.client.host
|
||||
user_file = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
|
||||
if file.name.endswith('json'):
|
||||
upload_data = check_json_format(file.name)
|
||||
if upload_data != {}:
|
||||
YamlHandle(user_file).dump_dict(upload_data)
|
||||
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
|
||||
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
|
||||
else:
|
||||
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
|
||||
return prompt.samples, prompt, []
|
||||
elif file.name.endswith('yaml'):
|
||||
upload_data = YamlHandle(file.name).load()
|
||||
if upload_data != {} and type(upload_data) is dict:
|
||||
YamlHandle(user_file).dump_dict(upload_data)
|
||||
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
|
||||
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
|
||||
else:
|
||||
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
|
||||
return prompt.samples, prompt, []
|
||||
else:
|
||||
upload_data = {}
|
||||
if upload_data != {}:
|
||||
SqliteHandle(f'prompt_{hosts}').inset_prompt(upload_data)
|
||||
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
|
||||
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
|
||||
else:
|
||||
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
|
||||
return prompt.samples, prompt, []
|
||||
|
||||
|
||||
def prompt_retrieval(is_all, hosts='', search=False):
|
||||
count_dict = {}
|
||||
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
|
||||
if '所有人' in is_all:
|
||||
for root, dirs, files in os.walk(prompt_path):
|
||||
for f in files:
|
||||
if f.startswith('prompt') and f.endswith('yaml'):
|
||||
data = YamlHandle(file=os.path.join(root, f)).load()
|
||||
if data: count_dict.update(data)
|
||||
for tab in SqliteHandle('ai_common').get_tables():
|
||||
if tab.startswith('prompt'):
|
||||
data = SqliteHandle(tab).get_prompt_value()
|
||||
if data: count_dict.update(data)
|
||||
elif '个人' in is_all:
|
||||
data = YamlHandle(file=user_path).load()
|
||||
data = SqliteHandle(f'prompt_{hosts}').get_prompt_value()
|
||||
if data: count_dict.update(data)
|
||||
retrieval = []
|
||||
if count_dict != {}:
|
||||
@ -320,8 +328,8 @@ def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipa
|
||||
|
||||
def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request):
|
||||
if txt and name:
|
||||
yaml_obj = YamlHandle(os.path.join(prompt_path, f'prompt_{ipaddr.client.host}.yaml'))
|
||||
yaml_obj.update(name, txt)
|
||||
yaml_obj = SqliteHandle(f'prompt_{ipaddr.client.host}')
|
||||
yaml_obj.inset_prompt({name: txt})
|
||||
result = prompt_retrieval(is_all=checkbox, hosts=ipaddr.client.host)
|
||||
prompt.samples = result
|
||||
return "", "", ['个人'], prompt.update(samples=result, samples_per_page=10, visible=True), prompt
|
||||
@ -351,6 +359,16 @@ def show_prompt_result(index, data: gr.Dataset, chatbot):
|
||||
chatbot.append((click[1], click[2]))
|
||||
return chatbot
|
||||
|
||||
|
||||
def thread_write_chat(chatbot):
|
||||
private_key = toolbox.get_conf('private_key')[0]
|
||||
chat_title = chatbot[0][0].split()
|
||||
if private_key in chat_title:
|
||||
SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({chatbot[-1][0]: chatbot[-1][1]})
|
||||
else:
|
||||
SqliteHandle(f'ai_common').inset_prompt({chatbot[-1][0]: chatbot[-1][1]})
|
||||
|
||||
|
||||
base_path = os.path.dirname(__file__)
|
||||
prompt_path = os.path.join(base_path, 'prompt_users')
|
||||
|
||||
@ -360,6 +378,7 @@ class YamlHandle:
|
||||
if not os.path.exists(file):
|
||||
Shell(f'touch {file}').read()
|
||||
self.file = file
|
||||
self._load = self.load()
|
||||
|
||||
|
||||
def load(self) -> dict:
|
||||
@ -368,7 +387,7 @@ class YamlHandle:
|
||||
return data
|
||||
|
||||
def update(self, key, value):
|
||||
date = self.load()
|
||||
date = self._load
|
||||
if not date:
|
||||
date = {}
|
||||
date[key] = value
|
||||
@ -377,7 +396,7 @@ class YamlHandle:
|
||||
return date
|
||||
|
||||
def dump_dict(self, new_dict):
|
||||
date = self.load()
|
||||
date = self._load
|
||||
if not date:
|
||||
date = {}
|
||||
date.update(new_dict)
|
||||
@ -417,6 +436,7 @@ class FileHandle:
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(YamlHandle().load())
|
||||
for i in YamlHandle().load():
|
||||
print(i)
|
||||
|
||||
|
||||
|
||||
@ -3,10 +3,73 @@
|
||||
# @Time : 2023/4/19
|
||||
# @Author : Spike
|
||||
# @Descr :
|
||||
|
||||
# 默认的prompt
|
||||
import os.path
|
||||
import sqlite3
|
||||
import threading
|
||||
import functools
|
||||
import func_box
|
||||
# 连接到数据库
|
||||
base_path = os.path.dirname(__file__)
|
||||
prompt_path = os.path.join(base_path, 'prompt_users')
|
||||
|
||||
|
||||
def connect_db_close(cls_method):
|
||||
@functools.wraps(cls_method)
|
||||
def wrapper(cls=None, *args, **kwargs):
|
||||
cls._connect_db()
|
||||
result = cls_method(cls, *args, **kwargs)
|
||||
cls._close_db()
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
class SqliteHandle:
|
||||
def __init__(self, table='ai_common'):
|
||||
self.__connect = sqlite3.connect(os.path.join(prompt_path, 'ai_prompt.db'))
|
||||
self.__cursor = self.__connect.cursor()
|
||||
self.__table = table
|
||||
if self.__table not in self.get_tables():
|
||||
self.create_tab()
|
||||
|
||||
def new_connect_db(self):
|
||||
"""多线程操作时,每个线程新建独立的connect"""
|
||||
self.__connect = sqlite3.connect(os.path.join(prompt_path, 'ai_prompt.db'))
|
||||
self.__cursor = self.__connect.cursor()
|
||||
|
||||
def new_close_db(self):
|
||||
self.__cursor.close()
|
||||
self.__connect.close()
|
||||
|
||||
def create_tab(self):
|
||||
self.__cursor.execute(f"CREATE TABLE `{self.__table}` ('id' INTEGER PRIMARY KEY AUTOINCREMENT, 'prompt' TEXT, 'result' TEXT)")
|
||||
|
||||
def get_tables(self):
|
||||
all_tab = []
|
||||
result = self.__cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
|
||||
for tab in result:
|
||||
all_tab.append(tab[0])
|
||||
return all_tab
|
||||
|
||||
def get_prompt_value(self):
|
||||
temp_all = {}
|
||||
result = self.__cursor.execute(f"SELECT prompt, result FROM `{self.__table}`").fetchall()
|
||||
for row in result:
|
||||
temp_all[row[0]] = row[1]
|
||||
return temp_all
|
||||
|
||||
def inset_prompt(self, prompt: dict):
|
||||
for key in prompt:
|
||||
self.__cursor.execute(f"INSERT INTO `{self.__table}` (prompt, result) VALUES (?, ?);", (str(key), str(prompt[key])))
|
||||
self.__connect.commit()
|
||||
|
||||
def delete_prompt(self):
|
||||
self.__cursor.execute(f"DELETE from `{self.__table}` where id BETWEEN 1 AND 21")
|
||||
self.__connect.commit()
|
||||
|
||||
sqlite_handle = SqliteHandle
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
# print(sqlite_handle('ai_common').inset_prompt(test))
|
||||
# sqlite_handle('ai_common').delete_prompt()
|
||||
print(sqlite_handle('ai_common').get_prompt_value())
|
||||
|
||||
16
toolbox.py
16
toolbox.py
@ -13,6 +13,7 @@ import shutil
|
||||
import os
|
||||
import time
|
||||
import glob
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
############################### 插件输入输出接驳区 #######################################
|
||||
|
||||
"""
|
||||
@ -91,20 +92,14 @@ def ArgsGeneralWrapper(f):
|
||||
return decorated
|
||||
|
||||
|
||||
pool = ThreadPoolExecutor(200)
|
||||
def update_ui(chatbot, history, msg='正常', txt='', *args): # 刷新界面
|
||||
"""
|
||||
刷新用户界面
|
||||
"""
|
||||
private_key = get_conf('private_key')[0]
|
||||
chat_title = chatbot[0][0].split()
|
||||
if private_key in chat_title:
|
||||
private_path = os.path.join(func_box.prompt_path, f"ai_private_{chat_title[-2]}.yaml")
|
||||
func_box.YamlHandle(private_path).update(key=chatbot[-1][0], value=chatbot[-1][1])
|
||||
else:
|
||||
func_box.YamlHandle().update(key=chatbot[-1][0], value=chatbot[-1][1])
|
||||
|
||||
assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时,可用clear将其清空,然后用for+append循环重新赋值。"
|
||||
yield chatbot.get_cookies(), chatbot, history, msg, txt
|
||||
pool.submit(func_box.thread_write_chat, chatbot)
|
||||
|
||||
def trimmed_format_exc():
|
||||
import os, traceback
|
||||
@ -254,8 +249,8 @@ def text_divide_paragraph(text):
|
||||
else:
|
||||
# wtf input
|
||||
lines = text.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
lines[i] = lines[i].replace(" ", " ")
|
||||
# for i, line in enumerate(lines):
|
||||
# lines[i] = lines[i].replace(" ", " ")
|
||||
text = "</br>".join(lines)
|
||||
return text
|
||||
|
||||
@ -373,6 +368,7 @@ def format_io(self, y):
|
||||
gpt_reply = close_up_code_segment_during_stream(gpt_reply) # 当代码输出半截的时候,试着补上后个```
|
||||
y[-1] = (
|
||||
None if i_ask is None else markdown.markdown(i_ask, extensions=['fenced_code', 'tables']),
|
||||
#None if i_ask is None else markdown_convertion(i_ask),
|
||||
None if gpt_reply is None else markdown_convertion(gpt_reply)
|
||||
)
|
||||
return y
|
||||
|
||||
Reference in New Issue
Block a user