Files
gpt_academic/func_box.py

560 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#! .\venv\
# encoding: utf-8
# @Time : 2023/4/18
# @Author : Spike
# @Descr :
import hashlib
import io
import json
import os.path
import subprocess
import threading
import time
import psutil
import re
import tempfile
import shutil
from contextlib import ExitStack
import logging
import yaml
import requests
logger = logging
from sklearn.feature_extraction.text import CountVectorizer
import numpy as np
from scipy.linalg import norm
import pyperclip
import random
import gradio as gr
import toolbox
from prompt_generator import SqliteHandle
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
官网https://docs.python.org/3/library/contextlib.html"""
class Shell(object):
def __init__(self, args, stream=False):
self.args = args
self.subp = subprocess.Popen(args, shell=True,
stdin=subprocess.PIPE, stderr=subprocess.PIPE,
stdout=subprocess.PIPE, encoding='utf-8',
errors='ignore', close_fds=True)
self.__stream = stream
self.__temp = ''
def read(self):
logger.debug(f'The command being executed is: "{self.args}"')
if self.__stream:
sysout = self.subp.stdout
try:
with sysout as std:
for i in std:
logger.info(i.rstrip())
self.__temp += i
except KeyboardInterrupt as p:
return 3, self.__temp + self.subp.stderr.read()
finally:
return 3, self.__temp + self.subp.stderr.read()
else:
sysout = self.subp.stdout.read()
syserr = self.subp.stderr.read()
self.subp.stdin
if sysout:
logger.debug(f"{self.args} \n{sysout}")
return 1, sysout
elif syserr:
logger.error(f"{self.args} \n{syserr}")
return 0, syserr
else:
logger.debug(f"{self.args} \n{[sysout], [sysout]}")
return 2, '\n{}\n{}'.format(sysout, sysout)
def sync(self):
logger.debug('The command being executed is: "{}"'.format(self.args))
for i in self.subp.stdout:
logger.debug(i.rstrip())
self.__temp += i
yield self.__temp
for i in self.subp.stderr:
logger.debug(i.rstrip())
self.__temp += i
yield self.__temp
def timeStatistics(func):
"""
统计函数执行时常的装饰器
"""
def statistics(*args, **kwargs):
startTiem = time.time()
obj = func(*args, **kwargs)
endTiem = time.time()
ums = startTiem - endTiem
print('func:{} > Time-consuming: {}'.format(func, ums))
return obj
return statistics
def context_with(*parms):
"""
一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。
Args:
*parms: 参数列表,每个参数都是一个字符串,表示类中的一个属性名。
Returns:
一个装饰器函数。
"""
def decorator(cls_method):
"""
装饰器函数,用于将一个类方法转换为一个嵌套多个 with 语句的方法。
Args:
cls_method: 要装饰的类方法。
Returns:
装饰后的类方法。
"""
def wrapper(cls='', *args, **kwargs):
"""
装饰后的方法,用于嵌套多个 with 语句,并调用原始的类方法。
Args:
cls: 类的实例对象。
*args: 位置参数。
**kwargs: 关键字参数。
Returns:
原始的类方法返回的结果。
"""
with_list = [getattr(cls, arg) for arg in parms]
with ExitStack() as stack:
for context in with_list:
stack.enter_context(context)
return cls_method(cls, *args, **kwargs)
return wrapper
return decorator
def copy_temp_file(file):
if os.path.exists(file):
exdir = tempfile.mkdtemp()
temp_ = shutil.copy(file, os.path.join(exdir, os.path.basename(file)))
return temp_
else:
return None
def md5_str(st):
# 创建一个 MD5 对象
md5 = hashlib.md5()
# 更新 MD5 对象的内容
md5.update(str(st).encode())
# 获取加密后的结果
result = md5.hexdigest()
return result
def html_tag_color(tag, color=None, font='black'):
"""
将文本转换为带有高亮提示的html代码
"""
if not color:
rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
color = f"rgb{rgb}"
tag = f'<span style="background-color: {color}; font-weight: bold; color: {font}">&nbsp;{tag}&ensp;</span>'
return tag
def ipaddr():
# 获取本地ipx
ip = psutil.net_if_addrs()
for i in ip:
if ip[i][0][3]:
return ip[i][0][1]
def encryption_str(txt: str):
"""(关键字)(加密间隔)匹配机制(关键字间隔)"""
txt = str(txt)
pattern = re.compile(rf"(Authorization|WPS-Sid|Cookie)(:|\s+)\s*(\S+)[\s\S]*?(?=\n|$|\s)", re.IGNORECASE)
result = pattern.sub(lambda x: x.group(1) + ": XXXXXXXX", txt)
return result
def tree_out(dir=os.path.dirname(__file__), line=2, more=''):
"""
获取本地文件的树形结构转化为Markdown代码文本
"""
out = Shell(f'tree {dir} -F -I "__*|.*|venv|*.png|*.xlsx" -L {line} {more}').read()[1]
localfile = os.path.join(os.path.dirname(__file__), '.tree.md')
with open(localfile, 'w') as f:
f.write('```\n')
ll = out.splitlines()
for i in range(len(ll)):
if i == 0:
f.write(ll[i].split('/')[-2] + '\n')
else:
f.write(ll[i] + '\n')
f.write('```\n')
def chat_history(log: list, split=0):
"""
auto_gpt 使用的代码,后续会迁移
"""
if split:
log = log[split:]
chat = ''
history = ''
for i in log:
chat += f'{i[0]}\n\n'
history += f'{i[1]}\n\n'
return chat, history
def df_similarity(s1, s2):
"""弃用,会警告,这个库不会用"""
def add_space(s):
return ' '.join(list(s))
# 将字中间加入空格
s1, s2 = add_space(s1), add_space(s2)
# 转化为TF矩阵
cv = CountVectorizer(tokenizer=lambda s: s.split())
corpus = [s1, s2]
vectors = cv.fit_transform(corpus).toarray()
# 计算TF系数
return np.dot(vectors[0], vectors[1]) / (norm(vectors[0]) * norm(vectors[1]))
def check_json_format(file):
"""
检查上传的Json文件是否符合规范
"""
new_dict = {}
data = JsonHandle(file).load()
if type(data) is list and len(data) > 0:
if type(data[0]) is dict:
for i in data:
new_dict.update({i['act']: i['prompt']})
return new_dict
def json_convert_dict(file):
"""
批量将json转换为字典
"""
new_dict = {}
for root, dirs, files in os.walk(file):
for f in files:
if f.startswith('prompt') and f.endswith('json'):
new_dict.update(check_json_format(f))
return new_dict
def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request):
"""
绘制搜索结果
Args:
txt (str): 过滤文本
prompt : 原始的dataset对象
percent (int): TF系数用于计算文本相似度
switch (list): 过滤个人或所有人的Prompt
ipaddr : 请求人信息
Returns:
注册函数所需的元祖对象
"""
data = diff_list(txt, percent=percent, switch=switch, hosts=ipaddr.client.host)
prompt.samples = data
return prompt.update(samples=data, visible=True), prompt
def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''):
"""
按照搜索结果统计相似度的文本,两组文本相似度>70%的将统计在一起取最长的作为key
Args:
txt (str): 过滤文本
percent (int): TF系数用于计算文本相似度
switch (list): 过滤个人或所有人的Prompt
lst指定一个列表或字典
sp: 截取展示的文本长度
hosts : 请求人的ip
Returns:
返回一个列表
"""
import difflib
count_dict = {}
if not lst:
lst = SqliteHandle('ai_common').get_prompt_value(txt)
lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value(txt))
# diff 数据根据precent系数归类数据
for i in lst:
found = False
for key in count_dict.keys():
str_tf = difflib.SequenceMatcher(None, i, key).ratio()
if str_tf >= percent:
if len(i) > len(key):
count_dict[i] = count_dict[key] + 1
count_dict.pop(key)
else:
count_dict[key] += 1
found = True
break
if not found: count_dict[i] = 1
sorted_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)
if switch:
sorted_dict += prompt_retrieval(is_all=switch, hosts=hosts, search=True)
dateset_list = []
for key in sorted_dict:
# 开始匹配关键字
index = key[0].find(txt)
if index != -1:
# sp=split 用于判断在哪里启动、在哪里断开
if index - sp > 0:
start = index - sp
else:
start = 0
if len(key[0]) > sp * 2:
end = key[0][-sp:]
else:
end = ''
# 判断有没有传需要匹配的字符串,有则筛选、无则全返
if txt == '' and len(key[0]) >= sp:
show = key[0][0:sp] + " . . . " + end
elif txt == '' and len(key[0]) < sp:
show = key[0][0:sp]
else:
show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt))
show += f" {html_tag_color(' X ' + str(key[1]))}"
if lst.get(key[0]):
be_value = lst[key[0]]
else:
be_value = "这个prompt还没有对话过呢快去试试吧"
value = be_value
dateset_list.append([show, key[0], value])
return dateset_list
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
"""
上传文件将文件转换为字典然后存储到数据库并刷新Prompt区域
Args:
file 上传的文件
prompt 原始prompt对象
ipaddripaddr用户请求信息
Returns:
注册函数所需的元祖对象
"""
hosts = ipaddr.client.host
if file.name.endswith('json'):
upload_data = check_json_format(file.name)
elif file.name.endswith('yaml'):
upload_data = YamlHandle(file.name).load()
else:
upload_data = {}
if upload_data != {}:
SqliteHandle(f'prompt_{hosts}').inset_prompt(upload_data)
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
return prompt.update(samples=ret_data, visible=True), prompt, ['个人']
else:
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
return prompt.samples, prompt, []
def prompt_retrieval(is_all, hosts='', search=False):
"""
上传文件将文件转换为字典然后存储到数据库并刷新Prompt区域
Args:
is_all prompt类型
hosts 查询的用户ip
search支持搜索搜索时将key作为key
Returns:
返回一个列表
"""
count_dict = {}
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
if '所有人' in is_all:
for tab in SqliteHandle('ai_common').get_tables():
if tab.startswith('prompt'):
data = SqliteHandle(tab).get_prompt_value(None)
if data: count_dict.update(data)
elif '个人' in is_all:
data = SqliteHandle(f'prompt_{hosts}').get_prompt_value(None)
if data: count_dict.update(data)
retrieval = []
if count_dict != {}:
for key in count_dict:
if not search:
retrieval.append([key, count_dict[key]])
else:
retrieval.append([count_dict[key], key])
return retrieval
else:
return retrieval
def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request
"""
上传文件将文件转换为字典然后存储到数据库并刷新Prompt区域
Args:
is_all prompt类型
prompt dataset原始对象
ipaddr请求用户信息
Returns:
返回注册函数所需的对象
"""
data = prompt_retrieval(is_all=is_all, hosts=ipaddr.client.host)
prompt.samples = data
return prompt.update(samples=data, visible=True), prompt, is_all
def prompt_save(txt, name, prompt: gr.Dataset, ipaddr: gr.Request):
"""
编辑和保存Prompt
Args:
txt Prompt正文
name Prompt的名字
prompt dataset原始对象
ipaddr请求用户信息
Returns:
返回注册函数所需的对象
"""
if txt and name:
yaml_obj = SqliteHandle(f'prompt_{ipaddr.client.host}')
yaml_obj.inset_prompt({name: txt})
result = prompt_retrieval(is_all=['个人'], hosts=ipaddr.client.host)
prompt.samples = result
return "", "", ['个人'], prompt.update(samples=result, visible=True), prompt
elif not txt or not name:
result = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']]
prompt.samples = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']]
return txt, name, [], prompt.update(samples=result, visible=True), prompt
def prompt_input(txt, index, data: gr.Dataset):
"""
点击dataset的值使用Prompt
Args:
txt 输入框正文
index 点击的Dataset下标
data dataset原始对象
Returns:
返回注册函数所需的对象
"""
data_str = str(data.samples[index][1])
if txt:
txt = data_str + '\n' + txt
else:
txt = data_str
return txt
def copy_result(history):
"""复制history"""
if history != []:
pyperclip.copy(history[-1])
return '已将结果复制到剪切板'
else:
return "无对话记录,复制错误!!"
def show_prompt_result(index, data: gr.Dataset, chatbot):
"""
查看Prompt的对话记录结果
Args:
index 点击的Dataset下标
data dataset原始对象
chatbot聊天机器人
Returns:
返回注册函数所需的对象
"""
click = data.samples[index]
chatbot.append((click[1], click[2]))
return chatbot
def thread_write_chat(chatbot):
"""
对话记录写入数据库
"""
private_key = toolbox.get_conf('private_key')[0]
chat_title = chatbot[0][1].split()
pattern = re.compile(r'^<div class="markdown-body"><p>|<\/p><\/div>$')
i_say = pattern.sub('', chatbot[-1][0])
gpt_result = pattern.sub('', chatbot[-1][1])
if private_key in chat_title:
SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({i_say: gpt_result})
else:
SqliteHandle(f'ai_common').inset_prompt({i_say: gpt_result})
base_path = os.path.dirname(__file__)
prompt_path = os.path.join(base_path, 'prompt_users')
def reuse_chat(result, chatbot, history):
"""复用对话记录"""
if result is None or result == []:
pass
else:
chatbot.append(result[-1])
history += result[-1]
pattern = re.compile(r'^<div class="markdown-body"><p>|<\/p><\/div>$')
i_say = pattern.sub('', chatbot[-1][0])
return chatbot, history, i_say, gr.Tabs.update(selected='chatbot')
class YamlHandle:
def __init__(self, file=os.path.join(prompt_path, 'ai_common.yaml')):
if not os.path.exists(file):
Shell(f'touch {file}').read()
self.file = file
self._load = self.load()
def load(self) -> dict:
with open(file=self.file, mode='r') as f:
data = yaml.safe_load(f)
return data
def update(self, key, value):
date = self._load
if not date:
date = {}
date[key] = value
with open(file=self.file, mode='w') as f:
yaml.dump(date, f, allow_unicode=True)
return date
def dump_dict(self, new_dict):
date = self._load
if not date:
date = {}
date.update(new_dict)
with open(file=self.file, mode='w') as f:
yaml.dump(date, f, allow_unicode=True)
return date
class JsonHandle:
def __init__(self, file):
if os.path.exists(file):
with open(file=file, mode='r') as self.file_obj:
pass
else:
self.file_obj = io.StringIO() # 创建空白文本对象
self.file_obj.write('{}') # 向文本对象写入有有效 JSON 格式的数据
self.file_obj.seek(0) # 将文本对象的光标重置到开头
def load(self):
data = json.load(self.file_obj)
return data
if __name__ == '__main__':
pass