#! .\venv\ # encoding: utf-8 # @Time : 2023/4/18 # @Author : Spike # @Descr : import hashlib import json import os.path import subprocess import psutil import re import tempfile import shutil from contextlib import ExitStack import logging import yaml import requests logger = logging from sklearn.feature_extraction.text import CountVectorizer import numpy as np from scipy.linalg import norm import pyperclip import random import gradio as gr """contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。 官网:https://docs.python.org/3/library/contextlib.html""" class Shell(object): def __init__(self, args, stream=False): self.args = args self.subp = subprocess.Popen(args, shell=True, stdin=subprocess.PIPE, stderr=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8', errors='ignore', close_fds=True) self.__stream = stream self.__temp = '' def read(self): logger.debug(f'The command being executed is: "{self.args}"') if self.__stream: sysout = self.subp.stdout try: with sysout as std: for i in std: logger.info(i.rstrip()) self.__temp += i except KeyboardInterrupt as p: return 3, self.__temp+self.subp.stderr.read() finally: return 3, self.__temp+self.subp.stderr.read() else: sysout = self.subp.stdout.read() syserr = self.subp.stderr.read() if sysout: logger.debug(f"{self.args} \n{sysout}") return 1, sysout elif syserr: logger.error(f"{self.args} \n{syserr}") return 0, syserr else: logger.debug(f"{self.args} \n{[sysout], [sysout]}") return 2, '\n{}\n{}'.format(sysout, sysout) def sync(self): logger.debug('The command being executed is: "{}"'.format(self.args)) for i in self.subp.stdout: logger.debug(i.rstrip()) self.__temp += i yield self.__temp for i in self.subp.stderr: logger.debug(i.rstrip()) self.__temp += i yield self.__temp def context_with(*parms): """ 一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。 Args: *parms: 参数列表,每个参数都是一个字符串,表示类中的一个属性名。 Returns: 一个装饰器函数。 """ def decorator(cls_method): """ 装饰器函数,用于将一个类方法转换为一个嵌套多个 with 语句的方法。 Args: cls_method: 要装饰的类方法。 Returns: 装饰后的类方法。 """ def wrapper(cls='', *args, **kwargs): """ 装饰后的方法,用于嵌套多个 with 语句,并调用原始的类方法。 Args: cls: 类的实例对象。 *args: 位置参数。 **kwargs: 关键字参数。 Returns: 原始的类方法返回的结果。 """ with_list = [getattr(cls, arg) for arg in parms] with ExitStack() as stack: for context in with_list: stack.enter_context(context) return cls_method(cls, *args, **kwargs) return wrapper return decorator def copy_temp_file(file): if os.path.exists(file): exdir = tempfile.mkdtemp() temp_ = shutil.copy(file, os.path.join(exdir, os.path.basename(file))) return temp_ else: return None def md5_str(st): # 创建一个 MD5 对象 md5 = hashlib.md5() # 更新 MD5 对象的内容 md5.update(str(st).encode()) # 获取加密后的结果 result = md5.hexdigest() return result def html_tag_color(tag, color=None): if not color: rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) color = f"rgb{rgb}" tag = f' {tag} ' return tag def ipaddr(): # 获取本地ipx ip = psutil.net_if_addrs() for i in ip: if ip[i][0][3]: return ip[i][0][1] def encryption_str(txt: str): """(关键字)(加密间隔)匹配机制(关键字间隔)""" txt = str(txt) pattern = re.compile(rf"(Authorization|WPS-Sid|Cookie)(:|\s+)\s*(\S+)[\s\S]*?(?=\n|$|\s)", re.IGNORECASE) result = pattern.sub(lambda x: x.group(1) + ": XXXXXXXX", txt) return result def tree_out(dir=os.path.dirname(__file__), line=2, more=''): out = Shell(f'tree {dir} -F -I "__*|.*|venv|*.png|*.xlsx" -L {line} {more}').read()[1] localfile = os.path.join(os.path.dirname(__file__), '.tree.md') with open(localfile, 'w') as f: f.write('```\n') ll = out.splitlines() for i in range(len(ll)): if i == 0: f.write(ll[i].split('/')[-2]+'\n') else: f.write(ll[i]+'\n') f.write('```\n') def chat_history(log: list, split=0): if split: log = log[split:] chat = '' history = '' for i in log: chat += f'{i[0]}\n\n' history += f'{i[1]}\n\n' return chat, history def df_similarity(s1, s2): """弃用,会警告,这个库不会用""" def add_space(s): return ' '.join(list(s)) # 将字中间加入空格 s1, s2 = add_space(s1), add_space(s2) # 转化为TF矩阵 cv = CountVectorizer(tokenizer=lambda s: s.split()) corpus = [s1, s2] vectors = cv.fit_transform(corpus).toarray() # 计算TF系数 return np.dot(vectors[0], vectors[1]) / (norm(vectors[0]) * norm(vectors[1])) def check_json_format(file): new_dict = {} data = JsonHandle(file).load() if type(data) is list and len(data) > 0: if type(data[0]) is dict: for i in data: new_dict.update({i['act']: i['prompt']}) return new_dict def json_convert_dict(): new_dict = {} for root, dirs, files in os.walk(prompt_path): for f in files: if f.startswith('prompt') and f.endswith('json'): new_dict.update(check_json_format(f)) return new_dict def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request): data = diff_list(txt, percent=percent, switch=switch, hosts=ipaddr.client.host) prompt.samples = data return prompt.update(samples=data, samples_per_page=10), prompt def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''): import difflib count_dict = {} if not lst: lst = YamlHandle().load() lst.update(YamlHandle(os.path.join(prompt_path, f"ai_private_{hosts}.yaml")).load()) # diff 数据,根据precent系数归类数据 for i in lst: found = False for key in count_dict.keys(): str_tf = difflib.SequenceMatcher(None, i, key).ratio() if str_tf >= percent: if len(i) > len(key): count_dict[i] = count_dict[key] + 1 count_dict.pop(key) else: count_dict[key] += 1 found = True break if not found: count_dict[i] = 1 sorted_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True) if switch: sorted_dict += prompt_retrieval(is_all=switch, hosts=hosts, search=True) dateset_list = [] for key in sorted_dict: # 开始匹配关键字 index = key[0].find(txt) if index != -1: # sp=split 用于判断在哪里启动、在哪里断开 if index-sp > 0: start = index-sp else: start = 0 if len(key[0]) > sp * 2: end = key[0][-sp:] else: end = '' # 判断有没有传需要匹配的字符串,有则筛选、无则全返 if txt == '' and len(key[0]) >= sp: show = key[0][0:sp] + " . . . " + end elif txt == '' and len(key[0]) < sp: show = key[0][0:sp] else: show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt)) show += f" {html_tag_color(' X ' + str(key[1]))}" if lst.get(key[0]): be_value = lst[key[0]] else: be_value = "这个prompt还没有对话过呢,快去试试吧~" value = be_value dateset_list.append([show, key[0], value]) return dateset_list def search_list(txt, sp=15): lst = YamlHandle().load() dateset_list = [] for key in lst: index = key.find(txt) if index != -1: if index-sp > 0: start = index-sp else: start = 0 show = str(key[start:index+sp]).replace(txt, html_tag_color(txt)) value = lst[key] dateset_list.append([show, key, value]) return dateset_list def prompt_upload_refresh(file, prompt, ipaddr: gr.Request): hosts = ipaddr.client.host user_file = os.path.join(prompt_path, f'prompt_{hosts}.yaml') if file.name.endswith('json'): upload_data = check_json_format(file.name) if upload_data != {}: YamlHandle(user_file).dump_dict(upload_data) ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts) return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人'] else: prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']] return prompt.samples, prompt, [] elif file.name.endswith('yaml'): upload_data = YamlHandle(file.name).load() if upload_data != {} and type(upload_data) is dict: YamlHandle(user_file).dump_dict(upload_data) ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts) return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人'] else: prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']] return prompt.samples, prompt, [] def prompt_retrieval(is_all, hosts='', search=False): count_dict = {} user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml') if '所有人' in is_all: for root, dirs, files in os.walk(prompt_path): for f in files: if f.startswith('prompt') and f.endswith('yaml'): data = YamlHandle(file=os.path.join(root, f)).load() if data: count_dict.update(data) elif '个人' in is_all: data = YamlHandle(file=user_path).load() if data: count_dict.update(data) retrieval = [] if count_dict != {}: for key in count_dict: if not search: retrieval.append([key, count_dict[key]]) else: retrieval.append([count_dict[key], key]) return retrieval else: return retrieval def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request data = prompt_retrieval(is_all=is_all, hosts=ipaddr.client.host) prompt.samples = data return prompt.update(samples=data, samples_per_page=10, visible=True), prompt, is_all def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request): if txt and name: yaml_obj = YamlHandle(os.path.join(prompt_path, f'prompt_{ipaddr.client.host}.yaml')) yaml_obj.update(name, txt) result = prompt_retrieval(is_all=checkbox, hosts=ipaddr.client.host) prompt.samples = result return "", "", ['个人'], prompt.update(samples=result, samples_per_page=10, visible=True), prompt if not txt or not name: result = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']] prompt.samples = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']] return txt, name, checkbox, prompt.update(samples=result, samples_per_page=10, visible=True), prompt def prompt_input(txt, index, data: gr.Dataset): data_str = str(data.samples[index][1]) if txt: txt = data_str+'\n'+txt else: txt = data_str return txt def copy_result(history): if history != []: pyperclip.copy(history[-1]) return '已将结果复制到剪切板' else: return "无对话记录,复制错误!!" def show_prompt_result(index, data: gr.Dataset, chatbot): click = data.samples[index] chatbot.append((click[1], click[2])) return chatbot base_path = os.path.dirname(__file__) prompt_path = os.path.join(base_path, 'prompt_users') class YamlHandle: def __init__(self, file=os.path.join(prompt_path, 'ai_common.yaml')): if not os.path.exists(file): Shell(f'touch {file}').read() self.file = file def load(self) -> dict: with open(file=self.file, mode='r') as f: data = yaml.safe_load(f) return data def update(self, key, value): date = self.load() if not date: date = {} date[key] = value with open(file=self.file, mode='w') as f: yaml.dump(date, f, allow_unicode=True) return date def dump_dict(self, new_dict): date = self.load() if not date: date = {} date.update(new_dict) with open(file=self.file, mode='w') as f: yaml.dump(date, f, allow_unicode=True) return date class JsonHandle: def __init__(self, file=os.path.join(prompt_path, 'prompts-PlexPt.json')): if not os.path.exists(file): Shell(f'touch {file}').read() self.file = file def load(self): with open(file=self.file, mode='r') as f: data = json.load(f) return data class FileHandle: def __init__(self, file=None): self.file = file def read(self): with open(file=self.file, mode='r') as f: print(f.read()) def read_link(self): link = 'https://github.com/PlexPt/awesome-chatgpt-prompts-zh/blob/main/prompts-zh.json' name = link.split('/')[3]+link.split('/')[-1] new_file = os.path.join(base_path, 'gpt_log', name) response = requests.get(url=link, verify=False) with open(new_file, "wb") as f: f.write(response.content) if __name__ == '__main__': print(YamlHandle().load())