diff --git a/config.py b/config.py index 4fbb331..f1d96d5 100644 --- a/config.py +++ b/config.py @@ -126,4 +126,8 @@ put your new bing cookies here # 阿里云实时语音识别 配置难度较高 仅建议高手用户使用 参考 https://help.aliyun.com/document_detail/450255.html ENABLE_AUDIO = False ALIYUN_TOKEN="" # 例如 f37f30e0f9934c34a992f6f64f7eba4f -ALIYUN_APPKEY="" # 例如 RoPlZrM88DnAFkZK \ No newline at end of file +ALIYUN_APPKEY="" # 例如 RoPlZrM88DnAFkZK + + +# ChatGLM Finetune Model Path +ChatGLM_PTUNING_CHECKPOINT = "" \ No newline at end of file diff --git a/request_llm/bridge_all.py b/request_llm/bridge_all.py index 13f49bd..ed9ceb0 100644 --- a/request_llm/bridge_all.py +++ b/request_llm/bridge_all.py @@ -269,6 +269,24 @@ if "newbing" in AVAIL_LLM_MODELS: # same with newbing-free }) except: print(trimmed_format_exc()) +if "chatglmft" in AVAIL_LLM_MODELS: # same with newbing-free + try: + from .bridge_chatglmft import predict_no_ui_long_connection as chatglmft_noui + from .bridge_chatglmft import predict as chatglmft_ui + # claude + model_info.update({ + "chatglmft": { + "fn_with_ui": chatglmft_ui, + "fn_without_ui": chatglmft_noui, + "endpoint": None, + "max_token": 4096, + "tokenizer": tokenizer_gpt35, + "token_cnt": get_token_num_gpt35, + } + }) + except: + print(trimmed_format_exc()) + def LLM_CATCH_EXCEPTION(f): """ @@ -372,6 +390,6 @@ def predict(inputs, llm_kwargs, *args, **kwargs): additional_fn代表点击的哪个按钮,按钮见functional.py """ - method = model_info[llm_kwargs['llm_model']]["fn_with_ui"] + method = model_info[llm_kwargs['llm_model']]["fn_with_ui"] # 如果这里报错,检查config中的AVAIL_LLM_MODELS选项 yield from method(inputs, llm_kwargs, *args, **kwargs) diff --git a/request_llm/bridge_chatglmft.py b/request_llm/bridge_chatglmft.py index 27043fb..6c2fe25 100644 --- a/request_llm/bridge_chatglmft.py +++ b/request_llm/bridge_chatglmft.py @@ -59,12 +59,18 @@ class GetGLMFTHandle(Process): if self.chatglmft_model is None: from transformers import AutoConfig import torch - conf = 'request_llm\current_ptune_model.json' - if not os.path.exists(conf): raise RuntimeError('找不到微调模型信息') - with open('request_llm\current_ptune_model.json', 'r', encoding='utf8') as f: - model_args = json.loads(f.read()) - - tokenizer = AutoTokenizer.from_pretrained( + # conf = 'request_llm/current_ptune_model.json' + # if not os.path.exists(conf): raise RuntimeError('找不到微调模型信息') + # with open(conf, 'r', encoding='utf8') as f: + # model_args = json.loads(f.read()) + ChatGLM_PTUNING_CHECKPOINT, = get_conf('ChatGLM_PTUNING_CHECKPOINT') + conf = os.path.join(ChatGLM_PTUNING_CHECKPOINT, "config.json") + with open(conf, 'r', encoding='utf8') as f: + model_args_ = json.loads(f.read()) + model_args_.update(model_args) + model_args = model_args_ + + self.chatglmft_tokenizer = AutoTokenizer.from_pretrained( model_args['model_name_or_path'], trust_remote_code=True) config = AutoConfig.from_pretrained( model_args['model_name_or_path'], trust_remote_code=True) @@ -72,17 +78,14 @@ class GetGLMFTHandle(Process): config.pre_seq_len = model_args['pre_seq_len'] config.prefix_projection = model_args['prefix_projection'] - if model_args['ptuning_checkpoint'] is not None: - print(f"Loading prefix_encoder weight from {model_args['ptuning_checkpoint']}") - model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True) - prefix_state_dict = torch.load(os.path.join(model_args['ptuning_checkpoint'], "pytorch_model.bin")) - new_prefix_state_dict = {} - for k, v in prefix_state_dict.items(): - if k.startswith("transformer.prefix_encoder."): - new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v - model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) - else: - model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True) + print(f"Loading prefix_encoder weight from {ChatGLM_PTUNING_CHECKPOINT}") + model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True) + prefix_state_dict = torch.load(os.path.join(ChatGLM_PTUNING_CHECKPOINT, "pytorch_model.bin")) + new_prefix_state_dict = {} + for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v + model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) if model_args['quantization_bit'] is not None: print(f"Quantized to {model_args['quantization_bit']} bit") @@ -91,13 +94,12 @@ class GetGLMFTHandle(Process): if model_args['pre_seq_len'] is not None: # P-tuning v2 model.transformer.prefix_encoder.float() - - model = model.eval() + self.chatglmft_model = model.eval() break else: break - except: + except Exception as e: retry += 1 if retry > 3: self.child.send('[Local Message] Call ChatGLMFT fail 不能正常加载ChatGLMFT的参数。')