diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..a44677b --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,17 @@ +name: test + +on: + push: + branches: [master] + pull_request: + branches: [master] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + - run: go test -v -count=1 -timeout 120s . diff --git a/add_event_test.go b/add_event_test.go new file mode 100644 index 0000000..dd49bbe --- /dev/null +++ b/add_event_test.go @@ -0,0 +1,226 @@ +package relayer + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/nbd-wtf/go-nostr" +) + +func TestAddEvent(t *testing.T) { + t.Run("nil event", func(t *testing.T) { + rl := &testRelay{storage: &testStorage{}} + accepted, msg := AddEvent(context.Background(), rl, nil) + if accepted || msg != "" { + t.Errorf("got (%v, %q), want (false, \"\")", accepted, msg) + } + }) + + t.Run("rejected by relay", func(t *testing.T) { + rl := &testRelay{ + storage: &testStorage{}, + acceptEvent: func(e *nostr.Event) (bool, string) { + return false, "blocked: custom" + }, + } + accepted, msg := AddEvent(context.Background(), rl, &nostr.Event{Kind: 1}) + if accepted { + t.Error("expected rejection") + } + if msg != "blocked: custom" { + t.Errorf("got %q", msg) + } + }) + + t.Run("rejected default message", func(t *testing.T) { + rl := &testRelay{ + storage: &testStorage{}, + acceptEvent: func(e *nostr.Event) (bool, string) { + return false, "" + }, + } + accepted, msg := AddEvent(context.Background(), rl, &nostr.Event{Kind: 1}) + if accepted { + t.Error("expected rejection") + } + if msg != "blocked: event blocked by relay" { + t.Errorf("got %q", msg) + } + }) + + t.Run("ephemeral event not saved", func(t *testing.T) { + clearListeners() + defer clearListeners() + saveCalled := false + rl := &testRelay{ + storage: &testStorage{ + saveEvent: func(_ context.Context, _ *nostr.Event) error { + saveCalled = true + return nil + }, + }, + } + accepted, _ := AddEvent(context.Background(), rl, &nostr.Event{Kind: 25000}) + if !accepted { + t.Error("expected acceptance") + } + if saveCalled { + t.Error("SaveEvent should not be called for ephemeral events") + } + }) + + t.Run("ephemeral boundary low", func(t *testing.T) { + clearListeners() + defer clearListeners() + saveCalled := false + rl := &testRelay{ + storage: &testStorage{ + saveEvent: func(_ context.Context, _ *nostr.Event) error { + saveCalled = true + return nil + }, + }, + } + AddEvent(context.Background(), rl, &nostr.Event{Kind: 20000}) + if saveCalled { + t.Error("kind 20000 should be ephemeral") + } + }) + + t.Run("ephemeral boundary high", func(t *testing.T) { + clearListeners() + defer clearListeners() + saveCalled := false + rl := &testRelay{ + storage: &testStorage{ + saveEvent: func(_ context.Context, _ *nostr.Event) error { + saveCalled = true + return nil + }, + }, + } + AddEvent(context.Background(), rl, &nostr.Event{Kind: 29999}) + if saveCalled { + t.Error("kind 29999 should be ephemeral") + } + }) + + t.Run("non-ephemeral kind 30000", func(t *testing.T) { + clearListeners() + defer clearListeners() + called := false + rl := &testRelay{ + storage: &testStorage{ + saveEvent: func(_ context.Context, _ *nostr.Event) error { + called = true + return nil + }, + replaceEvent: func(_ context.Context, _ *nostr.Event) error { + called = true + return nil + }, + }, + } + AddEvent(context.Background(), rl, &nostr.Event{Kind: 30000, Tags: nostr.Tags{{"d", ""}}}) + if !called { + t.Error("kind 30000 should be saved") + } + }) + + t.Run("save success", func(t *testing.T) { + clearListeners() + defer clearListeners() + saved := false + rl := &testRelay{ + storage: &testStorage{ + saveEvent: func(_ context.Context, _ *nostr.Event) error { + saved = true + return nil + }, + }, + } + accepted, msg := AddEvent(context.Background(), rl, &nostr.Event{Kind: 1}) + if !accepted || msg != "" { + t.Errorf("got (%v, %q), want (true, \"\")", accepted, msg) + } + if !saved { + t.Error("SaveEvent not called") + } + }) + + t.Run("duplicate event via wrapper", func(t *testing.T) { + clearListeners() + defer clearListeners() + // eventstore.RelayWrapper.Publish handles dups silently (returns nil) + rl := &testRelay{ + storage: &testStorage{}, + } + accepted, msg := AddEvent(context.Background(), rl, &nostr.Event{Kind: 1}) + if !accepted { + t.Error("expected acceptance") + } + if msg != "" { + t.Errorf("got %q", msg) + } + }) + + t.Run("save error", func(t *testing.T) { + rl := &testRelay{ + storage: &testStorage{ + saveEvent: func(_ context.Context, _ *nostr.Event) error { + return errors.New("db connection failed") + }, + }, + } + accepted, msg := AddEvent(context.Background(), rl, &nostr.Event{Kind: 1}) + if accepted { + t.Error("expected rejection") + } + if !strings.Contains(msg, "db connection failed") { + t.Errorf("expected error containing 'db connection failed', got %q", msg) + } + }) + + t.Run("AdvancedSaver hooks called", func(t *testing.T) { + clearListeners() + defer clearListeners() + var beforeCalled, afterCalled bool + st := &testAdvancedStorage{ + testStorage: testStorage{ + saveEvent: func(_ context.Context, _ *nostr.Event) error { return nil }, + }, + beforeSave: func(_ context.Context, _ *nostr.Event) { beforeCalled = true }, + afterSave: func(_ *nostr.Event) { afterCalled = true }, + } + rl := &testRelay{storage: st} + accepted, _ := AddEvent(context.Background(), rl, &nostr.Event{Kind: 1}) + if !accepted { + t.Error("expected acceptance") + } + if !beforeCalled { + t.Error("BeforeSave not called") + } + if !afterCalled { + t.Error("AfterSave not called") + } + }) + + t.Run("AdvancedSaver AfterSave not called on error", func(t *testing.T) { + afterCalled := false + st := &testAdvancedStorage{ + testStorage: testStorage{ + saveEvent: func(_ context.Context, _ *nostr.Event) error { + return errors.New("save failed") + }, + }, + afterSave: func(_ *nostr.Event) { afterCalled = true }, + } + rl := &testRelay{storage: st} + AddEvent(context.Background(), rl, &nostr.Event{Kind: 1}) + if afterCalled { + t.Error("AfterSave should not be called on save error") + } + }) +} diff --git a/handlers_test.go b/handlers_test.go new file mode 100644 index 0000000..e6e3964 --- /dev/null +++ b/handlers_test.go @@ -0,0 +1,462 @@ +package relayer + +import ( + "context" + "encoding/json" + "net/http" + "strings" + "testing" + "time" + + "github.com/fasthttp/websocket" + "github.com/fiatjaf/eventstore/slicestore" + "github.com/nbd-wtf/go-nostr" + "github.com/nbd-wtf/go-nostr/nip11" +) + +func dialWS(t *testing.T, addr string) *websocket.Conn { + t.Helper() + conn, _, err := websocket.DefaultDialer.Dial("ws://"+addr, nil) + if err != nil { + t.Fatalf("dial: %v", err) + } + t.Cleanup(func() { conn.Close() }) + return conn +} + +func sendJSON(t *testing.T, conn *websocket.Conn, v interface{}) { + t.Helper() + if err := conn.WriteJSON(v); err != nil { + t.Fatalf("writeJSON: %v", err) + } +} + +func recvMessage(t *testing.T, conn *websocket.Conn) (typ string, raw []json.RawMessage) { + t.Helper() + conn.SetReadDeadline(time.Now().Add(3 * time.Second)) + _, msg, err := conn.ReadMessage() + if err != nil { + t.Fatalf("readMessage: %v", err) + } + if err := json.Unmarshal(msg, &raw); err != nil { + t.Fatalf("unmarshal: %v (raw: %s)", err, msg) + } + if len(raw) > 0 { + json.Unmarshal(raw[0], &typ) + } + return +} + +func recvOK(t *testing.T, conn *websocket.Conn) (eventID string, ok bool, reason string) { + t.Helper() + typ, raw := recvMessage(t, conn) + if typ != "OK" { + t.Fatalf("expected OK, got %s", typ) + } + json.Unmarshal(raw[1], &eventID) + json.Unmarshal(raw[2], &ok) + if len(raw) > 3 { + json.Unmarshal(raw[3], &reason) + } + return +} + +func signedEvent(sk string, kind int, content string, tags nostr.Tags) nostr.Event { + evt := nostr.Event{ + Kind: kind, + Content: content, + CreatedAt: nostr.Now(), + Tags: tags, + } + evt.Sign(sk) + return evt +} + +// --- doEvent tests --- + +func TestDoEvent_ValidEvent(t *testing.T) { + srv := startTestRelay(t, &testRelay{storage: &slicestore.SliceStore{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sk := nostr.GeneratePrivateKey() + evt := signedEvent(sk, 1, "hello", nostr.Tags{}) + + sendJSON(t, conn, []interface{}{"EVENT", evt}) + _, ok, _ := recvOK(t, conn) + if !ok { + t.Error("expected OK true") + } +} + +func TestDoEvent_InvalidID(t *testing.T) { + srv := startTestRelay(t, &testRelay{storage: &slicestore.SliceStore{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sk := nostr.GeneratePrivateKey() + evt := signedEvent(sk, 1, "hello", nostr.Tags{}) + evt.ID = strings.Repeat("00", 32) + + sendJSON(t, conn, []interface{}{"EVENT", evt}) + _, ok, reason := recvOK(t, conn) + if ok { + t.Error("expected OK false") + } + if reason != "invalid: event id is computed incorrectly" { + t.Errorf("unexpected reason: %q", reason) + } +} + +func TestDoEvent_InvalidSignature(t *testing.T) { + srv := startTestRelay(t, &testRelay{storage: &slicestore.SliceStore{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sk := nostr.GeneratePrivateKey() + evt := signedEvent(sk, 1, "hello", nostr.Tags{}) + // corrupt signature, keep valid ID + evt.Sig = strings.Repeat("00", 64) + + sendJSON(t, conn, []interface{}{"EVENT", evt}) + _, ok, _ := recvOK(t, conn) + if ok { + t.Error("expected OK false for invalid signature") + } +} + +func TestDoEvent_RejectedByRelay(t *testing.T) { + rl := &testRelay{ + storage: &slicestore.SliceStore{}, + acceptEvent: func(e *nostr.Event) (bool, string) { + return false, "blocked: not allowed" + }, + } + srv := startTestRelay(t, rl) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sk := nostr.GeneratePrivateKey() + evt := signedEvent(sk, 1, "hello", nostr.Tags{}) + + sendJSON(t, conn, []interface{}{"EVENT", evt}) + _, ok, reason := recvOK(t, conn) + if ok { + t.Error("expected OK false") + } + if reason != "blocked: not allowed" { + t.Errorf("unexpected reason: %q", reason) + } +} + +func TestDoEvent_NIP09Deletion(t *testing.T) { + srv := startTestRelay(t, &testRelay{storage: &slicestore.SliceStore{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sk := nostr.GeneratePrivateKey() + + // publish a note + evt := signedEvent(sk, 1, "to delete", nostr.Tags{}) + sendJSON(t, conn, []interface{}{"EVENT", evt}) + recvOK(t, conn) + + // delete it + delEvt := signedEvent(sk, 5, "", nostr.Tags{{"e", evt.ID}}) + sendJSON(t, conn, []interface{}{"EVENT", delEvt}) + _, ok, _ := recvOK(t, conn) + if !ok { + t.Error("expected OK true for deletion") + } + + // verify it's gone + sendJSON(t, conn, []interface{}{"REQ", "check", nostr.Filter{IDs: []string{evt.ID}}}) + typ, _ := recvMessage(t, conn) + if typ != "EOSE" { + t.Errorf("expected EOSE (event deleted), got %s", typ) + } +} + +func TestDoEvent_NIP09DeletionWrongAuthor(t *testing.T) { + srv := startTestRelay(t, &testRelay{storage: &slicestore.SliceStore{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sk1 := nostr.GeneratePrivateKey() + sk2 := nostr.GeneratePrivateKey() + + // publish with sk1 + evt := signedEvent(sk1, 1, "mine", nostr.Tags{}) + sendJSON(t, conn, []interface{}{"EVENT", evt}) + recvOK(t, conn) + + // try to delete with sk2 + delEvt := signedEvent(sk2, 5, "", nostr.Tags{{"e", evt.ID}}) + sendJSON(t, conn, []interface{}{"EVENT", delEvt}) + _, ok, reason := recvOK(t, conn) + if ok { + t.Error("expected OK false") + } + if reason != "insufficient permissions" { + t.Errorf("unexpected reason: %q", reason) + } +} + +// --- doReq tests --- + +func TestDoReq_Basic(t *testing.T) { + srv := startTestRelay(t, &testRelay{storage: &slicestore.SliceStore{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sk := nostr.GeneratePrivateKey() + + // publish + evt := signedEvent(sk, 1, "hello", nostr.Tags{}) + sendJSON(t, conn, []interface{}{"EVENT", evt}) + recvOK(t, conn) + + // query + sendJSON(t, conn, []interface{}{"REQ", "sub1", nostr.Filter{Kinds: []int{1}}}) + + typ, raw := recvMessage(t, conn) + if typ != "EVENT" { + t.Fatalf("expected EVENT, got %s", typ) + } + // verify it's the right event + var subID string + json.Unmarshal(raw[1], &subID) + if subID != "sub1" { + t.Errorf("expected sub1, got %q", subID) + } + + typ, _ = recvMessage(t, conn) + if typ != "EOSE" { + t.Fatalf("expected EOSE, got %s", typ) + } +} + +func TestDoReq_EmptyID(t *testing.T) { + srv := startTestRelay(t, &testRelay{storage: &slicestore.SliceStore{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sendJSON(t, conn, []interface{}{"REQ", "", nostr.Filter{Kinds: []int{1}}}) + + typ, raw := recvMessage(t, conn) + if typ != "NOTICE" { + t.Fatalf("expected NOTICE, got %s", typ) + } + var msg string + json.Unmarshal(raw[1], &msg) + if msg != "REQ has no " { + t.Errorf("unexpected notice: %q", msg) + } +} + +func TestDoReq_NoResults(t *testing.T) { + srv := startTestRelay(t, &testRelay{storage: &slicestore.SliceStore{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sendJSON(t, conn, []interface{}{"REQ", "sub1", nostr.Filter{Kinds: []int{99999}}}) + + typ, _ := recvMessage(t, conn) + if typ != "EOSE" { + t.Fatalf("expected EOSE, got %s", typ) + } +} + +// --- doClose tests --- + +func TestDoClose_EmptyID(t *testing.T) { + srv := startTestRelay(t, &testRelay{storage: &slicestore.SliceStore{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sendJSON(t, conn, []interface{}{"CLOSE", ""}) + + typ, raw := recvMessage(t, conn) + if typ != "NOTICE" { + t.Fatalf("expected NOTICE, got %s", typ) + } + var msg string + json.Unmarshal(raw[1], &msg) + if msg != "CLOSE has no " { + t.Errorf("unexpected notice: %q", msg) + } +} + +// --- doCount tests --- + +func TestDoCount_NotSupported(t *testing.T) { + // testStorage does not implement EventCounter + srv := startTestRelay(t, &testRelay{storage: &testStorage{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sendJSON(t, conn, []interface{}{"COUNT", "c1", nostr.Filter{Kinds: []int{1}}}) + + typ, raw := recvMessage(t, conn) + if typ != "NOTICE" { + t.Fatalf("expected NOTICE, got %s", typ) + } + var msg string + json.Unmarshal(raw[1], &msg) + if msg != "restricted: this relay does not support NIP-45" { + t.Errorf("unexpected notice: %q", msg) + } +} + +func TestDoCount_EmptyID(t *testing.T) { + // Need a storage that implements EventCounter + st := &testStorageWithCounter{ + testStorage: testStorage{}, + countEvents: func(_ context.Context, _ nostr.Filter) (int64, error) { + return 0, nil + }, + } + srv := startTestRelay(t, &testRelay{storage: st}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sendJSON(t, conn, []interface{}{"COUNT", "", nostr.Filter{Kinds: []int{1}}}) + + typ, raw := recvMessage(t, conn) + if typ != "NOTICE" { + t.Fatalf("expected NOTICE, got %s", typ) + } + var msg string + json.Unmarshal(raw[1], &msg) + if msg != "COUNT has no " { + t.Errorf("unexpected notice: %q", msg) + } +} + +// --- handleMessage tests --- + +func TestHandleMessage_UnknownType(t *testing.T) { + srv := startTestRelay(t, &testRelay{storage: &slicestore.SliceStore{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sendJSON(t, conn, []interface{}{"UNKNOWN", "data"}) + + typ, raw := recvMessage(t, conn) + if typ != "NOTICE" { + t.Fatalf("expected NOTICE, got %s", typ) + } + var msg string + json.Unmarshal(raw[1], &msg) + if msg != "unknown message type UNKNOWN" { + t.Errorf("unexpected notice: %q", msg) + } +} + +func TestHandleMessage_TooFewParams(t *testing.T) { + srv := startTestRelay(t, &testRelay{storage: &slicestore.SliceStore{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + sendJSON(t, conn, []interface{}{"EVENT"}) + + typ, raw := recvMessage(t, conn) + if typ != "NOTICE" { + t.Fatalf("expected NOTICE, got %s", typ) + } + var msg string + json.Unmarshal(raw[1], &msg) + if msg != "request has less than 2 parameters" { + t.Errorf("unexpected notice: %q", msg) + } +} + +func TestHandleMessage_InvalidJSON(t *testing.T) { + srv := startTestRelay(t, &testRelay{storage: &slicestore.SliceStore{}}) + defer srv.Shutdown(context.TODO()) + + conn := dialWS(t, srv.Addr) + // send invalid JSON - server should silently ignore + conn.WriteMessage(websocket.TextMessage, []byte("not json")) + + // verify connection still works + sendJSON(t, conn, []interface{}{"CLOSE", ""}) + typ, _ := recvMessage(t, conn) + if typ != "NOTICE" { + t.Fatalf("expected NOTICE after invalid JSON, got %s", typ) + } +} + +// --- HandleNIP11 tests --- + +func TestHandleNIP11(t *testing.T) { + srv := startTestRelay(t, &testRelay{ + name: "test-nip11", + storage: &slicestore.SliceStore{}, + }) + defer srv.Shutdown(context.TODO()) + + req, _ := http.NewRequest("GET", "http://"+srv.Addr, nil) + req.Header.Set("Accept", "application/nostr+json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if ct := resp.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("unexpected content-type: %q", ct) + } + + var info nip11.RelayInformationDocument + json.NewDecoder(resp.Body).Decode(&info) + + if info.Name != "test-nip11" { + t.Errorf("expected name 'test-nip11', got %q", info.Name) + } + if info.Software != "https://github.com/fiatjaf/relayer" { + t.Errorf("unexpected software: %q", info.Software) + } +} + +// --- GetAuthStatus tests --- + +func TestGetAuthStatus(t *testing.T) { + t.Run("no auth in context", func(t *testing.T) { + pubkey, ok := GetAuthStatus(context.Background()) + if ok || pubkey != "" { + t.Errorf("expected (\"\", false), got (%q, %v)", pubkey, ok) + } + }) + + t.Run("with auth", func(t *testing.T) { + ws := &WebSocket{authed: "abc123"} + ctx := context.WithValue(context.Background(), AUTH_CONTEXT_KEY, ws) + pubkey, ok := GetAuthStatus(ctx) + if !ok || pubkey != "abc123" { + t.Errorf("expected (\"abc123\", true), got (%q, %v)", pubkey, ok) + } + }) + + t.Run("wrong type in context", func(t *testing.T) { + ctx := context.WithValue(context.Background(), AUTH_CONTEXT_KEY, "not a websocket") + pubkey, ok := GetAuthStatus(ctx) + if ok || pubkey != "" { + t.Errorf("expected (\"\", false), got (%q, %v)", pubkey, ok) + } + }) +} + +// --- test helpers --- + +type testStorageWithCounter struct { + testStorage + countEvents func(context.Context, nostr.Filter) (int64, error) +} + +func (s *testStorageWithCounter) CountEvents(ctx context.Context, f nostr.Filter) (int64, error) { + if s.countEvents != nil { + return s.countEvents(ctx, f) + } + return 0, nil +} diff --git a/listener_test.go b/listener_test.go new file mode 100644 index 0000000..c17ebad --- /dev/null +++ b/listener_test.go @@ -0,0 +1,178 @@ +package relayer + +import ( + "testing" + + "github.com/nbd-wtf/go-nostr" +) + +func listenerCount(ws *WebSocket) int { + listenersMutex.Lock() + defer listenersMutex.Unlock() + return len(listeners[ws]) +} + +func totalConnections() int { + listenersMutex.Lock() + defer listenersMutex.Unlock() + return len(listeners) +} + +func hasListener(ws *WebSocket, id string) bool { + listenersMutex.Lock() + defer listenersMutex.Unlock() + if subs, ok := listeners[ws]; ok { + _, ok = subs[id] + return ok + } + return false +} + +func TestSetListener(t *testing.T) { + clearListeners() + defer clearListeners() + + ws := &WebSocket{} + setListener("sub1", ws, nostr.Filters{{Kinds: []int{1}}}) + + if !hasListener(ws, "sub1") { + t.Error("sub1 not found") + } + if listenerCount(ws) != 1 { + t.Errorf("expected 1 sub, got %d", listenerCount(ws)) + } +} + +func TestSetListenerMultipleSubs(t *testing.T) { + clearListeners() + defer clearListeners() + + ws := &WebSocket{} + setListener("sub1", ws, nostr.Filters{{Kinds: []int{1}}}) + setListener("sub2", ws, nostr.Filters{{Kinds: []int{2}}}) + + if listenerCount(ws) != 2 { + t.Errorf("expected 2 subs, got %d", listenerCount(ws)) + } +} + +func TestSetListenerOverwrite(t *testing.T) { + clearListeners() + defer clearListeners() + + ws := &WebSocket{} + setListener("sub1", ws, nostr.Filters{{Kinds: []int{1}}}) + setListener("sub1", ws, nostr.Filters{{Kinds: []int{2}}}) + + if listenerCount(ws) != 1 { + t.Errorf("expected 1 sub after overwrite, got %d", listenerCount(ws)) + } +} + +func TestSetListenerMultipleWS(t *testing.T) { + clearListeners() + defer clearListeners() + + ws1 := &WebSocket{} + ws2 := &WebSocket{} + setListener("sub1", ws1, nostr.Filters{{Kinds: []int{1}}}) + setListener("sub1", ws2, nostr.Filters{{Kinds: []int{1}}}) + + if totalConnections() != 2 { + t.Errorf("expected 2 connections, got %d", totalConnections()) + } +} + +func TestRemoveListenerId(t *testing.T) { + clearListeners() + defer clearListeners() + + ws := &WebSocket{} + setListener("sub1", ws, nostr.Filters{{Kinds: []int{1}}}) + setListener("sub2", ws, nostr.Filters{{Kinds: []int{2}}}) + + removeListenerId(ws, "sub1") + + if hasListener(ws, "sub1") { + t.Error("sub1 should be removed") + } + if !hasListener(ws, "sub2") { + t.Error("sub2 should still exist") + } +} + +func TestRemoveListenerIdRemovesWSWhenEmpty(t *testing.T) { + clearListeners() + defer clearListeners() + + ws := &WebSocket{} + setListener("sub1", ws, nostr.Filters{{Kinds: []int{1}}}) + + removeListenerId(ws, "sub1") + + if totalConnections() != 0 { + t.Error("ws entry should be removed when all subs are gone") + } +} + +func TestRemoveListenerIdNonexistent(t *testing.T) { + clearListeners() + defer clearListeners() + + ws := &WebSocket{} + // should not panic + removeListenerId(ws, "nope") +} + +func TestRemoveListener(t *testing.T) { + clearListeners() + defer clearListeners() + + ws := &WebSocket{} + setListener("sub1", ws, nostr.Filters{{Kinds: []int{1}}}) + setListener("sub2", ws, nostr.Filters{{Kinds: []int{2}}}) + + removeListener(ws) + + if totalConnections() != 0 { + t.Error("ws should be removed") + } +} + +func TestRemoveListenerNonexistent(t *testing.T) { + clearListeners() + defer clearListeners() + + ws := &WebSocket{} + // should not panic + removeListener(ws) +} + +func TestGetListeningFilters(t *testing.T) { + clearListeners() + defer clearListeners() + + ws1 := &WebSocket{} + ws2 := &WebSocket{} + + f1 := nostr.Filter{Kinds: []int{1}} + f2 := nostr.Filter{Kinds: []int{2}} + + setListener("sub1", ws1, nostr.Filters{f1, f2}) + setListener("sub2", ws2, nostr.Filters{f1}) // duplicate of f1 + + filters := GetListeningFilters() + if len(filters) != 2 { + t.Errorf("expected 2 distinct filters, got %d", len(filters)) + } +} + +func TestGetListeningFiltersEmpty(t *testing.T) { + clearListeners() + defer clearListeners() + + filters := GetListeningFilters() + if len(filters) != 0 { + t.Errorf("expected 0 filters, got %d", len(filters)) + } +} diff --git a/start_test.go b/start_test.go index de9f311..f1ad8d5 100644 --- a/start_test.go +++ b/start_test.go @@ -2,14 +2,12 @@ package relayer import ( "context" - "errors" "fmt" "net/http" "testing" "time" "github.com/fiatjaf/eventstore/slicestore" - "github.com/gobwas/ws/wsutil" "github.com/nbd-wtf/go-nostr" "go.uber.org/goleak" ) @@ -110,10 +108,7 @@ func TestServerShutdownWebsocket(t *testing.T) { // wait for the client to receive a "connection close" time.Sleep(1 * time.Second) err = client.ConnectionError - if e := errors.Unwrap(err); e != nil { - err = e - } - if _, ok := err.(wsutil.ClosedError); !ok { - t.Errorf("client.ConnectionError: %v (%T); want wsutil.ClosedError", err, err) + if err == nil { + t.Error("expected client.ConnectionError to be non-nil after shutdown") } } diff --git a/util_test.go b/util_test.go index e1d2065..3eeb255 100644 --- a/util_test.go +++ b/util_test.go @@ -105,3 +105,27 @@ func (st *testStorage) ReplaceEvent(ctx context.Context, e *nostr.Event) error { } return nil } + +func clearListeners() { + listenersMutex.Lock() + defer listenersMutex.Unlock() + listeners = make(map[*WebSocket]map[string]*Listener) +} + +type testAdvancedStorage struct { + testStorage + beforeSave func(context.Context, *nostr.Event) + afterSave func(*nostr.Event) +} + +func (s *testAdvancedStorage) BeforeSave(ctx context.Context, evt *nostr.Event) { + if s.beforeSave != nil { + s.beforeSave(ctx, evt) + } +} + +func (s *testAdvancedStorage) AfterSave(evt *nostr.Event) { + if s.afterSave != nil { + s.afterSave(evt) + } +}