增加prompt 检索和编辑器|增加prompt 展示
This commit is contained in:
219
func_box.py
219
func_box.py
@ -4,6 +4,7 @@
|
||||
# @Author : Spike
|
||||
# @Descr :
|
||||
import hashlib
|
||||
import json
|
||||
import os.path
|
||||
import subprocess
|
||||
import psutil
|
||||
@ -13,10 +14,14 @@ import shutil
|
||||
from contextlib import ExitStack
|
||||
import logging
|
||||
import yaml
|
||||
import requests
|
||||
logger = logging
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
import numpy as np
|
||||
from scipy.linalg import norm
|
||||
import pyperclip
|
||||
import random
|
||||
import gradio as gr
|
||||
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
|
||||
官网:https://docs.python.org/3/library/contextlib.html"""
|
||||
|
||||
@ -120,6 +125,13 @@ def md5_str(st):
|
||||
return result
|
||||
|
||||
|
||||
def html_tag_color(tag, color=None):
|
||||
if not color:
|
||||
rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
||||
color = f"rgb{rgb}"
|
||||
tag = f'<span style="background-color: {color}; font-weight: bold; color: black"> {tag} </span>'
|
||||
return tag
|
||||
|
||||
def ipaddr(): # 获取本地ipx
|
||||
ip = psutil.net_if_addrs()
|
||||
for i in ip:
|
||||
@ -172,9 +184,33 @@ def df_similarity(s1, s2):
|
||||
return np.dot(vectors[0], vectors[1]) / (norm(vectors[0]) * norm(vectors[1]))
|
||||
|
||||
|
||||
def diff_list(lst: list, percent=0.70):
|
||||
def check_json_format(file):
|
||||
new_dict = {}
|
||||
data = JsonHandle(file).load()
|
||||
if type(data) is list and len(data) > 0:
|
||||
if type(data[0]) is dict:
|
||||
for i in data:
|
||||
new_dict.update({i['act']: i['prompt']})
|
||||
return new_dict
|
||||
|
||||
def json_convert_dict():
|
||||
new_dict = {}
|
||||
for root, dirs, files in os.walk(prompt_path):
|
||||
for f in files:
|
||||
if f.startswith('prompt') and f.endswith('json'):
|
||||
new_dict.update(check_json_format(f))
|
||||
return new_dict
|
||||
|
||||
def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request):
|
||||
data = diff_list(txt, percent=percent, switch=switch, hosts=ipaddr.client.host)
|
||||
prompt.samples = data
|
||||
return prompt.update(samples=data, samples_per_page=10), prompt
|
||||
|
||||
def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''):
|
||||
import difflib
|
||||
count_dict = {}
|
||||
if not lst: lst = YamlHandle().load()
|
||||
# diff 数据,根据precent系数归类数据
|
||||
for i in lst:
|
||||
found = False
|
||||
for key in count_dict.keys():
|
||||
@ -186,16 +222,141 @@ def diff_list(lst: list, percent=0.70):
|
||||
count_dict[key] += 1
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
count_dict[i] = 1
|
||||
return
|
||||
if not found: count_dict[i] = 1
|
||||
sorted_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)
|
||||
if switch:
|
||||
sorted_dict += prompt_retrieval(is_all=switch, hosts=hosts, search=True)
|
||||
dateset_list = []
|
||||
for key in sorted_dict:
|
||||
# 开始匹配关键字
|
||||
index = key[0].find(txt)
|
||||
if index != -1:
|
||||
# sp=split 用于判断在哪里启动、在哪里断开
|
||||
if index-sp > 0: start = index-sp
|
||||
else: start = 0
|
||||
if len(key[0]) > sp * 2: end = key[0][-sp:]
|
||||
else: end = ''
|
||||
# 判断有没有传需要匹配的字符串,有则筛选、无则全返
|
||||
if txt == '' and len(key[0]) >= sp: show = key[0][0:sp] + " . . . " + end
|
||||
elif txt == '' and len(key[0]) < sp: show = key[0][0:sp]
|
||||
else: show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt))
|
||||
show += f" {html_tag_color(' X ' + str(key[1]))}"
|
||||
if lst.get(key[0]): be_value = lst[key[0]]
|
||||
else: be_value = "这个prompt还没有对话过呢,快去试试吧~"
|
||||
value = be_value
|
||||
dateset_list.append([show, key[0], value])
|
||||
return dateset_list
|
||||
|
||||
|
||||
def search_list(txt, sp=15):
|
||||
lst = YamlHandle().load()
|
||||
dateset_list = []
|
||||
for key in lst:
|
||||
index = key.find(txt)
|
||||
if index != -1:
|
||||
if index-sp > 0: start = index-sp
|
||||
else: start = 0
|
||||
show = str(key[start:index+sp]).replace(txt, html_tag_color(txt))
|
||||
value = lst[key]
|
||||
dateset_list.append([show, key, value])
|
||||
return dateset_list
|
||||
|
||||
|
||||
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
|
||||
hosts = ipaddr.client.host
|
||||
user_file = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
|
||||
if file.name.endswith('json'):
|
||||
upload_data = check_json_format(file.name)
|
||||
if upload_data != {}:
|
||||
YamlHandle(user_file).dump_dict(upload_data)
|
||||
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
|
||||
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
|
||||
else:
|
||||
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
|
||||
return prompt.samples, prompt, []
|
||||
elif file.name.endswith('yaml'):
|
||||
upload_data = YamlHandle(file.name).load()
|
||||
if upload_data != {} and type(upload_data) is dict:
|
||||
YamlHandle(user_file).dump_dict(upload_data)
|
||||
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
|
||||
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
|
||||
else:
|
||||
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
|
||||
return prompt.samples, prompt, []
|
||||
|
||||
def prompt_retrieval(is_all, hosts='', search=False):
|
||||
count_dict = {}
|
||||
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
|
||||
if '所有人' in is_all:
|
||||
for root, dirs, files in os.walk(prompt_path):
|
||||
for f in files:
|
||||
if f.startswith('prompt') and f.endswith('yaml'):
|
||||
data = YamlHandle(file=os.path.join(root, f)).load()
|
||||
if data: count_dict.update(data)
|
||||
elif '个人' in is_all:
|
||||
data = YamlHandle(file=user_path).load()
|
||||
if data: count_dict.update(data)
|
||||
retrieval = []
|
||||
if count_dict != {}:
|
||||
for key in count_dict:
|
||||
if not search:
|
||||
retrieval.append([key, count_dict[key]])
|
||||
else:
|
||||
retrieval.append([count_dict[key], key])
|
||||
return retrieval
|
||||
else:
|
||||
return retrieval
|
||||
|
||||
|
||||
def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request
|
||||
data = prompt_retrieval(is_all=is_all, hosts=ipaddr.client.host)
|
||||
prompt.samples = data
|
||||
return prompt.update(samples=data, samples_per_page=10, visible=True), prompt, is_all
|
||||
|
||||
|
||||
def prompt_save(txt, name, checkbox, prompt, ipaddr: gr.Request):
|
||||
if txt and name:
|
||||
yaml_obj = YamlHandle(os.path.join(prompt_path, f'prompt_{ipaddr.client.host}.yaml'))
|
||||
prompt_data = yaml_obj.load()
|
||||
prompt_data.update({name: txt})
|
||||
yaml_obj.update(name, txt)
|
||||
result = prompt_retrieval(is_all=checkbox, hosts=ipaddr.client.host)
|
||||
prompt.samples = result
|
||||
return "", "", ['个人'], prompt.samples, prompt
|
||||
if not txt or not name:
|
||||
prompt.samples = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']]
|
||||
return txt, name, checkbox, prompt.samples, prompt
|
||||
|
||||
def prompt_input(txt, index, data: gr.Dataset):
|
||||
data_str = str(data.samples[index][1])
|
||||
if txt:
|
||||
txt = data_str+'\n'+txt
|
||||
else:
|
||||
txt = data_str
|
||||
return txt
|
||||
|
||||
def copy_result(history):
|
||||
if history != []:
|
||||
pyperclip.copy(history[-1])
|
||||
return '已将结果复制到剪切板'
|
||||
else:
|
||||
return "无对话记录,复制错误!!"
|
||||
|
||||
def show_prompt_result(index, data: gr.Dataset, chatbot):
|
||||
click = data.samples[index]
|
||||
chatbot.append((click[1], click[2]))
|
||||
return chatbot
|
||||
|
||||
base_path = os.path.dirname(__file__)
|
||||
prompt_path = os.path.join(base_path, 'prompt_users')
|
||||
class YamlHandle:
|
||||
|
||||
def __init__(self, file='/Users/kilig/Job/Python-project/academic_gpt/logs/ai_prompt.yaml'):
|
||||
def __init__(self, file=os.path.join(prompt_path, 'ai_prompt.yaml')):
|
||||
if not os.path.exists(file):
|
||||
Shell(f'touch {file}').read()
|
||||
self.file = file
|
||||
|
||||
|
||||
def load(self) -> dict:
|
||||
with open(file=self.file, mode='r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
@ -210,13 +371,55 @@ class YamlHandle:
|
||||
yaml.dump(date, f, allow_unicode=True)
|
||||
return date
|
||||
|
||||
def dump_dict(self, new_dict):
|
||||
date = self.load()
|
||||
if not date:
|
||||
date = {}
|
||||
date.update(new_dict)
|
||||
with open(file=self.file, mode='w') as f:
|
||||
yaml.dump(date, f, allow_unicode=True)
|
||||
return date
|
||||
|
||||
class JsonHandle:
|
||||
|
||||
def __init__(self, file=os.path.join(prompt_path, 'prompts-PlexPt.json')):
|
||||
if not os.path.exists(file):
|
||||
Shell(f'touch {file}').read()
|
||||
self.file = file
|
||||
|
||||
def load(self):
|
||||
with open(file=self.file, mode='r') as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
class FileHandle:
|
||||
|
||||
def __init__(self, file=None):
|
||||
self.file = file
|
||||
|
||||
def read(self):
|
||||
with open(file=self.file, mode='r') as f:
|
||||
print(f.read())
|
||||
|
||||
|
||||
def read_link(self):
|
||||
link = 'https://github.com/PlexPt/awesome-chatgpt-prompts-zh/blob/main/prompts-zh.json'
|
||||
name = link.split('/')[3]+link.split('/')[-1]
|
||||
new_file = os.path.join(base_path, 'gpt_log', name)
|
||||
response = requests.get(url=link, verify=False)
|
||||
# with open(file=new_file, mode='w') as file:
|
||||
# for i in response.iter_content(chunk_size=1024):
|
||||
# print(i)
|
||||
# file.write(i.decode(''))
|
||||
with open(new_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
txt = "Authorization: WPS-2:AqY7ik9XQ92tvO7+NlCRvA==:b2f626f496de9c256605a15985c855a8b3e4be99\nwps-Sid: V02SgISzdeWrYdwvW_xbib-fGlqUIIw00afc5b890008c1976f\nCookie: wpsua=V1BTVUEvMS4wIChhbmRyb2lkLW9mZmljZToxNy41O2FuZHJvaWQ6MTA7ZjIwZDAyNWQzYTM5MmExMDBiYzgxNWI2NmI3Y2E5ODI6ZG1sMmJ5QldNakF5TUVFPSl2aXZvL1YyMDIwQQ=="
|
||||
txt = "Authorization: WPS-2:AqY7ik9XQ92tvO7+NlCRvA==:b2f626f496de9c256605a15985c855a8b3e4be99"
|
||||
txt = "Authorization: WPS-2:AqY7ik9XQ92tvO7+NlCRvA==:b2f626f496de9c256605a15985c855a8b3e4be99 客户发顺丰啦 这是其他文本哦"
|
||||
|
||||
# print(YamlHandle().update(123123213, 2131231231))
|
||||
# json_convert_dict()
|
||||
tree_out()
|
||||
|
||||
diff_list(YamlHandle().load())
|
||||
Reference in New Issue
Block a user