diff --git a/metaflow_extensions/slurm_ext/plugins/slurm/slurm_client.py b/metaflow_extensions/slurm_ext/plugins/slurm/slurm_client.py index c75130d..0959180 100644 --- a/metaflow_extensions/slurm_ext/plugins/slurm/slurm_client.py +++ b/metaflow_extensions/slurm_ext/plugins/slurm/slurm_client.py @@ -38,6 +38,17 @@ def __init__( % sys.executable ) + for param, env_var in [ + ("ssh_key_file", "METAFLOW_SLURM_SSH_KEY_FILE"), + ("username", "METAFLOW_SLURM_USERNAME"), + ("address", "METAFLOW_SLURM_ADDRESS"), + ]: + if locals()[param] is None: + raise SlurmException( + "%s is required. Set it via @slurm(%s=...) " + "or the %s environment variable." % (param, param, env_var) + ) + ssh_key_file_path = Path(ssh_key_file).expanduser().resolve() cert_file_path = Path(cert_file).expanduser().resolve() if cert_file else None