删除多余代码|增加代码注释

This commit is contained in:
w_xiaolizu
2023-05-22 16:51:46 +08:00
parent 331ab3be0d
commit 9b2bc97b67
3 changed files with 165 additions and 62 deletions

View File

@ -123,7 +123,7 @@ class ChatBot(ChatBotFrame):
inputs=[self.pro_prompt_list, self.pro_prompt_state, self.pro_results],
outputs=[self.pro_results])
self.pro_new_btn.click(fn=func_box.prompt_save,
inputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_private_check, self.pro_fp_state],
inputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_fp_state],
outputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_private_check,
self.pro_func_prompt, self.pro_fp_state])

View File

@ -17,6 +17,7 @@ from contextlib import ExitStack
import logging
import yaml
import requests
logger = logging
from sklearn.feature_extraction.text import CountVectorizer
import numpy as np
@ -26,14 +27,16 @@ import random
import gradio as gr
import toolbox
from prompt_generator import SqliteHandle
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
官网https://docs.python.org/3/library/contextlib.html"""
class Shell(object):
def __init__(self, args, stream=False):
self.args = args
self.subp = subprocess.Popen(args, shell=True,
stdin=subprocess.PIPE, stderr=subprocess.PIPE,
stdin=subprocess.PIPE, stderr=subprocess.PIPE,
stdout=subprocess.PIPE, encoding='utf-8',
errors='ignore', close_fds=True)
self.__stream = stream
@ -49,9 +52,9 @@ class Shell(object):
logger.info(i.rstrip())
self.__temp += i
except KeyboardInterrupt as p:
return 3, self.__temp+self.subp.stderr.read()
return 3, self.__temp + self.subp.stderr.read()
finally:
return 3, self.__temp+self.subp.stderr.read()
return 3, self.__temp + self.subp.stderr.read()
else:
sysout = self.subp.stdout.read()
syserr = self.subp.stderr.read()
@ -77,7 +80,12 @@ class Shell(object):
self.__temp += i
yield self.__temp
def timeStatistics(func):
"""
统计函数执行时常的装饰器
"""
def statistics(*args, **kwargs):
startTiem = time.time()
obj = func(*args, **kwargs)
@ -85,8 +93,10 @@ def timeStatistics(func):
ums = startTiem - endTiem
print('func:{} > Time-consuming: {}'.format(func, ums))
return obj
return statistics
def context_with(*parms):
"""
一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。
@ -95,6 +105,7 @@ def context_with(*parms):
Returns:
一个装饰器函数。
"""
def decorator(cls_method):
"""
装饰器函数,用于将一个类方法转换为一个嵌套多个 with 语句的方法。
@ -103,6 +114,7 @@ def context_with(*parms):
Returns:
装饰后的类方法。
"""
def wrapper(cls='', *args, **kwargs):
"""
装饰后的方法,用于嵌套多个 with 语句,并调用原始的类方法。
@ -118,7 +130,9 @@ def context_with(*parms):
for context in with_list:
stack.enter_context(context)
return cls_method(cls, *args, **kwargs)
return wrapper
return decorator
@ -130,6 +144,7 @@ def copy_temp_file(file):
else:
return None
def md5_str(st):
# 创建一个 MD5 对象
md5 = hashlib.md5()
@ -141,18 +156,24 @@ def md5_str(st):
def html_tag_color(tag, color=None):
"""
将文本转换为带有高亮提示的html代码
"""
if not color:
rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
color = f"rgb{rgb}"
tag = f'<span style="background-color: {color}; font-weight: bold; color: black">&nbsp;{tag}&ensp;</span>'
return tag
def ipaddr(): # 获取本地ipx
def ipaddr():
# 获取本地ipx
ip = psutil.net_if_addrs()
for i in ip:
if ip[i][0][3]:
return ip[i][0][1]
def encryption_str(txt: str):
"""(关键字)(加密间隔)匹配机制(关键字间隔)"""
txt = str(txt)
@ -160,7 +181,11 @@ def encryption_str(txt: str):
result = pattern.sub(lambda x: x.group(1) + ": XXXXXXXX", txt)
return result
def tree_out(dir=os.path.dirname(__file__), line=2, more=''):
"""
获取本地文件的树形结构转化为Markdown代码文本
"""
out = Shell(f'tree {dir} -F -I "__*|.*|venv|*.png|*.xlsx" -L {line} {more}').read()[1]
localfile = os.path.join(os.path.dirname(__file__), '.tree.md')
with open(localfile, 'w') as f:
@ -168,13 +193,16 @@ def tree_out(dir=os.path.dirname(__file__), line=2, more=''):
ll = out.splitlines()
for i in range(len(ll)):
if i == 0:
f.write(ll[i].split('/')[-2]+'\n')
f.write(ll[i].split('/')[-2] + '\n')
else:
f.write(ll[i]+'\n')
f.write(ll[i] + '\n')
f.write('```\n')
def chat_history(log: list, split=0):
"""
auto_gpt 使用的代码,后续会迁移
"""
if split:
log = log[split:]
chat = ''
@ -189,6 +217,7 @@ def df_similarity(s1, s2):
"""弃用,会警告,这个库不会用"""
def add_space(s):
return ' '.join(list(s))
# 将字中间加入空格
s1, s2 = add_space(s1), add_space(s2)
# 转化为TF矩阵
@ -200,6 +229,9 @@ def df_similarity(s1, s2):
def check_json_format(file):
"""
检查上传的Json文件是否符合规范
"""
new_dict = {}
data = JsonHandle(file).load()
if type(data) is list and len(data) > 0:
@ -208,21 +240,49 @@ def check_json_format(file):
new_dict.update({i['act']: i['prompt']})
return new_dict
def json_convert_dict():
def json_convert_dict(file):
"""
批量将json转换为字典
"""
new_dict = {}
for root, dirs, files in os.walk(prompt_path):
for root, dirs, files in os.walk(file):
for f in files:
if f.startswith('prompt') and f.endswith('json'):
new_dict.update(check_json_format(f))
return new_dict
def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request):
"""
绘制搜索结果
Args:
txt (str): 过滤文本
prompt : 原始的dataset对象
percent (int): TF系数用于计算文本相似度
switch (list): 过滤个人或所有人的Prompt
ipaddr : 请求人信息
Returns:
注册函数所需的元祖对象
"""
data = diff_list(txt, percent=percent, switch=switch, hosts=ipaddr.client.host)
prompt.samples = data
return prompt.update(samples=data, visible=True), prompt
def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''):
"""
按照搜索结果统计相似度的文本,两组文本相似度>70%的将统计在一起取最长的作为key
Args:
txt (str): 过滤文本
percent (int): TF系数用于计算文本相似度
switch (list): 过滤个人或所有人的Prompt
lst指定一个列表或字典
sp: 截取展示的文本长度
hosts : 请求人的ip
Returns:
返回一个列表
"""
import difflib
count_dict = {}
if not lst:
@ -251,37 +311,41 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
index = key[0].find(txt)
if index != -1:
# sp=split 用于判断在哪里启动、在哪里断开
if index-sp > 0: start = index-sp
else: start = 0
if len(key[0]) > sp * 2: end = key[0][-sp:]
else: end = ''
if index - sp > 0:
start = index - sp
else:
start = 0
if len(key[0]) > sp * 2:
end = key[0][-sp:]
else:
end = ''
# 判断有没有传需要匹配的字符串,有则筛选、无则全返
if txt == '' and len(key[0]) >= sp: show = key[0][0:sp] + " . . . " + end
elif txt == '' and len(key[0]) < sp: show = key[0][0:sp]
else: show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt))
if txt == '' and len(key[0]) >= sp:
show = key[0][0:sp] + " . . . " + end
elif txt == '' and len(key[0]) < sp:
show = key[0][0:sp]
else:
show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt))
show += f" {html_tag_color(' X ' + str(key[1]))}"
if lst.get(key[0]): be_value = lst[key[0]]
else: be_value = "这个prompt还没有对话过呢快去试试吧"
if lst.get(key[0]):
be_value = lst[key[0]]
else:
be_value = "这个prompt还没有对话过呢快去试试吧"
value = be_value
dateset_list.append([show, key[0], value])
return dateset_list
def search_list(txt, sp=15):
lst = SqliteHandle('ai_common').get_prompt_value()
dateset_list = []
for key in lst:
index = key.find(txt)
if index != -1:
if index-sp > 0: start = index-sp
else: start = 0
show = str(key[start:index+sp]).replace(txt, html_tag_color(txt))
value = lst[key]
dateset_list.append([show, key, value])
return dateset_list
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
"""
上传文件将文件转换为字典然后存储到数据库并刷新Prompt区域
Args:
file 上传的文件
prompt 原始prompt对象
ipaddripaddr用户请求信息
Returns:
注册函数所需的元祖对象
"""
hosts = ipaddr.client.host
if file.name.endswith('json'):
upload_data = check_json_format(file.name)
@ -299,6 +363,15 @@ def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
def prompt_retrieval(is_all, hosts='', search=False):
"""
上传文件将文件转换为字典然后存储到数据库并刷新Prompt区域
Args:
is_all prompt类型
hosts 查询的用户ip
search支持搜索搜索时将key作为key
Returns:
返回一个列表
"""
count_dict = {}
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
if '所有人' in is_all:
@ -307,12 +380,12 @@ def prompt_retrieval(is_all, hosts='', search=False):
data = SqliteHandle(tab).get_prompt_value()
if data: count_dict.update(data)
elif '个人' in is_all:
data = SqliteHandle(f'prompt_{hosts}').get_prompt_value()
if data: count_dict.update(data)
data = SqliteHandle(f'prompt_{hosts}').get_prompt_value()
if data: count_dict.update(data)
retrieval = []
if count_dict != {}:
for key in count_dict:
if not search:
if not search:
retrieval.append([key, count_dict[key]])
else:
retrieval.append([count_dict[key], key])
@ -321,17 +394,36 @@ def prompt_retrieval(is_all, hosts='', search=False):
return retrieval
def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request
def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request
"""
上传文件将文件转换为字典然后存储到数据库并刷新Prompt区域
Args:
is_all prompt类型
prompt dataset原始对象
ipaddr请求用户信息
Returns:
返回注册函数所需的对象
"""
data = prompt_retrieval(is_all=is_all, hosts=ipaddr.client.host)
prompt.samples = data
return prompt.update(samples=data, visible=True), prompt, is_all
def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request):
def prompt_save(txt, name, prompt: gr.Dataset, ipaddr: gr.Request):
"""
编辑和保存Prompt
Args:
txt Prompt正文
name Prompt的名字
prompt dataset原始对象
ipaddr请求用户信息
Returns:
返回注册函数所需的对象
"""
if txt and name:
yaml_obj = SqliteHandle(f'prompt_{ipaddr.client.host}')
yaml_obj.inset_prompt({name: txt})
result = prompt_retrieval(is_all=checkbox+['个人'], hosts=ipaddr.client.host)
result = prompt_retrieval(is_all=['个人'], hosts=ipaddr.client.host)
prompt.samples = result
return "", "", ['个人'], prompt.update(samples=result, visible=True), prompt
elif not txt or not name:
@ -339,15 +431,27 @@ def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request):
prompt.samples = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']]
return txt, name, [], prompt.update(samples=result, visible=True), prompt
def prompt_input(txt, index, data: gr.Dataset):
"""
点击dataset的值使用Prompt
Args:
txt 输入框正文
index 点击的Dataset下标
data dataset原始对象
Returns:
返回注册函数所需的对象
"""
data_str = str(data.samples[index][1])
if txt:
txt = data_str+'\n'+txt
else:
txt = data_str + '\n' + txt
else:
txt = data_str
return txt
def copy_result(history):
"""复制history"""
if history != []:
pyperclip.copy(history[-1])
return '已将结果复制到剪切板'
@ -356,12 +460,24 @@ def copy_result(history):
def show_prompt_result(index, data: gr.Dataset, chatbot):
"""
查看Prompt的对话记录结果
Args:
index 点击的Dataset下标
data dataset原始对象
chatbot聊天机器人
Returns:
返回注册函数所需的对象
"""
click = data.samples[index]
chatbot.append((click[1], click[2]))
return chatbot
def thread_write_chat(chatbot):
"""
对话记录写入数据库
"""
private_key = toolbox.get_conf('private_key')[0]
chat_title = chatbot[0][0].split()
i_say = chatbot[-1][0].strip("<p>/p")
@ -375,6 +491,7 @@ def thread_write_chat(chatbot):
base_path = os.path.dirname(__file__)
prompt_path = os.path.join(base_path, 'prompt_users')
class YamlHandle:
def __init__(self, file=os.path.join(prompt_path, 'ai_common.yaml')):
@ -383,7 +500,6 @@ class YamlHandle:
self.file = file
self._load = self.load()
def load(self) -> dict:
with open(file=self.file, mode='r') as f:
data = yaml.safe_load(f)
@ -407,6 +523,7 @@ class YamlHandle:
yaml.dump(date, f, allow_unicode=True)
return date
class JsonHandle:
def __init__(self, file=os.path.join(prompt_path, 'prompts-PlexPt.json')):
@ -419,27 +536,7 @@ class JsonHandle:
data = json.load(f)
return data
class FileHandle:
def __init__(self, file=None):
self.file = file
def read(self):
with open(file=self.file, mode='r') as f:
print(f.read())
def read_link(self):
link = 'https://github.com/PlexPt/awesome-chatgpt-prompts-zh/blob/main/prompts-zh.json'
name = link.split('/')[3]+link.split('/')[-1]
new_file = os.path.join(base_path, 'gpt_log', name)
response = requests.get(url=link, verify=False)
with open(new_file, "wb") as f:
f.write(response.content)
if __name__ == '__main__':
for i in YamlHandle().load():
print(i)
pass

View File

@ -458,6 +458,9 @@ def find_recent_files(directory):
def get_user_upload(chatbot, ipaddr: gr.Request):
"""
获取用户上传过的文件
"""
private_upload = './private_upload'
user_history = os.path.join(private_upload, ipaddr.client.host)
history = ''
@ -471,6 +474,9 @@ def get_user_upload(chatbot, ipaddr: gr.Request):
def get_user_download(chatbot, link, file):
"""
将短路径转换为下载链接
"""
for file_handle in str(file).split('\n'):
if os.path.isfile(file_handle):
# temp_file = func_box.copy_temp_file(file_handle) 无法使用外部的临时目录