Skip to content

Correct load_class bugs in tabsurvey models#1

Open
joelromanky wants to merge 6 commits into
mainfrom
jky-dev
Open

Correct load_class bugs in tabsurvey models#1
joelromanky wants to merge 6 commits into
mainfrom
jky-dev

Conversation

@joelromanky

Copy link
Copy Markdown

Summary:

This PR updates the load_class function in torch_models.py to ensure that model scalers are correctly re-initialized during loading. It now automatically loads ScalerData from a scaler.json file if one is present, ensuring the model's internal dimensions align with the training state.

Changes:

  • Modified load_class in tabularbench/models/torch_models.py to check for and load scaler.json if a scaler is not explicitly provided in kwargs.
  • Ensured TabScaler is properly fitted with the loaded ScalerData before model initialization, preventing size mismatch errors in the state_dict during load_state_dict().
  • Added a validation check to raise an error if scaler.json is missing, ensuring configuration integrity.
  • Updated model-related logic in the tabularbench/models/tabsurvey folder to ensure consistency with the new loading flow.

Impact:

  • Resolves RuntimeError: Error(s) in loading state_dict (size mismatches) that occurred when loading pre-trained models.
  • Ensures that models correctly reconstruct their input/feature dimensions from saved training data, rather than defaulting to empty/incorrect sizes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants