71 lines
1.9 KiB
Python
71 lines
1.9 KiB
Python
from typing import Any, overload
|
|
|
|
import numpy as np
|
|
import openai
|
|
|
|
from autogpt.config import Config
|
|
from autogpt.llm.utils import metered, retry_openai_api
|
|
from autogpt.logs import logger
|
|
|
|
Embedding = list[np.float32] | np.ndarray[Any, np.dtype[np.float32]]
|
|
"""Embedding vector"""
|
|
TText = list[int]
|
|
"""Token array representing text"""
|
|
|
|
|
|
@overload
|
|
def get_embedding(input: str | TText) -> Embedding:
|
|
...
|
|
|
|
|
|
@overload
|
|
def get_embedding(input: list[str] | list[TText]) -> list[Embedding]:
|
|
...
|
|
|
|
|
|
@metered
|
|
@retry_openai_api()
|
|
def get_embedding(
|
|
input: str | TText | list[str] | list[TText],
|
|
) -> Embedding | list[Embedding]:
|
|
"""Get an embedding from the ada model.
|
|
|
|
Args:
|
|
input: Input text to get embeddings for, encoded as a string or array of tokens.
|
|
Multiple inputs may be given as a list of strings or token arrays.
|
|
|
|
Returns:
|
|
List[float]: The embedding.
|
|
"""
|
|
cfg = Config()
|
|
multiple = isinstance(input, list) and all(not isinstance(i, int) for i in input)
|
|
|
|
if isinstance(input, str):
|
|
input = input.replace("\n", " ")
|
|
elif multiple and isinstance(input[0], str):
|
|
input = [text.replace("\n", " ") for text in input]
|
|
|
|
model = cfg.embedding_model
|
|
if cfg.use_azure:
|
|
kwargs = {"engine": cfg.get_azure_deployment_id_for_model(model)}
|
|
else:
|
|
kwargs = {"model": model}
|
|
|
|
logger.debug(
|
|
f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}"
|
|
f" with model '{model}'"
|
|
+ (f" via Azure deployment '{kwargs['engine']}'" if cfg.use_azure else "")
|
|
)
|
|
|
|
embeddings = openai.Embedding.create(
|
|
input=input,
|
|
api_key=cfg.openai_api_key,
|
|
**kwargs,
|
|
).data
|
|
|
|
if not multiple:
|
|
return embeddings[0]["embedding"]
|
|
|
|
embeddings = sorted(embeddings, key=lambda x: x["index"])
|
|
return [d["embedding"] for d in embeddings]
|