add fine tune model
This commit is contained in:
@ -126,4 +126,8 @@ put your new bing cookies here
|
|||||||
# 阿里云实时语音识别 配置难度较高 仅建议高手用户使用 参考 https://help.aliyun.com/document_detail/450255.html
|
# 阿里云实时语音识别 配置难度较高 仅建议高手用户使用 参考 https://help.aliyun.com/document_detail/450255.html
|
||||||
ENABLE_AUDIO = False
|
ENABLE_AUDIO = False
|
||||||
ALIYUN_TOKEN="" # 例如 f37f30e0f9934c34a992f6f64f7eba4f
|
ALIYUN_TOKEN="" # 例如 f37f30e0f9934c34a992f6f64f7eba4f
|
||||||
ALIYUN_APPKEY="" # 例如 RoPlZrM88DnAFkZK
|
ALIYUN_APPKEY="" # 例如 RoPlZrM88DnAFkZK
|
||||||
|
|
||||||
|
|
||||||
|
# ChatGLM Finetune Model Path
|
||||||
|
ChatGLM_PTUNING_CHECKPOINT = ""
|
||||||
@ -269,6 +269,24 @@ if "newbing" in AVAIL_LLM_MODELS: # same with newbing-free
|
|||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
print(trimmed_format_exc())
|
||||||
|
if "chatglmft" in AVAIL_LLM_MODELS: # same with newbing-free
|
||||||
|
try:
|
||||||
|
from .bridge_chatglmft import predict_no_ui_long_connection as chatglmft_noui
|
||||||
|
from .bridge_chatglmft import predict as chatglmft_ui
|
||||||
|
# claude
|
||||||
|
model_info.update({
|
||||||
|
"chatglmft": {
|
||||||
|
"fn_with_ui": chatglmft_ui,
|
||||||
|
"fn_without_ui": chatglmft_noui,
|
||||||
|
"endpoint": None,
|
||||||
|
"max_token": 4096,
|
||||||
|
"tokenizer": tokenizer_gpt35,
|
||||||
|
"token_cnt": get_token_num_gpt35,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except:
|
||||||
|
print(trimmed_format_exc())
|
||||||
|
|
||||||
|
|
||||||
def LLM_CATCH_EXCEPTION(f):
|
def LLM_CATCH_EXCEPTION(f):
|
||||||
"""
|
"""
|
||||||
@ -372,6 +390,6 @@ def predict(inputs, llm_kwargs, *args, **kwargs):
|
|||||||
additional_fn代表点击的哪个按钮,按钮见functional.py
|
additional_fn代表点击的哪个按钮,按钮见functional.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method = model_info[llm_kwargs['llm_model']]["fn_with_ui"]
|
method = model_info[llm_kwargs['llm_model']]["fn_with_ui"] # 如果这里报错,检查config中的AVAIL_LLM_MODELS选项
|
||||||
yield from method(inputs, llm_kwargs, *args, **kwargs)
|
yield from method(inputs, llm_kwargs, *args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -59,12 +59,18 @@ class GetGLMFTHandle(Process):
|
|||||||
if self.chatglmft_model is None:
|
if self.chatglmft_model is None:
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
import torch
|
import torch
|
||||||
conf = 'request_llm\current_ptune_model.json'
|
# conf = 'request_llm/current_ptune_model.json'
|
||||||
if not os.path.exists(conf): raise RuntimeError('找不到微调模型信息')
|
# if not os.path.exists(conf): raise RuntimeError('找不到微调模型信息')
|
||||||
with open('request_llm\current_ptune_model.json', 'r', encoding='utf8') as f:
|
# with open(conf, 'r', encoding='utf8') as f:
|
||||||
model_args = json.loads(f.read())
|
# model_args = json.loads(f.read())
|
||||||
|
ChatGLM_PTUNING_CHECKPOINT, = get_conf('ChatGLM_PTUNING_CHECKPOINT')
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
conf = os.path.join(ChatGLM_PTUNING_CHECKPOINT, "config.json")
|
||||||
|
with open(conf, 'r', encoding='utf8') as f:
|
||||||
|
model_args_ = json.loads(f.read())
|
||||||
|
model_args_.update(model_args)
|
||||||
|
model_args = model_args_
|
||||||
|
|
||||||
|
self.chatglmft_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args['model_name_or_path'], trust_remote_code=True)
|
model_args['model_name_or_path'], trust_remote_code=True)
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_args['model_name_or_path'], trust_remote_code=True)
|
model_args['model_name_or_path'], trust_remote_code=True)
|
||||||
@ -72,17 +78,14 @@ class GetGLMFTHandle(Process):
|
|||||||
config.pre_seq_len = model_args['pre_seq_len']
|
config.pre_seq_len = model_args['pre_seq_len']
|
||||||
config.prefix_projection = model_args['prefix_projection']
|
config.prefix_projection = model_args['prefix_projection']
|
||||||
|
|
||||||
if model_args['ptuning_checkpoint'] is not None:
|
print(f"Loading prefix_encoder weight from {ChatGLM_PTUNING_CHECKPOINT}")
|
||||||
print(f"Loading prefix_encoder weight from {model_args['ptuning_checkpoint']}")
|
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
|
||||||
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
|
prefix_state_dict = torch.load(os.path.join(ChatGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
|
||||||
prefix_state_dict = torch.load(os.path.join(model_args['ptuning_checkpoint'], "pytorch_model.bin"))
|
new_prefix_state_dict = {}
|
||||||
new_prefix_state_dict = {}
|
for k, v in prefix_state_dict.items():
|
||||||
for k, v in prefix_state_dict.items():
|
if k.startswith("transformer.prefix_encoder."):
|
||||||
if k.startswith("transformer.prefix_encoder."):
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||||
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
|
||||||
else:
|
|
||||||
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
|
|
||||||
|
|
||||||
if model_args['quantization_bit'] is not None:
|
if model_args['quantization_bit'] is not None:
|
||||||
print(f"Quantized to {model_args['quantization_bit']} bit")
|
print(f"Quantized to {model_args['quantization_bit']} bit")
|
||||||
@ -91,13 +94,12 @@ class GetGLMFTHandle(Process):
|
|||||||
if model_args['pre_seq_len'] is not None:
|
if model_args['pre_seq_len'] is not None:
|
||||||
# P-tuning v2
|
# P-tuning v2
|
||||||
model.transformer.prefix_encoder.float()
|
model.transformer.prefix_encoder.float()
|
||||||
|
self.chatglmft_model = model.eval()
|
||||||
model = model.eval()
|
|
||||||
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
except:
|
except Exception as e:
|
||||||
retry += 1
|
retry += 1
|
||||||
if retry > 3:
|
if retry > 3:
|
||||||
self.child.send('[Local Message] Call ChatGLMFT fail 不能正常加载ChatGLMFT的参数。')
|
self.child.send('[Local Message] Call ChatGLMFT fail 不能正常加载ChatGLMFT的参数。')
|
||||||
|
|||||||
Reference in New Issue
Block a user