add fine tune model

This commit is contained in:
binary-husky
2023-07-10 03:00:49 +08:00
parent 3cda37a6a6
commit b585ae6835
3 changed files with 46 additions and 22 deletions

View File

@ -126,4 +126,8 @@ put your new bing cookies here
# 阿里云实时语音识别 配置难度较高 仅建议高手用户使用 参考 https://help.aliyun.com/document_detail/450255.html
ENABLE_AUDIO = False
ALIYUN_TOKEN="" # 例如 f37f30e0f9934c34a992f6f64f7eba4f
ALIYUN_APPKEY="" # 例如 RoPlZrM88DnAFkZK
ALIYUN_APPKEY="" # 例如 RoPlZrM88DnAFkZK
# ChatGLM Finetune Model Path
ChatGLM_PTUNING_CHECKPOINT = ""

View File

@ -269,6 +269,24 @@ if "newbing" in AVAIL_LLM_MODELS: # same with newbing-free
})
except:
print(trimmed_format_exc())
if "chatglmft" in AVAIL_LLM_MODELS: # same with newbing-free
try:
from .bridge_chatglmft import predict_no_ui_long_connection as chatglmft_noui
from .bridge_chatglmft import predict as chatglmft_ui
# claude
model_info.update({
"chatglmft": {
"fn_with_ui": chatglmft_ui,
"fn_without_ui": chatglmft_noui,
"endpoint": None,
"max_token": 4096,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
}
})
except:
print(trimmed_format_exc())
def LLM_CATCH_EXCEPTION(f):
"""
@ -372,6 +390,6 @@ def predict(inputs, llm_kwargs, *args, **kwargs):
additional_fn代表点击的哪个按钮按钮见functional.py
"""
method = model_info[llm_kwargs['llm_model']]["fn_with_ui"]
method = model_info[llm_kwargs['llm_model']]["fn_with_ui"] # 如果这里报错检查config中的AVAIL_LLM_MODELS选项
yield from method(inputs, llm_kwargs, *args, **kwargs)

View File

@ -59,12 +59,18 @@ class GetGLMFTHandle(Process):
if self.chatglmft_model is None:
from transformers import AutoConfig
import torch
conf = 'request_llm\current_ptune_model.json'
if not os.path.exists(conf): raise RuntimeError('找不到微调模型信息')
with open('request_llm\current_ptune_model.json', 'r', encoding='utf8') as f:
model_args = json.loads(f.read())
tokenizer = AutoTokenizer.from_pretrained(
# conf = 'request_llm/current_ptune_model.json'
# if not os.path.exists(conf): raise RuntimeError('找不到微调模型信息')
# with open(conf, 'r', encoding='utf8') as f:
# model_args = json.loads(f.read())
ChatGLM_PTUNING_CHECKPOINT, = get_conf('ChatGLM_PTUNING_CHECKPOINT')
conf = os.path.join(ChatGLM_PTUNING_CHECKPOINT, "config.json")
with open(conf, 'r', encoding='utf8') as f:
model_args_ = json.loads(f.read())
model_args_.update(model_args)
model_args = model_args_
self.chatglmft_tokenizer = AutoTokenizer.from_pretrained(
model_args['model_name_or_path'], trust_remote_code=True)
config = AutoConfig.from_pretrained(
model_args['model_name_or_path'], trust_remote_code=True)
@ -72,17 +78,14 @@ class GetGLMFTHandle(Process):
config.pre_seq_len = model_args['pre_seq_len']
config.prefix_projection = model_args['prefix_projection']
if model_args['ptuning_checkpoint'] is not None:
print(f"Loading prefix_encoder weight from {model_args['ptuning_checkpoint']}")
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(model_args['ptuning_checkpoint'], "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
print(f"Loading prefix_encoder weight from {ChatGLM_PTUNING_CHECKPOINT}")
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(ChatGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
if model_args['quantization_bit'] is not None:
print(f"Quantized to {model_args['quantization_bit']} bit")
@ -91,13 +94,12 @@ class GetGLMFTHandle(Process):
if model_args['pre_seq_len'] is not None:
# P-tuning v2
model.transformer.prefix_encoder.float()
model = model.eval()
self.chatglmft_model = model.eval()
break
else:
break
except:
except Exception as e:
retry += 1
if retry > 3:
self.child.send('[Local Message] Call ChatGLMFT fail 不能正常加载ChatGLMFT的参数。')