Francis Duval
01/26/2024, 8:36 PMmy_pytorch_model:
type: pickle.PickleDataset
filepath: data/05_models/my_pytorch_model.pkl
backend: pickle
I wish you a great Friday πmarrrcin
01/27/2024, 9:09 AMFrancis Duval
01/27/2024, 5:40 PMmarrrcin
01/29/2024, 10:15 AMFrancis Duval
01/29/2024, 1:44 PMFrancis Duval
01/29/2024, 2:21 PMmarrrcin
01/29/2024, 2:26 PMclass TorchScriptModelDataSet(AbstractDataSet):
def __init__(
self,
filepath: str,
map_location: str = "cpu",
fs_args: Optional[Dict] = None,
credentials: Optional[Dict] = None,
) -> None:
super().__init__()
self.filepath = filepath
self.map_location = map_location
_fs_args = deepcopy(fs_args) or {}
_credentials = deepcopy(credentials) or {}
protocol, path = get_protocol_and_path(filepath)
if protocol == "file":
_fs_args.setdefault("auto_mkdir", True)
self._protocol = protocol
self._storage_options = {**_credentials, **_fs_args}
self._fs: AbstractFileSystem = fsspec.filesystem(
self._protocol, **self._storage_options
)
def _load(self) -> ScriptModule:
with self._fs.open(self.filepath, "rb") as f:
return torch.jit.load(f, self.map_location)
def _save(self, data: ScriptModule) -> None:
with self._fs.open(self.filepath, "wb") as f:
return torch.jit.save(data, f)
def _describe(self) -> Dict[str, Any]:
return {"type": "Torch Script Model"}
marrrcin
01/29/2024, 2:26 PMmodel:
type: kedro_pytorch_demo.datasets.TorchScriptModelDataSet
filepath: data/model.pt
Francis Duval
01/29/2024, 2:28 PMmarrrcin
01/29/2024, 2:29 PMmarrrcin
01/29/2024, 2:29 PMFrancis Duval
01/29/2024, 2:29 PMmarrrcin
01/29/2024, 2:29 PMmarrrcin
01/29/2024, 2:30 PMmarrrcin
01/29/2024, 2:30 PMFrancis Duval
01/29/2024, 2:32 PMFrancis Duval
01/29/2024, 3:02 PMdef extract_768_20_params(model):
weights = model.module_.fc1.weight.data
biases = model.module_.fc1.bias.data
return {'weights': weights, 'biases': biases}