tabensemb.model.AbstractNN.get_loss_fn#

method

static AbstractNN.get_loss_fn(loss, task) Module[source]#

The loss function for the output of forward and the target.

Parameters:
loss

“cross_entropy”, “mae”, or “mse”

task

“regression”, “multiclass”, or “binary”

Returns:
nn.Module

The loss function.