diff --git a/python/dpu_utils/tfutils/tfvariablesaver.py b/python/dpu_utils/tfutils/tfvariablesaver.py index 9f77cf3..5b7b7be 100644 --- a/python/dpu_utils/tfutils/tfvariablesaver.py +++ b/python/dpu_utils/tfutils/tfvariablesaver.py @@ -10,7 +10,7 @@ class TFVariableSaver: def __init__(self): self.__saved_variables = {} # type: Dict[str, np.ndarray] - def save_all(self, session: tf.Session, exclude_variable: Optional[Callable[[str], bool]]=None) -> None: + def save_all(self, session: tf.compat.v1.Session, exclude_variable: Optional[Callable[[str], bool]]=None) -> None: self.__saved_variables = {} for variable in session.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): assert variable.name not in self.__saved_variables @@ -21,7 +21,7 @@ def save_all(self, session: tf.Session, exclude_variable: Optional[Callable[[str def has_saved_variables(self) -> bool: return len(self.__saved_variables) > 0 - def restore_saved_values(self, session: tf.Session) -> None: + def restore_saved_values(self, session: tf.v1.compat.Session) -> None: assert len(self.__saved_variables) > 0 save_ops = [] with tf.name_scope("restore"):