meharji arumilli
12/27/2022, 1:42 PMFor non-spark objects I used to save/read from the catalog as:
lightgbm_model:
type: pickle.PickleDataSet
filepath: <s3://bucket/data/lightgbm_model.pkl>
backend: pickle
How can I save if the 'lightgbm_model' model is from spark pipeline?
William Caicedo
12/27/2022, 7:27 PMfrom kedro.extras.datasets.spark import SparkDataSet
from <http://pyspark.ml|pyspark.ml> import PipelineModel
class PySparkMLPipelineDataSet(SparkDataSet):
def _load(self) -> PipelineModel:
load_path = self._fs_prefix + str(self._get_load_path())
self._get_spark()
return PipelineModel.load(load_path)
def _save(self, pipeline: PipelineModel) -> None:
save_path = self._fs_prefix + str(self._get_save_path())
pipeline.write().overwrite().save(save_path)
def _exists(self) -> bool:
load_path = self._fs_prefix + str(self._get_load_path())
try:
PipelineModel.load(load_path)
except AnalysisException as exception:
if (
exception.desc.startswith("Path does not exist:")
):
return False
raise
return True
meharji arumilli
12/28/2022, 3:18 PMlightgbm_model:
type: pickle.PySparkMLPipelineDataSet
filepath:<s3://bucket/data/lightgbm_model.pkl>