https://kedro.org/ logo
#questions
Title
# questions
f

Francis Duval

01/26/2024, 8:36 PM
Hello all! When I attempt to save a PyTorch model with a pickle dataset, it only saves the weights and biases as a dictionary. How can I save the whole model? This documentation page that lists all the Kedro datasets (https://docs.kedro.org/projects/kedro-datasets/en/kedro-datasets-2.0.0/) does not seem to include a class to save a PyTorch model, which I find weird since there is one for TensorFlow models. Okay, I could create my own custom AbstractDataset class, but I prefer using the built-in one! So as I said, this will only save a dictionary containing the weights and biases of the PyTorch model:
Copy code
my_pytorch_model:
  type: pickle.PickleDataset
  filepath: data/05_models/my_pytorch_model.pkl
  backend: pickle
I wish you a great Friday πŸ™‚
m

marrrcin

01/27/2024, 9:09 AM
Why are you trying to pickle PyTorch model?
Use torch script or sth PyTorch native Check this video around 23:29

https://youtu.be/u2hRsCSDcQ8?si=j46A1D31voSmQNCKβ–Ύ

f

Francis Duval

01/27/2024, 5:40 PM
I don't know, i thought it was possible πŸ˜‚ Thanks!
πŸ‘ 1
m

marrrcin

01/29/2024, 10:15 AM
It’s better to stick to native formats for given libraries πŸ™‚
πŸ‘ 1
f

Francis Duval

01/29/2024, 1:44 PM
@marrrcin, do you have the code somewhere for you class TorchScriptModelDataSet? I did not find it on GitHub.
This documentation page from Skorch says we can save a NeuralNet object with pickle, that is why I thought it was possible: https://skorch.readthedocs.io/en/stable/user/save_load.html
m

marrrcin

01/29/2024, 2:26 PM
Copy code
class TorchScriptModelDataSet(AbstractDataSet):
    def __init__(
        self,
        filepath: str,
        map_location: str = "cpu",
        fs_args: Optional[Dict] = None,
        credentials: Optional[Dict] = None,
    ) -> None:
        super().__init__()
        self.filepath = filepath
        self.map_location = map_location

        _fs_args = deepcopy(fs_args) or {}
        _credentials = deepcopy(credentials) or {}

        protocol, path = get_protocol_and_path(filepath)
        if protocol == "file":
            _fs_args.setdefault("auto_mkdir", True)

        self._protocol = protocol
        self._storage_options = {**_credentials, **_fs_args}

        self._fs: AbstractFileSystem = fsspec.filesystem(
            self._protocol, **self._storage_options
        )

    def _load(self) -> ScriptModule:
        with self._fs.open(self.filepath, "rb") as f:
            return torch.jit.load(f, self.map_location)

    def _save(self, data: ScriptModule) -> None:
        with self._fs.open(self.filepath, "wb") as f:
            return torch.jit.save(data, f)

    def _describe(self) -> Dict[str, Any]:
        return {"type": "Torch Script Model"}
πŸ₯³ 1
Usage
Copy code
model:
  type: kedro_pytorch_demo.datasets.TorchScriptModelDataSet
  filepath: data/model.pt
f

Francis Duval

01/29/2024, 2:28 PM
Thanks! Unfortunately, the object I want to save is a NeuralNetRegressor (from Skorch package), which does not have the method 'to_torchscript', so I guess I'll need to train my nn with something else than Skorch!
m

marrrcin

01/29/2024, 2:29 PM
Lightning - always πŸ˜„
f

Francis Duval

01/29/2024, 2:29 PM
Oh yes? But is is compatible with SKLearn?
m

marrrcin

01/29/2024, 2:29 PM
Most likely not
Do you need it though? Skroch is a high level wrapper around PyTorch that sometimes is too high level
I get all the `fit`/`predict` approach from scikit, but imho it does not apply that well in NN world
f

Francis Duval

01/29/2024, 2:32 PM
Interesting! Maybe lightning is better suited for my needs, I'll try it. I'm new to Python so I just took the first package I came across πŸ˜‚ Many thanks!
Very weird because my pickle-saved fitted NeuralNetRegressor object is a dictionary (with keys 'weights' and 'biases') when I load it with catalog.load(). However, when I run the pipeline, it is an actual NeuralNetRegressor object, since I use this function on it:
Copy code
def extract_768_20_params(model):
    weights = model.module_.fc1.weight.data
    biases = model.module_.fc1.bias.data

    return {'weights': weights, 'biases': biases}