tabensemb.model.TorchModel.count_params#

method

TorchModel.count_params(model_name, trainable_only=False)[source]#

Count the number of parameters in a torch.nn.Module

Parameters:
model_name

The name of the selected model

trainable_only

Only count trainable (requires_grad=True) parameters.

Returns:
float

The number of parameters