删除多余代码|增加代码注释
This commit is contained in:
@ -123,7 +123,7 @@ class ChatBot(ChatBotFrame):
|
|||||||
inputs=[self.pro_prompt_list, self.pro_prompt_state, self.pro_results],
|
inputs=[self.pro_prompt_list, self.pro_prompt_state, self.pro_results],
|
||||||
outputs=[self.pro_results])
|
outputs=[self.pro_results])
|
||||||
self.pro_new_btn.click(fn=func_box.prompt_save,
|
self.pro_new_btn.click(fn=func_box.prompt_save,
|
||||||
inputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_private_check, self.pro_fp_state],
|
inputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_fp_state],
|
||||||
outputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_private_check,
|
outputs=[self.pro_edit_txt, self.pro_name_txt, self.pro_private_check,
|
||||||
self.pro_func_prompt, self.pro_fp_state])
|
self.pro_func_prompt, self.pro_fp_state])
|
||||||
|
|
||||||
|
|||||||
219
func_box.py
219
func_box.py
@ -17,6 +17,7 @@ from contextlib import ExitStack
|
|||||||
import logging
|
import logging
|
||||||
import yaml
|
import yaml
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
logger = logging
|
logger = logging
|
||||||
from sklearn.feature_extraction.text import CountVectorizer
|
from sklearn.feature_extraction.text import CountVectorizer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -26,14 +27,16 @@ import random
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import toolbox
|
import toolbox
|
||||||
from prompt_generator import SqliteHandle
|
from prompt_generator import SqliteHandle
|
||||||
|
|
||||||
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
|
"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
|
||||||
官网:https://docs.python.org/3/library/contextlib.html"""
|
官网:https://docs.python.org/3/library/contextlib.html"""
|
||||||
|
|
||||||
|
|
||||||
class Shell(object):
|
class Shell(object):
|
||||||
def __init__(self, args, stream=False):
|
def __init__(self, args, stream=False):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.subp = subprocess.Popen(args, shell=True,
|
self.subp = subprocess.Popen(args, shell=True,
|
||||||
stdin=subprocess.PIPE, stderr=subprocess.PIPE,
|
stdin=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||||
stdout=subprocess.PIPE, encoding='utf-8',
|
stdout=subprocess.PIPE, encoding='utf-8',
|
||||||
errors='ignore', close_fds=True)
|
errors='ignore', close_fds=True)
|
||||||
self.__stream = stream
|
self.__stream = stream
|
||||||
@ -49,9 +52,9 @@ class Shell(object):
|
|||||||
logger.info(i.rstrip())
|
logger.info(i.rstrip())
|
||||||
self.__temp += i
|
self.__temp += i
|
||||||
except KeyboardInterrupt as p:
|
except KeyboardInterrupt as p:
|
||||||
return 3, self.__temp+self.subp.stderr.read()
|
return 3, self.__temp + self.subp.stderr.read()
|
||||||
finally:
|
finally:
|
||||||
return 3, self.__temp+self.subp.stderr.read()
|
return 3, self.__temp + self.subp.stderr.read()
|
||||||
else:
|
else:
|
||||||
sysout = self.subp.stdout.read()
|
sysout = self.subp.stdout.read()
|
||||||
syserr = self.subp.stderr.read()
|
syserr = self.subp.stderr.read()
|
||||||
@ -77,7 +80,12 @@ class Shell(object):
|
|||||||
self.__temp += i
|
self.__temp += i
|
||||||
yield self.__temp
|
yield self.__temp
|
||||||
|
|
||||||
|
|
||||||
def timeStatistics(func):
|
def timeStatistics(func):
|
||||||
|
"""
|
||||||
|
统计函数执行时常的装饰器
|
||||||
|
"""
|
||||||
|
|
||||||
def statistics(*args, **kwargs):
|
def statistics(*args, **kwargs):
|
||||||
startTiem = time.time()
|
startTiem = time.time()
|
||||||
obj = func(*args, **kwargs)
|
obj = func(*args, **kwargs)
|
||||||
@ -85,8 +93,10 @@ def timeStatistics(func):
|
|||||||
ums = startTiem - endTiem
|
ums = startTiem - endTiem
|
||||||
print('func:{} > Time-consuming: {}'.format(func, ums))
|
print('func:{} > Time-consuming: {}'.format(func, ums))
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
return statistics
|
return statistics
|
||||||
|
|
||||||
|
|
||||||
def context_with(*parms):
|
def context_with(*parms):
|
||||||
"""
|
"""
|
||||||
一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。
|
一个装饰器,根据传递的参数列表,在类方法上下文中嵌套多个 with 语句。
|
||||||
@ -95,6 +105,7 @@ def context_with(*parms):
|
|||||||
Returns:
|
Returns:
|
||||||
一个装饰器函数。
|
一个装饰器函数。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(cls_method):
|
def decorator(cls_method):
|
||||||
"""
|
"""
|
||||||
装饰器函数,用于将一个类方法转换为一个嵌套多个 with 语句的方法。
|
装饰器函数,用于将一个类方法转换为一个嵌套多个 with 语句的方法。
|
||||||
@ -103,6 +114,7 @@ def context_with(*parms):
|
|||||||
Returns:
|
Returns:
|
||||||
装饰后的类方法。
|
装饰后的类方法。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(cls='', *args, **kwargs):
|
def wrapper(cls='', *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
装饰后的方法,用于嵌套多个 with 语句,并调用原始的类方法。
|
装饰后的方法,用于嵌套多个 with 语句,并调用原始的类方法。
|
||||||
@ -118,7 +130,9 @@ def context_with(*parms):
|
|||||||
for context in with_list:
|
for context in with_list:
|
||||||
stack.enter_context(context)
|
stack.enter_context(context)
|
||||||
return cls_method(cls, *args, **kwargs)
|
return cls_method(cls, *args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
@ -130,6 +144,7 @@ def copy_temp_file(file):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def md5_str(st):
|
def md5_str(st):
|
||||||
# 创建一个 MD5 对象
|
# 创建一个 MD5 对象
|
||||||
md5 = hashlib.md5()
|
md5 = hashlib.md5()
|
||||||
@ -141,18 +156,24 @@ def md5_str(st):
|
|||||||
|
|
||||||
|
|
||||||
def html_tag_color(tag, color=None):
|
def html_tag_color(tag, color=None):
|
||||||
|
"""
|
||||||
|
将文本转换为带有高亮提示的html代码
|
||||||
|
"""
|
||||||
if not color:
|
if not color:
|
||||||
rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
||||||
color = f"rgb{rgb}"
|
color = f"rgb{rgb}"
|
||||||
tag = f'<span style="background-color: {color}; font-weight: bold; color: black"> {tag} </span>'
|
tag = f'<span style="background-color: {color}; font-weight: bold; color: black"> {tag} </span>'
|
||||||
return tag
|
return tag
|
||||||
|
|
||||||
def ipaddr(): # 获取本地ipx
|
|
||||||
|
def ipaddr():
|
||||||
|
# 获取本地ipx
|
||||||
ip = psutil.net_if_addrs()
|
ip = psutil.net_if_addrs()
|
||||||
for i in ip:
|
for i in ip:
|
||||||
if ip[i][0][3]:
|
if ip[i][0][3]:
|
||||||
return ip[i][0][1]
|
return ip[i][0][1]
|
||||||
|
|
||||||
|
|
||||||
def encryption_str(txt: str):
|
def encryption_str(txt: str):
|
||||||
"""(关键字)(加密间隔)匹配机制(关键字间隔)"""
|
"""(关键字)(加密间隔)匹配机制(关键字间隔)"""
|
||||||
txt = str(txt)
|
txt = str(txt)
|
||||||
@ -160,7 +181,11 @@ def encryption_str(txt: str):
|
|||||||
result = pattern.sub(lambda x: x.group(1) + ": XXXXXXXX", txt)
|
result = pattern.sub(lambda x: x.group(1) + ": XXXXXXXX", txt)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def tree_out(dir=os.path.dirname(__file__), line=2, more=''):
|
def tree_out(dir=os.path.dirname(__file__), line=2, more=''):
|
||||||
|
"""
|
||||||
|
获取本地文件的树形结构转化为Markdown代码文本
|
||||||
|
"""
|
||||||
out = Shell(f'tree {dir} -F -I "__*|.*|venv|*.png|*.xlsx" -L {line} {more}').read()[1]
|
out = Shell(f'tree {dir} -F -I "__*|.*|venv|*.png|*.xlsx" -L {line} {more}').read()[1]
|
||||||
localfile = os.path.join(os.path.dirname(__file__), '.tree.md')
|
localfile = os.path.join(os.path.dirname(__file__), '.tree.md')
|
||||||
with open(localfile, 'w') as f:
|
with open(localfile, 'w') as f:
|
||||||
@ -168,13 +193,16 @@ def tree_out(dir=os.path.dirname(__file__), line=2, more=''):
|
|||||||
ll = out.splitlines()
|
ll = out.splitlines()
|
||||||
for i in range(len(ll)):
|
for i in range(len(ll)):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
f.write(ll[i].split('/')[-2]+'\n')
|
f.write(ll[i].split('/')[-2] + '\n')
|
||||||
else:
|
else:
|
||||||
f.write(ll[i]+'\n')
|
f.write(ll[i] + '\n')
|
||||||
f.write('```\n')
|
f.write('```\n')
|
||||||
|
|
||||||
|
|
||||||
def chat_history(log: list, split=0):
|
def chat_history(log: list, split=0):
|
||||||
|
"""
|
||||||
|
auto_gpt 使用的代码,后续会迁移
|
||||||
|
"""
|
||||||
if split:
|
if split:
|
||||||
log = log[split:]
|
log = log[split:]
|
||||||
chat = ''
|
chat = ''
|
||||||
@ -189,6 +217,7 @@ def df_similarity(s1, s2):
|
|||||||
"""弃用,会警告,这个库不会用"""
|
"""弃用,会警告,这个库不会用"""
|
||||||
def add_space(s):
|
def add_space(s):
|
||||||
return ' '.join(list(s))
|
return ' '.join(list(s))
|
||||||
|
|
||||||
# 将字中间加入空格
|
# 将字中间加入空格
|
||||||
s1, s2 = add_space(s1), add_space(s2)
|
s1, s2 = add_space(s1), add_space(s2)
|
||||||
# 转化为TF矩阵
|
# 转化为TF矩阵
|
||||||
@ -200,6 +229,9 @@ def df_similarity(s1, s2):
|
|||||||
|
|
||||||
|
|
||||||
def check_json_format(file):
|
def check_json_format(file):
|
||||||
|
"""
|
||||||
|
检查上传的Json文件是否符合规范
|
||||||
|
"""
|
||||||
new_dict = {}
|
new_dict = {}
|
||||||
data = JsonHandle(file).load()
|
data = JsonHandle(file).load()
|
||||||
if type(data) is list and len(data) > 0:
|
if type(data) is list and len(data) > 0:
|
||||||
@ -208,21 +240,49 @@ def check_json_format(file):
|
|||||||
new_dict.update({i['act']: i['prompt']})
|
new_dict.update({i['act']: i['prompt']})
|
||||||
return new_dict
|
return new_dict
|
||||||
|
|
||||||
def json_convert_dict():
|
|
||||||
|
def json_convert_dict(file):
|
||||||
|
"""
|
||||||
|
批量将json转换为字典
|
||||||
|
"""
|
||||||
new_dict = {}
|
new_dict = {}
|
||||||
for root, dirs, files in os.walk(prompt_path):
|
for root, dirs, files in os.walk(file):
|
||||||
for f in files:
|
for f in files:
|
||||||
if f.startswith('prompt') and f.endswith('json'):
|
if f.startswith('prompt') and f.endswith('json'):
|
||||||
new_dict.update(check_json_format(f))
|
new_dict.update(check_json_format(f))
|
||||||
return new_dict
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request):
|
def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request):
|
||||||
|
"""
|
||||||
|
绘制搜索结果
|
||||||
|
Args:
|
||||||
|
txt (str): 过滤文本
|
||||||
|
prompt : 原始的dataset对象
|
||||||
|
percent (int): TF系数,用于计算文本相似度
|
||||||
|
switch (list): 过滤个人或所有人的Prompt
|
||||||
|
ipaddr : 请求人信息
|
||||||
|
Returns:
|
||||||
|
注册函数所需的元祖对象
|
||||||
|
"""
|
||||||
data = diff_list(txt, percent=percent, switch=switch, hosts=ipaddr.client.host)
|
data = diff_list(txt, percent=percent, switch=switch, hosts=ipaddr.client.host)
|
||||||
prompt.samples = data
|
prompt.samples = data
|
||||||
return prompt.update(samples=data, visible=True), prompt
|
return prompt.update(samples=data, visible=True), prompt
|
||||||
|
|
||||||
|
|
||||||
def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''):
|
def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15, hosts=''):
|
||||||
|
"""
|
||||||
|
按照搜索结果统计相似度的文本,两组文本相似度>70%的将统计在一起,取最长的作为key
|
||||||
|
Args:
|
||||||
|
txt (str): 过滤文本
|
||||||
|
percent (int): TF系数,用于计算文本相似度
|
||||||
|
switch (list): 过滤个人或所有人的Prompt
|
||||||
|
lst:指定一个列表或字典
|
||||||
|
sp: 截取展示的文本长度
|
||||||
|
hosts : 请求人的ip
|
||||||
|
Returns:
|
||||||
|
返回一个列表
|
||||||
|
"""
|
||||||
import difflib
|
import difflib
|
||||||
count_dict = {}
|
count_dict = {}
|
||||||
if not lst:
|
if not lst:
|
||||||
@ -251,37 +311,41 @@ def diff_list(txt='', percent=0.70, switch: list = None, lst: list = None, sp=15
|
|||||||
index = key[0].find(txt)
|
index = key[0].find(txt)
|
||||||
if index != -1:
|
if index != -1:
|
||||||
# sp=split 用于判断在哪里启动、在哪里断开
|
# sp=split 用于判断在哪里启动、在哪里断开
|
||||||
if index-sp > 0: start = index-sp
|
if index - sp > 0:
|
||||||
else: start = 0
|
start = index - sp
|
||||||
if len(key[0]) > sp * 2: end = key[0][-sp:]
|
else:
|
||||||
else: end = ''
|
start = 0
|
||||||
|
if len(key[0]) > sp * 2:
|
||||||
|
end = key[0][-sp:]
|
||||||
|
else:
|
||||||
|
end = ''
|
||||||
# 判断有没有传需要匹配的字符串,有则筛选、无则全返
|
# 判断有没有传需要匹配的字符串,有则筛选、无则全返
|
||||||
if txt == '' and len(key[0]) >= sp: show = key[0][0:sp] + " . . . " + end
|
if txt == '' and len(key[0]) >= sp:
|
||||||
elif txt == '' and len(key[0]) < sp: show = key[0][0:sp]
|
show = key[0][0:sp] + " . . . " + end
|
||||||
else: show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt))
|
elif txt == '' and len(key[0]) < sp:
|
||||||
|
show = key[0][0:sp]
|
||||||
|
else:
|
||||||
|
show = str(key[0][start:index + sp]).replace(txt, html_tag_color(txt))
|
||||||
show += f" {html_tag_color(' X ' + str(key[1]))}"
|
show += f" {html_tag_color(' X ' + str(key[1]))}"
|
||||||
if lst.get(key[0]): be_value = lst[key[0]]
|
if lst.get(key[0]):
|
||||||
else: be_value = "这个prompt还没有对话过呢,快去试试吧~"
|
be_value = lst[key[0]]
|
||||||
|
else:
|
||||||
|
be_value = "这个prompt还没有对话过呢,快去试试吧~"
|
||||||
value = be_value
|
value = be_value
|
||||||
dateset_list.append([show, key[0], value])
|
dateset_list.append([show, key[0], value])
|
||||||
return dateset_list
|
return dateset_list
|
||||||
|
|
||||||
|
|
||||||
def search_list(txt, sp=15):
|
|
||||||
lst = SqliteHandle('ai_common').get_prompt_value()
|
|
||||||
dateset_list = []
|
|
||||||
for key in lst:
|
|
||||||
index = key.find(txt)
|
|
||||||
if index != -1:
|
|
||||||
if index-sp > 0: start = index-sp
|
|
||||||
else: start = 0
|
|
||||||
show = str(key[start:index+sp]).replace(txt, html_tag_color(txt))
|
|
||||||
value = lst[key]
|
|
||||||
dateset_list.append([show, key, value])
|
|
||||||
return dateset_list
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
|
def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
|
||||||
|
"""
|
||||||
|
上传文件,将文件转换为字典,然后存储到数据库,并刷新Prompt区域
|
||||||
|
Args:
|
||||||
|
file: 上传的文件
|
||||||
|
prompt: 原始prompt对象
|
||||||
|
ipaddr:ipaddr用户请求信息
|
||||||
|
Returns:
|
||||||
|
注册函数所需的元祖对象
|
||||||
|
"""
|
||||||
hosts = ipaddr.client.host
|
hosts = ipaddr.client.host
|
||||||
if file.name.endswith('json'):
|
if file.name.endswith('json'):
|
||||||
upload_data = check_json_format(file.name)
|
upload_data = check_json_format(file.name)
|
||||||
@ -299,6 +363,15 @@ def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
|
|||||||
|
|
||||||
|
|
||||||
def prompt_retrieval(is_all, hosts='', search=False):
|
def prompt_retrieval(is_all, hosts='', search=False):
|
||||||
|
"""
|
||||||
|
上传文件,将文件转换为字典,然后存储到数据库,并刷新Prompt区域
|
||||||
|
Args:
|
||||||
|
is_all: prompt类型
|
||||||
|
hosts: 查询的用户ip
|
||||||
|
search:支持搜索,搜索时将key作为key
|
||||||
|
Returns:
|
||||||
|
返回一个列表
|
||||||
|
"""
|
||||||
count_dict = {}
|
count_dict = {}
|
||||||
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
|
user_path = os.path.join(prompt_path, f'prompt_{hosts}.yaml')
|
||||||
if '所有人' in is_all:
|
if '所有人' in is_all:
|
||||||
@ -307,12 +380,12 @@ def prompt_retrieval(is_all, hosts='', search=False):
|
|||||||
data = SqliteHandle(tab).get_prompt_value()
|
data = SqliteHandle(tab).get_prompt_value()
|
||||||
if data: count_dict.update(data)
|
if data: count_dict.update(data)
|
||||||
elif '个人' in is_all:
|
elif '个人' in is_all:
|
||||||
data = SqliteHandle(f'prompt_{hosts}').get_prompt_value()
|
data = SqliteHandle(f'prompt_{hosts}').get_prompt_value()
|
||||||
if data: count_dict.update(data)
|
if data: count_dict.update(data)
|
||||||
retrieval = []
|
retrieval = []
|
||||||
if count_dict != {}:
|
if count_dict != {}:
|
||||||
for key in count_dict:
|
for key in count_dict:
|
||||||
if not search:
|
if not search:
|
||||||
retrieval.append([key, count_dict[key]])
|
retrieval.append([key, count_dict[key]])
|
||||||
else:
|
else:
|
||||||
retrieval.append([count_dict[key], key])
|
retrieval.append([count_dict[key], key])
|
||||||
@ -321,17 +394,36 @@ def prompt_retrieval(is_all, hosts='', search=False):
|
|||||||
return retrieval
|
return retrieval
|
||||||
|
|
||||||
|
|
||||||
def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request
|
def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request
|
||||||
|
"""
|
||||||
|
上传文件,将文件转换为字典,然后存储到数据库,并刷新Prompt区域
|
||||||
|
Args:
|
||||||
|
is_all: prompt类型
|
||||||
|
prompt: dataset原始对象
|
||||||
|
ipaddr:请求用户信息
|
||||||
|
Returns:
|
||||||
|
返回注册函数所需的对象
|
||||||
|
"""
|
||||||
data = prompt_retrieval(is_all=is_all, hosts=ipaddr.client.host)
|
data = prompt_retrieval(is_all=is_all, hosts=ipaddr.client.host)
|
||||||
prompt.samples = data
|
prompt.samples = data
|
||||||
return prompt.update(samples=data, visible=True), prompt, is_all
|
return prompt.update(samples=data, visible=True), prompt, is_all
|
||||||
|
|
||||||
|
|
||||||
def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request):
|
def prompt_save(txt, name, prompt: gr.Dataset, ipaddr: gr.Request):
|
||||||
|
"""
|
||||||
|
编辑和保存Prompt
|
||||||
|
Args:
|
||||||
|
txt: Prompt正文
|
||||||
|
name: Prompt的名字
|
||||||
|
prompt: dataset原始对象
|
||||||
|
ipaddr:请求用户信息
|
||||||
|
Returns:
|
||||||
|
返回注册函数所需的对象
|
||||||
|
"""
|
||||||
if txt and name:
|
if txt and name:
|
||||||
yaml_obj = SqliteHandle(f'prompt_{ipaddr.client.host}')
|
yaml_obj = SqliteHandle(f'prompt_{ipaddr.client.host}')
|
||||||
yaml_obj.inset_prompt({name: txt})
|
yaml_obj.inset_prompt({name: txt})
|
||||||
result = prompt_retrieval(is_all=checkbox+['个人'], hosts=ipaddr.client.host)
|
result = prompt_retrieval(is_all=['个人'], hosts=ipaddr.client.host)
|
||||||
prompt.samples = result
|
prompt.samples = result
|
||||||
return "", "", ['个人'], prompt.update(samples=result, visible=True), prompt
|
return "", "", ['个人'], prompt.update(samples=result, visible=True), prompt
|
||||||
elif not txt or not name:
|
elif not txt or not name:
|
||||||
@ -339,15 +431,27 @@ def prompt_save(txt, name, checkbox, prompt: gr.Dataset, ipaddr: gr.Request):
|
|||||||
prompt.samples = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']]
|
prompt.samples = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']]
|
||||||
return txt, name, [], prompt.update(samples=result, visible=True), prompt
|
return txt, name, [], prompt.update(samples=result, visible=True), prompt
|
||||||
|
|
||||||
|
|
||||||
def prompt_input(txt, index, data: gr.Dataset):
|
def prompt_input(txt, index, data: gr.Dataset):
|
||||||
|
"""
|
||||||
|
点击dataset的值使用Prompt
|
||||||
|
Args:
|
||||||
|
txt: 输入框正文
|
||||||
|
index: 点击的Dataset下标
|
||||||
|
data: dataset原始对象
|
||||||
|
Returns:
|
||||||
|
返回注册函数所需的对象
|
||||||
|
"""
|
||||||
data_str = str(data.samples[index][1])
|
data_str = str(data.samples[index][1])
|
||||||
if txt:
|
if txt:
|
||||||
txt = data_str+'\n'+txt
|
txt = data_str + '\n' + txt
|
||||||
else:
|
else:
|
||||||
txt = data_str
|
txt = data_str
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
|
|
||||||
def copy_result(history):
|
def copy_result(history):
|
||||||
|
"""复制history"""
|
||||||
if history != []:
|
if history != []:
|
||||||
pyperclip.copy(history[-1])
|
pyperclip.copy(history[-1])
|
||||||
return '已将结果复制到剪切板'
|
return '已将结果复制到剪切板'
|
||||||
@ -356,12 +460,24 @@ def copy_result(history):
|
|||||||
|
|
||||||
|
|
||||||
def show_prompt_result(index, data: gr.Dataset, chatbot):
|
def show_prompt_result(index, data: gr.Dataset, chatbot):
|
||||||
|
"""
|
||||||
|
查看Prompt的对话记录结果
|
||||||
|
Args:
|
||||||
|
index: 点击的Dataset下标
|
||||||
|
data: dataset原始对象
|
||||||
|
chatbot:聊天机器人
|
||||||
|
Returns:
|
||||||
|
返回注册函数所需的对象
|
||||||
|
"""
|
||||||
click = data.samples[index]
|
click = data.samples[index]
|
||||||
chatbot.append((click[1], click[2]))
|
chatbot.append((click[1], click[2]))
|
||||||
return chatbot
|
return chatbot
|
||||||
|
|
||||||
|
|
||||||
def thread_write_chat(chatbot):
|
def thread_write_chat(chatbot):
|
||||||
|
"""
|
||||||
|
对话记录写入数据库
|
||||||
|
"""
|
||||||
private_key = toolbox.get_conf('private_key')[0]
|
private_key = toolbox.get_conf('private_key')[0]
|
||||||
chat_title = chatbot[0][0].split()
|
chat_title = chatbot[0][0].split()
|
||||||
i_say = chatbot[-1][0].strip("<p>/p")
|
i_say = chatbot[-1][0].strip("<p>/p")
|
||||||
@ -375,6 +491,7 @@ def thread_write_chat(chatbot):
|
|||||||
base_path = os.path.dirname(__file__)
|
base_path = os.path.dirname(__file__)
|
||||||
prompt_path = os.path.join(base_path, 'prompt_users')
|
prompt_path = os.path.join(base_path, 'prompt_users')
|
||||||
|
|
||||||
|
|
||||||
class YamlHandle:
|
class YamlHandle:
|
||||||
|
|
||||||
def __init__(self, file=os.path.join(prompt_path, 'ai_common.yaml')):
|
def __init__(self, file=os.path.join(prompt_path, 'ai_common.yaml')):
|
||||||
@ -383,7 +500,6 @@ class YamlHandle:
|
|||||||
self.file = file
|
self.file = file
|
||||||
self._load = self.load()
|
self._load = self.load()
|
||||||
|
|
||||||
|
|
||||||
def load(self) -> dict:
|
def load(self) -> dict:
|
||||||
with open(file=self.file, mode='r') as f:
|
with open(file=self.file, mode='r') as f:
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
@ -407,6 +523,7 @@ class YamlHandle:
|
|||||||
yaml.dump(date, f, allow_unicode=True)
|
yaml.dump(date, f, allow_unicode=True)
|
||||||
return date
|
return date
|
||||||
|
|
||||||
|
|
||||||
class JsonHandle:
|
class JsonHandle:
|
||||||
|
|
||||||
def __init__(self, file=os.path.join(prompt_path, 'prompts-PlexPt.json')):
|
def __init__(self, file=os.path.join(prompt_path, 'prompts-PlexPt.json')):
|
||||||
@ -419,27 +536,7 @@ class JsonHandle:
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
class FileHandle:
|
|
||||||
|
|
||||||
def __init__(self, file=None):
|
|
||||||
self.file = file
|
|
||||||
|
|
||||||
def read(self):
|
|
||||||
with open(file=self.file, mode='r') as f:
|
|
||||||
print(f.read())
|
|
||||||
|
|
||||||
|
|
||||||
def read_link(self):
|
|
||||||
link = 'https://github.com/PlexPt/awesome-chatgpt-prompts-zh/blob/main/prompts-zh.json'
|
|
||||||
name = link.split('/')[3]+link.split('/')[-1]
|
|
||||||
new_file = os.path.join(base_path, 'gpt_log', name)
|
|
||||||
response = requests.get(url=link, verify=False)
|
|
||||||
with open(new_file, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
for i in YamlHandle().load():
|
pass
|
||||||
print(i)
|
|
||||||
|
|
||||||
|
|
||||||
@ -458,6 +458,9 @@ def find_recent_files(directory):
|
|||||||
|
|
||||||
|
|
||||||
def get_user_upload(chatbot, ipaddr: gr.Request):
|
def get_user_upload(chatbot, ipaddr: gr.Request):
|
||||||
|
"""
|
||||||
|
获取用户上传过的文件
|
||||||
|
"""
|
||||||
private_upload = './private_upload'
|
private_upload = './private_upload'
|
||||||
user_history = os.path.join(private_upload, ipaddr.client.host)
|
user_history = os.path.join(private_upload, ipaddr.client.host)
|
||||||
history = ''
|
history = ''
|
||||||
@ -471,6 +474,9 @@ def get_user_upload(chatbot, ipaddr: gr.Request):
|
|||||||
|
|
||||||
|
|
||||||
def get_user_download(chatbot, link, file):
|
def get_user_download(chatbot, link, file):
|
||||||
|
"""
|
||||||
|
将短路径转换为下载链接
|
||||||
|
"""
|
||||||
for file_handle in str(file).split('\n'):
|
for file_handle in str(file).split('\n'):
|
||||||
if os.path.isfile(file_handle):
|
if os.path.isfile(file_handle):
|
||||||
# temp_file = func_box.copy_temp_file(file_handle) 无法使用外部的临时目录
|
# temp_file = func_box.copy_temp_file(file_handle) 无法使用外部的临时目录
|
||||||
|
|||||||
Reference in New Issue
Block a user