From fd744c4389be3aaec55b93d8ae9ed6d13fe753fa Mon Sep 17 00:00:00 2001 From: "cubic-dev-ai[bot]" <1082092+cubic-dev-ai[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 19:45:30 +0000 Subject: [PATCH] fix: validate file-derived lengths before allocation to prevent OOM on corruption --- file.go | 8 ++++++++ file_test.go | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/file.go b/file.go index 0d35c50..ce79c79 100644 --- a/file.go +++ b/file.go @@ -228,6 +228,10 @@ func loadMetadata(dir string) (uint64, error) { if maxBucketChunks == 0 { return 0, fmt.Errorf("invalid maxBucketChunks=0 read from %q", metadataPath) } + maxAllowedChunks := maxBucketSize / chunkSize + if maxBucketChunks > maxAllowedChunks { + return 0, fmt.Errorf("too big maxBucketChunks=%d read from %q; cannot exceed %d", maxBucketChunks, metadataPath, maxAllowedChunks) + } return maxBucketChunks, nil } @@ -352,6 +356,10 @@ func (b *bucket) Load(r io.Reader, maxChunks uint64) error { if err != nil { return fmt.Errorf("cannot read len(b.m): %s", err) } + maxKvs := maxChunks * chunkSize / 4 + if kvsLen > maxKvs { + return fmt.Errorf("too big kvsLen=%d; cannot exceed %d", kvsLen, maxKvs) + } kvsLen *= 2 * 8 kvs := make([]byte, kvsLen) if _, err := io.ReadFull(r, kvs); err != nil { diff --git a/file_test.go b/file_test.go index b2ccd97..0bc2bff 100644 --- a/file_test.go +++ b/file_test.go @@ -1,11 +1,13 @@ package fastcache import ( + "encoding/binary" "errors" "fmt" "io/ioutil" "os" "path/filepath" + "strings" "sync" "testing" ) @@ -259,3 +261,56 @@ func TestSaveLoadConcurrent(t *testing.T) { close(stopCh) wgWorkers.Wait() } + +func writeUint64ToBytes(v uint64) []byte { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], v) + return buf[:] +} + +func TestLoadCorruptedMetadataTooBigChunks(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + filePath := filepath.Join(tmpDir, "corrupted.fastcache") + if err := os.MkdirAll(filePath, 0755); err != nil { + t.Fatal(err) + } + + metadataPath := filepath.Join(filePath, "metadata.bin") + hugeChunks := maxBucketSize/chunkSize + 1 + if err := os.WriteFile(metadataPath, writeUint64ToBytes(hugeChunks), 0644); err != nil { + t.Fatal(err) + } + + _, err = LoadFromFile(filePath) + if err == nil { + t.Fatal("expected error for corrupted metadata with huge maxBucketChunks") + } + if !strings.Contains(err.Error(), "too big maxBucketChunks") { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestLoadCorruptedDataTooBigKvsLen(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + filePath := filepath.Join(tmpDir, "corrupted.fastcache") + c := New(bucketsCount * chunkSize * 2) + c.Set([]byte("key"), []byte("value")) + if err := c.SaveToFile(filePath); err != nil { + t.Fatalf("SaveToFile error: %s", err) + } + + _, err = LoadFromFile(filePath) + if err != nil { + t.Fatalf("LoadFromFile must succeed for valid cache: %s", err) + } +}