diff --git a/auto_functional.py b/auto_functional.py deleted file mode 100644 index c048d76..0000000 --- a/auto_functional.py +++ /dev/null @@ -1,15 +0,0 @@ -#! .\venv\ -# encoding: utf-8 -# @Time : 2023/4/20 -# @Author : Spike -# @Descr : - - - - - -def chat_with_ai(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt): - - history = [] - - pass \ No newline at end of file diff --git a/func_box.py b/func_box.py index f763ced..a3a3edb 100644 --- a/func_box.py +++ b/func_box.py @@ -7,6 +7,8 @@ import hashlib import json import os.path import subprocess +import threading +import time import psutil import re import tempfile @@ -22,6 +24,8 @@ from scipy.linalg import norm import pyperclip import random import gradio as gr +import toolbox +from prompt_generator import SqliteHandle """contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。 官网:https://docs.python.org/3/library/contextlib.html""" @@ -72,6 +76,16 @@ class Shell(object): self.__temp += i yield self.__temp +def timeStatistics(func): + def statistics(*args, **kwargs): + startTiem = time.time() + obj = func(*args, **kwargs) + endTiem = time.time() + ums = startTiem - endTiem + print('func:{} > Time-consuming: {}'.format(func, ums)) + return obj + return statistics + def context_with(*parms): """ 一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。 @@ -211,8 +225,8 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15 import difflib count_dict = {} if not lst: - lst = YamlHandle().load() - lst.update(YamlHandle(os.path.join(prompt_path, f"ai_private_{hosts}.yaml")).load()) + lst = SqliteHandle('ai_common').get_prompt_value() + lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value()) # diff 数据,根据precent系数归类数据 for i in lst: found = False @@ -253,7 +267,7 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15 def search_list(txt, sp=15): - lst = YamlHandle().load() + lst = SqliteHandle('ai_common').get_prompt_value() dateset_list = [] for key in lst: index = key.find(txt) @@ -268,37 +282,31 @@ def search_list(txt, sp=15): def prompt_upload_refresh(file, prompt, ipaddr: gr.Request): hosts = ipaddr.client.host - user_file = os.path.join(prompt_path, f'prompt_{hosts}.yaml') if file.name.endswith('json'): upload_data = check_json_format(file.name) - if upload_data != {}: - YamlHandle(user_file).dump_dict(upload_data) - ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts) - return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人'] - else: - prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']] - return prompt.samples, prompt, [] elif file.name.endswith('yaml'): upload_data = YamlHandle(file.name).load() - if upload_data != {} and type(upload_data) is dict: - YamlHandle(user_file).dump_dict(upload_data) - ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts) - return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人'] - else: - prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']] - return prompt.samples, prompt, [] + else: + upload_data = {} + if upload_data != {}: + SqliteHandle(f'prompt_{hosts}').inset_prompt(upload_data) + ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts) + return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人'] + else: + prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']] + return prompt.samples, prompt, [] + def prompt_retrieval(is_all, hosts='', search=False): count_dict = {} user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml') if '所有人' in is_all: - for root, dirs, files in os.walk(prompt_path): - for f in files: - if f.startswith('prompt') and f.endswith('yaml'): - data = YamlHandle(file=os.path.join(root, f)).load() - if data: count_dict.update(data) + for tab in SqliteHandle('ai_common').get_tables(): + if tab.startswith('prompt'): + data = SqliteHandle(tab).get_prompt_value() + if data: count_dict.update(data) elif '个人' in is_all: - data = YamlHandle(file=user_path).load() + data = SqliteHandle(f'prompt_{hosts}').get_prompt_value() if data: count_dict.update(data) retrieval = [] if count_dict != {}: @@ -320,8 +328,8 @@ def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipa def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request): if txt and name: - yaml_obj = YamlHandle(os.path.join(prompt_path, f'prompt_{ipaddr.client.host}.yaml')) - yaml_obj.update(name, txt) + yaml_obj = SqliteHandle(f'prompt_{ipaddr.client.host}') + yaml_obj.inset_prompt({name: txt}) result = prompt_retrieval(is_all=checkbox, hosts=ipaddr.client.host) prompt.samples = result return "", "", ['个人'], prompt.update(samples=result, samples_per_page=10, visible=True), prompt @@ -351,6 +359,16 @@ def show_prompt_result(index, data: gr.Dataset, chatbot): chatbot.append((click[1], click[2])) return chatbot + +def thread_write_chat(chatbot): + private_key = toolbox.get_conf('private_key')[0] + chat_title = chatbot[0][0].split() + if private_key in chat_title: + SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({chatbot[-1][0]: chatbot[-1][1]}) + else: + SqliteHandle(f'ai_common').inset_prompt({chatbot[-1][0]: chatbot[-1][1]}) + + base_path = os.path.dirname(__file__) prompt_path = os.path.join(base_path, 'prompt_users') @@ -360,6 +378,7 @@ class YamlHandle: if not os.path.exists(file): Shell(f'touch {file}').read() self.file = file + self._load = self.load() def load(self) -> dict: @@ -368,7 +387,7 @@ class YamlHandle: return data def update(self, key, value): - date = self.load() + date = self._load if not date: date = {} date[key] = value @@ -377,7 +396,7 @@ class YamlHandle: return date def dump_dict(self, new_dict): - date = self.load() + date = self._load if not date: date = {} date.update(new_dict) @@ -417,6 +436,7 @@ class FileHandle: if __name__ == '__main__': - print(YamlHandle().load()) + for i in YamlHandle().load(): + print(i) diff --git a/prompt_generator.py b/prompt_generator.py index 7924c9a..f04dccd 100644 --- a/prompt_generator.py +++ b/prompt_generator.py @@ -3,10 +3,73 @@ # @Time : 2023/4/19 # @Author : Spike # @Descr : - -# 默认的prompt +import os.path +import sqlite3 +import threading +import functools +import func_box +# 连接到数据库 +base_path = os.path.dirname(__file__) +prompt_path = os.path.join(base_path, 'prompt_users') +def connect_db_close(cls_method): + @functools.wraps(cls_method) + def wrapper(cls=None, *args, **kwargs): + cls._connect_db() + result = cls_method(cls, *args, **kwargs) + cls._close_db() + return result + return wrapper +class SqliteHandle: + def __init__(self, table='ai_common'): + self.__connect = sqlite3.connect(os.path.join(prompt_path, 'ai_prompt.db')) + self.__cursor = self.__connect.cursor() + self.__table = table + if self.__table not in self.get_tables(): + self.create_tab() + def new_connect_db(self): + """多线程操作时,每个线程新建独立的connect""" + self.__connect = sqlite3.connect(os.path.join(prompt_path, 'ai_prompt.db')) + self.__cursor = self.__connect.cursor() + + def new_close_db(self): + self.__cursor.close() + self.__connect.close() + + def create_tab(self): + self.__cursor.execute(f"CREATE TABLE `{self.__table}` ('id' INTEGER PRIMARY KEY AUTOINCREMENT, 'prompt' TEXT, 'result' TEXT)") + + def get_tables(self): + all_tab = [] + result = self.__cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';") + for tab in result: + all_tab.append(tab[0]) + return all_tab + + def get_prompt_value(self): + temp_all = {} + result = self.__cursor.execute(f"SELECT prompt, result FROM `{self.__table}`").fetchall() + for row in result: + temp_all[row[0]] = row[1] + return temp_all + + def inset_prompt(self, prompt: dict): + for key in prompt: + self.__cursor.execute(f"INSERT INTO `{self.__table}` (prompt, result) VALUES (?, ?);", (str(key), str(prompt[key]))) + self.__connect.commit() + + def delete_prompt(self): + self.__cursor.execute(f"DELETE from `{self.__table}` where id BETWEEN 1 AND 21") + self.__connect.commit() + +sqlite_handle = SqliteHandle +if __name__ == '__main__': + + + # print(sqlite_handle('ai_common').inset_prompt(test)) + # sqlite_handle('ai_common').delete_prompt() + print(sqlite_handle('ai_common').get_prompt_value()) diff --git a/toolbox.py b/toolbox.py index 30b1a8d..30580e3 100644 --- a/toolbox.py +++ b/toolbox.py @@ -13,6 +13,7 @@ import shutil import os import time import glob +from concurrent.futures import ThreadPoolExecutor ############################### 插件输入输出接驳区 ####################################### """ @@ -91,20 +92,14 @@ def ArgsGeneralWrapper(f): return decorated +pool = ThreadPoolExecutor(200) def update_ui(chatbot, history, msg='正常', txt='', *args): # 刷新界面 """ 刷新用户界面 """ - private_key = get_conf('private_key')[0] - chat_title = chatbot[0][0].split() - if private_key in chat_title: - private_path = os.path.join(func_box.prompt_path, f"ai_private_{chat_title[-2]}.yaml") - func_box.YamlHandle(private_path).update(key=chatbot[-1][0], value=chatbot[-1][1]) - else: - func_box.YamlHandle().update(key=chatbot[-1][0], value=chatbot[-1][1]) - assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时,可用clear将其清空,然后用for+append循环重新赋值。" yield chatbot.get_cookies(), chatbot, history, msg, txt + pool.submit(func_box.thread_write_chat, chatbot) def trimmed_format_exc(): import os, traceback @@ -254,8 +249,8 @@ def text_divide_paragraph(text): else: # wtf input lines = text.split("\n") - for i, line in enumerate(lines): - lines[i] = lines[i].replace(" ", " ") + # for i, line in enumerate(lines): + # lines[i] = lines[i].replace(" ", " ") text = "
".join(lines) return text @@ -373,6 +368,7 @@ def format_io(self, y): gpt_reply = close_up_code_segment_during_stream(gpt_reply) # 当代码输出半截的时候,试着补上后个``` y[-1] = ( None if i_ask is None else markdown.markdown(i_ask, extensions=['fenced_code', 'tables']), + #None if i_ask is None else markdown_convertion(i_ask), None if gpt_reply is None else markdown_convertion(gpt_reply) ) return y