将文件存储换成sqlite3 | 对话可以更多花样
This commit is contained in:
@ -1,15 +0,0 @@
|
|||||||
#! .\venv\
|
|
||||||
# encoding: utf-8
|
|
||||||
# @Time : 2023/4/20
|
|
||||||
# @Author : Spike
|
|
||||||
# @Descr :
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def chat_with_ai(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt):
|
|
||||||
|
|
||||||
history = []
|
|
||||||
|
|
||||||
pass
|
|
||||||
78
func_box.py
78
func_box.py
@ -7,6 +7,8 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
import psutil
|
import psutil
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -22,6 +24,8 @@ from scipy.linalg import norm
|
|||||||
import pyperclip
|
import pyperclip
|
||||||
import random
|
import random
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import toolbox
|
||||||
|
from prompt_generator import SqliteHandle
|
||||||
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
|
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
|
||||||
官网:https://docs.python.org/3/library/contextlib.html"""
|
官网:https://docs.python.org/3/library/contextlib.html"""
|
||||||
|
|
||||||
@ -72,6 +76,16 @@ class Shell(object):
|
|||||||
self.__temp += i
|
self.__temp += i
|
||||||
yield self.__temp
|
yield self.__temp
|
||||||
|
|
||||||
|
def timeStatistics(func):
|
||||||
|
def statistics(*args, **kwargs):
|
||||||
|
startTiem = time.time()
|
||||||
|
obj = func(*args, **kwargs)
|
||||||
|
endTiem = time.time()
|
||||||
|
ums = startTiem - endTiem
|
||||||
|
print('func:{} > Time-consuming: {}'.format(func, ums))
|
||||||
|
return obj
|
||||||
|
return statistics
|
||||||
|
|
||||||
def context_with(*parms):
|
def context_with(*parms):
|
||||||
"""
|
"""
|
||||||
一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。
|
一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。
|
||||||
@ -211,8 +225,8 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
|
|||||||
import difflib
|
import difflib
|
||||||
count_dict = {}
|
count_dict = {}
|
||||||
if not lst:
|
if not lst:
|
||||||
lst = YamlHandle().load()
|
lst = SqliteHandle('ai_common').get_prompt_value()
|
||||||
lst.update(YamlHandle(os.path.join(prompt_path, f"ai_private_{hosts}.yaml")).load())
|
lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value())
|
||||||
# diff 数据,根据precent系数归类数据
|
# diff 数据,根据precent系数归类数据
|
||||||
for i in lst:
|
for i in lst:
|
||||||
found = False
|
found = False
|
||||||
@ -253,7 +267,7 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
|
|||||||
|
|
||||||
|
|
||||||
def search_list(txt, sp=15):
|
def search_list(txt, sp=15):
|
||||||
lst = YamlHandle().load()
|
lst = SqliteHandle('ai_common').get_prompt_value()
|
||||||
dateset_list = []
|
dateset_list = []
|
||||||
for key in lst:
|
for key in lst:
|
||||||
index = key.find(txt)
|
index = key.find(txt)
|
||||||
@ -268,37 +282,31 @@ def search_list(txt, sp=15):
|
|||||||
|
|
||||||
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
|
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
|
||||||
hosts = ipaddr.client.host
|
hosts = ipaddr.client.host
|
||||||
user_file = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
|
|
||||||
if file.name.endswith('json'):
|
if file.name.endswith('json'):
|
||||||
upload_data = check_json_format(file.name)
|
upload_data = check_json_format(file.name)
|
||||||
if upload_data != {}:
|
|
||||||
YamlHandle(user_file).dump_dict(upload_data)
|
|
||||||
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
|
|
||||||
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
|
|
||||||
else:
|
|
||||||
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
|
|
||||||
return prompt.samples, prompt, []
|
|
||||||
elif file.name.endswith('yaml'):
|
elif file.name.endswith('yaml'):
|
||||||
upload_data = YamlHandle(file.name).load()
|
upload_data = YamlHandle(file.name).load()
|
||||||
if upload_data != {} and type(upload_data) is dict:
|
else:
|
||||||
YamlHandle(user_file).dump_dict(upload_data)
|
upload_data = {}
|
||||||
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
|
if upload_data != {}:
|
||||||
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
|
SqliteHandle(f'prompt_{hosts}').inset_prompt(upload_data)
|
||||||
else:
|
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
|
||||||
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
|
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
|
||||||
return prompt.samples, prompt, []
|
else:
|
||||||
|
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
|
||||||
|
return prompt.samples, prompt, []
|
||||||
|
|
||||||
|
|
||||||
def prompt_retrieval(is_all, hosts='', search=False):
|
def prompt_retrieval(is_all, hosts='', search=False):
|
||||||
count_dict = {}
|
count_dict = {}
|
||||||
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
|
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
|
||||||
if '所有人' in is_all:
|
if '所有人' in is_all:
|
||||||
for root, dirs, files in os.walk(prompt_path):
|
for tab in SqliteHandle('ai_common').get_tables():
|
||||||
for f in files:
|
if tab.startswith('prompt'):
|
||||||
if f.startswith('prompt') and f.endswith('yaml'):
|
data = SqliteHandle(tab).get_prompt_value()
|
||||||
data = YamlHandle(file=os.path.join(root, f)).load()
|
if data: count_dict.update(data)
|
||||||
if data: count_dict.update(data)
|
|
||||||
elif '个人' in is_all:
|
elif '个人' in is_all:
|
||||||
data = YamlHandle(file=user_path).load()
|
data = SqliteHandle(f'prompt_{hosts}').get_prompt_value()
|
||||||
if data: count_dict.update(data)
|
if data: count_dict.update(data)
|
||||||
retrieval = []
|
retrieval = []
|
||||||
if count_dict != {}:
|
if count_dict != {}:
|
||||||
@ -320,8 +328,8 @@ def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipa
|
|||||||
|
|
||||||
def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request):
|
def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request):
|
||||||
if txt and name:
|
if txt and name:
|
||||||
yaml_obj = YamlHandle(os.path.join(prompt_path, f'prompt_{ipaddr.client.host}.yaml'))
|
yaml_obj = SqliteHandle(f'prompt_{ipaddr.client.host}')
|
||||||
yaml_obj.update(name, txt)
|
yaml_obj.inset_prompt({name: txt})
|
||||||
result = prompt_retrieval(is_all=checkbox, hosts=ipaddr.client.host)
|
result = prompt_retrieval(is_all=checkbox, hosts=ipaddr.client.host)
|
||||||
prompt.samples = result
|
prompt.samples = result
|
||||||
return "", "", ['个人'], prompt.update(samples=result, samples_per_page=10, visible=True), prompt
|
return "", "", ['个人'], prompt.update(samples=result, samples_per_page=10, visible=True), prompt
|
||||||
@ -351,6 +359,16 @@ def show_prompt_result(index, data: gr.Dataset, chatbot):
|
|||||||
chatbot.append((click[1], click[2]))
|
chatbot.append((click[1], click[2]))
|
||||||
return chatbot
|
return chatbot
|
||||||
|
|
||||||
|
|
||||||
|
def thread_write_chat(chatbot):
|
||||||
|
private_key = toolbox.get_conf('private_key')[0]
|
||||||
|
chat_title = chatbot[0][0].split()
|
||||||
|
if private_key in chat_title:
|
||||||
|
SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({chatbot[-1][0]: chatbot[-1][1]})
|
||||||
|
else:
|
||||||
|
SqliteHandle(f'ai_common').inset_prompt({chatbot[-1][0]: chatbot[-1][1]})
|
||||||
|
|
||||||
|
|
||||||
base_path = os.path.dirname(__file__)
|
base_path = os.path.dirname(__file__)
|
||||||
prompt_path = os.path.join(base_path, 'prompt_users')
|
prompt_path = os.path.join(base_path, 'prompt_users')
|
||||||
|
|
||||||
@ -360,6 +378,7 @@ class YamlHandle:
|
|||||||
if not os.path.exists(file):
|
if not os.path.exists(file):
|
||||||
Shell(f'touch {file}').read()
|
Shell(f'touch {file}').read()
|
||||||
self.file = file
|
self.file = file
|
||||||
|
self._load = self.load()
|
||||||
|
|
||||||
|
|
||||||
def load(self) -> dict:
|
def load(self) -> dict:
|
||||||
@ -368,7 +387,7 @@ class YamlHandle:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def update(self, key, value):
|
def update(self, key, value):
|
||||||
date = self.load()
|
date = self._load
|
||||||
if not date:
|
if not date:
|
||||||
date = {}
|
date = {}
|
||||||
date[key] = value
|
date[key] = value
|
||||||
@ -377,7 +396,7 @@ class YamlHandle:
|
|||||||
return date
|
return date
|
||||||
|
|
||||||
def dump_dict(self, new_dict):
|
def dump_dict(self, new_dict):
|
||||||
date = self.load()
|
date = self._load
|
||||||
if not date:
|
if not date:
|
||||||
date = {}
|
date = {}
|
||||||
date.update(new_dict)
|
date.update(new_dict)
|
||||||
@ -417,6 +436,7 @@ class FileHandle:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print(YamlHandle().load())
|
for i in YamlHandle().load():
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,10 +3,73 @@
|
|||||||
# @Time : 2023/4/19
|
# @Time : 2023/4/19
|
||||||
# @Author : Spike
|
# @Author : Spike
|
||||||
# @Descr :
|
# @Descr :
|
||||||
|
import os.path
|
||||||
# 默认的prompt
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
import functools
|
||||||
|
import func_box
|
||||||
|
# 连接到数据库
|
||||||
|
base_path = os.path.dirname(__file__)
|
||||||
|
prompt_path = os.path.join(base_path, 'prompt_users')
|
||||||
|
|
||||||
|
|
||||||
|
def connect_db_close(cls_method):
|
||||||
|
@functools.wraps(cls_method)
|
||||||
|
def wrapper(cls=None, *args, **kwargs):
|
||||||
|
cls._connect_db()
|
||||||
|
result = cls_method(cls, *args, **kwargs)
|
||||||
|
cls._close_db()
|
||||||
|
return result
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteHandle:
|
||||||
|
def __init__(self, table='ai_common'):
|
||||||
|
self.__connect = sqlite3.connect(os.path.join(prompt_path, 'ai_prompt.db'))
|
||||||
|
self.__cursor = self.__connect.cursor()
|
||||||
|
self.__table = table
|
||||||
|
if self.__table not in self.get_tables():
|
||||||
|
self.create_tab()
|
||||||
|
|
||||||
|
def new_connect_db(self):
|
||||||
|
"""多线程操作时,每个线程新建独立的connect"""
|
||||||
|
self.__connect = sqlite3.connect(os.path.join(prompt_path, 'ai_prompt.db'))
|
||||||
|
self.__cursor = self.__connect.cursor()
|
||||||
|
|
||||||
|
def new_close_db(self):
|
||||||
|
self.__cursor.close()
|
||||||
|
self.__connect.close()
|
||||||
|
|
||||||
|
def create_tab(self):
|
||||||
|
self.__cursor.execute(f"CREATE TABLE `{self.__table}` ('id' INTEGER PRIMARY KEY AUTOINCREMENT, 'prompt' TEXT, 'result' TEXT)")
|
||||||
|
|
||||||
|
def get_tables(self):
|
||||||
|
all_tab = []
|
||||||
|
result = self.__cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
|
||||||
|
for tab in result:
|
||||||
|
all_tab.append(tab[0])
|
||||||
|
return all_tab
|
||||||
|
|
||||||
|
def get_prompt_value(self):
|
||||||
|
temp_all = {}
|
||||||
|
result = self.__cursor.execute(f"SELECT prompt, result FROM `{self.__table}`").fetchall()
|
||||||
|
for row in result:
|
||||||
|
temp_all[row[0]] = row[1]
|
||||||
|
return temp_all
|
||||||
|
|
||||||
|
def inset_prompt(self, prompt: dict):
|
||||||
|
for key in prompt:
|
||||||
|
self.__cursor.execute(f"INSERT INTO `{self.__table}` (prompt, result) VALUES (?, ?);", (str(key), str(prompt[key])))
|
||||||
|
self.__connect.commit()
|
||||||
|
|
||||||
|
def delete_prompt(self):
|
||||||
|
self.__cursor.execute(f"DELETE from `{self.__table}` where id BETWEEN 1 AND 21")
|
||||||
|
self.__connect.commit()
|
||||||
|
|
||||||
|
sqlite_handle = SqliteHandle
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
|
||||||
|
# print(sqlite_handle('ai_common').inset_prompt(test))
|
||||||
|
# sqlite_handle('ai_common').delete_prompt()
|
||||||
|
print(sqlite_handle('ai_common').get_prompt_value())
|
||||||
|
|||||||
16
toolbox.py
16
toolbox.py
@ -13,6 +13,7 @@ import shutil
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import glob
|
import glob
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
############################### 插件输入输出接驳区 #######################################
|
############################### 插件输入输出接驳区 #######################################
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -91,20 +92,14 @@ def ArgsGeneralWrapper(f):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
pool = ThreadPoolExecutor(200)
|
||||||
def update_ui(chatbot, history, msg='正常', txt='', *args): # 刷新界面
|
def update_ui(chatbot, history, msg='正常', txt='', *args): # 刷新界面
|
||||||
"""
|
"""
|
||||||
刷新用户界面
|
刷新用户界面
|
||||||
"""
|
"""
|
||||||
private_key = get_conf('private_key')[0]
|
|
||||||
chat_title = chatbot[0][0].split()
|
|
||||||
if private_key in chat_title:
|
|
||||||
private_path = os.path.join(func_box.prompt_path, f"ai_private_{chat_title[-2]}.yaml")
|
|
||||||
func_box.YamlHandle(private_path).update(key=chatbot[-1][0], value=chatbot[-1][1])
|
|
||||||
else:
|
|
||||||
func_box.YamlHandle().update(key=chatbot[-1][0], value=chatbot[-1][1])
|
|
||||||
|
|
||||||
assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时,可用clear将其清空,然后用for+append循环重新赋值。"
|
assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时,可用clear将其清空,然后用for+append循环重新赋值。"
|
||||||
yield chatbot.get_cookies(), chatbot, history, msg, txt
|
yield chatbot.get_cookies(), chatbot, history, msg, txt
|
||||||
|
pool.submit(func_box.thread_write_chat, chatbot)
|
||||||
|
|
||||||
def trimmed_format_exc():
|
def trimmed_format_exc():
|
||||||
import os, traceback
|
import os, traceback
|
||||||
@ -254,8 +249,8 @@ def text_divide_paragraph(text):
|
|||||||
else:
|
else:
|
||||||
# wtf input
|
# wtf input
|
||||||
lines = text.split("\n")
|
lines = text.split("\n")
|
||||||
for i, line in enumerate(lines):
|
# for i, line in enumerate(lines):
|
||||||
lines[i] = lines[i].replace(" ", " ")
|
# lines[i] = lines[i].replace(" ", " ")
|
||||||
text = "</br>".join(lines)
|
text = "</br>".join(lines)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@ -373,6 +368,7 @@ def format_io(self, y):
|
|||||||
gpt_reply = close_up_code_segment_during_stream(gpt_reply) # 当代码输出半截的时候,试着补上后个```
|
gpt_reply = close_up_code_segment_during_stream(gpt_reply) # 当代码输出半截的时候,试着补上后个```
|
||||||
y[-1] = (
|
y[-1] = (
|
||||||
None if i_ask is None else markdown.markdown(i_ask, extensions=['fenced_code', 'tables']),
|
None if i_ask is None else markdown.markdown(i_ask, extensions=['fenced_code', 'tables']),
|
||||||
|
#None if i_ask is None else markdown_convertion(i_ask),
|
||||||
None if gpt_reply is None else markdown_convertion(gpt_reply)
|
None if gpt_reply is None else markdown_convertion(gpt_reply)
|
||||||
)
|
)
|
||||||
return y
|
return y
|
||||||
|
|||||||
Reference in New Issue
Block a user