Files
gpt_academic/autogpt/memory/local.py
2023-04-28 09:53:51 +08:00

127 lines
3.2 KiB
Python

from __future__ import annotations
import dataclasses
from pathlib import Path
from typing import Any, List
import numpy as np
import orjson
from autogpt.llm_utils import create_embedding_with_ada
from autogpt.memory.base import MemoryProviderSingleton
EMBED_DIM = 1536
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS
def create_default_embeddings():
return np.zeros((0, EMBED_DIM)).astype(np.float32)
@dataclasses.dataclass
class CacheContent:
texts: List[str] = dataclasses.field(default_factory=list)
embeddings: np.ndarray = dataclasses.field(
default_factory=create_default_embeddings
)
class LocalCache(MemoryProviderSingleton):
"""A class that stores the memory in a local file"""
def __init__(self, cfg) -> None:
"""Initialize a class instance
Args:
cfg: Config object
Returns:
None
"""
workspace_path = Path(cfg.workspace_path)
self.filename = workspace_path / f"{cfg.memory_index}.json"
self.filename.touch(exist_ok=True)
file_content = b"{}"
with self.filename.open("w+b") as f:
f.write(file_content)
self.data = CacheContent()
def add(self, text: str):
"""
Add text to our list of texts, add embedding as row to our
embeddings-matrix
Args:
text: str
Returns: None
"""
if "Command Error:" in text:
return ""
self.data.texts.append(text)
embedding = create_embedding_with_ada(text)
vector = np.array(embedding).astype(np.float32)
vector = vector[np.newaxis, :]
self.data.embeddings = np.concatenate(
[
self.data.embeddings,
vector,
],
axis=0,
)
with open(self.filename, "wb") as f:
out = orjson.dumps(self.data, option=SAVE_OPTIONS)
f.write(out)
return text
def clear(self) -> str:
"""
Clears the redis server.
Returns: A message indicating that the memory has been cleared.
"""
self.data = CacheContent()
return "Obliviated"
def get(self, data: str) -> list[Any] | None:
"""
Gets the data from the memory that is most relevant to the given data.
Args:
data: The data to compare to.
Returns: The most relevant data.
"""
return self.get_relevant(data, 1)
def get_relevant(self, text: str, k: int) -> list[Any]:
""" "
matrix-vector mult to find score-for-each-row-of-matrix
get indices for top-k winning scores
return texts for those indices
Args:
text: str
k: int
Returns: List[str]
"""
embedding = create_embedding_with_ada(text)
scores = np.dot(self.data.embeddings, embedding)
top_k_indices = np.argsort(scores)[-k:][::-1]
return [self.data.texts[i] for i in top_k_indices]
def get_stats(self) -> tuple[int, tuple[int, ...]]:
"""
Returns: The stats of the local cache.
"""
return len(self.data.texts), self.data.embeddings.shape