tabensemb.model.AbstractNN.call_required_model#

method

static AbstractNN.call_required_model(required_model, x, derived_tensors, model_name=None) Tensor[source]#

Call a required model and return its result. Predictions and hidden representations are generated before training using this method.

Parameters:
required_model

A required model specified in AbstractModel.required_models() and extracted by AbstractModel._get_required_models().

x

See _forward().

derived_tensors

See _forward().

model_name

The name of the required model. It is necessary if the model comes from the same model base.

Returns:
torch.Tensor

The result of the required model.

Notes

If you want to run the required model and further train it, pass a copied derived_tensors after removing the {MODEL_NAME}_pred item in its data_required_models item.