删除多余代码|增加代码注释
This commit is contained in:
@ -123,7 +123,7 @@ class ChatBot(ChatBotFrame):
|
||||
inputs=[self.pro_prompt_list, self.pro_prompt_state, self.pro_results],
|
||||
outputs=[self.pro_results])
|
||||
self.pro_new_btn.click(fn=func_box.prompt_save,
|
||||
inputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_private_check, self.pro_fp_state],
|
||||
inputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_fp_state],
|
||||
outputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_private_check,
|
||||
self.pro_func_prompt, self.pro_fp_state])
|
||||
|
||||
|
||||
197
func_box.py
197
func_box.py
@ -17,6 +17,7 @@ from contextlib import ExitStack
|
||||
import logging
|
||||
import yaml
|
||||
import requests
|
||||
|
||||
logger = logging
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
import numpy as np
|
||||
@ -26,9 +27,11 @@ import random
|
||||
import gradio as gr
|
||||
import toolbox
|
||||
from prompt_generator import SqliteHandle
|
||||
|
||||
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
|
||||
官网:https://docs.python.org/3/library/contextlib.html"""
|
||||
|
||||
|
||||
class Shell(object):
|
||||
def __init__(self, args, stream=False):
|
||||
self.args = args
|
||||
@ -77,7 +80,12 @@ class Shell(object):
|
||||
self.__temp += i
|
||||
yield self.__temp
|
||||
|
||||
|
||||
def timeStatistics(func):
|
||||
"""
|
||||
统计函数执行时常的装饰器
|
||||
"""
|
||||
|
||||
def statistics(*args, **kwargs):
|
||||
startTiem = time.time()
|
||||
obj = func(*args, **kwargs)
|
||||
@ -85,8 +93,10 @@ def timeStatistics(func):
|
||||
ums = startTiem - endTiem
|
||||
print('func:{} > Time-consuming: {}'.format(func, ums))
|
||||
return obj
|
||||
|
||||
return statistics
|
||||
|
||||
|
||||
def context_with(*parms):
|
||||
"""
|
||||
一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。
|
||||
@ -95,6 +105,7 @@ def context_with(*parms):
|
||||
Returns:
|
||||
一个装饰器函数。
|
||||
"""
|
||||
|
||||
def decorator(cls_method):
|
||||
"""
|
||||
装饰器函数,用于将一个类方法转换为一个嵌套多个 with 语句的方法。
|
||||
@ -103,6 +114,7 @@ def context_with(*parms):
|
||||
Returns:
|
||||
装饰后的类方法。
|
||||
"""
|
||||
|
||||
def wrapper(cls='', *args, **kwargs):
|
||||
"""
|
||||
装饰后的方法,用于嵌套多个 with 语句,并调用原始的类方法。
|
||||
@ -118,7 +130,9 @@ def context_with(*parms):
|
||||
for context in with_list:
|
||||
stack.enter_context(context)
|
||||
return cls_method(cls, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@ -130,6 +144,7 @@ def copy_temp_file(file):
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def md5_str(st):
|
||||
# 创建一个 MD5 对象
|
||||
md5 = hashlib.md5()
|
||||
@ -141,18 +156,24 @@ def md5_str(st):
|
||||
|
||||
|
||||
def html_tag_color(tag, color=None):
|
||||
"""
|
||||
将文本转换为带有高亮提示的html代码
|
||||
"""
|
||||
if not color:
|
||||
rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
||||
color = f"rgb{rgb}"
|
||||
tag = f'<span style="background-color: {color}; font-weight: bold; color: black"> {tag} </span>'
|
||||
return tag
|
||||
|
||||
def ipaddr(): # 获取本地ipx
|
||||
|
||||
def ipaddr():
|
||||
# 获取本地ipx
|
||||
ip = psutil.net_if_addrs()
|
||||
for i in ip:
|
||||
if ip[i][0][3]:
|
||||
return ip[i][0][1]
|
||||
|
||||
|
||||
def encryption_str(txt: str):
|
||||
"""(关键字)(加密间隔)匹配机制(关键字间隔)"""
|
||||
txt = str(txt)
|
||||
@ -160,7 +181,11 @@ def encryption_str(txt: str):
|
||||
result = pattern.sub(lambda x: x.group(1) + ": XXXXXXXX", txt)
|
||||
return result
|
||||
|
||||
|
||||
def tree_out(dir=os.path.dirname(__file__), line=2, more=''):
|
||||
"""
|
||||
获取本地文件的树形结构转化为Markdown代码文本
|
||||
"""
|
||||
out = Shell(f'tree {dir} -F -I "__*|.*|venv|*.png|*.xlsx" -L {line} {more}').read()[1]
|
||||
localfile = os.path.join(os.path.dirname(__file__), '.tree.md')
|
||||
with open(localfile, 'w') as f:
|
||||
@ -175,6 +200,9 @@ def tree_out(dir=os.path.dirname(__file__), line=2, more=''):
|
||||
|
||||
|
||||
def chat_history(log: list, split=0):
|
||||
"""
|
||||
auto_gpt 使用的代码,后续会迁移
|
||||
"""
|
||||
if split:
|
||||
log = log[split:]
|
||||
chat = ''
|
||||
@ -189,6 +217,7 @@ def df_similarity(s1, s2):
|
||||
"""弃用,会警告,这个库不会用"""
|
||||
def add_space(s):
|
||||
return ' '.join(list(s))
|
||||
|
||||
# 将字中间加入空格
|
||||
s1, s2 = add_space(s1), add_space(s2)
|
||||
# 转化为TF矩阵
|
||||
@ -200,6 +229,9 @@ def df_similarity(s1, s2):
|
||||
|
||||
|
||||
def check_json_format(file):
|
||||
"""
|
||||
检查上传的Json文件是否符合规范
|
||||
"""
|
||||
new_dict = {}
|
||||
data = JsonHandle(file).load()
|
||||
if type(data) is list and len(data) > 0:
|
||||
@ -208,21 +240,49 @@ def check_json_format(file):
|
||||
new_dict.update({i['act']: i['prompt']})
|
||||
return new_dict
|
||||
|
||||
def json_convert_dict():
|
||||
|
||||
def json_convert_dict(file):
|
||||
"""
|
||||
批量将json转换为字典
|
||||
"""
|
||||
new_dict = {}
|
||||
for root, dirs, files in os.walk(prompt_path):
|
||||
for root, dirs, files in os.walk(file):
|
||||
for f in files:
|
||||
if f.startswith('prompt') and f.endswith('json'):
|
||||
new_dict.update(check_json_format(f))
|
||||
return new_dict
|
||||
|
||||
|
||||
def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request):
|
||||
"""
|
||||
绘制搜索结果
|
||||
Args:
|
||||
txt (str): 过滤文本
|
||||
prompt : 原始的dataset对象
|
||||
percent (int): TF系数,用于计算文本相似度
|
||||
switch (list): 过滤个人或所有人的Prompt
|
||||
ipaddr : 请求人信息
|
||||
Returns:
|
||||
注册函数所需的元祖对象
|
||||
"""
|
||||
data = diff_list(txt, percent=percent, switch=switch, hosts=ipaddr.client.host)
|
||||
prompt.samples = data
|
||||
return prompt.update(samples=data, visible=True), prompt
|
||||
|
||||
|
||||
def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''):
|
||||
"""
|
||||
按照搜索结果统计相似度的文本,两组文本相似度>70%的将统计在一起,取最长的作为key
|
||||
Args:
|
||||
txt (str): 过滤文本
|
||||
percent (int): TF系数,用于计算文本相似度
|
||||
switch (list): 过滤个人或所有人的Prompt
|
||||
lst:指定一个列表或字典
|
||||
sp: 截取展示的文本长度
|
||||
hosts : 请求人的ip
|
||||
Returns:
|
||||
返回一个列表
|
||||
"""
|
||||
import difflib
|
||||
count_dict = {}
|
||||
if not lst:
|
||||
@ -251,37 +311,41 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
|
||||
index = key[0].find(txt)
|
||||
if index != -1:
|
||||
# sp=split 用于判断在哪里启动、在哪里断开
|
||||
if index-sp > 0: start = index-sp
|
||||
else: start = 0
|
||||
if len(key[0]) > sp * 2: end = key[0][-sp:]
|
||||
else: end = ''
|
||||
if index - sp > 0:
|
||||
start = index - sp
|
||||
else:
|
||||
start = 0
|
||||
if len(key[0]) > sp * 2:
|
||||
end = key[0][-sp:]
|
||||
else:
|
||||
end = ''
|
||||
# 判断有没有传需要匹配的字符串,有则筛选、无则全返
|
||||
if txt == '' and len(key[0]) >= sp: show = key[0][0:sp] + " . . . " + end
|
||||
elif txt == '' and len(key[0]) < sp: show = key[0][0:sp]
|
||||
else: show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt))
|
||||
if txt == '' and len(key[0]) >= sp:
|
||||
show = key[0][0:sp] + " . . . " + end
|
||||
elif txt == '' and len(key[0]) < sp:
|
||||
show = key[0][0:sp]
|
||||
else:
|
||||
show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt))
|
||||
show += f" {html_tag_color(' X ' + str(key[1]))}"
|
||||
if lst.get(key[0]): be_value = lst[key[0]]
|
||||
else: be_value = "这个prompt还没有对话过呢,快去试试吧~"
|
||||
if lst.get(key[0]):
|
||||
be_value = lst[key[0]]
|
||||
else:
|
||||
be_value = "这个prompt还没有对话过呢,快去试试吧~"
|
||||
value = be_value
|
||||
dateset_list.append([show, key[0], value])
|
||||
return dateset_list
|
||||
|
||||
|
||||
def search_list(txt, sp=15):
|
||||
lst = SqliteHandle('ai_common').get_prompt_value()
|
||||
dateset_list = []
|
||||
for key in lst:
|
||||
index = key.find(txt)
|
||||
if index != -1:
|
||||
if index-sp > 0: start = index-sp
|
||||
else: start = 0
|
||||
show = str(key[start:index+sp]).replace(txt, html_tag_color(txt))
|
||||
value = lst[key]
|
||||
dateset_list.append([show, key, value])
|
||||
return dateset_list
|
||||
|
||||
|
||||
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
|
||||
"""
|
||||
上传文件,将文件转换为字典,然后存储到数据库,并刷新Prompt区域
|
||||
Args:
|
||||
file: 上传的文件
|
||||
prompt: 原始prompt对象
|
||||
ipaddr:ipaddr用户请求信息
|
||||
Returns:
|
||||
注册函数所需的元祖对象
|
||||
"""
|
||||
hosts = ipaddr.client.host
|
||||
if file.name.endswith('json'):
|
||||
upload_data = check_json_format(file.name)
|
||||
@ -299,6 +363,15 @@ def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
|
||||
|
||||
|
||||
def prompt_retrieval(is_all, hosts='', search=False):
|
||||
"""
|
||||
上传文件,将文件转换为字典,然后存储到数据库,并刷新Prompt区域
|
||||
Args:
|
||||
is_all: prompt类型
|
||||
hosts: 查询的用户ip
|
||||
search:支持搜索,搜索时将key作为key
|
||||
Returns:
|
||||
返回一个列表
|
||||
"""
|
||||
count_dict = {}
|
||||
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
|
||||
if '所有人' in is_all:
|
||||
@ -322,16 +395,35 @@ def prompt_retrieval(is_all, hosts='', search=False):
|
||||
|
||||
|
||||
def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request
|
||||
"""
|
||||
上传文件,将文件转换为字典,然后存储到数据库,并刷新Prompt区域
|
||||
Args:
|
||||
is_all: prompt类型
|
||||
prompt: dataset原始对象
|
||||
ipaddr:请求用户信息
|
||||
Returns:
|
||||
返回注册函数所需的对象
|
||||
"""
|
||||
data = prompt_retrieval(is_all=is_all, hosts=ipaddr.client.host)
|
||||
prompt.samples = data
|
||||
return prompt.update(samples=data, visible=True), prompt, is_all
|
||||
|
||||
|
||||
def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request):
|
||||
def prompt_save(txt, name, prompt: gr.Dataset, ipaddr: gr.Request):
|
||||
"""
|
||||
编辑和保存Prompt
|
||||
Args:
|
||||
txt: Prompt正文
|
||||
name: Prompt的名字
|
||||
prompt: dataset原始对象
|
||||
ipaddr:请求用户信息
|
||||
Returns:
|
||||
返回注册函数所需的对象
|
||||
"""
|
||||
if txt and name:
|
||||
yaml_obj = SqliteHandle(f'prompt_{ipaddr.client.host}')
|
||||
yaml_obj.inset_prompt({name: txt})
|
||||
result = prompt_retrieval(is_all=checkbox+['个人'], hosts=ipaddr.client.host)
|
||||
result = prompt_retrieval(is_all=['个人'], hosts=ipaddr.client.host)
|
||||
prompt.samples = result
|
||||
return "", "", ['个人'], prompt.update(samples=result, visible=True), prompt
|
||||
elif not txt or not name:
|
||||
@ -339,7 +431,17 @@ def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request):
|
||||
prompt.samples = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']]
|
||||
return txt, name, [], prompt.update(samples=result, visible=True), prompt
|
||||
|
||||
|
||||
def prompt_input(txt, index, data: gr.Dataset):
|
||||
"""
|
||||
点击dataset的值使用Prompt
|
||||
Args:
|
||||
txt: 输入框正文
|
||||
index: 点击的Dataset下标
|
||||
data: dataset原始对象
|
||||
Returns:
|
||||
返回注册函数所需的对象
|
||||
"""
|
||||
data_str = str(data.samples[index][1])
|
||||
if txt:
|
||||
txt = data_str + '\n' + txt
|
||||
@ -347,7 +449,9 @@ def prompt_input(txt, index, data: gr.Dataset):
|
||||
txt = data_str
|
||||
return txt
|
||||
|
||||
|
||||
def copy_result(history):
|
||||
"""复制history"""
|
||||
if history != []:
|
||||
pyperclip.copy(history[-1])
|
||||
return '已将结果复制到剪切板'
|
||||
@ -356,12 +460,24 @@ def copy_result(history):
|
||||
|
||||
|
||||
def show_prompt_result(index, data: gr.Dataset, chatbot):
|
||||
"""
|
||||
查看Prompt的对话记录结果
|
||||
Args:
|
||||
index: 点击的Dataset下标
|
||||
data: dataset原始对象
|
||||
chatbot:聊天机器人
|
||||
Returns:
|
||||
返回注册函数所需的对象
|
||||
"""
|
||||
click = data.samples[index]
|
||||
chatbot.append((click[1], click[2]))
|
||||
return chatbot
|
||||
|
||||
|
||||
def thread_write_chat(chatbot):
|
||||
"""
|
||||
对话记录写入数据库
|
||||
"""
|
||||
private_key = toolbox.get_conf('private_key')[0]
|
||||
chat_title = chatbot[0][0].split()
|
||||
i_say = chatbot[-1][0].strip("<p>/p")
|
||||
@ -375,6 +491,7 @@ def thread_write_chat(chatbot):
|
||||
base_path = os.path.dirname(__file__)
|
||||
prompt_path = os.path.join(base_path, 'prompt_users')
|
||||
|
||||
|
||||
class YamlHandle:
|
||||
|
||||
def __init__(self, file=os.path.join(prompt_path, 'ai_common.yaml')):
|
||||
@ -383,7 +500,6 @@ class YamlHandle:
|
||||
self.file = file
|
||||
self._load = self.load()
|
||||
|
||||
|
||||
def load(self) -> dict:
|
||||
with open(file=self.file, mode='r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
@ -407,6 +523,7 @@ class YamlHandle:
|
||||
yaml.dump(date, f, allow_unicode=True)
|
||||
return date
|
||||
|
||||
|
||||
class JsonHandle:
|
||||
|
||||
def __init__(self, file=os.path.join(prompt_path, 'prompts-PlexPt.json')):
|
||||
@ -419,27 +536,7 @@ class JsonHandle:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
class FileHandle:
|
||||
|
||||
def __init__(self, file=None):
|
||||
self.file = file
|
||||
|
||||
def read(self):
|
||||
with open(file=self.file, mode='r') as f:
|
||||
print(f.read())
|
||||
|
||||
|
||||
def read_link(self):
|
||||
link = 'https://github.com/PlexPt/awesome-chatgpt-prompts-zh/blob/main/prompts-zh.json'
|
||||
name = link.split('/')[3]+link.split('/')[-1]
|
||||
new_file = os.path.join(base_path, 'gpt_log', name)
|
||||
response = requests.get(url=link, verify=False)
|
||||
with open(new_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for i in YamlHandle().load():
|
||||
print(i)
|
||||
|
||||
|
||||
pass
|
||||
@ -458,6 +458,9 @@ def find_recent_files(directory):
|
||||
|
||||
|
||||
def get_user_upload(chatbot, ipaddr: gr.Request):
|
||||
"""
|
||||
获取用户上传过的文件
|
||||
"""
|
||||
private_upload = './private_upload'
|
||||
user_history = os.path.join(private_upload, ipaddr.client.host)
|
||||
history = ''
|
||||
@ -471,6 +474,9 @@ def get_user_upload(chatbot, ipaddr: gr.Request):
|
||||
|
||||
|
||||
def get_user_download(chatbot, link, file):
|
||||
"""
|
||||
将短路径转换为下载链接
|
||||
"""
|
||||
for file_handle in str(file).split('\n'):
|
||||
if os.path.isfile(file_handle):
|
||||
# temp_file = func_box.copy_temp_file(file_handle) 无法使用外部的临时目录
|
||||
|
||||
Reference in New Issue
Block a user