diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000..39f8308
Binary files /dev/null and b/.DS_Store differ
diff --git a/.gitignore b/.gitignore
index 18d3fb8..1625567 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,15 +2,14 @@
__pycache__/
*.py[cod]
*$py.class
-
# C extensions
*.so
-
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
+plugins/
downloads/
eggs/
.eggs/
@@ -26,7 +25,6 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
-
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
@@ -35,7 +33,6 @@ MANIFEST
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
-
# Unit test / coverage reports
htmlcov/
.tox/
@@ -49,91 +46,64 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
-
# Translations
*.mo
*.pot
-github
-.github
-TEMP
-TRASH
-
# Django stuff:
-*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
-
# Flask stuff:
instance/
.webassets-cache
-
# Scrapy stuff:
.scrapy
-
# Sphinx documentation
docs/_build/
-
+site/
# PyBuilder
target/
-
# Jupyter Notebook
.ipynb_checkpoints
-
# IPython
profile_default/
ipython_config.py
-
# pyenv
.python-version
-
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
-
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
-
# Celery stuff
celerybeat-schedule
celerybeat.pid
-
# SageMath parsed files
*.sage.py
-
# Environments
-.env
+.direnv/
.venv
env/
-venv/
+venv*/
ENV/
env.bak/
-venv.bak/
-
# Spyder project settings
.spyderproject
.spyproject
-
# Rope project settings
.ropeproject
-
# mkdocs documentation
/site
-
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
-
# Pyre type checker
.pyre/
-
-.vscode
.idea
-
history
ssr_conf
config_private.py
@@ -145,8 +115,12 @@ cradle*
debug*
private*
crazy_functions/test_project/pdf_and_word
+crazy_fun
+ctions/test_samples
crazy_functions/test_samples
request_llm/jittorllms
-multi-language
+users_data/*
request_llm/moss
+multi-language
media
+__test.py
\ No newline at end of file
diff --git a/__main__.py b/__main__.py
new file mode 100644
index 0000000..9a4e621
--- /dev/null
+++ b/__main__.py
@@ -0,0 +1,486 @@
+import os
+import gradio as gr
+from request_llm.bridge_all import predict
+from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_user_upload, \
+ get_conf, ArgsGeneralWrapper, DummyWith
+
+# 问询记录, python 版本建议3.9+(越新越好)
+import logging
+
+# 一些普通功能模块
+from core_functional import get_core_functions
+
+functional = get_core_functions()
+
+# 高级函数插件
+from crazy_functional import get_crazy_functions
+
+crazy_fns = get_crazy_functions()
+
+# 处理markdown文本格式的转变
+gr.Chatbot.postprocess = format_io
+
+# 做一些外观色彩上的调整
+from theme import adjust_theme, advanced_css, custom_css
+
+set_theme = adjust_theme()
+
+# 代理与自动更新
+from check_proxy import check_proxy, auto_update, warm_up_modules
+
+import func_box
+
+from check_proxy import get_current_version
+
+os.makedirs("gpt_log", exist_ok=True)
+try:
+ logging.basicConfig(filename="gpt_log/chat_secrets.log", level=logging.INFO, encoding="utf-8")
+except:
+ logging.basicConfig(filename="gpt_log/chat_secrets.log", level=logging.INFO)
+print("所有问询记录将自动保存在本地目录./gpt_log/chat_secrets.log, 请注意自我隐私保护哦!")
+
+# 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
+proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, LAYOUT, API_KEY, AVAIL_LLM_MODELS, LOCAL_PORT= \
+ get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'LAYOUT',
+ 'API_KEY', 'AVAIL_LLM_MODELS', 'LOCAL_PORT')
+
+proxy_info = check_proxy(proxies)
+# 如果WEB_PORT是-1, 则随机选取WEB端口
+PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT
+if not AUTHENTICATION: AUTHENTICATION = None
+os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
+
+
+class ChatBotFrame:
+
+ def __init__(self):
+ self.cancel_handles = []
+ self.initial_prompt = "You will play a professional to answer me according to my needs."
+ self.title_html = f"
'
+ pre = '
'
suf = '
'
if txt.startswith(pre) and txt.endswith(suf):
# print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
diff --git a/docs/translate_english.json b/docs/translate_english.json
index 57e008b..13b0869 100644
--- a/docs/translate_english.json
+++ b/docs/translate_english.json
@@ -265,7 +265,7 @@
"例如chatglm&gpt-3.5-turbo&api2d-gpt-4": "e.g. chatglm&gpt-3.5-turbo&api2d-gpt-4",
"先切换模型到openai或api2d": "Switch the model to openai or api2d first",
"在这里输入分辨率": "Enter the resolution here",
- "如256x256": "e.g. 256x256",
+ "如'256x256', '512x512', '1024x1024'": "e.g. '256x256', '512x512', '1024x1024'",
"默认": "Default",
"建议您复制一个config_private.py放自己的秘密": "We suggest you to copy a config_private.py file to keep your secrets, such as API and proxy URLs, from being accidentally uploaded to Github and seen by others.",
"如API和代理网址": "Such as API and proxy URLs",
diff --git a/docs/waifu_plugin/autoload.js b/docs/waifu_plugin/autoload.js
index 3464a5c..04a29e6 100644
--- a/docs/waifu_plugin/autoload.js
+++ b/docs/waifu_plugin/autoload.js
@@ -12,7 +12,7 @@ try {
live2d_settings['waifuTipsSize'] = '187x52';
live2d_settings['canSwitchModel'] = true;
live2d_settings['canSwitchTextures'] = true;
- live2d_settings['canSwitchHitokoto'] = false;
+ live2d_settings['canSwitchHitokoto'] = true;
live2d_settings['canTakeScreenshot'] = false;
live2d_settings['canTurnToHomePage'] = false;
live2d_settings['canTurnToAboutPage'] = false;
diff --git a/docs/waifu_plugin/waifu-tips.json b/docs/waifu_plugin/waifu-tips.json
index 229d5a1..524545c 100644
--- a/docs/waifu_plugin/waifu-tips.json
+++ b/docs/waifu_plugin/waifu-tips.json
@@ -34,10 +34,10 @@
"2": ["来自 Potion Maker 的 Tia 酱 ~"]
},
"hitokoto_api_message": {
- "lwl12.com": ["这句一言来自
『{source}』", ",是
{creator} 投稿的", "。"],
- "fghrsh.net": ["这句一言出处是
『{source}』,是
FGHRSH 在 {date} 收藏的!"],
- "jinrishici.com": ["这句诗词出自
《{title}》,是 {dynasty}诗人 {author} 创作的!"],
- "hitokoto.cn": ["这句一言来自
『{source}』,是
{creator} 在 hitokoto.cn 投稿的。"]
+ "lwl12.com": ["这句一言来自
『{source}』", ",是
{creator} 投稿的", "。"],
+ "fghrsh.net": ["这句一言出处是
『{source}』,是
FGHRSH 在 {date} 收藏的!"],
+ "jinrishici.com": ["这句诗词出自
《{title}》,是 {dynasty}诗人 {author} 创作的!"],
+ "hitokoto.cn": ["这句一言来自
『{source}』,是
{creator} 在 hitokoto.cn 投稿的。"]
}
},
"mouseover": [
diff --git a/func_box.py b/func_box.py
new file mode 100644
index 0000000..9924694
--- /dev/null
+++ b/func_box.py
@@ -0,0 +1,778 @@
+#! .\venv\
+# encoding: utf-8
+# @Time : 2023/4/18
+# @Author : Spike
+# @Descr :
+import ast
+import copy
+import hashlib
+import io
+import json
+import os.path
+import subprocess
+import threading
+import time
+from concurrent.futures import ThreadPoolExecutor
+import Levenshtein
+import psutil
+import re
+import tempfile
+import shutil
+from contextlib import ExitStack
+import logging
+import yaml
+import requests
+import tiktoken
+logger = logging
+from sklearn.feature_extraction.text import CountVectorizer
+import numpy as np
+from scipy.linalg import norm
+import pyperclip
+import random
+import gradio as gr
+import toolbox
+from prompt_generator import SqliteHandle
+from bs4 import BeautifulSoup
+import copy
+
+"""contextlib 是 Python 标准库中的一个模块,提供了一些工具函数和装饰器,用于支持编写上下文管理器和处理上下文的常见任务,例如资源管理、异常处理等。
+官网:https://docs.python.org/3/library/contextlib.html"""
+
+
+class Shell(object):
+ def __init__(self, args, stream=False):
+ self.args = args
+ self.subp = subprocess.Popen(args, shell=True,
+ stdin=subprocess.PIPE, stderr=subprocess.PIPE,
+ stdout=subprocess.PIPE, encoding='utf-8',
+ errors='ignore', close_fds=True)
+ self.__stream = stream
+ self.__temp = ''
+
+ def read(self):
+ logger.debug(f'The command being executed is: "{self.args}"')
+ if self.__stream:
+ sysout = self.subp.stdout
+ try:
+ with sysout as std:
+ for i in std:
+ logger.info(i.rstrip())
+ self.__temp += i
+ except KeyboardInterrupt as p:
+ return 3, self.__temp + self.subp.stderr.read()
+ finally:
+ return 3, self.__temp + self.subp.stderr.read()
+ else:
+ sysout = self.subp.stdout.read()
+ syserr = self.subp.stderr.read()
+ self.subp.stdin
+ if sysout:
+ logger.debug(f"{self.args} \n{sysout}")
+ return 1, sysout
+ elif syserr:
+ logger.error(f"{self.args} \n{syserr}")
+ return 0, syserr
+ else:
+ logger.debug(f"{self.args} \n{[sysout], [sysout]}")
+ return 2, '\n{}\n{}'.format(sysout, sysout)
+
+ def sync(self):
+ logger.debug('The command being executed is: "{}"'.format(self.args))
+ for i in self.subp.stdout:
+ logger.debug(i.rstrip())
+ self.__temp += i
+ yield self.__temp
+ for i in self.subp.stderr:
+ logger.debug(i.rstrip())
+ self.__temp += i
+ yield self.__temp
+
+
+def timeStatistics(func):
+ """
+ 统计函数执行时常的装饰器
+ """
+
+ def statistics(*args, **kwargs):
+ startTiem = time.time()
+ obj = func(*args, **kwargs)
+ endTiem = time.time()
+ ums = startTiem - endTiem
+ print('func:{} > Time-consuming: {}'.format(func, ums))
+ return obj
+
+ return statistics
+
+
+def copy_temp_file(file):
+ if os.path.exists(file):
+ exdir = tempfile.mkdtemp()
+ temp_ = shutil.copy(file, os.path.join(exdir, os.path.basename(file)))
+ return temp_
+ else:
+ return None
+
+
+def md5_str(st):
+ # 创建一个 MD5 对象
+ md5 = hashlib.md5()
+ # 更新 MD5 对象的内容
+ md5.update(str(st).encode())
+ # 获取加密后的结果
+ result = md5.hexdigest()
+ return result
+
+
+def html_tag_color(tag, color=None, font='black'):
+ """
+ 将文本转换为带有高亮提示的html代码
+ """
+ if not color:
+ rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
+ color = f"rgb{rgb}"
+ tag = f'
{tag} '
+ return tag
+
+def html_a_blank(__href, name=''):
+ if not name:
+ name = __href
+ a = f'
{name}'
+ return a
+
+def html_view_blank(__href, file_name=''):
+ if os.path.exists(__href):
+ __href = f'/file={__href}'
+ if not file_name:
+ file_name = __href.split('/')[-1]
+ a = f'
{file_name}'
+ return a
+
+def html_iframe_code(html_file):
+ proxy, = toolbox.get_conf('LOCAL_PORT')
+ html_file = f'http://{ipaddr()}:{proxy}/file={html_file}'
+ ifr = f'
'
+ return ifr
+
+
+def html_download_blank(__href, file_name='temp', dir_name=''):
+ if os.path.exists(__href):
+ __href = f'/file={__href}'
+ if not dir_name:
+ dir_name = file_name
+ a = f'
{file_name}'
+ return a
+
+def html_local_img(__file):
+ a = f'

'
+ return a
+
+def ipaddr():
+ # 获取本地ipx
+ ip = psutil.net_if_addrs()
+ for i in ip:
+ if ip[i][0][3]:
+ return ip[i][0][1]
+
+
+def encryption_str(txt: str):
+ """(关键字)(加密间隔)匹配机制(关键字间隔)"""
+ txt = str(txt)
+ pattern = re.compile(rf"(Authorization|WPS-Sid|Cookie)(:|\s+)\s*(\S+)[\s\S]*?(?=\n|$|\s)", re.IGNORECASE)
+ result = pattern.sub(lambda x: x.group(1) + ": XXXXXXXX", txt)
+ return result
+
+
+def tree_out(dir=os.path.dirname(__file__), line=2, more=''):
+ """
+ 获取本地文件的树形结构转化为Markdown代码文本
+ """
+ out = Shell(f'tree {dir} -F -I "__*|.*|venv|*.png|*.xlsx" -L {line} {more}').read()[1]
+ localfile = os.path.join(os.path.dirname(__file__), '.tree.md')
+ with open(localfile, 'w') as f:
+ f.write('```\n')
+ ll = out.splitlines()
+ for i in range(len(ll)):
+ if i == 0:
+ f.write(ll[i].split('/')[-2] + '\n')
+ else:
+ f.write(ll[i] + '\n')
+ f.write('```\n')
+
+
+def chat_history(log: list, split=0):
+ """
+ auto_gpt 使用的代码,后续会迁移
+ """
+ if split:
+ log = log[split:]
+ chat = ''
+ history = ''
+ for i in log:
+ chat += f'{i[0]}\n\n'
+ history += f'{i[1]}\n\n'
+ return chat, history
+
+
+def df_similarity(s1, s2):
+ """弃用,会警告,这个库不会用"""
+ def add_space(s):
+ return ' '.join(list(s))
+
+ # 将字中间加入空格
+ s1, s2 = add_space(s1), add_space(s2)
+ # 转化为TF矩阵
+ cv = CountVectorizer(tokenizer=lambda s: s.split())
+ corpus = [s1, s2]
+ vectors = cv.fit_transform(corpus).toarray()
+ # 计算TF系数
+ return np.dot(vectors[0], vectors[1]) / (norm(vectors[0]) * norm(vectors[1]))
+
+
+def check_json_format(file):
+ """
+ 检查上传的Json文件是否符合规范
+ """
+ new_dict = {}
+ data = JsonHandle(file).load()
+ if type(data) is list and len(data) > 0:
+ if type(data[0]) is dict:
+ for i in data:
+ new_dict.update({i['act']: i['prompt']})
+ return new_dict
+
+
+def json_convert_dict(file):
+ """
+ 批量将json转换为字典
+ """
+ new_dict = {}
+ for root, dirs, files in os.walk(file):
+ for f in files:
+ if f.startswith('prompt') and f.endswith('json'):
+ new_dict.update(check_json_format(f))
+ return new_dict
+
+
+def draw_results(txt, prompt: gr.Dataset, percent, switch, ipaddr: gr.Request):
+ """
+ 绘制搜索结果
+ Args:
+ txt (str): 过滤文本
+ prompt : 原始的dataset对象
+ percent (int): TF系数,用于计算文本相似度
+ switch (list): 过滤个人或所有人的Prompt
+ ipaddr : 请求人信息
+ Returns:
+ 注册函数所需的元祖对象
+ """
+ data = diff_list(txt, percent=percent, switch=switch, hosts=ipaddr.client.host)
+ prompt.samples = data
+ return prompt.update(samples=data, visible=True), prompt
+
+
+def diff_list(txt='', percent=0.70, switch: list = None, lst: dict = None, sp=15, hosts=''):
+ """
+ 按照搜索结果统计相似度的文本,两组文本相似度>70%的将统计在一起,取最长的作为key
+ Args:
+ txt (str): 过滤文本
+ percent (int): TF系数,用于计算文本相似度
+ switch (list): 过滤个人或所有人的Prompt
+ lst:指定一个列表或字典
+ sp: 截取展示的文本长度
+ hosts : 请求人的ip
+ Returns:
+ 返回一个列表
+ """
+ count_dict = {}
+ is_all = toolbox.get_conf('prompt_list')[0]['key'][1]
+ if not lst:
+ lst = {}
+ tabs = SqliteHandle().get_tables()
+ if is_all in switch:
+ lst.update(SqliteHandle(f"ai_common_{hosts}").get_prompt_value(txt))
+ else:
+ for tab in tabs:
+ if tab.startswith('ai_common'):
+ lst.update(SqliteHandle(f"{tab}").get_prompt_value(txt))
+ lst.update(SqliteHandle(f"ai_private_{hosts}").get_prompt_value(txt))
+ # diff 数据,根据precent系数归类数据
+ str_ = time.time()
+ def tf_factor_calcul(i):
+ found = False
+ dict_copy = count_dict.copy()
+ for key in dict_copy.keys():
+ str_tf = Levenshtein.jaro_winkler(i, key)
+ if str_tf >= percent:
+ if len(i) > len(key):
+ count_dict[i] = count_dict.copy()[key] + 1
+ count_dict.pop(key)
+ else:
+ count_dict[key] += 1
+ found = True
+ break
+ if not found: count_dict[i] = 1
+ with ThreadPoolExecutor(100) as executor:
+ executor.map(tf_factor_calcul, lst)
+ print('计算耗时', time.time()-str_)
+ sorted_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)
+ if switch:
+ sorted_dict += prompt_retrieval(is_all=switch, hosts=hosts, search=True)
+ dateset_list = []
+ for key in sorted_dict:
+ # 开始匹配关键字
+ index = str(key[0]).lower().find(txt.lower())
+ index_ = str(key[1]).lower().find(txt.lower())
+ if index != -1 or index_ != -1:
+ if index == -1: index = index_ # 增加搜索prompt 名称
+ # sp=split 用于判断在哪里启动、在哪里断开
+ if index - sp > 0:
+ start = index - sp
+ else:
+ start = 0
+ if len(key[0]) > sp * 2:
+ end = key[0][-sp:]
+ else:
+ end = ''
+ # 判断有没有传需要匹配的字符串,有则筛选、无则全返
+ if txt == '' and len(key[0]) >= sp:
+ show = key[0][0:sp] + " . . . " + end
+ show = show.replace('<', '')
+ elif txt == '' and len(key[0]) < sp:
+ show = key[0][0:sp]
+ show = show.replace('<', '')
+ else:
+ show = str(key[0][start:index + sp]).replace('<', '').replace(txt, html_tag_color(txt))
+ show += f" {html_tag_color(' X ' + str(key[1]))}"
+ if lst.get(key[0]):
+ be_value = lst[key[0]]
+ else:
+ be_value = None
+ value = be_value
+ dateset_list.append([show, key[0], value, key[1]])
+ return dateset_list
+
+
+def prompt_upload_refresh(file, prompt, ipaddr: gr.Request):
+ """
+ 上传文件,将文件转换为字典,然后存储到数据库,并刷新Prompt区域
+ Args:
+ file: 上传的文件
+ prompt: 原始prompt对象
+ ipaddr:ipaddr用户请求信息
+ Returns:
+ 注册函数所需的元祖对象
+ """
+ hosts = ipaddr.client.host
+ if file.name.endswith('json'):
+ upload_data = check_json_format(file.name)
+ elif file.name.endswith('yaml'):
+ upload_data = YamlHandle(file.name).load()
+ else:
+ upload_data = {}
+ if upload_data != {}:
+ SqliteHandle(f'prompt_{hosts}').inset_prompt(upload_data)
+ ret_data = prompt_retrieval(is_all=['个人'], hosts=hosts)
+ return prompt.update(samples=ret_data, visible=True), prompt, ['个人']
+ else:
+ prompt.samples = [[f'{html_tag_color("数据解析失败,请检查文件是否符合规范", color="red")}', '']]
+ return prompt.samples, prompt, []
+
+
+def prompt_retrieval(is_all, hosts='', search=False):
+ """
+ 上传文件,将文件转换为字典,然后存储到数据库,并刷新Prompt区域
+ Args:
+ is_all: prompt类型
+ hosts: 查询的用户ip
+ search:支持搜索,搜索时将key作为key
+ Returns:
+ 返回一个列表
+ """
+ count_dict = {}
+ if '所有人' in is_all:
+ for tab in SqliteHandle('ai_common').get_tables():
+ if tab.startswith('prompt'):
+ data = SqliteHandle(tab).get_prompt_value(None)
+ if data: count_dict.update(data)
+ elif '个人' in is_all:
+ data = SqliteHandle(f'prompt_{hosts}').get_prompt_value(None)
+ if data: count_dict.update(data)
+ retrieval = []
+ if count_dict != {}:
+ for key in count_dict:
+ if not search:
+ retrieval.append([key, count_dict[key]])
+ else:
+ retrieval.append([count_dict[key], key])
+ return retrieval
+ else:
+ return retrieval
+
+
+def prompt_reduce(is_all, prompt: gr.Dataset, ipaddr: gr.Request): # is_all, ipaddr: gr.Request
+ """
+ 上传文件,将文件转换为字典,然后存储到数据库,并刷新Prompt区域
+ Args:
+ is_all: prompt类型
+ prompt: dataset原始对象
+ ipaddr:请求用户信息
+ Returns:
+ 返回注册函数所需的对象
+ """
+ data = prompt_retrieval(is_all=is_all, hosts=ipaddr.client.host)
+ prompt.samples = data
+ return prompt.update(samples=data, visible=True), prompt, is_all
+
+
+def prompt_save(txt, name, prompt: gr.Dataset, ipaddr: gr.Request):
+ """
+ 编辑和保存Prompt
+ Args:
+ txt: Prompt正文
+ name: Prompt的名字
+ prompt: dataset原始对象
+ ipaddr:请求用户信息
+ Returns:
+ 返回注册函数所需的对象
+ """
+ if txt and name:
+ yaml_obj = SqliteHandle(f'prompt_{ipaddr.client.host}')
+ yaml_obj.inset_prompt({name: txt})
+ result = prompt_retrieval(is_all=['个人'], hosts=ipaddr.client.host)
+ prompt.samples = result
+ return "", "", ['个人'], prompt.update(samples=result, visible=True), prompt, gr.Tabs.update(selected='chatbot')
+ elif not txt or not name:
+ result = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']]
+ prompt.samples = [[f'{html_tag_color("编辑框 or 名称不能为空!!!!!", color="red")}', '']]
+ return txt, name, [], prompt.update(samples=result, visible=True), prompt, gr.Tabs.update(selected='chatbot')
+
+
+def prompt_input(txt: str, prompt_str, name_str, index, data: gr.Dataset, tabs_index):
+ """
+ 点击dataset的值使用Prompt
+ Args:
+ txt: 输入框正文
+ index: 点击的Dataset下标
+ data: dataset原始对象
+ Returns:
+ 返回注册函数所需的对象
+ """
+ data_str = str(data.samples[index][1])
+ data_name = str(data.samples[index][0])
+ rp_str = '{{{v}}}'
+
+ def str_v_handle(__str):
+ if data_str.find(rp_str) != -1 and __str:
+ txt_temp = data_str.replace(rp_str, __str)
+ elif __str:
+ txt_temp = data_str + '\n' + __str
+ else:
+ txt_temp = data_str
+ return txt_temp
+ if tabs_index == 1:
+ new_txt = str_v_handle(prompt_str)
+ return txt, new_txt, data_name
+ else:
+ new_txt = str_v_handle(txt)
+ return new_txt, prompt_str, name_str
+
+
+def copy_result(history):
+ """复制history"""
+ if history != []:
+ pyperclip.copy(history[-1])
+ return '已将结果复制到剪切板'
+ else:
+ return "无对话记录,复制错误!!"
+
+
+def str_is_list(s):
+ try:
+ list_ast = ast.literal_eval(s)
+ return isinstance(list_ast, list)
+ except (SyntaxError, ValueError):
+ return False
+
+
+def show_prompt_result(index, data: gr.Dataset, chatbot, pro_edit, pro_name):
+ """
+ 查看Prompt的对话记录结果
+ Args:
+ index: 点击的Dataset下标
+ data: dataset原始对象
+ chatbot:聊天机器人
+ Returns:
+ 返回注册函数所需的对象
+ """
+ click = data.samples[index]
+ if str_is_list(click[2]):
+ list_copy = eval(click[2])
+ for i in range(0, len(list_copy), 2):
+ if i + 1 >= len(list_copy): # 如果下标越界了,单独处理最后一个元素
+ chatbot.append([list_copy[i]])
+ else:
+ chatbot.append([list_copy[i], list_copy[i + 1]])
+ elif click[2] is None and pro_edit == '':
+ pro_edit = click[1]
+ pro_name = click[3]
+ else:
+ chatbot.append((click[1], click[2]))
+ return chatbot, pro_edit, pro_name
+
+
+
+def pattern_html(html):
+ bs = BeautifulSoup(str(html), 'html.parser')
+ md_message = bs.find('div', {'class': 'md-message'})
+ if md_message:
+ return md_message.get_text(separator='')
+ else:
+ return ""
+
+
+def thread_write_chat(chatbot, history):
+ """
+ 对话记录写入数据库
+ """
+ chatbot, history = copy.copy(chatbot), copy.copy(history)
+ private_key = toolbox.get_conf('private_key')[0]
+ chat_title = chatbot[0][1].split()
+ i_say = pattern_html(chatbot[-1][0])
+ if history:
+ gpt_result = history
+ else: # 如果历史对话不存在,那么读取对话框
+ gpt_result = [pattern_html(v) for i in chatbot for v in i]
+ if private_key in chat_title:
+ SqliteHandle(f'ai_private_{chat_title[-2]}').inset_prompt({i_say: gpt_result})
+ else:
+ SqliteHandle(f'ai_common_{chat_title[-2]}').inset_prompt({i_say: gpt_result})
+
+
+base_path = os.path.dirname(__file__)
+prompt_path = os.path.join(base_path, 'users_data')
+users_path = os.path.join(base_path, 'private_upload')
+logs_path = os.path.join(base_path, 'gpt_log')
+
+def reuse_chat(result, chatbot, history, pro_numb, say):
+ """复用对话记录"""
+ if result is None or result == []:
+ return chatbot, history, gr.update(), gr.update(), '', gr.Column.update()
+ else:
+ if pro_numb:
+ chatbot += result
+ history += [pattern_html(_) for i in result for _ in i]
+ else:
+ chatbot.append(result[-1])
+ history += [pattern_html(_) for i in result[-2:] for _ in i]
+ print(chatbot[-1][0])
+ return chatbot, history, say, gr.Tabs.update(selected='chatbot'), '', gr.Column.update(visible=False)
+
+
+def num_tokens_from_string(listing: list, encoding_name: str = 'cl100k_base') -> int:
+ """Returns the number of tokens in a text string."""
+ count_tokens = 0
+ for i in listing:
+ encoding = tiktoken.get_encoding(encoding_name)
+ count_tokens += len(encoding.encode(i))
+ return count_tokens
+
+
+def spinner_chatbot_loading(chatbot):
+ loading = [''.join(['.' * random.randint(1, 5)])]
+ # 将元组转换为列表并修改元素
+ loading_msg = copy.deepcopy(chatbot)
+ temp_list = list(loading_msg[-1])
+
+ temp_list[1] = pattern_html(temp_list[1]) + f'{random.choice(loading)}'
+ # 将列表转换回元组并替换原始元组
+ loading_msg[-1] = tuple(temp_list)
+ return loading_msg
+
+
+def refresh_load_data(chat, history, prompt, crazy_list, request: gr.Request):
+ """
+ Args:
+ chat: 聊天组件
+ history: 对话记录
+ prompt: prompt dataset组件
+
+ Returns:
+ 预期是每次刷新页面,加载最新
+ """
+ is_all = toolbox.get_conf('prompt_list')[0]['key'][0]
+ data = prompt_retrieval(is_all=[is_all])
+ prompt.samples = data
+ selected = random.sample(crazy_list, 4)
+ user_agent = request.kwargs['headers']['user-agent'].lower()
+ if user_agent.find('android') != -1 or user_agent.find('iphone') != -1:
+ hied_elem = gr.update(visible=False)
+ else:
+ hied_elem = gr.update()
+ outputs = [prompt.update(samples=data, visible=True), prompt,
+ chat, history, gr.Dataset.update(samples=[[i] for i in selected]), selected,
+ hied_elem, hied_elem]
+ return outputs
+
+
+
+def txt_converter_json(input_string):
+ try:
+ if input_string.startswith("{") and input_string.endswith("}"):
+ # 尝试将字符串形式的字典转换为字典对象
+ dict_object = ast.literal_eval(input_string)
+ else:
+ # 尝试将字符串解析为JSON对象
+ dict_object = json.loads(input_string)
+ formatted_json_string = json.dumps(dict_object, indent=4, ensure_ascii=False)
+ return formatted_json_string
+ except (ValueError, SyntaxError):
+ return input_string
+
+
+def clean_br_string(s):
+ s = re.sub('<\s*br\s*/?>', '\n', s) # 使用正则表达式同时匹配
、
、
、< br>和< br/>
+ return s
+
+
+def update_btn(self,
+ value: str = None,
+ variant: str = None,
+ visible: bool = None,
+ interactive: bool = None,
+ elem_id: str = None,
+ label: str = None
+):
+ if not variant: variant = self.variant
+ if not visible: visible = self.visible
+ if not value: value = self.value
+ if not interactive: interactive = self.interactive
+ if not elem_id: elem_id = self.elem_id
+ if not elem_id: label = self.label
+ return {
+ "variant": variant,
+ "visible": visible,
+ "value": value,
+ "interactive": interactive,
+ 'elem_id': elem_id,
+ 'label': label,
+ "__type__": "update",
+ }
+
+def update_txt(self,
+ value: str = None,
+ lines: int = None,
+ max_lines: int = None,
+ placeholder: str = None,
+ label: str = None,
+ show_label: bool = None,
+ visible: bool = None,
+ interactive: bool = None,
+ type: str = None,
+ elem_id: str = None
+ ):
+
+ return {
+ "lines": self.lines,
+ "max_lines": self.max_lines,
+ "placeholder": self.placeholder,
+ "label": self.label,
+ "show_label": self.show_label,
+ "visible": self.visible,
+ "value": self.value,
+ "type": self.type,
+ "interactive": self.interactive,
+ "elem_id": elem_id,
+ "__type__": "update",
+
+ }
+
+
+def get_html(filename):
+ path = os.path.join(base_path, "docs/assets", "html", filename)
+ if os.path.exists(path):
+ with open(path, encoding="utf8") as file:
+ return file.read()
+ return ""
+
+
+def git_log_list():
+ ll = Shell("git log --pretty=format:'%s | %h' -n 10").read()[1].splitlines()
+
+ return [i.split('|') for i in ll if 'branch' not in i][:5]
+
+import qrcode
+from PIL import Image, ImageDraw
+def qr_code_generation(data, icon_path=None, file_name='qc_icon.png'):
+ # 创建qrcode对象
+ qr = qrcode.QRCode(version=2, error_correction=qrcode.constants.ERROR_CORRECT_Q, box_size=10, border=2,)
+ qr.add_data(data)
+ # 创建二维码图片
+ img = qr.make_image()
+ # 图片转换为RGBA格式
+ img = img.convert('RGBA')
+ # 返回二维码图片的大小
+ img_w, img_h = img.size
+ # 打开logo
+ if not icon_path:
+ icon_path = os.path.join(base_path, 'docs/assets/PLAI.jpeg')
+ logo = Image.open(icon_path)
+ # logo大小为二维码的四分之一
+ logo_w = img_w // 4
+ logo_h = img_w // 4
+ # 修改logo图片大小
+ logo = logo.resize((logo_w, logo_h), Image.LANCZOS) # or Image.Resampling.LANCZOS
+ # 把logo放置在二维码中间
+ w = (img_w - logo_w) // 2
+ h = (img_h - logo_h) // 2
+ img.paste(logo, (w, h))
+ qr_path = os.path.join(logs_path, 'file_name')
+ img.save()
+ return qr_path
+
+
+class YamlHandle:
+
+ def __init__(self, file=os.path.join(prompt_path, 'ai_common.yaml')):
+ if not os.path.exists(file):
+ Shell(f'touch {file}').read()
+ self.file = file
+ self._load = self.load()
+
+ def load(self) -> dict:
+ with open(file=self.file, mode='r') as f:
+ data = yaml.safe_load(f)
+ return data
+
+ def update(self, key, value):
+ date = self._load
+ if not date:
+ date = {}
+ date[key] = value
+ with open(file=self.file, mode='w') as f:
+ yaml.dump(date, f, allow_unicode=True)
+ return date
+
+ def dump_dict(self, new_dict):
+ date = self._load
+ if not date:
+ date = {}
+ date.update(new_dict)
+ with open(file=self.file, mode='w') as f:
+ yaml.dump(date, f, allow_unicode=True)
+ return date
+
+
+class JsonHandle:
+
+ def __init__(self, file):
+ self.file = file
+
+ def load(self) -> object:
+ with open(self.file, 'r') as f:
+ data = json.load(f)
+ return data
+
+
+
+if __name__ == '__main__':
+ pass
\ No newline at end of file
diff --git a/main.py b/main.py
index 2cbb27f..b9a0992 100644
--- a/main.py
+++ b/main.py
@@ -130,9 +130,9 @@ def main():
ret.update({plugin_advanced_arg: gr.update(visible=("插件参数区" in a))})
if "底部输入区" in a: ret.update({txt: gr.update(value="")})
return ret
- checkboxes.select(fn_area_visibility, [checkboxes], [area_basic_fn, area_crazy_fn, area_input_primary, area_input_secondary, txt, txt2, clearBtn, clearBtn2, plugin_advanced_arg] )
+ checkboxes.select(fn_area_visibility, [checkboxes], [area_basic_fn, area_crazy_fn, area_input_primary, area_input_secondary, txt, clearBtn, clearBtn2, plugin_advanced_arg] )
# 整理反复出现的控件句柄组合
- input_combo = [cookies, max_length_sl, md_dropdown, txt, txt2, top_p, temperature, chatbot, history, system_prompt, plugin_advanced_arg]
+ input_combo = [cookies, max_length_sl, md_dropdown, txt, top_p, temperature, chatbot, history, system_prompt, plugin_advanced_arg]
output_combo = [cookies, chatbot, history, status]
predict_args = dict(fn=ArgsGeneralWrapper(predict), inputs=input_combo, outputs=output_combo)
# 提交按钮、重置按钮
@@ -155,7 +155,7 @@ def main():
click_handle = functional[k]["Button"].click(fn=ArgsGeneralWrapper(predict), inputs=[*input_combo, gr.State(True), gr.State(k)], outputs=output_combo)
cancel_handles.append(click_handle)
# 文件上传区,接收文件后与chatbot的互动
- file_upload.upload(on_file_uploaded, [file_upload, chatbot, txt, txt2, checkboxes], [chatbot, txt, txt2])
+ file_upload.upload(on_file_uploaded, [file_upload, chatbot, txt ], [chatbot, txt])
# 函数插件-固定按钮区
for k in crazy_fns:
if not crazy_fns[k].get("AsButton", True): continue
@@ -174,7 +174,7 @@ def main():
dropdown.select(on_dropdown_changed, [dropdown], [switchy_bt, plugin_advanced_arg] )
def on_md_dropdown_changed(k):
return {chatbot: gr.update(label="当前模型:"+k)}
- md_dropdown.select(on_md_dropdown_changed, [md_dropdown], [chatbot] )
+ md_dropdown.select(on_md_dropdown_changed, [md_dropdown], [chatbot])
# 随变按钮的回调函数注册
def route(k, *args, **kwargs):
if k in [r"打开插件列表", r"请先从插件列表中选择"]: return
diff --git a/prompt_generator.py b/prompt_generator.py
new file mode 100644
index 0000000..0efdc83
--- /dev/null
+++ b/prompt_generator.py
@@ -0,0 +1,102 @@
+#! .\venv\
+# encoding: utf-8
+# @Time : 2023/4/19
+# @Author : Spike
+# @Descr :
+import os.path
+import sqlite3
+import threading
+import functools
+import func_box
+# 连接到数据库
+base_path = os.path.dirname(__file__)
+prompt_path = os.path.join(base_path, 'users_data')
+
+
+def connect_db_close(cls_method):
+ @functools.wraps(cls_method)
+ def wrapper(cls=None, *args, **kwargs):
+ cls._connect_db()
+ result = cls_method(cls, *args, **kwargs)
+ cls._close_db()
+ return result
+ return wrapper
+
+
+class SqliteHandle:
+ def __init__(self, table='ai_common', database='ai_prompt.db'):
+ self.__database = database
+ self.__connect = sqlite3.connect(os.path.join(prompt_path, self.__database))
+ self.__cursor = self.__connect.cursor()
+ self.__table = table
+ if self.__table not in self.get_tables():
+ self.create_tab()
+
+ def new_connect_db(self):
+ """多线程操作时,每个线程新建独立的connect"""
+ self.__connect = sqlite3.connect(os.path.join(prompt_path, self.__database))
+ self.__cursor = self.__connect.cursor()
+
+ def new_close_db(self):
+ self.__cursor.close()
+ self.__connect.close()
+
+ def create_tab(self):
+ self.__cursor.execute(f"CREATE TABLE `{self.__table}` ('prompt' TEXT UNIQUE, 'result' TEXT)")
+
+ def get_tables(self):
+ all_tab = []
+ result = self.__cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table';")
+ for tab in result:
+ all_tab.append(tab[0])
+ return all_tab
+
+ def get_prompt_value(self, find=None):
+ temp_all = {}
+ if find:
+ result = self.__cursor.execute(f"SELECT prompt, result FROM `{self.__table}` WHERE prompt LIKE '%{find}%'").fetchall()
+ else:
+ result = self.__cursor.execute(f"SELECT prompt, result FROM `{self.__table}`").fetchall()
+ for row in result:
+ temp_all[row[0]] = row[1]
+ return temp_all
+
+ def inset_prompt(self, prompt: dict):
+ for key in prompt:
+ self.__cursor.execute(f"REPLACE INTO `{self.__table}` (prompt, result) VALUES (?, ?);", (str(key), str(prompt[key])))
+ self.__connect.commit()
+
+ def delete_prompt(self, name):
+ self.__cursor.execute(f"DELETE from `{self.__table}` where prompt LIKE '{name}'")
+ self.__connect.commit()
+
+ def delete_tabls(self, tab):
+ self.__cursor.execute(f"DROP TABLE `{tab}`;")
+ self.__connect.commit()
+
+ def find_prompt_result(self, name):
+ query = self.__cursor.execute(f"SELECT result FROM `{self.__table}` WHERE prompt LIKE '{name}'").fetchall()
+ if query == []:
+ query = self.__cursor.execute(f"SELECT result FROM `prompt_127.0.0.1` WHERE prompt LIKE '{name}'").fetchall()
+ return query[0][0]
+ else:
+ return query[0][0]
+
+def cp_db_data(incloud_tab='prompt'):
+ sql_ll = sqlite_handle(database='ai_prompt_cp.db')
+ tabs = sql_ll.get_tables()
+ for i in tabs:
+ if str(i).startswith(incloud_tab):
+ old_data = sqlite_handle(table=i, database='ai_prompt_cp.db').get_prompt_value()
+ sqlite_handle(table=i).inset_prompt(old_data)
+
+def inset_127_prompt():
+ sql_handle = sqlite_handle(table='prompt_127.0.0.1')
+ prompt_json = os.path.join(prompt_path, 'prompts-PlexPt.json')
+ data_list = func_box.JsonHandle(prompt_json).load()
+ for i in data_list:
+ sql_handle.inset_prompt(prompt={i['act']: i['prompt']})
+
+sqlite_handle = SqliteHandle
+if __name__ == '__main__':
+ cp_db_data()
\ No newline at end of file
diff --git a/request_llm/bridge_all.py b/request_llm/bridge_all.py
index d33f161..d595ad5 100644
--- a/request_llm/bridge_all.py
+++ b/request_llm/bridge_all.py
@@ -13,8 +13,11 @@ from functools import lru_cache
from concurrent.futures import ThreadPoolExecutor
from toolbox import get_conf, trimmed_format_exc
-from .bridge_chatgpt import predict_no_ui_long_connection as chatgpt_noui
-from .bridge_chatgpt import predict as chatgpt_ui
+from request_llm.bridge_chatgpt import predict_no_ui_long_connection as chatgpt_noui
+from request_llm.bridge_chatgpt import predict as chatgpt_ui
+
+from .bridge_azure_test import predict_no_ui_long_connection as azure_noui
+from .bridge_azure_test import predict as azure_ui
from .bridge_azure_test import predict_no_ui_long_connection as azure_noui
from .bridge_azure_test import predict as azure_ui
@@ -51,10 +54,11 @@ class LazyloadTiktoken(object):
return encoder.decode(*args, **kwargs)
# Endpoint 重定向
-API_URL_REDIRECT, = get_conf("API_URL_REDIRECT")
+API_URL_REDIRECT, PROXY_API_URL = get_conf("API_URL_REDIRECT", 'PROXY_API_URL')
openai_endpoint = "https://api.openai.com/v1/chat/completions"
api2d_endpoint = "https://openai.api2d.net/v1/chat/completions"
newbing_endpoint = "wss://sydney.bing.com/sydney/ChatHub"
+proxy_endpoint = PROXY_API_URL
# 兼容旧版的配置
try:
API_URL, = get_conf("API_URL")
@@ -69,6 +73,7 @@ if api2d_endpoint in API_URL_REDIRECT: api2d_endpoint = API_URL_REDIRECT[api2d_e
if newbing_endpoint in API_URL_REDIRECT: newbing_endpoint = API_URL_REDIRECT[newbing_endpoint]
+
# 获取tokenizer
tokenizer_gpt35 = LazyloadTiktoken("gpt-3.5-turbo")
tokenizer_gpt4 = LazyloadTiktoken("gpt-4")
@@ -122,6 +127,15 @@ model_info = {
"tokenizer": tokenizer_gpt4,
"token_cnt": get_token_num_gpt4,
},
+ # azure openai
+ "azure-gpt35":{
+ "fn_with_ui": azure_ui,
+ "fn_without_ui": azure_noui,
+ "endpoint": get_conf("AZURE_ENDPOINT"),
+ "max_token": 4096,
+ "tokenizer": tokenizer_gpt35,
+ "token_cnt": get_token_num_gpt35,
+ },
# azure openai
"azure-gpt35":{
@@ -147,9 +161,9 @@ model_info = {
"fn_with_ui": chatgpt_ui,
"fn_without_ui": chatgpt_noui,
"endpoint": api2d_endpoint,
- "max_token": 8192,
- "tokenizer": tokenizer_gpt4,
- "token_cnt": get_token_num_gpt4,
+ "max_token": 4096,
+ "tokenizer": tokenizer_gpt35,
+ "token_cnt": get_token_num_gpt35,
},
# 将 chatglm 直接对齐到 chatglm2
diff --git a/request_llm/bridge_chatgpt.py b/request_llm/bridge_chatgpt.py
index eef8fbf..cfca56a 100644
--- a/request_llm/bridge_chatgpt.py
+++ b/request_llm/bridge_chatgpt.py
@@ -12,12 +12,14 @@
"""
import json
+import random
import time
import gradio as gr
import logging
import traceback
import requests
import importlib
+import func_box
# config_private.py放自己的秘密如API和代理网址
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
@@ -60,7 +62,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
while True:
try:
# make a POST request to the API endpoint, stream=False
- from .bridge_all import model_info
+ from request_llm.bridge_all import model_info
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
response = requests.post(endpoint, headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
@@ -106,7 +108,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
return result
-def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
+def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
"""
发送至chatGPT,流式获取输出。
用于基础的对话功能。
@@ -134,24 +136,22 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
inputs = core_functional[additional_fn]["Prefix"] + inputs + core_functional[additional_fn]["Suffix"]
raw_input = inputs
- logging.info(f'[raw_input] {raw_input}')
+ logging.info(f'[raw_input]_{llm_kwargs["ipaddr"]} {raw_input}')
chatbot.append((inputs, ""))
- yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
-
+ loading_msg = func_box.spinner_chatbot_loading(chatbot)
+ yield from update_ui(chatbot=loading_msg, history=history, msg="等待响应") # 刷新界面
try:
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt, stream)
except RuntimeError as e:
chatbot[-1] = (inputs, f"您提供的api-key不满足要求,不包含任何可用于{llm_kwargs['llm_model']}的api-key。您可能选择了错误的模型或请求源。")
yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
return
-
history.append(inputs); history.append("")
-
retry = 0
while True:
try:
# make a POST request to the API endpoint, stream=True
- from .bridge_all import model_info
+ from request_llm.bridge_all import model_info
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
response = requests.post(endpoint, headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
@@ -163,7 +163,6 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
if retry > MAX_RETRY: raise TimeoutError
gpt_replying_buffer = ""
-
is_head_of_the_stream = True
if stream:
stream_response = response.iter_lines()
@@ -181,24 +180,26 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
if is_head_of_the_stream and (r'"object":"error"' not in chunk.decode()):
# 数据流的第一帧不携带content
is_head_of_the_stream = False; continue
-
+
if chunk:
try:
chunk_decoded = chunk.decode()
# 前者API2D的
if ('data: [DONE]' in chunk_decoded) or (len(json.loads(chunk_decoded[6:])['choices'][0]["delta"]) == 0):
# 判定为数据流的结束,gpt_replying_buffer也写完了
- logging.info(f'[response] {gpt_replying_buffer}')
+ logging.info(f'[response]_{llm_kwargs["ipaddr"]} {gpt_replying_buffer}')
break
# 处理数据流的主体
chunkjson = json.loads(chunk_decoded[6:])
- status_text = f"finish_reason: {chunkjson['choices'][0]['finish_reason']}"
+
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出
gpt_replying_buffer = gpt_replying_buffer + json.loads(chunk_decoded[6:])['choices'][0]["delta"]["content"]
history[-1] = gpt_replying_buffer
chatbot[-1] = (history[-2], history[-1])
+ count_time = round(time.time() - llm_kwargs['start_time'], 3)
+ status_text = f"finish_reason: {chunkjson['choices'][0]['finish_reason']}\t" \
+ f"本次对话耗时: {func_box.html_tag_color(tag=f'{count_time}s')}"
yield from update_ui(chatbot=chatbot, history=history, msg=status_text) # 刷新界面
-
except Exception as e:
traceback.print_exc()
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析不合常规") # 刷新界面
@@ -207,7 +208,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
error_msg = chunk_decoded
if "reduce the length" in error_msg:
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
- history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
+ history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])) # history至少释放二分之一
chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
# history = [] # 清除历史
@@ -227,6 +228,9 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
chatbot[-1] = (chatbot[-1][0], f"[Local Message] 异常 \n\n{tb_str} \n\n{regular_txt_to_markdown(chunk_decoded)}")
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
return
+ count_tokens = func_box.num_tokens_from_string(listing=history)
+ status_text += f'\t 本次对话使用tokens: {func_box.html_tag_color(count_tokens)}'
+ yield from update_ui(chatbot=chatbot, history=history, msg=status_text) # 刷新界面
def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
"""
@@ -234,13 +238,18 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
"""
if not is_any_api_key(llm_kwargs['api_key']):
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
-
- api_key = select_api_key(llm_kwargs['api_key'], llm_kwargs['llm_model'])
-
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {api_key}"
- }
+ if llm_kwargs['llm_model'].startswith('proxy-'):
+ api_key = select_api_key(llm_kwargs['api_key'], llm_kwargs['llm_model'])
+ headers = {
+ "Content-Type": "application/json",
+ "api-key": f"{api_key}"
+ }
+ else:
+ api_key = select_api_key(llm_kwargs['api_key'], llm_kwargs['llm_model'])
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {api_key}"
+ }
conversation_cnt = len(history) // 2
@@ -277,9 +286,20 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
"frequency_penalty": 0,
}
try:
- print(f" {llm_kwargs['llm_model']} : {conversation_cnt} : {inputs[:100]} ..........")
+ print("\033[1;35m", f"{llm_kwargs['llm_model']}_{llm_kwargs['ipaddr']} :", "\033[0m", f"{conversation_cnt} : {inputs[:100]} ..........")
except:
print('输入中可能存在乱码。')
- return headers,payload
-
+ return headers, payload
+if __name__ == '__main__':
+ llm_kwargs = {
+ 'api_key': 'sk-',
+ 'llm_model': 'gpt-3.5-turbo',
+ 'top_p': 1,
+ 'max_length': 512,
+ 'temperature': 1,
+ # 'ipaddr': ipaddr.client.host
+ }
+ chat = []
+ predict('你好', llm_kwargs=llm_kwargs, chatbot=chat, plugin_kwargs={})
+ print(chat)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 1d70323..110e0ae 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -15,4 +15,11 @@ pymupdf
openai
numpy
arxiv
-rich
\ No newline at end of file
+pymupdf
+pyperclip
+scikit-learn
+psutil
+distro
+python-dotenv
+rich
+Levenshtein
\ No newline at end of file
diff --git a/theme.py b/theme.py
index 5ef7e96..ed39e45 100644
--- a/theme.py
+++ b/theme.py
@@ -1,6 +1,6 @@
import gradio as gr
from toolbox import get_conf
-CODE_HIGHLIGHT, ADD_WAIFU = get_conf('CODE_HIGHLIGHT', 'ADD_WAIFU')
+CODE_HIGHLIGHT, ADD_WAIFU, ADD_CHUANHU = get_conf('CODE_HIGHLIGHT', 'ADD_WAIFU', 'ADD_CHUANHU')
# gradio可用颜色列表
# gr.themes.utils.colors.slate (石板色)
# gr.themes.utils.colors.gray (灰色)
@@ -29,105 +29,185 @@ CODE_HIGHLIGHT, ADD_WAIFU = get_conf('CODE_HIGHLIGHT', 'ADD_WAIFU')
def adjust_theme():
try:
- color_er = gr.themes.utils.colors.fuchsia
- set_theme = gr.themes.Default(
- primary_hue=gr.themes.utils.colors.orange,
- neutral_hue=gr.themes.utils.colors.gray,
- font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui",
- "sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")],
- font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")])
- set_theme.set(
- # Colors
- input_background_fill_dark="*neutral_800",
- # Transition
- button_transition="none",
- # Shadows
- button_shadow="*shadow_drop",
- button_shadow_hover="*shadow_drop_lg",
- button_shadow_active="*shadow_inset",
- input_shadow="0 0 0 *shadow_spread transparent, *shadow_inset",
- input_shadow_focus="0 0 0 *shadow_spread *secondary_50, *shadow_inset",
- input_shadow_focus_dark="0 0 0 *shadow_spread *neutral_700, *shadow_inset",
- checkbox_label_shadow="*shadow_drop",
- block_shadow="*shadow_drop",
- form_gap_width="1px",
- # Button borders
- input_border_width="1px",
- input_background_fill="white",
- # Gradients
- stat_background_fill="linear-gradient(to right, *primary_400, *primary_200)",
- stat_background_fill_dark="linear-gradient(to right, *primary_400, *primary_600)",
- error_background_fill=f"linear-gradient(to right, {color_er.c100}, *background_fill_secondary)",
- error_background_fill_dark="*background_fill_primary",
- checkbox_label_background_fill="linear-gradient(to top, *neutral_50, white)",
- checkbox_label_background_fill_dark="linear-gradient(to top, *neutral_900, *neutral_800)",
- checkbox_label_background_fill_hover="linear-gradient(to top, *neutral_100, white)",
- checkbox_label_background_fill_hover_dark="linear-gradient(to top, *neutral_900, *neutral_800)",
- button_primary_background_fill="linear-gradient(to bottom right, *primary_100, *primary_300)",
- button_primary_background_fill_dark="linear-gradient(to bottom right, *primary_500, *primary_600)",
- button_primary_background_fill_hover="linear-gradient(to bottom right, *primary_100, *primary_200)",
- button_primary_background_fill_hover_dark="linear-gradient(to bottom right, *primary_500, *primary_500)",
- button_primary_border_color_dark="*primary_500",
- button_secondary_background_fill="linear-gradient(to bottom right, *neutral_100, *neutral_200)",
- button_secondary_background_fill_dark="linear-gradient(to bottom right, *neutral_600, *neutral_700)",
- button_secondary_background_fill_hover="linear-gradient(to bottom right, *neutral_100, *neutral_100)",
- button_secondary_background_fill_hover_dark="linear-gradient(to bottom right, *neutral_600, *neutral_600)",
- button_cancel_background_fill=f"linear-gradient(to bottom right, {color_er.c100}, {color_er.c200})",
- button_cancel_background_fill_dark=f"linear-gradient(to bottom right, {color_er.c600}, {color_er.c700})",
- button_cancel_background_fill_hover=f"linear-gradient(to bottom right, {color_er.c100}, {color_er.c100})",
- button_cancel_background_fill_hover_dark=f"linear-gradient(to bottom right, {color_er.c600}, {color_er.c600})",
- button_cancel_border_color=color_er.c200,
- button_cancel_border_color_dark=color_er.c600,
- button_cancel_text_color=color_er.c600,
- button_cancel_text_color_dark="white",
- )
+ set_theme = gr.themes.Soft(
+ primary_hue=gr.themes.Color(
+ c50="#EBFAF2",
+ c100="#CFF3E1",
+ c200="#A8EAC8",
+ c300="#77DEA9",
+ c400="#3FD086",
+ c500="#02C160",
+ c600="#06AE56",
+ c700="#05974E",
+ c800="#057F45",
+ c900="#04673D",
+ c950="#2E5541",
+ name="small_and_beautiful",
+ ),
+ secondary_hue=gr.themes.Color(
+ c50="#576b95",
+ c100="#576b95",
+ c200="#576b95",
+ c300="#576b95",
+ c400="#576b95",
+ c500="#576b95",
+ c600="#576b95",
+ c700="#576b95",
+ c800="#576b95",
+ c900="#576b95",
+ c950="#576b95",
+ ),
+ neutral_hue=gr.themes.Color(
+ name="gray",
+ c50="#f6f7f8",
+ # c100="#f3f4f6",
+ c100="#F2F2F2",
+ c200="#e5e7eb",
+ c300="#d1d5db",
+ c400="#B2B2B2",
+ c500="#808080",
+ c600="#636363",
+ c700="#515151",
+ c800="#393939",
+ # c900="#272727",
+ c900="#2B2B2B",
+ c950="#171717",
+ ),
+ radius_size=gr.themes.sizes.radius_sm,
+ ).set(
+ button_primary_background_fill="*primary_500",
+ button_primary_background_fill_dark="*primary_600",
+ button_primary_background_fill_hover="*primary_400",
+ button_primary_border_color="*primary_500",
+ button_primary_border_color_dark="*primary_600",
+ button_primary_text_color="wihte",
+ button_primary_text_color_dark="white",
+ button_secondary_background_fill="*neutral_100",
+ button_secondary_background_fill_hover="*neutral_50",
+ button_secondary_background_fill_dark="*neutral_900",
+ button_secondary_text_color="*neutral_800",
+ button_secondary_text_color_dark="white",
+ background_fill_primary="#F7F7F7",
+ background_fill_primary_dark="#1F1F1F",
+ block_title_text_color="*primary_500",
+ block_title_background_fill_dark="*primary_900",
+ block_label_background_fill_dark="*primary_900",
+ input_background_fill="#F6F6F6",
+ chatbot_code_background_color="*neutral_950",
+ chatbot_code_background_color_dark="*neutral_950",
+ )
+ js = ''
+ if ADD_CHUANHU:
+ with open("./docs/assets/custom.js", "r", encoding="utf-8") as f, \
+ open("./docs/assets/external-scripts.js", "r", encoding="utf-8") as f1:
+ customJS = f.read()
+ externalScripts = f1.read()
+ js += f''
# 添加一个萌萌的看板娘
if ADD_WAIFU:
- js = """
+ js += """
"""
- gradio_original_template_fn = gr.routes.templates.TemplateResponse
- def gradio_new_template_fn(*args, **kwargs):
- res = gradio_original_template_fn(*args, **kwargs)
- res.body = res.body.replace(b'