tabensemb.model.AbstractNN.to_torchscript#
method
- AbstractNN.to_torchscript(file_path: str | Path | None = None, method: str | None = 'script', example_inputs: Any | None = None, **kwargs: Any) ScriptModule | Dict[str, ScriptModule]#
By default compiles the whole model to a
ScriptModule. If you want to use tracing, please provided the argumentmethod='trace'and make sure that either the example_inputs argument is provided, or the model hasexample_input_arrayset. If you would like to customize the modules that are scripted you should override this method. In case you want to return multiple modules, we recommend using a dictionary.- Parameters:
file_path¶ – Path where to save the torchscript. Default: None (no file saved).
method¶ – Whether to use TorchScript’s script or trace method. Default: ‘script’
example_inputs¶ – An input to be used to do tracing when method is set to ‘trace’. Default: None (uses
example_input_array)**kwargs¶ – Additional arguments that will be passed to the
torch.jit.script()ortorch.jit.trace()function.
Note
Requires the implementation of the
forward()method.The exported script will be set to evaluation mode.
It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. See also the
torch.jitdocumentation for supported features.
Example
>>> class SimpleModel(LightningModule): ... def __init__(self): ... super().__init__() ... self.l1 = torch.nn.Linear(in_features=64, out_features=4) ... ... def forward(self, x): ... return torch.relu(self.l1(x.view(x.size(0), -1))) ... >>> import os >>> model = SimpleModel() >>> model.to_torchscript(file_path="model.pt") >>> os.path.isfile("model.pt") >>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', ... example_inputs=torch.randn(1, 64))) >>> os.path.isfile("model_trace.pt") True
- Returns:
This LightningModule as a torchscript, regardless of whether file_path is defined or not.