Compare commits
24 Commits
version3.4
...
finetune-l
| Author | SHA1 | Date | |
|---|---|---|---|
| b585ae6835 | |||
| 3cda37a6a6 | |||
| 127b1e5646 | |||
| ee23e8a851 | |||
| acddb86f3a | |||
| 4fde0120ab | |||
| 592a354eef | |||
| bd66cf3d8b | |||
| e6e5174734 | |||
| 13ade82677 | |||
| ce9eb8d20a | |||
| dd47c0a284 | |||
| f725ab1b31 | |||
| 91aee50ea7 | |||
| e5ccedf491 | |||
| f620666a58 | |||
| 594c63e5d6 | |||
| b082b5eb1b | |||
| 9648d78453 | |||
| 2dc8718041 | |||
| a330d6636e | |||
| 322c4be145 | |||
| a3596ff60d | |||
| e11d8132f8 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -150,3 +150,4 @@ request_llm/jittorllms
|
||||
multi-language
|
||||
request_llm/moss
|
||||
media
|
||||
flagged
|
||||
|
||||
@ -304,7 +304,12 @@ gpt_academic开发者QQ群-2:610599535
|
||||
- 某些浏览器翻译插件干扰此软件前端的运行
|
||||
- 官方Gradio目前有很多兼容性Bug,请务必使用`requirement.txt`安装Gradio
|
||||
|
||||
### III:参考与学习
|
||||
### III:主题
|
||||
|
||||
1. `Chuanhu-Small-and-Beautiful` [网址](https://github.com/GaiZhenbiao/ChuanhuChatGPT/)
|
||||
|
||||
|
||||
### IV:参考与学习
|
||||
|
||||
```
|
||||
代码中参考了很多其他优秀项目中的设计,顺序不分先后:
|
||||
|
||||
12
config.py
12
config.py
@ -85,6 +85,8 @@ CONCURRENT_COUNT = 100
|
||||
# 是否在提交时自动清空输入框
|
||||
AUTO_CLEAR_TXT = False
|
||||
|
||||
# 色彩主体,可选 ["Default", "Chuanhu-Small-and-Beautiful"]
|
||||
THEME = "Default"
|
||||
|
||||
# 加一个live2d装饰
|
||||
ADD_WAIFU = False
|
||||
@ -119,3 +121,13 @@ NEWBING_STYLE = "creative" # ["creative", "balanced", "precise"]
|
||||
NEWBING_COOKIES = """
|
||||
put your new bing cookies here
|
||||
"""
|
||||
|
||||
|
||||
# 阿里云实时语音识别 配置难度较高 仅建议高手用户使用 参考 https://help.aliyun.com/document_detail/450255.html
|
||||
ENABLE_AUDIO = False
|
||||
ALIYUN_TOKEN="" # 例如 f37f30e0f9934c34a992f6f64f7eba4f
|
||||
ALIYUN_APPKEY="" # 例如 RoPlZrM88DnAFkZK
|
||||
|
||||
|
||||
# ChatGLM Finetune Model Path
|
||||
ChatGLM_PTUNING_CHECKPOINT = ""
|
||||
@ -416,18 +416,34 @@ def get_crazy_functions():
|
||||
except:
|
||||
print('Load function plugin failed')
|
||||
|
||||
# try:
|
||||
# from crazy_functions.虚空终端 import 终端
|
||||
# function_plugins.update({
|
||||
# "超级终端": {
|
||||
# "Color": "stop",
|
||||
# "AsButton": False,
|
||||
# # "AdvancedArgs": True,
|
||||
# # "ArgsReminder": "",
|
||||
# "Function": HotReload(终端)
|
||||
# }
|
||||
# })
|
||||
# except:
|
||||
# print('Load function plugin failed')
|
||||
|
||||
try:
|
||||
from toolbox import get_conf
|
||||
ENABLE_AUDIO, = get_conf('ENABLE_AUDIO')
|
||||
if ENABLE_AUDIO:
|
||||
from crazy_functions.语音助手 import 语音助手
|
||||
function_plugins.update({
|
||||
"实时音频采集": {
|
||||
"Color": "stop",
|
||||
"AsButton": True,
|
||||
"Function": HotReload(语音助手)
|
||||
}
|
||||
})
|
||||
except:
|
||||
print('Load function plugin failed')
|
||||
|
||||
try:
|
||||
from crazy_functions.虚空终端 import 终端
|
||||
function_plugins.update({
|
||||
"超级终端": {
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
# "AdvancedArgs": True,
|
||||
# "ArgsReminder": "",
|
||||
"Function": HotReload(终端)
|
||||
}
|
||||
})
|
||||
except:
|
||||
print('Load function plugin failed')
|
||||
|
||||
return function_plugins
|
||||
|
||||
@ -69,3 +69,57 @@ def 微调数据集生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
|
||||
promote_file_to_downloadzone(txt+'.generated.json', rename_file='generated.json', chatbot=chatbot)
|
||||
return
|
||||
|
||||
|
||||
|
||||
def 启动微调(arguments):
|
||||
"""
|
||||
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
|
||||
llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行
|
||||
plugin_kwargs 插件模型的参数
|
||||
chatbot 聊天显示框的句柄,用于显示给用户
|
||||
history 聊天历史,前情提要
|
||||
system_prompt 给gpt的静默提醒
|
||||
web_port 当前软件运行的端口号
|
||||
"""
|
||||
history = [] # 清空历史,以免输入溢出
|
||||
import subprocess
|
||||
PRE_SEQ_LEN = 128
|
||||
LR = 2e-2
|
||||
NUM_GPUS = 1
|
||||
JSON_FILE = 't_code.json'
|
||||
tune_work_path = '/home/hmp/ChatGLM2-6B/ptuning'
|
||||
|
||||
command = f"torchrun --standalone --nnodes=1 --nproc-per-node={NUM_GPUS} main.py \
|
||||
--do_train \
|
||||
--train_file AdvertiseGen/{JSON_FILE} \
|
||||
--validation_file AdvertiseGen/{JSON_FILE} \
|
||||
--preprocessing_num_workers 20 \
|
||||
--prompt_column content \
|
||||
--response_column summary \
|
||||
--overwrite_cache \
|
||||
--model_name_or_path THUDM/chatglm2-6b \
|
||||
--output_dir output/clothgen-chatglm2-6b-pt-{PRE_SEQ_LEN}-{LR} \
|
||||
--overwrite_output_dir \
|
||||
--max_source_length 256 \
|
||||
--max_target_length 256 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--predict_with_generate \
|
||||
--max_steps 100 \
|
||||
--logging_steps 10 \
|
||||
--save_steps 20 \
|
||||
--learning_rate {LR} \
|
||||
--pre_seq_len {PRE_SEQ_LEN} \
|
||||
--quantization_bit 4"
|
||||
|
||||
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=tune_work_path)
|
||||
try:
|
||||
stdout, stderr = process.communicate(timeout=3600*5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
stdout, stderr = process.communicate()
|
||||
print("Process timed out!")
|
||||
return False
|
||||
return
|
||||
|
||||
89
crazy_functions/live_audio/aliyunASR.py
Normal file
89
crazy_functions/live_audio/aliyunASR.py
Normal file
@ -0,0 +1,89 @@
|
||||
import time, threading, json
|
||||
|
||||
|
||||
class AliyunASR():
|
||||
|
||||
def test_on_sentence_begin(self, message, *args):
|
||||
# print("test_on_sentence_begin:{}".format(message))
|
||||
pass
|
||||
|
||||
def test_on_sentence_end(self, message, *args):
|
||||
# print("test_on_sentence_end:{}".format(message))
|
||||
message = json.loads(message)
|
||||
self.parsed_sentence = message['payload']['result']
|
||||
self.event_on_entence_end.set()
|
||||
print(self.parsed_sentence)
|
||||
|
||||
def test_on_start(self, message, *args):
|
||||
# print("test_on_start:{}".format(message))
|
||||
pass
|
||||
|
||||
def test_on_error(self, message, *args):
|
||||
# print("on_error args=>{}".format(args))
|
||||
pass
|
||||
|
||||
def test_on_close(self, *args):
|
||||
# print("on_close: args=>{}".format(args))
|
||||
pass
|
||||
|
||||
def test_on_result_chg(self, message, *args):
|
||||
# print("test_on_chg:{}".format(message))
|
||||
message = json.loads(message)
|
||||
self.parsed_text = message['payload']['result']
|
||||
self.event_on_result_chg.set()
|
||||
|
||||
def test_on_completed(self, message, *args):
|
||||
# print("on_completed:args=>{} message=>{}".format(args, message))
|
||||
pass
|
||||
|
||||
|
||||
def audio_convertion_thread(self, uuid):
|
||||
# 在一个异步线程中采集音频
|
||||
import nls # pip install git+https://github.com/aliyun/alibabacloud-nls-python-sdk.git
|
||||
import tempfile
|
||||
from scipy import io
|
||||
from toolbox import get_conf
|
||||
from .audio_io import change_sample_rate
|
||||
from .audio_io import RealtimeAudioDistribution
|
||||
NEW_SAMPLERATE = 16000
|
||||
rad = RealtimeAudioDistribution()
|
||||
rad.clean_up()
|
||||
temp_folder = tempfile.gettempdir()
|
||||
TOKEN, APPKEY = get_conf('ALIYUN_TOKEN', 'ALIYUN_APPKEY')
|
||||
|
||||
URL="wss://nls-gateway.aliyuncs.com/ws/v1"
|
||||
sr = nls.NlsSpeechTranscriber(
|
||||
url=URL,
|
||||
token=TOKEN,
|
||||
appkey=APPKEY,
|
||||
on_sentence_begin=self.test_on_sentence_begin,
|
||||
on_sentence_end=self.test_on_sentence_end,
|
||||
on_start=self.test_on_start,
|
||||
on_result_changed=self.test_on_result_chg,
|
||||
on_completed=self.test_on_completed,
|
||||
on_error=self.test_on_error,
|
||||
on_close=self.test_on_close,
|
||||
callback_args=[uuid.hex]
|
||||
)
|
||||
|
||||
r = sr.start(aformat="pcm",
|
||||
enable_intermediate_result=True,
|
||||
enable_punctuation_prediction=True,
|
||||
enable_inverse_text_normalization=True)
|
||||
|
||||
while not self.stop:
|
||||
# time.sleep(self.capture_interval)
|
||||
audio = rad.read(uuid.hex)
|
||||
if audio is not None:
|
||||
# convert to pcm file
|
||||
temp_file = f'{temp_folder}/{uuid.hex}.pcm' #
|
||||
dsdata = change_sample_rate(audio, rad.rate, NEW_SAMPLERATE) # 48000 --> 16000
|
||||
io.wavfile.write(temp_file, NEW_SAMPLERATE, dsdata)
|
||||
# read pcm binary
|
||||
with open(temp_file, "rb") as f: data = f.read()
|
||||
# print('audio len:', len(audio), '\t ds len:', len(dsdata), '\t need n send:', len(data)//640)
|
||||
slices = zip(*(iter(data),) * 640) # 640个字节为一组
|
||||
for i in slices: sr.send_audio(bytes(i))
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
r = sr.stop()
|
||||
51
crazy_functions/live_audio/audio_io.py
Normal file
51
crazy_functions/live_audio/audio_io.py
Normal file
@ -0,0 +1,51 @@
|
||||
import numpy as np
|
||||
from scipy import interpolate
|
||||
|
||||
def Singleton(cls):
|
||||
_instance = {}
|
||||
|
||||
def _singleton(*args, **kargs):
|
||||
if cls not in _instance:
|
||||
_instance[cls] = cls(*args, **kargs)
|
||||
return _instance[cls]
|
||||
|
||||
return _singleton
|
||||
|
||||
|
||||
@Singleton
|
||||
class RealtimeAudioDistribution():
|
||||
def __init__(self) -> None:
|
||||
self.data = {}
|
||||
self.max_len = 1024*1024
|
||||
self.rate = 48000 # 只读,每秒采样数量
|
||||
|
||||
def clean_up(self):
|
||||
self.data = {}
|
||||
|
||||
def feed(self, uuid, audio):
|
||||
self.rate, audio_ = audio
|
||||
# print('feed', len(audio_), audio_[-25:])
|
||||
if uuid not in self.data:
|
||||
self.data[uuid] = audio_
|
||||
else:
|
||||
new_arr = np.concatenate((self.data[uuid], audio_))
|
||||
if len(new_arr) > self.max_len: new_arr = new_arr[-self.max_len:]
|
||||
self.data[uuid] = new_arr
|
||||
|
||||
def read(self, uuid):
|
||||
if uuid in self.data:
|
||||
res = self.data.pop(uuid)
|
||||
print('\r read-', len(res), '-', max(res), end='', flush=True)
|
||||
else:
|
||||
res = None
|
||||
return res
|
||||
|
||||
def change_sample_rate(audio, old_sr, new_sr):
|
||||
duration = audio.shape[0] / old_sr
|
||||
|
||||
time_old = np.linspace(0, duration, audio.shape[0])
|
||||
time_new = np.linspace(0, duration, int(audio.shape[0] * new_sr / old_sr))
|
||||
|
||||
interpolator = interpolate.interp1d(time_old, audio.T)
|
||||
new_audio = interpolator(time_new).T
|
||||
return new_audio.astype(np.int16)
|
||||
@ -1,87 +0,0 @@
|
||||
#include "libipc/buffer.h"
|
||||
#include "libipc/utility/pimpl.h"
|
||||
|
||||
#include <cstring>
|
||||
|
||||
namespace ipc {
|
||||
|
||||
bool operator==(buffer const & b1, buffer const & b2) {
|
||||
return (b1.size() == b2.size()) && (std::memcmp(b1.data(), b2.data(), b1.size()) == 0);
|
||||
}
|
||||
|
||||
bool operator!=(buffer const & b1, buffer const & b2) {
|
||||
return !(b1 == b2);
|
||||
}
|
||||
|
||||
class buffer::buffer_ : public pimpl<buffer_> {
|
||||
public:
|
||||
void* p_;
|
||||
std::size_t s_;
|
||||
void* a_;
|
||||
buffer::destructor_t d_;
|
||||
|
||||
buffer_(void* p, std::size_t s, buffer::destructor_t d, void* a)
|
||||
: p_(p), s_(s), a_(a), d_(d) {
|
||||
}
|
||||
|
||||
~buffer_() {
|
||||
if (d_ == nullptr) return;
|
||||
d_((a_ == nullptr) ? p_ : a_, s_);
|
||||
}
|
||||
};
|
||||
|
||||
buffer::buffer()
|
||||
: buffer(nullptr, 0, nullptr, nullptr) {
|
||||
}
|
||||
|
||||
buffer::buffer(void* p, std::size_t s, destructor_t d)
|
||||
: p_(p_->make(p, s, d, nullptr)) {
|
||||
}
|
||||
|
||||
buffer::buffer(void* p, std::size_t s, destructor_t d, void* additional)
|
||||
: p_(p_->make(p, s, d, additional)) {
|
||||
}
|
||||
|
||||
buffer::buffer(void* p, std::size_t s)
|
||||
: buffer(p, s, nullptr) {
|
||||
}
|
||||
|
||||
buffer::buffer(char const & c)
|
||||
: buffer(const_cast<char*>(&c), 1) {
|
||||
}
|
||||
|
||||
buffer::buffer(buffer&& rhs)
|
||||
: buffer() {
|
||||
swap(rhs);
|
||||
}
|
||||
|
||||
buffer::~buffer() {
|
||||
p_->clear();
|
||||
}
|
||||
|
||||
void buffer::swap(buffer& rhs) {
|
||||
std::swap(p_, rhs.p_);
|
||||
}
|
||||
|
||||
buffer& buffer::operator=(buffer rhs) {
|
||||
swap(rhs);
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool buffer::empty() const noexcept {
|
||||
return (impl(p_)->p_ == nullptr) || (impl(p_)->s_ == 0);
|
||||
}
|
||||
|
||||
void* buffer::data() noexcept {
|
||||
return impl(p_)->p_;
|
||||
}
|
||||
|
||||
void const * buffer::data() const noexcept {
|
||||
return impl(p_)->p_;
|
||||
}
|
||||
|
||||
std::size_t buffer::size() const noexcept {
|
||||
return impl(p_)->s_;
|
||||
}
|
||||
|
||||
} // namespace ipc
|
||||
@ -1,701 +0,0 @@
|
||||
|
||||
#include <type_traits>
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <utility> // std::pair, std::move, std::forward
|
||||
#include <atomic>
|
||||
#include <type_traits> // aligned_storage_t
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
|
||||
#include "libipc/ipc.h"
|
||||
#include "libipc/def.h"
|
||||
#include "libipc/shm.h"
|
||||
#include "libipc/pool_alloc.h"
|
||||
#include "libipc/queue.h"
|
||||
#include "libipc/policy.h"
|
||||
#include "libipc/rw_lock.h"
|
||||
#include "libipc/waiter.h"
|
||||
|
||||
#include "libipc/utility/log.h"
|
||||
#include "libipc/utility/id_pool.h"
|
||||
#include "libipc/utility/scope_guard.h"
|
||||
#include "libipc/utility/utility.h"
|
||||
|
||||
#include "libipc/memory/resource.h"
|
||||
#include "libipc/platform/detail.h"
|
||||
#include "libipc/circ/elem_array.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using msg_id_t = std::uint32_t;
|
||||
using acc_t = std::atomic<msg_id_t>;
|
||||
|
||||
template <std::size_t DataSize, std::size_t AlignSize>
|
||||
struct msg_t;
|
||||
|
||||
template <std::size_t AlignSize>
|
||||
struct msg_t<0, AlignSize> {
|
||||
msg_id_t cc_id_;
|
||||
msg_id_t id_;
|
||||
std::int32_t remain_;
|
||||
bool storage_;
|
||||
};
|
||||
|
||||
template <std::size_t DataSize, std::size_t AlignSize>
|
||||
struct msg_t : msg_t<0, AlignSize> {
|
||||
std::aligned_storage_t<DataSize, AlignSize> data_ {};
|
||||
|
||||
msg_t() = default;
|
||||
msg_t(msg_id_t cc_id, msg_id_t id, std::int32_t remain, void const * data, std::size_t size)
|
||||
: msg_t<0, AlignSize> {cc_id, id, remain, (data == nullptr) || (size == 0)} {
|
||||
if (this->storage_) {
|
||||
if (data != nullptr) {
|
||||
// copy storage-id
|
||||
*reinterpret_cast<ipc::storage_id_t*>(&data_) =
|
||||
*static_cast<ipc::storage_id_t const *>(data);
|
||||
}
|
||||
}
|
||||
else std::memcpy(&data_, data, size);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
ipc::buff_t make_cache(T& data, std::size_t size) {
|
||||
auto ptr = ipc::mem::alloc(size);
|
||||
std::memcpy(ptr, &data, (ipc::detail::min)(sizeof(data), size));
|
||||
return { ptr, size, ipc::mem::free };
|
||||
}
|
||||
|
||||
struct cache_t {
|
||||
std::size_t fill_;
|
||||
ipc::buff_t buff_;
|
||||
|
||||
cache_t(std::size_t f, ipc::buff_t && b)
|
||||
: fill_(f), buff_(std::move(b))
|
||||
{}
|
||||
|
||||
void append(void const * data, std::size_t size) {
|
||||
if (fill_ >= buff_.size() || data == nullptr || size == 0) return;
|
||||
auto new_fill = (ipc::detail::min)(fill_ + size, buff_.size());
|
||||
std::memcpy(static_cast<ipc::byte_t*>(buff_.data()) + fill_, data, new_fill - fill_);
|
||||
fill_ = new_fill;
|
||||
}
|
||||
};
|
||||
|
||||
auto cc_acc() {
|
||||
static ipc::shm::handle acc_h("__CA_CONN__", sizeof(acc_t));
|
||||
return static_cast<acc_t*>(acc_h.get());
|
||||
}
|
||||
|
||||
IPC_CONSTEXPR_ std::size_t align_chunk_size(std::size_t size) noexcept {
|
||||
return (((size - 1) / ipc::large_msg_align) + 1) * ipc::large_msg_align;
|
||||
}
|
||||
|
||||
IPC_CONSTEXPR_ std::size_t calc_chunk_size(std::size_t size) noexcept {
|
||||
return ipc::make_align(alignof(std::max_align_t), align_chunk_size(
|
||||
ipc::make_align(alignof(std::max_align_t), sizeof(std::atomic<ipc::circ::cc_t>)) + size));
|
||||
}
|
||||
|
||||
struct chunk_t {
|
||||
std::atomic<ipc::circ::cc_t> &conns() noexcept {
|
||||
return *reinterpret_cast<std::atomic<ipc::circ::cc_t> *>(this);
|
||||
}
|
||||
|
||||
void *data() noexcept {
|
||||
return reinterpret_cast<ipc::byte_t *>(this)
|
||||
+ ipc::make_align(alignof(std::max_align_t), sizeof(std::atomic<ipc::circ::cc_t>));
|
||||
}
|
||||
};
|
||||
|
||||
struct chunk_info_t {
|
||||
ipc::id_pool<> pool_;
|
||||
ipc::spin_lock lock_;
|
||||
|
||||
IPC_CONSTEXPR_ static std::size_t chunks_mem_size(std::size_t chunk_size) noexcept {
|
||||
return ipc::id_pool<>::max_count * chunk_size;
|
||||
}
|
||||
|
||||
ipc::byte_t *chunks_mem() noexcept {
|
||||
return reinterpret_cast<ipc::byte_t *>(this + 1);
|
||||
}
|
||||
|
||||
chunk_t *at(std::size_t chunk_size, ipc::storage_id_t id) noexcept {
|
||||
if (id < 0) return nullptr;
|
||||
return reinterpret_cast<chunk_t *>(chunks_mem() + (chunk_size * id));
|
||||
}
|
||||
};
|
||||
|
||||
auto& chunk_storages() {
|
||||
class chunk_handle_t {
|
||||
ipc::shm::handle handle_;
|
||||
|
||||
public:
|
||||
chunk_info_t *get_info(std::size_t chunk_size) {
|
||||
if (!handle_.valid() &&
|
||||
!handle_.acquire( ("__CHUNK_INFO__" + ipc::to_string(chunk_size)).c_str(),
|
||||
sizeof(chunk_info_t) + chunk_info_t::chunks_mem_size(chunk_size) )) {
|
||||
ipc::error("[chunk_storages] chunk_shm.id_info_.acquire failed: chunk_size = %zd\n", chunk_size);
|
||||
return nullptr;
|
||||
}
|
||||
auto info = static_cast<chunk_info_t*>(handle_.get());
|
||||
if (info == nullptr) {
|
||||
ipc::error("[chunk_storages] chunk_shm.id_info_.get failed: chunk_size = %zd\n", chunk_size);
|
||||
return nullptr;
|
||||
}
|
||||
return info;
|
||||
}
|
||||
};
|
||||
static ipc::map<std::size_t, chunk_handle_t> chunk_hs;
|
||||
return chunk_hs;
|
||||
}
|
||||
|
||||
chunk_info_t *chunk_storage_info(std::size_t chunk_size) {
|
||||
auto &storages = chunk_storages();
|
||||
std::decay_t<decltype(storages)>::iterator it;
|
||||
{
|
||||
static ipc::rw_lock lock;
|
||||
IPC_UNUSED_ std::shared_lock<ipc::rw_lock> guard {lock};
|
||||
if ((it = storages.find(chunk_size)) == storages.end()) {
|
||||
using chunk_handle_t = std::decay_t<decltype(storages)>::value_type::second_type;
|
||||
guard.unlock();
|
||||
IPC_UNUSED_ std::lock_guard<ipc::rw_lock> guard {lock};
|
||||
it = storages.emplace(chunk_size, chunk_handle_t{}).first;
|
||||
}
|
||||
}
|
||||
return it->second.get_info(chunk_size);
|
||||
}
|
||||
|
||||
std::pair<ipc::storage_id_t, void*> acquire_storage(std::size_t size, ipc::circ::cc_t conns) {
|
||||
std::size_t chunk_size = calc_chunk_size(size);
|
||||
auto info = chunk_storage_info(chunk_size);
|
||||
if (info == nullptr) return {};
|
||||
|
||||
info->lock_.lock();
|
||||
info->pool_.prepare();
|
||||
// got an unique id
|
||||
auto id = info->pool_.acquire();
|
||||
info->lock_.unlock();
|
||||
|
||||
auto chunk = info->at(chunk_size, id);
|
||||
if (chunk == nullptr) return {};
|
||||
chunk->conns().store(conns, std::memory_order_relaxed);
|
||||
return { id, chunk->data() };
|
||||
}
|
||||
|
||||
void *find_storage(ipc::storage_id_t id, std::size_t size) {
|
||||
if (id < 0) {
|
||||
ipc::error("[find_storage] id is invalid: id = %ld, size = %zd\n", (long)id, size);
|
||||
return nullptr;
|
||||
}
|
||||
std::size_t chunk_size = calc_chunk_size(size);
|
||||
auto info = chunk_storage_info(chunk_size);
|
||||
if (info == nullptr) return nullptr;
|
||||
return info->at(chunk_size, id)->data();
|
||||
}
|
||||
|
||||
void release_storage(ipc::storage_id_t id, std::size_t size) {
|
||||
if (id < 0) {
|
||||
ipc::error("[release_storage] id is invalid: id = %ld, size = %zd\n", (long)id, size);
|
||||
return;
|
||||
}
|
||||
std::size_t chunk_size = calc_chunk_size(size);
|
||||
auto info = chunk_storage_info(chunk_size);
|
||||
if (info == nullptr) return;
|
||||
info->lock_.lock();
|
||||
info->pool_.release(id);
|
||||
info->lock_.unlock();
|
||||
}
|
||||
|
||||
template <ipc::relat Rp, ipc::relat Rc>
|
||||
bool sub_rc(ipc::wr<Rp, Rc, ipc::trans::unicast>,
|
||||
std::atomic<ipc::circ::cc_t> &/*conns*/, ipc::circ::cc_t /*curr_conns*/, ipc::circ::cc_t /*conn_id*/) noexcept {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <ipc::relat Rp, ipc::relat Rc>
|
||||
bool sub_rc(ipc::wr<Rp, Rc, ipc::trans::broadcast>,
|
||||
std::atomic<ipc::circ::cc_t> &conns, ipc::circ::cc_t curr_conns, ipc::circ::cc_t conn_id) noexcept {
|
||||
auto last_conns = curr_conns & ~conn_id;
|
||||
for (unsigned k = 0;;) {
|
||||
auto chunk_conns = conns.load(std::memory_order_acquire);
|
||||
if (conns.compare_exchange_weak(chunk_conns, chunk_conns & last_conns, std::memory_order_release)) {
|
||||
return (chunk_conns & last_conns) == 0;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Flag>
|
||||
void recycle_storage(ipc::storage_id_t id, std::size_t size, ipc::circ::cc_t curr_conns, ipc::circ::cc_t conn_id) {
|
||||
if (id < 0) {
|
||||
ipc::error("[recycle_storage] id is invalid: id = %ld, size = %zd\n", (long)id, size);
|
||||
return;
|
||||
}
|
||||
std::size_t chunk_size = calc_chunk_size(size);
|
||||
auto info = chunk_storage_info(chunk_size);
|
||||
if (info == nullptr) return;
|
||||
|
||||
auto chunk = info->at(chunk_size, id);
|
||||
if (chunk == nullptr) return;
|
||||
|
||||
if (!sub_rc(Flag{}, chunk->conns(), curr_conns, conn_id)) {
|
||||
return;
|
||||
}
|
||||
info->lock_.lock();
|
||||
info->pool_.release(id);
|
||||
info->lock_.unlock();
|
||||
}
|
||||
|
||||
template <typename MsgT>
|
||||
bool clear_message(void* p) {
|
||||
auto msg = static_cast<MsgT*>(p);
|
||||
if (msg->storage_) {
|
||||
std::int32_t r_size = static_cast<std::int32_t>(ipc::data_length) + msg->remain_;
|
||||
if (r_size <= 0) {
|
||||
ipc::error("[clear_message] invalid msg size: %d\n", (int)r_size);
|
||||
return true;
|
||||
}
|
||||
release_storage(
|
||||
*reinterpret_cast<ipc::storage_id_t*>(&msg->data_),
|
||||
static_cast<std::size_t>(r_size));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
struct conn_info_head {
|
||||
|
||||
ipc::string name_;
|
||||
msg_id_t cc_id_; // connection-info id
|
||||
ipc::detail::waiter cc_waiter_, wt_waiter_, rd_waiter_;
|
||||
ipc::shm::handle acc_h_;
|
||||
|
||||
conn_info_head(char const * name)
|
||||
: name_ {name}
|
||||
, cc_id_ {(cc_acc() == nullptr) ? 0 : cc_acc()->fetch_add(1, std::memory_order_relaxed)}
|
||||
, cc_waiter_{("__CC_CONN__" + name_).c_str()}
|
||||
, wt_waiter_{("__WT_CONN__" + name_).c_str()}
|
||||
, rd_waiter_{("__RD_CONN__" + name_).c_str()}
|
||||
, acc_h_ {("__AC_CONN__" + name_).c_str(), sizeof(acc_t)} {
|
||||
}
|
||||
|
||||
void quit_waiting() {
|
||||
cc_waiter_.quit_waiting();
|
||||
wt_waiter_.quit_waiting();
|
||||
rd_waiter_.quit_waiting();
|
||||
}
|
||||
|
||||
auto acc() {
|
||||
return static_cast<acc_t*>(acc_h_.get());
|
||||
}
|
||||
|
||||
auto& recv_cache() {
|
||||
thread_local ipc::unordered_map<msg_id_t, cache_t> tls;
|
||||
return tls;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename W, typename F>
|
||||
bool wait_for(W& waiter, F&& pred, std::uint64_t tm) {
|
||||
if (tm == 0) return !pred();
|
||||
for (unsigned k = 0; pred();) {
|
||||
bool ret = true;
|
||||
ipc::sleep(k, [&k, &ret, &waiter, &pred, tm] {
|
||||
ret = waiter.wait_if(std::forward<F>(pred), tm);
|
||||
k = 0;
|
||||
});
|
||||
if (!ret) return false; // timeout or fail
|
||||
if (k == 0) break; // k has been reset
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Policy,
|
||||
std::size_t DataSize = ipc::data_length,
|
||||
std::size_t AlignSize = (ipc::detail::min)(DataSize, alignof(std::max_align_t))>
|
||||
struct queue_generator {
|
||||
|
||||
using queue_t = ipc::queue<msg_t<DataSize, AlignSize>, Policy>;
|
||||
|
||||
struct conn_info_t : conn_info_head {
|
||||
queue_t que_;
|
||||
|
||||
conn_info_t(char const * name)
|
||||
: conn_info_head{name}
|
||||
, que_{("__QU_CONN__" +
|
||||
ipc::to_string(DataSize) + "__" +
|
||||
ipc::to_string(AlignSize) + "__" + name).c_str()} {
|
||||
}
|
||||
|
||||
void disconnect_receiver() {
|
||||
bool dis = que_.disconnect();
|
||||
this->quit_waiting();
|
||||
if (dis) {
|
||||
this->recv_cache().clear();
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Policy>
|
||||
struct detail_impl {
|
||||
|
||||
using policy_t = Policy;
|
||||
using flag_t = typename policy_t::flag_t;
|
||||
using queue_t = typename queue_generator<policy_t>::queue_t;
|
||||
using conn_info_t = typename queue_generator<policy_t>::conn_info_t;
|
||||
|
||||
constexpr static conn_info_t* info_of(ipc::handle_t h) noexcept {
|
||||
return static_cast<conn_info_t*>(h);
|
||||
}
|
||||
|
||||
constexpr static queue_t* queue_of(ipc::handle_t h) noexcept {
|
||||
return (info_of(h) == nullptr) ? nullptr : &(info_of(h)->que_);
|
||||
}
|
||||
|
||||
/* API implementations */
|
||||
|
||||
static void disconnect(ipc::handle_t h) {
|
||||
auto que = queue_of(h);
|
||||
if (que == nullptr) {
|
||||
return;
|
||||
}
|
||||
que->shut_sending();
|
||||
assert(info_of(h) != nullptr);
|
||||
info_of(h)->disconnect_receiver();
|
||||
}
|
||||
|
||||
static bool reconnect(ipc::handle_t * ph, bool start_to_recv) {
|
||||
assert(ph != nullptr);
|
||||
assert(*ph != nullptr);
|
||||
auto que = queue_of(*ph);
|
||||
if (que == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (start_to_recv) {
|
||||
que->shut_sending();
|
||||
if (que->connect()) { // wouldn't connect twice
|
||||
info_of(*ph)->cc_waiter_.broadcast();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// start_to_recv == false
|
||||
if (que->connected()) {
|
||||
info_of(*ph)->disconnect_receiver();
|
||||
}
|
||||
return que->ready_sending();
|
||||
}
|
||||
|
||||
static bool connect(ipc::handle_t * ph, char const * name, bool start_to_recv) {
|
||||
assert(ph != nullptr);
|
||||
if (*ph == nullptr) {
|
||||
*ph = ipc::mem::alloc<conn_info_t>(name);
|
||||
}
|
||||
return reconnect(ph, start_to_recv);
|
||||
}
|
||||
|
||||
static void destroy(ipc::handle_t h) {
|
||||
disconnect(h);
|
||||
ipc::mem::free(info_of(h));
|
||||
}
|
||||
|
||||
static std::size_t recv_count(ipc::handle_t h) noexcept {
|
||||
auto que = queue_of(h);
|
||||
if (que == nullptr) {
|
||||
return ipc::invalid_value;
|
||||
}
|
||||
return que->conn_count();
|
||||
}
|
||||
|
||||
static bool wait_for_recv(ipc::handle_t h, std::size_t r_count, std::uint64_t tm) {
|
||||
auto que = queue_of(h);
|
||||
if (que == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return wait_for(info_of(h)->cc_waiter_, [que, r_count] {
|
||||
return que->conn_count() < r_count;
|
||||
}, tm);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
static bool send(F&& gen_push, ipc::handle_t h, void const * data, std::size_t size) {
|
||||
if (data == nullptr || size == 0) {
|
||||
ipc::error("fail: send(%p, %zd)\n", data, size);
|
||||
return false;
|
||||
}
|
||||
auto que = queue_of(h);
|
||||
if (que == nullptr) {
|
||||
ipc::error("fail: send, queue_of(h) == nullptr\n");
|
||||
return false;
|
||||
}
|
||||
if (que->elems() == nullptr) {
|
||||
ipc::error("fail: send, queue_of(h)->elems() == nullptr\n");
|
||||
return false;
|
||||
}
|
||||
if (!que->ready_sending()) {
|
||||
ipc::error("fail: send, que->ready_sending() == false\n");
|
||||
return false;
|
||||
}
|
||||
ipc::circ::cc_t conns = que->elems()->connections(std::memory_order_relaxed);
|
||||
if (conns == 0) {
|
||||
ipc::error("fail: send, there is no receiver on this connection.\n");
|
||||
return false;
|
||||
}
|
||||
// calc a new message id
|
||||
auto acc = info_of(h)->acc();
|
||||
if (acc == nullptr) {
|
||||
ipc::error("fail: send, info_of(h)->acc() == nullptr\n");
|
||||
return false;
|
||||
}
|
||||
auto msg_id = acc->fetch_add(1, std::memory_order_relaxed);
|
||||
auto try_push = std::forward<F>(gen_push)(info_of(h), que, msg_id);
|
||||
if (size > ipc::large_msg_limit) {
|
||||
auto dat = acquire_storage(size, conns);
|
||||
void * buf = dat.second;
|
||||
if (buf != nullptr) {
|
||||
std::memcpy(buf, data, size);
|
||||
return try_push(static_cast<std::int32_t>(size) -
|
||||
static_cast<std::int32_t>(ipc::data_length), &(dat.first), 0);
|
||||
}
|
||||
// try using message fragment
|
||||
//ipc::log("fail: shm::handle for big message. msg_id: %zd, size: %zd\n", msg_id, size);
|
||||
}
|
||||
// push message fragment
|
||||
std::int32_t offset = 0;
|
||||
for (std::int32_t i = 0; i < static_cast<std::int32_t>(size / ipc::data_length); ++i, offset += ipc::data_length) {
|
||||
if (!try_push(static_cast<std::int32_t>(size) - offset - static_cast<std::int32_t>(ipc::data_length),
|
||||
static_cast<ipc::byte_t const *>(data) + offset, ipc::data_length)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// if remain > 0, this is the last message fragment
|
||||
std::int32_t remain = static_cast<std::int32_t>(size) - offset;
|
||||
if (remain > 0) {
|
||||
if (!try_push(remain - static_cast<std::int32_t>(ipc::data_length),
|
||||
static_cast<ipc::byte_t const *>(data) + offset,
|
||||
static_cast<std::size_t>(remain))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool send(ipc::handle_t h, void const * data, std::size_t size, std::uint64_t tm) {
|
||||
return send([tm](auto info, auto que, auto msg_id) {
|
||||
return [tm, info, que, msg_id](std::int32_t remain, void const * data, std::size_t size) {
|
||||
if (!wait_for(info->wt_waiter_, [&] {
|
||||
return !que->push(
|
||||
[](void*) { return true; },
|
||||
info->cc_id_, msg_id, remain, data, size);
|
||||
}, tm)) {
|
||||
ipc::log("force_push: msg_id = %zd, remain = %d, size = %zd\n", msg_id, remain, size);
|
||||
if (!que->force_push(
|
||||
clear_message<typename queue_t::value_t>,
|
||||
info->cc_id_, msg_id, remain, data, size)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
info->rd_waiter_.broadcast();
|
||||
return true;
|
||||
};
|
||||
}, h, data, size);
|
||||
}
|
||||
|
||||
static bool try_send(ipc::handle_t h, void const * data, std::size_t size, std::uint64_t tm) {
|
||||
return send([tm](auto info, auto que, auto msg_id) {
|
||||
return [tm, info, que, msg_id](std::int32_t remain, void const * data, std::size_t size) {
|
||||
if (!wait_for(info->wt_waiter_, [&] {
|
||||
return !que->push(
|
||||
[](void*) { return true; },
|
||||
info->cc_id_, msg_id, remain, data, size);
|
||||
}, tm)) {
|
||||
return false;
|
||||
}
|
||||
info->rd_waiter_.broadcast();
|
||||
return true;
|
||||
};
|
||||
}, h, data, size);
|
||||
}
|
||||
|
||||
static ipc::buff_t recv(ipc::handle_t h, std::uint64_t tm) {
|
||||
auto que = queue_of(h);
|
||||
if (que == nullptr) {
|
||||
ipc::error("fail: recv, queue_of(h) == nullptr\n");
|
||||
return {};
|
||||
}
|
||||
if (!que->connected()) {
|
||||
// hasn't connected yet, just return.
|
||||
return {};
|
||||
}
|
||||
auto& rc = info_of(h)->recv_cache();
|
||||
for (;;) {
|
||||
// pop a new message
|
||||
typename queue_t::value_t msg;
|
||||
if (!wait_for(info_of(h)->rd_waiter_, [que, &msg] {
|
||||
return !que->pop(msg);
|
||||
}, tm)) {
|
||||
// pop failed, just return.
|
||||
return {};
|
||||
}
|
||||
info_of(h)->wt_waiter_.broadcast();
|
||||
if ((info_of(h)->acc() != nullptr) && (msg.cc_id_ == info_of(h)->cc_id_)) {
|
||||
continue; // ignore message to self
|
||||
}
|
||||
// msg.remain_ may minus & abs(msg.remain_) < data_length
|
||||
std::int32_t r_size = static_cast<std::int32_t>(ipc::data_length) + msg.remain_;
|
||||
if (r_size <= 0) {
|
||||
ipc::error("fail: recv, r_size = %d\n", (int)r_size);
|
||||
return {};
|
||||
}
|
||||
std::size_t msg_size = static_cast<std::size_t>(r_size);
|
||||
// large message
|
||||
if (msg.storage_) {
|
||||
ipc::storage_id_t buf_id = *reinterpret_cast<ipc::storage_id_t*>(&msg.data_);
|
||||
void* buf = find_storage(buf_id, msg_size);
|
||||
if (buf != nullptr) {
|
||||
struct recycle_t {
|
||||
ipc::storage_id_t storage_id;
|
||||
ipc::circ::cc_t curr_conns;
|
||||
ipc::circ::cc_t conn_id;
|
||||
} *r_info = ipc::mem::alloc<recycle_t>(recycle_t{
|
||||
buf_id, que->elems()->connections(std::memory_order_relaxed), que->connected_id()
|
||||
});
|
||||
if (r_info == nullptr) {
|
||||
ipc::log("fail: ipc::mem::alloc<recycle_t>.\n");
|
||||
return ipc::buff_t{buf, msg_size}; // no recycle
|
||||
} else {
|
||||
return ipc::buff_t{buf, msg_size, [](void* p_info, std::size_t size) {
|
||||
auto r_info = static_cast<recycle_t *>(p_info);
|
||||
IPC_UNUSED_ auto finally = ipc::guard([r_info] {
|
||||
ipc::mem::free(r_info);
|
||||
});
|
||||
recycle_storage<flag_t>(r_info->storage_id, size, r_info->curr_conns, r_info->conn_id);
|
||||
}, r_info};
|
||||
}
|
||||
} else {
|
||||
ipc::log("fail: shm::handle for large message. msg_id: %zd, buf_id: %zd, size: %zd\n", msg.id_, buf_id, msg_size);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// find cache with msg.id_
|
||||
auto cac_it = rc.find(msg.id_);
|
||||
if (cac_it == rc.end()) {
|
||||
if (msg_size <= ipc::data_length) {
|
||||
return make_cache(msg.data_, msg_size);
|
||||
}
|
||||
// gc
|
||||
if (rc.size() > 1024) {
|
||||
std::vector<msg_id_t> need_del;
|
||||
for (auto const & pair : rc) {
|
||||
auto cmp = std::minmax(msg.id_, pair.first);
|
||||
if (cmp.second - cmp.first > 8192) {
|
||||
need_del.push_back(pair.first);
|
||||
}
|
||||
}
|
||||
for (auto id : need_del) rc.erase(id);
|
||||
}
|
||||
// cache the first message fragment
|
||||
rc.emplace(msg.id_, cache_t { ipc::data_length, make_cache(msg.data_, msg_size) });
|
||||
}
|
||||
// has cached before this message
|
||||
else {
|
||||
auto& cac = cac_it->second;
|
||||
// this is the last message fragment
|
||||
if (msg.remain_ <= 0) {
|
||||
cac.append(&(msg.data_), msg_size);
|
||||
// finish this message, erase it from cache
|
||||
auto buff = std::move(cac.buff_);
|
||||
rc.erase(cac_it);
|
||||
return buff;
|
||||
}
|
||||
// there are remain datas after this message
|
||||
cac.append(&(msg.data_), ipc::data_length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static ipc::buff_t try_recv(ipc::handle_t h) {
|
||||
return recv(h, 0);
|
||||
}
|
||||
|
||||
}; // detail_impl<Policy>
|
||||
|
||||
template <typename Flag>
|
||||
using policy_t = ipc::policy::choose<ipc::circ::elem_array, Flag>;
|
||||
|
||||
} // internal-linkage
|
||||
|
||||
namespace ipc {
|
||||
|
||||
template <typename Flag>
|
||||
ipc::handle_t chan_impl<Flag>::inited() {
|
||||
ipc::detail::waiter::init();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename Flag>
|
||||
bool chan_impl<Flag>::connect(ipc::handle_t * ph, char const * name, unsigned mode) {
|
||||
return detail_impl<policy_t<Flag>>::connect(ph, name, mode & receiver);
|
||||
}
|
||||
|
||||
template <typename Flag>
|
||||
bool chan_impl<Flag>::reconnect(ipc::handle_t * ph, unsigned mode) {
|
||||
return detail_impl<policy_t<Flag>>::reconnect(ph, mode & receiver);
|
||||
}
|
||||
|
||||
template <typename Flag>
|
||||
void chan_impl<Flag>::disconnect(ipc::handle_t h) {
|
||||
detail_impl<policy_t<Flag>>::disconnect(h);
|
||||
}
|
||||
|
||||
template <typename Flag>
|
||||
void chan_impl<Flag>::destroy(ipc::handle_t h) {
|
||||
detail_impl<policy_t<Flag>>::destroy(h);
|
||||
}
|
||||
|
||||
template <typename Flag>
|
||||
char const * chan_impl<Flag>::name(ipc::handle_t h) {
|
||||
auto info = detail_impl<policy_t<Flag>>::info_of(h);
|
||||
return (info == nullptr) ? nullptr : info->name_.c_str();
|
||||
}
|
||||
|
||||
template <typename Flag>
|
||||
std::size_t chan_impl<Flag>::recv_count(ipc::handle_t h) {
|
||||
return detail_impl<policy_t<Flag>>::recv_count(h);
|
||||
}
|
||||
|
||||
template <typename Flag>
|
||||
bool chan_impl<Flag>::wait_for_recv(ipc::handle_t h, std::size_t r_count, std::uint64_t tm) {
|
||||
return detail_impl<policy_t<Flag>>::wait_for_recv(h, r_count, tm);
|
||||
}
|
||||
|
||||
template <typename Flag>
|
||||
bool chan_impl<Flag>::send(ipc::handle_t h, void const * data, std::size_t size, std::uint64_t tm) {
|
||||
return detail_impl<policy_t<Flag>>::send(h, data, size, tm);
|
||||
}
|
||||
|
||||
template <typename Flag>
|
||||
buff_t chan_impl<Flag>::recv(ipc::handle_t h, std::uint64_t tm) {
|
||||
return detail_impl<policy_t<Flag>>::recv(h, tm);
|
||||
}
|
||||
|
||||
template <typename Flag>
|
||||
bool chan_impl<Flag>::try_send(ipc::handle_t h, void const * data, std::size_t size, std::uint64_t tm) {
|
||||
return detail_impl<policy_t<Flag>>::try_send(h, data, size, tm);
|
||||
}
|
||||
|
||||
template <typename Flag>
|
||||
buff_t chan_impl<Flag>::try_recv(ipc::handle_t h) {
|
||||
return detail_impl<policy_t<Flag>>::try_recv(h);
|
||||
}
|
||||
|
||||
template struct chan_impl<ipc::wr<relat::single, relat::single, trans::unicast >>;
|
||||
// template struct chan_impl<ipc::wr<relat::single, relat::multi , trans::unicast >>; // TBD
|
||||
// template struct chan_impl<ipc::wr<relat::multi , relat::multi , trans::unicast >>; // TBD
|
||||
template struct chan_impl<ipc::wr<relat::single, relat::multi , trans::broadcast>>;
|
||||
template struct chan_impl<ipc::wr<relat::multi , relat::multi , trans::broadcast>>;
|
||||
|
||||
} // namespace ipc
|
||||
@ -1,25 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "libipc/def.h"
|
||||
#include "libipc/prod_cons.h"
|
||||
|
||||
#include "libipc/circ/elem_array.h"
|
||||
|
||||
namespace ipc {
|
||||
namespace policy {
|
||||
|
||||
template <template <typename, std::size_t...> class Elems, typename Flag>
|
||||
struct choose;
|
||||
|
||||
template <typename Flag>
|
||||
struct choose<circ::elem_array, Flag> {
|
||||
using flag_t = Flag;
|
||||
|
||||
template <std::size_t DataSize, std::size_t AlignSize>
|
||||
using elems_t = circ::elem_array<ipc::prod_cons_impl<flag_t>, DataSize, AlignSize>;
|
||||
};
|
||||
|
||||
} // namespace policy
|
||||
} // namespace ipc
|
||||
@ -1,17 +0,0 @@
|
||||
#include "libipc/pool_alloc.h"
|
||||
|
||||
#include "libipc/memory/resource.h"
|
||||
|
||||
namespace ipc {
|
||||
namespace mem {
|
||||
|
||||
void* pool_alloc::alloc(std::size_t size) {
|
||||
return async_pool_alloc::alloc(size);
|
||||
}
|
||||
|
||||
void pool_alloc::free(void* p, std::size_t size) {
|
||||
async_pool_alloc::free(p, size);
|
||||
}
|
||||
|
||||
} // namespace mem
|
||||
} // namespace ipc
|
||||
@ -1,433 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <utility>
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <cstdint>
|
||||
|
||||
#include "libipc/def.h"
|
||||
|
||||
#include "libipc/platform/detail.h"
|
||||
#include "libipc/circ/elem_def.h"
|
||||
#include "libipc/utility/log.h"
|
||||
#include "libipc/utility/utility.h"
|
||||
|
||||
namespace ipc {
|
||||
|
||||
////////////////////////////////////////////////////////////////
|
||||
/// producer-consumer implementation
|
||||
////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Flag>
|
||||
struct prod_cons_impl;
|
||||
|
||||
template <>
|
||||
struct prod_cons_impl<wr<relat::single, relat::single, trans::unicast>> {
|
||||
|
||||
template <std::size_t DataSize, std::size_t AlignSize>
|
||||
struct elem_t {
|
||||
std::aligned_storage_t<DataSize, AlignSize> data_ {};
|
||||
};
|
||||
|
||||
alignas(cache_line_size) std::atomic<circ::u2_t> rd_; // read index
|
||||
alignas(cache_line_size) std::atomic<circ::u2_t> wt_; // write index
|
||||
|
||||
constexpr circ::u2_t cursor() const noexcept {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool push(W* /*wrapper*/, F&& f, E* elems) {
|
||||
auto cur_wt = circ::index_of(wt_.load(std::memory_order_relaxed));
|
||||
if (cur_wt == circ::index_of(rd_.load(std::memory_order_acquire) - 1)) {
|
||||
return false; // full
|
||||
}
|
||||
std::forward<F>(f)(&(elems[cur_wt].data_));
|
||||
wt_.fetch_add(1, std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* In single-single-unicast, 'force_push' means 'no reader' or 'the only one reader is dead'.
|
||||
* So we could just disconnect all connections of receiver, and return false.
|
||||
*/
|
||||
template <typename W, typename F, typename E>
|
||||
bool force_push(W* wrapper, F&&, E*) {
|
||||
wrapper->elems()->disconnect_receiver(~static_cast<circ::cc_t>(0u));
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename R, typename E>
|
||||
bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E* elems) {
|
||||
auto cur_rd = circ::index_of(rd_.load(std::memory_order_relaxed));
|
||||
if (cur_rd == circ::index_of(wt_.load(std::memory_order_acquire))) {
|
||||
return false; // empty
|
||||
}
|
||||
std::forward<F>(f)(&(elems[cur_rd].data_));
|
||||
std::forward<R>(out)(true);
|
||||
rd_.fetch_add(1, std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct prod_cons_impl<wr<relat::single, relat::multi , trans::unicast>>
|
||||
: prod_cons_impl<wr<relat::single, relat::single, trans::unicast>> {
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool force_push(W* wrapper, F&&, E*) {
|
||||
wrapper->elems()->disconnect_receiver(1);
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename R,
|
||||
template <std::size_t, std::size_t> class E, std::size_t DS, std::size_t AS>
|
||||
bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E<DS, AS>* elems) {
|
||||
byte_t buff[DS];
|
||||
for (unsigned k = 0;;) {
|
||||
auto cur_rd = rd_.load(std::memory_order_relaxed);
|
||||
if (circ::index_of(cur_rd) ==
|
||||
circ::index_of(wt_.load(std::memory_order_acquire))) {
|
||||
return false; // empty
|
||||
}
|
||||
std::memcpy(buff, &(elems[circ::index_of(cur_rd)].data_), sizeof(buff));
|
||||
if (rd_.compare_exchange_weak(cur_rd, cur_rd + 1, std::memory_order_release)) {
|
||||
std::forward<F>(f)(buff);
|
||||
std::forward<R>(out)(true);
|
||||
return true;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct prod_cons_impl<wr<relat::multi , relat::multi, trans::unicast>>
|
||||
: prod_cons_impl<wr<relat::single, relat::multi, trans::unicast>> {
|
||||
|
||||
using flag_t = std::uint64_t;
|
||||
|
||||
template <std::size_t DataSize, std::size_t AlignSize>
|
||||
struct elem_t {
|
||||
std::aligned_storage_t<DataSize, AlignSize> data_ {};
|
||||
std::atomic<flag_t> f_ct_ { 0 }; // commit flag
|
||||
};
|
||||
|
||||
alignas(cache_line_size) std::atomic<circ::u2_t> ct_; // commit index
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool push(W* /*wrapper*/, F&& f, E* elems) {
|
||||
circ::u2_t cur_ct, nxt_ct;
|
||||
for (unsigned k = 0;;) {
|
||||
cur_ct = ct_.load(std::memory_order_relaxed);
|
||||
if (circ::index_of(nxt_ct = cur_ct + 1) ==
|
||||
circ::index_of(rd_.load(std::memory_order_acquire))) {
|
||||
return false; // full
|
||||
}
|
||||
if (ct_.compare_exchange_weak(cur_ct, nxt_ct, std::memory_order_acq_rel)) {
|
||||
break;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
auto* el = elems + circ::index_of(cur_ct);
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
// set flag & try update wt
|
||||
el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
|
||||
while (1) {
|
||||
auto cac_ct = el->f_ct_.load(std::memory_order_acquire);
|
||||
if (cur_ct != wt_.load(std::memory_order_relaxed)) {
|
||||
return true;
|
||||
}
|
||||
if ((~cac_ct) != cur_ct) {
|
||||
return true;
|
||||
}
|
||||
if (!el->f_ct_.compare_exchange_strong(cac_ct, 0, std::memory_order_relaxed)) {
|
||||
return true;
|
||||
}
|
||||
wt_.store(nxt_ct, std::memory_order_release);
|
||||
cur_ct = nxt_ct;
|
||||
nxt_ct = cur_ct + 1;
|
||||
el = elems + circ::index_of(cur_ct);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool force_push(W* wrapper, F&&, E*) {
|
||||
wrapper->elems()->disconnect_receiver(1);
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename R,
|
||||
template <std::size_t, std::size_t> class E, std::size_t DS, std::size_t AS>
|
||||
bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E<DS, AS>* elems) {
|
||||
byte_t buff[DS];
|
||||
for (unsigned k = 0;;) {
|
||||
auto cur_rd = rd_.load(std::memory_order_relaxed);
|
||||
auto cur_wt = wt_.load(std::memory_order_acquire);
|
||||
auto id_rd = circ::index_of(cur_rd);
|
||||
auto id_wt = circ::index_of(cur_wt);
|
||||
if (id_rd == id_wt) {
|
||||
auto* el = elems + id_wt;
|
||||
auto cac_ct = el->f_ct_.load(std::memory_order_acquire);
|
||||
if ((~cac_ct) != cur_wt) {
|
||||
return false; // empty
|
||||
}
|
||||
if (el->f_ct_.compare_exchange_weak(cac_ct, 0, std::memory_order_relaxed)) {
|
||||
wt_.store(cur_wt + 1, std::memory_order_release);
|
||||
}
|
||||
k = 0;
|
||||
}
|
||||
else {
|
||||
std::memcpy(buff, &(elems[circ::index_of(cur_rd)].data_), sizeof(buff));
|
||||
if (rd_.compare_exchange_weak(cur_rd, cur_rd + 1, std::memory_order_release)) {
|
||||
std::forward<F>(f)(buff);
|
||||
std::forward<R>(out)(true);
|
||||
return true;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct prod_cons_impl<wr<relat::single, relat::multi, trans::broadcast>> {
|
||||
|
||||
using rc_t = std::uint64_t;
|
||||
|
||||
enum : rc_t {
|
||||
ep_mask = 0x00000000ffffffffull,
|
||||
ep_incr = 0x0000000100000000ull
|
||||
};
|
||||
|
||||
template <std::size_t DataSize, std::size_t AlignSize>
|
||||
struct elem_t {
|
||||
std::aligned_storage_t<DataSize, AlignSize> data_ {};
|
||||
std::atomic<rc_t> rc_ { 0 }; // read-counter
|
||||
};
|
||||
|
||||
alignas(cache_line_size) std::atomic<circ::u2_t> wt_; // write index
|
||||
alignas(cache_line_size) rc_t epoch_ { 0 }; // only one writer
|
||||
|
||||
circ::u2_t cursor() const noexcept {
|
||||
return wt_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool push(W* wrapper, F&& f, E* elems) {
|
||||
E* el;
|
||||
for (unsigned k = 0;;) {
|
||||
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
|
||||
if (cc == 0) return false; // no reader
|
||||
el = elems + circ::index_of(wt_.load(std::memory_order_relaxed));
|
||||
// check all consumers have finished reading this element
|
||||
auto cur_rc = el->rc_.load(std::memory_order_acquire);
|
||||
circ::cc_t rem_cc = cur_rc & ep_mask;
|
||||
if ((cc & rem_cc) && ((cur_rc & ~ep_mask) == epoch_)) {
|
||||
return false; // has not finished yet
|
||||
}
|
||||
// consider rem_cc to be 0 here
|
||||
if (el->rc_.compare_exchange_weak(
|
||||
cur_rc, epoch_ | static_cast<rc_t>(cc), std::memory_order_release)) {
|
||||
break;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
wt_.fetch_add(1, std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool force_push(W* wrapper, F&& f, E* elems) {
|
||||
E* el;
|
||||
epoch_ += ep_incr;
|
||||
for (unsigned k = 0;;) {
|
||||
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
|
||||
if (cc == 0) return false; // no reader
|
||||
el = elems + circ::index_of(wt_.load(std::memory_order_relaxed));
|
||||
// check all consumers have finished reading this element
|
||||
auto cur_rc = el->rc_.load(std::memory_order_acquire);
|
||||
circ::cc_t rem_cc = cur_rc & ep_mask;
|
||||
if (cc & rem_cc) {
|
||||
ipc::log("force_push: k = %u, cc = %u, rem_cc = %u\n", k, cc, rem_cc);
|
||||
cc = wrapper->elems()->disconnect_receiver(rem_cc); // disconnect all invalid readers
|
||||
if (cc == 0) return false; // no reader
|
||||
}
|
||||
// just compare & exchange
|
||||
if (el->rc_.compare_exchange_weak(
|
||||
cur_rc, epoch_ | static_cast<rc_t>(cc), std::memory_order_release)) {
|
||||
break;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
wt_.fetch_add(1, std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename R, typename E>
|
||||
bool pop(W* wrapper, circ::u2_t& cur, F&& f, R&& out, E* elems) {
|
||||
if (cur == cursor()) return false; // acquire
|
||||
auto* el = elems + circ::index_of(cur++);
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
for (unsigned k = 0;;) {
|
||||
auto cur_rc = el->rc_.load(std::memory_order_acquire);
|
||||
if ((cur_rc & ep_mask) == 0) {
|
||||
std::forward<R>(out)(true);
|
||||
return true;
|
||||
}
|
||||
auto nxt_rc = cur_rc & ~static_cast<rc_t>(wrapper->connected_id());
|
||||
if (el->rc_.compare_exchange_weak(cur_rc, nxt_rc, std::memory_order_release)) {
|
||||
std::forward<R>(out)((nxt_rc & ep_mask) == 0);
|
||||
return true;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct prod_cons_impl<wr<relat::multi, relat::multi, trans::broadcast>> {
|
||||
|
||||
using rc_t = std::uint64_t;
|
||||
using flag_t = std::uint64_t;
|
||||
|
||||
enum : rc_t {
|
||||
rc_mask = 0x00000000ffffffffull,
|
||||
ep_mask = 0x00ffffffffffffffull,
|
||||
ep_incr = 0x0100000000000000ull,
|
||||
ic_mask = 0xff000000ffffffffull,
|
||||
ic_incr = 0x0000000100000000ull
|
||||
};
|
||||
|
||||
template <std::size_t DataSize, std::size_t AlignSize>
|
||||
struct elem_t {
|
||||
std::aligned_storage_t<DataSize, AlignSize> data_ {};
|
||||
std::atomic<rc_t > rc_ { 0 }; // read-counter
|
||||
std::atomic<flag_t> f_ct_ { 0 }; // commit flag
|
||||
};
|
||||
|
||||
alignas(cache_line_size) std::atomic<circ::u2_t> ct_; // commit index
|
||||
alignas(cache_line_size) std::atomic<rc_t> epoch_ { 0 };
|
||||
|
||||
circ::u2_t cursor() const noexcept {
|
||||
return ct_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
constexpr static rc_t inc_rc(rc_t rc) noexcept {
|
||||
return (rc & ic_mask) | ((rc + ic_incr) & ~ic_mask);
|
||||
}
|
||||
|
||||
constexpr static rc_t inc_mask(rc_t rc) noexcept {
|
||||
return inc_rc(rc) & ~rc_mask;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool push(W* wrapper, F&& f, E* elems) {
|
||||
E* el;
|
||||
circ::u2_t cur_ct;
|
||||
rc_t epoch = epoch_.load(std::memory_order_acquire);
|
||||
for (unsigned k = 0;;) {
|
||||
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
|
||||
if (cc == 0) return false; // no reader
|
||||
el = elems + circ::index_of(cur_ct = ct_.load(std::memory_order_relaxed));
|
||||
// check all consumers have finished reading this element
|
||||
auto cur_rc = el->rc_.load(std::memory_order_relaxed);
|
||||
circ::cc_t rem_cc = cur_rc & rc_mask;
|
||||
if ((cc & rem_cc) && ((cur_rc & ~ep_mask) == epoch)) {
|
||||
return false; // has not finished yet
|
||||
}
|
||||
else if (!rem_cc) {
|
||||
auto cur_fl = el->f_ct_.load(std::memory_order_acquire);
|
||||
if ((cur_fl != cur_ct) && cur_fl) {
|
||||
return false; // full
|
||||
}
|
||||
}
|
||||
// consider rem_cc to be 0 here
|
||||
if (el->rc_.compare_exchange_weak(
|
||||
cur_rc, inc_mask(epoch | (cur_rc & ep_mask)) | static_cast<rc_t>(cc), std::memory_order_relaxed) &&
|
||||
epoch_.compare_exchange_weak(epoch, epoch, std::memory_order_acq_rel)) {
|
||||
break;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
// only one thread/process would touch here at one time
|
||||
ct_.store(cur_ct + 1, std::memory_order_release);
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
// set flag & try update wt
|
||||
el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool force_push(W* wrapper, F&& f, E* elems) {
|
||||
E* el;
|
||||
circ::u2_t cur_ct;
|
||||
rc_t epoch = epoch_.fetch_add(ep_incr, std::memory_order_release) + ep_incr;
|
||||
for (unsigned k = 0;;) {
|
||||
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
|
||||
if (cc == 0) return false; // no reader
|
||||
el = elems + circ::index_of(cur_ct = ct_.load(std::memory_order_relaxed));
|
||||
// check all consumers have finished reading this element
|
||||
auto cur_rc = el->rc_.load(std::memory_order_acquire);
|
||||
circ::cc_t rem_cc = cur_rc & rc_mask;
|
||||
if (cc & rem_cc) {
|
||||
ipc::log("force_push: k = %u, cc = %u, rem_cc = %u\n", k, cc, rem_cc);
|
||||
cc = wrapper->elems()->disconnect_receiver(rem_cc); // disconnect all invalid readers
|
||||
if (cc == 0) return false; // no reader
|
||||
}
|
||||
// just compare & exchange
|
||||
if (el->rc_.compare_exchange_weak(
|
||||
cur_rc, inc_mask(epoch | (cur_rc & ep_mask)) | static_cast<rc_t>(cc), std::memory_order_relaxed)) {
|
||||
if (epoch == epoch_.load(std::memory_order_acquire)) {
|
||||
break;
|
||||
}
|
||||
else if (push(wrapper, std::forward<F>(f), elems)) {
|
||||
return true;
|
||||
}
|
||||
epoch = epoch_.fetch_add(ep_incr, std::memory_order_release) + ep_incr;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
// only one thread/process would touch here at one time
|
||||
ct_.store(cur_ct + 1, std::memory_order_release);
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
// set flag & try update wt
|
||||
el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename R, typename E, std::size_t N>
|
||||
bool pop(W* wrapper, circ::u2_t& cur, F&& f, R&& out, E(& elems)[N]) {
|
||||
auto* el = elems + circ::index_of(cur);
|
||||
auto cur_fl = el->f_ct_.load(std::memory_order_acquire);
|
||||
if (cur_fl != ~static_cast<flag_t>(cur)) {
|
||||
return false; // empty
|
||||
}
|
||||
++cur;
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
for (unsigned k = 0;;) {
|
||||
auto cur_rc = el->rc_.load(std::memory_order_acquire);
|
||||
if ((cur_rc & rc_mask) == 0) {
|
||||
std::forward<R>(out)(true);
|
||||
el->f_ct_.store(cur + N - 1, std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
auto nxt_rc = inc_rc(cur_rc) & ~static_cast<rc_t>(wrapper->connected_id());
|
||||
bool last_one = false;
|
||||
if ((last_one = (nxt_rc & rc_mask) == 0)) {
|
||||
el->f_ct_.store(cur + N - 1, std::memory_order_release);
|
||||
}
|
||||
if (el->rc_.compare_exchange_weak(cur_rc, nxt_rc, std::memory_order_release)) {
|
||||
std::forward<R>(out)(last_one);
|
||||
return true;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ipc
|
||||
@ -1,216 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <new>
|
||||
#include <utility> // [[since C++14]]: std::exchange
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <tuple>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
#include <cassert> // assert
|
||||
|
||||
#include "libipc/def.h"
|
||||
#include "libipc/shm.h"
|
||||
#include "libipc/rw_lock.h"
|
||||
|
||||
#include "libipc/utility/log.h"
|
||||
#include "libipc/platform/detail.h"
|
||||
#include "libipc/circ/elem_def.h"
|
||||
|
||||
namespace ipc {
|
||||
namespace detail {
|
||||
|
||||
class queue_conn {
|
||||
protected:
|
||||
circ::cc_t connected_ = 0;
|
||||
shm::handle elems_h_;
|
||||
|
||||
template <typename Elems>
|
||||
Elems* open(char const * name) {
|
||||
if (name == nullptr || name[0] == '\0') {
|
||||
ipc::error("fail open waiter: name is empty!\n");
|
||||
return nullptr;
|
||||
}
|
||||
if (!elems_h_.acquire(name, sizeof(Elems))) {
|
||||
return nullptr;
|
||||
}
|
||||
auto elems = static_cast<Elems*>(elems_h_.get());
|
||||
if (elems == nullptr) {
|
||||
ipc::error("fail acquire elems: %s\n", name);
|
||||
return nullptr;
|
||||
}
|
||||
elems->init();
|
||||
return elems;
|
||||
}
|
||||
|
||||
void close() {
|
||||
elems_h_.release();
|
||||
}
|
||||
|
||||
public:
|
||||
queue_conn() = default;
|
||||
queue_conn(const queue_conn&) = delete;
|
||||
queue_conn& operator=(const queue_conn&) = delete;
|
||||
|
||||
bool connected() const noexcept {
|
||||
return connected_ != 0;
|
||||
}
|
||||
|
||||
circ::cc_t connected_id() const noexcept {
|
||||
return connected_;
|
||||
}
|
||||
|
||||
template <typename Elems>
|
||||
auto connect(Elems* elems) noexcept
|
||||
/*needs 'optional' here*/
|
||||
-> std::tuple<bool, bool, decltype(std::declval<Elems>().cursor())> {
|
||||
if (elems == nullptr) return {};
|
||||
// if it's already connected, just return
|
||||
if (connected()) return {connected(), false, 0};
|
||||
connected_ = elems->connect_receiver();
|
||||
return {connected(), true, elems->cursor()};
|
||||
}
|
||||
|
||||
template <typename Elems>
|
||||
bool disconnect(Elems* elems) noexcept {
|
||||
if (elems == nullptr) return false;
|
||||
// if it's already disconnected, just return false
|
||||
if (!connected()) return false;
|
||||
elems->disconnect_receiver(std::exchange(connected_, 0));
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Elems>
|
||||
class queue_base : public queue_conn {
|
||||
using base_t = queue_conn;
|
||||
|
||||
public:
|
||||
using elems_t = Elems;
|
||||
using policy_t = typename elems_t::policy_t;
|
||||
|
||||
protected:
|
||||
elems_t * elems_ = nullptr;
|
||||
decltype(std::declval<elems_t>().cursor()) cursor_ = 0;
|
||||
bool sender_flag_ = false;
|
||||
|
||||
public:
|
||||
using base_t::base_t;
|
||||
|
||||
queue_base() = default;
|
||||
|
||||
explicit queue_base(char const * name)
|
||||
: queue_base{} {
|
||||
elems_ = open<elems_t>(name);
|
||||
}
|
||||
|
||||
explicit queue_base(elems_t * elems) noexcept
|
||||
: queue_base{} {
|
||||
assert(elems != nullptr);
|
||||
elems_ = elems;
|
||||
}
|
||||
|
||||
/* not virtual */ ~queue_base() {
|
||||
base_t::close();
|
||||
}
|
||||
|
||||
elems_t * elems() noexcept { return elems_; }
|
||||
elems_t const * elems() const noexcept { return elems_; }
|
||||
|
||||
bool ready_sending() noexcept {
|
||||
if (elems_ == nullptr) return false;
|
||||
return sender_flag_ || (sender_flag_ = elems_->connect_sender());
|
||||
}
|
||||
|
||||
void shut_sending() noexcept {
|
||||
if (elems_ == nullptr) return;
|
||||
if (!sender_flag_) return;
|
||||
elems_->disconnect_sender();
|
||||
}
|
||||
|
||||
bool connect() noexcept {
|
||||
auto tp = base_t::connect(elems_);
|
||||
if (std::get<0>(tp) && std::get<1>(tp)) {
|
||||
cursor_ = std::get<2>(tp);
|
||||
return true;
|
||||
}
|
||||
return std::get<0>(tp);
|
||||
}
|
||||
|
||||
bool disconnect() noexcept {
|
||||
return base_t::disconnect(elems_);
|
||||
}
|
||||
|
||||
std::size_t conn_count() const noexcept {
|
||||
return (elems_ == nullptr) ? static_cast<std::size_t>(invalid_value) : elems_->conn_count();
|
||||
}
|
||||
|
||||
bool valid() const noexcept {
|
||||
return elems_ != nullptr;
|
||||
}
|
||||
|
||||
bool empty() const noexcept {
|
||||
return !valid() || (cursor_ == elems_->cursor());
|
||||
}
|
||||
|
||||
template <typename T, typename F, typename... P>
|
||||
bool push(F&& prep, P&&... params) {
|
||||
if (elems_ == nullptr) return false;
|
||||
return elems_->push(this, [&](void* p) {
|
||||
if (prep(p)) ::new (p) T(std::forward<P>(params)...);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename F, typename... P>
|
||||
bool force_push(F&& prep, P&&... params) {
|
||||
if (elems_ == nullptr) return false;
|
||||
return elems_->force_push(this, [&](void* p) {
|
||||
if (prep(p)) ::new (p) T(std::forward<P>(params)...);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
bool pop(T& item, F&& out) {
|
||||
if (elems_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return elems_->pop(this, &(this->cursor_), [&item](void* p) {
|
||||
::new (&item) T(std::move(*static_cast<T*>(p)));
|
||||
}, std::forward<F>(out));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, typename Policy>
|
||||
class queue final : public detail::queue_base<typename Policy::template elems_t<sizeof(T), alignof(T)>> {
|
||||
using base_t = detail::queue_base<typename Policy::template elems_t<sizeof(T), alignof(T)>>;
|
||||
|
||||
public:
|
||||
using value_t = T;
|
||||
|
||||
using base_t::base_t;
|
||||
|
||||
template <typename... P>
|
||||
bool push(P&&... params) {
|
||||
return base_t::template push<T>(std::forward<P>(params)...);
|
||||
}
|
||||
|
||||
template <typename... P>
|
||||
bool force_push(P&&... params) {
|
||||
return base_t::template force_push<T>(std::forward<P>(params)...);
|
||||
}
|
||||
|
||||
bool pop(T& item) {
|
||||
return base_t::pop(item, [](bool) {});
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
bool pop(T& item, F&& out) {
|
||||
return base_t::pop(item, std::forward<F>(out));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ipc
|
||||
@ -1,103 +0,0 @@
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "libipc/shm.h"
|
||||
|
||||
#include "libipc/utility/pimpl.h"
|
||||
#include "libipc/memory/resource.h"
|
||||
|
||||
namespace ipc {
|
||||
namespace shm {
|
||||
|
||||
class handle::handle_ : public pimpl<handle_> {
|
||||
public:
|
||||
shm::id_t id_ = nullptr;
|
||||
void* m_ = nullptr;
|
||||
|
||||
ipc::string n_;
|
||||
std::size_t s_ = 0;
|
||||
};
|
||||
|
||||
handle::handle()
|
||||
: p_(p_->make()) {
|
||||
}
|
||||
|
||||
handle::handle(char const * name, std::size_t size, unsigned mode)
|
||||
: handle() {
|
||||
acquire(name, size, mode);
|
||||
}
|
||||
|
||||
handle::handle(handle&& rhs)
|
||||
: handle() {
|
||||
swap(rhs);
|
||||
}
|
||||
|
||||
handle::~handle() {
|
||||
release();
|
||||
p_->clear();
|
||||
}
|
||||
|
||||
void handle::swap(handle& rhs) {
|
||||
std::swap(p_, rhs.p_);
|
||||
}
|
||||
|
||||
handle& handle::operator=(handle rhs) {
|
||||
swap(rhs);
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool handle::valid() const noexcept {
|
||||
return impl(p_)->m_ != nullptr;
|
||||
}
|
||||
|
||||
std::size_t handle::size() const noexcept {
|
||||
return impl(p_)->s_;
|
||||
}
|
||||
|
||||
char const * handle::name() const noexcept {
|
||||
return impl(p_)->n_.c_str();
|
||||
}
|
||||
|
||||
std::int32_t handle::ref() const noexcept {
|
||||
return shm::get_ref(impl(p_)->id_);
|
||||
}
|
||||
|
||||
void handle::sub_ref() noexcept {
|
||||
shm::sub_ref(impl(p_)->id_);
|
||||
}
|
||||
|
||||
bool handle::acquire(char const * name, std::size_t size, unsigned mode) {
|
||||
release();
|
||||
impl(p_)->id_ = shm::acquire((impl(p_)->n_ = name).c_str(), size, mode);
|
||||
impl(p_)->m_ = shm::get_mem(impl(p_)->id_, &(impl(p_)->s_));
|
||||
return valid();
|
||||
}
|
||||
|
||||
std::int32_t handle::release() {
|
||||
if (impl(p_)->id_ == nullptr) return -1;
|
||||
return shm::release(detach());
|
||||
}
|
||||
|
||||
void* handle::get() const {
|
||||
return impl(p_)->m_;
|
||||
}
|
||||
|
||||
void handle::attach(id_t id) {
|
||||
if (id == nullptr) return;
|
||||
release();
|
||||
impl(p_)->id_ = id;
|
||||
impl(p_)->m_ = shm::get_mem(impl(p_)->id_, &(impl(p_)->s_));
|
||||
}
|
||||
|
||||
id_t handle::detach() {
|
||||
auto old = impl(p_)->id_;
|
||||
impl(p_)->id_ = nullptr;
|
||||
impl(p_)->m_ = nullptr;
|
||||
impl(p_)->s_ = 0;
|
||||
impl(p_)->n_.clear();
|
||||
return old;
|
||||
}
|
||||
|
||||
} // namespace shm
|
||||
} // namespace ipc
|
||||
@ -1,83 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
#include <atomic>
|
||||
|
||||
#include "libipc/def.h"
|
||||
#include "libipc/mutex.h"
|
||||
#include "libipc/condition.h"
|
||||
#include "libipc/platform/detail.h"
|
||||
|
||||
namespace ipc {
|
||||
namespace detail {
|
||||
|
||||
class waiter {
|
||||
ipc::sync::condition cond_;
|
||||
ipc::sync::mutex lock_;
|
||||
std::atomic<bool> quit_ {false};
|
||||
|
||||
public:
|
||||
static void init();
|
||||
|
||||
waiter() = default;
|
||||
waiter(char const *name) {
|
||||
open(name);
|
||||
}
|
||||
|
||||
~waiter() {
|
||||
close();
|
||||
}
|
||||
|
||||
bool valid() const noexcept {
|
||||
return cond_.valid() && lock_.valid();
|
||||
}
|
||||
|
||||
bool open(char const *name) noexcept {
|
||||
quit_.store(false, std::memory_order_relaxed);
|
||||
if (!cond_.open((std::string{"_waiter_cond_"} + name).c_str())) {
|
||||
return false;
|
||||
}
|
||||
if (!lock_.open((std::string{"_waiter_lock_"} + name).c_str())) {
|
||||
cond_.close();
|
||||
return false;
|
||||
}
|
||||
return valid();
|
||||
}
|
||||
|
||||
void close() noexcept {
|
||||
cond_.close();
|
||||
lock_.close();
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
bool wait_if(F &&pred, std::uint64_t tm = ipc::invalid_value) noexcept {
|
||||
IPC_UNUSED_ std::lock_guard<ipc::sync::mutex> guard {lock_};
|
||||
while ([this, &pred] {
|
||||
return !quit_.load(std::memory_order_relaxed)
|
||||
&& std::forward<F>(pred)();
|
||||
}()) {
|
||||
if (!cond_.wait(lock_, tm)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool notify() noexcept {
|
||||
std::lock_guard<ipc::sync::mutex>{lock_}; // barrier
|
||||
return cond_.notify(lock_);
|
||||
}
|
||||
|
||||
bool broadcast() noexcept {
|
||||
std::lock_guard<ipc::sync::mutex>{lock_}; // barrier
|
||||
return cond_.broadcast(lock_);
|
||||
}
|
||||
|
||||
bool quit_waiting() {
|
||||
quit_.store(true, std::memory_order_release);
|
||||
return broadcast();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace ipc
|
||||
@ -1,3 +0,0 @@
|
||||
https://github.com/mutouyun/cpp-ipc
|
||||
|
||||
A high-performance inter-process communication library using shared memory on Linux/Windows.
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,316 +0,0 @@
|
||||
// jpgd.h - C++ class for JPEG decompression.
|
||||
// Public domain, Rich Geldreich <richgel99@gmail.com>
|
||||
#ifndef JPEG_DECODER_H
|
||||
#define JPEG_DECODER_H
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <setjmp.h>
|
||||
|
||||
namespace jpgd
|
||||
{
|
||||
typedef unsigned char uint8;
|
||||
typedef signed short int16;
|
||||
typedef unsigned short uint16;
|
||||
typedef unsigned int uint;
|
||||
typedef signed int int32;
|
||||
|
||||
// Loads a JPEG image from a memory buffer or a file.
|
||||
// req_comps can be 1 (grayscale), 3 (RGB), or 4 (RGBA).
|
||||
// On return, width/height will be set to the image's dimensions, and actual_comps will be set to the either 1 (grayscale) or 3 (RGB).
|
||||
// Notes: For more control over where and how the source data is read, see the decompress_jpeg_image_from_stream() function below, or call the jpeg_decoder class directly.
|
||||
// Requesting a 8 or 32bpp image is currently a little faster than 24bpp because the jpeg_decoder class itself currently always unpacks to either 8 or 32bpp.
|
||||
// BEGIN EPIC MOD
|
||||
//unsigned char *decompress_jpeg_image_from_memory(const unsigned char *pSrc_data, int src_data_size, int *width, int *height, int *actual_comps, int req_comps);
|
||||
unsigned char *decompress_jpeg_image_from_memory(const unsigned char *pSrc_data, int src_data_size, int *width, int *height, int *actual_comps, int req_comps, int format);
|
||||
// END EPIC MOD
|
||||
unsigned char *decompress_jpeg_image_from_file(const char *pSrc_filename, int *width, int *height, int *actual_comps, int req_comps);
|
||||
|
||||
// Success/failure error codes.
|
||||
enum jpgd_status
|
||||
{
|
||||
JPGD_SUCCESS = 0, JPGD_FAILED = -1, JPGD_DONE = 1,
|
||||
JPGD_BAD_DHT_COUNTS = -256, JPGD_BAD_DHT_INDEX, JPGD_BAD_DHT_MARKER, JPGD_BAD_DQT_MARKER, JPGD_BAD_DQT_TABLE,
|
||||
JPGD_BAD_PRECISION, JPGD_BAD_HEIGHT, JPGD_BAD_WIDTH, JPGD_TOO_MANY_COMPONENTS,
|
||||
JPGD_BAD_SOF_LENGTH, JPGD_BAD_VARIABLE_MARKER, JPGD_BAD_DRI_LENGTH, JPGD_BAD_SOS_LENGTH,
|
||||
JPGD_BAD_SOS_COMP_ID, JPGD_W_EXTRA_BYTES_BEFORE_MARKER, JPGD_NO_ARITHMITIC_SUPPORT, JPGD_UNEXPECTED_MARKER,
|
||||
JPGD_NOT_JPEG, JPGD_UNSUPPORTED_MARKER, JPGD_BAD_DQT_LENGTH, JPGD_TOO_MANY_BLOCKS,
|
||||
JPGD_UNDEFINED_QUANT_TABLE, JPGD_UNDEFINED_HUFF_TABLE, JPGD_NOT_SINGLE_SCAN, JPGD_UNSUPPORTED_COLORSPACE,
|
||||
JPGD_UNSUPPORTED_SAMP_FACTORS, JPGD_DECODE_ERROR, JPGD_BAD_RESTART_MARKER, JPGD_ASSERTION_ERROR,
|
||||
JPGD_BAD_SOS_SPECTRAL, JPGD_BAD_SOS_SUCCESSIVE, JPGD_STREAM_READ, JPGD_NOTENOUGHMEM
|
||||
};
|
||||
|
||||
// Input stream interface.
|
||||
// Derive from this class to read input data from sources other than files or memory. Set m_eof_flag to true when no more data is available.
|
||||
// The decoder is rather greedy: it will keep on calling this method until its internal input buffer is full, or until the EOF flag is set.
|
||||
// It the input stream contains data after the JPEG stream's EOI (end of image) marker it will probably be pulled into the internal buffer.
|
||||
// Call the get_total_bytes_read() method to determine the actual size of the JPEG stream after successful decoding.
|
||||
class jpeg_decoder_stream
|
||||
{
|
||||
public:
|
||||
jpeg_decoder_stream() { }
|
||||
virtual ~jpeg_decoder_stream() { }
|
||||
|
||||
// The read() method is called when the internal input buffer is empty.
|
||||
// Parameters:
|
||||
// pBuf - input buffer
|
||||
// max_bytes_to_read - maximum bytes that can be written to pBuf
|
||||
// pEOF_flag - set this to true if at end of stream (no more bytes remaining)
|
||||
// Returns -1 on error, otherwise return the number of bytes actually written to the buffer (which may be 0).
|
||||
// Notes: This method will be called in a loop until you set *pEOF_flag to true or the internal buffer is full.
|
||||
virtual int read(uint8 *pBuf, int max_bytes_to_read, bool *pEOF_flag) = 0;
|
||||
};
|
||||
|
||||
// stdio FILE stream class.
|
||||
class jpeg_decoder_file_stream : public jpeg_decoder_stream
|
||||
{
|
||||
jpeg_decoder_file_stream(const jpeg_decoder_file_stream &);
|
||||
jpeg_decoder_file_stream &operator =(const jpeg_decoder_file_stream &);
|
||||
|
||||
FILE *m_pFile;
|
||||
bool m_eof_flag, m_error_flag;
|
||||
|
||||
public:
|
||||
jpeg_decoder_file_stream();
|
||||
virtual ~jpeg_decoder_file_stream();
|
||||
|
||||
bool open(const char *Pfilename);
|
||||
void close();
|
||||
|
||||
virtual int read(uint8 *pBuf, int max_bytes_to_read, bool *pEOF_flag);
|
||||
};
|
||||
|
||||
// Memory stream class.
|
||||
class jpeg_decoder_mem_stream : public jpeg_decoder_stream
|
||||
{
|
||||
const uint8 *m_pSrc_data;
|
||||
uint m_ofs, m_size;
|
||||
|
||||
public:
|
||||
jpeg_decoder_mem_stream() : m_pSrc_data(NULL), m_ofs(0), m_size(0) { }
|
||||
jpeg_decoder_mem_stream(const uint8 *pSrc_data, uint size) : m_pSrc_data(pSrc_data), m_ofs(0), m_size(size) { }
|
||||
|
||||
virtual ~jpeg_decoder_mem_stream() { }
|
||||
|
||||
bool open(const uint8 *pSrc_data, uint size);
|
||||
void close() { m_pSrc_data = NULL; m_ofs = 0; m_size = 0; }
|
||||
|
||||
virtual int read(uint8 *pBuf, int max_bytes_to_read, bool *pEOF_flag);
|
||||
};
|
||||
|
||||
// Loads JPEG file from a jpeg_decoder_stream.
|
||||
unsigned char *decompress_jpeg_image_from_stream(jpeg_decoder_stream *pStream, int *width, int *height, int *actual_comps, int req_comps);
|
||||
|
||||
enum
|
||||
{
|
||||
JPGD_IN_BUF_SIZE = 8192, JPGD_MAX_BLOCKS_PER_MCU = 10, JPGD_MAX_HUFF_TABLES = 8, JPGD_MAX_QUANT_TABLES = 4,
|
||||
JPGD_MAX_COMPONENTS = 4, JPGD_MAX_COMPS_IN_SCAN = 4, JPGD_MAX_BLOCKS_PER_ROW = 8192, JPGD_MAX_HEIGHT = 16384, JPGD_MAX_WIDTH = 16384
|
||||
};
|
||||
|
||||
typedef int16 jpgd_quant_t;
|
||||
typedef int16 jpgd_block_t;
|
||||
|
||||
class jpeg_decoder
|
||||
{
|
||||
public:
|
||||
// Call get_error_code() after constructing to determine if the stream is valid or not. You may call the get_width(), get_height(), etc.
|
||||
// methods after the constructor is called. You may then either destruct the object, or begin decoding the image by calling begin_decoding(), then decode() on each scanline.
|
||||
jpeg_decoder(jpeg_decoder_stream *pStream);
|
||||
|
||||
~jpeg_decoder();
|
||||
|
||||
// Call this method after constructing the object to begin decompression.
|
||||
// If JPGD_SUCCESS is returned you may then call decode() on each scanline.
|
||||
int begin_decoding();
|
||||
|
||||
// Returns the next scan line.
|
||||
// For grayscale images, pScan_line will point to a buffer containing 8-bit pixels (get_bytes_per_pixel() will return 1).
|
||||
// Otherwise, it will always point to a buffer containing 32-bit RGBA pixels (A will always be 255, and get_bytes_per_pixel() will return 4).
|
||||
// Returns JPGD_SUCCESS if a scan line has been returned.
|
||||
// Returns JPGD_DONE if all scan lines have been returned.
|
||||
// Returns JPGD_FAILED if an error occurred. Call get_error_code() for a more info.
|
||||
int decode(const void** pScan_line, uint* pScan_line_len);
|
||||
|
||||
inline jpgd_status get_error_code() const { return m_error_code; }
|
||||
|
||||
inline int get_width() const { return m_image_x_size; }
|
||||
inline int get_height() const { return m_image_y_size; }
|
||||
|
||||
inline int get_num_components() const { return m_comps_in_frame; }
|
||||
|
||||
inline int get_bytes_per_pixel() const { return m_dest_bytes_per_pixel; }
|
||||
inline int get_bytes_per_scan_line() const { return m_image_x_size * get_bytes_per_pixel(); }
|
||||
|
||||
// Returns the total number of bytes actually consumed by the decoder (which should equal the actual size of the JPEG file).
|
||||
inline int get_total_bytes_read() const { return m_total_bytes_read; }
|
||||
|
||||
private:
|
||||
jpeg_decoder(const jpeg_decoder &);
|
||||
jpeg_decoder &operator =(const jpeg_decoder &);
|
||||
|
||||
typedef void (*pDecode_block_func)(jpeg_decoder *, int, int, int);
|
||||
|
||||
struct huff_tables
|
||||
{
|
||||
bool ac_table;
|
||||
uint look_up[256];
|
||||
uint look_up2[256];
|
||||
uint8 code_size[256];
|
||||
uint tree[512];
|
||||
};
|
||||
|
||||
struct coeff_buf
|
||||
{
|
||||
uint8 *pData;
|
||||
int block_num_x, block_num_y;
|
||||
int block_len_x, block_len_y;
|
||||
int block_size;
|
||||
};
|
||||
|
||||
struct mem_block
|
||||
{
|
||||
mem_block *m_pNext;
|
||||
size_t m_used_count;
|
||||
size_t m_size;
|
||||
char m_data[1];
|
||||
};
|
||||
|
||||
jmp_buf m_jmp_state;
|
||||
mem_block *m_pMem_blocks;
|
||||
int m_image_x_size;
|
||||
int m_image_y_size;
|
||||
jpeg_decoder_stream *m_pStream;
|
||||
int m_progressive_flag;
|
||||
uint8 m_huff_ac[JPGD_MAX_HUFF_TABLES];
|
||||
uint8* m_huff_num[JPGD_MAX_HUFF_TABLES]; // pointer to number of Huffman codes per bit size
|
||||
uint8* m_huff_val[JPGD_MAX_HUFF_TABLES]; // pointer to Huffman codes per bit size
|
||||
jpgd_quant_t* m_quant[JPGD_MAX_QUANT_TABLES]; // pointer to quantization tables
|
||||
int m_scan_type; // Gray, Yh1v1, Yh1v2, Yh2v1, Yh2v2 (CMYK111, CMYK4114 no longer supported)
|
||||
int m_comps_in_frame; // # of components in frame
|
||||
int m_comp_h_samp[JPGD_MAX_COMPONENTS]; // component's horizontal sampling factor
|
||||
int m_comp_v_samp[JPGD_MAX_COMPONENTS]; // component's vertical sampling factor
|
||||
int m_comp_quant[JPGD_MAX_COMPONENTS]; // component's quantization table selector
|
||||
int m_comp_ident[JPGD_MAX_COMPONENTS]; // component's ID
|
||||
int m_comp_h_blocks[JPGD_MAX_COMPONENTS];
|
||||
int m_comp_v_blocks[JPGD_MAX_COMPONENTS];
|
||||
int m_comps_in_scan; // # of components in scan
|
||||
int m_comp_list[JPGD_MAX_COMPS_IN_SCAN]; // components in this scan
|
||||
int m_comp_dc_tab[JPGD_MAX_COMPONENTS]; // component's DC Huffman coding table selector
|
||||
int m_comp_ac_tab[JPGD_MAX_COMPONENTS]; // component's AC Huffman coding table selector
|
||||
int m_spectral_start; // spectral selection start
|
||||
int m_spectral_end; // spectral selection end
|
||||
int m_successive_low; // successive approximation low
|
||||
int m_successive_high; // successive approximation high
|
||||
int m_max_mcu_x_size; // MCU's max. X size in pixels
|
||||
int m_max_mcu_y_size; // MCU's max. Y size in pixels
|
||||
int m_blocks_per_mcu;
|
||||
int m_max_blocks_per_row;
|
||||
int m_mcus_per_row, m_mcus_per_col;
|
||||
int m_mcu_org[JPGD_MAX_BLOCKS_PER_MCU];
|
||||
int m_total_lines_left; // total # lines left in image
|
||||
int m_mcu_lines_left; // total # lines left in this MCU
|
||||
int m_real_dest_bytes_per_scan_line;
|
||||
int m_dest_bytes_per_scan_line; // rounded up
|
||||
int m_dest_bytes_per_pixel; // 4 (RGB) or 1 (Y)
|
||||
huff_tables* m_pHuff_tabs[JPGD_MAX_HUFF_TABLES];
|
||||
coeff_buf* m_dc_coeffs[JPGD_MAX_COMPONENTS];
|
||||
coeff_buf* m_ac_coeffs[JPGD_MAX_COMPONENTS];
|
||||
int m_eob_run;
|
||||
int m_block_y_mcu[JPGD_MAX_COMPONENTS];
|
||||
uint8* m_pIn_buf_ofs;
|
||||
int m_in_buf_left;
|
||||
int m_tem_flag;
|
||||
bool m_eof_flag;
|
||||
uint8 m_in_buf_pad_start[128];
|
||||
uint8 m_in_buf[JPGD_IN_BUF_SIZE + 128];
|
||||
uint8 m_in_buf_pad_end[128];
|
||||
int m_bits_left;
|
||||
uint m_bit_buf;
|
||||
int m_restart_interval;
|
||||
int m_restarts_left;
|
||||
int m_next_restart_num;
|
||||
int m_max_mcus_per_row;
|
||||
int m_max_blocks_per_mcu;
|
||||
int m_expanded_blocks_per_mcu;
|
||||
int m_expanded_blocks_per_row;
|
||||
int m_expanded_blocks_per_component;
|
||||
bool m_freq_domain_chroma_upsample;
|
||||
int m_max_mcus_per_col;
|
||||
uint m_last_dc_val[JPGD_MAX_COMPONENTS];
|
||||
jpgd_block_t* m_pMCU_coefficients;
|
||||
int m_mcu_block_max_zag[JPGD_MAX_BLOCKS_PER_MCU];
|
||||
uint8* m_pSample_buf;
|
||||
int m_crr[256];
|
||||
int m_cbb[256];
|
||||
int m_crg[256];
|
||||
int m_cbg[256];
|
||||
uint8* m_pScan_line_0;
|
||||
uint8* m_pScan_line_1;
|
||||
jpgd_status m_error_code;
|
||||
bool m_ready_flag;
|
||||
int m_total_bytes_read;
|
||||
|
||||
void free_all_blocks();
|
||||
// BEGIN EPIC MOD
|
||||
UE_NORETURN void stop_decoding(jpgd_status status);
|
||||
// END EPIC MOD
|
||||
void *alloc(size_t n, bool zero = false);
|
||||
void word_clear(void *p, uint16 c, uint n);
|
||||
void prep_in_buffer();
|
||||
void read_dht_marker();
|
||||
void read_dqt_marker();
|
||||
void read_sof_marker();
|
||||
void skip_variable_marker();
|
||||
void read_dri_marker();
|
||||
void read_sos_marker();
|
||||
int next_marker();
|
||||
int process_markers();
|
||||
void locate_soi_marker();
|
||||
void locate_sof_marker();
|
||||
int locate_sos_marker();
|
||||
void init(jpeg_decoder_stream * pStream);
|
||||
void create_look_ups();
|
||||
void fix_in_buffer();
|
||||
void transform_mcu(int mcu_row);
|
||||
void transform_mcu_expand(int mcu_row);
|
||||
coeff_buf* coeff_buf_open(int block_num_x, int block_num_y, int block_len_x, int block_len_y);
|
||||
inline jpgd_block_t *coeff_buf_getp(coeff_buf *cb, int block_x, int block_y);
|
||||
void load_next_row();
|
||||
void decode_next_row();
|
||||
void make_huff_table(int index, huff_tables *pH);
|
||||
void check_quant_tables();
|
||||
void check_huff_tables();
|
||||
void calc_mcu_block_order();
|
||||
int init_scan();
|
||||
void init_frame();
|
||||
void process_restart();
|
||||
void decode_scan(pDecode_block_func decode_block_func);
|
||||
void init_progressive();
|
||||
void init_sequential();
|
||||
void decode_start();
|
||||
void decode_init(jpeg_decoder_stream * pStream);
|
||||
void H2V2Convert();
|
||||
void H2V1Convert();
|
||||
void H1V2Convert();
|
||||
void H1V1Convert();
|
||||
void gray_convert();
|
||||
void expanded_convert();
|
||||
void find_eoi();
|
||||
inline uint get_char();
|
||||
inline uint get_char(bool *pPadding_flag);
|
||||
inline void stuff_char(uint8 q);
|
||||
inline uint8 get_octet();
|
||||
inline uint get_bits(int num_bits);
|
||||
inline uint get_bits_no_markers(int numbits);
|
||||
inline int huff_decode(huff_tables *pH);
|
||||
inline int huff_decode(huff_tables *pH, int& extrabits);
|
||||
static inline uint8 clamp(int i);
|
||||
static void decode_block_dc_first(jpeg_decoder *pD, int component_id, int block_x, int block_y);
|
||||
static void decode_block_dc_refine(jpeg_decoder *pD, int component_id, int block_x, int block_y);
|
||||
static void decode_block_ac_first(jpeg_decoder *pD, int component_id, int block_x, int block_y);
|
||||
static void decode_block_ac_refine(jpeg_decoder *pD, int component_id, int block_x, int block_y);
|
||||
};
|
||||
|
||||
} // namespace jpgd
|
||||
|
||||
#endif // JPEG_DECODER_H
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,172 +0,0 @@
|
||||
|
||||
// jpge.h - C++ class for JPEG compression.
|
||||
// Public domain, Rich Geldreich <richgel99@gmail.com>
|
||||
// Alex Evans: Added RGBA support, linear memory allocator.
|
||||
#ifndef JPEG_ENCODER_H
|
||||
#define JPEG_ENCODER_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace jpge
|
||||
{
|
||||
typedef unsigned char uint8;
|
||||
typedef signed short int16;
|
||||
typedef signed int int32;
|
||||
typedef unsigned short uint16;
|
||||
typedef unsigned int uint32;
|
||||
typedef unsigned int uint;
|
||||
|
||||
// JPEG chroma subsampling factors. Y_ONLY (grayscale images) and H2V2 (color images) are the most common.
|
||||
enum subsampling_t { Y_ONLY = 0, H1V1 = 1, H2V1 = 2, H2V2 = 3 };
|
||||
|
||||
// JPEG compression parameters structure.
|
||||
struct params
|
||||
{
|
||||
inline params() : m_quality(85), m_subsampling(H2V2), m_no_chroma_discrim_flag(false), m_two_pass_flag(false) { }
|
||||
|
||||
inline bool check_valid() const
|
||||
{
|
||||
if ((m_quality < 1) || (m_quality > 100)) return false;
|
||||
if ((uint)m_subsampling > (uint)H2V2) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Quality: 1-100, higher is better. Typical values are around 50-95.
|
||||
int m_quality;
|
||||
|
||||
// m_subsampling:
|
||||
// 0 = Y (grayscale) only
|
||||
// 1 = YCbCr, no subsampling (H1V1, YCbCr 1x1x1, 3 blocks per MCU)
|
||||
// 2 = YCbCr, H2V1 subsampling (YCbCr 2x1x1, 4 blocks per MCU)
|
||||
// 3 = YCbCr, H2V2 subsampling (YCbCr 4x1x1, 6 blocks per MCU-- very common)
|
||||
subsampling_t m_subsampling;
|
||||
|
||||
// Disables CbCr discrimination - only intended for testing.
|
||||
// If true, the Y quantization table is also used for the CbCr channels.
|
||||
bool m_no_chroma_discrim_flag;
|
||||
|
||||
bool m_two_pass_flag;
|
||||
};
|
||||
|
||||
// Writes JPEG image to a file.
|
||||
// num_channels must be 1 (Y) or 3 (RGB), image pitch must be width*num_channels.
|
||||
bool compress_image_to_jpeg_file(const char *pFilename, int64_t width, int64_t height, int64_t num_channels, const uint8 *pImage_data, const params &comp_params = params());
|
||||
|
||||
// Writes JPEG image to memory buffer.
|
||||
// On entry, buf_size is the size of the output buffer pointed at by pBuf, which should be at least ~1024 bytes.
|
||||
// If return value is true, buf_size will be set to the size of the compressed data.
|
||||
bool compress_image_to_jpeg_file_in_memory(void *pBuf, int64_t &buf_size, int64_t width, int64_t height, int64_t num_channels, const uint8 *pImage_data, const params &comp_params = params());
|
||||
|
||||
// Output stream abstract class - used by the jpeg_encoder class to write to the output stream.
|
||||
// put_buf() is generally called with len==JPGE_OUT_BUF_SIZE bytes, but for headers it'll be called with smaller amounts.
|
||||
class output_stream
|
||||
{
|
||||
public:
|
||||
virtual ~output_stream() { };
|
||||
virtual bool put_buf(const void* Pbuf, int64_t len) = 0;
|
||||
template<class T> inline bool put_obj(const T& obj) { return put_buf(&obj, sizeof(T)); }
|
||||
};
|
||||
|
||||
// Lower level jpeg_encoder class - useful if more control is needed than the above helper functions.
|
||||
class jpeg_encoder
|
||||
{
|
||||
public:
|
||||
jpeg_encoder();
|
||||
~jpeg_encoder();
|
||||
|
||||
// Initializes the compressor.
|
||||
// pStream: The stream object to use for writing compressed data.
|
||||
// params - Compression parameters structure, defined above.
|
||||
// width, height - Image dimensions.
|
||||
// channels - May be 1, or 3. 1 indicates grayscale, 3 indicates RGB source data.
|
||||
// Returns false on out of memory or if a stream write fails.
|
||||
bool init(output_stream *pStream, int64_t width, int64_t height, int64_t src_channels, const params &comp_params = params());
|
||||
|
||||
const params &get_params() const { return m_params; }
|
||||
|
||||
// Deinitializes the compressor, freeing any allocated memory. May be called at any time.
|
||||
void deinit();
|
||||
|
||||
uint get_total_passes() const { return m_params.m_two_pass_flag ? 2 : 1; }
|
||||
inline uint get_cur_pass() { return m_pass_num; }
|
||||
|
||||
// Call this method with each source scanline.
|
||||
// width * src_channels bytes per scanline is expected (RGB or Y format).
|
||||
// You must call with NULL after all scanlines are processed to finish compression.
|
||||
// Returns false on out of memory or if a stream write fails.
|
||||
bool process_scanline(const void* pScanline);
|
||||
|
||||
private:
|
||||
jpeg_encoder(const jpeg_encoder &);
|
||||
jpeg_encoder &operator =(const jpeg_encoder &);
|
||||
|
||||
typedef int32 sample_array_t;
|
||||
|
||||
output_stream *m_pStream;
|
||||
params m_params;
|
||||
uint8 m_num_components;
|
||||
uint8 m_comp_h_samp[3], m_comp_v_samp[3];
|
||||
int m_image_x, m_image_y, m_image_bpp, m_image_bpl;
|
||||
int m_image_x_mcu, m_image_y_mcu;
|
||||
int m_image_bpl_xlt, m_image_bpl_mcu;
|
||||
int m_mcus_per_row;
|
||||
int m_mcu_x, m_mcu_y;
|
||||
uint8 *m_mcu_lines[16];
|
||||
uint8 m_mcu_y_ofs;
|
||||
sample_array_t m_sample_array[64];
|
||||
int16 m_coefficient_array[64];
|
||||
int32 m_quantization_tables[2][64];
|
||||
uint m_huff_codes[4][256];
|
||||
uint8 m_huff_code_sizes[4][256];
|
||||
uint8 m_huff_bits[4][17];
|
||||
uint8 m_huff_val[4][256];
|
||||
uint32 m_huff_count[4][256];
|
||||
int m_last_dc_val[3];
|
||||
enum { JPGE_OUT_BUF_SIZE = 2048 };
|
||||
uint8 m_out_buf[JPGE_OUT_BUF_SIZE];
|
||||
uint8 *m_pOut_buf;
|
||||
uint m_out_buf_left;
|
||||
uint32 m_bit_buffer;
|
||||
uint m_bits_in;
|
||||
uint8 m_pass_num;
|
||||
bool m_all_stream_writes_succeeded;
|
||||
|
||||
void optimize_huffman_table(int table_num, int table_len);
|
||||
void emit_byte(uint8 i);
|
||||
void emit_word(uint i);
|
||||
void emit_marker(int marker);
|
||||
void emit_jfif_app0();
|
||||
void emit_dqt();
|
||||
void emit_sof();
|
||||
void emit_dht(uint8 *bits, uint8 *val, int index, bool ac_flag);
|
||||
void emit_dhts();
|
||||
void emit_sos();
|
||||
void emit_markers();
|
||||
void compute_huffman_table(uint *codes, uint8 *code_sizes, uint8 *bits, uint8 *val);
|
||||
void compute_quant_table(int32 *dst, int16 *src);
|
||||
void adjust_quant_table(int32 *dst, int32 *src);
|
||||
void first_pass_init();
|
||||
bool second_pass_init();
|
||||
bool jpg_open(int p_x_res, int p_y_res, int src_channels);
|
||||
void load_block_8_8_grey(int x);
|
||||
void load_block_8_8(int x, int y, int c);
|
||||
void load_block_16_8(int x, int c);
|
||||
void load_block_16_8_8(int x, int c);
|
||||
void load_quantized_coefficients(int component_num);
|
||||
void flush_output_buffer();
|
||||
void put_bits(uint bits, uint len);
|
||||
void code_coefficients_pass_one(int component_num);
|
||||
void code_coefficients_pass_two(int component_num);
|
||||
void code_block(int component_num);
|
||||
void process_mcu_row();
|
||||
bool terminate_pass_one();
|
||||
bool terminate_pass_two();
|
||||
bool process_end_of_image();
|
||||
void load_mcu(const void* src);
|
||||
void clear();
|
||||
void init();
|
||||
};
|
||||
|
||||
} // namespace jpge
|
||||
|
||||
#endif // JPEG_ENCODER
|
||||
@ -1,3 +0,0 @@
|
||||
jpge.h - C++ class for JPEG compression.
|
||||
Public domain, Rich Geldreich <richgel99@gmail.com>
|
||||
Alex Evans: Added RGBA support, linear memory allocator.
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,433 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <utility>
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <cstdint>
|
||||
|
||||
#include "libipc/def.h"
|
||||
|
||||
#include "libipc/platform/detail.h"
|
||||
#include "libipc/circ/elem_def.h"
|
||||
#include "libipc/utility/log.h"
|
||||
#include "libipc/utility/utility.h"
|
||||
|
||||
namespace ipc {
|
||||
|
||||
////////////////////////////////////////////////////////////////
|
||||
/// producer-consumer implementation
|
||||
////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Flag>
|
||||
struct prod_cons_impl;
|
||||
|
||||
template <>
|
||||
struct prod_cons_impl<wr<relat::single, relat::single, trans::unicast>> {
|
||||
|
||||
template <std::size_t DataSize, std::size_t AlignSize>
|
||||
struct elem_t {
|
||||
std::aligned_storage_t<DataSize, AlignSize> data_ {};
|
||||
};
|
||||
|
||||
alignas(cache_line_size) std::atomic<circ::u2_t> rd_; // read index
|
||||
alignas(cache_line_size) std::atomic<circ::u2_t> wt_; // write index
|
||||
|
||||
constexpr circ::u2_t cursor() const noexcept {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool push(W* /*wrapper*/, F&& f, E* elems) {
|
||||
auto cur_wt = circ::index_of(wt_.load(std::memory_order_relaxed));
|
||||
if (cur_wt == circ::index_of(rd_.load(std::memory_order_acquire) - 1)) {
|
||||
return false; // full
|
||||
}
|
||||
std::forward<F>(f)(&(elems[cur_wt].data_));
|
||||
wt_.fetch_add(1, std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* In single-single-unicast, 'force_push' means 'no reader' or 'the only one reader is dead'.
|
||||
* So we could just disconnect all connections of receiver, and return false.
|
||||
*/
|
||||
template <typename W, typename F, typename E>
|
||||
bool force_push(W* wrapper, F&&, E*) {
|
||||
wrapper->elems()->disconnect_receiver(~static_cast<circ::cc_t>(0u));
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename R, typename E>
|
||||
bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E* elems) {
|
||||
auto cur_rd = circ::index_of(rd_.load(std::memory_order_relaxed));
|
||||
if (cur_rd == circ::index_of(wt_.load(std::memory_order_acquire))) {
|
||||
return false; // empty
|
||||
}
|
||||
std::forward<F>(f)(&(elems[cur_rd].data_));
|
||||
std::forward<R>(out)(true);
|
||||
rd_.fetch_add(1, std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct prod_cons_impl<wr<relat::single, relat::multi , trans::unicast>>
|
||||
: prod_cons_impl<wr<relat::single, relat::single, trans::unicast>> {
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool force_push(W* wrapper, F&&, E*) {
|
||||
wrapper->elems()->disconnect_receiver(1);
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename R,
|
||||
template <std::size_t, std::size_t> class E, std::size_t DS, std::size_t AS>
|
||||
bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E<DS, AS>* elems) {
|
||||
byte_t buff[DS];
|
||||
for (unsigned k = 0;;) {
|
||||
auto cur_rd = rd_.load(std::memory_order_relaxed);
|
||||
if (circ::index_of(cur_rd) ==
|
||||
circ::index_of(wt_.load(std::memory_order_acquire))) {
|
||||
return false; // empty
|
||||
}
|
||||
std::memcpy(buff, &(elems[circ::index_of(cur_rd)].data_), sizeof(buff));
|
||||
if (rd_.compare_exchange_weak(cur_rd, cur_rd + 1, std::memory_order_release)) {
|
||||
std::forward<F>(f)(buff);
|
||||
std::forward<R>(out)(true);
|
||||
return true;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct prod_cons_impl<wr<relat::multi , relat::multi, trans::unicast>>
|
||||
: prod_cons_impl<wr<relat::single, relat::multi, trans::unicast>> {
|
||||
|
||||
using flag_t = std::uint64_t;
|
||||
|
||||
template <std::size_t DataSize, std::size_t AlignSize>
|
||||
struct elem_t {
|
||||
std::aligned_storage_t<DataSize, AlignSize> data_ {};
|
||||
std::atomic<flag_t> f_ct_ { 0 }; // commit flag
|
||||
};
|
||||
|
||||
alignas(cache_line_size) std::atomic<circ::u2_t> ct_; // commit index
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool push(W* /*wrapper*/, F&& f, E* elems) {
|
||||
circ::u2_t cur_ct, nxt_ct;
|
||||
for (unsigned k = 0;;) {
|
||||
cur_ct = ct_.load(std::memory_order_relaxed);
|
||||
if (circ::index_of(nxt_ct = cur_ct + 1) ==
|
||||
circ::index_of(rd_.load(std::memory_order_acquire))) {
|
||||
return false; // full
|
||||
}
|
||||
if (ct_.compare_exchange_weak(cur_ct, nxt_ct, std::memory_order_acq_rel)) {
|
||||
break;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
auto* el = elems + circ::index_of(cur_ct);
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
// set flag & try update wt
|
||||
el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
|
||||
while (1) {
|
||||
auto cac_ct = el->f_ct_.load(std::memory_order_acquire);
|
||||
if (cur_ct != wt_.load(std::memory_order_relaxed)) {
|
||||
return true;
|
||||
}
|
||||
if ((~cac_ct) != cur_ct) {
|
||||
return true;
|
||||
}
|
||||
if (!el->f_ct_.compare_exchange_strong(cac_ct, 0, std::memory_order_relaxed)) {
|
||||
return true;
|
||||
}
|
||||
wt_.store(nxt_ct, std::memory_order_release);
|
||||
cur_ct = nxt_ct;
|
||||
nxt_ct = cur_ct + 1;
|
||||
el = elems + circ::index_of(cur_ct);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool force_push(W* wrapper, F&&, E*) {
|
||||
wrapper->elems()->disconnect_receiver(1);
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename R,
|
||||
template <std::size_t, std::size_t> class E, std::size_t DS, std::size_t AS>
|
||||
bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E<DS, AS>* elems) {
|
||||
byte_t buff[DS];
|
||||
for (unsigned k = 0;;) {
|
||||
auto cur_rd = rd_.load(std::memory_order_relaxed);
|
||||
auto cur_wt = wt_.load(std::memory_order_acquire);
|
||||
auto id_rd = circ::index_of(cur_rd);
|
||||
auto id_wt = circ::index_of(cur_wt);
|
||||
if (id_rd == id_wt) {
|
||||
auto* el = elems + id_wt;
|
||||
auto cac_ct = el->f_ct_.load(std::memory_order_acquire);
|
||||
if ((~cac_ct) != cur_wt) {
|
||||
return false; // empty
|
||||
}
|
||||
if (el->f_ct_.compare_exchange_weak(cac_ct, 0, std::memory_order_relaxed)) {
|
||||
wt_.store(cur_wt + 1, std::memory_order_release);
|
||||
}
|
||||
k = 0;
|
||||
}
|
||||
else {
|
||||
std::memcpy(buff, &(elems[circ::index_of(cur_rd)].data_), sizeof(buff));
|
||||
if (rd_.compare_exchange_weak(cur_rd, cur_rd + 1, std::memory_order_release)) {
|
||||
std::forward<F>(f)(buff);
|
||||
std::forward<R>(out)(true);
|
||||
return true;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct prod_cons_impl<wr<relat::single, relat::multi, trans::broadcast>> {
|
||||
|
||||
using rc_t = std::uint64_t;
|
||||
|
||||
enum : rc_t {
|
||||
ep_mask = 0x00000000ffffffffull,
|
||||
ep_incr = 0x0000000100000000ull
|
||||
};
|
||||
|
||||
template <std::size_t DataSize, std::size_t AlignSize>
|
||||
struct elem_t {
|
||||
std::aligned_storage_t<DataSize, AlignSize> data_ {};
|
||||
std::atomic<rc_t> rc_ { 0 }; // read-counter
|
||||
};
|
||||
|
||||
alignas(cache_line_size) std::atomic<circ::u2_t> wt_; // write index
|
||||
alignas(cache_line_size) rc_t epoch_ { 0 }; // only one writer
|
||||
|
||||
circ::u2_t cursor() const noexcept {
|
||||
return wt_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool push(W* wrapper, F&& f, E* elems) {
|
||||
E* el;
|
||||
for (unsigned k = 0;;) {
|
||||
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
|
||||
if (cc == 0) return false; // no reader
|
||||
el = elems + circ::index_of(wt_.load(std::memory_order_relaxed));
|
||||
// check all consumers have finished reading this element
|
||||
auto cur_rc = el->rc_.load(std::memory_order_acquire);
|
||||
circ::cc_t rem_cc = cur_rc & ep_mask;
|
||||
if ((cc & rem_cc) && ((cur_rc & ~ep_mask) == epoch_)) {
|
||||
return false; // has not finished yet
|
||||
}
|
||||
// consider rem_cc to be 0 here
|
||||
if (el->rc_.compare_exchange_weak(
|
||||
cur_rc, epoch_ | static_cast<rc_t>(cc), std::memory_order_release)) {
|
||||
break;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
wt_.fetch_add(1, std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool force_push(W* wrapper, F&& f, E* elems) {
|
||||
E* el;
|
||||
epoch_ += ep_incr;
|
||||
for (unsigned k = 0;;) {
|
||||
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
|
||||
if (cc == 0) return false; // no reader
|
||||
el = elems + circ::index_of(wt_.load(std::memory_order_relaxed));
|
||||
// check all consumers have finished reading this element
|
||||
auto cur_rc = el->rc_.load(std::memory_order_acquire);
|
||||
circ::cc_t rem_cc = cur_rc & ep_mask;
|
||||
if (cc & rem_cc) {
|
||||
ipc::log("force_push: k = %u, cc = %u, rem_cc = %u\n", k, cc, rem_cc);
|
||||
cc = wrapper->elems()->disconnect_receiver(rem_cc); // disconnect all invalid readers
|
||||
if (cc == 0) return false; // no reader
|
||||
}
|
||||
// just compare & exchange
|
||||
if (el->rc_.compare_exchange_weak(
|
||||
cur_rc, epoch_ | static_cast<rc_t>(cc), std::memory_order_release)) {
|
||||
break;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
wt_.fetch_add(1, std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename R, typename E>
|
||||
bool pop(W* wrapper, circ::u2_t& cur, F&& f, R&& out, E* elems) {
|
||||
if (cur == cursor()) return false; // acquire
|
||||
auto* el = elems + circ::index_of(cur++);
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
for (unsigned k = 0;;) {
|
||||
auto cur_rc = el->rc_.load(std::memory_order_acquire);
|
||||
if ((cur_rc & ep_mask) == 0) {
|
||||
std::forward<R>(out)(true);
|
||||
return true;
|
||||
}
|
||||
auto nxt_rc = cur_rc & ~static_cast<rc_t>(wrapper->connected_id());
|
||||
if (el->rc_.compare_exchange_weak(cur_rc, nxt_rc, std::memory_order_release)) {
|
||||
std::forward<R>(out)((nxt_rc & ep_mask) == 0);
|
||||
return true;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct prod_cons_impl<wr<relat::multi, relat::multi, trans::broadcast>> {
|
||||
|
||||
using rc_t = std::uint64_t;
|
||||
using flag_t = std::uint64_t;
|
||||
|
||||
enum : rc_t {
|
||||
rc_mask = 0x00000000ffffffffull,
|
||||
ep_mask = 0x00ffffffffffffffull,
|
||||
ep_incr = 0x0100000000000000ull,
|
||||
ic_mask = 0xff000000ffffffffull,
|
||||
ic_incr = 0x0000000100000000ull
|
||||
};
|
||||
|
||||
template <std::size_t DataSize, std::size_t AlignSize>
|
||||
struct elem_t {
|
||||
std::aligned_storage_t<DataSize, AlignSize> data_ {};
|
||||
std::atomic<rc_t > rc_ { 0 }; // read-counter
|
||||
std::atomic<flag_t> f_ct_ { 0 }; // commit flag
|
||||
};
|
||||
|
||||
alignas(cache_line_size) std::atomic<circ::u2_t> ct_; // commit index
|
||||
alignas(cache_line_size) std::atomic<rc_t> epoch_ { 0 };
|
||||
|
||||
circ::u2_t cursor() const noexcept {
|
||||
return ct_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
constexpr static rc_t inc_rc(rc_t rc) noexcept {
|
||||
return (rc & ic_mask) | ((rc + ic_incr) & ~ic_mask);
|
||||
}
|
||||
|
||||
constexpr static rc_t inc_mask(rc_t rc) noexcept {
|
||||
return inc_rc(rc) & ~rc_mask;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool push(W* wrapper, F&& f, E* elems) {
|
||||
E* el;
|
||||
circ::u2_t cur_ct;
|
||||
rc_t epoch = epoch_.load(std::memory_order_acquire);
|
||||
for (unsigned k = 0;;) {
|
||||
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
|
||||
if (cc == 0) return false; // no reader
|
||||
el = elems + circ::index_of(cur_ct = ct_.load(std::memory_order_relaxed));
|
||||
// check all consumers have finished reading this element
|
||||
auto cur_rc = el->rc_.load(std::memory_order_relaxed);
|
||||
circ::cc_t rem_cc = cur_rc & rc_mask;
|
||||
if ((cc & rem_cc) && ((cur_rc & ~ep_mask) == epoch)) {
|
||||
return false; // has not finished yet
|
||||
}
|
||||
else if (!rem_cc) {
|
||||
auto cur_fl = el->f_ct_.load(std::memory_order_acquire);
|
||||
if ((cur_fl != cur_ct) && cur_fl) {
|
||||
return false; // full
|
||||
}
|
||||
}
|
||||
// consider rem_cc to be 0 here
|
||||
if (el->rc_.compare_exchange_weak(
|
||||
cur_rc, inc_mask(epoch | (cur_rc & ep_mask)) | static_cast<rc_t>(cc), std::memory_order_relaxed) &&
|
||||
epoch_.compare_exchange_weak(epoch, epoch, std::memory_order_acq_rel)) {
|
||||
break;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
// only one thread/process would touch here at one time
|
||||
ct_.store(cur_ct + 1, std::memory_order_release);
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
// set flag & try update wt
|
||||
el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename E>
|
||||
bool force_push(W* wrapper, F&& f, E* elems) {
|
||||
E* el;
|
||||
circ::u2_t cur_ct;
|
||||
rc_t epoch = epoch_.fetch_add(ep_incr, std::memory_order_release) + ep_incr;
|
||||
for (unsigned k = 0;;) {
|
||||
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
|
||||
if (cc == 0) return false; // no reader
|
||||
el = elems + circ::index_of(cur_ct = ct_.load(std::memory_order_relaxed));
|
||||
// check all consumers have finished reading this element
|
||||
auto cur_rc = el->rc_.load(std::memory_order_acquire);
|
||||
circ::cc_t rem_cc = cur_rc & rc_mask;
|
||||
if (cc & rem_cc) {
|
||||
ipc::log("force_push: k = %u, cc = %u, rem_cc = %u\n", k, cc, rem_cc);
|
||||
cc = wrapper->elems()->disconnect_receiver(rem_cc); // disconnect all invalid readers
|
||||
if (cc == 0) return false; // no reader
|
||||
}
|
||||
// just compare & exchange
|
||||
if (el->rc_.compare_exchange_weak(
|
||||
cur_rc, inc_mask(epoch | (cur_rc & ep_mask)) | static_cast<rc_t>(cc), std::memory_order_relaxed)) {
|
||||
if (epoch == epoch_.load(std::memory_order_acquire)) {
|
||||
break;
|
||||
}
|
||||
else if (push(wrapper, std::forward<F>(f), elems)) {
|
||||
return true;
|
||||
}
|
||||
epoch = epoch_.fetch_add(ep_incr, std::memory_order_release) + ep_incr;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
// only one thread/process would touch here at one time
|
||||
ct_.store(cur_ct + 1, std::memory_order_release);
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
// set flag & try update wt
|
||||
el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename W, typename F, typename R, typename E, std::size_t N>
|
||||
bool pop(W* wrapper, circ::u2_t& cur, F&& f, R&& out, E(& elems)[N]) {
|
||||
auto* el = elems + circ::index_of(cur);
|
||||
auto cur_fl = el->f_ct_.load(std::memory_order_acquire);
|
||||
if (cur_fl != ~static_cast<flag_t>(cur)) {
|
||||
return false; // empty
|
||||
}
|
||||
++cur;
|
||||
std::forward<F>(f)(&(el->data_));
|
||||
for (unsigned k = 0;;) {
|
||||
auto cur_rc = el->rc_.load(std::memory_order_acquire);
|
||||
if ((cur_rc & rc_mask) == 0) {
|
||||
std::forward<R>(out)(true);
|
||||
el->f_ct_.store(cur + N - 1, std::memory_order_release);
|
||||
return true;
|
||||
}
|
||||
auto nxt_rc = inc_rc(cur_rc) & ~static_cast<rc_t>(wrapper->connected_id());
|
||||
bool last_one = false;
|
||||
if ((last_one = (nxt_rc & rc_mask) == 0)) {
|
||||
el->f_ct_.store(cur + N - 1, std::memory_order_release);
|
||||
}
|
||||
if (el->rc_.compare_exchange_weak(cur_rc, nxt_rc, std::memory_order_release)) {
|
||||
std::forward<R>(out)(last_one);
|
||||
return true;
|
||||
}
|
||||
ipc::yield(k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ipc
|
||||
@ -1,58 +0,0 @@
|
||||
The goal of reducing sequential computation also forms the foundation of the Extended Neural GPU \citep{extendedngpu}, ByteNet \citep{NalBytenet2017} and ConvS2S \citep{JonasFaceNet2017}, all of which use convolutional neural networks as basic building block, computing hidden representations in parallel for all input and output positions. In these models, the number of operations required to relate signals from two arbitrary input or output positions grows in the distance between positions, linearly for ConvS2S and logarithmically for ByteNet. This makes it more difficult to learn dependencies between distant positions \citep{hochreiter2001gradient}. In the Transformer this is reduced to a constant number of operations, albeit at the cost of reduced effective resolution due to averaging attention-weighted positions, an effect we counteract with Multi-Head Attention as described in section~\ref{sec:attention}.
|
||||
|
||||
Self-attention, sometimes called intra-attention is an attention mechanism relating different positions of a single sequence in order to compute a representation of the sequence. Self-attention has been used successfully in a variety of tasks including reading comprehension, abstractive summarization, textual entailment and learning task-independent sentence representations \citep{cheng2016long, decomposableAttnModel, paulus2017deep, lin2017structured}.
|
||||
|
||||
End-to-end memory networks are based on a recurrent attention mechanism instead of sequence-aligned recurrence and have been shown to perform well on simple-language question answering and language modeling tasks \citep{sukhbaatar2015}.
|
||||
|
||||
To the best of our knowledge, however, the Transformer is the first transduction model relying entirely on self-attention to compute representations of its input and output without using sequence-aligned RNNs or convolution.
|
||||
In the following sections, we will describe the Transformer, motivate self-attention and discuss its advantages over models such as \citep{neural_gpu, NalBytenet2017} and \citep{JonasFaceNet2017}.
|
||||
|
||||
|
||||
%\citep{JonasFaceNet2017} report new SOTA on machine translation for English-to-German (EnDe), Enlish-to-French (EnFr) and English-to-Romanian language pairs.
|
||||
|
||||
%For example,! in MT, we must draw information from both input and previous output words to translate an output word accurately. An attention layer \citep{bahdanau2014neural} can connect a very large number of positions at low computation cost, making it an essential ingredient in competitive recurrent models for machine translation.
|
||||
|
||||
%A natural question to ask then is, "Could we replace recurrence with attention?". \marginpar{Don't know if it's the most natural question to ask given the previous statements. Also, need to say that the complexity table summarizes these statements} Such a model would be blessed with the computational efficiency of attention and the power of cross-positional communication. In this work, show that pure attention models work remarkably well for MT, achieving new SOTA results on EnDe and EnFr, and can be trained in under $2$ days on xyz architecture.
|
||||
|
||||
%After the seminal models introduced in \citep{sutskever14, bahdanau2014neural, cho2014learning}, recurrent models have become the dominant solution for both sequence modeling and sequence-to-sequence transduction. Many efforts such as \citep{wu2016google,luong2015effective,jozefowicz2016exploring} have pushed the boundaries of machine translation (MT) and language modeling with recurrent endoder-decoder and recurrent language models. Recent effort \citep{shazeer2017outrageously} has successfully combined the power of conditional computation with sequence models to train very large models for MT, pushing SOTA at lower computational cost.
|
||||
|
||||
%Recurrent models compute a vector of hidden states $h_t$, for each time step $t$ of computation. $h_t$ is a function of both the input at time $t$ and the previous hidden state $h_t$. This dependence on the previous hidden state precludes processing all timesteps at once, instead requiring long sequences of sequential operations. In practice, this results in greatly reduced computational efficiency, as on modern computing hardware, a single operation on a large batch is much faster than a large number of operations on small batches. The problem gets worse at longer sequence lengths. Although sequential computation is not a severe bottleneck at inference time, as autoregressively generating each output requires all previous outputs, the inability to compute scores at all output positions at once hinders us from rapidly training our models over large datasets. Although impressive work such as \citep{Kuchaiev2017Factorization} is able to significantly accelerate the training of LSTMs with factorization tricks, we are still bound by the linear dependence on sequence length.
|
||||
|
||||
%If the model could compute hidden states at each time step using only the inputs and outputs, it would be liberated from the dependence on results from previous time steps during training. This line of thought is the foundation of recent efforts such as the Markovian neural GPU \citep{neural_gpu}, ByteNet \citep{NalBytenet2017} and ConvS2S \citep{JonasFaceNet2017}, all of which use convolutional neural networks as a building block to compute hidden representations simultaneously for all timesteps, resulting in $O(1)$ sequential time complexity. \citep{JonasFaceNet2017} report new SOTA on machine translation for English-to-German (EnDe), Enlish-to-French (EnFr) and English-to-Romanian language pairs.
|
||||
|
||||
%A crucial component for accurate sequence prediction is modeling cross-positional communication. For example, in MT, we must draw information from both input and previous output words to translate an output word accurately. An attention layer \citep{bahdanau2014neural} can connect a very large number of positions at a low computation cost, also $O(1)$ sequential time complexity, making it an essential ingredient in recurrent encoder-decoder architectures for MT. A natural question to ask then is, "Could we replace recurrence with attention?". \marginpar{Don't know if it's the most natural question to ask given the previous statements. Also, need to say that the complexity table summarizes these statements} Such a model would be blessed with the computational efficiency of attention and the power of cross-positional communication. In this work, show that pure attention models work remarkably well for MT, achieving new SOTA results on EnDe and EnFr, and can be trained in under $2$ days on xyz architecture.
|
||||
|
||||
|
||||
|
||||
%Note: Facebook model is no better than RNNs in this regard, since it requires a number of layers proportional to the distance you want to communicate. Bytenet is more promising, since it requires a logarithmnic number of layers (does bytenet have SOTA results)?
|
||||
|
||||
%Note: An attention layer can connect a very large number of positions at a low computation cost in O(1) sequential operations. This is why encoder-decoder attention has been so successful in seq-to-seq models so far. It is only natural, then, to also use attention to connect the timesteps of the same sequence.
|
||||
|
||||
%Note: I wouldn't say that long sequences are not a problem during inference. It would be great if we could infer with no long sequences. We could just say later on that, while our training graph is constant-depth, our model still requires sequential operations in the decoder part during inference due to the autoregressive nature of the model.
|
||||
|
||||
%\begin{table}[h!]
|
||||
%\caption{Attention models are quite efficient for cross-positional communications when sequence length is smaller than channel depth. $n$ represents the sequence length and $d$ represents the channel depth.}
|
||||
%\label{tab:op_complexities}
|
||||
%\begin{center}
|
||||
%\vspace{-5pt}
|
||||
%\scalebox{0.75}{
|
||||
|
||||
%\begin{tabular}{l|c|c|c}
|
||||
%\hline \hline
|
||||
%Layer Type & Receptive & Complexity & Sequential \\
|
||||
% & Field & & Operations \\
|
||||
%\hline
|
||||
%Pointwise Feed-Forward & $1$ & $O(n \cdot d^2)$ & $O(1)$ \\
|
||||
%\hline
|
||||
%Recurrent & $n$ & $O(n \cdot d^2)$ & $O(n)$ \\
|
||||
%\hline
|
||||
%Convolutional & $r$ & $O(r \cdot n \cdot d^2)$ & $O(1)$ \\
|
||||
%\hline
|
||||
%Convolutional (separable) & $r$ & $O(r \cdot n \cdot d + n %\cdot d^2)$ & $O(1)$ \\
|
||||
%\hline
|
||||
%Attention & $r$ & $O(r \cdot n \cdot d)$ & $O(1)$ \\
|
||||
%\hline \hline
|
||||
%\end{tabular}
|
||||
%}
|
||||
%\end{center}
|
||||
%\end{table}
|
||||
@ -1,18 +0,0 @@
|
||||
Recurrent neural networks, long short-term memory \citep{hochreiter1997} and gated recurrent \citep{gruEval14} neural networks in particular, have been firmly established as state of the art approaches in sequence modeling and transduction problems such as language modeling and machine translation \citep{sutskever14, bahdanau2014neural, cho2014learning}. Numerous efforts have since continued to push the boundaries of recurrent language models and encoder-decoder architectures \citep{wu2016google,luong2015effective,jozefowicz2016exploring}.
|
||||
|
||||
Recurrent models typically factor computation along the symbol positions of the input and output sequences. Aligning the positions to steps in computation time, they generate a sequence of hidden states $h_t$, as a function of the previous hidden state $h_{t-1}$ and the input for position $t$. This inherently sequential nature precludes parallelization within training examples, which becomes critical at longer sequence lengths, as memory constraints limit batching across examples.
|
||||
%\marginpar{not sure if the memory constraints are understandable here}
|
||||
Recent work has achieved significant improvements in computational efficiency through factorization tricks \citep{Kuchaiev2017Factorization} and conditional computation \citep{shazeer2017outrageously}, while also improving model performance in case of the latter. The fundamental constraint of sequential computation, however, remains.
|
||||
|
||||
%\marginpar{@all: there is work on analyzing what attention really does in seq2seq models, couldn't find it right away}
|
||||
|
||||
Attention mechanisms have become an integral part of compelling sequence modeling and transduction models in various tasks, allowing modeling of dependencies without regard to their distance in the input or output sequences \citep{bahdanau2014neural, structuredAttentionNetworks}. In all but a few cases \citep{decomposableAttnModel}, however, such attention mechanisms are used in conjunction with a recurrent network.
|
||||
|
||||
%\marginpar{not sure if "cross-positional communication" is understandable without explanation}
|
||||
%\marginpar{insert exact training times and stats for the model that reaches sota earliest, maybe even a single GPU model?}
|
||||
|
||||
In this work we propose the Transformer, a model architecture eschewing recurrence and instead relying entirely on an attention mechanism to draw global dependencies between input and output. The Transformer allows for significantly more parallelization and can reach a new state of the art in translation quality after being trained for as little as twelve hours on eight P100 GPUs.
|
||||
%\marginpar{you removed the constant number of repetitions part. I wrote it because I wanted to make it clear that the model does not only perform attention once, while it's also not recurrent. I thought that might be important to get across early.}
|
||||
|
||||
% Just a standard paragraph with citations, rewrite.
|
||||
%After the seminal papers of \citep{sutskever14}, \citep{bahdanau2014neural}, and \citep{cho2014learning}, recurrent models have become the dominant solution for both sequence modeling and sequence-to-sequence transduction. Many efforts such as \citep{wu2016google,luong2015effective,jozefowicz2016exploring} have pushed the boundaries of machine translation and language modeling with recurrent sequence models. Recent effort \citep{shazeer2017outrageously} has combined the power of conditional computation with sequence models to train very large models for machine translation, pushing SOTA at lower computational cost. Recurrent models compute a vector of hidden states $h_t$, for each time step $t$ of computation. $h_t$ is a function of both the input at time $t$ and the previous hidden state $h_t$. This dependence on the previous hidden state encumbers recurrnet models to process multiple inputs at once, and their time complexity is a linear function of the length of the input and output, both during training and inference. [What I want to say here is that although this is fine during decoding, at training time, we are given both input and output and this linear nature does not allow the RNN to process all inputs and outputs simultaneously and haven't been used on datasets that are the of the scale of the web. What's the largest dataset we have ? . Talk about Nividia and possibly other's effors to speed up things, and possibly other efforts that alleviate this, but are still limited by it's comptuational nature]. Rest of the intro: What if you could construct the state based on the actual inputs and outputs, then you could construct them all at once. This has been the foundation of many promising recent efforts, bytenet,facenet (Also talk about quasi rnn here). Now we talk about attention!! Along with cell architectures such as long short-term meory (LSTM) \citep{hochreiter1997}, and gated recurrent units (GRUs) \citep{cho2014learning}, attention has emerged as an essential ingredient in successful sequence models, in particular for machine translation. In recent years, many, if not all, state-of-the-art (SOTA) results in machine translation have been achieved with attention-based sequence models \citep{wu2016google,luong2015effective,jozefowicz2016exploring}. Talk about the neon work on how it played with attention to do self attention! Then talk about what we do.
|
||||
@ -1,155 +0,0 @@
|
||||
|
||||
\begin{figure}
|
||||
\centering
|
||||
\includegraphics[scale=0.6]{Figures/ModalNet-21}
|
||||
\caption{The Transformer - model architecture.}
|
||||
\label{fig:model-arch}
|
||||
\end{figure}
|
||||
|
||||
% Although the primary workhorse of our model is attention,
|
||||
%Our model maintains the encoder-decoder structure that is common to many so-called sequence-to-sequence models \citep{bahdanau2014neural,sutskever14}. As in all such architectures, the encoder computes a representation of the input sequence, and the decoder consumes these representations along with the output tokens to autoregressively produce the output sequence. Where, traditionally, the encoder and decoder contain stacks of recurrent or convolutional layers, our encoder and decoder stacks are composed of attention layers and position-wise feed-forward layers (Figure~\ref{fig:model-arch}). The following sections describe the gross architecture and these particular components in detail.
|
||||
|
||||
Most competitive neural sequence transduction models have an encoder-decoder structure \citep{cho2014learning,bahdanau2014neural,sutskever14}. Here, the encoder maps an input sequence of symbol representations $(x_1, ..., x_n)$ to a sequence of continuous representations $\mathbf{z} = (z_1, ..., z_n)$. Given $\mathbf{z}$, the decoder then generates an output sequence $(y_1,...,y_m)$ of symbols one element at a time. At each step the model is auto-regressive \citep{graves2013generating}, consuming the previously generated symbols as additional input when generating the next.
|
||||
|
||||
The Transformer follows this overall architecture using stacked self-attention and point-wise, fully connected layers for both the encoder and decoder, shown in the left and right halves of Figure~\ref{fig:model-arch}, respectively.
|
||||
|
||||
\subsection{Encoder and Decoder Stacks}
|
||||
|
||||
\paragraph{Encoder:}The encoder is composed of a stack of $N=6$ identical layers. Each layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network. We employ a residual connection \citep{he2016deep} around each of the two sub-layers, followed by layer normalization \cite{layernorm2016}. That is, the output of each sub-layer is $\mathrm{LayerNorm}(x + \mathrm{Sublayer}(x))$, where $\mathrm{Sublayer}(x)$ is the function implemented by the sub-layer itself. To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension $\dmodel=512$.
|
||||
|
||||
\paragraph{Decoder:}The decoder is also composed of a stack of $N=6$ identical layers. In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack. Similar to the encoder, we employ residual connections around each of the sub-layers, followed by layer normalization. We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position $i$ can depend only on the known outputs at positions less than $i$.
|
||||
|
||||
% In our model (Figure~\ref{fig:model-arch}), the encoder and decoder are composed of stacks of alternating self-attention layers (for cross-positional communication) and position-wise feed-forward layers (for in-place computation). In addition, the decoder stack contains encoder-decoder attention layers. Since attention is agnostic to the distances between words, our model requires a "positional encoding" to be added to the encoder and decoder input. The following sections describe all of these components in detail.
|
||||
|
||||
\subsection{Attention} \label{sec:attention}
|
||||
An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.
|
||||
|
||||
\subsubsection{Scaled Dot-Product Attention} \label{sec:scaled-dot-prod}
|
||||
|
||||
% \begin{figure}
|
||||
% \centering
|
||||
% \includegraphics[scale=0.6]{Figures/ModalNet-19}
|
||||
% \caption{Scaled Dot-Product Attention.}
|
||||
% \label{fig:multi-head-att}
|
||||
% \end{figure}
|
||||
|
||||
We call our particular attention "Scaled Dot-Product Attention" (Figure~\ref{fig:multi-head-att}). The input consists of queries and keys of dimension $d_k$, and values of dimension $d_v$. We compute the dot products of the query with all keys, divide each by $\sqrt{d_k}$, and apply a softmax function to obtain the weights on the values.
|
||||
|
||||
In practice, we compute the attention function on a set of queries simultaneously, packed together into a matrix $Q$. The keys and values are also packed together into matrices $K$ and $V$. We compute the matrix of outputs as:
|
||||
|
||||
\begin{equation}
|
||||
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
|
||||
\end{equation}
|
||||
|
||||
The two most commonly used attention functions are additive attention \citep{bahdanau2014neural}, and dot-product (multiplicative) attention. Dot-product attention is identical to our algorithm, except for the scaling factor of $\frac{1}{\sqrt{d_k}}$. Additive attention computes the compatibility function using a feed-forward network with a single hidden layer. While the two are similar in theoretical complexity, dot-product attention is much faster and more space-efficient in practice, since it can be implemented using highly optimized matrix multiplication code.
|
||||
|
||||
%We scale the dot products by $1/\sqrt{d_k}$ to limit the magnitude of the dot products, which works well in practice. Otherwise, we found applying the softmax to often result in weights very close to 0 or 1, and hence minuscule gradients.
|
||||
|
||||
% Already described in the subsequent section
|
||||
%When used as part of decoder self-attention, an optional mask function is applied just before the softmax to prevent positions from attending to subsequent positions. This mask simply sets the logits corresponding to all illegal connections (those outside of the lower triangle) to $-\infty$.
|
||||
|
||||
%\paragraph{Comparison to Additive Attention: } We choose dot product attention over additive attention \citep{bahdanau2014neural} since it can be computed using highly optimized matrix multiplication code. This optimization is particularly important to us, as we employ many attention layers in our model.
|
||||
|
||||
While for small values of $d_k$ the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of $d_k$ \citep{DBLP:journals/corr/BritzGLL17}. We suspect that for large values of $d_k$, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients \footnote{To illustrate why the dot products get large, assume that the components of $q$ and $k$ are independent random variables with mean $0$ and variance $1$. Then their dot product, $q \cdot k = \sum_{i=1}^{d_k} q_ik_i$, has mean $0$ and variance $d_k$.}. To counteract this effect, we scale the dot products by $\frac{1}{\sqrt{d_k}}$.
|
||||
|
||||
|
||||
%We suspect this to be caused by the dot products growing too large in magnitude to result in useful gradients after applying the softmax function. To counteract this, we scale the dot product by $1/\sqrt{d_k}$.
|
||||
|
||||
|
||||
\subsubsection{Multi-Head Attention} \label{sec:multihead}
|
||||
|
||||
\begin{figure}
|
||||
\begin{minipage}[t]{0.5\textwidth}
|
||||
\centering
|
||||
Scaled Dot-Product Attention \\
|
||||
\vspace{0.5cm}
|
||||
\includegraphics[scale=0.6]{Figures/ModalNet-19}
|
||||
\end{minipage}
|
||||
\begin{minipage}[t]{0.5\textwidth}
|
||||
\centering
|
||||
Multi-Head Attention \\
|
||||
\vspace{0.1cm}
|
||||
\includegraphics[scale=0.6]{Figures/ModalNet-20}
|
||||
\end{minipage}
|
||||
|
||||
|
||||
% \centering
|
||||
|
||||
\caption{(left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several attention layers running in parallel.}
|
||||
\label{fig:multi-head-att}
|
||||
\end{figure}
|
||||
|
||||
Instead of performing a single attention function with $\dmodel$-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values $h$ times with different, learned linear projections to $d_k$, $d_k$ and $d_v$ dimensions, respectively.
|
||||
On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding $d_v$-dimensional output values. These are concatenated and once again projected, resulting in the final values, as depicted in Figure~\ref{fig:multi-head-att}.
|
||||
|
||||
Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.
|
||||
|
||||
\begin{align*}
|
||||
\mathrm{MultiHead}(Q, K, V) &= \mathrm{Concat}(\mathrm{head_1}, ..., \mathrm{head_h})W^O\\
|
||||
% \mathrm{where} \mathrm{head_i} &= \mathrm{Attention}(QW_Q_i^{\dmodel \times d_q}, KW_K_i^{\dmodel \times d_k}, VW^V_i^{\dmodel \times d_v})\\
|
||||
\text{where}~\mathrm{head_i} &= \mathrm{Attention}(QW^Q_i, KW^K_i, VW^V_i)\\
|
||||
\end{align*}
|
||||
|
||||
Where the projections are parameter matrices $W^Q_i \in \mathbb{R}^{\dmodel \times d_k}$, $W^K_i \in \mathbb{R}^{\dmodel \times d_k}$, $W^V_i \in \mathbb{R}^{\dmodel \times d_v}$ and $W^O \in \mathbb{R}^{hd_v \times \dmodel}$.
|
||||
|
||||
|
||||
%find it better (and no more expensive) to have multiple parallel attention layers (each over the full set of positions) with proportionally lower-dimensional keys, values and queries. We call this "Multi-Head Attention" (Figure~\ref{fig:multi-head-att}). The keys, values, and queries for each of these parallel attention layers are computed by learned linear transformations of the inputs to the multi-head attention. We use different linear transformations across different parallel attention layers. The output of the parallel attention layers are concatenated, and then passed through a final learned linear transformation.
|
||||
|
||||
In this work we employ $h=8$ parallel attention layers, or heads. For each of these we use $d_k=d_v=\dmodel/h=64$.
|
||||
Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.
|
||||
|
||||
\subsubsection{Applications of Attention in our Model}
|
||||
|
||||
The Transformer uses multi-head attention in three different ways:
|
||||
\begin{itemize}
|
||||
\item In "encoder-decoder attention" layers, the queries come from the previous decoder layer, and the memory keys and values come from the output of the encoder. This allows every position in the decoder to attend over all positions in the input sequence. This mimics the typical encoder-decoder attention mechanisms in sequence-to-sequence models such as \citep{wu2016google, bahdanau2014neural,JonasFaceNet2017}.
|
||||
|
||||
\item The encoder contains self-attention layers. In a self-attention layer all of the keys, values and queries come from the same place, in this case, the output of the previous layer in the encoder. Each position in the encoder can attend to all positions in the previous layer of the encoder.
|
||||
|
||||
\item Similarly, self-attention layers in the decoder allow each position in the decoder to attend to all positions in the decoder up to and including that position. We need to prevent leftward information flow in the decoder to preserve the auto-regressive property. We implement this inside of scaled dot-product attention by masking out (setting to $-\infty$) all values in the input of the softmax which correspond to illegal connections. See Figure~\ref{fig:multi-head-att}.
|
||||
|
||||
\end{itemize}
|
||||
|
||||
\subsection{Position-wise Feed-Forward Networks}\label{sec:ffn}
|
||||
|
||||
In addition to attention sub-layers, each of the layers in our encoder and decoder contains a fully connected feed-forward network, which is applied to each position separately and identically. This consists of two linear transformations with a ReLU activation in between.
|
||||
|
||||
\begin{equation}
|
||||
\mathrm{FFN}(x)=\max(0, xW_1 + b_1) W_2 + b_2
|
||||
\end{equation}
|
||||
|
||||
While the linear transformations are the same across different positions, they use different parameters from layer to layer. Another way of describing this is as two convolutions with kernel size 1. The dimensionality of input and output is $\dmodel=512$, and the inner-layer has dimensionality $d_{ff}=2048$.
|
||||
|
||||
|
||||
|
||||
%In the appendix, we describe how the position-wise feed-forward network can also be seen as a form of attention.
|
||||
|
||||
%from Jakob: The number of operations required for the model to relate signals from two arbitrary input or output positions grows in the distance between positions in input or output, linearly for ConvS2S and logarithmically for ByteNet, making it harder to learn dependencies between these positions \citep{hochreiter2001gradient}. In the transformer this is reduced to a constant number of operations, albeit at the cost of effective resolution caused by averaging attention-weighted positions, an effect we aim to counteract with multi-headed attention.
|
||||
|
||||
|
||||
%Figure~\ref{fig:simple-att} presents a simple attention function, $A$, with a single head, that forms the basis of our multi-head attention. $A$ takes a query key vector $\kq$, matrices of memory keys $\km$ and memory values $\vm$ ,and produces a query value vector $\vq$ as
|
||||
%\begin{equation*} \label{eq:attention}
|
||||
% A(\kq, \km, \vm) = {\vm}^T (Softmax(\km \kq).
|
||||
%\end{equation*}
|
||||
%We linearly transform $\kq,\,\km$, and $\vm$ with learned matrices ${\Wkq \text{,} \, \Wkm}$, and ${\Wvm}$ before calling the attention function, and transform the output query with $\Wvq$ before handing it to the feed forward layer. Each attention layer has it's own set of transformation matrices, which are shared across all query positions. $A$ is applied in parallel for each query position, and is implemented very efficiently as a batch of matrix multiplies. The self-attention and encoder-decoder attention layers use $A$, but with different arguments. For example, in encdoder self-attention, queries in encoder layer $i$ attention to memories in encoder layer $i-1$. To ensure that decoder self-attention layers do not look at future words, we add $- \inf$ to the softmax logits in positions $j+1$ to query length for query position $l$.
|
||||
|
||||
%In simple attention, the query value is a weighted combination of the memory values where the attention weights sum to one. Although this function performs well in practice, the constraint on attention weights can restrict the amount of information that flows from memories to queries because the query cannot focus on multiple memory positions at once, which might be desirable when translating long sequences. \marginpar{@usz, could you think of an example of this ?} We remedy this by maintaining multiple attention heads at each query position that attend to all memory positions in parallel, with a different set of parameters per attention head $h$.
|
||||
%\marginpar{}
|
||||
|
||||
\subsection{Embeddings and Softmax}
|
||||
Similarly to other sequence transduction models, we use learned embeddings to convert the input tokens and output tokens to vectors of dimension $\dmodel$. We also use the usual learned linear transformation and softmax function to convert the decoder output to predicted next-token probabilities. In our model, we share the same weight matrix between the two embedding layers and the pre-softmax linear transformation, similar to \citep{press2016using}. In the embedding layers, we multiply those weights by $\sqrt{\dmodel}$.
|
||||
|
||||
|
||||
\subsection{Positional Encoding}
|
||||
Since our model contains no recurrence and no convolution, in order for the model to make use of the order of the sequence, we must inject some information about the relative or absolute position of the tokens in the sequence. To this end, we add "positional encodings" to the input embeddings at the bottoms of the encoder and decoder stacks. The positional encodings have the same dimension $\dmodel$ as the embeddings, so that the two can be summed. There are many choices of positional encodings, learned and fixed \citep{JonasFaceNet2017}.
|
||||
|
||||
In this work, we use sine and cosine functions of different frequencies:
|
||||
|
||||
\begin{align*}
|
||||
PE_{(pos,2i)} = sin(pos / 10000^{2i/\dmodel}) \\
|
||||
PE_{(pos,2i+1)} = cos(pos / 10000^{2i/\dmodel})
|
||||
\end{align*}
|
||||
|
||||
where $pos$ is the position and $i$ is the dimension. That is, each dimension of the positional encoding corresponds to a sinusoid. The wavelengths form a geometric progression from $2\pi$ to $10000 \cdot 2\pi$. We chose this function because we hypothesized it would allow the model to easily learn to attend by relative positions, since for any fixed offset $k$, $PE_{pos+k}$ can be represented as a linear function of $PE_{pos}$.
|
||||
|
||||
We also experimented with using learned positional embeddings \citep{JonasFaceNet2017} instead, and found that the two versions produced nearly identical results (see Table~\ref{tab:variations} row (E)). We chose the sinusoidal version because it may allow the model to extrapolate to sequence lengths longer than the ones encountered during training.
|
||||
@ -1,45 +0,0 @@
|
||||
\pagebreak
|
||||
\section*{Two Feed-Forward Layers = Attention over Parameters}\label{sec:parameter_attention}
|
||||
|
||||
In addition to attention layers, our model contains position-wise feed-forward networks (Section \ref{sec:ffn}), which consist of two linear transformations with a ReLU activation in between. In fact, these networks too can be seen as a form of attention. Compare the formula for such a network with the formula for a simple dot-product attention layer (biases and scaling factors omitted):
|
||||
|
||||
\begin{align*}
|
||||
FFN(x, W_1, W_2) = ReLU(xW_1)W_2 \\
|
||||
A(q, K, V) = Softmax(qK^T)V
|
||||
\end{align*}
|
||||
|
||||
Based on the similarity of these formulae, the two-layer feed-forward network can be seen as a kind of attention, where the keys and values are the rows of the trainable parameter matrices $W_1$ and $W_2$, and where we use ReLU instead of Softmax in the compatibility function.
|
||||
|
||||
%the compatablity function is $compat(q, k_i) = ReLU(q \cdot k_i)$ instead of $Softmax(qK_T)_i$.
|
||||
|
||||
Given this similarity, we experimented with replacing the position-wise feed-forward networks with attention layers similar to the ones we use everywhere else our model. The multi-head-attention-over-parameters sublayer is identical to the multi-head attention described in \ref{sec:multihead}, except that the "keys" and "values" inputs to each attention head are trainable model parameters, as opposed to being linear projections of a previous layer. These parameters are scaled up by a factor of $\sqrt{d_{model}}$ in order to be more similar to activations.
|
||||
|
||||
In our first experiment, we replaced each position-wise feed-forward network with a multi-head-attention-over-parameters sublayer with $h_p=8$ heads, key-dimensionality $d_{pk}=64$, and value-dimensionality $d_{pv}=64$, using $n_p=1536$ key-value pairs for each attention head. The sublayer has a total of $2097152$ parameters, including the parameters in the query projection and the output projection. This matches the number of parameters in the position-wise feed-forward network that we replaced. While the theoretical amount of computation is also the same, in practice, the attention version caused the step times to be about 30\% longer.
|
||||
|
||||
In our second experiment, we used $h_p=8$ heads, and $n_p=512$ key-value pairs for each attention head, again matching the total number of parameters in the base model.
|
||||
|
||||
Results for the first experiment were slightly worse than for the base model, and results for the second experiment were slightly better, see Table~\ref{tab:parameter_attention}.
|
||||
|
||||
\begin{table}[h]
|
||||
\caption{Replacing the position-wise feed-forward networks with multihead-attention-over-parameters produces similar results to the base model. All metrics are on the English-to-German translation development set, newstest2013.}
|
||||
\label{tab:parameter_attention}
|
||||
\begin{center}
|
||||
\vspace{-2mm}
|
||||
%\scalebox{1.0}{
|
||||
\begin{tabular}{c|cccccc|cccc}
|
||||
\hline\rule{0pt}{2.0ex}
|
||||
& \multirow{2}{*}{$\dmodel$} & \multirow{2}{*}{$\dff$} &
|
||||
\multirow{2}{*}{$h_p$} & \multirow{2}{*}{$d_{pk}$} & \multirow{2}{*}{$d_{pv}$} &
|
||||
\multirow{2}{*}{$n_p$} &
|
||||
PPL & BLEU & params & training\\
|
||||
& & & & & & & (dev) & (dev) & $\times10^6$ & time \\
|
||||
\hline\rule{0pt}{2.0ex}
|
||||
base & 512 & 2048 & & & & & 4.92 & 25.8 & 65 & 12 hours\\
|
||||
\hline\rule{0pt}{2.0ex}
|
||||
AOP$_1$ & 512 & & 8 & 64 & 64 & 1536 & 4.92& 25.5 & 65 & 16 hours\\
|
||||
AOP$_2$ & 512 & & 16 & 64 & 64 & 512 & \textbf{4.86} & \textbf{25.9} & 65 & 16 hours \\
|
||||
\hline
|
||||
\end{tabular}
|
||||
%}
|
||||
\end{center}
|
||||
\end{table}
|
||||
@ -1,8 +0,0 @@
|
||||
chatgpt的老祖宗《Attention is all you need》
|
||||
|
||||
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
|
||||
|
||||
真实的摘要如下
|
||||
The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.
|
||||
|
||||
https://arxiv.org/abs/1706.03762
|
||||
@ -1,2 +0,0 @@
|
||||
from stable_baselines3.dqn.dqn import DQN
|
||||
from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy
|
||||
@ -1,245 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch.nn import functional as F
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
||||
from stable_baselines3.dqn.policies import DQNPolicy
|
||||
|
||||
|
||||
class DQN(OffPolicyAlgorithm):
|
||||
"""
|
||||
Deep Q-Network (DQN)
|
||||
|
||||
Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236
|
||||
Default hyperparameters are taken from the nature paper,
|
||||
except for the optimizer and learning rate that were taken from Stable Baselines defaults.
|
||||
|
||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: The learning rate, it can be a function
|
||||
of the current progress remaining (from 1 to 0)
|
||||
:param buffer_size: size of the replay buffer
|
||||
:param learning_starts: how many steps of the model to collect transitions for before learning starts
|
||||
:param batch_size: Minibatch size for each gradient update
|
||||
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
|
||||
:param gamma: the discount factor
|
||||
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
|
||||
like ``(5, "step")`` or ``(2, "episode")``.
|
||||
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
|
||||
Set to ``-1`` means to do as many gradient steps as steps done in the environment
|
||||
during the rollout.
|
||||
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
||||
at a cost of more complexity.
|
||||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||
:param target_update_interval: update the target network every ``target_update_interval``
|
||||
environment steps.
|
||||
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
|
||||
:param exploration_initial_eps: initial value of random action probability
|
||||
:param exploration_final_eps: final value of random action probability
|
||||
:param max_grad_norm: The maximum value for the gradient clipping
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[DQNPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 1e-4,
|
||||
buffer_size: int = 1000000,
|
||||
learning_starts: int = 50000,
|
||||
batch_size: Optional[int] = 32,
|
||||
tau: float = 1.0,
|
||||
gamma: float = 0.99,
|
||||
train_freq: Union[int, Tuple[int, str]] = 4,
|
||||
gradient_steps: int = 1,
|
||||
optimize_memory_usage: bool = False,
|
||||
target_update_interval: int = 10000,
|
||||
exploration_fraction: float = 0.1,
|
||||
exploration_initial_eps: float = 1.0,
|
||||
exploration_final_eps: float = 0.05,
|
||||
max_grad_norm: float = 10,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super(DQN, self).__init__(
|
||||
policy,
|
||||
env,
|
||||
DQNPolicy,
|
||||
learning_rate,
|
||||
buffer_size,
|
||||
learning_starts,
|
||||
batch_size,
|
||||
tau,
|
||||
gamma,
|
||||
train_freq,
|
||||
gradient_steps,
|
||||
action_noise=None, # No action noise
|
||||
policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=tensorboard_log,
|
||||
verbose=verbose,
|
||||
device=device,
|
||||
create_eval_env=create_eval_env,
|
||||
seed=seed,
|
||||
sde_support=False,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
supported_action_spaces=(gym.spaces.Discrete,),
|
||||
)
|
||||
|
||||
self.exploration_initial_eps = exploration_initial_eps
|
||||
self.exploration_final_eps = exploration_final_eps
|
||||
self.exploration_fraction = exploration_fraction
|
||||
self.target_update_interval = target_update_interval
|
||||
self.max_grad_norm = max_grad_norm
|
||||
# "epsilon" for the epsilon-greedy exploration
|
||||
self.exploration_rate = 0.0
|
||||
# Linear schedule will be defined in `_setup_model()`
|
||||
self.exploration_schedule = None
|
||||
self.q_net, self.q_net_target = None, None
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
super(DQN, self)._setup_model()
|
||||
self._create_aliases()
|
||||
self.exploration_schedule = get_linear_fn(
|
||||
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
|
||||
)
|
||||
|
||||
def _create_aliases(self) -> None:
|
||||
self.q_net = self.policy.q_net
|
||||
self.q_net_target = self.policy.q_net_target
|
||||
|
||||
def _on_step(self) -> None:
|
||||
"""
|
||||
Update the exploration rate and target network if needed.
|
||||
This method is called in ``collect_rollouts()`` after each step in the environment.
|
||||
"""
|
||||
if self.num_timesteps % self.target_update_interval == 0:
|
||||
polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)
|
||||
|
||||
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
|
||||
logger.record("rollout/exploration rate", self.exploration_rate)
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
||||
# Update learning rate according to schedule
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
|
||||
losses = []
|
||||
for _ in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
|
||||
with th.no_grad():
|
||||
# Compute the next Q-values using the target network
|
||||
next_q_values = self.q_net_target(replay_data.next_observations)
|
||||
# Follow greedy policy: use the one with the highest value
|
||||
next_q_values, _ = next_q_values.max(dim=1)
|
||||
# Avoid potential broadcast issue
|
||||
next_q_values = next_q_values.reshape(-1, 1)
|
||||
# 1-step TD target
|
||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
||||
|
||||
# Get current Q-values estimates
|
||||
current_q_values = self.q_net(replay_data.observations)
|
||||
|
||||
# Retrieve the q-values for the actions from the replay buffer
|
||||
current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())
|
||||
|
||||
# Compute Huber loss (less sensitive to outliers)
|
||||
loss = F.smooth_l1_loss(current_q_values, target_q_values)
|
||||
losses.append(loss.item())
|
||||
|
||||
# Optimize the policy
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clip gradient norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
|
||||
# Increase update counter
|
||||
self._n_updates += gradient_steps
|
||||
|
||||
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
logger.record("train/loss", np.mean(losses))
|
||||
|
||||
def predict(
|
||||
self,
|
||||
observation: np.ndarray,
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
"""
|
||||
Overrides the base_class predict function to include epsilon-greedy exploration.
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last states (can be None, used in recurrent policies)
|
||||
:param mask: The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:return: the model's action and the next state
|
||||
(used in recurrent policies)
|
||||
"""
|
||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
|
||||
n_batch = observation.shape[0]
|
||||
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
||||
else:
|
||||
action = np.array(self.action_space.sample())
|
||||
else:
|
||||
action, state = self.policy.predict(observation, state, mask, deterministic)
|
||||
return action, state
|
||||
|
||||
def learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "DQN",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
|
||||
return super(DQN, self).learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
eval_env=eval_env,
|
||||
eval_freq=eval_freq,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"]
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
state_dicts = ["policy", "policy.optimizer"]
|
||||
|
||||
return state_dicts, []
|
||||
@ -1,237 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
|
||||
|
||||
class QNetwork(BasePolicy):
|
||||
"""
|
||||
Action-Value (Q-Value) network for DQN
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param net_arch: The specification of the policy and value networks.
|
||||
:param activation_fn: Activation function
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
):
|
||||
super(QNetwork, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor=features_extractor,
|
||||
normalize_images=normalize_images,
|
||||
)
|
||||
|
||||
if net_arch is None:
|
||||
net_arch = [64, 64]
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.features_extractor = features_extractor
|
||||
self.features_dim = features_dim
|
||||
self.normalize_images = normalize_images
|
||||
action_dim = self.action_space.n # number of actions
|
||||
q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn)
|
||||
self.q_net = nn.Sequential(*q_net)
|
||||
|
||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Predict the q-values.
|
||||
|
||||
:param obs: Observation
|
||||
:return: The estimated Q-Value for each action.
|
||||
"""
|
||||
return self.q_net(self.extract_features(obs))
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
q_values = self.forward(observation)
|
||||
# Greedy action
|
||||
action = q_values.argmax(dim=1).reshape(-1)
|
||||
return action
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
data.update(
|
||||
dict(
|
||||
net_arch=self.net_arch,
|
||||
features_dim=self.features_dim,
|
||||
activation_fn=self.activation_fn,
|
||||
features_extractor=self.features_extractor,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class DQNPolicy(BasePolicy):
|
||||
"""
|
||||
Policy class with Q-Value Net and target net for DQN
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param lr_schedule: Learning rate schedule (could be constant)
|
||||
:param net_arch: The specification of the policy and value networks.
|
||||
:param activation_fn: Activation function
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(DQNPolicy, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
)
|
||||
|
||||
if net_arch is None:
|
||||
if features_extractor_class == FlattenExtractor:
|
||||
net_arch = [64, 64]
|
||||
else:
|
||||
net_arch = []
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.normalize_images = normalize_images
|
||||
|
||||
self.net_args = {
|
||||
"observation_space": self.observation_space,
|
||||
"action_space": self.action_space,
|
||||
"net_arch": self.net_arch,
|
||||
"activation_fn": self.activation_fn,
|
||||
"normalize_images": normalize_images,
|
||||
}
|
||||
|
||||
self.q_net, self.q_net_target = None, None
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, lr_schedule: Schedule) -> None:
|
||||
"""
|
||||
Create the network and the optimizer.
|
||||
|
||||
:param lr_schedule: Learning rate schedule
|
||||
lr_schedule(1) is the initial learning rate
|
||||
"""
|
||||
|
||||
self.q_net = self.make_q_net()
|
||||
self.q_net_target = self.make_q_net()
|
||||
self.q_net_target.load_state_dict(self.q_net.state_dict())
|
||||
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def make_q_net(self) -> QNetwork:
|
||||
# Make sure we always have separate networks for features extractors etc
|
||||
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
|
||||
return QNetwork(**net_args).to(self.device)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
return self.q_net._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
data.update(
|
||||
dict(
|
||||
net_arch=self.net_args["net_arch"],
|
||||
activation_fn=self.net_args["activation_fn"],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
MlpPolicy = DQNPolicy
|
||||
|
||||
|
||||
class CnnPolicy(DQNPolicy):
|
||||
"""
|
||||
Policy class for DQN when using images as input.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param lr_schedule: Learning rate schedule (could be constant)
|
||||
:param net_arch: The specification of the policy and value networks.
|
||||
:param activation_fn: Activation function
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(CnnPolicy, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
net_arch,
|
||||
activation_fn,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
register_policy("CnnPolicy", CnnPolicy)
|
||||
@ -1,2 +0,0 @@
|
||||
github stablebaseline3
|
||||
https://github.com/DLR-RM/stable-baselines3
|
||||
@ -1,27 +0,0 @@
|
||||
"In practice, we found that a high-entropy initial state is more likely to increase the speed of training.
|
||||
The entropy is calculated by:
|
||||
$$H=-\sum_{k= 1}^{n_k} p(k) \cdot \log p(k), p(k)=\frac{|A_k|}{|\mathcal{A}|}$$
|
||||
where $H$ is the entropy, $|A_k|$ is the number of agent nodes in $k$-th cluster, $|\mathcal{A}|$ is the total number of agents.
|
||||
To ensure the Cooperation Graph initialization has higher entropy,
|
||||
we will randomly generate multiple initial states,
|
||||
rank by their entropy and then pick the one with maximum $H$."
|
||||
|
||||
```
|
||||
FROM ubuntu:latest
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y python3 python3-pip && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN echo '[global]' > /etc/pip.conf && \
|
||||
echo 'index-url = https://mirrors.aliyun.com/pypi/simple/' >> /etc/pip.conf && \
|
||||
echo 'trusted-host = mirrors.aliyun.com' >> /etc/pip.conf
|
||||
|
||||
RUN pip3 install gradio requests[socks] mdtex2html
|
||||
|
||||
COPY . /gpt
|
||||
WORKDIR /gpt
|
||||
|
||||
|
||||
CMD ["python3", "main.py"]
|
||||
```
|
||||
@ -12,7 +12,7 @@ def write_chat_to_file(chatbot, history=None, file_name=None):
|
||||
file_name = 'chatGPT对话历史' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.html'
|
||||
os.makedirs('./gpt_log/', exist_ok=True)
|
||||
with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
|
||||
from theme import advanced_css
|
||||
from theme.theme import advanced_css
|
||||
f.write(f'<!DOCTYPE html><head><meta charset="utf-8"><title>对话历史</title><style>{advanced_css}</style></head>')
|
||||
for i, contents in enumerate(chatbot):
|
||||
for j, content in enumerate(contents):
|
||||
|
||||
188
crazy_functions/语音助手.py
Normal file
188
crazy_functions/语音助手.py
Normal file
@ -0,0 +1,188 @@
|
||||
from toolbox import update_ui
|
||||
from toolbox import CatchException, get_conf, markdown_convertion
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
from request_llm.bridge_all import predict_no_ui_long_connection
|
||||
import threading, time
|
||||
import numpy as np
|
||||
from .live_audio.aliyunASR import AliyunASR
|
||||
import json
|
||||
|
||||
class WatchDog():
|
||||
def __init__(self, timeout, bark_fn, interval=3, msg="") -> None:
|
||||
self.last_feed = None
|
||||
self.timeout = timeout
|
||||
self.bark_fn = bark_fn
|
||||
self.interval = interval
|
||||
self.msg = msg
|
||||
|
||||
def watch(self):
|
||||
while True:
|
||||
if time.time() - self.last_feed > self.timeout:
|
||||
if len(self.msg) > 0: print(self.msg)
|
||||
self.bark_fn()
|
||||
break
|
||||
time.sleep(self.interval)
|
||||
|
||||
def begin_watch(self):
|
||||
self.last_feed = time.time()
|
||||
th = threading.Thread(target=self.watch)
|
||||
th.daemon = True
|
||||
th.start()
|
||||
|
||||
def feed(self):
|
||||
self.last_feed = time.time()
|
||||
|
||||
def chatbot2history(chatbot):
|
||||
history = []
|
||||
for c in chatbot:
|
||||
for q in c:
|
||||
if q not in ["[请讲话]", "[等待GPT响应]", "[正在等您说完问题]"]:
|
||||
history.append(q.strip('<div class="markdown-body">').strip('</div>').strip('<p>').strip('</p>'))
|
||||
return history
|
||||
|
||||
class AsyncGptTask():
|
||||
def __init__(self) -> None:
|
||||
self.observe_future = []
|
||||
self.observe_future_chatbot_index = []
|
||||
|
||||
def gpt_thread_worker(self, i_say, llm_kwargs, history, sys_prompt, observe_window, index):
|
||||
try:
|
||||
MAX_TOKEN_ALLO = 2560
|
||||
i_say, history = input_clipping(i_say, history, max_token_limit=MAX_TOKEN_ALLO)
|
||||
gpt_say_partial = predict_no_ui_long_connection(inputs=i_say, llm_kwargs=llm_kwargs, history=history, sys_prompt=sys_prompt,
|
||||
observe_window=observe_window[index], console_slience=True)
|
||||
except ConnectionAbortedError as token_exceed_err:
|
||||
print('至少一个线程任务Token溢出而失败', e)
|
||||
except Exception as e:
|
||||
print('至少一个线程任务意外失败', e)
|
||||
|
||||
def add_async_gpt_task(self, i_say, chatbot_index, llm_kwargs, history, system_prompt):
|
||||
self.observe_future.append([""])
|
||||
self.observe_future_chatbot_index.append(chatbot_index)
|
||||
cur_index = len(self.observe_future)-1
|
||||
th_new = threading.Thread(target=self.gpt_thread_worker, args=(i_say, llm_kwargs, history, system_prompt, self.observe_future, cur_index))
|
||||
th_new.daemon = True
|
||||
th_new.start()
|
||||
|
||||
def update_chatbot(self, chatbot):
|
||||
for of, ofci in zip(self.observe_future, self.observe_future_chatbot_index):
|
||||
try:
|
||||
chatbot[ofci] = list(chatbot[ofci])
|
||||
chatbot[ofci][1] = markdown_convertion(of[0])
|
||||
except:
|
||||
self.observe_future = []
|
||||
self.observe_future_chatbot_index = []
|
||||
return chatbot
|
||||
|
||||
class InterviewAssistant(AliyunASR):
|
||||
def __init__(self):
|
||||
self.capture_interval = 0.5 # second
|
||||
self.stop = False
|
||||
self.parsed_text = ""
|
||||
self.parsed_sentence = ""
|
||||
self.buffered_sentence = ""
|
||||
self.event_on_result_chg = threading.Event()
|
||||
self.event_on_entence_end = threading.Event()
|
||||
self.event_on_commit_question = threading.Event()
|
||||
|
||||
def __del__(self):
|
||||
self.stop = True
|
||||
|
||||
def init(self, chatbot):
|
||||
# 初始化音频采集线程
|
||||
self.captured_audio = np.array([])
|
||||
self.keep_latest_n_second = 10
|
||||
self.commit_after_pause_n_second = 1.5
|
||||
self.ready_audio_flagment = None
|
||||
self.stop = False
|
||||
self.plugin_wd = WatchDog(timeout=5, bark_fn=self.__del__, msg="程序终止")
|
||||
self.aut = threading.Thread(target=self.audio_convertion_thread, args=(chatbot._cookies['uuid'],))
|
||||
self.aut.daemon = True
|
||||
self.aut.start()
|
||||
# th2 = threading.Thread(target=self.audio2txt_thread, args=(chatbot._cookies['uuid'],))
|
||||
# th2.daemon = True
|
||||
# th2.start()
|
||||
|
||||
def no_audio_for_a_while(self):
|
||||
if len(self.buffered_sentence) < 7: # 如果一句话小于7个字,暂不提交
|
||||
self.commit_wd.begin_watch()
|
||||
else:
|
||||
self.event_on_commit_question.set()
|
||||
|
||||
def begin(self, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
# main plugin function
|
||||
self.init(chatbot)
|
||||
chatbot.append(["[请讲话]", "[正在等您说完问题]"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
self.plugin_wd.begin_watch()
|
||||
self.agt = AsyncGptTask()
|
||||
self.commit_wd = WatchDog(timeout=self.commit_after_pause_n_second, bark_fn=self.no_audio_for_a_while, interval=0.2)
|
||||
self.commit_wd.begin_watch()
|
||||
|
||||
while True:
|
||||
self.event_on_result_chg.wait(timeout=0.25) # run once every 0.25 second
|
||||
chatbot = self.agt.update_chatbot(chatbot) # 将子线程的gpt结果写入chatbot
|
||||
history = chatbot2history(chatbot)
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
self.plugin_wd.feed()
|
||||
|
||||
if self.event_on_result_chg.is_set():
|
||||
# update audio decode result
|
||||
self.event_on_result_chg.clear()
|
||||
chatbot[-1] = list(chatbot[-1])
|
||||
chatbot[-1][0] = self.buffered_sentence + self.parsed_text
|
||||
history = chatbot2history(chatbot)
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
self.commit_wd.feed()
|
||||
|
||||
if self.event_on_entence_end.is_set():
|
||||
# called when a sentence has ended
|
||||
self.event_on_entence_end.clear()
|
||||
self.parsed_text = self.parsed_sentence
|
||||
self.buffered_sentence += self.parsed_sentence
|
||||
|
||||
if self.event_on_commit_question.is_set():
|
||||
# called when a question should be commited
|
||||
self.event_on_commit_question.clear()
|
||||
if len(self.buffered_sentence) == 0: raise RuntimeError
|
||||
|
||||
self.commit_wd.begin_watch()
|
||||
chatbot[-1] = list(chatbot[-1])
|
||||
chatbot[-1] = [self.buffered_sentence, "[等待GPT响应]"]
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
# add gpt task 创建子线程请求gpt,避免线程阻塞
|
||||
history = chatbot2history(chatbot)
|
||||
self.agt.add_async_gpt_task(self.buffered_sentence, len(chatbot)-1, llm_kwargs, history, system_prompt)
|
||||
|
||||
self.buffered_sentence = ""
|
||||
chatbot.append(["[请讲话]", "[正在等您说完问题]"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
|
||||
|
||||
|
||||
@CatchException
|
||||
def 语音助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
# pip install -U openai-whisper
|
||||
chatbot.append(["对话助手函数插件:使用时,双手离开鼠标键盘吧", "音频助手, 正在听您讲话(点击“停止”键可终止程序)..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
||||
try:
|
||||
import nls
|
||||
from scipy import io
|
||||
except:
|
||||
chatbot.append(["导入依赖失败", "使用该模块需要额外依赖, 安装方法:```pip install --upgrade pyOpenSSL scipy git+https://github.com/aliyun/alibabacloud-nls-python-sdk.git```"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
|
||||
TOKEN, APPKEY = get_conf('ALIYUN_TOKEN', 'ALIYUN_APPKEY')
|
||||
if TOKEN == "" or APPKEY == "":
|
||||
chatbot.append(["导入依赖失败", "没有阿里云语音识别APPKEY和TOKEN, 详情见https://help.aliyun.com/document_detail/450255.html"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
ia = InterviewAssistant()
|
||||
yield from ia.begin(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||
|
||||
@ -1890,7 +1890,6 @@
|
||||
"暗色模式 / 亮色模式": "Dark mode / Light mode",
|
||||
"检测到arxiv文档连接": "Detected arXiv document link",
|
||||
"此插件Windows支持最佳": "This plugin has best support for Windows",
|
||||
"from crazy_functions.虚空终端 import 终端": "from crazy_functions.null_terminal import Terminal",
|
||||
"本地论文翻译": "Local paper translation",
|
||||
"输出html调试文件": "Output HTML debugging file",
|
||||
"以下所有配置也都支持利用环境变量覆写": "All the following configurations can also be overridden using environment variables",
|
||||
@ -1956,5 +1955,184 @@
|
||||
"填入ENGINE": "Fill in ENGINE",
|
||||
"填入api版本": "Fill in the API version",
|
||||
"中文Bing版": "Chinese Bing version",
|
||||
"当前支持的格式包括": "Currently supported formats include"
|
||||
"当前支持的格式包括": "Currently supported formats include",
|
||||
"交互功能模板函数": "InteractiveFunctionTemplateFunction",
|
||||
"交互功能函数模板": "InteractiveFunctionFunctionTemplate",
|
||||
"语音助手": "VoiceAssistant",
|
||||
"微调数据集生成": "FineTuneDatasetGeneration",
|
||||
"chatglm微调工具": "ChatGLMFineTuningTool",
|
||||
"解析一个Rect项目": "ParseARectProject",
|
||||
"用于灵活调整复杂功能的各种参数": "Various parameters for flexible adjustment of complex functions",
|
||||
"阿里云实时语音识别 配置难度较高 仅建议高手用户使用 参考 https": "Aliyun real-time speech recognition has high configuration difficulty and is only recommended for advanced users. Refer to https",
|
||||
"/* 设定聊天气泡的样式": "/* Set the style of chat bubbles",
|
||||
"解析整个React项目": "Parsing the entire React project",
|
||||
"后来也传播到了其他游戏中": "Later it spread to other games",
|
||||
"多线程demo": "Multithreading demo",
|
||||
"最初起源于游戏《CSGO》": "Originally originated from the game 'CSGO'",
|
||||
"首先你在中文语境下通读整篇论文": "First, you read the entire paper in a Chinese context",
|
||||
"避免线程阻塞": "Avoid thread blocking",
|
||||
"其中第1份搜索结果中的一篇文章也指出": "One article in the first search result also points out",
|
||||
"3. 更新速度": "3. Update speed",
|
||||
"极少数情况下": "In very few cases",
|
||||
"例如 RoPlZrM88DnAFkZK": "For example, RoPlZrM88DnAFkZK",
|
||||
"每秒采样数量": "Number of samples per second",
|
||||
"以及其中不少玩家以极端立场宣扬的想法和言论": "And many players promote extreme ideas and opinions",
|
||||
"音频助手": "Audio assistant",
|
||||
"还需要填写组织": "Also need to fill in the organization",
|
||||
"计算平方和": "Calculate the sum of squares",
|
||||
"请向下翻": "Please scroll down",
|
||||
"起源于哪里": "Where does it originate from",
|
||||
"并且网友们还 给这位UP主起了“苍蝇侠”的外号": "And netizens also gave this UP host the nickname 'Flyman'",
|
||||
"第一次调用": "First call",
|
||||
"点击“停止”键可终止程序": "Click the 'Stop' button to terminate the program",
|
||||
"在执行完成之后": "After execution is complete",
|
||||
"老六是网络流行语": "Lao Liu is an internet slang",
|
||||
"请先将.doc文档转换为.docx文档": "Please convert the .doc document to .docx document first",
|
||||
"从第0份、第1份、第2份搜索结果可以看出": "It can be seen from the 0th, 1st, and 2nd search results",
|
||||
"如 绿帽子*深蓝色衬衫*黑色运动裤": "Such as green hat * dark blue shirt * black sports pants",
|
||||
"对这个人外貌、身处的环境、内心世界、过去经历进行描写": "Describing the appearance, environment, inner world, and past experiences of this person",
|
||||
"沙特和伊朗于3月10日达成了恢复两国外交关系的协议": "Saudi Arabia and Iran reached an agreement on March 10th to restore diplomatic relations between the two countries",
|
||||
"例如 f37f30e0f9934c34a992f6f64f7eba4f": "For example, f37f30e0f9934c34a992f6f64f7eba4f",
|
||||
"避免多用户干扰": "Avoiding interference from multiple users",
|
||||
"AutoGPT是一个基于GPT-4语言模型的开源应用程序": "AutoGPT is an open-source application based on the GPT-4 language model",
|
||||
"开始最终总结": "Start the final summary",
|
||||
"包括事件分析、营销方案撰写、代码编程、数学运算等等": "Including event analysis, writing marketing plans, coding, mathematical calculations, etc.",
|
||||
"包括背景颜色、内、外边距、圆角": "Including background color, padding, margin, and border-radius",
|
||||
"只读": "Read-only",
|
||||
"/* 设置表头单元格的内边距为0.5em和0.2em. */": "/* Set the padding of the table header cells to 0.5em and 0.2em. */",
|
||||
"游戏中的一个情节": "A plot in the game",
|
||||
"如果一句话小于7个字": "If a sentence is less than 7 words",
|
||||
"赋予插件状态": "Assigning plugin status",
|
||||
"用第二人称": "Using the second person",
|
||||
"上传本地文件/压缩包供函数插件调用": "Upload local files/compressed packages for function plugin calls",
|
||||
"要求": "Requirements",
|
||||
"正在等您说完问题": "Waiting for you to finish the question",
|
||||
"后来逐渐演变成指玩得比较阴险的玩家": "Later it gradually evolved to refer to players who play in a more cunning way",
|
||||
"当下一次用户提交时": "When the next user submits",
|
||||
"色彩主体": "Color theme",
|
||||
"例如 “Espe-": "For example, 'Espe-'",
|
||||
"为每一位访问的用户赋予一个独一无二的uuid编码": "Assign a unique UUID code to each visiting user",
|
||||
"清空label": "Clear label",
|
||||
"对话助手函数插件": "Conversation assistant function plugin",
|
||||
"Chuanhu-Small-and-Beautiful主题": "Chuanhu-Small-and-Beautiful theme",
|
||||
"初始化音频采集线程": "Initialize the audio capture thread",
|
||||
"“我们称之为高效”这 一用语来源于群星": "The phrase 'we call it efficient' comes from the stars",
|
||||
"给观众留下了等待的时间过长的印象": "Leaving the audience with the impression of waiting for too long",
|
||||
"在一个异步线程中采集音频": "Collect audio in an asynchronous thread",
|
||||
"正在听您讲话": "Listening to you speak",
|
||||
"指游戏中玩家中独来独往、游离于队伍之外的“自由人”或玩得比较菜或者玩得比较阴险的人": "Refers to players in the game who are independent, play poorly, or play cunningly and stay away from the team",
|
||||
"函数插件模板Demo": "Function plugin template Demo",
|
||||
"比如巨像": "Such as giant",
|
||||
"格式如org-123456789abcdefghijklmno的": "The format is like org-123456789abcdefghijklmno",
|
||||
"插件锁定中": "Plugin locked",
|
||||
"/* 去掉列表前缀的默认间距": "/* Remove the default spacing of the list prefix",
|
||||
"内部单元格之间边框合并": "Merge borders between internal cells",
|
||||
"可以将自身的状态存储到cookie中": "Can store its own state in a cookie",
|
||||
"解除插件锁定": "Unlock plugin",
|
||||
"/* 设定代码块的样式": "/* Set the style of the code block",
|
||||
"右下角更换模型菜单中可切换openai": "You can switch openai in the lower right corner model menu",
|
||||
"颜色为--border-color-primary. */": "The color is --border-color-primary. */",
|
||||
"引用了一群游戏玩家对于需要对P社玩家进行枪毙的讨论": "Quoted a discussion among a group of gamers about the need to execute P community players",
|
||||
"这个话题的本质是玩家们对于P 社游戏中的政治与历史元素的不同看法": "The essence of this topic is the different views of players on the political and historical elements in P community games",
|
||||
"我将为您查找相关壁纸": "I will find related wallpapers for you",
|
||||
"640个字节为一组": "640 bytes per group",
|
||||
"赋予插件锁定 锁定插件回调路径": "Assign plugin lock, lock plugin callback path",
|
||||
"等游戏": "Waiting for the game",
|
||||
"将文件添加到chatbot cookie中": "Add file to chatbot cookie",
|
||||
"处理个别特殊插件的锁定状态": "Handle the lock status of individual special plugins",
|
||||
"/* 设置表头背景颜色为rgba": "/* Set the background color of the table header to rgba",
|
||||
"读 docs\\use_azure.md": "Read docs\\use_azure.md",
|
||||
"边框粗细为1.2px": "The border thickness is 1.2px",
|
||||
"这个用语最初可能是在群星": "This term may have originated in the stars",
|
||||
"请输入关键词": "Please enter keywords",
|
||||
"给出实现的步骤和实现细节": "Provide implementation steps and details",
|
||||
"可以得知失败的man是指一位在B站购买了蜘蛛侠COS服后穿上后被网友嘲笑的UP主": "The 'failed man' refers to a UP main who bought a Spiderman COS costume on Bilibili and was mocked by netizens after wearing it.",
|
||||
"最近它在GitHub上爆火": "Recently, it has become popular on GitHub.",
|
||||
"提取总结": "Extract summary",
|
||||
"正在锁定插件": "Locking the plugin",
|
||||
"给出指令": "Give instructions",
|
||||
"避免遗忘导致死锁": "Avoid forgetting causing deadlock",
|
||||
"建议直接在API_KEY处填写": "It is recommended to fill in directly at API_KEY",
|
||||
"详情见https": "For details, see https",
|
||||
"因此有人就以枪毙这些玩家来回应此类言论": "Therefore, some people respond to such comments by executing these players",
|
||||
"而“失败的man”是蜘蛛侠英文名“spiderman”的谐音梗": "And 'failed man' is a pun on the English name 'Spiderman'",
|
||||
"格式如org-xxxxxxxxxxxxxxxxxxxxxxxx": "The format is like org-xxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"根据第1份搜索结果": "According to the first search result",
|
||||
"GPT 学术优化": "GPT academic optimization",
|
||||
"等待用户的再次调用": "Waiting for user's call again",
|
||||
"把文件复制过去": "Copy the file over",
|
||||
"该选项即将被弃用": "This option will be deprecated",
|
||||
"对这个人外貌、身处的环境、内心世界、人设进行描写": "Describe this person's appearance, environment, inner world, and character setting",
|
||||
"例如您可以将以下命令复制到下方": "For example, you can copy the following command below",
|
||||
"使用": "Use",
|
||||
"找不到": "Not found",
|
||||
"正常状态": "Normal state",
|
||||
"记住当前的label": "Remember the current label",
|
||||
"成为了业内最热门的项目之一": "Has become one of the hottest projects in the industry",
|
||||
"它可以根据用户需求自主执行任务": "It can independently perform tasks according to user needs",
|
||||
"时而快时而慢": "Sometimes fast, sometimes slow",
|
||||
"它们都是关于一个知乎用户所发的帖子": "They are all about a post made by a Zhihu user",
|
||||
"解决插件锁定时的界面显示问题": "Resolve interface display issues when the plugin is locked",
|
||||
"azure和api2d请求源": "Azure and api2d request sources",
|
||||
"并不应该被当做真实的态度或者观点": "Should not be taken as a real attitude or viewpoint",
|
||||
"没有阿里云语音识别APPKEY和TOKEN": "No Alibaba Cloud Speech Recognition APPKEY and TOKEN",
|
||||
"Rainbow Six Siege 游戏中 Smoke 的 Canister 中装有何种物质相关的官方信息": "Official information related to the substance contained in Smoke's Canister in the game Rainbow Six Siege",
|
||||
"此处填API密钥": "Fill in API key here",
|
||||
"使其与文本线对齐. */": "Align it with the text line. */",
|
||||
"先删除": "Delete first",
|
||||
"100字以内": "Within 100 characters",
|
||||
"并完全不需要用户插手": "And does not require user intervention at all",
|
||||
"插件可读取“输入区”文本/路径作为参数": "The plugin can read the text/path in the 'input area' as a parameter",
|
||||
"但是这个话题本身并没有实质内容": "But this topic itself has no substantive content",
|
||||
"会直接转到该函数": "Will directly go to that function",
|
||||
"因此这种说法没有实际意义": "Therefore, this statement has no practical meaning",
|
||||
"它可以自己思考": "It can think for itself",
|
||||
"/* 行内代码的背景设为淡灰色": "/* Set the background of inline code to light gray",
|
||||
"请直接提交即可": "Please submit directly",
|
||||
"计算欧几里得距离矩阵": "Calculate the Euclidean distance matrix",
|
||||
"空单元格显示. */": "Display empty cells. */",
|
||||
"openai的官方KEY需要伴随组织编码": "The official KEY of OpenAI needs to be accompanied by organizational coding",
|
||||
"最近在中国的斡旋下": "Recently, under the mediation of China",
|
||||
"将子线程的gpt结果写入chatbot": "Write the GPT results of the sub-thread into the chatbot",
|
||||
"罗小黑战记的更新时间不定": "The update time of Luo Xiaohei's Adventure is uncertain",
|
||||
"/* 设置表格单元格的内边距为5px": "/* Set the inner padding of table cells to 5px",
|
||||
"透明度为0.2. */": "Opacity is 0.2. */",
|
||||
"暂不提交": "Do not submit temporarily",
|
||||
"用户们用来形容一些游戏策略或行为非常高效且能够带来好的效果的用语": "Terms used by users to describe game strategies or behaviors that are very efficient and can bring good results",
|
||||
"等待GPT响应": "Waiting for GPT response",
|
||||
"计算输入数组X中所有样本点的两两距离矩阵": "Calculate the pairwise distance matrix of all sample points in the input array X",
|
||||
"因此": "Therefore",
|
||||
"包括圆角、最大宽度和阴影等. */": "Including rounded corners, maximum width, and shadows, etc. */",
|
||||
"初始化插件状态": "Initialize plugin status",
|
||||
"请讲话": "Please speak",
|
||||
"可选": "Optional",
|
||||
"甚至可以自问自答执 行任务": "You can even ask and answer questions to perform tasks",
|
||||
"add gpt task 创建子线程请求gpt": "add gpt task Create a sub-thread request gpt",
|
||||
"失败的man是指这位UP主在穿上蜘蛛侠COS服后被网友嘲笑的情况": "The failed man refers to the situation where this UP master was mocked by netizens after wearing Spider-Man COS clothes",
|
||||
"这个游戏里面流行起来的": "Popular in this game",
|
||||
"获取关键词": "Get keywords",
|
||||
"设定圆角和间距. */": "Set rounded corners and spacing. */",
|
||||
"/* 设置表格的外边距为1em": "/* Set the margin of the table to 1em",
|
||||
"只是一个玩笑或者恶搞": "Just a joke or a prank",
|
||||
"双手离开鼠标键盘吧": "Take your hands off the mouse and keyboard",
|
||||
"上传文件自动修正路径": "Automatically correct the path when uploading files",
|
||||
"找不到任何Rect文件": "Cannot find any Rect files",
|
||||
"先上传数据集": "Upload the dataset first",
|
||||
"如果已经存在": "If it already exists",
|
||||
"使用时": "When using",
|
||||
"解除插件状态": "Release plugin status",
|
||||
"cially” 转换为 “Especially”": "Convert 'cially' to 'Especially'",
|
||||
"这表明两国关系已经重新回到正常化状态": "This indicates that the relationship between the two countries has returned to a normal state",
|
||||
"由于提问含不合规内容被Azure过滤": "Due to the Azure filtering of questions containing non-compliant content",
|
||||
"填入你亲手写的部署名": "Fill in the deployment name you wrote by hand",
|
||||
"想象一个穿着者": "Imagine a wearer",
|
||||
"建议使用英文单词": "It is recommended to use English words",
|
||||
"没给定指令": "No instruction given",
|
||||
"实时音频采集": "Real-time audio collection",
|
||||
"“我们称之为高效”是指在游戏社区中": "'We call it efficient' refers to the game community",
|
||||
"找 API_ORG 设置项": "Find API_ORG settings",
|
||||
"虚空终端": "VoidTerminal",
|
||||
"请查看terminal的输出或耐心等待": "Please check the output of the terminal or wait patiently",
|
||||
"发送到openai音频解析terminal": "Send to openai audio parsing terminal",
|
||||
"超级terminal": "Super terminal"
|
||||
}
|
||||
12021
ft_ds.json
Normal file
12021
ft_ds.json
Normal file
File diff suppressed because it is too large
Load Diff
29
main.py
29
main.py
@ -8,18 +8,19 @@ def main():
|
||||
# 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
|
||||
proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT, AVAIL_LLM_MODELS, AUTO_CLEAR_TXT = \
|
||||
get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT', 'AVAIL_LLM_MODELS', 'AUTO_CLEAR_TXT')
|
||||
|
||||
ENABLE_AUDIO, AUTO_CLEAR_TXT = get_conf('ENABLE_AUDIO', 'AUTO_CLEAR_TXT')
|
||||
# 如果WEB_PORT是-1, 则随机选取WEB端口
|
||||
PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT
|
||||
if not AUTHENTICATION: AUTHENTICATION = None
|
||||
|
||||
from check_proxy import get_current_version
|
||||
from theme.theme import adjust_theme, advanced_css, theme_declaration
|
||||
initial_prompt = "Serve me as a writing and programming assistant."
|
||||
title_html = f"<h1 align=\"center\">ChatGPT 学术优化 {get_current_version()}</h1>"
|
||||
title_html = f"<h1 align=\"center\">GPT 学术优化 {get_current_version()}</h1>{theme_declaration}"
|
||||
description = """代码开源和更新[地址🚀](https://github.com/binary-husky/chatgpt_academic),感谢热情的[开发者们❤️](https://github.com/binary-husky/chatgpt_academic/graphs/contributors)"""
|
||||
|
||||
# 问询记录, python 版本建议3.9+(越新越好)
|
||||
import logging
|
||||
import logging, uuid
|
||||
os.makedirs("gpt_log", exist_ok=True)
|
||||
try:logging.basicConfig(filename="gpt_log/chat_secrets.log", level=logging.INFO, encoding="utf-8")
|
||||
except:logging.basicConfig(filename="gpt_log/chat_secrets.log", level=logging.INFO)
|
||||
@ -37,7 +38,6 @@ def main():
|
||||
gr.Chatbot.postprocess = format_io
|
||||
|
||||
# 做一些外观色彩上的调整
|
||||
from theme import adjust_theme, advanced_css
|
||||
set_theme = adjust_theme()
|
||||
|
||||
# 代理与自动更新
|
||||
@ -70,6 +70,9 @@ def main():
|
||||
resetBtn = gr.Button("重置", variant="secondary"); resetBtn.style(size="sm")
|
||||
stopBtn = gr.Button("停止", variant="secondary"); stopBtn.style(size="sm")
|
||||
clearBtn = gr.Button("清除", variant="secondary", visible=False); clearBtn.style(size="sm")
|
||||
if ENABLE_AUDIO:
|
||||
with gr.Row():
|
||||
audio_mic = gr.Audio(source="microphone", type="numpy", streaming=True, show_label=False)
|
||||
with gr.Row():
|
||||
status = gr.Markdown(f"Tip: 按Enter提交, 按Shift+Enter换行。当前模型: {LLM_MODEL} \n {proxy_info}", elem_id="state-panel")
|
||||
with gr.Accordion("基础功能区", open=True, elem_id="basic-panel") as area_basic_fn:
|
||||
@ -80,7 +83,7 @@ def main():
|
||||
functional[k]["Button"] = gr.Button(k, variant=variant)
|
||||
with gr.Accordion("函数插件区", open=True, elem_id="plugin-panel") as area_crazy_fn:
|
||||
with gr.Row():
|
||||
gr.Markdown("注意:以下“红颜色”标识的函数插件需从输入区读取路径作为参数.")
|
||||
gr.Markdown("插件可读取“输入区”文本/路径作为参数(上传文件自动修正路径)")
|
||||
with gr.Row():
|
||||
for k in crazy_fns:
|
||||
if not crazy_fns[k].get("AsButton", True): continue
|
||||
@ -91,14 +94,14 @@ def main():
|
||||
with gr.Accordion("更多函数插件", open=True):
|
||||
dropdown_fn_list = [k for k in crazy_fns.keys() if not crazy_fns[k].get("AsButton", True)]
|
||||
with gr.Row():
|
||||
dropdown = gr.Dropdown(dropdown_fn_list, value=r"打开插件列表", label="").style(container=False)
|
||||
dropdown = gr.Dropdown(dropdown_fn_list, value=r"打开插件列表", label="", show_label=False).style(container=False)
|
||||
with gr.Row():
|
||||
plugin_advanced_arg = gr.Textbox(show_label=True, label="高级参数输入区", visible=False,
|
||||
placeholder="这里是特殊函数插件的高级参数输入区").style(container=False)
|
||||
with gr.Row():
|
||||
switchy_bt = gr.Button(r"请先从插件列表中选择", variant="secondary")
|
||||
with gr.Row():
|
||||
with gr.Accordion("点击展开“文件上传区”。上传本地文件可供红色函数插件调用。", open=False) as area_file_up:
|
||||
with gr.Accordion("点击展开“文件上传区”。上传本地文件/压缩包供函数插件调用。", open=False) as area_file_up:
|
||||
file_upload = gr.Files(label="任何文件, 但推荐上传压缩文件(zip, tar)", file_count="multiple")
|
||||
with gr.Accordion("更换模型 & SysPrompt & 交互界面布局", open=(LAYOUT == "TOP-DOWN"), elem_id="interact-panel"):
|
||||
system_prompt = gr.Textbox(show_label=True, placeholder=f"System Prompt", label="System prompt", value=initial_prompt)
|
||||
@ -185,6 +188,18 @@ def main():
|
||||
# 终止按钮的回调函数注册
|
||||
stopBtn.click(fn=None, inputs=None, outputs=None, cancels=cancel_handles)
|
||||
stopBtn2.click(fn=None, inputs=None, outputs=None, cancels=cancel_handles)
|
||||
if ENABLE_AUDIO:
|
||||
from crazy_functions.live_audio.audio_io import RealtimeAudioDistribution
|
||||
rad = RealtimeAudioDistribution()
|
||||
def deal_audio(audio, cookies):
|
||||
rad.feed(cookies['uuid'].hex, audio)
|
||||
audio_mic.stream(deal_audio, inputs=[audio_mic, cookies])
|
||||
|
||||
def init_cookie(cookies, chatbot):
|
||||
# 为每一位访问的用户赋予一个独一无二的uuid编码
|
||||
cookies.update({'uuid': uuid.uuid4()})
|
||||
return cookies
|
||||
demo.load(init_cookie, inputs=[cookies, chatbot], outputs=[cookies])
|
||||
demo.load(lambda: 0, inputs=None, outputs=None, _js='()=>{ChatBotHeight();}')
|
||||
|
||||
# gradio的inbrowser触发不太稳定,回滚代码到原始的浏览器打开函数
|
||||
|
||||
BIN
objdump.tmp
Normal file
BIN
objdump.tmp
Normal file
Binary file not shown.
@ -269,6 +269,24 @@ if "newbing" in AVAIL_LLM_MODELS: # same with newbing-free
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
if "chatglmft" in AVAIL_LLM_MODELS: # same with newbing-free
|
||||
try:
|
||||
from .bridge_chatglmft import predict_no_ui_long_connection as chatglmft_noui
|
||||
from .bridge_chatglmft import predict as chatglmft_ui
|
||||
# claude
|
||||
model_info.update({
|
||||
"chatglmft": {
|
||||
"fn_with_ui": chatglmft_ui,
|
||||
"fn_without_ui": chatglmft_noui,
|
||||
"endpoint": None,
|
||||
"max_token": 4096,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
|
||||
|
||||
def LLM_CATCH_EXCEPTION(f):
|
||||
"""
|
||||
@ -372,6 +390,6 @@ def predict(inputs, llm_kwargs, *args, **kwargs):
|
||||
additional_fn代表点击的哪个按钮,按钮见functional.py
|
||||
"""
|
||||
|
||||
method = model_info[llm_kwargs['llm_model']]["fn_with_ui"]
|
||||
method = model_info[llm_kwargs['llm_model']]["fn_with_ui"] # 如果这里报错,检查config中的AVAIL_LLM_MODELS选项
|
||||
yield from method(inputs, llm_kwargs, *args, **kwargs)
|
||||
|
||||
|
||||
210
request_llm/bridge_chatglmft.py
Normal file
210
request_llm/bridge_chatglmft.py
Normal file
@ -0,0 +1,210 @@
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
import threading
|
||||
import importlib
|
||||
from toolbox import update_ui, get_conf
|
||||
from multiprocessing import Process, Pipe
|
||||
|
||||
load_message = "ChatGLMFT尚未加载,加载需要一段时间。注意,取决于`config.py`的配置,ChatGLMFT消耗大量的内存(CPU)或显存(GPU),也许会导致低配计算机卡死 ……"
|
||||
|
||||
def string_to_options(arguments):
|
||||
import argparse
|
||||
import shlex
|
||||
# Create an argparse.ArgumentParser instance
|
||||
parser = argparse.ArgumentParser()
|
||||
# Add command-line arguments
|
||||
parser.add_argument("--llm_to_learn", type=str, help="LLM model to learn", default="gpt-3.5-turbo")
|
||||
parser.add_argument("--prompt_prefix", type=str, help="Prompt prefix", default='')
|
||||
parser.add_argument("--system_prompt", type=str, help="System prompt", default='')
|
||||
parser.add_argument("--batch", type=int, help="System prompt", default=50)
|
||||
# Parse the arguments
|
||||
args = parser.parse_args(shlex.split(arguments))
|
||||
return args
|
||||
|
||||
|
||||
#################################################################################
|
||||
class GetGLMFTHandle(Process):
|
||||
def __init__(self):
|
||||
super().__init__(daemon=True)
|
||||
self.parent, self.child = Pipe()
|
||||
self.chatglmft_model = None
|
||||
self.chatglmft_tokenizer = None
|
||||
self.info = ""
|
||||
self.success = True
|
||||
self.check_dependency()
|
||||
self.start()
|
||||
self.threadLock = threading.Lock()
|
||||
|
||||
def check_dependency(self):
|
||||
try:
|
||||
import sentencepiece
|
||||
self.info = "依赖检测通过"
|
||||
self.success = True
|
||||
except:
|
||||
self.info = "缺少ChatGLMFT的依赖,如果要使用ChatGLMFT,除了基础的pip依赖以外,您还需要运行`pip install -r request_llm/requirements_chatglm.txt`安装ChatGLM的依赖。"
|
||||
self.success = False
|
||||
|
||||
def ready(self):
|
||||
return self.chatglmft_model is not None
|
||||
|
||||
def run(self):
|
||||
# 子进程执行
|
||||
# 第一次运行,加载参数
|
||||
retry = 0
|
||||
while True:
|
||||
try:
|
||||
if self.chatglmft_model is None:
|
||||
from transformers import AutoConfig
|
||||
import torch
|
||||
# conf = 'request_llm/current_ptune_model.json'
|
||||
# if not os.path.exists(conf): raise RuntimeError('找不到微调模型信息')
|
||||
# with open(conf, 'r', encoding='utf8') as f:
|
||||
# model_args = json.loads(f.read())
|
||||
ChatGLM_PTUNING_CHECKPOINT, = get_conf('ChatGLM_PTUNING_CHECKPOINT')
|
||||
conf = os.path.join(ChatGLM_PTUNING_CHECKPOINT, "config.json")
|
||||
with open(conf, 'r', encoding='utf8') as f:
|
||||
model_args_ = json.loads(f.read())
|
||||
model_args_.update(model_args)
|
||||
model_args = model_args_
|
||||
|
||||
self.chatglmft_tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args['model_name_or_path'], trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args['model_name_or_path'], trust_remote_code=True)
|
||||
|
||||
config.pre_seq_len = model_args['pre_seq_len']
|
||||
config.prefix_projection = model_args['prefix_projection']
|
||||
|
||||
print(f"Loading prefix_encoder weight from {ChatGLM_PTUNING_CHECKPOINT}")
|
||||
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
|
||||
prefix_state_dict = torch.load(os.path.join(ChatGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
|
||||
new_prefix_state_dict = {}
|
||||
for k, v in prefix_state_dict.items():
|
||||
if k.startswith("transformer.prefix_encoder."):
|
||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||
|
||||
if model_args['quantization_bit'] is not None:
|
||||
print(f"Quantized to {model_args['quantization_bit']} bit")
|
||||
model = model.quantize(model_args['quantization_bit'])
|
||||
model = model.cuda()
|
||||
if model_args['pre_seq_len'] is not None:
|
||||
# P-tuning v2
|
||||
model.transformer.prefix_encoder.float()
|
||||
self.chatglmft_model = model.eval()
|
||||
|
||||
break
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
retry += 1
|
||||
if retry > 3:
|
||||
self.child.send('[Local Message] Call ChatGLMFT fail 不能正常加载ChatGLMFT的参数。')
|
||||
raise RuntimeError("不能正常加载ChatGLMFT的参数!")
|
||||
|
||||
while True:
|
||||
# 进入任务等待状态
|
||||
kwargs = self.child.recv()
|
||||
# 收到消息,开始请求
|
||||
try:
|
||||
for response, history in self.chatglmft_model.stream_chat(self.chatglmft_tokenizer, **kwargs):
|
||||
self.child.send(response)
|
||||
# # 中途接收可能的终止指令(如果有的话)
|
||||
# if self.child.poll():
|
||||
# command = self.child.recv()
|
||||
# if command == '[Terminate]': break
|
||||
except:
|
||||
from toolbox import trimmed_format_exc
|
||||
self.child.send('[Local Message] Call ChatGLMFT fail.' + '\n```\n' + trimmed_format_exc() + '\n```\n')
|
||||
# 请求处理结束,开始下一个循环
|
||||
self.child.send('[Finish]')
|
||||
|
||||
def stream_chat(self, **kwargs):
|
||||
# 主进程执行
|
||||
self.threadLock.acquire()
|
||||
self.parent.send(kwargs)
|
||||
while True:
|
||||
res = self.parent.recv()
|
||||
if res != '[Finish]':
|
||||
yield res
|
||||
else:
|
||||
break
|
||||
self.threadLock.release()
|
||||
|
||||
global glmft_handle
|
||||
glmft_handle = None
|
||||
#################################################################################
|
||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
|
||||
"""
|
||||
多线程方法
|
||||
函数的说明请见 request_llm/bridge_all.py
|
||||
"""
|
||||
global glmft_handle
|
||||
if glmft_handle is None:
|
||||
glmft_handle = GetGLMFTHandle()
|
||||
if len(observe_window) >= 1: observe_window[0] = load_message + "\n\n" + glmft_handle.info
|
||||
if not glmft_handle.success:
|
||||
error = glmft_handle.info
|
||||
glmft_handle = None
|
||||
raise RuntimeError(error)
|
||||
|
||||
# chatglmft 没有 sys_prompt 接口,因此把prompt加入 history
|
||||
history_feedin = []
|
||||
history_feedin.append(["What can I do?", sys_prompt])
|
||||
for i in range(len(history)//2):
|
||||
history_feedin.append([history[2*i], history[2*i+1]] )
|
||||
|
||||
watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可
|
||||
response = ""
|
||||
for response in glmft_handle.stream_chat(query=inputs, history=history_feedin, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
|
||||
if len(observe_window) >= 1: observe_window[0] = response
|
||||
if len(observe_window) >= 2:
|
||||
if (time.time()-observe_window[1]) > watch_dog_patience:
|
||||
raise RuntimeError("程序终止。")
|
||||
return response
|
||||
|
||||
|
||||
|
||||
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
|
||||
"""
|
||||
单线程方法
|
||||
函数的说明请见 request_llm/bridge_all.py
|
||||
"""
|
||||
chatbot.append((inputs, ""))
|
||||
|
||||
global glmft_handle
|
||||
if glmft_handle is None:
|
||||
glmft_handle = GetGLMFTHandle()
|
||||
chatbot[-1] = (inputs, load_message + "\n\n" + glmft_handle.info)
|
||||
yield from update_ui(chatbot=chatbot, history=[])
|
||||
if not glmft_handle.success:
|
||||
glmft_handle = None
|
||||
return
|
||||
|
||||
if additional_fn is not None:
|
||||
import core_functional
|
||||
importlib.reload(core_functional) # 热更新prompt
|
||||
core_functional = core_functional.get_core_functions()
|
||||
if "PreProcess" in core_functional[additional_fn]: inputs = core_functional[additional_fn]["PreProcess"](inputs) # 获取预处理函数(如果有的话)
|
||||
inputs = core_functional[additional_fn]["Prefix"] + inputs + core_functional[additional_fn]["Suffix"]
|
||||
|
||||
# 处理历史信息
|
||||
history_feedin = []
|
||||
history_feedin.append(["What can I do?", system_prompt] )
|
||||
for i in range(len(history)//2):
|
||||
history_feedin.append([history[2*i], history[2*i+1]] )
|
||||
|
||||
# 开始接收chatglmft的回复
|
||||
response = "[Local Message]: 等待ChatGLMFT响应中 ..."
|
||||
for response in glmft_handle.stream_chat(query=inputs, history=history_feedin, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
|
||||
chatbot[-1] = (inputs, response)
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 总结输出
|
||||
if response == "[Local Message]: 等待ChatGLMFT响应中 ...":
|
||||
response = "[Local Message]: ChatGLMFT响应异常 ..."
|
||||
history.extend([inputs, response])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
5
request_llm/current_ptune_model.json
Normal file
5
request_llm/current_ptune_model.json
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"model_name_or_path": "THUDM/chatglm2-6b",
|
||||
"pre_seq_len": "128",
|
||||
"ptuning_checkpoint": "/home/hmp/ChatGLM2-6B/ptuning/output/clothgen-chatglm2-6b-pt-128-2e-2/checkpoint-100"
|
||||
}
|
||||
@ -10,7 +10,8 @@ def validate_path():
|
||||
|
||||
validate_path() # validate path so you can run from base directory
|
||||
if __name__ == "__main__":
|
||||
from request_llm.bridge_newbingfree import predict_no_ui_long_connection
|
||||
from request_llm.bridge_chatglmft import predict_no_ui_long_connection
|
||||
# from request_llm.bridge_newbingfree import predict_no_ui_long_connection
|
||||
# from request_llm.bridge_moss import predict_no_ui_long_connection
|
||||
# from request_llm.bridge_jittorllms_pangualpha import predict_no_ui_long_connection
|
||||
# from request_llm.bridge_jittorllms_llama import predict_no_ui_long_connection
|
||||
@ -27,52 +28,3 @@ if __name__ == "__main__":
|
||||
sys_prompt="")
|
||||
print('final result:', result)
|
||||
|
||||
|
||||
result = predict_no_ui_long_connection(inputs="what is a hero?",
|
||||
llm_kwargs=llm_kwargs,
|
||||
history=["hello world"],
|
||||
sys_prompt="")
|
||||
print('final result:', result)
|
||||
|
||||
result = predict_no_ui_long_connection(inputs="如何理解传奇?",
|
||||
llm_kwargs=llm_kwargs,
|
||||
history=[],
|
||||
sys_prompt="")
|
||||
print('final result:', result)
|
||||
|
||||
# # print(result)
|
||||
# from multiprocessing import Process, Pipe
|
||||
# class GetGLMHandle(Process):
|
||||
# def __init__(self):
|
||||
# super().__init__(daemon=True)
|
||||
# pass
|
||||
# def run(self):
|
||||
# # 子进程执行
|
||||
# # 第一次运行,加载参数
|
||||
# def validate_path():
|
||||
# import os, sys
|
||||
# dir_name = os.path.dirname(__file__)
|
||||
# root_dir_assume = os.path.abspath(os.path.dirname(__file__) + '/..')
|
||||
# os.chdir(root_dir_assume + '/request_llm/jittorllms')
|
||||
# sys.path.append(root_dir_assume + '/request_llm/jittorllms')
|
||||
# validate_path() # validate path so you can run from base directory
|
||||
|
||||
# jittorllms_model = None
|
||||
# import types
|
||||
# try:
|
||||
# if jittorllms_model is None:
|
||||
# from models import get_model
|
||||
# # availabel_models = ["chatglm", "pangualpha", "llama", "chatrwkv"]
|
||||
# args_dict = {'model': 'chatrwkv'}
|
||||
# print('self.jittorllms_model = get_model(types.SimpleNamespace(**args_dict))')
|
||||
# jittorllms_model = get_model(types.SimpleNamespace(**args_dict))
|
||||
# print('done get model')
|
||||
# except:
|
||||
# # self.child.send('[Local Message] Call jittorllms fail 不能正常加载jittorllms的参数。')
|
||||
# raise RuntimeError("不能正常加载jittorllms的参数!")
|
||||
|
||||
# x = GetGLMHandle()
|
||||
# x.start()
|
||||
|
||||
|
||||
# input()
|
||||
47
theme/common.js
Normal file
47
theme/common.js
Normal file
@ -0,0 +1,47 @@
|
||||
function ChatBotHeight() {
|
||||
function update_height(){
|
||||
var { panel_height_target, chatbot_height, chatbot } = get_elements();
|
||||
if (panel_height_target!=chatbot_height)
|
||||
{
|
||||
var pixelString = panel_height_target.toString() + 'px';
|
||||
chatbot.style.maxHeight = pixelString; chatbot.style.height = pixelString;
|
||||
}
|
||||
}
|
||||
|
||||
function update_height_slow(){
|
||||
var { panel_height_target, chatbot_height, chatbot } = get_elements();
|
||||
if (panel_height_target!=chatbot_height)
|
||||
{
|
||||
new_panel_height = (panel_height_target - chatbot_height)*0.5 + chatbot_height;
|
||||
if (Math.abs(new_panel_height - panel_height_target) < 10){
|
||||
new_panel_height = panel_height_target;
|
||||
}
|
||||
// console.log(chatbot_height, panel_height_target, new_panel_height);
|
||||
var pixelString = new_panel_height.toString() + 'px';
|
||||
chatbot.style.maxHeight = pixelString; chatbot.style.height = pixelString;
|
||||
}
|
||||
}
|
||||
|
||||
update_height();
|
||||
setInterval(function() {
|
||||
update_height_slow()
|
||||
}, 50); // 每100毫秒执行一次
|
||||
}
|
||||
|
||||
function get_elements() {
|
||||
var chatbot = document.querySelector('#gpt-chatbot > div.wrap.svelte-18telvq');
|
||||
if (!chatbot) {
|
||||
chatbot = document.querySelector('#gpt-chatbot');
|
||||
}
|
||||
const panel1 = document.querySelector('#input-panel');
|
||||
const panel2 = document.querySelector('#basic-panel');
|
||||
const panel3 = document.querySelector('#plugin-panel');
|
||||
const panel4 = document.querySelector('#interact-panel');
|
||||
const panel5 = document.querySelector('#input-panel2');
|
||||
const panel_active = document.querySelector('#state-panel');
|
||||
var panel_height_target = (20-panel_active.offsetHeight) + panel1.offsetHeight + panel2.offsetHeight + panel3.offsetHeight + panel4.offsetHeight + panel5.offsetHeight + 21;
|
||||
var panel_height_target = parseInt(panel_height_target);
|
||||
var chatbot_height = chatbot.style.height;
|
||||
var chatbot_height = parseInt(chatbot_height);
|
||||
return { panel_height_target, chatbot_height, chatbot };
|
||||
}
|
||||
@ -1,164 +1,3 @@
|
||||
import gradio as gr
|
||||
from toolbox import get_conf
|
||||
CODE_HIGHLIGHT, ADD_WAIFU, LAYOUT = get_conf('CODE_HIGHLIGHT', 'ADD_WAIFU', 'LAYOUT')
|
||||
# gradio可用颜色列表
|
||||
# gr.themes.utils.colors.slate (石板色)
|
||||
# gr.themes.utils.colors.gray (灰色)
|
||||
# gr.themes.utils.colors.zinc (锌色)
|
||||
# gr.themes.utils.colors.neutral (中性色)
|
||||
# gr.themes.utils.colors.stone (石头色)
|
||||
# gr.themes.utils.colors.red (红色)
|
||||
# gr.themes.utils.colors.orange (橙色)
|
||||
# gr.themes.utils.colors.amber (琥珀色)
|
||||
# gr.themes.utils.colors.yellow (黄色)
|
||||
# gr.themes.utils.colors.lime (酸橙色)
|
||||
# gr.themes.utils.colors.green (绿色)
|
||||
# gr.themes.utils.colors.emerald (祖母绿)
|
||||
# gr.themes.utils.colors.teal (青蓝色)
|
||||
# gr.themes.utils.colors.cyan (青色)
|
||||
# gr.themes.utils.colors.sky (天蓝色)
|
||||
# gr.themes.utils.colors.blue (蓝色)
|
||||
# gr.themes.utils.colors.indigo (靛蓝色)
|
||||
# gr.themes.utils.colors.violet (紫罗兰色)
|
||||
# gr.themes.utils.colors.purple (紫色)
|
||||
# gr.themes.utils.colors.fuchsia (洋红色)
|
||||
# gr.themes.utils.colors.pink (粉红色)
|
||||
# gr.themes.utils.colors.rose (玫瑰色)
|
||||
|
||||
|
||||
def adjust_theme():
|
||||
|
||||
try:
|
||||
color_er = gr.themes.utils.colors.fuchsia
|
||||
set_theme = gr.themes.Default(
|
||||
primary_hue=gr.themes.utils.colors.orange,
|
||||
neutral_hue=gr.themes.utils.colors.gray,
|
||||
font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui",
|
||||
"sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")],
|
||||
font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")])
|
||||
set_theme.set(
|
||||
# Colors
|
||||
input_background_fill_dark="*neutral_800",
|
||||
# Transition
|
||||
button_transition="none",
|
||||
# Shadows
|
||||
button_shadow="*shadow_drop",
|
||||
button_shadow_hover="*shadow_drop_lg",
|
||||
button_shadow_active="*shadow_inset",
|
||||
input_shadow="0 0 0 *shadow_spread transparent, *shadow_inset",
|
||||
input_shadow_focus="0 0 0 *shadow_spread *secondary_50, *shadow_inset",
|
||||
input_shadow_focus_dark="0 0 0 *shadow_spread *neutral_700, *shadow_inset",
|
||||
checkbox_label_shadow="*shadow_drop",
|
||||
block_shadow="*shadow_drop",
|
||||
form_gap_width="1px",
|
||||
# Button borders
|
||||
input_border_width="1px",
|
||||
input_background_fill="white",
|
||||
# Gradients
|
||||
stat_background_fill="linear-gradient(to right, *primary_400, *primary_200)",
|
||||
stat_background_fill_dark="linear-gradient(to right, *primary_400, *primary_600)",
|
||||
error_background_fill=f"linear-gradient(to right, {color_er.c100}, *background_fill_secondary)",
|
||||
error_background_fill_dark="*background_fill_primary",
|
||||
checkbox_label_background_fill="linear-gradient(to top, *neutral_50, white)",
|
||||
checkbox_label_background_fill_dark="linear-gradient(to top, *neutral_900, *neutral_800)",
|
||||
checkbox_label_background_fill_hover="linear-gradient(to top, *neutral_100, white)",
|
||||
checkbox_label_background_fill_hover_dark="linear-gradient(to top, *neutral_900, *neutral_800)",
|
||||
button_primary_background_fill="linear-gradient(to bottom right, *primary_100, *primary_300)",
|
||||
button_primary_background_fill_dark="linear-gradient(to bottom right, *primary_500, *primary_600)",
|
||||
button_primary_background_fill_hover="linear-gradient(to bottom right, *primary_100, *primary_200)",
|
||||
button_primary_background_fill_hover_dark="linear-gradient(to bottom right, *primary_500, *primary_500)",
|
||||
button_primary_border_color_dark="*primary_500",
|
||||
button_secondary_background_fill="linear-gradient(to bottom right, *neutral_100, *neutral_200)",
|
||||
button_secondary_background_fill_dark="linear-gradient(to bottom right, *neutral_600, *neutral_700)",
|
||||
button_secondary_background_fill_hover="linear-gradient(to bottom right, *neutral_100, *neutral_100)",
|
||||
button_secondary_background_fill_hover_dark="linear-gradient(to bottom right, *neutral_600, *neutral_600)",
|
||||
button_cancel_background_fill=f"linear-gradient(to bottom right, {color_er.c100}, {color_er.c200})",
|
||||
button_cancel_background_fill_dark=f"linear-gradient(to bottom right, {color_er.c600}, {color_er.c700})",
|
||||
button_cancel_background_fill_hover=f"linear-gradient(to bottom right, {color_er.c100}, {color_er.c100})",
|
||||
button_cancel_background_fill_hover_dark=f"linear-gradient(to bottom right, {color_er.c600}, {color_er.c600})",
|
||||
button_cancel_border_color=color_er.c200,
|
||||
button_cancel_border_color_dark=color_er.c600,
|
||||
button_cancel_text_color=color_er.c600,
|
||||
button_cancel_text_color_dark="white",
|
||||
)
|
||||
|
||||
# Layout = "LEFT-RIGHT"
|
||||
js = """
|
||||
<script>
|
||||
function ChatBotHeight() {
|
||||
function update_height(){
|
||||
var { panel_height_target, chatbot_height, chatbot } = get_elements();
|
||||
if (panel_height_target!=chatbot_height)
|
||||
{
|
||||
var pixelString = panel_height_target.toString() + 'px';
|
||||
chatbot.style.maxHeight = pixelString; chatbot.style.height = pixelString;
|
||||
}
|
||||
}
|
||||
|
||||
function update_height_slow(){
|
||||
var { panel_height_target, chatbot_height, chatbot } = get_elements();
|
||||
if (panel_height_target!=chatbot_height)
|
||||
{
|
||||
new_panel_height = (panel_height_target - chatbot_height)*0.5 + chatbot_height;
|
||||
if (Math.abs(new_panel_height - panel_height_target) < 10){
|
||||
new_panel_height = panel_height_target;
|
||||
}
|
||||
// console.log(chatbot_height, panel_height_target, new_panel_height);
|
||||
var pixelString = new_panel_height.toString() + 'px';
|
||||
chatbot.style.maxHeight = pixelString; chatbot.style.height = pixelString;
|
||||
}
|
||||
}
|
||||
|
||||
update_height();
|
||||
setInterval(function() {
|
||||
update_height_slow()
|
||||
}, 50); // 每100毫秒执行一次
|
||||
}
|
||||
|
||||
function get_elements() {
|
||||
var chatbot = document.querySelector('#gpt-chatbot > div.wrap.svelte-18telvq');
|
||||
if (!chatbot) {
|
||||
chatbot = document.querySelector('#gpt-chatbot');
|
||||
}
|
||||
const panel1 = document.querySelector('#input-panel');
|
||||
const panel2 = document.querySelector('#basic-panel');
|
||||
const panel3 = document.querySelector('#plugin-panel');
|
||||
const panel4 = document.querySelector('#interact-panel');
|
||||
const panel5 = document.querySelector('#input-panel2');
|
||||
const panel_active = document.querySelector('#state-panel');
|
||||
var panel_height_target = (20-panel_active.offsetHeight) + panel1.offsetHeight + panel2.offsetHeight + panel3.offsetHeight + panel4.offsetHeight + panel5.offsetHeight + 21;
|
||||
var panel_height_target = parseInt(panel_height_target);
|
||||
var chatbot_height = chatbot.style.height;
|
||||
var chatbot_height = parseInt(chatbot_height);
|
||||
return { panel_height_target, chatbot_height, chatbot };
|
||||
}
|
||||
</script>
|
||||
"""
|
||||
|
||||
if LAYOUT=="TOP-DOWN":
|
||||
js = ""
|
||||
|
||||
# 添加一个萌萌的看板娘
|
||||
if ADD_WAIFU:
|
||||
js += """
|
||||
<script src="file=docs/waifu_plugin/jquery.min.js"></script>
|
||||
<script src="file=docs/waifu_plugin/jquery-ui.min.js"></script>
|
||||
<script src="file=docs/waifu_plugin/autoload.js"></script>
|
||||
"""
|
||||
gradio_original_template_fn = gr.routes.templates.TemplateResponse
|
||||
def gradio_new_template_fn(*args, **kwargs):
|
||||
res = gradio_original_template_fn(*args, **kwargs)
|
||||
res.body = res.body.replace(b'</html>', f'{js}</html>'.encode("utf8"))
|
||||
res.init_headers()
|
||||
return res
|
||||
gr.routes.templates.TemplateResponse = gradio_new_template_fn # override gradio template
|
||||
except:
|
||||
set_theme = None
|
||||
print('gradio版本较旧, 不能自定义字体和颜色')
|
||||
return set_theme
|
||||
|
||||
|
||||
advanced_css = """
|
||||
.markdown-body table {
|
||||
margin: 1em 0;
|
||||
border-collapse: collapse;
|
||||
@ -243,10 +82,13 @@ advanced_css = """
|
||||
margin: 1em 2em 1em 0.5em;
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
if CODE_HIGHLIGHT:
|
||||
advanced_css += """
|
||||
.app.svelte-1mya07g.svelte-1mya07g {
|
||||
max-width: 95%;
|
||||
position: relative;
|
||||
padding: var(--size-4);
|
||||
width: 95%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.codehilite .hll { background-color: #6e7681 }
|
||||
.codehilite .c { color: #8b949e; font-style: italic } /* Comment */
|
||||
@ -406,4 +248,3 @@ if CODE_HIGHLIGHT:
|
||||
.dark .codehilite .vm { color: #82AAFF } /* Name.Variable.Magic */
|
||||
.dark .codehilite .il { color: #F78C6C } /* Literal.Number.Integer.Long */
|
||||
|
||||
"""
|
||||
87
theme/default.py
Normal file
87
theme/default.py
Normal file
@ -0,0 +1,87 @@
|
||||
import gradio as gr
|
||||
from toolbox import get_conf
|
||||
CODE_HIGHLIGHT, ADD_WAIFU, LAYOUT = get_conf('CODE_HIGHLIGHT', 'ADD_WAIFU', 'LAYOUT')
|
||||
|
||||
def adjust_theme():
|
||||
|
||||
try:
|
||||
color_er = gr.themes.utils.colors.fuchsia
|
||||
set_theme = gr.themes.Default(
|
||||
primary_hue=gr.themes.utils.colors.orange,
|
||||
neutral_hue=gr.themes.utils.colors.gray,
|
||||
font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui",
|
||||
"sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")],
|
||||
font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")])
|
||||
set_theme.set(
|
||||
# Colors
|
||||
input_background_fill_dark="*neutral_800",
|
||||
# Transition
|
||||
button_transition="none",
|
||||
# Shadows
|
||||
button_shadow="*shadow_drop",
|
||||
button_shadow_hover="*shadow_drop_lg",
|
||||
button_shadow_active="*shadow_inset",
|
||||
input_shadow="0 0 0 *shadow_spread transparent, *shadow_inset",
|
||||
input_shadow_focus="0 0 0 *shadow_spread *secondary_50, *shadow_inset",
|
||||
input_shadow_focus_dark="0 0 0 *shadow_spread *neutral_700, *shadow_inset",
|
||||
checkbox_label_shadow="*shadow_drop",
|
||||
block_shadow="*shadow_drop",
|
||||
form_gap_width="1px",
|
||||
# Button borders
|
||||
input_border_width="1px",
|
||||
input_background_fill="white",
|
||||
# Gradients
|
||||
stat_background_fill="linear-gradient(to right, *primary_400, *primary_200)",
|
||||
stat_background_fill_dark="linear-gradient(to right, *primary_400, *primary_600)",
|
||||
error_background_fill=f"linear-gradient(to right, {color_er.c100}, *background_fill_secondary)",
|
||||
error_background_fill_dark="*background_fill_primary",
|
||||
checkbox_label_background_fill="linear-gradient(to top, *neutral_50, white)",
|
||||
checkbox_label_background_fill_dark="linear-gradient(to top, *neutral_900, *neutral_800)",
|
||||
checkbox_label_background_fill_hover="linear-gradient(to top, *neutral_100, white)",
|
||||
checkbox_label_background_fill_hover_dark="linear-gradient(to top, *neutral_900, *neutral_800)",
|
||||
button_primary_background_fill="linear-gradient(to bottom right, *primary_100, *primary_300)",
|
||||
button_primary_background_fill_dark="linear-gradient(to bottom right, *primary_500, *primary_600)",
|
||||
button_primary_background_fill_hover="linear-gradient(to bottom right, *primary_100, *primary_200)",
|
||||
button_primary_background_fill_hover_dark="linear-gradient(to bottom right, *primary_500, *primary_500)",
|
||||
button_primary_border_color_dark="*primary_500",
|
||||
button_secondary_background_fill="linear-gradient(to bottom right, *neutral_100, *neutral_200)",
|
||||
button_secondary_background_fill_dark="linear-gradient(to bottom right, *neutral_600, *neutral_700)",
|
||||
button_secondary_background_fill_hover="linear-gradient(to bottom right, *neutral_100, *neutral_100)",
|
||||
button_secondary_background_fill_hover_dark="linear-gradient(to bottom right, *neutral_600, *neutral_600)",
|
||||
button_cancel_background_fill=f"linear-gradient(to bottom right, {color_er.c100}, {color_er.c200})",
|
||||
button_cancel_background_fill_dark=f"linear-gradient(to bottom right, {color_er.c600}, {color_er.c700})",
|
||||
button_cancel_background_fill_hover=f"linear-gradient(to bottom right, {color_er.c100}, {color_er.c100})",
|
||||
button_cancel_background_fill_hover_dark=f"linear-gradient(to bottom right, {color_er.c600}, {color_er.c600})",
|
||||
button_cancel_border_color=color_er.c200,
|
||||
button_cancel_border_color_dark=color_er.c600,
|
||||
button_cancel_text_color=color_er.c600,
|
||||
button_cancel_text_color_dark="white",
|
||||
)
|
||||
|
||||
if LAYOUT=="TOP-DOWN":
|
||||
js = ""
|
||||
else:
|
||||
with open('theme/common.js', 'r', encoding='utf8') as f:
|
||||
js = f"<script>{f.read()}</script>"
|
||||
|
||||
# 添加一个萌萌的看板娘
|
||||
if ADD_WAIFU:
|
||||
js += """
|
||||
<script src="file=docs/waifu_plugin/jquery.min.js"></script>
|
||||
<script src="file=docs/waifu_plugin/jquery-ui.min.js"></script>
|
||||
<script src="file=docs/waifu_plugin/autoload.js"></script>
|
||||
"""
|
||||
gradio_original_template_fn = gr.routes.templates.TemplateResponse
|
||||
def gradio_new_template_fn(*args, **kwargs):
|
||||
res = gradio_original_template_fn(*args, **kwargs)
|
||||
res.body = res.body.replace(b'</html>', f'{js}</html>'.encode("utf8"))
|
||||
res.init_headers()
|
||||
return res
|
||||
gr.routes.templates.TemplateResponse = gradio_new_template_fn # override gradio template
|
||||
except:
|
||||
set_theme = None
|
||||
print('gradio版本较旧, 不能自定义字体和颜色')
|
||||
return set_theme
|
||||
|
||||
with open("theme/default.css", "r", encoding="utf-8") as f:
|
||||
advanced_css = f.read()
|
||||
806
theme/green.css
Normal file
806
theme/green.css
Normal file
@ -0,0 +1,806 @@
|
||||
:root {
|
||||
--chatbot-color-light: #000000;
|
||||
--chatbot-color-dark: #FFFFFF;
|
||||
--chatbot-background-color-light: #F3F3F3;
|
||||
--chatbot-background-color-dark: #121111;
|
||||
--message-user-background-color-light: #95EC69;
|
||||
--message-user-background-color-dark: #26B561;
|
||||
--message-bot-background-color-light: #FFFFFF;
|
||||
--message-bot-background-color-dark: #2C2C2C;
|
||||
}
|
||||
mspace {
|
||||
display: block;
|
||||
}
|
||||
@media only screen and (max-width: 767px) {
|
||||
#column_1 {
|
||||
display: none !important;
|
||||
}
|
||||
}
|
||||
@keyframes highlight {
|
||||
0%, 100% {
|
||||
border: 2px solid transparent;
|
||||
}
|
||||
50% {
|
||||
border-color: yellow;
|
||||
}
|
||||
}
|
||||
|
||||
#highlight_update {
|
||||
animation-name: highlight;
|
||||
animation-duration: 0.75s;
|
||||
animation-iteration-count: 3;
|
||||
}
|
||||
|
||||
.table-wrap.svelte-13hsdno.svelte-13hsdno.svelte-13hsdno {
|
||||
border: 0px solid var(--border-color-primary) !important;
|
||||
}
|
||||
|
||||
#examples_col {
|
||||
z-index: 2;
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
margin-bottom: 30% !important;
|
||||
}
|
||||
#hide_examples {
|
||||
z-index: 0;
|
||||
}
|
||||
|
||||
#debug_mes {
|
||||
position: absolute;
|
||||
display: flex;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
z-index: 1; /* 设置更高的 z-index 值 */
|
||||
margin-bottom: -4px !important;
|
||||
align-self: flex-end;
|
||||
}
|
||||
#chat_box {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow-y: visible !important;
|
||||
z-index: 3;
|
||||
flex-grow: 1; /* 自动填充剩余空间 */
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
margin-bottom: 30px !important;
|
||||
border: 1px solid var(--border-color-primary);
|
||||
}
|
||||
.toast-body {
|
||||
z-index: 5 !important;
|
||||
}
|
||||
.chat_input {
|
||||
|
||||
}
|
||||
.sm_btn {
|
||||
position: relative;
|
||||
bottom: 5px;
|
||||
height: 10%;
|
||||
border-radius: 20px!important;
|
||||
min-width: min(10%,100%) !important;
|
||||
overflow: hidden;
|
||||
}
|
||||
.sm_select {
|
||||
position: relative !important;
|
||||
z-index: 5 !important;
|
||||
bottom: 5px;
|
||||
min-width: min(20%,100%) !important;
|
||||
border-radius: 20px!important;
|
||||
}
|
||||
.sm_checkbox {
|
||||
position: relative !important;
|
||||
z-index: 5 !important;
|
||||
bottom: 5px;
|
||||
padding: 0 !important;
|
||||
}
|
||||
.sm_select .wrap-inner.svelte-aqlk7e.svelte-aqlk7e.svelte-aqlk7e {
|
||||
padding: 0 !important;
|
||||
}
|
||||
.sm_select .block.svelte-mppz8v {
|
||||
width: 10% !important;
|
||||
}
|
||||
|
||||
/* usage_display */
|
||||
.insert_block {
|
||||
position: relative;
|
||||
bottom: 2px;
|
||||
min-width: min(55px,100%) !important;
|
||||
}
|
||||
|
||||
.submit_btn {
|
||||
flex-direction: column-reverse;
|
||||
overflow-y: auto !important;
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
right: 10px;
|
||||
margin-bottom: 10px !important;
|
||||
min-width: min(50px,100%) !important;
|
||||
}
|
||||
|
||||
textarea {
|
||||
resize: none;
|
||||
height: 100%; /* 填充父元素的高度 */
|
||||
}
|
||||
#main_chatbot {
|
||||
height: 75vh !important;
|
||||
max-height: 75vh !important;
|
||||
/* overflow: auto !important; */
|
||||
z-index: 2;
|
||||
transform: translateZ(0) !important;
|
||||
backface-visibility: hidden !important;
|
||||
will-change: transform !important;
|
||||
}
|
||||
#prompt_result{
|
||||
height: 60vh !important;
|
||||
max-height: 60vh !important;
|
||||
}
|
||||
|
||||
#app_title {
|
||||
font-weight: var(--prose-header-text-weight);
|
||||
font-size: var(--text-xxl);
|
||||
line-height: 1.3;
|
||||
text-align: left;
|
||||
margin-top: 6px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
#description {
|
||||
text-align: center;
|
||||
margin: 32px 0 4px 0;
|
||||
}
|
||||
|
||||
/* gradio的页脚信息 */
|
||||
footer {
|
||||
/* display: none !important; */
|
||||
margin-top: .2em !important;
|
||||
font-size: 85%;
|
||||
}
|
||||
#footer {
|
||||
text-align: center;
|
||||
}
|
||||
#footer div {
|
||||
display: inline-block;
|
||||
}
|
||||
#footer .versions{
|
||||
font-size: 85%;
|
||||
opacity: 0.60;
|
||||
}
|
||||
/* user_info */
|
||||
|
||||
#float_display {
|
||||
position: absolute;
|
||||
max-height: 30px;
|
||||
}
|
||||
/* user_info */
|
||||
#user_info {
|
||||
white-space: nowrap;
|
||||
position: absolute; left: 8em; top: .2em;
|
||||
z-index: var(--layer-2);
|
||||
box-shadow: var(--block-shadow);
|
||||
border: none; border-radius: var(--block-label-radius);
|
||||
background: var(--color-accent);
|
||||
padding: var(--block-label-padding);
|
||||
font-size: var(--block-label-text-size); line-height: var(--line-sm);
|
||||
width: auto; min-height: 30px !important;
|
||||
opacity: 1;
|
||||
transition: opacity 0.3s ease-in-out;
|
||||
}
|
||||
textarea.svelte-1pie7s6 {
|
||||
background: #e7e6e6 !important;
|
||||
width: 96% !important;
|
||||
}
|
||||
|
||||
.dark textarea.svelte-1pie7s6 {
|
||||
background: var(--input-background-fill) !important;
|
||||
width: 96% !important;
|
||||
}
|
||||
|
||||
.dark input[type=number].svelte-1cl284s {
|
||||
background: #393939 !important;
|
||||
border: var(--input-border-width) solid var(--input-border-color) !important;
|
||||
}
|
||||
.dark input[type="range"] {
|
||||
background: #393939 !important;
|
||||
}
|
||||
#user_info .wrap {
|
||||
opacity: 0;
|
||||
}
|
||||
#user_info p {
|
||||
color: white;
|
||||
font-weight: var(--block-label-text-weight);
|
||||
}
|
||||
#user_info.hideK {
|
||||
opacity: 0;
|
||||
transition: opacity 1s ease-in-out;
|
||||
}
|
||||
[class *= "message"] {
|
||||
gap: 7px !important;
|
||||
border-radius: var(--radius-xl) !important
|
||||
}
|
||||
/* debug_mes */
|
||||
#debug_mes {
|
||||
min-height: 2em;
|
||||
align-items: flex-end;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
#debug_mes p {
|
||||
font-size: .85em;
|
||||
font-family: ui-monospace, "SF Mono", "SFMono-Regular", "Menlo", "Consolas", "Liberation Mono", "Microsoft Yahei UI", "Microsoft Yahei", monospace;
|
||||
/* Windows下中文的monospace会fallback为新宋体,实在太丑,这里折中使用微软雅黑 */
|
||||
color: #000000;
|
||||
}
|
||||
.dark #debug_mes p {
|
||||
color: #ee65ed;
|
||||
}
|
||||
|
||||
#debug_mes {
|
||||
transition: all 0.6s;
|
||||
}
|
||||
#main_chatbot {
|
||||
transition: height 0.3s ease;
|
||||
}
|
||||
|
||||
/* .wrap.svelte-18telvq.svelte-18telvq {
|
||||
padding: var(--block-padding) !important;
|
||||
height: 100% !important;
|
||||
max-height: 95% !important;
|
||||
overflow-y: auto !important;
|
||||
}*/
|
||||
.app.svelte-1mya07g.svelte-1mya07g {
|
||||
max-width: 100%;
|
||||
position: relative;
|
||||
padding: var(--size-4);
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.gradio-container-3-32-2 h1 {
|
||||
font-weight: 700 !important;
|
||||
font-size: 28px !important;
|
||||
}
|
||||
|
||||
|
||||
.gradio-container-3-32-2 h2 {
|
||||
font-weight: 600 !important;
|
||||
font-size: 24px !important;
|
||||
}
|
||||
.gradio-container-3-32-2 h3 {
|
||||
font-weight: 500 !important;
|
||||
font-size: 20px !important;
|
||||
}
|
||||
.gradio-container-3-32-2 h4 {
|
||||
font-weight: 400 !important;
|
||||
font-size: 16px !important;
|
||||
}
|
||||
.gradio-container-3-32-2 h5 {
|
||||
font-weight: 300 !important;
|
||||
font-size: 14px !important;
|
||||
}
|
||||
.gradio-container-3-32-2 h6 {
|
||||
font-weight: 200 !important;
|
||||
font-size: 12px !important;
|
||||
}
|
||||
|
||||
|
||||
#usage_display p, #usage_display span {
|
||||
margin: 0;
|
||||
font-size: .85em;
|
||||
color: var(--body-text-color-subdued);
|
||||
}
|
||||
.progress-bar {
|
||||
background-color: var(--input-background-fill);;
|
||||
margin: .5em 0 !important;
|
||||
height: 20px;
|
||||
border-radius: 10px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.progress {
|
||||
background-color: var(--block-title-background-fill);
|
||||
height: 100%;
|
||||
border-radius: 10px;
|
||||
text-align: right;
|
||||
transition: width 0.5s ease-in-out;
|
||||
}
|
||||
.progress-text {
|
||||
/* color: white; */
|
||||
color: var(--color-accent) !important;
|
||||
font-size: 1em !important;
|
||||
font-weight: bold;
|
||||
padding-right: 10px;
|
||||
line-height: 20px;
|
||||
}
|
||||
|
||||
.apSwitch {
|
||||
top: 2px;
|
||||
display: inline-block;
|
||||
height: 24px;
|
||||
position: relative;
|
||||
width: 48px;
|
||||
border-radius: 12px;
|
||||
}
|
||||
.apSwitch input {
|
||||
display: none !important;
|
||||
}
|
||||
.apSlider {
|
||||
background-color: var(--neutral-200);
|
||||
bottom: 0;
|
||||
cursor: pointer;
|
||||
left: 0;
|
||||
position: absolute;
|
||||
right: 0;
|
||||
top: 0;
|
||||
transition: .4s;
|
||||
font-size: 18px;
|
||||
border-radius: 7px;
|
||||
}
|
||||
.apSlider::before {
|
||||
bottom: -1.5px;
|
||||
left: 1px;
|
||||
position: absolute;
|
||||
transition: .4s;
|
||||
content: "🌞";
|
||||
}
|
||||
hr.append-display {
|
||||
margin: 8px 0;
|
||||
border: none;
|
||||
height: 1px;
|
||||
border-top-width: 0;
|
||||
background-image: linear-gradient(to right, rgba(50,50,50, 0.1), rgba(150, 150, 150, 0.8), rgba(50,50,50, 0.1));
|
||||
}
|
||||
.source-a {
|
||||
font-size: 0.8em;
|
||||
max-width: 100%;
|
||||
margin: 0;
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
/* background-color: #dddddd88; */
|
||||
border-radius: 1.5rem;
|
||||
padding: 0.2em;
|
||||
}
|
||||
.source-a a {
|
||||
display: inline-block;
|
||||
background-color: #aaaaaa50;
|
||||
border-radius: 1rem;
|
||||
padding: 0.5em;
|
||||
text-align: center;
|
||||
text-overflow: ellipsis;
|
||||
overflow: hidden;
|
||||
min-width: 20%;
|
||||
white-space: nowrap;
|
||||
margin: 0.2rem 0.1rem;
|
||||
text-decoration: none !important;
|
||||
flex: 1;
|
||||
transition: flex 0.5s;
|
||||
}
|
||||
.source-a a:hover {
|
||||
background-color: #aaaaaa20;
|
||||
flex: 2;
|
||||
}
|
||||
input:checked + .apSlider {
|
||||
background-color: var(--primary-600);
|
||||
}
|
||||
input:checked + .apSlider::before {
|
||||
transform: translateX(23px);
|
||||
content:"🌚";
|
||||
}
|
||||
|
||||
/* Override Slider Styles (for webkit browsers like Safari and Chrome)
|
||||
* 好希望这份提案能早日实现 https://github.com/w3c/csswg-drafts/issues/4410
|
||||
* 进度滑块在各个平台还是太不统一了
|
||||
*/
|
||||
input[type="range"] {
|
||||
-webkit-appearance: none;
|
||||
height: 4px;
|
||||
background: var(--input-background-fill);
|
||||
border-radius: 5px;
|
||||
background-image: linear-gradient(var(--primary-500),var(--primary-500));
|
||||
background-size: 0% 100%;
|
||||
background-repeat: no-repeat;
|
||||
}
|
||||
input[type="range"]::-webkit-slider-thumb {
|
||||
-webkit-appearance: none;
|
||||
height: 20px;
|
||||
width: 20px;
|
||||
border-radius: 50%;
|
||||
border: solid 0.5px #ddd;
|
||||
background-color: white;
|
||||
cursor: ew-resize;
|
||||
box-shadow: var(--input-shadow);
|
||||
transition: background-color .1s ease;
|
||||
}
|
||||
input[type="range"]::-webkit-slider-thumb:hover {
|
||||
background: var(--neutral-50);
|
||||
}
|
||||
input[type=range]::-webkit-slider-runnable-track {
|
||||
-webkit-appearance: none;
|
||||
box-shadow: none;
|
||||
border: none;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.submit_btn, #cancel_btn {
|
||||
height: 42px !important;
|
||||
}
|
||||
.submit_btn::before {
|
||||
content: url("data:image/svg+xml, %3Csvg width='21px' height='20px' viewBox='0 0 21 20' version='1.1' xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink'%3E %3Cg id='page' stroke='none' stroke-width='1' fill='none' fill-rule='evenodd'%3E %3Cg id='send' transform='translate(0.435849, 0.088463)' fill='%23FFFFFF' fill-rule='nonzero'%3E %3Cpath d='M0.579148261,0.0428666046 C0.301105539,-0.0961547561 -0.036517765,0.122307382 0.0032026237,0.420210298 L1.4927172,18.1553639 C1.5125774,18.4334066 1.79062012,18.5922882 2.04880264,18.4929872 L8.24518329,15.8913017 L11.6412765,19.7441794 C11.8597387,19.9825018 12.2370824,19.8832008 12.3165231,19.5852979 L13.9450591,13.4882182 L19.7839562,11.0255541 C20.0619989,10.8865327 20.0818591,10.4694687 19.7839562,10.3105871 L0.579148261,0.0428666046 Z M11.6138902,17.0883151 L9.85385903,14.7195502 L0.718169621,0.618812241 L12.69945,12.9346347 L11.6138902,17.0883151 Z' id='shape'%3E%3C/path%3E %3C/g%3E %3C/g%3E %3C/svg%3E");
|
||||
height: 21px;
|
||||
}
|
||||
|
||||
#cancel_btn::before {
|
||||
content: url("data:image/svg+xml,%3Csvg width='21px' height='21px' viewBox='0 0 21 21' version='1.1' xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink'%3E %3Cg id='pg' stroke='none' stroke-width='1' fill='none' fill-rule='evenodd'%3E %3Cpath d='M10.2072007,20.088463 C11.5727865,20.088463 12.8594566,19.8259823 14.067211,19.3010209 C15.2749653,18.7760595 16.3386126,18.0538087 17.2581528,17.1342685 C18.177693,16.2147282 18.8982283,15.1527965 19.4197586,13.9484733 C19.9412889,12.7441501 20.202054,11.4557644 20.202054,10.0833163 C20.202054,8.71773046 19.9395733,7.43106036 19.4146119,6.22330603 C18.8896505,5.01555169 18.1673997,3.95018885 17.2478595,3.0272175 C16.3283192,2.10424615 15.2646719,1.3837109 14.0569176,0.865611739 C12.8491633,0.34751258 11.5624932,0.088463 10.1969073,0.088463 C8.83132146,0.088463 7.54636692,0.34751258 6.34204371,0.865611739 C5.1377205,1.3837109 4.07407321,2.10424615 3.15110186,3.0272175 C2.22813051,3.95018885 1.5058797,5.01555169 0.984349419,6.22330603 C0.46281914,7.43106036 0.202054,8.71773046 0.202054,10.0833163 C0.202054,11.4557644 0.4645347,12.7441501 0.9894961,13.9484733 C1.5144575,15.1527965 2.23670831,16.2147282 3.15624854,17.1342685 C4.07578877,18.0538087 5.1377205,18.7760595 6.34204371,19.3010209 C7.54636692,19.8259823 8.83475258,20.088463 10.2072007,20.088463 Z M10.2072007,18.2562448 C9.07493099,18.2562448 8.01471483,18.0452309 7.0265522,17.6232031 C6.03838956,17.2011753 5.17031614,16.6161693 4.42233192,15.8681851 C3.6743477,15.1202009 3.09105726,14.2521274 2.67246059,13.2639648 C2.25386392,12.2758022 2.04456558,11.215586 2.04456558,10.0833163 C2.04456558,8.95104663 2.25386392,7.89083047 2.67246059,6.90266784 C3.09105726,5.9145052 3.6743477,5.04643178 4.42233192,4.29844756 C5.17031614,3.55046334 6.036674,2.9671729 7.02140552,2.54857623 C8.00613703,2.12997956 9.06463763,1.92068122 10.1969073,1.92068122 C11.329177,1.92068122 12.3911087,2.12997956 13.3827025,2.54857623 C14.3742962,2.9671729 15.2440852,3.55046334 15.9920694,4.29844756 C16.7400537,5.04643178 17.3233441,5.9145052 17.7419408,6.90266784 C18.1605374,7.89083047 18.3698358,8.95104663 18.3698358,10.0833163 C18.3698358,11.215586 18.1605374,12.2758022 17.7419408,13.2639648 C17.3233441,14.2521274 16.7400537,15.1202009 15.9920694,15.8681851 C15.2440852,16.6161693 14.3760118,17.2011753 13.3878492,17.6232031 C12.3996865,18.0452309 11.3394704,18.2562448 10.2072007,18.2562448 Z M7.65444721,13.6242324 L12.7496608,13.6242324 C13.0584616,13.6242324 13.3003556,13.5384544 13.4753427,13.3668984 C13.6503299,13.1953424 13.7378234,12.9585951 13.7378234,12.6566565 L13.7378234,7.49968276 C13.7378234,7.19774418 13.6503299,6.96099688 13.4753427,6.78944087 C13.3003556,6.61788486 13.0584616,6.53210685 12.7496608,6.53210685 L7.65444721,6.53210685 C7.33878414,6.53210685 7.09345904,6.61788486 6.91847191,6.78944087 C6.74348478,6.96099688 6.65599121,7.19774418 6.65599121,7.49968276 L6.65599121,12.6566565 C6.65599121,12.9585951 6.74348478,13.1953424 6.91847191,13.3668984 C7.09345904,13.5384544 7.33878414,13.6242324 7.65444721,13.6242324 Z' id='shape' fill='%23FF3B30' fill-rule='nonzero'%3E%3C/path%3E %3C/g%3E %3C/svg%3E");
|
||||
height: 21px;
|
||||
}
|
||||
/* list */
|
||||
ol:not(.options), ul:not(.options) {
|
||||
padding-inline-start: 2em !important;
|
||||
}
|
||||
|
||||
/* 亮色(默认) */
|
||||
#main_chatbot {
|
||||
background-color: var(--chatbot-background-color-light) !important;
|
||||
color: var(--chatbot-color-light) !important;
|
||||
}
|
||||
/* 暗色 */
|
||||
.dark #main_chatbot {
|
||||
background-color: var(--block-background-fill) !important;
|
||||
color: var(--chatbot-color-dark) !important;
|
||||
}
|
||||
|
||||
/* 屏幕宽度大于等于500px的设备 */
|
||||
/* update on 2023.4.8: 高度的细致调整已写入JavaScript */
|
||||
@media screen and (min-width: 500px) {
|
||||
#main_chatbot {
|
||||
height: calc(100vh - 200px);
|
||||
}
|
||||
#main_chatbot .wrap {
|
||||
max-height: calc(100vh - 200px - var(--line-sm)*1rem - 2*var(--block-label-margin) );
|
||||
}
|
||||
}
|
||||
/* 屏幕宽度小于500px的设备 */
|
||||
@media screen and (max-width: 499px) {
|
||||
#main_chatbot {
|
||||
height: calc(100vh - 140px);
|
||||
}
|
||||
#main_chatbot .wrap {
|
||||
max-height: calc(100vh - 140px - var(--line-sm)*1rem - 2*var(--block-label-margin) );
|
||||
}
|
||||
[data-testid = "bot"] {
|
||||
max-width: 95% !important;
|
||||
}
|
||||
#app_title h1{
|
||||
letter-spacing: -1px; font-size: 22px;
|
||||
}
|
||||
}
|
||||
#main_chatbot .wrap {
|
||||
overflow-x: hidden
|
||||
}
|
||||
/* 对话气泡 */
|
||||
.message {
|
||||
border-radius: var(--radius-xl) !important;
|
||||
border: none;
|
||||
padding: var(--spacing-xl) !important;
|
||||
font-size: 15px !important;
|
||||
line-height: var(--line-md) !important;
|
||||
min-height: calc(var(--text-md)*var(--line-md) + 2*var(--spacing-xl));
|
||||
min-width: calc(var(--text-md)*var(--line-md) + 2*var(--spacing-xl));
|
||||
}
|
||||
[data-testid = "bot"] {
|
||||
max-width: 85%;
|
||||
border-bottom-left-radius: 0 !important;
|
||||
}
|
||||
[data-testid = "user"] {
|
||||
max-width: 85%;
|
||||
width: auto !important;
|
||||
border-bottom-right-radius: 0 !important;
|
||||
}
|
||||
|
||||
.message p {
|
||||
margin-top: 0.6em !important;
|
||||
margin-bottom: 0.6em !important;
|
||||
}
|
||||
.message p:first-child { margin-top: 0 !important; }
|
||||
.message p:last-of-type { margin-bottom: 0 !important; }
|
||||
|
||||
.message .md-message {
|
||||
display: block;
|
||||
padding: 0 !important;
|
||||
}
|
||||
.message .raw-message {
|
||||
display: block;
|
||||
padding: 0 !important;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
.raw-message.hideM, .md-message.hideM {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* custom buttons */
|
||||
.chuanhu-btn {
|
||||
border-radius: 5px;
|
||||
/* background-color: #E6E6E6 !important; */
|
||||
color: rgba(120, 120, 120, 0.64) !important;
|
||||
padding: 4px !important;
|
||||
position: absolute;
|
||||
right: -22px;
|
||||
cursor: pointer !important;
|
||||
transition: color .2s ease, background-color .2s ease;
|
||||
}
|
||||
.chuanhu-btn:hover {
|
||||
background-color: rgba(167, 167, 167, 0.25) !important;
|
||||
color: unset !important;
|
||||
}
|
||||
.chuanhu-btn:active {
|
||||
background-color: rgba(167, 167, 167, 0.5) !important;
|
||||
}
|
||||
.chuanhu-btn:focus {
|
||||
outline: none;
|
||||
}
|
||||
.copy-bot-btn {
|
||||
/* top: 18px; */
|
||||
bottom: 0;
|
||||
}
|
||||
.toggle-md-btn {
|
||||
/* top: 0; */
|
||||
bottom: 20px;
|
||||
}
|
||||
.copy-code-btn {
|
||||
position: relative;
|
||||
float: right;
|
||||
font-size: 1em;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.message-wrap>div img{
|
||||
border-radius: 10px !important;
|
||||
}
|
||||
|
||||
/* history message */
|
||||
.wrap>.history-message {
|
||||
padding: 10px !important;
|
||||
}
|
||||
.history-message {
|
||||
/* padding: 0 !important; */
|
||||
opacity: 80%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
.history-message>.history-message {
|
||||
padding: 0 !important;
|
||||
}
|
||||
.history-message>.message-wrap {
|
||||
padding: 0 !important;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.history-message>.message {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.wrap>.history-message::after {
|
||||
content: "";
|
||||
display: block;
|
||||
height: 2px;
|
||||
background-color: var(--body-text-color-subdued);
|
||||
margin-bottom: 10px;
|
||||
margin-top: -10px;
|
||||
clear: both;
|
||||
}
|
||||
.wrap>.history-message>:last-child::after {
|
||||
content: "仅供查看";
|
||||
display: block;
|
||||
text-align: center;
|
||||
color: var(--body-text-color-subdued);
|
||||
font-size: 0.8em;
|
||||
}
|
||||
|
||||
/* 表格 */
|
||||
table {
|
||||
margin: 1em 0;
|
||||
border-collapse: collapse;
|
||||
empty-cells: show;
|
||||
}
|
||||
td,th {
|
||||
border: 1.2px solid var(--border-color-primary) !important;
|
||||
padding: 0.2em;
|
||||
}
|
||||
thead {
|
||||
background-color: rgba(175,184,193,0.2);
|
||||
}
|
||||
thead th {
|
||||
padding: .5em .2em;
|
||||
}
|
||||
/* 行内代码 */
|
||||
.message :not(pre) code {
|
||||
display: inline;
|
||||
white-space: break-spaces;
|
||||
border-radius: 6px;
|
||||
margin: 0 2px 0 2px;
|
||||
padding: .2em .4em .1em .4em;
|
||||
background-color: rgba(175,184,193,0.2);
|
||||
}
|
||||
/* 代码块 */
|
||||
.message pre code {
|
||||
display: block;
|
||||
overflow: auto;
|
||||
white-space: pre;
|
||||
background-color: hsla(0, 0%, 7%, 70%)!important;
|
||||
border-radius: 10px;
|
||||
padding: 1.2em 1em 0em .5em;
|
||||
margin: 0.6em 2em 1em 0.2em;
|
||||
color: #FFF;
|
||||
box-shadow: 6px 6px 16px hsla(0, 0%, 0%, 0.2);
|
||||
}
|
||||
.dark .message pre code {
|
||||
background-color: hsla(0, 0%, 20%, 300%)!important;
|
||||
}
|
||||
.message pre {
|
||||
padding: 0 !important;
|
||||
}
|
||||
.message pre code div.highlight {
|
||||
background-color: unset !important;
|
||||
}
|
||||
|
||||
button.copy-button {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* 代码高亮样式 */
|
||||
.codehilite .hll { background-color: #6e7681 }
|
||||
.codehilite .c { color: #8b949e; font-style: italic } /* Comment */
|
||||
.codehilite .err { color: #f85149 } /* Error */
|
||||
.codehilite .esc { color: #c9d1d9 } /* Escape */
|
||||
.codehilite .g { color: #c9d1d9 } /* Generic */
|
||||
.codehilite .k { color: #ff7b72 } /* Keyword */
|
||||
.codehilite .l { color: #a5d6ff } /* Literal */
|
||||
.codehilite .n { color: #c9d1d9 } /* Name */
|
||||
.codehilite .o { color: #ff7b72; font-weight: bold } /* Operator */
|
||||
.codehilite .x { color: #c9d1d9 } /* Other */
|
||||
.codehilite .p { color: #c9d1d9 } /* Punctuation */
|
||||
.codehilite .ch { color: #8b949e; font-style: italic } /* Comment.Hashbang */
|
||||
.codehilite .cm { color: #8b949e; font-style: italic } /* Comment.Multiline */
|
||||
.codehilite .cp { color: #8b949e; font-weight: bold; font-style: italic } /* Comment.Preproc */
|
||||
.codehilite .cpf { color: #8b949e; font-style: italic } /* Comment.PreprocFile */
|
||||
.codehilite .c1 { color: #8b949e; font-style: italic } /* Comment.Single */
|
||||
.codehilite .cs { color: #8b949e; font-weight: bold; font-style: italic } /* Comment.Special */
|
||||
.codehilite .gd { color: #ffa198; background-color: #490202 } /* Generic.Deleted */
|
||||
.codehilite .ge { color: #c9d1d9; font-style: italic } /* Generic.Emph */
|
||||
.codehilite .gr { color: #ffa198 } /* Generic.Error */
|
||||
.codehilite .gh { color: #79c0ff; font-weight: bold } /* Generic.Heading */
|
||||
.codehilite .gi { color: #56d364; background-color: #0f5323 } /* Generic.Inserted */
|
||||
.codehilite .go { color: #8b949e } /* Generic.Output */
|
||||
.codehilite .gp { color: #8b949e } /* Generic.Prompt */
|
||||
.codehilite .gs { color: #c9d1d9; font-weight: bold } /* Generic.Strong */
|
||||
.codehilite .gu { color: #79c0ff } /* Generic.Subheading */
|
||||
.codehilite .gt { color: #ff7b72 } /* Generic.Traceback */
|
||||
.codehilite .g-Underline { color: #c9d1d9; text-decoration: underline } /* Generic.Underline */
|
||||
.codehilite .kc { color: #79c0ff } /* Keyword.Constant */
|
||||
.codehilite .kd { color: #ff7b72 } /* Keyword.Declaration */
|
||||
.codehilite .kn { color: #ff7b72 } /* Keyword.Namespace */
|
||||
.codehilite .kp { color: #79c0ff } /* Keyword.Pseudo */
|
||||
.codehilite .kr { color: #ff7b72 } /* Keyword.Reserved */
|
||||
.codehilite .kt { color: #ff7b72 } /* Keyword.Type */
|
||||
.codehilite .ld { color: #79c0ff } /* Literal.Date */
|
||||
.codehilite .m { color: #a5d6ff } /* Literal.Number */
|
||||
.codehilite .s { color: #a5d6ff } /* Literal.String */
|
||||
.codehilite .na { color: #c9d1d9 } /* Name.Attribute */
|
||||
.codehilite .nb { color: #c9d1d9 } /* Name.Builtin */
|
||||
.codehilite .nc { color: #f0883e; font-weight: bold } /* Name.Class */
|
||||
.codehilite .no { color: #79c0ff; font-weight: bold } /* Name.Constant */
|
||||
.codehilite .nd { color: #d2a8ff; font-weight: bold } /* Name.Decorator */
|
||||
.codehilite .ni { color: #ffa657 } /* Name.Entity */
|
||||
.codehilite .ne { color: #f0883e; font-weight: bold } /* Name.Exception */
|
||||
.codehilite .nf { color: #d2a8ff; font-weight: bold } /* Name.Function */
|
||||
.codehilite .nl { color: #79c0ff; font-weight: bold } /* Name.Label */
|
||||
.codehilite .nn { color: #ff7b72 } /* Name.Namespace */
|
||||
.codehilite .nx { color: #c9d1d9 } /* Name.Other */
|
||||
.codehilite .py { color: #79c0ff } /* Name.Property */
|
||||
.codehilite .nt { color: #7ee787 } /* Name.Tag */
|
||||
.codehilite .nv { color: #79c0ff } /* Name.Variable */
|
||||
.codehilite .ow { color: #ff7b72; font-weight: bold } /* Operator.Word */
|
||||
.codehilite .pm { color: #c9d1d9 } /* Punctuation.Marker */
|
||||
.codehilite .w { color: #6e7681 } /* Text.Whitespace */
|
||||
.codehilite .mb { color: #a5d6ff } /* Literal.Number.Bin */
|
||||
.codehilite .mf { color: #a5d6ff } /* Literal.Number.Float */
|
||||
.codehilite .mh { color: #a5d6ff } /* Literal.Number.Hex */
|
||||
.codehilite .mi { color: #a5d6ff } /* Literal.Number.Integer */
|
||||
.codehilite .mo { color: #a5d6ff } /* Literal.Number.Oct */
|
||||
.codehilite .sa { color: #79c0ff } /* Literal.String.Affix */
|
||||
.codehilite .sb { color: #a5d6ff } /* Literal.String.Backtick */
|
||||
.codehilite .sc { color: #a5d6ff } /* Literal.String.Char */
|
||||
.codehilite .dl { color: #79c0ff } /* Literal.String.Delimiter */
|
||||
.codehilite .sd { color: #a5d6ff } /* Literal.String.Doc */
|
||||
.codehilite .s2 { color: #a5d6ff } /* Literal.String.Double */
|
||||
.codehilite .se { color: #79c0ff } /* Literal.String.Escape */
|
||||
.codehilite .sh { color: #79c0ff } /* Literal.String.Heredoc */
|
||||
.codehilite .si { color: #a5d6ff } /* Literal.String.Interpol */
|
||||
.codehilite .sx { color: #a5d6ff } /* Literal.String.Other */
|
||||
.codehilite .sr { color: #79c0ff } /* Literal.String.Regex */
|
||||
.codehilite .s1 { color: #a5d6ff } /* Literal.String.Single */
|
||||
.codehilite .ss { color: #a5d6ff } /* Literal.String.Symbol */
|
||||
.codehilite .bp { color: #c9d1d9 } /* Name.Builtin.Pseudo */
|
||||
.codehilite .fm { color: #d2a8ff; font-weight: bold } /* Name.Function.Magic */
|
||||
.codehilite .vc { color: #79c0ff } /* Name.Variable.Class */
|
||||
.codehilite .vg { color: #79c0ff } /* Name.Variable.Global */
|
||||
.codehilite .vi { color: #79c0ff } /* Name.Variable.Instance */
|
||||
.codehilite .vm { color: #79c0ff } /* Name.Variable.Magic */
|
||||
.codehilite .il { color: #a5d6ff } /* Literal.Number.Integer.Long */
|
||||
|
||||
.dark .codehilite .hll { background-color: #2C3B41 }
|
||||
.dark .codehilite .c { color: #79d618; font-style: italic } /* Comment */
|
||||
.dark .codehilite .err { color: #FF5370 } /* Error */
|
||||
.dark .codehilite .esc { color: #89DDFF } /* Escape */
|
||||
.dark .codehilite .g { color: #EEFFFF } /* Generic */
|
||||
.dark .codehilite .k { color: #BB80B3 } /* Keyword */
|
||||
.dark .codehilite .l { color: #C3E88D } /* Literal */
|
||||
.dark .codehilite .n { color: #EEFFFF } /* Name */
|
||||
.dark .codehilite .o { color: #89DDFF } /* Operator */
|
||||
.dark .codehilite .p { color: #89DDFF } /* Punctuation */
|
||||
.dark .codehilite .ch { color: #79d618; font-style: italic } /* Comment.Hashbang */
|
||||
.dark .codehilite .cm { color: #79d618; font-style: italic } /* Comment.Multiline */
|
||||
.dark .codehilite .cp { color: #79d618; font-style: italic } /* Comment.Preproc */
|
||||
.dark .codehilite .cpf { color: #79d618; font-style: italic } /* Comment.PreprocFile */
|
||||
.dark .codehilite .c1 { color: #79d618; font-style: italic } /* Comment.Single */
|
||||
.dark .codehilite .cs { color: #79d618; font-style: italic } /* Comment.Special */
|
||||
.dark .codehilite .gd { color: #FF5370 } /* Generic.Deleted */
|
||||
.dark .codehilite .ge { color: #89DDFF } /* Generic.Emph */
|
||||
.dark .codehilite .gr { color: #FF5370 } /* Generic.Error */
|
||||
.dark .codehilite .gh { color: #C3E88D } /* Generic.Heading */
|
||||
.dark .codehilite .gi { color: #C3E88D } /* Generic.Inserted */
|
||||
.dark .codehilite .go { color: #79d618 } /* Generic.Output */
|
||||
.dark .codehilite .gp { color: #FFCB6B } /* Generic.Prompt */
|
||||
.dark .codehilite .gs { color: #FF5370 } /* Generic.Strong */
|
||||
.dark .codehilite .gu { color: #89DDFF } /* Generic.Subheading */
|
||||
.dark .codehilite .gt { color: #FF5370 } /* Generic.Traceback */
|
||||
.dark .codehilite .kc { color: #89DDFF } /* Keyword.Constant */
|
||||
.dark .codehilite .kd { color: #BB80B3 } /* Keyword.Declaration */
|
||||
.dark .codehilite .kn { color: #89DDFF; font-style: italic } /* Keyword.Namespace */
|
||||
.dark .codehilite .kp { color: #89DDFF } /* Keyword.Pseudo */
|
||||
.dark .codehilite .kr { color: #BB80B3 } /* Keyword.Reserved */
|
||||
.dark .codehilite .kt { color: #BB80B3 } /* Keyword.Type */
|
||||
.dark .codehilite .ld { color: #C3E88D } /* Literal.Date */
|
||||
.dark .codehilite .m { color: #F78C6C } /* Literal.Number */
|
||||
.dark .codehilite .s { color: #C3E88D } /* Literal.String */
|
||||
.dark .codehilite .na { color: #BB80B3 } /* Name.Attribute */
|
||||
.dark .codehilite .nb { color: #82AAFF } /* Name.Builtin */
|
||||
.dark .codehilite .nc { color: #FFCB6B } /* Name.Class */
|
||||
.dark .codehilite .no { color: #EEFFFF } /* Name.Constant */
|
||||
.dark .codehilite .nd { color: #82AAFF } /* Name.Decorator */
|
||||
.dark .codehilite .ni { color: #89DDFF } /* Name.Entity */
|
||||
.dark .codehilite .ne { color: #FFCB6B } /* Name.Exception */
|
||||
.dark .codehilite .nf { color: #82AAFF } /* Name.Function */
|
||||
.dark .codehilite .nl { color: #82AAFF } /* Name.Label */
|
||||
.dark .codehilite .nn { color: #FFCB6B } /* Name.Namespace */
|
||||
.dark .codehilite .nx { color: #EEFFFF } /* Name.Other */
|
||||
.dark .codehilite .py { color: #FFCB6B } /* Name.Property */
|
||||
.dark .codehilite .nt { color: #FF5370 } /* Name.Tag */
|
||||
.dark .codehilite .nv { color: #89DDFF } /* Name.Variable */
|
||||
.dark .codehilite .ow { color: #89DDFF; font-style: italic } /* Operator.Word */
|
||||
.dark .codehilite .pm { color: #89DDFF } /* Punctuation.Marker */
|
||||
.dark .codehilite .w { color: #EEFFFF } /* Text.Whitespace */
|
||||
.dark .codehilite .mb { color: #F78C6C } /* Literal.Number.Bin */
|
||||
.dark .codehilite .mf { color: #F78C6C } /* Literal.Number.Float */
|
||||
.dark .codehilite .mh { color: #F78C6C } /* Literal.Number.Hex */
|
||||
.dark .codehilite .mi { color: #F78C6C } /* Literal.Number.Integer */
|
||||
.dark .codehilite .mo { color: #F78C6C } /* Literal.Number.Oct */
|
||||
.dark .codehilite .sa { color: #BB80B3 } /* Literal.String.Affix */
|
||||
.dark .codehilite .sb { color: #C3E88D } /* Literal.String.Backtick */
|
||||
.dark .codehilite .sc { color: #C3E88D } /* Literal.String.Char */
|
||||
.dark .codehilite .dl { color: #EEFFFF } /* Literal.String.Delimiter */
|
||||
.dark .codehilite .sd { color: #79d618; font-style: italic } /* Literal.String.Doc */
|
||||
.dark .codehilite .s2 { color: #C3E88D } /* Literal.String.Double */
|
||||
.dark .codehilite .se { color: #EEFFFF } /* Literal.String.Escape */
|
||||
.dark .codehilite .sh { color: #C3E88D } /* Literal.String.Heredoc */
|
||||
.dark .codehilite .si { color: #89DDFF } /* Literal.String.Interpol */
|
||||
.dark .codehilite .sx { color: #C3E88D } /* Literal.String.Other */
|
||||
.dark .codehilite .sr { color: #89DDFF } /* Literal.String.Regex */
|
||||
.dark .codehilite .s1 { color: #C3E88D } /* Literal.String.Single */
|
||||
.dark .codehilite .ss { color: #89DDFF } /* Literal.String.Symbol */
|
||||
.dark .codehilite .bp { color: #89DDFF } /* Name.Builtin.Pseudo */
|
||||
.dark .codehilite .fm { color: #82AAFF } /* Name.Function.Magic */
|
||||
.dark .codehilite .vc { color: #89DDFF } /* Name.Variable.Class */
|
||||
.dark .codehilite .vg { color: #89DDFF } /* Name.Variable.Global */
|
||||
.dark .codehilite .vi { color: #89DDFF } /* Name.Variable.Instance */
|
||||
.dark .codehilite .vm { color: #82AAFF } /* Name.Variable.Magic */
|
||||
.dark .codehilite .il { color: #F78C6C } /* Literal.Number.Integer.Long */
|
||||
104
theme/green.py
Normal file
104
theme/green.py
Normal file
@ -0,0 +1,104 @@
|
||||
import gradio as gr
|
||||
from toolbox import get_conf
|
||||
CODE_HIGHLIGHT, ADD_WAIFU, LAYOUT = get_conf('CODE_HIGHLIGHT', 'ADD_WAIFU', 'LAYOUT')
|
||||
|
||||
def adjust_theme():
|
||||
try:
|
||||
set_theme = gr.themes.Soft(
|
||||
primary_hue=gr.themes.Color(
|
||||
c50="#EBFAF2",
|
||||
c100="#CFF3E1",
|
||||
c200="#A8EAC8",
|
||||
c300="#77DEA9",
|
||||
c400="#3FD086",
|
||||
c500="#02C160",
|
||||
c600="#06AE56",
|
||||
c700="#05974E",
|
||||
c800="#057F45",
|
||||
c900="#04673D",
|
||||
c950="#2E5541",
|
||||
name="small_and_beautiful",
|
||||
),
|
||||
secondary_hue=gr.themes.Color(
|
||||
c50="#576b95",
|
||||
c100="#576b95",
|
||||
c200="#576b95",
|
||||
c300="#576b95",
|
||||
c400="#576b95",
|
||||
c500="#576b95",
|
||||
c600="#576b95",
|
||||
c700="#576b95",
|
||||
c800="#576b95",
|
||||
c900="#576b95",
|
||||
c950="#576b95",
|
||||
),
|
||||
neutral_hue=gr.themes.Color(
|
||||
name="gray",
|
||||
c50="#f6f7f8",
|
||||
# c100="#f3f4f6",
|
||||
c100="#F2F2F2",
|
||||
c200="#e5e7eb",
|
||||
c300="#d1d5db",
|
||||
c400="#B2B2B2",
|
||||
c500="#808080",
|
||||
c600="#636363",
|
||||
c700="#515151",
|
||||
c800="#393939",
|
||||
# c900="#272727",
|
||||
c900="#2B2B2B",
|
||||
c950="#171717",
|
||||
),
|
||||
|
||||
radius_size=gr.themes.sizes.radius_sm,
|
||||
).set(
|
||||
button_primary_background_fill="*primary_500",
|
||||
button_primary_background_fill_dark="*primary_600",
|
||||
button_primary_background_fill_hover="*primary_400",
|
||||
button_primary_border_color="*primary_500",
|
||||
button_primary_border_color_dark="*primary_600",
|
||||
button_primary_text_color="wihte",
|
||||
button_primary_text_color_dark="white",
|
||||
button_secondary_background_fill="*neutral_100",
|
||||
button_secondary_background_fill_hover="*neutral_50",
|
||||
button_secondary_background_fill_dark="*neutral_900",
|
||||
button_secondary_text_color="*neutral_800",
|
||||
button_secondary_text_color_dark="white",
|
||||
background_fill_primary="#F7F7F7",
|
||||
background_fill_primary_dark="#1F1F1F",
|
||||
block_title_text_color="*primary_500",
|
||||
block_title_background_fill_dark="*primary_900",
|
||||
block_label_background_fill_dark="*primary_900",
|
||||
input_background_fill="#F6F6F6",
|
||||
chatbot_code_background_color="*neutral_950",
|
||||
chatbot_code_background_color_dark="*neutral_950",
|
||||
)
|
||||
|
||||
js = ''
|
||||
if LAYOUT=="TOP-DOWN":
|
||||
js = ""
|
||||
else:
|
||||
with open('theme/common.js', 'r', encoding='utf8') as f:
|
||||
js = f"<script>{f.read()}</script>"
|
||||
|
||||
# 添加一个萌萌的看板娘
|
||||
if ADD_WAIFU:
|
||||
js += """
|
||||
<script src="file=docs/waifu_plugin/jquery.min.js"></script>
|
||||
<script src="file=docs/waifu_plugin/jquery-ui.min.js"></script>
|
||||
<script src="file=docs/waifu_plugin/autoload.js"></script>
|
||||
"""
|
||||
gradio_original_template_fn = gr.routes.templates.TemplateResponse
|
||||
def gradio_new_template_fn(*args, **kwargs):
|
||||
res = gradio_original_template_fn(*args, **kwargs)
|
||||
res.body = res.body.replace(b'</html>', f'{js}</html>'.encode("utf8"))
|
||||
res.init_headers()
|
||||
return res
|
||||
gr.routes.templates.TemplateResponse = gradio_new_template_fn # override gradio template
|
||||
except:
|
||||
set_theme = None
|
||||
print('gradio版本较旧, 不能自定义字体和颜色')
|
||||
return set_theme
|
||||
|
||||
|
||||
with open("theme/green.css", "r", encoding="utf-8") as f:
|
||||
advanced_css = f.read()
|
||||
12
theme/theme.py
Normal file
12
theme/theme.py
Normal file
@ -0,0 +1,12 @@
|
||||
import gradio as gr
|
||||
from toolbox import get_conf
|
||||
THEME, = get_conf('THEME')
|
||||
|
||||
if THEME == 'Chuanhu-Small-and-Beautiful':
|
||||
from .green import adjust_theme, advanced_css
|
||||
theme_declaration = "<h2 align=\"center\" class=\"small\">[Chuanhu-Small-and-Beautiful主题]</h2>"
|
||||
else:
|
||||
from .default import adjust_theme, advanced_css
|
||||
theme_declaration = ""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user