增加prompt 检索和编辑器|增加prompt 展示

This commit is contained in:
w_xiaolizu
2023-05-11 17:05:19 +08:00
parent 03f0f49847
commit 3cc6eeb314
10 changed files with 1746 additions and 77 deletions

View File

@ -4,6 +4,7 @@
# @Author : Spike
# @Descr :
import hashlib
import json
import os.path
import subprocess
import psutil
@ -13,10 +14,14 @@ import shutil
from contextlib import ExitStack
import logging
import yaml
import requests
logger = logging
from sklearn.feature_extraction.text import CountVectorizer
import numpy as np
from scipy.linalg import norm
import pyperclip
import random
import gradio as gr
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
官网https://docs.python.org/3/library/contextlib.html"""
@ -120,6 +125,13 @@ def md5_str(st):
return result
def html_tag_color(tag, color=None):
if not color:
rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
color = f"rgb{rgb}"
tag = f'<span style="background-color: {color}; font-weight: bold; color: black">&nbsp;{tag}&ensp;</span>'
return tag
def ipaddr(): # 获取本地ipx
ip = psutil.net_if_addrs()
for i in ip:
@ -172,9 +184,33 @@ def df_similarity(s1, s2):
return np.dot(vectors[0], vectors[1]) / (norm(vectors[0]) * norm(vectors[1]))
def diff_list(lst: list, percent=0.70):
def check_json_format(file):
new_dict = {}
data = JsonHandle(file).load()
if type(data) is list and len(data) > 0:
if type(data[0]) is dict:
for i in data:
new_dict.update({i['act']: i['prompt']})
return new_dict
def json_convert_dict():
new_dict = {}
for root, dirs, files in os.walk(prompt_path):
for f in files:
if f.startswith('prompt') and f.endswith('json'):
new_dict.update(check_json_format(f))
return new_dict
def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request):
data = diff_list(txt, percent=percent, switch=switch, hosts=ipaddr.client.host)
prompt.samples = data
return prompt.update(samples=data, samples_per_page=10), prompt
def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''):
import difflib
count_dict = {}
if not lst: lst = YamlHandle().load()
# diff 数据根据precent系数归类数据
for i in lst:
found = False
for key in count_dict.keys():
@ -186,16 +222,141 @@ def diff_list(lst: list, percent=0.70):
count_dict[key] += 1
found = True
break
if not found:
count_dict[i] = 1
return
if not found: count_dict[i] = 1
sorted_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)
if switch:
sorted_dict += prompt_retrieval(is_all=switch, hosts=hosts, search=True)
dateset_list = []
for key in sorted_dict:
# 开始匹配关键字
index = key[0].find(txt)
if index != -1:
# sp=split 用于判断在哪里启动、在哪里断开
if index-sp > 0: start = index-sp
else: start = 0
if len(key[0]) > sp * 2: end = key[0][-sp:]
else: end = ''
# 判断有没有传需要匹配的字符串,有则筛选、无则全返
if txt == '' and len(key[0]) >= sp: show = key[0][0:sp] + " . . . " + end
elif txt == '' and len(key[0]) < sp: show = key[0][0:sp]
else: show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt))
show += f" {html_tag_color(' X ' + str(key[1]))}"
if lst.get(key[0]): be_value = lst[key[0]]
else: be_value = "这个prompt还没有对话过呢快去试试吧"
value = be_value
dateset_list.append([show, key[0], value])
return dateset_list
def search_list(txt, sp=15):
lst = YamlHandle().load()
dateset_list = []
for key in lst:
index = key.find(txt)
if index != -1:
if index-sp > 0: start = index-sp
else: start = 0
show = str(key[start:index+sp]).replace(txt, html_tag_color(txt))
value = lst[key]
dateset_list.append([show, key, value])
return dateset_list
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
hosts = ipaddr.client.host
user_file = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
if file.name.endswith('json'):
upload_data = check_json_format(file.name)
if upload_data != {}:
YamlHandle(user_file).dump_dict(upload_data)
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
else:
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
return prompt.samples, prompt, []
elif file.name.endswith('yaml'):
upload_data = YamlHandle(file.name).load()
if upload_data != {} and type(upload_data) is dict:
YamlHandle(user_file).dump_dict(upload_data)
ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
return prompt.update(samples=ret_data, samples_per_page=10, visible=True), prompt, ['个人']
else:
prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
return prompt.samples, prompt, []
def prompt_retrieval(is_all, hosts='', search=False):
count_dict = {}
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
if '所有人' in is_all:
for root, dirs, files in os.walk(prompt_path):
for f in files:
if f.startswith('prompt') and f.endswith('yaml'):
data = YamlHandle(file=os.path.join(root, f)).load()
if data: count_dict.update(data)
elif '个人' in is_all:
data = YamlHandle(file=user_path).load()
if data: count_dict.update(data)
retrieval = []
if count_dict != {}:
for key in count_dict:
if not search:
retrieval.append([key, count_dict[key]])
else:
retrieval.append([count_dict[key], key])
return retrieval
else:
return retrieval
def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request
data = prompt_retrieval(is_all=is_all, hosts=ipaddr.client.host)
prompt.samples = data
return prompt.update(samples=data, samples_per_page=10, visible=True), prompt, is_all
def prompt_save(txt, name, checkbox, prompt, ipaddr: gr.Request):
if txt and name:
yaml_obj = YamlHandle(os.path.join(prompt_path, f'prompt_{ipaddr.client.host}.yaml'))
prompt_data = yaml_obj.load()
prompt_data.update({name: txt})
yaml_obj.update(name, txt)
result = prompt_retrieval(is_all=checkbox, hosts=ipaddr.client.host)
prompt.samples = result
return "", "", ['个人'], prompt.samples, prompt
if not txt or not name:
prompt.samples = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']]
return txt, name, checkbox, prompt.samples, prompt
def prompt_input(txt, index, data: gr.Dataset):
data_str = str(data.samples[index][1])
if txt:
txt = data_str+'\n'+txt
else:
txt = data_str
return txt
def copy_result(history):
if history != []:
pyperclip.copy(history[-1])
return '已将结果复制到剪切板'
else:
return "无对话记录,复制错误!!"
def show_prompt_result(index, data: gr.Dataset, chatbot):
click = data.samples[index]
chatbot.append((click[1], click[2]))
return chatbot
base_path = os.path.dirname(__file__)
prompt_path = os.path.join(base_path, 'prompt_users')
class YamlHandle:
def __init__(self, file='/Users/kilig/Job/Python-project/academic_gpt/logs/ai_prompt.yaml'):
def __init__(self, file=os.path.join(prompt_path, 'ai_prompt.yaml')):
if not os.path.exists(file):
Shell(f'touch {file}').read()
self.file = file
def load(self) -> dict:
with open(file=self.file, mode='r') as f:
data = yaml.safe_load(f)
@ -210,13 +371,55 @@ class YamlHandle:
yaml.dump(date, f, allow_unicode=True)
return date
def dump_dict(self, new_dict):
date = self.load()
if not date:
date = {}
date.update(new_dict)
with open(file=self.file, mode='w') as f:
yaml.dump(date, f, allow_unicode=True)
return date
class JsonHandle:
def __init__(self, file=os.path.join(prompt_path, 'prompts-PlexPt.json')):
if not os.path.exists(file):
Shell(f'touch {file}').read()
self.file = file
def load(self):
with open(file=self.file, mode='r') as f:
data = json.load(f)
return data
class FileHandle:
def __init__(self, file=None):
self.file = file
def read(self):
with open(file=self.file, mode='r') as f:
print(f.read())
def read_link(self):
link = 'https://github.com/PlexPt/awesome-chatgpt-prompts-zh/blob/main/prompts-zh.json'
name = link.split('/')[3]+link.split('/')[-1]
new_file = os.path.join(base_path, 'gpt_log', name)
response = requests.get(url=link, verify=False)
# with open(file=new_file, mode='w') as file:
# for i in response.iter_content(chunk_size=1024):
# print(i)
# file.write(i.decode(''))
with open(new_file, "wb") as f:
f.write(response.content)
if __name__ == '__main__':
txt = "Authorization: WPS-2:AqY7ik9XQ92tvO7+NlCRvA==:b2f626f496de9c256605a15985c855a8b3e4be99\nwps-Sid: V02SgISzdeWrYdwvW_xbib-fGlqUIIw00afc5b890008c1976f\nCookie: wpsua=V1BTVUEvMS4wIChhbmRyb2lkLW9mZmljZToxNy41O2FuZHJvaWQ6MTA7ZjIwZDAyNWQzYTM5MmExMDBiYzgxNWI2NmI3Y2E5ODI6ZG1sMmJ5QldNakF5TUVFPSl2aXZvL1YyMDIwQQ=="
txt = "Authorization: WPS-2:AqY7ik9XQ92tvO7+NlCRvA==:b2f626f496de9c256605a15985c855a8b3e4be99"
txt = "Authorization: WPS-2:AqY7ik9XQ92tvO7+NlCRvA==:b2f626f496de9c256605a15985c855a8b3e4be99 客户发顺丰啦 这是其他文本哦"
# print(YamlHandle().update(123123213, 2131231231))
# json_convert_dict()
tree_out()
diff_list(YamlHandle().load())