Optimizing the code, requests. exceptions. ConnectionError should be written in the post request instead of reading from the iterator. If the post request is unsuccessful, it will not be executed to the iterator step.

This commit is contained in:
kainstan
2023-06-06 16:10:54 +08:00
parent 0e1de5a184
commit 2da36c7667

View File

@ -62,15 +62,16 @@ def predict_no_ui_long_connection(
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可 watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=True) headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=True)
retry = 0 retry = 0
from bridge_all import model_info
while True: while True:
try: try:
# make a POST request to the API endpoint, stream=False # make a POST request to the API endpoint, stream=False
from bridge_all import model_info
endpoint = model_info[llm_kwargs['llm_model']]['endpoint'] endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
response = requests.post(endpoint, headers=headers, proxies=proxies, response = requests.post(endpoint, headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS) json=payload, stream=True, timeout=TIMEOUT_SECONDS)
stream_response = response.iter_lines()
break break
except requests.exceptions.ReadTimeout: except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectionError):
retry += 1 retry += 1
traceback.print_exc() traceback.print_exc()
if retry > MAX_RETRY: if retry > MAX_RETRY:
@ -81,15 +82,14 @@ def predict_no_ui_long_connection(
print(f"出现异常:{e}") print(f"出现异常:{e}")
raise e raise e
stream_response = response.iter_lines()
result = '' result = ''
while True: while True:
try: try:
chunk = next(stream_response).decode() chunk = next(stream_response).decode()
except StopIteration: except StopIteration:
break break
except requests.exceptions.ConnectionError: # except requests.exceptions.ConnectionError:
chunk = next(stream_response).decode() # 失败了,重试一次?再失败就没办法了。 # chunk = next(stream_response).decode() # 失败了,重试一次?再失败就没办法了。
if len(chunk) == 0: if len(chunk) == 0:
continue continue
if not chunk.startswith('data:'): if not chunk.startswith('data:'):
@ -98,11 +98,14 @@ def predict_no_ui_long_connection(
raise ConnectionAbortedError("OpenAI拒绝了请求:" + error_msg) raise ConnectionAbortedError("OpenAI拒绝了请求:" + error_msg)
else: else:
raise RuntimeError("OpenAI拒绝了请求" + error_msg) raise RuntimeError("OpenAI拒绝了请求" + error_msg)
if ('data: [DONE]' in chunk): break # api2d 正常完成 if 'data: [DONE]' in chunk:
break # api2d 正常完成
json_data = json.loads(chunk.lstrip('data:'))['choices'][0] json_data = json.loads(chunk.lstrip('data:'))['choices'][0]
delta = json_data["delta"] delta = json_data["delta"]
if len(delta) == 0: break if len(delta) == 0:
if "role" in delta: continue break
if "role" in delta:
continue
if "content" in delta: if "content" in delta:
result += delta["content"] result += delta["content"]
if not console_slience: if not console_slience: