tabensemb.model.AbstractNN.to_onnx#

method

AbstractNN.to_onnx(file_path: str | Path, input_sample: Any | None = None, **kwargs: Any) None#

Saves the model in ONNX format.

Parameters:
  • file_path – The path of the file the onnx model should be saved to.

  • input_sample – An input for tracing. Default: None (Use self.example_input_array)

  • **kwargs – Will be passed to torch.onnx.export function.

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

import os, tempfile

model = SimpleModel()
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile:
    model.to_onnx(tmpfile.name, torch.randn((1, 64)), export_params=True)
    os.path.isfile(tmpfile.name)