tabensemb.model.AbstractWrapper#

class tabensemb.model.AbstractWrapper(model: AbstractModel)[source]#

Bases: object

For those required deep learning models, this is a wrapper to make them have hidden information like hidden_representation or something else extracted from the forward process.

Attributes:
hidden_rep_dim

The dimension of hidden_representation().

hidden_representation

The extracted information of a deep learning model when forward-passing a batch.

Methods

__init__(model: AbstractModel)[source]#

eval()

reset_forward()

Reset the overridden forward method of the torch.nn.Module to ensure pickling compatibility.

wrap_forward()

Override the forward method of a torch.nn.Module to record hidden representations.