diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h index 4ca404c157..80c2e5a967 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h @@ -213,6 +213,10 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { impl_->release_snapshot(handle); } + void delete_rocksdb_checkpoint_dir() { + impl_->delete_rocksdb_checkpoint_dir(); + } + int64_t get_snapshot_count() const { return impl_->get_snapshot_count(); } diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h index 1c14d3f809..2d1601941c 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h @@ -121,7 +121,7 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder { void deserialize(const std::string& serialized); std::vector get_kvtensor_serializable_metadata() const; - + void delete_rocksdb_checkpoint_dir() const; friend void to_json(json& j, const KVTensorWrapper& kvt); friend void from_json(const json& j, KVTensorWrapper& kvt); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp index 2a91a1faa2..de1e78e9e6 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp @@ -108,4 +108,10 @@ std::string KVTensorWrapper::layout_str() { oss << options_.layout(); return oss.str(); } + +std::vector KVTensorWrapper::get_kvtensor_serializable_metadata() + const { + FBEXCEPTION("Not implemented"); + return std::vector{}; +} } // namespace ssd diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 541ae73c9d..3a764e8d0f 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -412,9 +412,24 @@ std::string KVTensorWrapper::serialize() const { return json_serialized.dump(); } +void KVTensorWrapper::delete_rocksdb_checkpoint_dir() const { + if (readonly_db_) { + LOG(INFO) << "deleting checkpoint dir for " << checkpoint_uuid; + readonly_db_->delete_rocksdb_checkpoint_dir(); + } else if (db_) { + auto* db = dynamic_cast(db_.get()); + LOG(INFO) << "embedding delete"; + db->delete_rocksdb_checkpoint_dir(); + } +} + std::vector KVTensorWrapper::get_kvtensor_serializable_metadata() const { std::vector metadata; + // Return empty metadata if checkpoint_handle_ is not initialized yet + if (checkpoint_handle_ == nullptr) { + return metadata; + } auto* db = dynamic_cast(db_.get()); auto checkpoint_paths = db->get_checkpoints(checkpoint_handle_->uuid); metadata.push_back(std::to_string(checkpoint_paths.size())); @@ -931,6 +946,9 @@ static auto embedding_rocks_db_wrapper = &EmbeddingRocksDBWrapper::wait_util_filling_work_done) .def("create_snapshot", &EmbeddingRocksDBWrapper::create_snapshot) .def("release_snapshot", &EmbeddingRocksDBWrapper::release_snapshot) + .def( + "delete_rocksdb_checkpoint_dir", + &EmbeddingRocksDBWrapper::delete_rocksdb_checkpoint_dir) .def("get_snapshot_count", &EmbeddingRocksDBWrapper::get_snapshot_count) .def( "get_keys_in_range_by_snapshot", @@ -1142,7 +1160,10 @@ static auto kv_tensor_wrapper = .def("logs", &KVTensorWrapper::logs, "") .def( "get_kvtensor_serializable_metadata", - &KVTensorWrapper::get_kvtensor_serializable_metadata); + &KVTensorWrapper::get_kvtensor_serializable_metadata) + .def( + "delete_rocksdb_checkpoint_dir", + &KVTensorWrapper::delete_rocksdb_checkpoint_dir); TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 050f046767..d1ed713f82 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -345,6 +345,13 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { return result; } + void delete_rocksdb_checkpoint_dir() { + for (auto shard = 0; shard < dbs_.size(); ++shard) { + LOG(INFO) << "removing checkpoint directories: " << db_paths_[shard]; + kv_db_utils::remove_dir(db_paths_[shard]); + } + } + void initialize_dbs( int64_t num_shards, std::string path,