Ian Whalen
10/19/2023, 8:41 PM_load
method to be pretty simple. Something like:
# Inside my class FAISSDataSet:
def _load(self): return FAISS.load_local(self.filepath, self.embeddings)
I have another dataset to handle loading embeddings
(see thread). Specifically the OpenAIEmbeddings
class in langchain.
However, I’m not exactly sure how I’d get this into my hypothetical FAISSDataSet
.
Ideally, I could version control the FAISSDataSet
.
• OpenAIEmbeddings
can be pickled, but my credentials are going to be changing constantly. (and it feels icky to pickle something with secrets in it)
• So I’m considering taking in the API credentials to both OpenAIEmbeddingsDataSet
and FAISSDataSet
• Then I can do something hacky like save everything but the credentials of the embedding inside FAISSDataSet._save
and then swap them in on a _load
to handle versioning.
Maybe I’m missing something though. Does this seem logical?OpenAIEmbeddingsDataSet
for those who are interested.
class OpenAIEmbeddingsDataSet(AbstractDataSet[None, OpenAIEmbeddings]):
"""OpenAI Embeddings dataset.
Must be a dataset to access credentials at runtime.
"""
def __init__(self, credentials: Dict[str, str], **kwargs):
"""Constructor.
Args:
credentials: must contain `openai_api_base` and `openai_api_key`.
**kwargs: keyword arguments passed to the `OpenAIEmbeddings` class.
"""
self.openai_api_base = credentials["openai_api_base"]
self.openai_api_key = credentials["openai_api_key"]
self.kwargs = kwargs
def _describe(self) -> dict[str, Any]:
return {**self.kwargs}
def _save(self, data: None) -> NoReturn:
raise DatasetError(f"{self.__class__.__name__} is a read only data set type")
def _load(self) -> OpenAIEmbeddings:
return OpenAIEmbeddings(
openai_api_base=self.openai_api_base,
openai_api_key=self.openai_api_key,
**self.kwargs,
)
from kedro.extras.datasets.pickle import PickleDataSet
from langchain.vectorstores import FAISS
class FAISSDataSet(PickleDataSet):
"""Saves and loads a FAISS vector store."""
# TODO: find a better way to do this dataset.
# - Using anything but an API will also serialize the embedding model which will be too big.
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.openai_api_base = kwargs["credentials"]["openai_api_base"]
self.openai_api_key = kwargs["credentials"]["openai_api_key"]
def _load(self) -> FAISS:
faiss = super()._load()
# TODO: this assumes we're using an OpenAIEmbeddings embedding function.
faiss.embedding_function.__self__.openai_api_base = self.openai_api_base
faiss.embedding_function.__self__.openai_api_key = self.openai_api_key
return faiss
See TODOs as wellEmilio Gagliardi
11/27/2023, 7:31 PM