Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion transformers4rec/torch/features/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def from_schema( # type: ignore
None,
None,
)

processed_features = []

if continuous_tags:
if continuous_soft_embeddings:
maybe_continuous_module = cls.SOFT_EMBEDDING_MODULE_CLASS.from_schema(
Expand All @@ -175,14 +176,29 @@ def from_schema( # type: ignore
maybe_continuous_module = cls.CONTINUOUS_MODULE_CLASS.from_schema(
schema, tags=continuous_tags, **kwargs
)
processed_features.extend(schema.select_by_tag(continuous_tags).column_names)
if categorical_tags:
maybe_categorical_module = cls.EMBEDDING_MODULE_CLASS.from_schema(
schema, tags=categorical_tags, **kwargs
)
processed_features.extend(schema.select_by_tag(categorical_tags).column_names)
if pretrained_embeddings_tags:
maybe_pretrained_module = cls.PRETRAINED_EMBEDDING_MODULE_CLASS.from_schema(
schema, tags=pretrained_embeddings_tags, **kwargs
)
processed_features.extend(schema.select_by_tag(pretrained_embeddings_tags).column_names)

unprocessed_features = set(schema.column_names).difference(set(processed_features))
if unprocessed_features:
raise ValueError(
"Schema provided to `TabularFeatures` includes features "
"without any of the following tags: "
f"continuous ({continuous_tags}), categorical ({categorical_tags}), "
f"or pretrained embeddings ({pretrained_embeddings_tags}). "
"Please ensure all columns have one of these tags "
"or are excluded from the schema. "
f"\nUnproceesed features: {unprocessed_features} "


output = cls(
continuous_module=maybe_continuous_module,
Expand Down
15 changes: 15 additions & 0 deletions transformers4rec/torch/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,21 @@ def __init__(
self.top_k = top_k

def forward(self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs):
model_expected_features = set(self.input_schema.column_names)
call_input_features = set(inputs.keys())
if not (training or testing) and model_expected_features != call_input_features:
raise ValueError(
"Model forward called with different set of features "
"compared with the input schema it was configured with "
"Please check that the inputs passed to the model are only "
"those required by the model."
f"\nModel expected features:\n\t{model_expected_features}"
f"\nCall input features:\n\t{call_input_features}"
f"\nFeatures expected by model input schema only:"
f"\n\t{model_expected_features.difference(call_input_features)}"
f"\nFeatures provided in inputs only:"
f"\n\t{call_input_features.difference(model_expected_features)}"
)
# Convert inputs to float32 which is the default type, expected by PyTorch
for name, val in inputs.items():
if torch.is_floating_point(val):
Expand Down