427 lines
15 KiB
Python
427 lines
15 KiB
Python
#! .\venv\
|
||
# encoding: utf-8
|
||
# @Time : 2023/4/18
|
||
# @Author : Spike
|
||
# @Descr :
|
||
import hashlib
|
||
import json
|
||
import os.path
|
||
import subprocess
|
||
import psutil
|
||
import re
|
||
import tempfile
|
||
import shutil
|
||
from contextlib import ExitStack
|
||
import logging
|
||
import yaml
|
||
import requests
|
||
logger = logging
|
||
from sklearn.feature_extraction.text import CountVectorizer
|
||
import numpy as np
|
||
from scipy.linalg import norm
|
||
import pyperclip
|
||
import random
|
||
import gradio as gr
|
||
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
|
||
官网:https://docs.python.org/3/library/contextlib.html"""
|
||
|
||
class Shell(object):
|
||
def __init__(self, args, stream=False):
|
||
self.args = args
|
||
self.subp = subprocess.Popen(args, shell=True,
|
||
stdin=subprocess.PIPE, stderr=subprocess.PIPE,
|
||
stdout=subprocess.PIPE, encoding='utf-8',
|
||
errors='ignore', close_fds=True)
|
||
self.__stream = stream
|
||
self.__temp = ''
|
||
|
||
def read(self):
|
||
logger.debug(f'The command being executed is: "{self.args}"')
|
||
if self.__stream:
|
||
sysout = self.subp.stdout
|
||
try:
|
||
with sysout as std:
|
||
for i in std:
|
||
logger.info(i.rstrip())
|
||
self.__temp += i
|
||
except KeyboardInterrupt as p:
|
||
return 3, self.__temp+self.subp.stderr.read()
|
||
finally:
|
||
return 3, self.__temp+self.subp.stderr.read()
|
||
else:
|
||
sysout = self.subp.stdout.read()
|
||
syserr = self.subp.stderr.read()
|
||
if sysout:
|
||
logger.debug(f"{self.args} \n{sysout}")
|
||
return 1, sysout
|
||
elif syserr:
|
||
logger.error(f"{self.args} \n{syserr}")
|
||
return 0, syserr
|
||
else:
|
||
logger.debug(f"{self.args} \n{[sysout], [sysout]}")
|
||
return 2, '\n{}\n{}'.format(sysout, sysout)
|
||
|
||
def sync(self):
|
||
logger.debug('The command being executed is: "{}"'.format(self.args))
|
||
for i in self.subp.stdout:
|
||
logger.debug(i.rstrip())
|
||
self.__temp += i
|
||
yield self.__temp
|
||
for i in self.subp.stderr:
|
||
logger.debug(i.rstrip())
|
||
self.__temp += i
|
||
yield self.__temp
|
||
|
||
def context_with(*parms):
|
||
"""
|
||
一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。
|
||
Args:
|
||
*parms: 参数列表,每个参数都是一个字符串,表示类中的一个属性名。
|
||
Returns:
|
||
一个装饰器函数。
|
||
"""
|
||
def decorator(cls_method):
|
||
"""
|
||
装饰器函数,用于将一个类方法转换为一个嵌套多个 with 语句的方法。
|
||
Args:
|
||
cls_method: 要装饰的类方法。
|
||
Returns:
|
||
装饰后的类方法。
|
||
"""
|
||
def wrapper(cls='', *args, **kwargs):
|
||
"""
|
||
装饰后的方法,用于嵌套多个 with 语句,并调用原始的类方法。
|
||
Args:
|
||
cls: 类的实例对象。
|
||
*args: 位置参数。
|
||
**kwargs: 关键字参数。
|
||
Returns:
|
||
原始的类方法返回的结果。
|
||
"""
|
||
with_list = [getattr(cls, arg) for arg in parms]
|
||
with ExitStack() as stack:
|
||
for context in with_list:
|
||
stack.enter_context(context)
|
||
return cls_method(cls, *args, **kwargs)
|
||
return wrapper
|
||
return decorator
|
||
|
||
|
||
def copy_temp_file(file):
|
||
if os.path.exists(file):
|
||
exdir = tempfile.mkdtemp()
|
||
temp_ = shutil.copy(file, os.path.join(exdir, os.path.basename(file)))
|
||
return temp_
|
||
else:
|
||
return None
|
||
|
||
def md5_str(st):
|
||
# 创建一个 MD5 对象
|
||
md5 = hashlib.md5()
|
||
# 更新 MD5 对象的内容
|
||
md5.update(str(st).encode())
|
||
# 获取加密后的结果
|
||
result = md5.hexdigest()
|
||
return result
|
||
|
||
|
||
def html_tag_color(tag, color=None):
|
||
if not color:
|
||
rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
||
color = f"rgb{rgb}"
|
||
tag = f'<span style="background-color: {color}; font-weight: bold; color: black"> {tag} </span>'
|
||
return tag
|
||
|
||
def ipaddr(): # 获取本地ipx
|
||
ip = psutil.net_if_addrs()
|
||
for i in ip:
|
||
if ip[i][0][3]:
|
||
return ip[i][0][1]
|
||
|
||
def encryption_str(txt: str):
|
||
"""(关键字)(加密间隔)匹配机制(关键字间隔)"""
|
||
txt = str(txt)
|
||
pattern = re.compile(rf"(Authorization|WPS-Sid|Cookie)(:|\s+)\s*(\S+)[\s\S]*?(?=\n|$|\s)", re.IGNORECASE)
|
||
result = pattern.sub(lambda x: x.group(1) + ": XXXXXXXX", txt)
|
||
return result
|
||
|
||
def tree_out(dir=os.path.dirname(__file__), line=2, more=''):
|
||
out = Shell(f'tree {dir} -F -I "__*|.*|venv|*.png|*.xlsx" -L {line} {more}').read()[1]
|
||
localfile = os.path.join(os.path.dirname(__file__), '.tree.md')
|
||
with open(localfile, 'w') as f:
|
||
f.write('```\n')
|
||
ll = out.splitlines()
|
||
for i in range(len(ll)):
|
||
if i == 0:
|
||
f.write(ll[i].split('/')[-2]+'\n')
|
||
else:
|
||
f.write(ll[i]+'\n')
|
||
f.write('```\n')
|
||
|
||
|
||
def chat_history(log: list, split=0):
|
||
if split:
|
||
log = log[split:]
|
||
chat = ''
|
||
history = ''
|
||
for i in log:
|
||
chat += f'{i[0]}\n\n'
|
||
history += f'{i[1]}\n\n'
|
||
return chat, history
|
||
|
||
|
||
def df_similarity(s1, s2):
|
||
"""弃用,会警告,这个库不会用"""
|
||
def add_space(s):
|
||
return ' '.join(list(s))
|
||
# 将字中间加入空格
|
||
s1, s2 = add_space(s1), add_space(s2)
|
||
# 转化为TF矩阵
|
||
cv = CountVectorizer(tokenizer=lambda s: s.split())
|
||
corpus = [s1, s2]
|
||
vectors = cv.fit_transform(corpus).toarray()
|
||
# 计算TF系数
|
||
return np.dot(vectors[0], vectors[1]) / (norm(vectors[0]) * norm(vectors[1]))
|
||
|
||
|
||
def check_json_format(file):
|
||
new_dict = {}
|
||
data = JsonHandle(file).load()
|
||
if type(data) is list and len(data) > 0:
|
||
if type(data[0]) is dict:
|
||
for i in data:
|
||
new_dict.update({i['act']: i['prompt']})
|
||
return new_dict
|
||
|
||
def json_convert_dict():
|
||
new_dict = {}
|
||
for root, dirs, files in os.walk(prompt_path):
|
||
for f in files:
|
||
if f.startswith('prompt') and f.endswith('json'):
|
||
new_dict.update(check_json_format(f))
|
||
return new_dict
|
||
|
||
def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request):
|
||
data = diff_list(txt, percent=percent, switch=switch, hosts=ipaddr.client.host)
|
||
prompt.samples = data
|
||
return prompt.update(samples=data, samples_per_page=10), prompt
|
||
|
||
|
||
def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''):
|
||
import difflib
|
||
count_dict = {}
|
||
if not lst:
|
||
lst = YamlHandle().load()
|
||
lst.update(YamlHandle(os.path.join(prompt_path, f"ai_private_{hosts}.yaml")).load())
|
||
# diff 数据,根据precent系数归类数据
|
||
for i in lst:
|
||
found = False
|
||
for key in count_dict.keys():
|
||
str_tf = difflib.SequenceMatcher(None, i, key).ratio()
|
||
if str_tf >= percent:
|
||
if len(i) > len(key):
|
||
count_dict[i] = count_dict[key] + 1
|
||
count_dict.pop(key)
|
||
else:
|
||
count_dict[key] += 1
|
||
found = True
|
||
break
|
||
if not found: count_dict[i] = 1
|
||
sorted_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)
|
||
if switch:
|
||
sorted_dict += prompt_retrieval(is_all=switch, hosts=hosts, search=True)
|
||
dateset_list = []
|
||
for key in sorted_dict:
|
||
# 开始匹配关键字
|
||
index = key[0].find(txt)
|
||
if index != -1:
|
||
# sp=split 用于判断在哪里启动、在哪里断开
|
||
if index-sp > 0: start = index-sp
|
||
else: start = 0
|
||
if len(key[0]) > sp * 2: end = key[0][-sp:]
|
||
else: end = ''
|
||
# 判断有没有传需要匹配的字符串,有则筛选、无则全返
|
||
if txt == '' and len(key[0]) >= sp: show = key[0][0:sp] + " . . . " + end
|
||
elif txt == '' and len(key[0]) < sp: show = key[0][0:sp]
|
||
else: show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt))
|
||
show += f" {html_tag_color(' X ' + str(key[1]))}"
|
||
if lst.get(key[0]): be_value = lst[key[0]]
|
||
else: be_value = "这个prompt还没有对话过呢,快去试试吧~"
|
||
value = be_value
|
||
dateset_list.append([show, key[0], value])
|
||
return dateset_list
|
||
|
||
|
||
def search_list(txt, sp=15):
|
||
lst = YamlHandle().load()
|
||
dateset_list = []
|
||
for key in lst:
|
||
index = key.find(txt)
|
||
if index != -1:
|
||
if index-sp > 0: start = index-sp
|
||
else: start = 0
|
||
show = str(key[start:index+sp]).replace(txt, html_tag_color(txt))
|
||
value = lst[key]
|
||
dateset_list.append([show, key, value])
|
||
return dateset_list
|
||
|
||
|
||
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
|
||
hosts = ipaddr.client.host
|
||
user_file = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
|
||
if file.name.endswith('json'):
|
||
upload_data = check_json_format(file.name)
|
||
if upload_data != {}:
|
||
YamlHandle(user_file).dump_dict(upload_data)
|
||
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
|
||
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
|
||
else:
|
||
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
|
||
return prompt.samples, prompt, []
|
||
elif file.name.endswith('yaml'):
|
||
upload_data = YamlHandle(file.name).load()
|
||
if upload_data != {} and type(upload_data) is dict:
|
||
YamlHandle(user_file).dump_dict(upload_data)
|
||
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
|
||
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
|
||
else:
|
||
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
|
||
return prompt.samples, prompt, []
|
||
|
||
def prompt_retrieval(is_all, hosts='', search=False):
|
||
count_dict = {}
|
||
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
|
||
if '所有人' in is_all:
|
||
for root, dirs, files in os.walk(prompt_path):
|
||
for f in files:
|
||
if f.startswith('prompt') and f.endswith('yaml'):
|
||
data = YamlHandle(file=os.path.join(root, f)).load()
|
||
if data: count_dict.update(data)
|
||
elif '个人' in is_all:
|
||
data = YamlHandle(file=user_path).load()
|
||
if data: count_dict.update(data)
|
||
retrieval = []
|
||
if count_dict != {}:
|
||
for key in count_dict:
|
||
if not search:
|
||
retrieval.append([key, count_dict[key]])
|
||
else:
|
||
retrieval.append([count_dict[key], key])
|
||
return retrieval
|
||
else:
|
||
return retrieval
|
||
|
||
|
||
def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request
|
||
data = prompt_retrieval(is_all=is_all, hosts=ipaddr.client.host)
|
||
prompt.samples = data
|
||
return prompt.update(samples=data, samples_per_page=10, visible=True), prompt, is_all
|
||
|
||
|
||
def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request):
|
||
if txt and name:
|
||
yaml_obj = YamlHandle(os.path.join(prompt_path, f'prompt_{ipaddr.client.host}.yaml'))
|
||
yaml_obj.update(name, txt)
|
||
result = prompt_retrieval(is_all=checkbox, hosts=ipaddr.client.host)
|
||
prompt.samples = result
|
||
return "", "", ['个人'], prompt.update(samples=result, samples_per_page=10, visible=True), prompt
|
||
if not txt or not name:
|
||
result = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']]
|
||
prompt.samples = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']]
|
||
return txt, name, checkbox, prompt.update(samples=result, samples_per_page=10, visible=True), prompt
|
||
|
||
def prompt_input(txt, index, data: gr.Dataset):
|
||
data_str = str(data.samples[index][1])
|
||
if txt:
|
||
txt = data_str+'\n'+txt
|
||
else:
|
||
txt = data_str
|
||
return txt
|
||
|
||
def copy_result(history):
|
||
if history != []:
|
||
pyperclip.copy(history[-1])
|
||
return '已将结果复制到剪切板'
|
||
else:
|
||
return "无对话记录,复制错误!!"
|
||
|
||
|
||
def show_prompt_result(index, data: gr.Dataset, chatbot):
|
||
click = data.samples[index]
|
||
chatbot.append((click[1], click[2]))
|
||
return chatbot
|
||
|
||
base_path = os.path.dirname(__file__)
|
||
prompt_path = os.path.join(base_path, 'prompt_users')
|
||
|
||
class YamlHandle:
|
||
|
||
def __init__(self, file=os.path.join(prompt_path, 'ai_common.yaml')):
|
||
if not os.path.exists(file):
|
||
Shell(f'touch {file}').read()
|
||
self.file = file
|
||
|
||
|
||
def load(self) -> dict:
|
||
with open(file=self.file, mode='r') as f:
|
||
data = yaml.safe_load(f)
|
||
return data
|
||
|
||
def update(self, key, value):
|
||
date = self.load()
|
||
if not date:
|
||
date = {}
|
||
date[key] = value
|
||
with open(file=self.file, mode='w') as f:
|
||
yaml.dump(date, f, allow_unicode=True)
|
||
return date
|
||
|
||
def dump_dict(self, new_dict):
|
||
date = self.load()
|
||
if not date:
|
||
date = {}
|
||
date.update(new_dict)
|
||
with open(file=self.file, mode='w') as f:
|
||
yaml.dump(date, f, allow_unicode=True)
|
||
return date
|
||
|
||
class JsonHandle:
|
||
|
||
def __init__(self, file=os.path.join(prompt_path, 'prompts-PlexPt.json')):
|
||
if not os.path.exists(file):
|
||
Shell(f'touch {file}').read()
|
||
self.file = file
|
||
|
||
def load(self):
|
||
with open(file=self.file, mode='r') as f:
|
||
data = json.load(f)
|
||
return data
|
||
|
||
class FileHandle:
|
||
|
||
def __init__(self, file=None):
|
||
self.file = file
|
||
|
||
def read(self):
|
||
with open(file=self.file, mode='r') as f:
|
||
print(f.read())
|
||
|
||
|
||
def read_link(self):
|
||
link = 'https://github.com/PlexPt/awesome-chatgpt-prompts-zh/blob/main/prompts-zh.json'
|
||
name = link.split('/')[3]+link.split('/')[-1]
|
||
new_file = os.path.join(base_path, 'gpt_log', name)
|
||
response = requests.get(url=link, verify=False)
|
||
with open(new_file, "wb") as f:
|
||
f.write(response.content)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
|
||
txt = "Authorization: WPS-2:AqY7ik9XQ92tvO7+NlCRvA==:b2f626f496de9c256605a15985c855a8b3e4be99\nwps-Sid: V02SgISzdeWrYdwvW_xbib-fGlqUIIw00afc5b890008c1976f\nCookie: wpsua=V1BTVUEvMS4wIChhbmRyb2lkLW9mZmljZToxNy41O2FuZHJvaWQ6MTA7ZjIwZDAyNWQzYTM5MmExMDBiYzgxNWI2NmI3Y2E5ODI6ZG1sMmJ5QldNakF5TUVFPSl2aXZvL1YyMDIwQQ=="
|
||
txt = "Authorization: WPS-2:AqY7ik9XQ92tvO7+NlCRvA==:b2f626f496de9c256605a15985c855a8b3e4be99 客户发顺丰啦 这是其他文本哦"
|
||
|
||
print(YamlHandle().load())
|
||
|
||
|