Francis Duval
01/08/2024, 5:02 PMcasa:
type: pandas.CSVDataset
filepath: data/01_raw/casa.csv
load_args:
encoding: 'ISO-8859-1'
Do you know how I can log a model? For instance, I need to log a SentenceTransformer model:
st_model = SentenceTransformer('models/finetuned/mpnet_16_26epochs')
So I need something like:
st_model:
type: SentenceTransformerModel
filepath: data/01_raw/st_model.csv
datajoely
01/08/2024, 5:03 PMtype: pickle
filepath: ...
engine: pickle|dill|joblic|cloudpickle
Francis Duval
01/08/2024, 5:04 PMFrancis Duval
01/08/2024, 6:08 PMdatajoely
01/08/2024, 6:12 PMsave()
method and then contribute it back to Kedro!datajoely
01/08/2024, 6:12 PMFrancis Duval
01/08/2024, 7:16 PMfrom kedro.io import AbstractDataset
from sentence_transformers import SentenceTransformer
class STDataset(AbstractDataset):
def __init__(self, filepath: str):
self.filepath = filepath
def _load(self):
return SentenceTransformer(self.filepath)
def _save(self, model: SentenceTransformer):
model.save(self.filepath)
def _describe(self):
return {
"filepath": self.filepath
}
And then, in the catalog:
st_model:
type: kedro_tuto.datasets.st_dataset.STDataset
filepath: data/06_models/model_folder
datajoely
01/08/2024, 11:08 PMdatajoely
01/08/2024, 11:08 PMFrancis Duval
01/08/2024, 11:11 PMJuan Luis
01/09/2024, 2:34 PMdatajoely
01/09/2024, 3:10 PMSentenceTransformers
work with this?Francis Duval
01/09/2024, 3:11 PMFrancis Duval
01/09/2024, 3:27 PMsentence_transformer_model:
type: huggingface.HFTransformerPipelineDataset
model_name: data/06_models/mpnet_16_26epochs
The following may work, but I get an SSL error, and I think this is because my organization prevents us from downloading models directly from HuggingFace/SentenceTransformers.
sentence_transformer_model:
type: huggingface.HFTransformerPipelineDataset
model_name: sentence-transformers/all-mpnet-base-v2
Juan Luis
01/09/2024, 11:54 PM