From 03abe35586ccbaaf8b39429f9ccfc120052a9bea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 12 Apr 2026 15:09:21 +0200 Subject: [PATCH 01/11] style(contract): standardize doc comments, naming, and add t.Parallel - Rename abbreviated params: k->name, d->fallback in ParamInt/QueryInt variants - Standardize doc comments from bulleted parameter lists to concise prose - Add t.Parallel() to all test functions across all contract test files - Add NeedsRehash doc comment on Rehashable interface - Fix WithTransaction doc wording about nested operations --- contract/cache_test.go | 8 +++ contract/database.go | 2 +- contract/database_test.go | 8 +++ contract/hash.go | 11 ++++ contract/hooks_test.go | 4 ++ contract/request/body_test.go | 62 +++++++++++++++++++++++ contract/request/cookie.go | 34 +++---------- contract/request/cookie_test.go | 14 ++++++ contract/request/header_test.go | 20 ++++++++ contract/request/hooks_test.go | 16 ++++++ contract/request/param.go | 69 ++++++------------------- contract/request/param_test.go | 26 ++++++++++ contract/request/query.go | 86 ++++++++------------------------ contract/request/query_test.go | 36 +++++++++++++ contract/request/session_test.go | 26 ++++++++++ contract/response/static_test.go | 58 +++++++++++++++++++++ contract/response/stream_test.go | 26 ++++++++++ contract/session_test.go | 4 ++ 18 files changed, 362 insertions(+), 148 deletions(-) diff --git a/contract/cache_test.go b/contract/cache_test.go index aa4cd3f..44026b5 100644 --- a/contract/cache_test.go +++ b/contract/cache_test.go @@ -8,14 +8,20 @@ import ( ) func TestErrCacheKeyNotFoundMessage(t *testing.T) { + t.Parallel() + require.Equal(t, "cache key not found", contract.ErrCacheKeyNotFound.Error()) } func TestErrCacheKeyNotFoundIsNonNil(t *testing.T) { + t.Parallel() + require.NotNil(t, contract.ErrCacheKeyNotFound) } func TestErrCacheUnsupportedOperationMessage(t *testing.T) { + t.Parallel() + require.Equal( t, "cache unsupported operation", @@ -24,5 +30,7 @@ func TestErrCacheUnsupportedOperationMessage(t *testing.T) { } func TestErrCacheUnsupportedOperationIsNonNil(t *testing.T) { + t.Parallel() + require.NotNil(t, contract.ErrCacheUnsupportedOperation) } diff --git a/contract/database.go b/contract/database.go index bd743bd..67feac5 100644 --- a/contract/database.go +++ b/contract/database.go @@ -65,6 +65,6 @@ type Database interface { // WithTransaction executes the provided function fn within a database transaction. // If fn returns an error, the transaction is rolled back. Otherwise, it is committed. // The tx passed to fn implements the same Database interface and can be used - // for nested operations within the transaction. + // for operations within the transaction scope. WithTransaction(ctx context.Context, fn func(tx Database) error) error } diff --git a/contract/database_test.go b/contract/database_test.go index 82fd9d4..c60af6d 100644 --- a/contract/database_test.go +++ b/contract/database_test.go @@ -8,6 +8,8 @@ import ( ) func TestErrDatabaseNoRowsMessage(t *testing.T) { + t.Parallel() + require.Equal( t, "no database rows were found", @@ -16,10 +18,14 @@ func TestErrDatabaseNoRowsMessage(t *testing.T) { } func TestErrDatabaseNoRowsIsNonNil(t *testing.T) { + t.Parallel() + require.NotNil(t, contract.ErrDatabaseNoRows) } func TestErrDatabaseNestedTransactionMessage(t *testing.T) { + t.Parallel() + require.Equal( t, "nested transactions are not supported", @@ -28,5 +34,7 @@ func TestErrDatabaseNestedTransactionMessage(t *testing.T) { } func TestErrDatabaseNestedTransactionIsNonNil(t *testing.T) { + t.Parallel() + require.NotNil(t, contract.ErrDatabaseNestedTransaction) } diff --git a/contract/hash.go b/contract/hash.go index 8a0aa3c..108d931 100644 --- a/contract/hash.go +++ b/contract/hash.go @@ -13,3 +13,14 @@ type Hasher interface { // It returns an error if the verification operation fails. Check(value []byte, hash []byte) (bool, error) } + +// Rehashable extends [Hasher] with the ability to detect stale hash parameters. +// Implementations should return true when the given hash was produced with +// different parameters than the current configuration, indicating the value +// should be re-hashed on the next successful authentication. +type Rehashable interface { + // NeedsRehash reports whether the given hash was produced with + // different parameters than the current configuration, indicating + // the value should be re-hashed. + NeedsRehash(hash []byte) bool +} diff --git a/contract/hooks_test.go b/contract/hooks_test.go index 75b9ab4..3408d95 100644 --- a/contract/hooks_test.go +++ b/contract/hooks_test.go @@ -8,10 +8,14 @@ import ( ) func TestHooksKeyIsNonNil(t *testing.T) { + t.Parallel() + require.NotNil(t, contract.HooksKey) } func TestHooksKeyIsDistinctType(t *testing.T) { + t.Parallel() + var other any = struct{}{} require.NotEqual(t, other, contract.HooksKey) diff --git a/contract/request/body_test.go b/contract/request/body_test.go index a5e5ca3..44dc0db 100644 --- a/contract/request/body_test.go +++ b/contract/request/body_test.go @@ -12,6 +12,8 @@ import ( ) func TestBytesReadsEntireBody(t *testing.T) { + t.Parallel() + body := "hello world" r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) @@ -22,6 +24,8 @@ func TestBytesReadsEntireBody(t *testing.T) { } func TestBytesEmptyBody(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("")) result, err := request.Bytes(r) @@ -31,6 +35,8 @@ func TestBytesEmptyBody(t *testing.T) { } func TestBytesErrorOnFailedRead(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodPost, "/", errReader{}) _, err := request.Bytes(r) @@ -39,6 +45,8 @@ func TestBytesErrorOnFailedRead(t *testing.T) { } func TestLimitedBytesReadsUpToLimit(t *testing.T) { + t.Parallel() + body := "abcdefghij" r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) @@ -49,6 +57,8 @@ func TestLimitedBytesReadsUpToLimit(t *testing.T) { } func TestLimitedBytesReadsFullBodyUnderLimit(t *testing.T) { + t.Parallel() + body := "abc" r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) @@ -59,6 +69,8 @@ func TestLimitedBytesReadsFullBodyUnderLimit(t *testing.T) { } func TestLimitedBytesUsesDefaultOnNegativeMaxSize(t *testing.T) { + t.Parallel() + body := "short" r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) @@ -69,6 +81,8 @@ func TestLimitedBytesUsesDefaultOnNegativeMaxSize(t *testing.T) { } func TestStringReadsBodyAsString(t *testing.T) { + t.Parallel() + body := "hello string" r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) @@ -79,6 +93,8 @@ func TestStringReadsBodyAsString(t *testing.T) { } func TestStringEmptyBody(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("")) result, err := request.String(r) @@ -88,6 +104,8 @@ func TestStringEmptyBody(t *testing.T) { } func TestStringErrorOnFailedRead(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodPost, "/", errReader{}) _, err := request.String(r) @@ -96,6 +114,8 @@ func TestStringErrorOnFailedRead(t *testing.T) { } func TestLimitedStringReadsUpToLimit(t *testing.T) { + t.Parallel() + body := "abcdefghij" r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) @@ -106,6 +126,8 @@ func TestLimitedStringReadsUpToLimit(t *testing.T) { } func TestLimitedStringReadsFullBodyUnderLimit(t *testing.T) { + t.Parallel() + body := "abc" r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) @@ -116,6 +138,8 @@ func TestLimitedStringReadsFullBodyUnderLimit(t *testing.T) { } func TestLimitedStringUsesDefaultOnNegativeMaxSize(t *testing.T) { + t.Parallel() + body := "short" r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) @@ -126,6 +150,8 @@ func TestLimitedStringUsesDefaultOnNegativeMaxSize(t *testing.T) { } func TestLimitedStringErrorOnFailedRead(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodPost, "/", errReader{}) _, err := request.LimitedString(r, 10) @@ -134,6 +160,8 @@ func TestLimitedStringErrorOnFailedRead(t *testing.T) { } func TestJSONDecodesValidPayload(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` Age int `json:"age"` @@ -152,6 +180,8 @@ func TestJSONDecodesValidPayload(t *testing.T) { } func TestJSONReturnsErrorOnInvalidPayload(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` } @@ -167,6 +197,8 @@ func TestJSONReturnsErrorOnInvalidPayload(t *testing.T) { } func TestJSONIgnoresUnknownFields(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` } @@ -183,6 +215,8 @@ func TestJSONIgnoresUnknownFields(t *testing.T) { } func TestStrictJSONDecodesValidPayload(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` } @@ -199,6 +233,8 @@ func TestStrictJSONDecodesValidPayload(t *testing.T) { } func TestStrictJSONRejectsUnknownFields(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` } @@ -214,6 +250,8 @@ func TestStrictJSONRejectsUnknownFields(t *testing.T) { } func TestStrictJSONReturnsErrorOnInvalidPayload(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` } @@ -229,6 +267,8 @@ func TestStrictJSONReturnsErrorOnInvalidPayload(t *testing.T) { } func TestLimitedJSONDecodesValidPayload(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` } @@ -245,6 +285,8 @@ func TestLimitedJSONDecodesValidPayload(t *testing.T) { } func TestLimitedJSONUsesDefaultOnNegativeMaxSize(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` } @@ -261,6 +303,8 @@ func TestLimitedJSONUsesDefaultOnNegativeMaxSize(t *testing.T) { } func TestLimitedJSONReturnsErrorOnInvalidPayload(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` } @@ -276,6 +320,8 @@ func TestLimitedJSONReturnsErrorOnInvalidPayload(t *testing.T) { } func TestStrictLimitedJSONDecodesValidPayload(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` } @@ -292,6 +338,8 @@ func TestStrictLimitedJSONDecodesValidPayload(t *testing.T) { } func TestStrictLimitedJSONRejectsUnknownFields(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` } @@ -307,6 +355,8 @@ func TestStrictLimitedJSONRejectsUnknownFields(t *testing.T) { } func TestStrictLimitedJSONUsesDefaultOnNegativeMaxSize(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` } @@ -323,6 +373,8 @@ func TestStrictLimitedJSONUsesDefaultOnNegativeMaxSize(t *testing.T) { } func TestStrictLimitedJSONReturnsErrorOnInvalidPayload(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` } @@ -338,6 +390,8 @@ func TestStrictLimitedJSONReturnsErrorOnInvalidPayload(t *testing.T) { } func TestXMLDecodesValidPayload(t *testing.T) { + t.Parallel() + type payload struct { Name string `xml:"name"` } @@ -354,6 +408,8 @@ func TestXMLDecodesValidPayload(t *testing.T) { } func TestXMLReturnsErrorOnInvalidPayload(t *testing.T) { + t.Parallel() + type payload struct { Name string `xml:"name"` } @@ -369,6 +425,8 @@ func TestXMLReturnsErrorOnInvalidPayload(t *testing.T) { } func TestLimitedXMLDecodesValidPayload(t *testing.T) { + t.Parallel() + type payload struct { Name string `xml:"name"` } @@ -385,6 +443,8 @@ func TestLimitedXMLDecodesValidPayload(t *testing.T) { } func TestLimitedXMLUsesDefaultOnNegativeMaxSize(t *testing.T) { + t.Parallel() + type payload struct { Name string `xml:"name"` } @@ -401,6 +461,8 @@ func TestLimitedXMLUsesDefaultOnNegativeMaxSize(t *testing.T) { } func TestLimitedXMLReturnsErrorOnInvalidPayload(t *testing.T) { + t.Parallel() + type payload struct { Name string `xml:"name"` } diff --git a/contract/request/cookie.go b/contract/request/cookie.go index 3923e42..3019f14 100644 --- a/contract/request/cookie.go +++ b/contract/request/cookie.go @@ -2,15 +2,8 @@ package request import "net/http" -// Cookie retrieves a cookie by name from the HTTP request. -// Returns the cookie if found, or nil if the cookie doesn't exist -// or if there's an error retrieving it. -// -// Parameters: -// - r: The HTTP request to search for the cookie -// - name: The name of the cookie to retrieve -// -// Returns the cookie object or nil if not found. +// Cookie retrieves the named cookie from the HTTP request using +// [http.Request.Cookie]. It returns nil if the cookie does not exist. func Cookie(r *http.Request, name string) *http.Cookie { if cookie, err := r.Cookie(name); err == nil { return cookie @@ -19,15 +12,8 @@ func Cookie(r *http.Request, name string) *http.Cookie { return nil } -// CookieValue retrieves the value of a cookie by name from the HTTP request. -// This is a convenience function that extracts just the value from the cookie. -// Returns an empty string if the cookie doesn't exist. -// -// Parameters: -// - r: The HTTP request to search for the cookie -// - name: The name of the cookie whose value to retrieve -// -// Returns the cookie value as a string, or empty string if not found. +// CookieValue retrieves the value of the named cookie from the HTTP +// request. It returns an empty string if the cookie does not exist. func CookieValue(r *http.Request, name string) string { if cookie := Cookie(r, name); cookie != nil { return cookie.Value @@ -36,16 +22,8 @@ func CookieValue(r *http.Request, name string) string { return "" } -// CookieValueOr retrieves the value of a cookie by name, returning a default -// value if the cookie doesn't exist or has an empty value. This is useful -// for providing fallback values when cookies are optional. -// -// Parameters: -// - r: The HTTP request to search for the cookie -// - name: The name of the cookie whose value to retrieve -// - fallback: The default value to return if the cookie is not found or empty -// -// Returns the cookie value if found and non-empty, otherwise the default value. +// CookieValueOr retrieves the value of the named cookie, falling back +// to the provided default value if the cookie is missing or empty. func CookieValueOr(r *http.Request, name string, fallback string) string { if value := CookieValue(r, name); value != "" { return value diff --git a/contract/request/cookie_test.go b/contract/request/cookie_test.go index 14197e1..b6c3bb7 100644 --- a/contract/request/cookie_test.go +++ b/contract/request/cookie_test.go @@ -10,6 +10,8 @@ import ( ) func TestCookieReturnsMatchingCookie(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.AddCookie(&http.Cookie{Name: "session", Value: "abc123"}) @@ -21,6 +23,8 @@ func TestCookieReturnsMatchingCookie(t *testing.T) { } func TestCookieReturnsNilWhenNotFound(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) cookie := request.Cookie(r, "missing") @@ -29,6 +33,8 @@ func TestCookieReturnsNilWhenNotFound(t *testing.T) { } func TestCookieValueReturnsValue(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.AddCookie(&http.Cookie{Name: "token", Value: "xyz"}) @@ -38,6 +44,8 @@ func TestCookieValueReturnsValue(t *testing.T) { } func TestCookieValueReturnsEmptyWhenNotFound(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) value := request.CookieValue(r, "missing") @@ -46,6 +54,8 @@ func TestCookieValueReturnsEmptyWhenNotFound(t *testing.T) { } func TestCookieValueOrReturnsValueWhenPresent(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.AddCookie(&http.Cookie{Name: "lang", Value: "en"}) @@ -55,6 +65,8 @@ func TestCookieValueOrReturnsValueWhenPresent(t *testing.T) { } func TestCookieValueOrReturnsFallbackWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) value := request.CookieValueOr(r, "lang", "fr") @@ -63,6 +75,8 @@ func TestCookieValueOrReturnsFallbackWhenMissing(t *testing.T) { } func TestCookieValueOrReturnsFallbackWhenEmpty(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.AddCookie(&http.Cookie{Name: "lang", Value: ""}) diff --git a/contract/request/header_test.go b/contract/request/header_test.go index e8abfaa..0ed312d 100644 --- a/contract/request/header_test.go +++ b/contract/request/header_test.go @@ -10,6 +10,8 @@ import ( ) func TestHeaderReturnsValue(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("X-Custom", "value") @@ -19,6 +21,8 @@ func TestHeaderReturnsValue(t *testing.T) { } func TestHeaderReturnsEmptyWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result := request.Header(r, "X-Missing") @@ -27,6 +31,8 @@ func TestHeaderReturnsEmptyWhenMissing(t *testing.T) { } func TestHasHeaderReturnsTrueWhenPresent(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("Authorization", "Bearer token") @@ -36,6 +42,8 @@ func TestHasHeaderReturnsTrueWhenPresent(t *testing.T) { } func TestHasHeaderReturnsFalseWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result := request.HasHeader(r, "Authorization") @@ -44,6 +52,8 @@ func TestHasHeaderReturnsFalseWhenMissing(t *testing.T) { } func TestHasHeaderReturnsFalseWhenEmpty(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("X-Empty", "") @@ -53,6 +63,8 @@ func TestHasHeaderReturnsFalseWhenEmpty(t *testing.T) { } func TestHeaderOrReturnsValueWhenPresent(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("Accept", "application/json") @@ -62,6 +74,8 @@ func TestHeaderOrReturnsValueWhenPresent(t *testing.T) { } func TestHeaderOrReturnsFallbackWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result := request.HeaderOr(r, "Accept", "text/html") @@ -70,6 +84,8 @@ func TestHeaderOrReturnsFallbackWhenMissing(t *testing.T) { } func TestHeaderOrReturnsFallbackWhenEmpty(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("Accept", "") @@ -79,6 +95,8 @@ func TestHeaderOrReturnsFallbackWhenEmpty(t *testing.T) { } func TestHeaderValuesReturnsAllValues(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("X-Multi", "first") r.Header.Add("X-Multi", "second") @@ -89,6 +107,8 @@ func TestHeaderValuesReturnsAllValues(t *testing.T) { } func TestHeaderValuesReturnsNilWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result := request.HeaderValues(r, "X-Missing") diff --git a/contract/request/hooks_test.go b/contract/request/hooks_test.go index 6b78326..1b6a57e 100644 --- a/contract/request/hooks_test.go +++ b/contract/request/hooks_test.go @@ -24,6 +24,8 @@ func (stubHooks) BeforeWriteHeaderFuncs() []contract.BeforeWriteHeaderHook { } func TestHooksReturnsHooksFromContext(t *testing.T) { + t.Parallel() + hooks := stubHooks{} ctx := context.WithValue( context.Background(), contract.HooksKey, contract.Hooks(hooks), @@ -38,6 +40,8 @@ func TestHooksReturnsHooksFromContext(t *testing.T) { } func TestHooksPanicsWithoutContext(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) require.Panics(t, func() { @@ -46,6 +50,8 @@ func TestHooksPanicsWithoutContext(t *testing.T) { } func TestHooksPanicsWithErrNoHooksMiddleware(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) defer func() { @@ -57,6 +63,8 @@ func TestHooksPanicsWithErrNoHooksMiddleware(t *testing.T) { } func TestTryHooksReturnsTrueWhenPresent(t *testing.T) { + t.Parallel() + hooks := stubHooks{} ctx := context.WithValue( context.Background(), contract.HooksKey, contract.Hooks(hooks), @@ -72,6 +80,8 @@ func TestTryHooksReturnsTrueWhenPresent(t *testing.T) { } func TestTryHooksReturnsFalseWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result, ok := request.TryHooks(r) @@ -81,6 +91,8 @@ func TestTryHooksReturnsFalseWhenMissing(t *testing.T) { } func TestTryHooksReturnsFalseWhenWrongType(t *testing.T) { + t.Parallel() + ctx := context.WithValue( context.Background(), contract.HooksKey, "not hooks", ) @@ -95,6 +107,8 @@ func TestTryHooksReturnsFalseWhenWrongType(t *testing.T) { } func TestErrNoHooksMiddlewareHasCorrectTitle(t *testing.T) { + t.Parallel() + require.Equal( t, "No hooks context", @@ -103,6 +117,8 @@ func TestErrNoHooksMiddlewareHasCorrectTitle(t *testing.T) { } func TestErrNoHooksMiddlewareHasCorrectStatus(t *testing.T) { + t.Parallel() + require.Equal( t, http.StatusInternalServerError, diff --git a/contract/request/param.go b/contract/request/param.go index cddb6ab..1deb6cb 100644 --- a/contract/request/param.go +++ b/contract/request/param.go @@ -6,32 +6,14 @@ import ( "strconv" ) -// Param retrieves a path parameter value by name from the -// HTTP request. This uses Go's built-in PathValue method -// which extracts values from URL path patterns like -// "/users/{id}" where {id} is the parameter name. -// -// Parameters: -// - r: The HTTP request containing the path parameters -// - name: The name of the path parameter to retrieve -// -// Returns the parameter value as a string, or empty string if not found. +// Param retrieves the value of the named path parameter from the +// given HTTP request using [http.Request.PathValue]. func Param(r *http.Request, name string) string { return r.PathValue(name) } -// ParamOr retrieves a path parameter value by name, -// returning a default value if the parameter doesn't exist -// or is empty. This is useful for providing fallback values -// when path parameters are optional or when you want to -// handle missing parameters gracefully. -// -// Parameters: -// - r: The HTTP request containing the path parameters -// - name: The name of the path parameter to retrieve -// - fallback: The default value to return if the parameter is not found or empty -// -// Returns the parameter value if found and non-empty, otherwise the default value. +// ParamOr retrieves the named path parameter, falling back to +// the provided default value if the parameter is missing or empty. func ParamOr(r *http.Request, name string, fallback string) string { if value := Param(r, name); value != "" { return value @@ -40,23 +22,14 @@ func ParamOr(r *http.Request, name string, fallback string) string { return fallback } -// ParamInt retrieves a path parameter by name and parses -// it as an integer. This prevents injection via malformed -// numeric path parameters by validating that the value is -// a well-formed integer. -// -// Parameters: -// - r: The HTTP request containing the path parameters -// - k: The name of the path parameter to parse -// -// Returns the parsed integer value and any parsing error. -// Returns an error if the parameter is empty or is not +// ParamInt retrieves the named path parameter and parses it as an +// integer. It returns an error if the parameter is empty or is not // a valid integer string. -func ParamInt(r *http.Request, k string) (int, error) { - raw := Param(r, k) +func ParamInt(r *http.Request, name string) (int, error) { + raw := Param(r, name) if raw == "" { - return 0, fmt.Errorf("path parameter %q is empty", k) + return 0, fmt.Errorf("path parameter %q is empty", name) } value, err := strconv.Atoi(raw) @@ -64,31 +37,21 @@ func ParamInt(r *http.Request, k string) (int, error) { if err != nil { return 0, fmt.Errorf( "path parameter %q is not a valid integer: %w", - k, err, + name, err, ) } return value, nil } -// ParamIntOr retrieves a path parameter by name and parses -// it as an integer, returning the provided fallback value -// if the parameter is empty or cannot be parsed. This is -// useful when a numeric path parameter is optional or when -// a sensible default exists. -// -// Parameters: -// - r: The HTTP request containing the path parameters -// - k: The name of the path parameter to parse -// - d: The fallback value to return on failure -// -// Returns the parsed integer if valid, otherwise the -// fallback value. -func ParamIntOr(r *http.Request, k string, d int) int { - value, err := ParamInt(r, k) +// ParamIntOr retrieves the named path parameter and parses it as an +// integer, falling back to the provided default value if the parameter +// is empty or cannot be parsed. +func ParamIntOr(r *http.Request, name string, fallback int) int { + value, err := ParamInt(r, name) if err != nil { - return d + return fallback } return value diff --git a/contract/request/param_test.go b/contract/request/param_test.go index b5c4842..6f1ceb9 100644 --- a/contract/request/param_test.go +++ b/contract/request/param_test.go @@ -10,6 +10,8 @@ import ( ) func TestParamReturnsPathValue(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/users/42", nil) r.SetPathValue("id", "42") @@ -19,6 +21,8 @@ func TestParamReturnsPathValue(t *testing.T) { } func TestParamReturnsEmptyWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result := request.Param(r, "id") @@ -27,6 +31,8 @@ func TestParamReturnsEmptyWhenMissing(t *testing.T) { } func TestParamOrReturnsValueWhenPresent(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/users/42", nil) r.SetPathValue("id", "42") @@ -36,6 +42,8 @@ func TestParamOrReturnsValueWhenPresent(t *testing.T) { } func TestParamOrReturnsFallbackWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result := request.ParamOr(r, "id", "default") @@ -44,6 +52,8 @@ func TestParamOrReturnsFallbackWhenMissing(t *testing.T) { } func TestParamOrReturnsFallbackWhenEmpty(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.SetPathValue("id", "") @@ -53,6 +63,8 @@ func TestParamOrReturnsFallbackWhenEmpty(t *testing.T) { } func TestParamIntReturnsInteger(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/users/42", nil) r.SetPathValue("id", "42") @@ -63,6 +75,8 @@ func TestParamIntReturnsInteger(t *testing.T) { } func TestParamIntReturnsErrorWhenEmpty(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) _, err := request.ParamInt(r, "id") @@ -72,6 +86,8 @@ func TestParamIntReturnsErrorWhenEmpty(t *testing.T) { } func TestParamIntReturnsErrorWhenNotInteger(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/users/abc", nil) r.SetPathValue("id", "abc") @@ -82,6 +98,8 @@ func TestParamIntReturnsErrorWhenNotInteger(t *testing.T) { } func TestParamIntReturnsNegativeInteger(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.SetPathValue("offset", "-5") @@ -92,6 +110,8 @@ func TestParamIntReturnsNegativeInteger(t *testing.T) { } func TestParamIntReturnsZero(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.SetPathValue("page", "0") @@ -102,6 +122,8 @@ func TestParamIntReturnsZero(t *testing.T) { } func TestParamIntOrReturnsIntegerWhenValid(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/users/99", nil) r.SetPathValue("id", "99") @@ -111,6 +133,8 @@ func TestParamIntOrReturnsIntegerWhenValid(t *testing.T) { } func TestParamIntOrReturnsFallbackWhenEmpty(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result := request.ParamIntOr(r, "id", 10) @@ -119,6 +143,8 @@ func TestParamIntOrReturnsFallbackWhenEmpty(t *testing.T) { } func TestParamIntOrReturnsFallbackWhenNotInteger(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) r.SetPathValue("id", "abc") diff --git a/contract/request/query.go b/contract/request/query.go index 53beaab..bd21068 100644 --- a/contract/request/query.go +++ b/contract/request/query.go @@ -6,47 +6,22 @@ import ( "strconv" ) -// Query retrieves a query parameter value by name from the -// HTTP request URL. This extracts values from the URL query -// string like "?name=value&other=test" where the parameter -// name matches the provided key. -// -// Parameters: -// - r: The HTTP request containing the URL with query parameters -// - name: The name of the query parameter to retrieve -// -// Returns the first value associated with the key, or empty string if not found. +// Query retrieves the first value of the named query parameter +// from the request URL. It returns an empty string if the parameter +// is not present. func Query(r *http.Request, name string) string { return r.URL.Query().Get(name) } -// HasQuery checks if a query parameter exists in the HTTP -// request URL, regardless of its value. This is useful for -// distinguishing between a parameter that doesn't exist and -// one that exists but has an empty value. -// -// Parameters: -// - r: The HTTP request containing the URL with query parameters -// - name: The name of the query parameter to check for -// -// Returns true if the parameter exists in the query string, false otherwise. +// HasQuery reports whether the named query parameter exists in the +// request URL, regardless of its value. func HasQuery(r *http.Request, name string) bool { return r.URL.Query().Has(name) } -// QueryOr retrieves a query parameter value by name, -// returning a default value if the parameter doesn't exist. -// Note that if the parameter exists but has an empty value, -// the empty value is returned, not the default. This is -// useful for providing fallback values for optional -// parameters. -// -// Parameters: -// - r: The HTTP request containing the URL with query parameters -// - name: The name of the query parameter to retrieve -// - fallback: The default value to return if the parameter doesn't exist -// -// Returns the parameter value if it exists, otherwise the default value. +// QueryOr retrieves the named query parameter, falling back to the +// provided default value if the parameter does not exist. If the +// parameter exists but has an empty value, the empty value is returned. func QueryOr(r *http.Request, name string, fallback string) string { if HasQuery(r, name) { return Query(r, name) @@ -55,24 +30,14 @@ func QueryOr(r *http.Request, name string, fallback string) string { return fallback } -// QueryInt retrieves a query parameter by name and parses -// it as an integer. This prevents injection via malformed -// numeric query parameters by validating that the value is -// a well-formed integer. -// -// Parameters: -// - r: The HTTP request containing the URL with query -// parameters -// - k: The name of the query parameter to parse -// -// Returns the parsed integer value and any parsing error. -// Returns an error if the parameter is missing or is not +// QueryInt retrieves the named query parameter and parses it as an +// integer. It returns an error if the parameter is missing or is not // a valid integer string. -func QueryInt(r *http.Request, k string) (int, error) { - raw := Query(r, k) +func QueryInt(r *http.Request, name string) (int, error) { + raw := Query(r, name) if raw == "" { - return 0, fmt.Errorf("query parameter %q is empty", k) + return 0, fmt.Errorf("query parameter %q is empty", name) } value, err := strconv.Atoi(raw) @@ -80,32 +45,21 @@ func QueryInt(r *http.Request, k string) (int, error) { if err != nil { return 0, fmt.Errorf( "query parameter %q is not a valid integer: %w", - k, err, + name, err, ) } return value, nil } -// QueryIntOr retrieves a query parameter by name and parses -// it as an integer, returning the provided fallback value -// if the parameter is missing or cannot be parsed. This is -// useful when a numeric query parameter is optional or when -// a sensible default exists (e.g., pagination page numbers). -// -// Parameters: -// - r: The HTTP request containing the URL with query -// parameters -// - k: The name of the query parameter to parse -// - d: The fallback value to return on failure -// -// Returns the parsed integer if valid, otherwise the -// fallback value. -func QueryIntOr(r *http.Request, k string, d int) int { - value, err := QueryInt(r, k) +// QueryIntOr retrieves the named query parameter and parses it as an +// integer, falling back to the provided default value if the parameter +// is missing or cannot be parsed. +func QueryIntOr(r *http.Request, name string, fallback int) int { + value, err := QueryInt(r, name) if err != nil { - return d + return fallback } return value diff --git a/contract/request/query_test.go b/contract/request/query_test.go index 5bcf489..dff8e7d 100644 --- a/contract/request/query_test.go +++ b/contract/request/query_test.go @@ -10,6 +10,8 @@ import ( ) func TestQueryReturnsValue(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?name=alice", nil) result := request.Query(r, "name") @@ -18,6 +20,8 @@ func TestQueryReturnsValue(t *testing.T) { } func TestQueryReturnsEmptyWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result := request.Query(r, "name") @@ -26,6 +30,8 @@ func TestQueryReturnsEmptyWhenMissing(t *testing.T) { } func TestQueryReturnsEmptyValueWhenPresentButEmpty(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?name=", nil) result := request.Query(r, "name") @@ -34,6 +40,8 @@ func TestQueryReturnsEmptyValueWhenPresentButEmpty(t *testing.T) { } func TestHasQueryReturnsTrueWhenPresent(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?active=true", nil) result := request.HasQuery(r, "active") @@ -42,6 +50,8 @@ func TestHasQueryReturnsTrueWhenPresent(t *testing.T) { } func TestHasQueryReturnsTrueWhenPresentButEmpty(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?active=", nil) result := request.HasQuery(r, "active") @@ -50,6 +60,8 @@ func TestHasQueryReturnsTrueWhenPresentButEmpty(t *testing.T) { } func TestHasQueryReturnsTrueWhenPresentNoValue(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?active", nil) result := request.HasQuery(r, "active") @@ -58,6 +70,8 @@ func TestHasQueryReturnsTrueWhenPresentNoValue(t *testing.T) { } func TestHasQueryReturnsFalseWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result := request.HasQuery(r, "active") @@ -66,6 +80,8 @@ func TestHasQueryReturnsFalseWhenMissing(t *testing.T) { } func TestQueryOrReturnsValueWhenPresent(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?page=2", nil) result := request.QueryOr(r, "page", "1") @@ -74,6 +90,8 @@ func TestQueryOrReturnsValueWhenPresent(t *testing.T) { } func TestQueryOrReturnsEmptyValueWhenParamExists(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?page=", nil) result := request.QueryOr(r, "page", "1") @@ -82,6 +100,8 @@ func TestQueryOrReturnsEmptyValueWhenParamExists(t *testing.T) { } func TestQueryOrReturnsFallbackWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result := request.QueryOr(r, "page", "1") @@ -90,6 +110,8 @@ func TestQueryOrReturnsFallbackWhenMissing(t *testing.T) { } func TestQueryIntReturnsInteger(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?page=5", nil) result, err := request.QueryInt(r, "page") @@ -99,6 +121,8 @@ func TestQueryIntReturnsInteger(t *testing.T) { } func TestQueryIntReturnsErrorWhenEmpty(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) _, err := request.QueryInt(r, "page") @@ -108,6 +132,8 @@ func TestQueryIntReturnsErrorWhenEmpty(t *testing.T) { } func TestQueryIntReturnsErrorWhenNotInteger(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?page=abc", nil) _, err := request.QueryInt(r, "page") @@ -117,6 +143,8 @@ func TestQueryIntReturnsErrorWhenNotInteger(t *testing.T) { } func TestQueryIntReturnsNegativeInteger(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?offset=-3", nil) result, err := request.QueryInt(r, "offset") @@ -126,6 +154,8 @@ func TestQueryIntReturnsNegativeInteger(t *testing.T) { } func TestQueryIntReturnsZero(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?page=0", nil) result, err := request.QueryInt(r, "page") @@ -135,6 +165,8 @@ func TestQueryIntReturnsZero(t *testing.T) { } func TestQueryIntOrReturnsIntegerWhenValid(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?page=7", nil) result := request.QueryIntOr(r, "page", 1) @@ -143,6 +175,8 @@ func TestQueryIntOrReturnsIntegerWhenValid(t *testing.T) { } func TestQueryIntOrReturnsFallbackWhenEmpty(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result := request.QueryIntOr(r, "page", 1) @@ -151,6 +185,8 @@ func TestQueryIntOrReturnsFallbackWhenEmpty(t *testing.T) { } func TestQueryIntOrReturnsFallbackWhenNotInteger(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/?page=abc", nil) result := request.QueryIntOr(r, "page", 1) diff --git a/contract/request/session_test.go b/contract/request/session_test.go index 6f4a93a..bb1fe29 100644 --- a/contract/request/session_test.go +++ b/contract/request/session_test.go @@ -34,6 +34,8 @@ func (s stubSession) HasRegenerated() bool { return false } func (s stubSession) MarkAsUnchanged() {} func TestSessionReturnsTrueWhenPresent(t *testing.T) { + t.Parallel() + sess := stubSession{id: "sess-1"} ctx := context.WithValue( context.Background(), @@ -51,6 +53,8 @@ func TestSessionReturnsTrueWhenPresent(t *testing.T) { } func TestSessionReturnsFalseWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) result, ok := request.Session(r) @@ -60,6 +64,8 @@ func TestSessionReturnsFalseWhenMissing(t *testing.T) { } func TestSessionKeyedReturnsTrueWhenPresent(t *testing.T) { + t.Parallel() + type customKey struct{} sess := stubSession{id: "sess-2"} ctx := context.WithValue( @@ -78,6 +84,8 @@ func TestSessionKeyedReturnsTrueWhenPresent(t *testing.T) { } func TestSessionKeyedReturnsFalseWhenMissing(t *testing.T) { + t.Parallel() + type customKey struct{} r := httptest.NewRequest(http.MethodGet, "/", nil) @@ -88,6 +96,8 @@ func TestSessionKeyedReturnsFalseWhenMissing(t *testing.T) { } func TestSessionKeyedReturnsFalseWhenWrongType(t *testing.T) { + t.Parallel() + ctx := context.WithValue( context.Background(), contract.SessionKey, @@ -104,6 +114,8 @@ func TestSessionKeyedReturnsFalseWhenWrongType(t *testing.T) { } func TestMustSessionReturnsSessionWhenPresent(t *testing.T) { + t.Parallel() + sess := stubSession{id: "sess-3"} ctx := context.WithValue( context.Background(), @@ -120,6 +132,8 @@ func TestMustSessionReturnsSessionWhenPresent(t *testing.T) { } func TestMustSessionPanicsWhenMissing(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) require.Panics(t, func() { @@ -128,6 +142,8 @@ func TestMustSessionPanicsWhenMissing(t *testing.T) { } func TestMustSessionPanicsWithErrSessionNotFound(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodGet, "/", nil) defer func() { @@ -139,6 +155,8 @@ func TestMustSessionPanicsWithErrSessionNotFound(t *testing.T) { } func TestMustSessionKeyedReturnsSessionWhenPresent(t *testing.T) { + t.Parallel() + type customKey struct{} sess := stubSession{id: "sess-4"} ctx := context.WithValue( @@ -156,6 +174,8 @@ func TestMustSessionKeyedReturnsSessionWhenPresent(t *testing.T) { } func TestMustSessionKeyedPanicsWhenMissing(t *testing.T) { + t.Parallel() + type customKey struct{} r := httptest.NewRequest(http.MethodGet, "/", nil) @@ -165,6 +185,8 @@ func TestMustSessionKeyedPanicsWhenMissing(t *testing.T) { } func TestMustSessionKeyedPanicsWithErrSessionNotFound(t *testing.T) { + t.Parallel() + type customKey struct{} r := httptest.NewRequest(http.MethodGet, "/", nil) @@ -177,6 +199,8 @@ func TestMustSessionKeyedPanicsWithErrSessionNotFound(t *testing.T) { } func TestErrSessionNotFoundHasCorrectTitle(t *testing.T) { + t.Parallel() + require.Equal( t, "Session not found", @@ -185,6 +209,8 @@ func TestErrSessionNotFoundHasCorrectTitle(t *testing.T) { } func TestErrSessionNotFoundHasCorrectStatus(t *testing.T) { + t.Parallel() + require.Equal( t, http.StatusInternalServerError, diff --git a/contract/response/static_test.go b/contract/response/static_test.go index 9fb9393..38f7b80 100644 --- a/contract/response/static_test.go +++ b/contract/response/static_test.go @@ -14,6 +14,8 @@ import ( ) func TestRawWritesBytesWithStatus(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.Raw(w, http.StatusOK, []byte("hello")) @@ -24,6 +26,8 @@ func TestRawWritesBytesWithStatus(t *testing.T) { } func TestRawSetsDefaultContentType(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.Raw(w, http.StatusOK, []byte("data")) @@ -37,6 +41,8 @@ func TestRawSetsDefaultContentType(t *testing.T) { } func TestRawPreservesExistingContentType(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() w.Header().Set("Content-Type", "text/plain") @@ -47,6 +53,8 @@ func TestRawPreservesExistingContentType(t *testing.T) { } func TestRawWithEmptyBody(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.Raw(w, http.StatusOK, []byte{}) @@ -57,6 +65,8 @@ func TestRawWithEmptyBody(t *testing.T) { } func TestStatusSetsStatusCode(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.Status(w, http.StatusNoContent) @@ -66,6 +76,8 @@ func TestStatusSetsStatusCode(t *testing.T) { } func TestStatusWritesNoBody(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.Status(w, http.StatusCreated) @@ -75,6 +87,8 @@ func TestStatusWritesNoBody(t *testing.T) { } func TestBytesWritesWithOctetStreamContentType(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.Bytes( @@ -92,6 +106,8 @@ func TestBytesWritesWithOctetStreamContentType(t *testing.T) { } func TestStringWritesPlainText(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.String(w, http.StatusOK, "hello world") @@ -107,6 +123,8 @@ func TestStringWritesPlainText(t *testing.T) { } func TestStringEmptyBody(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.String(w, http.StatusOK, "") @@ -116,6 +134,8 @@ func TestStringEmptyBody(t *testing.T) { } func TestStringTemplateExecutes(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() tmpl := template.Must( template.New("test").Parse("Hello, {{.Name}}!"), @@ -136,6 +156,8 @@ func TestStringTemplateExecutes(t *testing.T) { } func TestStringTemplateReturnsErrorOnBadTemplate(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() tmpl := template.Must( template.New("test").Parse("{{.Name}}"), @@ -147,6 +169,8 @@ func TestStringTemplateReturnsErrorOnBadTemplate(t *testing.T) { } func TestHTMLWritesHTMLContent(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.HTML( @@ -164,6 +188,8 @@ func TestHTMLWritesHTMLContent(t *testing.T) { } func TestHTMLTemplateExecutes(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() tmpl := htmltemplate.Must( htmltemplate.New("test").Parse("

{{.Name}}

"), @@ -184,6 +210,8 @@ func TestHTMLTemplateExecutes(t *testing.T) { } func TestHTMLTemplateReturnsErrorOnBadTemplate(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() tmpl := htmltemplate.Must( htmltemplate.New("test").Parse("{{.Name}}"), @@ -195,6 +223,8 @@ func TestHTMLTemplateReturnsErrorOnBadTemplate(t *testing.T) { } func TestJSONWritesJSONContent(t *testing.T) { + t.Parallel() + type payload struct { Name string `json:"name"` Age int `json:"age"` @@ -223,6 +253,8 @@ func TestJSONWritesJSONContent(t *testing.T) { } func TestJSONWritesNullForNil(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.JSON[any](w, http.StatusOK, nil) @@ -232,6 +264,8 @@ func TestJSONWritesNullForNil(t *testing.T) { } func TestJSONSetsStatusCode(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.JSON(w, http.StatusCreated, map[string]string{ @@ -243,6 +277,8 @@ func TestJSONSetsStatusCode(t *testing.T) { } func TestXMLWritesXMLContent(t *testing.T) { + t.Parallel() + type payload struct { XMLName xml.Name `xml:"item"` Name string `xml:"name"` @@ -265,6 +301,8 @@ func TestXMLWritesXMLContent(t *testing.T) { } func TestXMLSetsStatusCode(t *testing.T) { + t.Parallel() + type payload struct { XMLName xml.Name `xml:"item"` ID int `xml:"id"` @@ -281,6 +319,8 @@ func TestXMLSetsStatusCode(t *testing.T) { } func TestRedirectSetsLocationHeader(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.Redirect( @@ -297,6 +337,8 @@ func TestRedirectSetsLocationHeader(t *testing.T) { } func TestRedirectWritesNoBody(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.Redirect( @@ -308,6 +350,8 @@ func TestRedirectWritesNoBody(t *testing.T) { } func TestSafeRedirectAllowsRelativePath(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.SafeRedirect( @@ -320,6 +364,8 @@ func TestSafeRedirectAllowsRelativePath(t *testing.T) { } func TestSafeRedirectAllowsRelativePathWithQuery(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.SafeRedirect( @@ -335,6 +381,8 @@ func TestSafeRedirectAllowsRelativePathWithQuery(t *testing.T) { } func TestSafeRedirectRejectsAbsoluteURL(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.SafeRedirect( @@ -345,6 +393,8 @@ func TestSafeRedirectRejectsAbsoluteURL(t *testing.T) { } func TestSafeRedirectRejectsProtocolRelativeURL(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.SafeRedirect( @@ -355,6 +405,8 @@ func TestSafeRedirectRejectsProtocolRelativeURL(t *testing.T) { } func TestSafeRedirectRejectsNonSlashPrefix(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.SafeRedirect( @@ -365,6 +417,8 @@ func TestSafeRedirectRejectsNonSlashPrefix(t *testing.T) { } func TestSafeRedirectRejectsEmptyURL(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.SafeRedirect(w, http.StatusFound, "") @@ -373,6 +427,8 @@ func TestSafeRedirectRejectsEmptyURL(t *testing.T) { } func TestSafeRedirectRejectsUnparseableURL(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() err := response.SafeRedirect( @@ -383,6 +439,8 @@ func TestSafeRedirectRejectsUnparseableURL(t *testing.T) { } func TestErrUnsafeRedirectMessage(t *testing.T) { + t.Parallel() + require.Equal( t, "unsafe redirect URL: must be a relative path", diff --git a/contract/response/stream_test.go b/contract/response/stream_test.go index d8a6bb6..25c15fa 100644 --- a/contract/response/stream_test.go +++ b/contract/response/stream_test.go @@ -70,6 +70,8 @@ func (w *errFlushWriter) WriteHeader(int) {} func (w *errFlushWriter) Flush() {} func TestStreamSendsDataFromChannel(t *testing.T) { + t.Parallel() + w := &flushRecorder{ ResponseRecorder: httptest.NewRecorder(), } @@ -87,6 +89,8 @@ func TestStreamSendsDataFromChannel(t *testing.T) { } func TestStreamSetsDefaultContentType(t *testing.T) { + t.Parallel() + w := &flushRecorder{ ResponseRecorder: httptest.NewRecorder(), } @@ -105,6 +109,8 @@ func TestStreamSetsDefaultContentType(t *testing.T) { } func TestStreamPreservesExistingContentType(t *testing.T) { + t.Parallel() + w := &flushRecorder{ ResponseRecorder: httptest.NewRecorder(), } @@ -120,6 +126,8 @@ func TestStreamPreservesExistingContentType(t *testing.T) { } func TestStreamSetsCacheControlHeader(t *testing.T) { + t.Parallel() + w := &flushRecorder{ ResponseRecorder: httptest.NewRecorder(), } @@ -134,6 +142,8 @@ func TestStreamSetsCacheControlHeader(t *testing.T) { } func TestStreamSetsConnectionHeader(t *testing.T) { + t.Parallel() + w := &flushRecorder{ ResponseRecorder: httptest.NewRecorder(), } @@ -152,6 +162,8 @@ func TestStreamSetsConnectionHeader(t *testing.T) { } func TestStreamReturnsNilOnChannelClose(t *testing.T) { + t.Parallel() + w := &flushRecorder{ ResponseRecorder: httptest.NewRecorder(), } @@ -165,6 +177,8 @@ func TestStreamReturnsNilOnChannelClose(t *testing.T) { } func TestStreamReturnsErrorOnContextCancellation(t *testing.T) { + t.Parallel() + w := &flushRecorder{ ResponseRecorder: httptest.NewRecorder(), } @@ -182,6 +196,8 @@ func TestStreamReturnsErrorOnContextCancellation(t *testing.T) { } func TestStreamReturnsErrorOnWriteFailure(t *testing.T) { + t.Parallel() + w := newErrFlushWriter() r := httptest.NewRequest(http.MethodGet, "/", nil) ch := make(chan []byte, 1) @@ -193,6 +209,8 @@ func TestStreamReturnsErrorOnWriteFailure(t *testing.T) { } func TestStreamReturnsErrNonFlushableWriter(t *testing.T) { + t.Parallel() + w := newNonFlushWriter() r := httptest.NewRequest(http.MethodGet, "/", nil) ch := make(chan []byte) @@ -203,6 +221,8 @@ func TestStreamReturnsErrNonFlushableWriter(t *testing.T) { } func TestErrNonFlushableWriterMessage(t *testing.T) { + t.Parallel() + require.Equal( t, "non-flushable response writer", @@ -211,6 +231,8 @@ func TestErrNonFlushableWriterMessage(t *testing.T) { } func TestSSESetsEventStreamContentType(t *testing.T) { + t.Parallel() + w := &flushRecorder{ ResponseRecorder: httptest.NewRecorder(), } @@ -229,6 +251,8 @@ func TestSSESetsEventStreamContentType(t *testing.T) { } func TestSSESendsDataFromChannel(t *testing.T) { + t.Parallel() + w := &flushRecorder{ ResponseRecorder: httptest.NewRecorder(), } @@ -244,6 +268,8 @@ func TestSSESendsDataFromChannel(t *testing.T) { } func TestSSEReturnsErrNonFlushableWriter(t *testing.T) { + t.Parallel() + w := newNonFlushWriter() r := httptest.NewRequest(http.MethodGet, "/", nil) ch := make(chan []byte) diff --git a/contract/session_test.go b/contract/session_test.go index 7dd6dec..9caeb24 100644 --- a/contract/session_test.go +++ b/contract/session_test.go @@ -8,10 +8,14 @@ import ( ) func TestSessionKeyIsNonNil(t *testing.T) { + t.Parallel() + require.NotNil(t, contract.SessionKey) } func TestSessionKeyIsDistinctType(t *testing.T) { + t.Parallel() + var other any = struct{}{} require.NotEqual(t, other, contract.SessionKey) From 60377068a2bca44b5b3805fdf305a8fc5def7604 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 12 Apr 2026 15:09:29 +0200 Subject: [PATCH 02/11] style(router): improve doc comments, naming, and panic assertions - Rewrite all doc comments to start with symbol name, end with period - Fix typos: recursivity->recursion, an handler->a handler, responses->requests and responses - Remove redundant zero-value initializers in New() - Rename Record parameter r->request for full-word convention - Add panic message assertions in dot-dot and empty-method tests - Remove stale RecordHandler reference --- router/router.go | 109 +++++++++++++++++------------------------- router/router_test.go | 54 +++++++++++++++++++++ 2 files changed, 98 insertions(+), 65 deletions(-) diff --git a/router/router.go b/router/router.go index cad5606..eb17d93 100644 --- a/router/router.go +++ b/router/router.go @@ -9,23 +9,14 @@ import ( "strings" ) -// Middleware is a func type that can be used to -// apply middleware logic between request and responses. +// Middleware is a func type that can be used to apply middleware logic between requests and responses. type Middleware[H http.Handler] = func(H) H -// Router is the structure that handles -// http routing in an application. -// -// This router is completely optional and -// uses [http.ServeMux] under the hood -// to register all the routes. -// -// It also handles some patterns automatically, -// such as {$}, that is appended on each route -// automatically, regardless of the pattern. +// Router is a generic HTTP router that wraps [http.ServeMux] and supports +// middleware, route groups, and automatic trailing-slash handling. type Router[H http.Handler] struct { // native stores the actual [http.ServeMux] - // that's used internally to register the routes. + // that's used internally to register the routes. native *http.ServeMux // pattern stores the current pattern that will be @@ -61,14 +52,10 @@ var allMethods = []string{ http.MethodOptions, } -// New creates a new [Router] instance and automatically -// creates all the needed components such as the middleware -// list or the native [http.ServeMux] used under the hood. +// New creates a new [Router] with an empty [http.ServeMux]. func New[H http.Handler]() *Router[H] { return &Router[H]{ native: http.NewServeMux(), - pattern: "", - parent: nil, middlewares: make([]Middleware[H], 0), } } @@ -99,12 +86,13 @@ func (router *Router[H]) Group(pattern string, subrouter func(*Router[H])) { }) } -// Grouped clones the router inside a subrouter. +// Grouped creates a cloned sub-router for scoping middleware without a path prefix. func (router *Router[H]) Grouped(subrouter func(*Router[H])) { subrouter(router.Clone()) } -// Clone creates a new subrouter and returns it. +// Clone creates a sub-router that shares the same [http.ServeMux] but has an +// independent middleware stack. func (router *Router[H]) Clone() *Router[H] { return &Router[H]{ native: nil, // parent's native will be used @@ -114,14 +102,13 @@ func (router *Router[H]) Clone() *Router[H] { } } -// With does create a new sub-router that automatically applies -// the given middlewares. +// With creates a new sub-router that applies the given middlewares in addition +// to any inherited ones. // -// This is very useful when used to inline some middlewares to -// specific routes. +// This is useful for inlining middleware on specific routes. // -// In contrast to [Router.Use] method, it does create a new -// sub-router instead of modifying the current router. +// In contrast to [Router.Use], it creates a new sub-router instead of +// modifying the current router. func (router *Router[H]) With(middlewares ...Middleware[H]) *Router[H] { return &Router[H]{ native: nil, // parent's native will be used @@ -134,7 +121,7 @@ func (router *Router[H]) With(middlewares ...Middleware[H]) *Router[H] { // mux returns the native [http.ServeMux] that is used // internally by the router. This exists because sub-routers // must use the same [http.ServeMux] and therefore, there's -// some recursivity involved to get the same [http.ServeMux]. +// some recursion involved to get the same [http.ServeMux]. func (router *Router[H]) mux() *http.ServeMux { if router.parent != nil { return router.parent.mux() @@ -143,10 +130,9 @@ func (router *Router[H]) mux() *http.ServeMux { return router.native } -// wrap makes an handler wrapped by the current routers' -// middlewares. This means that the resulting handler is -// the same as first calling the router middlewares and then the -// provided handler. +// wrap makes a handler wrapped by the current router's middleware. +// This means that the resulting handler is the same as first calling +// the router's middleware and then the provided handler. func (router *Router[H]) wrap(handler H) H { for i := len(router.middlewares) - 1; i >= 0; i-- { handler = router.middlewares[i](handler) @@ -275,6 +261,10 @@ func (router *Router[H]) registerPair(method string, pattern string, handler H) // CONNECT is reserved for HTTP proxies. If needed, use the // [Router.Trace] or [Router.Connect] methods explicitly. func (router *Router[H]) Method(method string, pattern string, handler H) { + if method == "" { + panic("router: method must not be empty") + } + pattern = path.Join(router.pattern, pattern) // When the pattern is simply a slash, we shall @@ -288,75 +278,67 @@ func (router *Router[H]) Method(method string, pattern string, handler H) { router.registerPair(method, pattern, handler) } -// Methods allows binding multiple methods to the pattern and handler. +// Methods registers a handler for each method in the given slice by calling +// [Router.Method] for each entry. func (router *Router[H]) Methods(methods []string, pattern string, handler H) { for _, method := range methods { router.Method(method, pattern, handler) } } -// Any registers all methods to the given pattern and handler. +// Any registers a handler for all standard HTTP methods (GET, HEAD, POST, PUT, +// PATCH, DELETE, OPTIONS) using [Router.Methods]. TRACE and CONNECT are +// intentionally excluded for security reasons. func (router *Router[H]) Any(pattern string, handler H) { router.Methods(allMethods, pattern, handler) } -// Get registers a new handler to the router using [Router.Method] -// and using the [http.MethodGet] as the method parameter. +// Get registers a handler for [http.MethodGet] using [Router.Method]. func (router *Router[H]) Get(pattern string, handler H) { router.Method(http.MethodGet, pattern, handler) } -// Head registers a new handler to the router using [Router.Method] -// and using the [http.MethodHead] as the method parameter. +// Head registers a handler for [http.MethodHead] using [Router.Method]. func (router *Router[H]) Head(pattern string, handler H) { router.Method(http.MethodHead, pattern, handler) } -// Post registers a new handler to the router using [Router.Method] -// and using the [http.MethodPost] as the method parameter. +// Post registers a handler for [http.MethodPost] using [Router.Method]. func (router *Router[H]) Post(pattern string, handler H) { router.Method(http.MethodPost, pattern, handler) } -// Put registers a new handler to the router using [Router.Method] -// and using the [http.MethodPut] as the method parameter. +// Put registers a handler for [http.MethodPut] using [Router.Method]. func (router *Router[H]) Put(pattern string, handler H) { router.Method(http.MethodPut, pattern, handler) } -// Patch registers a new handler to the router using [Router.Method] -// and using the [http.MethodPatch] as the method parameter. +// Patch registers a handler for [http.MethodPatch] using [Router.Method]. func (router *Router[H]) Patch(pattern string, handler H) { router.Method(http.MethodPatch, pattern, handler) } -// Delete registers a new handler to the router using [Router.Method] -// and using the [http.MethodDelete] as the method parameter. +// Delete registers a handler for [http.MethodDelete] using [Router.Method]. func (router *Router[H]) Delete(pattern string, handler H) { router.Method(http.MethodDelete, pattern, handler) } -// Connect registers a new handler to the router using [Router.Method] -// and using the [http.MethodConnect] as the method parameter. +// Connect registers a handler for [http.MethodConnect] using [Router.Method]. func (router *Router[H]) Connect(pattern string, handler H) { router.Method(http.MethodConnect, pattern, handler) } -// Options registers a new handler to the router using [Router.Method] -// and using the [http.MethodOptions] as the method parameter. +// Options registers a handler for [http.MethodOptions] using [Router.Method]. func (router *Router[H]) Options(pattern string, handler H) { router.Method(http.MethodOptions, pattern, handler) } -// Trace registers a new handler to the router using [Router.Method] -// and using the [http.MethodTrace] as the method parameter. +// Trace registers a handler for [http.MethodTrace] using [Router.Method]. func (router *Router[H]) Trace(pattern string, handler H) { router.Method(http.MethodTrace, pattern, handler) } -// ServeHTTP is the method that will make the router implement -// the handler interface, making it possible to be used -// as a handler in places like [http.Server]. +// ServeHTTP implements [http.Handler] by delegating to the underlying [http.ServeMux]. func (router *Router[H]) ServeHTTP(w http.ResponseWriter, r *http.Request) { router.mux().ServeHTTP(w, r) } @@ -364,8 +346,7 @@ func (router *Router[H]) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Has reports whether the given pattern is registered in the router // with the given method. // -// Alternatively, check out the [Router.Matches] to use an [http.Request] -// as the parameter. +// Alternatively, use the [Router.Matches] method with an [http.Request]. func (router *Router[H]) Has(method string, pattern string) bool { if req, err := http.NewRequest(method, pattern, nil); err == nil { return router.Matches(req) @@ -374,7 +355,7 @@ func (router *Router[H]) Has(method string, pattern string) bool { return false } -// Matches reports whether the given [http.Request] match any registered +// Matches reports whether the given [http.Request] matches any registered // route in the router. // // This means that, given the request method and the @@ -386,7 +367,8 @@ func (router *Router[H]) Matches(request *http.Request) bool { } // Handler returns the handler that matches the given method and pattern. -// The second return value determines if the handler was found or not. +// The second return value reports whether the handler was found; when false, +// the first return value is the zero value of H. // // For matching against an [http.Request] use the [Router.HandlerMatch] method. func (router *Router[H]) Handler(method string, pattern string) (h H, ok bool) { @@ -413,14 +395,11 @@ func (router *Router[H]) HandlerMatch(request *http.Request) (h H, ok bool) { return h, false } -// Record returns a [http.Response] that can be used to inspect what -// the given http request would have returned as a response. -// -// This method is a shortcut of calling [RecordHandler] with the router as the -// handler and the given request. -func (router *Router[H]) Record(r *http.Request) *http.Response { +// Record returns an [http.Response] produced by dispatching the given HTTP +// request through the router's full middleware and handler pipeline. +func (router *Router[H]) Record(request *http.Request) *http.Response { rr := httptest.NewRecorder() - router.ServeHTTP(rr, r) + router.ServeHTTP(rr, request) return rr.Result() } diff --git a/router/router_test.go b/router/router_test.go index 9cde5d5..7846021 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -1643,3 +1643,57 @@ func TestSubRouterMatchesFromClone(t *testing.T) { t.Fatal("sub-router Matches should find the route") } } + +func TestMethodPanicsOnDotDotPattern(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + + if r == nil { + t.Fatal("Method with '..' pattern should panic") + } + + msg, ok := r.(error) + + if !ok { + t.Fatalf("expected panic value to be an error, got %T", r) + } + + if msg.Error() == "" { + t.Fatal("panic error message should not be empty") + } + }() + + rt := router.New[http.HandlerFunc]() + rt.Method(http.MethodGet, "..", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) +} + +func TestMethodPanicsOnEmptyMethod(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + + if r == nil { + t.Fatal("Method with empty method should panic") + } + + msg, ok := r.(string) + + if !ok { + t.Fatalf("expected panic value to be a string, got %T", r) + } + + if msg == "" { + t.Fatal("panic message should not be empty") + } + }() + + rt := router.New[http.HandlerFunc]() + rt.Method("", "/test", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) +} From 6f29fae1fb8ba901d571d29493eef6c66d64c2ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 12 Apr 2026 15:09:38 +0200 Subject: [PATCH 03/11] style(problem): fix doc comments, remove unused import, use strings.Builder - Remove unused embed import (no //go:embed directive) - Fix all doc comments: Problem type interfaces, With/Without copy-on-write, Error/Errors/Unwrap descriptions, MarshalJSON/UnmarshalJSON behavior, Defaulted field enumeration, strack-trace typo - Add magic number comment for +5 RFC 9457 fields - Fix internal/accept.go doc comments: Accept type, find grammar, Order ordering - Replace make([]error, 0) with var for nil-error fast path - Use strings.Builder instead of += concatenation in textHandler - Use httptest.NewRequest in accept_test.go --- problem/internal/accept.go | 7 ++-- problem/internal/accept_test.go | 62 ++++++--------------------------- problem/problem.go | 49 ++++++++++++++------------ problem/utils.go | 6 ++-- 4 files changed, 43 insertions(+), 81 deletions(-) diff --git a/problem/internal/accept.go b/problem/internal/accept.go index 4a01e24..6dde0b6 100644 --- a/problem/internal/accept.go +++ b/problem/internal/accept.go @@ -15,8 +15,7 @@ type acceptPair struct { quality float64 } -// Accept is a type designed to help working -// with header values found in the "Accept" header. +// Accept represents a parsed HTTP Accept header for content negotiation. type Accept struct { values []acceptPair } @@ -57,7 +56,7 @@ func ParseAccept(request *http.Request) Accept { // find looks for a given media in the accept header and // returns its [acceptPair] if found. // -// The second return value is true when is found, and false otherwise. +// The second return value is true when found, and false otherwise. func (accept Accept) find(media string) (acceptPair, bool) { for _, pair := range accept.values { if media == pair.media { @@ -105,7 +104,7 @@ func (accept Accept) Quality(media string) float64 { } // Order creates an ordered slice that contains the -// actual acceptance order based on the accept quality. +// media types sorted in descending quality order. func (accept Accept) Order() []string { values := slices.Clone(accept.values) diff --git a/problem/internal/accept_test.go b/problem/internal/accept_test.go index 8abe0be..a71acdb 100644 --- a/problem/internal/accept_test.go +++ b/problem/internal/accept_test.go @@ -1,7 +1,7 @@ package internal_test import ( - "net/http" + "net/http/httptest" "testing" "github.com/studiolambda/cosmos/problem/internal" @@ -10,11 +10,7 @@ import ( func TestAcceptAccepts(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "/", nil) - - if err != nil { - t.Fatalf("failed to create request: %s", err) - } + request := httptest.NewRequest("GET", "/", nil) request.Header.Add("Accept", "application/json, text/*") @@ -40,11 +36,7 @@ func TestAcceptAccepts(t *testing.T) { func TestAcceptAcceptsWithMultipleHeaderValues(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "/", nil) - - if err != nil { - t.Fatalf("failed to create request: %s", err) - } + request := httptest.NewRequest("GET", "/", nil) request.Header.Add("Accept", "application/json") request.Header.Add("Accept", "text/*") @@ -71,11 +63,7 @@ func TestAcceptAcceptsWithMultipleHeaderValues(t *testing.T) { func TestAcceptOrder(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "/", nil) - - if err != nil { - t.Fatalf("failed to create request: %s", err) - } + request := httptest.NewRequest("GET", "/", nil) request.Header.Add("Accept", "application/json, text/*, foo/bar;q=0.3, another/*;q=0.4, bar/baz;q=0.5") @@ -110,11 +98,7 @@ func TestAcceptOrder(t *testing.T) { func TestAcceptQualityFound(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "/", nil) - - if err != nil { - t.Fatalf("failed to create request: %s", err) - } + request := httptest.NewRequest("GET", "/", nil) request.Header.Add("Accept", "application/json;q=0.8") @@ -129,11 +113,7 @@ func TestAcceptQualityFound(t *testing.T) { func TestAcceptQualityNotFound(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "/", nil) - - if err != nil { - t.Fatalf("failed to create request: %s", err) - } + request := httptest.NewRequest("GET", "/", nil) request.Header.Add("Accept", "application/json") @@ -148,11 +128,7 @@ func TestAcceptQualityNotFound(t *testing.T) { func TestAcceptQualityDefault(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "/", nil) - - if err != nil { - t.Fatalf("failed to create request: %s", err) - } + request := httptest.NewRequest("GET", "/", nil) request.Header.Add("Accept", "application/json") @@ -167,11 +143,7 @@ func TestAcceptQualityDefault(t *testing.T) { func TestParseAcceptMalformedMediaType(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "/", nil) - - if err != nil { - t.Fatalf("failed to create request: %s", err) - } + request := httptest.NewRequest("GET", "/", nil) request.Header.Add("Accept", "a/b/c/d, application/json") @@ -191,11 +163,7 @@ func TestParseAcceptMalformedMediaType(t *testing.T) { func TestAcceptOrderEmpty(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "/", nil) - - if err != nil { - t.Fatalf("failed to create request: %s", err) - } + request := httptest.NewRequest("GET", "/", nil) accept := internal.ParseAccept(request) order := accept.Order() @@ -208,11 +176,7 @@ func TestAcceptOrderEmpty(t *testing.T) { func TestAcceptAcceptsWildcardInMedia(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "/", nil) - - if err != nil { - t.Fatalf("failed to create request: %s", err) - } + request := httptest.NewRequest("GET", "/", nil) request.Header.Add("Accept", "application/json") @@ -226,11 +190,7 @@ func TestAcceptAcceptsWildcardInMedia(t *testing.T) { func TestAcceptNoAcceptHeader(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "/", nil) - - if err != nil { - t.Fatalf("failed to create request: %s", err) - } + request := httptest.NewRequest("GET", "/", nil) accept := internal.ParseAccept(request) diff --git a/problem/problem.go b/problem/problem.go index 2d164ae..535ed1e 100644 --- a/problem/problem.go +++ b/problem/problem.go @@ -1,7 +1,6 @@ package problem import ( - _ "embed" "encoding/json" "fmt" "maps" @@ -11,7 +10,8 @@ import ( "github.com/studiolambda/cosmos/problem/internal" ) -// Problem represents a problem details for HTTP APIs. +// Problem represents a problem details response for HTTP APIs. +// It implements [error], [http.Handler], and [json.Marshaler]. // See https://datatracker.ietf.org/doc/html/rfc9457 for more information. type Problem struct { @@ -79,7 +79,7 @@ type Problem struct { // use the HTTP status code. Status int - // The "instance" member is a JSON string containing a URI reference that + // The "instance" member is a JSON string containing a URI reference that // identifies the specific occurrence of the problem. // // When the "instance" URI is dereferenceable, the problem details object @@ -128,7 +128,9 @@ func (problem Problem) Additional(key string) (any, bool) { return additional, ok } -// With adds a new additional value to the given key. +// With returns a new [Problem] with an additional value set for the given key. +// The original [Problem] is not modified (copy-on-write). +// Use [Problem.Without] to remove additional values. func (problem Problem) With(key string, value any) Problem { if problem.additional == nil { problem.additional = map[string]any{key: value} @@ -158,7 +160,7 @@ func (problem Problem) WithoutError() Problem { return problem } -// Without removes an additional value to the given key. +// Without returns a new [Problem] with the additional value for the given key removed. func (problem Problem) Without(key string) Problem { if problem.additional == nil { return problem @@ -170,7 +172,7 @@ func (problem Problem) Without(key string) Problem { return problem } -// Error is the error-like string representation of a [Problem]. +// Error returns a string representation of the [Problem] for the [error] interface. func (problem Problem) Error() string { if problem.err != nil { return strings.ToLower( @@ -183,14 +185,14 @@ func (problem Problem) Error() string { ) } -// Errors returns all the strack-trace of errors that +// Errors returns all the stack trace of errors that // are bound to this particular [Problem]. func (problem Problem) Errors() []error { return stackTrace(problem.err) } -// Unwrap is used to get the original error from -// the problem to use with the errors pkg. +// Unwrap returns the underlying error from the [Problem] +// for use with [errors.Is] and [errors.As]. func (problem Problem) Unwrap() error { return problem.err } @@ -215,9 +217,10 @@ func (problem Problem) WithoutStackTrace() Problem { return problem.Without(StackTraceKey) } -// MarshalJSON replaces the default JSON encoding behaviour. +// MarshalJSON encodes the [Problem] as a flat JSON object by merging +// the five standard RFC 9457 fields with any additional values. func (problem Problem) MarshalJSON() ([]byte, error) { - mapped := make(map[string]any, len(problem.additional)+5) + mapped := make(map[string]any, len(problem.additional)+5) // 5 standard RFC 9457 fields mapped["detail"] = problem.Detail mapped["instance"] = problem.Instance @@ -230,7 +233,8 @@ func (problem Problem) MarshalJSON() ([]byte, error) { return json.Marshal(mapped) } -// UnmarshalJSON replaces the default JSON decoding behaviour. +// UnmarshalJSON decodes a flat JSON object into the [Problem] by extracting +// the five standard RFC 9457 fields and storing the remaining keys as additional values. func (problem *Problem) UnmarshalJSON(data []byte) error { mapped := make(map[string]any) @@ -269,8 +273,9 @@ func (problem *Problem) UnmarshalJSON(data []byte) error { return nil } -// Defaulted returns a [Problem] that is defaulted using the given -// request and the current instance. +// Defaulted returns a new [Problem] with zero-value fields filled with sensible +// defaults: Type defaults to "about:blank", Status to 500, Title to +// [http.StatusText] of the status, and Instance to the request URL path. func (problem Problem) Defaulted(request *http.Request) Problem { if problem.Type == "" { problem.Type = "about:blank" @@ -294,25 +299,23 @@ func (problem Problem) Defaulted(request *http.Request) Problem { // textHandler writes the problem as a plain text HTTP response // including the status code, title, detail, and any stack traces. func (problem Problem) textHandler(w http.ResponseWriter, r *http.Request) { - textResponse := fmt.Sprintf( - "%d %s\n\n%s", - problem.Status, - problem.Title, - problem.Detail, - ) + var b strings.Builder + + fmt.Fprintf(&b, "%d %s\n\n%s", problem.Status, problem.Title, problem.Detail) additional, found := problem.Additional(StackTraceKey) traces, tracesOK := additional.([]string) if found && tracesOK { - textResponse += "\n\n" + b.WriteString("\n\n") for _, trace := range traces { - textResponse += trace + "\n" + b.WriteString(trace) + b.WriteByte('\n') } } - http.Error(w, textResponse, problem.Status) + http.Error(w, b.String(), problem.Status) } // jsonHandler writes the problem as a standard application/json diff --git a/problem/utils.go b/problem/utils.go index a94baa9..f4d5f7f 100644 --- a/problem/utils.go +++ b/problem/utils.go @@ -6,12 +6,12 @@ import "errors" // that have been either Joined or Wrapped using [errors.Join] // or [fmt.Errorf] with `%w` directive. func stackTrace(err error) []error { - result := make([]error, 0) - if err == nil { - return result + return nil } + var result []error + type joined interface { Unwrap() []error } From ae133d464dec13883c50f0d888f6eea7012ad97e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 12 Apr 2026 16:49:29 +0200 Subject: [PATCH 04/11] style(framework): add handleError doc, StatusClientClosedRequest constant, hook fixes - Add doc comment on handleError explaining error inspection logic - Extract StatusClientClosedRequest = 499 constant replacing magic number - Initialize afterResponseHooks in NewHooks for consistency - Add mutex doc comment in Hooks struct - Fix ServerOptions doc to cover all zero-valued fields - Use t.Context() in handler_test.go for proper test cancellation --- framework/handler.go | 9 ++++++++- framework/handler_test.go | 30 +++++++++++++++--------------- framework/hooks.go | 2 ++ framework/server.go | 2 +- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/framework/handler.go b/framework/handler.go index b623329..ef811cb 100644 --- a/framework/handler.go +++ b/framework/handler.go @@ -54,11 +54,18 @@ type HTTPStatus interface { HTTPStatus() int } +// StatusClientClosedRequest is the non-standard HTTP status code used +// when the client closes the connection before the server responds. +const StatusClientClosedRequest = 499 + +// handleError writes an error response by inspecting the error for context +// cancellation, custom status codes via [HTTPStatus], or self-rendering +// capability via [http.Handler], falling back to a Problem Details response. func handleError(w http.ResponseWriter, r *http.Request, err error) { status := http.StatusInternalServerError if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - status = 499 // A non-standard status code: 499 Client Closed Request + status = StatusClientClosedRequest } if target := (HTTPStatus)(nil); errors.As(err, &target) { diff --git a/framework/handler_test.go b/framework/handler_test.go index fb3f585..a8a79fe 100644 --- a/framework/handler_test.go +++ b/framework/handler_test.go @@ -36,7 +36,7 @@ func TestServeHTTPNoContent(t *testing.T) { return nil }) - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) @@ -50,7 +50,7 @@ func TestServeHTTPHandlerReturnsError(t *testing.T) { return errors.New("something went wrong") }) - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) @@ -64,11 +64,11 @@ func TestServeHTTPContextCanceled(t *testing.T) { return context.Canceled }) - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) - require.Equal(t, 499, rec.Code) + require.Equal(t, framework.StatusClientClosedRequest, rec.Code) } func TestServeHTTPContextDeadlineExceeded(t *testing.T) { @@ -78,11 +78,11 @@ func TestServeHTTPContextDeadlineExceeded(t *testing.T) { return context.DeadlineExceeded }) - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) - require.Equal(t, 499, rec.Code) + require.Equal(t, framework.StatusClientClosedRequest, rec.Code) } type statusError struct { @@ -104,7 +104,7 @@ func TestServeHTTPCustomHTTPStatus(t *testing.T) { return statusError{status: http.StatusTeapot} }) - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) @@ -128,7 +128,7 @@ func TestServeHTTPErrorImplementsHandler(t *testing.T) { return handlerError{} }) - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) @@ -143,7 +143,7 @@ func TestServeHTTPErrorAfterPartialWrite(t *testing.T) { return errors.New("late error") }) - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) @@ -168,7 +168,7 @@ func TestServeHTTPAfterResponseHooksRun(t *testing.T) { return errors.New("test error") }) - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) @@ -188,7 +188,7 @@ func TestServeHTTPAfterResponseHookPanicRecovered(t *testing.T) { return nil }) - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) rec := httptest.NewRecorder() // Should not panic. @@ -209,7 +209,7 @@ func TestServeHTTPHooksInContext(t *testing.T) { return nil }) - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) @@ -232,7 +232,7 @@ func TestServeHTTPAfterResponseHookReceivesNilOnSuccess(t *testing.T) { return response.Status(w, http.StatusOK) }) - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) @@ -247,9 +247,9 @@ func TestServeHTTPWrappedContextCanceled(t *testing.T) { return fmt.Errorf("request failed: %w", context.Canceled) }) - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) - require.Equal(t, 499, rec.Code) + require.Equal(t, framework.StatusClientClosedRequest, rec.Code) } diff --git a/framework/hooks.go b/framework/hooks.go index e9b0427..532524e 100644 --- a/framework/hooks.go +++ b/framework/hooks.go @@ -14,6 +14,7 @@ import ( // // All methods are safe for concurrent use. type Hooks struct { + // mutex guards all hook slices. mutex sync.Mutex afterResponseHooks []contract.AfterResponseHook beforeWriteHeaderHooks []contract.BeforeWriteHeaderHook @@ -26,6 +27,7 @@ func NewHooks() *Hooks { return &Hooks{ beforeWriteHeaderHooks: []contract.BeforeWriteHeaderHook{}, beforeWriteHooks: []contract.BeforeWriteHook{}, + afterResponseHooks: []contract.AfterResponseHook{}, } } diff --git a/framework/server.go b/framework/server.go index e3d1f4f..26a5012 100644 --- a/framework/server.go +++ b/framework/server.go @@ -6,7 +6,7 @@ import ( ) // ServerOptions configures the HTTP server created by [NewServer]. -// All timeout fields default to secure values when zero. +// All zero-valued fields default to secure values from [DefaultServerOptions]. type ServerOptions struct { // Addr is the TCP address to listen on (e.g. ":8080"). Addr string From a5016803e8bc4e47bc501ec560388975de797c1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 12 Apr 2026 16:49:38 +0200 Subject: [PATCH 05/11] style(framework/middleware): fix import groups, naming, docs, and test consistency - Separate imports into three groups (stdlib/cosmos/external) in all test files - Rename abbreviated receiver rs->readerStringer in recover_test.go - Fix RecoverWith doc to suggest interface pattern instead of unexported type - Remove redundant 1* in rate limit CleanupInterval default - Remove compile-time interface check on test-only testTextMarshaler - Improve Provide doc comment with parameter guidance and retrieval example --- framework/middleware/cors_test.go | 3 ++- framework/middleware/csrf_test.go | 3 ++- framework/middleware/http_test.go | 3 ++- framework/middleware/logger_test.go | 3 ++- framework/middleware/provide.go | 5 +++-- framework/middleware/provide_test.go | 3 ++- framework/middleware/rate_limit.go | 2 +- framework/middleware/recover.go | 3 ++- framework/middleware/recover_test.go | 11 +++++------ framework/middleware/secure_headers_test.go | 3 ++- 10 files changed, 23 insertions(+), 16 deletions(-) diff --git a/framework/middleware/cors_test.go b/framework/middleware/cors_test.go index f14169f..c071acc 100644 --- a/framework/middleware/cors_test.go +++ b/framework/middleware/cors_test.go @@ -5,9 +5,10 @@ import ( "net/http/httptest" "testing" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/framework" "github.com/studiolambda/cosmos/framework/middleware" + + "github.com/stretchr/testify/require" ) func TestCORSPreflightSetsHeaders(t *testing.T) { diff --git a/framework/middleware/csrf_test.go b/framework/middleware/csrf_test.go index 42e670f..87e0ba7 100644 --- a/framework/middleware/csrf_test.go +++ b/framework/middleware/csrf_test.go @@ -6,10 +6,11 @@ import ( "net/http/httptest" "testing" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/framework" "github.com/studiolambda/cosmos/framework/middleware" "github.com/studiolambda/cosmos/problem" + + "github.com/stretchr/testify/require" ) func TestCSRFAllowsSameOriginRequest(t *testing.T) { diff --git a/framework/middleware/http_test.go b/framework/middleware/http_test.go index 1de7eb0..fd4d5b7 100644 --- a/framework/middleware/http_test.go +++ b/framework/middleware/http_test.go @@ -6,9 +6,10 @@ import ( "net/http/httptest" "testing" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/framework" "github.com/studiolambda/cosmos/framework/middleware" + + "github.com/stretchr/testify/require" ) func TestHTTPAdapterCallsMiddleware(t *testing.T) { diff --git a/framework/middleware/logger_test.go b/framework/middleware/logger_test.go index 1bb553b..082dee4 100644 --- a/framework/middleware/logger_test.go +++ b/framework/middleware/logger_test.go @@ -8,9 +8,10 @@ import ( "net/http/httptest" "testing" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/framework" "github.com/studiolambda/cosmos/framework/middleware" + + "github.com/stretchr/testify/require" ) func TestLoggerNilLoggerDoesNotPanic(t *testing.T) { diff --git a/framework/middleware/provide.go b/framework/middleware/provide.go index 8152cf1..120118b 100644 --- a/framework/middleware/provide.go +++ b/framework/middleware/provide.go @@ -13,8 +13,9 @@ import ( type ProvideFunc = func(w http.ResponseWriter, r *http.Request) (context.Context, error) // Provide returns a middleware that injects a static key-value pair -// into the request context. Every downstream handler and middleware -// can retrieve the value with the same key. +// into the request context. The key should be an unexported type to +// avoid collisions. Downstream handlers retrieve the value with +// r.Context().Value(key). func Provide(key any, val any) framework.Middleware { return ProvideWith(func(w http.ResponseWriter, r *http.Request) (context.Context, error) { return context.WithValue(r.Context(), key, val), nil diff --git a/framework/middleware/provide_test.go b/framework/middleware/provide_test.go index 45c2ad0..52edede 100644 --- a/framework/middleware/provide_test.go +++ b/framework/middleware/provide_test.go @@ -7,9 +7,10 @@ import ( "net/http/httptest" "testing" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/framework" "github.com/studiolambda/cosmos/framework/middleware" + + "github.com/stretchr/testify/require" ) type contextKey string diff --git a/framework/middleware/rate_limit.go b/framework/middleware/rate_limit.go index 991fd60..a985592 100644 --- a/framework/middleware/rate_limit.go +++ b/framework/middleware/rate_limit.go @@ -56,7 +56,7 @@ var DefaultRateLimitOptions = RateLimitOptions{ return r.RemoteAddr }, ErrorResponse: ErrRateLimited, - CleanupInterval: 1 * time.Minute, + CleanupInterval: time.Minute, MaxIdleTime: 5 * time.Minute, } diff --git a/framework/middleware/recover.go b/framework/middleware/recover.go index c65c5f4..47181fe 100644 --- a/framework/middleware/recover.go +++ b/framework/middleware/recover.go @@ -109,7 +109,8 @@ func Recover() framework.Middleware { // // The custom handler receives the raw panic value and must return // an error that will be passed through the normal error handling -// chain. +// chain. Use [errors.As] with an interface containing a +// Stack() []byte method to access the stack trace. func RecoverWith(handler func(value any) error) framework.Middleware { return func(next framework.Handler) framework.Handler { return func( diff --git a/framework/middleware/recover_test.go b/framework/middleware/recover_test.go index 67b8a3c..0fc4959 100644 --- a/framework/middleware/recover_test.go +++ b/framework/middleware/recover_test.go @@ -1,7 +1,6 @@ package middleware_test import ( - "encoding" "errors" "fmt" "io" @@ -10,9 +9,10 @@ import ( "strings" "testing" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/framework" "github.com/studiolambda/cosmos/framework/middleware" + + "github.com/stretchr/testify/require" ) func TestRecoverFromErrorPanic(t *testing.T) { @@ -111,7 +111,6 @@ type testTextMarshaler struct { // Ensure testTextMarshaler does NOT implement fmt.Stringer // so it falls through to the TextMarshaler case. -var _ encoding.TextMarshaler = testTextMarshaler{} func (marshaler testTextMarshaler) MarshalText() ([]byte, error) { if marshaler.err != nil { @@ -281,12 +280,12 @@ type readerStringer struct { message string } -func (rs readerStringer) Read(p []byte) (int, error) { +func (readerStringer readerStringer) Read(p []byte) (int, error) { return 0, io.EOF } -func (rs readerStringer) String() string { - return rs.message +func (readerStringer readerStringer) String() string { + return readerStringer.message } func TestRecoverStringerTakesPrecedenceOverReader(t *testing.T) { diff --git a/framework/middleware/secure_headers_test.go b/framework/middleware/secure_headers_test.go index c5b5547..0f8e5e1 100644 --- a/framework/middleware/secure_headers_test.go +++ b/framework/middleware/secure_headers_test.go @@ -5,9 +5,10 @@ import ( "net/http/httptest" "testing" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/framework" "github.com/studiolambda/cosmos/framework/middleware" + + "github.com/stretchr/testify/require" ) func TestSecureHeadersDefault(t *testing.T) { From 4c94f20fdee9fd5dcde67b433061dec2100a9101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 12 Apr 2026 16:49:45 +0200 Subject: [PATCH 06/11] style(framework/session): fix docs, naming, export CacheDriverOptions.Prefix - Remove redundant sync.Mutex zero-value initialization - Fix informal comment and Regenerate/Delete doc inaccuracies - Replace string literal "GET" with http.MethodGet in all tests - Export CacheDriverOptions.Prefix field for external configurability - Fix import grouping in all test files --- framework/session/cache.go | 8 ++++---- framework/session/cache_test.go | 5 +++-- framework/session/middleware_test.go | 25 +++++++++++++------------ framework/session/session.go | 18 +++++++++--------- framework/session/session_test.go | 3 ++- 5 files changed, 31 insertions(+), 28 deletions(-) diff --git a/framework/session/cache.go b/framework/session/cache.go index c275afd..edb0ad8 100644 --- a/framework/session/cache.go +++ b/framework/session/cache.go @@ -26,9 +26,9 @@ type CacheDriver struct { } // CacheDriverOptions holds configuration for the CacheDriver. -// The prefix is prepended to session IDs when forming cache keys. +// The Prefix is prepended to session IDs when forming cache keys. type CacheDriverOptions struct { - prefix string + Prefix string } // ErrCacheDriverInvalidType is returned when a value retrieved from the @@ -40,7 +40,7 @@ var ErrCacheDriverInvalidType = errors.New("invalid cache type") // "cosmos.sessions". Use NewCacheDriverWith for custom options. func NewCacheDriver(cache contract.Cache) *CacheDriver { return NewCacheDriverWith(cache, CacheDriverOptions{ - prefix: "cosmos.sessions", + Prefix: "cosmos.sessions", }) } @@ -56,7 +56,7 @@ func NewCacheDriverWith(cache contract.Cache, options CacheDriverOptions) *Cache // key builds the full cache key by joining the configured prefix // with the session ID. func (driver *CacheDriver) key(id string) string { - return fmt.Sprintf("%s.%s", driver.options.prefix, id) + return fmt.Sprintf("%s.%s", driver.options.Prefix, id) } // Get retrieves a session from the cache by its ID. It returns diff --git a/framework/session/cache_test.go b/framework/session/cache_test.go index de6a3fd..39b827f 100644 --- a/framework/session/cache_test.go +++ b/framework/session/cache_test.go @@ -6,11 +6,12 @@ import ( "testing" "time" - tmock "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/contract" "github.com/studiolambda/cosmos/contract/mock" "github.com/studiolambda/cosmos/framework/session" + + tmock "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) func TestCacheDriverGetReturnsSession(t *testing.T) { diff --git a/framework/session/middleware_test.go b/framework/session/middleware_test.go index 4cfcb8a..c186c60 100644 --- a/framework/session/middleware_test.go +++ b/framework/session/middleware_test.go @@ -7,13 +7,14 @@ import ( "testing" "time" - tmock "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/contract" "github.com/studiolambda/cosmos/contract/mock" "github.com/studiolambda/cosmos/contract/request" "github.com/studiolambda/cosmos/framework" "github.com/studiolambda/cosmos/framework/session" + + tmock "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) func TestMiddlewareCookieExists(t *testing.T) { @@ -29,7 +30,7 @@ func TestMiddlewareCookieExists(t *testing.T) { handlerWithSessions := session.Middleware(cache)(handler) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) res := handlerWithSessions.Record(req) cookies := res.Cookies() @@ -78,7 +79,7 @@ func TestMiddlewareLoadsExistingSession(t *testing.T) { handlerWithSessions := session.Middleware(driver)(handler) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{ Name: session.DefaultCookie, Value: sessionID, @@ -109,7 +110,7 @@ func TestMiddlewareCreatesNewSessionForInvalidCookieID(t *testing.T) { handlerWithSessions := session.Middleware(driver)(handler) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{ Name: session.DefaultCookie, Value: "invalid!@#$", @@ -144,7 +145,7 @@ func TestMiddlewareCreatesNewSessionWhenDriverFails(t *testing.T) { handlerWithSessions := session.Middleware(driver)(handler) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{ Name: session.DefaultCookie, Value: validID, @@ -186,7 +187,7 @@ func TestMiddlewareWithExpiredSessionRegenerates(t *testing.T) { handlerWithSessions := session.Middleware(driver)(handler) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{ Name: session.DefaultCookie, Value: sessionID, @@ -232,7 +233,7 @@ func TestMiddlewareWithExpirationDeltaExtendsSession(t *testing.T) { }, )(handler) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{ Name: session.DefaultCookie, Value: sessionID, @@ -261,7 +262,7 @@ func TestMiddlewareWithDefaultShortcutUsesDefaults(t *testing.T) { handlerWithSessions := session.Middleware(driver)(handler) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) res := handlerWithSessions.Record(req) cookies := res.Cookies() @@ -299,7 +300,7 @@ func TestMiddlewareErrorHandlerCalledOnSaveError(t *testing.T) { }, )(handler) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) _ = handlerWithSessions.Record(req) require.ErrorIs(t, capturedErr, saveErr) @@ -334,7 +335,7 @@ func TestMiddlewareSessionNotSavedWhenUnchanged(t *testing.T) { }, )(handler) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{ Name: session.DefaultCookie, Value: sessionID, @@ -380,7 +381,7 @@ func TestMiddlewareRegenerateDeletesOldSession(t *testing.T) { handlerWithSessions := session.Middleware(driver)(handler) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{ Name: session.DefaultCookie, Value: sessionID, diff --git a/framework/session/session.go b/framework/session/session.go index 5af824e..60db814 100644 --- a/framework/session/session.go +++ b/framework/session/session.go @@ -85,8 +85,8 @@ func NewSession(expiresAt time.Time, storage map[string]any) (*Session, error) { createdAt: time.Now(), expiresAt: expiresAt, storage: storage, - mutex: sync.Mutex{}, - changed: true, // this will make sure first time its saved + // Mark changed so the session is persisted on first save. + changed: true, }, nil } @@ -139,7 +139,7 @@ func (session *Session) Put(key string, value any) { } // Delete removes a value from the session storage by key. If the key does not exist, -// this operation is a no-op. This operation marks the session as changed. +// the storage is unaffected but the session is still marked as changed. func (session *Session) Delete(key string) { session.mutex.Lock() defer session.mutex.Unlock() @@ -159,12 +159,12 @@ func (session *Session) Extend(expiresAt time.Time) { session.changed = true } -// Regenerate generates a new session ID and updates the expiration -// time. This is commonly used for security purposes such as -// preventing session fixation attacks after user authentication. -// The original session ID is preserved and can be retrieved via -// OriginalSessionID. This operation marks the session as changed. -// It returns an error if ID generation fails. +// Regenerate generates a new cryptographically random session ID +// for the session. This is commonly used for security purposes +// such as preventing session fixation attacks after user +// authentication. The original session ID is preserved and can be +// retrieved via OriginalSessionID. This operation marks the +// session as changed. It returns an error if ID generation fails. // // WARNING: This method MUST be called after any authentication // state change (login, logout, privilege escalation). Without diff --git a/framework/session/session_test.go b/framework/session/session_test.go index 3a44506..a7110b2 100644 --- a/framework/session/session_test.go +++ b/framework/session/session_test.go @@ -4,8 +4,9 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/framework/session" + + "github.com/stretchr/testify/require" ) func TestNewSessionReturnsNonNilSession(t *testing.T) { From ea5fa62f477149f7ec4d403734deee304f2f5eae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 12 Apr 2026 16:49:54 +0200 Subject: [PATCH 07/11] style(framework/cache): align doc comments, fix imports, document Pull atomicity - Fix import grouping to three groups in memory.go, redis.go, memory_test.go - Align redis doc comments with memory counterparts (Delete, Has, Forever, Increment, Decrement) - Remove named return values from redis Pull method - Document Pull's non-atomic nature in memory.go --- framework/cache/memory.go | 6 +++++- framework/cache/memory_test.go | 3 ++- framework/cache/redis.go | 19 ++++++++++++------- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/framework/cache/memory.go b/framework/cache/memory.go index b0ad08a..e812cbb 100644 --- a/framework/cache/memory.go +++ b/framework/cache/memory.go @@ -5,8 +5,9 @@ import ( "sync" "time" - "github.com/patrickmn/go-cache" "github.com/studiolambda/cosmos/contract" + + "github.com/patrickmn/go-cache" ) // Memory implements contract.Cache using an in-memory store backed @@ -66,6 +67,9 @@ func (memory *Memory) Has(_ context.Context, key string) (bool, error) { // Pull atomically retrieves and removes the value for the given key. // It holds a mutex to prevent races between the get and delete steps. +// +// Pull is not atomic; under concurrent access another caller may +// retrieve the same value before it is deleted. func (memory *Memory) Pull(ctx context.Context, key string) (any, error) { memory.mux.Lock() defer memory.mux.Unlock() diff --git a/framework/cache/memory_test.go b/framework/cache/memory_test.go index 4d2acb5..6dea5d6 100644 --- a/framework/cache/memory_test.go +++ b/framework/cache/memory_test.go @@ -6,9 +6,10 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/contract" "github.com/studiolambda/cosmos/framework/cache" + + "github.com/stretchr/testify/require" ) func TestMemoryGetReturnsStoredValue(t *testing.T) { diff --git a/framework/cache/redis.go b/framework/cache/redis.go index ad34de8..a2ecde8 100644 --- a/framework/cache/redis.go +++ b/framework/cache/redis.go @@ -6,8 +6,9 @@ import ( "errors" "time" - "github.com/redis/go-redis/v9" "github.com/studiolambda/cosmos/contract" + + "github.com/redis/go-redis/v9" ) // RedisOptions is an alias for redis.Options, exposing the full @@ -55,12 +56,12 @@ func (client *RedisClient) Put(ctx context.Context, key string, value any, ttl t return (*redis.Client)(client).Set(ctx, key, value, ttl).Err() } -// Delete removes a key from Redis. +// Delete removes a key from Redis. Deleting a non-existent key is a no-op. func (client *RedisClient) Delete(ctx context.Context, key string) error { return (*redis.Client)(client).Del(ctx, key).Err() } -// Has reports whether the key exists in Redis. +// Has reports whether the key exists in Redis and has not expired. func (client *RedisClient) Has(ctx context.Context, key string) (bool, error) { count, err := (*redis.Client)(client).Exists(ctx, key).Result() @@ -73,7 +74,7 @@ func (client *RedisClient) Has(ctx context.Context, key string) (bool, error) { // Pull atomically retrieves and deletes a key using Redis GETDEL. // The stored value is JSON-decoded into the return value. -func (client *RedisClient) Pull(ctx context.Context, key string) (value any, err error) { +func (client *RedisClient) Pull(ctx context.Context, key string) (any, error) { encoded, err := (*redis.Client)(client).GetDel(ctx, key).Result() if errors.Is(err, redis.Nil) { @@ -84,6 +85,8 @@ func (client *RedisClient) Pull(ctx context.Context, key string) (value any, err return nil, err } + var value any + if err := json.Unmarshal([]byte(encoded), &value); err != nil { return nil, err } @@ -91,17 +94,19 @@ func (client *RedisClient) Pull(ctx context.Context, key string) (value any, err return value, nil } -// Forever stores a value with no expiration. +// Forever stores a value with no expiration. Values are serialized for storage. func (client *RedisClient) Forever(ctx context.Context, key string, value any) error { return client.Put(ctx, key, value, 0) } -// Increment increases the integer value at key by the given amount. +// Increment atomically increases the integer value stored at key by +// the given amount. func (client *RedisClient) Increment(ctx context.Context, key string, by int64) (int64, error) { return (*redis.Client)(client).IncrBy(ctx, key, by).Result() } -// Decrement decreases the integer value at key by the given amount. +// Decrement atomically decreases the integer value stored at key by +// the given amount. func (client *RedisClient) Decrement(ctx context.Context, key string, by int64) (int64, error) { return (*redis.Client)(client).DecrBy(ctx, key, by).Result() } From 5e19e3a668b92645425eb0c7c5a80c09dc34206d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 12 Apr 2026 16:50:01 +0200 Subject: [PATCH 08/11] style(framework/crypto): standardize test names, variable naming, struct field order - Rename TestItCan* tests to TestAES*/TestChaCha20* naming convention - Rename abbreviated variable e->encrypter and cypher->ciphertext - Align ChaCha20 struct field order with AES (key, aead, AdditionalData) - Add ChaCha20 concurrency safety doc note matching AES - Fix import grouping in test files --- framework/crypto/aes_test.go | 69 ++++++++++++++++--------------- framework/crypto/chacha20.go | 11 +++-- framework/crypto/chacha20_test.go | 57 ++++++++++++------------- 3 files changed, 71 insertions(+), 66 deletions(-) diff --git a/framework/crypto/aes_test.go b/framework/crypto/aes_test.go index 7bde03c..6b65a7a 100644 --- a/framework/crypto/aes_test.go +++ b/framework/crypto/aes_test.go @@ -3,11 +3,12 @@ package crypto_test import ( "testing" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/framework/crypto" + + "github.com/stretchr/testify/require" ) -func TestItCanCreateAESEncrypter(t *testing.T) { +func TestAESNewCreatesEncrypter(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") @@ -16,34 +17,34 @@ func TestItCanCreateAESEncrypter(t *testing.T) { require.NoError(t, err) } -func TestItCanEncryptAES(t *testing.T) { +func TestAESEncryptSucceeds(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewAES(key) + encrypter, err := crypto.NewAES(key) require.NoError(t, err) plain := []byte("Hello, World!") - _, err = e.Encrypt(plain) + _, err = encrypter.Encrypt(plain) require.NoError(t, err) } -func TestItCanDecryptAES(t *testing.T) { +func TestAESEncryptDecryptRoundTrip(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewAES(key) + encrypter, err := crypto.NewAES(key) require.NoError(t, err) plain := []byte("Hello, World!") - cypher, err := e.Encrypt(plain) + ciphertext, err := encrypter.Encrypt(plain) require.NoError(t, err) - res, err := e.Decrypt(cypher) + res, err := encrypter.Decrypt(ciphertext) require.NoError(t, err) require.Equal(t, plain, res) @@ -61,16 +62,16 @@ func TestAESNewWith16ByteKey(t *testing.T) { t.Parallel() key := []byte("1234567890123456") - e, err := crypto.NewAES(key) + encrypter, err := crypto.NewAES(key) require.NoError(t, err) plain := []byte("Hello, World!") - cypher, err := e.Encrypt(plain) + ciphertext, err := encrypter.Encrypt(plain) require.NoError(t, err) - res, err := e.Decrypt(cypher) + res, err := encrypter.Decrypt(ciphertext) require.NoError(t, err) require.Equal(t, plain, res) @@ -80,16 +81,16 @@ func TestAESNewWith24ByteKey(t *testing.T) { t.Parallel() key := []byte("123456789012345678901234") - e, err := crypto.NewAES(key) + encrypter, err := crypto.NewAES(key) require.NoError(t, err) plain := []byte("Hello, World!") - cypher, err := e.Encrypt(plain) + ciphertext, err := encrypter.Encrypt(plain) require.NoError(t, err) - res, err := e.Decrypt(cypher) + res, err := encrypter.Decrypt(ciphertext) require.NoError(t, err) require.Equal(t, plain, res) @@ -99,11 +100,11 @@ func TestAESDecryptWithShortCiphertext(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewAES(key) + encrypter, err := crypto.NewAES(key) require.NoError(t, err) - _, err = e.Decrypt([]byte("short")) + _, err = encrypter.Decrypt([]byte("short")) require.ErrorIs(t, err, crypto.ErrMismatchedAESNonceSize) } @@ -112,18 +113,18 @@ func TestAESDecryptWithCorruptedCiphertext(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewAES(key) + encrypter, err := crypto.NewAES(key) require.NoError(t, err) plain := []byte("Hello, World!") - cypher, err := e.Encrypt(plain) + ciphertext, err := encrypter.Encrypt(plain) require.NoError(t, err) - cypher[len(cypher)-1] ^= 0xFF + ciphertext[len(ciphertext)-1] ^= 0xFF - _, err = e.Decrypt(cypher) + _, err = encrypter.Decrypt(ciphertext) require.Error(t, err) } @@ -134,11 +135,11 @@ func TestAESCloseZerosKeyMaterial(t *testing.T) { key := make([]byte, 32) copy(key, "12345678901234567890123456789012") - e, err := crypto.NewAES(key) + encrypter, err := crypto.NewAES(key) require.NoError(t, err) - e.Close() + encrypter.Close() allZero := true for _, b := range key { @@ -162,7 +163,7 @@ func TestAESAdditionalDataMustMatchForDecrypt(t *testing.T) { encrypter.AdditionalData = []byte("context-v1") plain := []byte("Hello, World!") - cypher, err := encrypter.Encrypt(plain) + ciphertext, err := encrypter.Encrypt(plain) require.NoError(t, err) @@ -171,7 +172,7 @@ func TestAESAdditionalDataMustMatchForDecrypt(t *testing.T) { decrypter.AdditionalData = []byte("context-v2") - _, err = decrypter.Decrypt(cypher) + _, err = decrypter.Decrypt(ciphertext) require.Error(t, err) } @@ -180,18 +181,18 @@ func TestAESAdditionalDataRoundTrip(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewAES(key) + encrypter, err := crypto.NewAES(key) require.NoError(t, err) - e.AdditionalData = []byte("user-42") + encrypter.AdditionalData = []byte("user-42") plain := []byte("Hello, World!") - cypher, err := e.Encrypt(plain) + ciphertext, err := encrypter.Encrypt(plain) require.NoError(t, err) - res, err := e.Decrypt(cypher) + res, err := encrypter.Decrypt(ciphertext) require.NoError(t, err) require.Equal(t, plain, res) @@ -201,16 +202,16 @@ func TestAESEncryptProducesDifferentCiphertexts(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewAES(key) + encrypter, err := crypto.NewAES(key) require.NoError(t, err) plain := []byte("Hello, World!") - c1, err := e.Encrypt(plain) + c1, err := encrypter.Encrypt(plain) require.NoError(t, err) - c2, err := e.Encrypt(plain) + c2, err := encrypter.Encrypt(plain) require.NoError(t, err) require.NotEqual(t, c1, c2) @@ -220,11 +221,11 @@ func TestAESDecryptEmptyInput(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewAES(key) + encrypter, err := crypto.NewAES(key) require.NoError(t, err) - _, err = e.Decrypt([]byte{}) + _, err = encrypter.Decrypt([]byte{}) require.ErrorIs(t, err, crypto.ErrMismatchedAESNonceSize) } diff --git a/framework/crypto/chacha20.go b/framework/crypto/chacha20.go index 829f959..c0d1fc5 100644 --- a/framework/crypto/chacha20.go +++ b/framework/crypto/chacha20.go @@ -12,14 +12,17 @@ import ( // ChaCha20 implements contract.Encrypter using ChaCha20-Poly1305 // authenticated encryption. The AEAD cipher is created once at // construction time and reused for every Encrypt/Decrypt call. +// This is safe because ChaCha20-Poly1305 AEAD instances are safe +// for concurrent use with different nonces. The nonce is generated +// randomly for each Encrypt call and prepended to the ciphertext. type ChaCha20 struct { - // aead is the underlying ChaCha20-Poly1305 AEAD cipher. - aead cipher.AEAD - // key is the raw key material, retained so that Close // can zero it from memory. key []byte + // aead is the underlying ChaCha20-Poly1305 AEAD cipher. + aead cipher.AEAD + // AdditionalData is optional additional authenticated data (AAD) // passed to the AEAD Seal and Open operations. AAD is // authenticated but not encrypted, which allows binding the @@ -44,7 +47,7 @@ func NewChaCha20(key []byte) (*ChaCha20, error) { return nil, err } - return &ChaCha20{aead: aead, key: key}, nil + return &ChaCha20{key: key, aead: aead}, nil } // Encrypt encrypts the plaintext using ChaCha20-Poly1305 with a diff --git a/framework/crypto/chacha20_test.go b/framework/crypto/chacha20_test.go index ca9961a..873c85b 100644 --- a/framework/crypto/chacha20_test.go +++ b/framework/crypto/chacha20_test.go @@ -3,11 +3,12 @@ package crypto_test import ( "testing" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/framework/crypto" + + "github.com/stretchr/testify/require" ) -func TestItCanCreateChaCha20Encrypter(t *testing.T) { +func TestChaCha20NewCreatesEncrypter(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") @@ -16,34 +17,34 @@ func TestItCanCreateChaCha20Encrypter(t *testing.T) { require.NoError(t, err) } -func TestItCanEncryptChaCha20(t *testing.T) { +func TestChaCha20EncryptSucceeds(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewChaCha20(key) + encrypter, err := crypto.NewChaCha20(key) require.NoError(t, err) plain := []byte("Hello, World!") - _, err = e.Encrypt(plain) + _, err = encrypter.Encrypt(plain) require.NoError(t, err) } -func TestItCanDecryptChaCha20(t *testing.T) { +func TestChaCha20EncryptDecryptRoundTrip(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewChaCha20(key) + encrypter, err := crypto.NewChaCha20(key) require.NoError(t, err) plain := []byte("Hello, World!") - cypher, err := e.Encrypt(plain) + ciphertext, err := encrypter.Encrypt(plain) require.NoError(t, err) - res, err := e.Decrypt(cypher) + res, err := encrypter.Decrypt(ciphertext) require.NoError(t, err) require.Equal(t, plain, res) @@ -61,11 +62,11 @@ func TestChaCha20DecryptWithShortCiphertext(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewChaCha20(key) + encrypter, err := crypto.NewChaCha20(key) require.NoError(t, err) - _, err = e.Decrypt([]byte("short")) + _, err = encrypter.Decrypt([]byte("short")) require.ErrorIs(t, err, crypto.ErrMismatchedChaCha20NonceSize) } @@ -74,18 +75,18 @@ func TestChaCha20DecryptWithCorruptedCiphertext(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewChaCha20(key) + encrypter, err := crypto.NewChaCha20(key) require.NoError(t, err) plain := []byte("Hello, World!") - cypher, err := e.Encrypt(plain) + ciphertext, err := encrypter.Encrypt(plain) require.NoError(t, err) - cypher[len(cypher)-1] ^= 0xFF + ciphertext[len(ciphertext)-1] ^= 0xFF - _, err = e.Decrypt(cypher) + _, err = encrypter.Decrypt(ciphertext) require.Error(t, err) } @@ -96,11 +97,11 @@ func TestChaCha20CloseZerosKeyMaterial(t *testing.T) { key := make([]byte, 32) copy(key, "12345678901234567890123456789012") - e, err := crypto.NewChaCha20(key) + encrypter, err := crypto.NewChaCha20(key) require.NoError(t, err) - e.Close() + encrypter.Close() allZero := true for _, b := range key { @@ -124,7 +125,7 @@ func TestChaCha20AdditionalDataMustMatchForDecrypt(t *testing.T) { encrypter.AdditionalData = []byte("context-v1") plain := []byte("Hello, World!") - cypher, err := encrypter.Encrypt(plain) + ciphertext, err := encrypter.Encrypt(plain) require.NoError(t, err) @@ -133,7 +134,7 @@ func TestChaCha20AdditionalDataMustMatchForDecrypt(t *testing.T) { decrypter.AdditionalData = []byte("context-v2") - _, err = decrypter.Decrypt(cypher) + _, err = decrypter.Decrypt(ciphertext) require.Error(t, err) } @@ -142,18 +143,18 @@ func TestChaCha20AdditionalDataRoundTrip(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewChaCha20(key) + encrypter, err := crypto.NewChaCha20(key) require.NoError(t, err) - e.AdditionalData = []byte("user-42") + encrypter.AdditionalData = []byte("user-42") plain := []byte("Hello, World!") - cypher, err := e.Encrypt(plain) + ciphertext, err := encrypter.Encrypt(plain) require.NoError(t, err) - res, err := e.Decrypt(cypher) + res, err := encrypter.Decrypt(ciphertext) require.NoError(t, err) require.Equal(t, plain, res) @@ -163,16 +164,16 @@ func TestChaCha20EncryptProducesDifferentCiphertexts(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewChaCha20(key) + encrypter, err := crypto.NewChaCha20(key) require.NoError(t, err) plain := []byte("Hello, World!") - c1, err := e.Encrypt(plain) + c1, err := encrypter.Encrypt(plain) require.NoError(t, err) - c2, err := e.Encrypt(plain) + c2, err := encrypter.Encrypt(plain) require.NoError(t, err) require.NotEqual(t, c1, c2) @@ -182,11 +183,11 @@ func TestChaCha20DecryptEmptyInput(t *testing.T) { t.Parallel() key := []byte("12345678901234567890123456789012") - e, err := crypto.NewChaCha20(key) + encrypter, err := crypto.NewChaCha20(key) require.NoError(t, err) - _, err = e.Decrypt([]byte{}) + _, err = encrypter.Decrypt([]byte{}) require.ErrorIs(t, err, crypto.ErrMismatchedChaCha20NonceSize) } From cb436dc6e6cd3f75a0ed06302c1ea3d3120a5591 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 12 Apr 2026 16:50:06 +0200 Subject: [PATCH 09/11] style(framework/hash): standardize test names, naming, use testify for zero checks - Rename TestItCan* tests to TestArgon2*/TestBcrypt* naming convention - Rename abbreviated variables h->hasher, r->hashed - Replace manual all-zero byte loops with require.Equal assertions - Use parenthesized import in argon2.go for consistency - Fix import grouping in test files --- framework/hash/argon2.go | 4 +- framework/hash/argon2_test.go | 63 ++++++++++++-------------------- framework/hash/bcrypt_test.go | 69 +++++++++++++---------------------- 3 files changed, 52 insertions(+), 84 deletions(-) diff --git a/framework/hash/argon2.go b/framework/hash/argon2.go index 24e3e70..94bec72 100644 --- a/framework/hash/argon2.go +++ b/framework/hash/argon2.go @@ -1,6 +1,8 @@ package hash -import "github.com/matthewhartstonge/argon2" +import ( + "github.com/matthewhartstonge/argon2" +) // Argon2Config is an alias for argon2.Config, exposing the full set // of tuning parameters (memory, iterations, parallelism) without diff --git a/framework/hash/argon2_test.go b/framework/hash/argon2_test.go index 6f162e2..da7b064 100644 --- a/framework/hash/argon2_test.go +++ b/framework/hash/argon2_test.go @@ -3,32 +3,33 @@ package hash_test import ( "testing" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/framework/hash" + + "github.com/stretchr/testify/require" ) -func TestItCanHashArgon2Passwords(t *testing.T) { +func TestArgon2HashProducesOutput(t *testing.T) { t.Parallel() - h := hash.NewArgon2() + hasher := hash.NewArgon2() content := []byte("hello, world") - r, err := h.Hash(content) + hashed, err := hasher.Hash(content) require.NoError(t, err) - require.Greater(t, len(r), 0) + require.Greater(t, len(hashed), 0) } -func TestItCanCheckHashedArgon2Hashes(t *testing.T) { +func TestArgon2CheckMatchesCorrectPassword(t *testing.T) { t.Parallel() - h := hash.NewArgon2() + hasher := hash.NewArgon2() - r, err := h.Hash([]byte("hello, world")) + hashed, err := hasher.Hash([]byte("hello, world")) require.NoError(t, err) - ok, err := h.Check([]byte("hello, world"), r) + ok, err := hasher.Check([]byte("hello, world"), hashed) require.NoError(t, err) require.True(t, ok) @@ -47,24 +48,24 @@ func TestArgon2WithCustomConfig(t *testing.T) { Version: 0, } - h := hash.NewArgon2With(config) + hasher := hash.NewArgon2With(config) - r, err := h.Hash([]byte("hello, world")) + hashed, err := hasher.Hash([]byte("hello, world")) require.NoError(t, err) - require.Greater(t, len(r), 0) + require.Greater(t, len(hashed), 0) } func TestArgon2CheckWrongPasswordReturnsFalse(t *testing.T) { t.Parallel() - h := hash.NewArgon2() + hasher := hash.NewArgon2() - hashed, err := h.Hash([]byte("correct-password")) + hashed, err := hasher.Hash([]byte("correct-password")) require.NoError(t, err) - ok, err := h.Check([]byte("wrong-password"), hashed) + ok, err := hasher.Check([]byte("wrong-password"), hashed) require.NoError(t, err) require.False(t, ok) @@ -73,46 +74,28 @@ func TestArgon2CheckWrongPasswordReturnsFalse(t *testing.T) { func TestArgon2HashZerosInputPassword(t *testing.T) { t.Parallel() - h := hash.NewArgon2() + hasher := hash.NewArgon2() password := []byte("sensitive-data") - _, err := h.Hash(password) + _, err := hasher.Hash(password) require.NoError(t, err) - - allZero := true - for _, b := range password { - if b != 0 { - allZero = false - break - } - } - - require.True(t, allZero) + require.Equal(t, make([]byte, len(password)), password) } func TestArgon2CheckZerosInputPassword(t *testing.T) { t.Parallel() - h := hash.NewArgon2() + hasher := hash.NewArgon2() - hashed, err := h.Hash([]byte("hello")) + hashed, err := hasher.Hash([]byte("hello")) require.NoError(t, err) password := []byte("hello") - _, err = h.Check(password, hashed) + _, err = hasher.Check(password, hashed) require.NoError(t, err) - - allZero := true - for _, b := range password { - if b != 0 { - allZero = false - break - } - } - - require.True(t, allZero) + require.Equal(t, make([]byte, len(password)), password) } diff --git a/framework/hash/bcrypt_test.go b/framework/hash/bcrypt_test.go index 37d9be0..dcb79ee 100644 --- a/framework/hash/bcrypt_test.go +++ b/framework/hash/bcrypt_test.go @@ -3,32 +3,33 @@ package hash_test import ( "testing" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/framework/hash" + + "github.com/stretchr/testify/require" ) -func TestItCanHashBcryptPasswords(t *testing.T) { +func TestBcryptHashProducesOutput(t *testing.T) { t.Parallel() - h := hash.NewBcrypt() + hasher := hash.NewBcrypt() content := []byte("hello, world") - r, err := h.Hash(content) + hashed, err := hasher.Hash(content) require.NoError(t, err) - require.Greater(t, len(r), 0) + require.Greater(t, len(hashed), 0) } -func TestItCanCheckHashedBcryptHashes(t *testing.T) { +func TestBcryptCheckMatchesCorrectPassword(t *testing.T) { t.Parallel() - h := hash.NewBcrypt() + hasher := hash.NewBcrypt() - r, err := h.Hash([]byte("hello, world")) + hashed, err := hasher.Hash([]byte("hello, world")) require.NoError(t, err) - ok, err := h.Check([]byte("hello, world"), r) + ok, err := hasher.Check([]byte("hello, world"), hashed) require.NoError(t, err) require.True(t, ok) @@ -37,14 +38,14 @@ func TestItCanCheckHashedBcryptHashes(t *testing.T) { func TestBcryptWithDefaultOptions(t *testing.T) { t.Parallel() - h := hash.NewBcryptWith(hash.BcryptOptions{}) + hasher := hash.NewBcryptWith(hash.BcryptOptions{}) - r, err := h.Hash([]byte("hello, world")) + hashed, err := hasher.Hash([]byte("hello, world")) require.NoError(t, err) - require.Greater(t, len(r), 0) + require.Greater(t, len(hashed), 0) - ok, err := h.Check([]byte("hello, world"), r) + ok, err := hasher.Check([]byte("hello, world"), hashed) require.NoError(t, err) require.True(t, ok) @@ -53,13 +54,13 @@ func TestBcryptWithDefaultOptions(t *testing.T) { func TestBcryptCheckWrongPasswordReturnsFalse(t *testing.T) { t.Parallel() - h := hash.NewBcrypt() + hasher := hash.NewBcrypt() - hashed, err := h.Hash([]byte("correct-password")) + hashed, err := hasher.Hash([]byte("correct-password")) require.NoError(t, err) - ok, err := h.Check([]byte("wrong-password"), hashed) + ok, err := hasher.Check([]byte("wrong-password"), hashed) require.NoError(t, err) require.False(t, ok) @@ -68,9 +69,9 @@ func TestBcryptCheckWrongPasswordReturnsFalse(t *testing.T) { func TestBcryptCheckCorruptedHashReturnsError(t *testing.T) { t.Parallel() - h := hash.NewBcrypt() + hasher := hash.NewBcrypt() - ok, err := h.Check([]byte("password"), []byte("not-a-hash")) + ok, err := hasher.Check([]byte("password"), []byte("not-a-hash")) require.Error(t, err) require.False(t, ok) @@ -79,46 +80,28 @@ func TestBcryptCheckCorruptedHashReturnsError(t *testing.T) { func TestBcryptHashZerosInputPassword(t *testing.T) { t.Parallel() - h := hash.NewBcrypt() + hasher := hash.NewBcrypt() password := []byte("sensitive-data") - _, err := h.Hash(password) + _, err := hasher.Hash(password) require.NoError(t, err) - - allZero := true - for _, b := range password { - if b != 0 { - allZero = false - break - } - } - - require.True(t, allZero) + require.Equal(t, make([]byte, len(password)), password) } func TestBcryptCheckZerosInputPassword(t *testing.T) { t.Parallel() - h := hash.NewBcrypt() + hasher := hash.NewBcrypt() - hashed, err := h.Hash([]byte("hello")) + hashed, err := hasher.Hash([]byte("hello")) require.NoError(t, err) password := []byte("hello") - _, err = h.Check(password, hashed) + _, err = hasher.Check(password, hashed) require.NoError(t, err) - - allZero := true - for _, b := range password { - if b != 0 { - allZero = false - break - } - } - - require.True(t, allZero) + require.Equal(t, make([]byte, len(password)), password) } From 94bb47aca9b9f3900d3d05bf2ae94b1431bd5603 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 12 Apr 2026 16:50:12 +0200 Subject: [PATCH 10/11] style(framework/database): fix import grouping and use nil TxOptions - Separate imports into three groups (stdlib/cosmos/external) - Replace &sql.TxOptions{} with nil for idiomatic driver defaults --- framework/database/sql.go | 5 +++-- framework/database/sql_test.go | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/framework/database/sql.go b/framework/database/sql.go index ffe47ed..fdb1658 100644 --- a/framework/database/sql.go +++ b/framework/database/sql.go @@ -5,8 +5,9 @@ import ( "database/sql" "errors" - "github.com/jmoiron/sqlx" "github.com/studiolambda/cosmos/contract" + + "github.com/jmoiron/sqlx" ) // prepare is the shared interface between *sqlx.DB and *sqlx.Tx @@ -182,7 +183,7 @@ func (database *SQL) WithTransaction(ctx context.Context, fn func(tx contract.Da return contract.ErrDatabaseNestedTransaction } - tx, err := database.raw.BeginTxx(ctx, &sql.TxOptions{}) + tx, err := database.raw.BeginTxx(ctx, nil) if err != nil { return err diff --git a/framework/database/sql_test.go b/framework/database/sql_test.go index d21cbe9..b4bfdbd 100644 --- a/framework/database/sql_test.go +++ b/framework/database/sql_test.go @@ -7,10 +7,11 @@ import ( "errors" "testing" - _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/contract" "github.com/studiolambda/cosmos/framework/database" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" ) func newTestDB(t *testing.T) *database.SQL { From 3da36248bebeb021c95342ef43ae2478e811e1d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 12 Apr 2026 16:50:19 +0200 Subject: [PATCH 11/11] style(framework/event): fix imports, idiomatic Go patterns, doc comments - Fix import grouping to three groups in all source and test files - Replace fmt.Sprintf with strconv.FormatUint for subscription IDs - Use var declarations instead of empty composite literals for zero-value types - Remove unnecessary var() grouping for single ErrBrokerClosed - Fix deliverToHandler doc comment accuracy - Add explicit _ = for discarded Close errors in AMQP broker --- framework/event/amqp.go | 21 +++++++++++++++------ framework/event/memory.go | 24 ++++++++++-------------- framework/event/memory_test.go | 3 ++- framework/event/mqtt.go | 3 ++- framework/event/nats.go | 5 +++-- framework/event/redis.go | 5 +++-- 6 files changed, 35 insertions(+), 26 deletions(-) diff --git a/framework/event/amqp.go b/framework/event/amqp.go index 2dd83d2..c5c0d86 100644 --- a/framework/event/amqp.go +++ b/framework/event/amqp.go @@ -5,8 +5,9 @@ import ( "encoding/json" "sync" - amqp091 "github.com/rabbitmq/amqp091-go" "github.com/studiolambda/cosmos/contract" + + amqp091 "github.com/rabbitmq/amqp091-go" ) // AMQPBroker implements the EventBroker interface using RabbitMQ's @@ -122,7 +123,9 @@ func NewAMQPBrokerFrom( ) if err != nil { - pubCh.Close() + // Close is best-effort: the connection is being abandoned due to + // the exchange declaration failure above. + _ = pubCh.Close() return nil, err } @@ -200,7 +203,9 @@ func (broker *AMQPBroker) Subscribe( nil, ) if err != nil { - ch.Close() + // Close is best-effort: the channel is being abandoned due to + // the queue declaration failure above. + _ = ch.Close() return nil, err } @@ -214,7 +219,9 @@ func (broker *AMQPBroker) Subscribe( ) if err != nil { - ch.Close() + // Close is best-effort: the channel is being abandoned due to + // the queue bind failure above. + _ = ch.Close() return nil, err } @@ -230,12 +237,14 @@ func (broker *AMQPBroker) Subscribe( nil, ) if err != nil { - ch.Close() + // Close is best-effort: the channel is being abandoned due to + // the consume setup failure above. + _ = ch.Close() return nil, err } - wg := sync.WaitGroup{} + var wg sync.WaitGroup wg.Go(func() { for delivery := range deliveries { diff --git a/framework/event/memory.go b/framework/event/memory.go index 17de430..c63ef07 100644 --- a/framework/event/memory.go +++ b/framework/event/memory.go @@ -4,8 +4,8 @@ import ( "context" "encoding/json" "errors" - "fmt" "log/slog" + "strconv" "strings" "sync" "sync/atomic" @@ -13,13 +13,11 @@ import ( "github.com/studiolambda/cosmos/contract" ) -var ( - // ErrBrokerClosed is returned when attempting operations on a closed - // broker. - // Once a broker is closed, it cannot be reused and a new instance - // must be created. - ErrBrokerClosed = errors.New("broker is closed") -) +// ErrBrokerClosed is returned when attempting operations on a closed +// broker. +// Once a broker is closed, it cannot be reused and a new instance +// must be created. +var ErrBrokerClosed = errors.New("broker is closed") // DefaultMaxConcurrentDeliveries is the maximum number of // concurrent handler goroutines allowed per MemoryBroker. @@ -162,7 +160,7 @@ func (broker *MemoryBroker) Subscribe( return nil, ErrBrokerClosed } - handlerID := fmt.Sprintf("%d", broker.nextID.Add(1)) + handlerID := strconv.FormatUint(broker.nextID.Add(1), 10) broker.mu.Lock() defer broker.mu.Unlock() @@ -206,11 +204,9 @@ func (broker *MemoryBroker) Close() error { return nil } -// deliverToHandler invokes a handler with the encoded payload -// in a goroutine with panic recovery. Recovered panics are -// logged via slog so they remain visible for debugging. -// This ensures that a panic in one handler doesn't affect -// other handlers or the broker itself. +// deliverToHandler invokes a handler with the encoded payload, +// recovering from any panic to prevent handler failures from +// affecting the broker. func (broker *MemoryBroker) deliverToHandler( handler contract.EventHandler, encoded []byte, diff --git a/framework/event/memory_test.go b/framework/event/memory_test.go index cbed5c5..8f3aaea 100644 --- a/framework/event/memory_test.go +++ b/framework/event/memory_test.go @@ -7,9 +7,10 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" "github.com/studiolambda/cosmos/contract" "github.com/studiolambda/cosmos/framework/event" + + "github.com/stretchr/testify/require" ) func TestMemoryBrokerPublishAndSubscribe(t *testing.T) { diff --git a/framework/event/mqtt.go b/framework/event/mqtt.go index 44f2589..ba82dae 100644 --- a/framework/event/mqtt.go +++ b/framework/event/mqtt.go @@ -11,9 +11,10 @@ import ( "sync" "sync/atomic" + "github.com/studiolambda/cosmos/contract" + "github.com/eclipse/paho.golang/autopaho" "github.com/eclipse/paho.golang/paho" - "github.com/studiolambda/cosmos/contract" ) // MQTTBroker implements the EventBroker interface using MQTT v5 diff --git a/framework/event/nats.go b/framework/event/nats.go index f813989..bdf4312 100644 --- a/framework/event/nats.go +++ b/framework/event/nats.go @@ -7,8 +7,9 @@ import ( "strings" "time" - "github.com/nats-io/nats.go" "github.com/studiolambda/cosmos/contract" + + "github.com/nats-io/nats.go" ) const ( @@ -138,7 +139,7 @@ func NewNATSBroker(url string) (*NATSBroker, error) { // // Returns an error if connection to the NATS server fails. func NewNATSBrokerWith(options *NATSBrokerOptions) (*NATSBroker, error) { - opts := []nats.Option{} + var opts []nats.Option if options.Name != "" { opts = append(opts, nats.Name(options.Name)) diff --git a/framework/event/redis.go b/framework/event/redis.go index 58578cc..198ae6d 100644 --- a/framework/event/redis.go +++ b/framework/event/redis.go @@ -6,8 +6,9 @@ import ( "strings" "sync" - "github.com/redis/go-redis/v9" "github.com/studiolambda/cosmos/contract" + + "github.com/redis/go-redis/v9" ) // RedisBroker implements contract.EventBus using Redis Pub/Sub. @@ -61,7 +62,7 @@ func (broker *RedisBroker) Subscribe( ) (contract.EventUnsubscribeFunc, error) { event = strings.ReplaceAll(event, "#", "*") sub := broker.client.PSubscribe(ctx, event) - wg := sync.WaitGroup{} + var wg sync.WaitGroup wg.Go(func() { for message := range sub.Channel() {