tabensemb.model.AbstractNN.to_torchscript#

method

AbstractNN.to_torchscript(file_path: str | Path | None = None, method: str | None = 'script', example_inputs: Any | None = None, **kwargs: Any) ScriptModule | Dict[str, ScriptModule]#

By default compiles the whole model to a ScriptModule. If you want to use tracing, please provided the argument method='trace' and make sure that either the example_inputs argument is provided, or the model has example_input_array set. If you would like to customize the modules that are scripted you should override this method. In case you want to return multiple modules, we recommend using a dictionary.

Parameters:
  • file_path – Path where to save the torchscript. Default: None (no file saved).

  • method – Whether to use TorchScript’s script or trace method. Default: ‘script’

  • example_inputs – An input to be used to do tracing when method is set to ‘trace’. Default: None (uses example_input_array)

  • **kwargs – Additional arguments that will be passed to the torch.jit.script() or torch.jit.trace() function.

Note

  • Requires the implementation of the forward() method.

  • The exported script will be set to evaluation mode.

  • It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. See also the torch.jit documentation for supported features.

Example

>>> class SimpleModel(LightningModule):
...     def __init__(self):
...         super().__init__()
...         self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
...     def forward(self, x):
...         return torch.relu(self.l1(x.view(x.size(0), -1)))
...
>>> import os
>>> model = SimpleModel()
>>> model.to_torchscript(file_path="model.pt")  
>>> os.path.isfile("model.pt")  
>>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', 
...                                     example_inputs=torch.randn(1, 64)))  
>>> os.path.isfile("model_trace.pt")  
True
Returns:

This LightningModule as a torchscript, regardless of whether file_path is defined or not.