From 9b2bc97b6710fc958e0b75d8f9cd4c9b77a72468 Mon Sep 17 00:00:00 2001 From: w_xiaolizu Date: Mon, 22 May 2023 16:51:46 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E5=A4=9A=E4=BD=99=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BD=9C=E5=A2=9E=E5=8A=A0=E4=BB=A3=E7=A0=81=E6=B3=A8?= =?UTF-8?q?=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __main__.py | 2 +- func_box.py | 219 +++++++++++++++++++++++++++++++++++++--------------- toolbox.py | 6 ++ 3 files changed, 165 insertions(+), 62 deletions(-) diff --git a/__main__.py b/__main__.py index eba428e..ca9b8a1 100644 --- a/__main__.py +++ b/__main__.py @@ -123,7 +123,7 @@ class ChatBot(ChatBotFrame): inputs=[self.pro_prompt_list, self.pro_prompt_state, self.pro_results], outputs=[self.pro_results]) self.pro_new_btn.click(fn=func_box.prompt_save, - inputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_private_check, self.pro_fp_state], + inputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_fp_state], outputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_private_check, self.pro_func_prompt, self.pro_fp_state]) diff --git a/func_box.py b/func_box.py index 2d01871..02fb412 100644 --- a/func_box.py +++ b/func_box.py @@ -17,6 +17,7 @@ from contextlib import ExitStack import logging import yaml import requests + logger = logging from sklearn.feature_extraction.text import CountVectorizer import numpy as np @@ -26,14 +27,16 @@ import random import gradio as gr import toolbox from prompt_generator import SqliteHandle + """contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。 官网:https://docs.python.org/3/library/contextlib.html""" + class Shell(object): def __init__(self, args, stream=False): self.args = args self.subp = subprocess.Popen(args, shell=True, - stdin=subprocess.PIPE, stderr=subprocess.PIPE, + stdin=subprocess.PIPE, stderr=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8', errors='ignore', close_fds=True) self.__stream = stream @@ -49,9 +52,9 @@ class Shell(object): logger.info(i.rstrip()) self.__temp += i except KeyboardInterrupt as p: - return 3, self.__temp+self.subp.stderr.read() + return 3, self.__temp + self.subp.stderr.read() finally: - return 3, self.__temp+self.subp.stderr.read() + return 3, self.__temp + self.subp.stderr.read() else: sysout = self.subp.stdout.read() syserr = self.subp.stderr.read() @@ -77,7 +80,12 @@ class Shell(object): self.__temp += i yield self.__temp + def timeStatistics(func): + """ + 统计函数执行时常的装饰器 + """ + def statistics(*args, **kwargs): startTiem = time.time() obj = func(*args, **kwargs) @@ -85,8 +93,10 @@ def timeStatistics(func): ums = startTiem - endTiem print('func:{} > Time-consuming: {}'.format(func, ums)) return obj + return statistics + def context_with(*parms): """ 一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。 @@ -95,6 +105,7 @@ def context_with(*parms): Returns: 一个装饰器函数。 """ + def decorator(cls_method): """ 装饰器函数,用于将一个类方法转换为一个嵌套多个 with 语句的方法。 @@ -103,6 +114,7 @@ def context_with(*parms): Returns: 装饰后的类方法。 """ + def wrapper(cls='', *args, **kwargs): """ 装饰后的方法,用于嵌套多个 with 语句,并调用原始的类方法。 @@ -118,7 +130,9 @@ def context_with(*parms): for context in with_list: stack.enter_context(context) return cls_method(cls, *args, **kwargs) + return wrapper + return decorator @@ -130,6 +144,7 @@ def copy_temp_file(file): else: return None + def md5_str(st): # 创建一个 MD5 对象 md5 = hashlib.md5() @@ -141,18 +156,24 @@ def md5_str(st): def html_tag_color(tag, color=None): + """ + 将文本转换为带有高亮提示的html代码 + """ if not color: rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) color = f"rgb{rgb}" tag = f' {tag} ' return tag -def ipaddr(): # 获取本地ipx + +def ipaddr(): + # 获取本地ipx ip = psutil.net_if_addrs() for i in ip: if ip[i][0][3]: return ip[i][0][1] + def encryption_str(txt: str): """(关键字)(加密间隔)匹配机制(关键字间隔)""" txt = str(txt) @@ -160,7 +181,11 @@ def encryption_str(txt: str): result = pattern.sub(lambda x: x.group(1) + ": XXXXXXXX", txt) return result + def tree_out(dir=os.path.dirname(__file__), line=2, more=''): + """ + 获取本地文件的树形结构转化为Markdown代码文本 + """ out = Shell(f'tree {dir} -F -I "__*|.*|venv|*.png|*.xlsx" -L {line} {more}').read()[1] localfile = os.path.join(os.path.dirname(__file__), '.tree.md') with open(localfile, 'w') as f: @@ -168,13 +193,16 @@ def tree_out(dir=os.path.dirname(__file__), line=2, more=''): ll = out.splitlines() for i in range(len(ll)): if i == 0: - f.write(ll[i].split('/')[-2]+'\n') + f.write(ll[i].split('/')[-2] + '\n') else: - f.write(ll[i]+'\n') + f.write(ll[i] + '\n') f.write('```\n') def chat_history(log: list, split=0): + """ + auto_gpt 使用的代码,后续会迁移 + """ if split: log = log[split:] chat = '' @@ -189,6 +217,7 @@ def df_similarity(s1, s2): """弃用,会警告,这个库不会用""" def add_space(s): return ' '.join(list(s)) + # 将字中间加入空格 s1, s2 = add_space(s1), add_space(s2) # 转化为TF矩阵 @@ -200,6 +229,9 @@ def df_similarity(s1, s2): def check_json_format(file): + """ + 检查上传的Json文件是否符合规范 + """ new_dict = {} data = JsonHandle(file).load() if type(data) is list and len(data) > 0: @@ -208,21 +240,49 @@ def check_json_format(file): new_dict.update({i['act']: i['prompt']}) return new_dict -def json_convert_dict(): + +def json_convert_dict(file): + """ + 批量将json转换为字典 + """ new_dict = {} - for root, dirs, files in os.walk(prompt_path): + for root, dirs, files in os.walk(file): for f in files: if f.startswith('prompt') and f.endswith('json'): new_dict.update(check_json_format(f)) return new_dict + def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request): + """ + 绘制搜索结果 + Args: + txt (str): 过滤文本 + prompt : 原始的dataset对象 + percent (int): TF系数,用于计算文本相似度 + switch (list): 过滤个人或所有人的Prompt + ipaddr : 请求人信息 + Returns: + 注册函数所需的元祖对象 + """ data = diff_list(txt, percent=percent, switch=switch, hosts=ipaddr.client.host) prompt.samples = data return prompt.update(samples=data, visible=True), prompt def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''): + """ + 按照搜索结果统计相似度的文本,两组文本相似度>70%的将统计在一起,取最长的作为key + Args: + txt (str): 过滤文本 + percent (int): TF系数,用于计算文本相似度 + switch (list): 过滤个人或所有人的Prompt + lst:指定一个列表或字典 + sp: 截取展示的文本长度 + hosts : 请求人的ip + Returns: + 返回一个列表 + """ import difflib count_dict = {} if not lst: @@ -251,37 +311,41 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15 index = key[0].find(txt) if index != -1: # sp=split 用于判断在哪里启动、在哪里断开 - if index-sp > 0: start = index-sp - else: start = 0 - if len(key[0]) > sp * 2: end = key[0][-sp:] - else: end = '' + if index - sp > 0: + start = index - sp + else: + start = 0 + if len(key[0]) > sp * 2: + end = key[0][-sp:] + else: + end = '' # 判断有没有传需要匹配的字符串,有则筛选、无则全返 - if txt == '' and len(key[0]) >= sp: show = key[0][0:sp] + " . . . " + end - elif txt == '' and len(key[0]) < sp: show = key[0][0:sp] - else: show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt)) + if txt == '' and len(key[0]) >= sp: + show = key[0][0:sp] + " . . . " + end + elif txt == '' and len(key[0]) < sp: + show = key[0][0:sp] + else: + show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt)) show += f" {html_tag_color(' X ' + str(key[1]))}" - if lst.get(key[0]): be_value = lst[key[0]] - else: be_value = "这个prompt还没有对话过呢,快去试试吧~" + if lst.get(key[0]): + be_value = lst[key[0]] + else: + be_value = "这个prompt还没有对话过呢,快去试试吧~" value = be_value dateset_list.append([show, key[0], value]) return dateset_list -def search_list(txt, sp=15): - lst = SqliteHandle('ai_common').get_prompt_value() - dateset_list = [] - for key in lst: - index = key.find(txt) - if index != -1: - if index-sp > 0: start = index-sp - else: start = 0 - show = str(key[start:index+sp]).replace(txt, html_tag_color(txt)) - value = lst[key] - dateset_list.append([show, key, value]) - return dateset_list - - def prompt_upload_refresh(file, prompt, ipaddr: gr.Request): + """ + 上传文件,将文件转换为字典,然后存储到数据库,并刷新Prompt区域 + Args: + file: 上传的文件 + prompt: 原始prompt对象 + ipaddr:ipaddr用户请求信息 + Returns: + 注册函数所需的元祖对象 + """ hosts = ipaddr.client.host if file.name.endswith('json'): upload_data = check_json_format(file.name) @@ -299,6 +363,15 @@ def prompt_upload_refresh(file, prompt, ipaddr: gr.Request): def prompt_retrieval(is_all, hosts='', search=False): + """ + 上传文件,将文件转换为字典,然后存储到数据库,并刷新Prompt区域 + Args: + is_all: prompt类型 + hosts: 查询的用户ip + search:支持搜索,搜索时将key作为key + Returns: + 返回一个列表 + """ count_dict = {} user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml') if '所有人' in is_all: @@ -307,12 +380,12 @@ def prompt_retrieval(is_all, hosts='', search=False): data = SqliteHandle(tab).get_prompt_value() if data: count_dict.update(data) elif '个人' in is_all: - data = SqliteHandle(f'prompt_{hosts}').get_prompt_value() - if data: count_dict.update(data) + data = SqliteHandle(f'prompt_{hosts}').get_prompt_value() + if data: count_dict.update(data) retrieval = [] if count_dict != {}: for key in count_dict: - if not search: + if not search: retrieval.append([key, count_dict[key]]) else: retrieval.append([count_dict[key], key]) @@ -321,17 +394,36 @@ def prompt_retrieval(is_all, hosts='', search=False): return retrieval -def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request +def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request + """ + 上传文件,将文件转换为字典,然后存储到数据库,并刷新Prompt区域 + Args: + is_all: prompt类型 + prompt: dataset原始对象 + ipaddr:请求用户信息 + Returns: + 返回注册函数所需的对象 + """ data = prompt_retrieval(is_all=is_all, hosts=ipaddr.client.host) prompt.samples = data return prompt.update(samples=data, visible=True), prompt, is_all -def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request): +def prompt_save(txt, name, prompt: gr.Dataset, ipaddr: gr.Request): + """ + 编辑和保存Prompt + Args: + txt: Prompt正文 + name: Prompt的名字 + prompt: dataset原始对象 + ipaddr:请求用户信息 + Returns: + 返回注册函数所需的对象 + """ if txt and name: yaml_obj = SqliteHandle(f'prompt_{ipaddr.client.host}') yaml_obj.inset_prompt({name: txt}) - result = prompt_retrieval(is_all=checkbox+['个人'], hosts=ipaddr.client.host) + result = prompt_retrieval(is_all=['个人'], hosts=ipaddr.client.host) prompt.samples = result return "", "", ['个人'], prompt.update(samples=result, visible=True), prompt elif not txt or not name: @@ -339,15 +431,27 @@ def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request): prompt.samples = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']] return txt, name, [], prompt.update(samples=result, visible=True), prompt + def prompt_input(txt, index, data: gr.Dataset): + """ + 点击dataset的值使用Prompt + Args: + txt: 输入框正文 + index: 点击的Dataset下标 + data: dataset原始对象 + Returns: + 返回注册函数所需的对象 + """ data_str = str(data.samples[index][1]) if txt: - txt = data_str+'\n'+txt - else: + txt = data_str + '\n' + txt + else: txt = data_str return txt + def copy_result(history): + """复制history""" if history != []: pyperclip.copy(history[-1]) return '已将结果复制到剪切板' @@ -356,12 +460,24 @@ def copy_result(history): def show_prompt_result(index, data: gr.Dataset, chatbot): + """ + 查看Prompt的对话记录结果 + Args: + index: 点击的Dataset下标 + data: dataset原始对象 + chatbot:聊天机器人 + Returns: + 返回注册函数所需的对象 + """ click = data.samples[index] chatbot.append((click[1], click[2])) return chatbot def thread_write_chat(chatbot): + """ + 对话记录写入数据库 + """ private_key = toolbox.get_conf('private_key')[0] chat_title = chatbot[0][0].split() i_say = chatbot[-1][0].strip("

/p") @@ -375,6 +491,7 @@ def thread_write_chat(chatbot): base_path = os.path.dirname(__file__) prompt_path = os.path.join(base_path, 'prompt_users') + class YamlHandle: def __init__(self, file=os.path.join(prompt_path, 'ai_common.yaml')): @@ -383,7 +500,6 @@ class YamlHandle: self.file = file self._load = self.load() - def load(self) -> dict: with open(file=self.file, mode='r') as f: data = yaml.safe_load(f) @@ -407,6 +523,7 @@ class YamlHandle: yaml.dump(date, f, allow_unicode=True) return date + class JsonHandle: def __init__(self, file=os.path.join(prompt_path, 'prompts-PlexPt.json')): @@ -419,27 +536,7 @@ class JsonHandle: data = json.load(f) return data -class FileHandle: - - def __init__(self, file=None): - self.file = file - - def read(self): - with open(file=self.file, mode='r') as f: - print(f.read()) - - - def read_link(self): - link = 'https://github.com/PlexPt/awesome-chatgpt-prompts-zh/blob/main/prompts-zh.json' - name = link.split('/')[3]+link.split('/')[-1] - new_file = os.path.join(base_path, 'gpt_log', name) - response = requests.get(url=link, verify=False) - with open(new_file, "wb") as f: - f.write(response.content) if __name__ == '__main__': - for i in YamlHandle().load(): - print(i) - - + pass \ No newline at end of file diff --git a/toolbox.py b/toolbox.py index 5181b42..4019544 100644 --- a/toolbox.py +++ b/toolbox.py @@ -458,6 +458,9 @@ def find_recent_files(directory): def get_user_upload(chatbot, ipaddr: gr.Request): + """ + 获取用户上传过的文件 + """ private_upload = './private_upload' user_history = os.path.join(private_upload, ipaddr.client.host) history = '' @@ -471,6 +474,9 @@ def get_user_upload(chatbot, ipaddr: gr.Request): def get_user_download(chatbot, link, file): + """ + 将短路径转换为下载链接 + """ for file_handle in str(file).split('\n'): if os.path.isfile(file_handle): # temp_file = func_box.copy_temp_file(file_handle) 无法使用外部的临时目录