From 0db528dab442832a82c15f2537a58bac68e90d03 Mon Sep 17 00:00:00 2001 From: caydyan Date: Sun, 14 Jun 2026 07:16:47 +0800 Subject: [PATCH] Validate workflow response token --- handlers/workflow.go | 37 +++++++++++++++++ handlers/workflow_test.go | 83 ++++++++++++++++++++++++++++++++++++++- mocks/Database.go | 58 ++++++++++++++++++++++++++- 3 files changed, 176 insertions(+), 2 deletions(-) diff --git a/handlers/workflow.go b/handlers/workflow.go index 0c08c8d43..6c718351b 100644 --- a/handlers/workflow.go +++ b/handlers/workflow.go @@ -1,6 +1,7 @@ package handlers import ( + "crypto/subtle" "encoding/json" "github.com/stakwork/sphinx-tribes/db" @@ -8,6 +9,8 @@ import ( "io" "net/http" + "os" + "strings" ) type workflowHandler struct { @@ -19,12 +22,41 @@ type CreateWorkflowRequestRequest struct { ResponseData db.PropertyMap `json:"response_data"` } +const workflowResponseTokenHeader = "x-api-token" + func NewWorkFlowHandler(database db.Database) *workflowHandler { return &workflowHandler{ db: database, } } +func validateStakworkResponseToken(r *http.Request) (int, string) { + expectedToken := strings.TrimSpace(os.Getenv("SWWF_RESKEY")) + if expectedToken == "" { + return http.StatusInternalServerError, "Server not configured" + } + + token := strings.TrimSpace(r.Header.Get(workflowResponseTokenHeader)) + if token == "" { + token = strings.TrimSpace(r.Header.Get("x-swwf-reskey")) + } + if token == "" { + authHeader := strings.TrimSpace(r.Header.Get("Authorization")) + if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + token = strings.TrimSpace(authHeader[len("bearer "):]) + } + } + if token == "" { + return http.StatusUnauthorized, "Missing token" + } + + if subtle.ConstantTimeCompare([]byte(token), []byte(expectedToken)) != 1 { + return http.StatusUnauthorized, "Invalid token" + } + + return http.StatusOK, "" +} + // HandleWorkflowRequest godoc // // @Summary Handle Workflow Request @@ -86,6 +118,11 @@ func (wh *workflowHandler) HandleWorkflowRequest(w http.ResponseWriter, r *http. // @Success 200 {object} map[string]string // @Router /workflows/response [post] func (wh *workflowHandler) HandleWorkflowResponse(w http.ResponseWriter, r *http.Request) { + if status, msg := validateStakworkResponseToken(r); status != http.StatusOK { + http.Error(w, msg, status) + return + } + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "Error reading response body", http.StatusBadRequest) diff --git a/handlers/workflow_test.go b/handlers/workflow_test.go index 7f2f44063..e4057a952 100644 --- a/handlers/workflow_test.go +++ b/handlers/workflow_test.go @@ -14,6 +14,12 @@ import ( "github.com/stretchr/testify/assert" ) +const testWorkflowResponseToken = "test-swwf-token" + +func authorizeWorkflowResponse(req *http.Request) { + req.Header.Set("x-api-token", testWorkflowResponseToken) +} + func TestHandleWorkflowRequest(t *testing.T) { teardownSuite := SetupSuite(t) @@ -78,6 +84,7 @@ func TestHandleWorkflowResponse(t *testing.T) { teardownSuite := SetupSuite(t) defer teardownSuite(t) + t.Setenv("SWWF_RESKEY", testWorkflowResponseToken) wh := NewWorkFlowHandler(db.TestDB) t.Run("should process workflow response successfully", func(t *testing.T) { @@ -120,6 +127,7 @@ func TestHandleWorkflowResponse(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/workflows/response", bytes.NewBuffer(payload)) req.Header.Set("Content-Type", "application/json") + authorizeWorkflowResponse(req) w := httptest.NewRecorder() wh.HandleWorkflowResponse(w, req) @@ -180,6 +188,7 @@ func TestHandleWorkflowResponse(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/workflows/response", bytes.NewBuffer(payload)) req.Header.Set("Content-Type", "application/json") + authorizeWorkflowResponse(req) w := httptest.NewRecorder() wh.HandleWorkflowResponse(w, req) @@ -201,6 +210,7 @@ func TestHandleWorkflowResponse(t *testing.T) { payload, _ := json.Marshal(response) req := httptest.NewRequest(http.MethodPost, "/workflows/response", bytes.NewBuffer(payload)) + authorizeWorkflowResponse(req) w := httptest.NewRecorder() wh.HandleWorkflowResponse(w, req) @@ -212,6 +222,7 @@ func TestHandleWorkflowResponse(t *testing.T) { invalidJSON := []byte(`{"request_id": "123", status: invalid}`) req := httptest.NewRequest(http.MethodPost, "/workflows/response", bytes.NewBuffer(invalidJSON)) + authorizeWorkflowResponse(req) w := httptest.NewRecorder() wh.HandleWorkflowResponse(w, req) @@ -228,6 +239,7 @@ func TestHandleWorkflowResponse(t *testing.T) { payload, _ := json.Marshal(response) req := httptest.NewRequest(http.MethodPost, "/workflows/response", bytes.NewBuffer(payload)) + authorizeWorkflowResponse(req) w := httptest.NewRecorder() wh.HandleWorkflowResponse(w, req) @@ -259,6 +271,7 @@ func TestHandleWorkflowResponse(t *testing.T) { payload, _ := json.Marshal(response) req := httptest.NewRequest(http.MethodPost, "/workflows/response", bytes.NewBuffer(payload)) + authorizeWorkflowResponse(req) w := httptest.NewRecorder() wh.HandleWorkflowResponse(w, req) @@ -301,6 +314,7 @@ func TestHandleWorkflowResponse(t *testing.T) { payload, _ := json.Marshal(response) req := httptest.NewRequest(http.MethodPost, "/workflows/response", bytes.NewBuffer(payload)) + authorizeWorkflowResponse(req) w := httptest.NewRecorder() wh.HandleWorkflowResponse(w, req) @@ -313,4 +327,71 @@ func TestHandleWorkflowResponse(t *testing.T) { assert.Equal(t, db.StatusPending, updatedReq.Status) assert.Equal(t, response.ResponseData, updatedReq.ResponseData) }) -} \ No newline at end of file +} + +func TestHandleWorkflowResponseTokenValidation(t *testing.T) { + wh := NewWorkFlowHandler(db.TestDB) + payload := []byte(`{"request_id":"test-request","response_data":{}}`) + + t.Run("returns 401 when token is missing", func(t *testing.T) { + t.Setenv("SWWF_RESKEY", testWorkflowResponseToken) + + req := httptest.NewRequest(http.MethodPost, "/workflows/response", bytes.NewBuffer(payload)) + w := httptest.NewRecorder() + + wh.HandleWorkflowResponse(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.Contains(t, w.Body.String(), "Missing token") + }) + + t.Run("returns 500 when server token is not configured", func(t *testing.T) { + t.Setenv("SWWF_RESKEY", "") + + req := httptest.NewRequest(http.MethodPost, "/workflows/response", bytes.NewBuffer(payload)) + req.Header.Set("x-api-token", testWorkflowResponseToken) + w := httptest.NewRecorder() + + wh.HandleWorkflowResponse(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Contains(t, w.Body.String(), "Server not configured") + }) + + t.Run("returns 401 when token is invalid", func(t *testing.T) { + t.Setenv("SWWF_RESKEY", testWorkflowResponseToken) + + req := httptest.NewRequest(http.MethodPost, "/workflows/response", bytes.NewBuffer(payload)) + req.Header.Set("x-api-token", "wrong-token") + w := httptest.NewRecorder() + + wh.HandleWorkflowResponse(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.Contains(t, w.Body.String(), "Invalid token") + }) + + t.Run("accepts bearer token", func(t *testing.T) { + t.Setenv("SWWF_RESKEY", testWorkflowResponseToken) + + req := httptest.NewRequest(http.MethodPost, "/workflows/response", bytes.NewBuffer(payload)) + req.Header.Set("Authorization", "Bearer "+testWorkflowResponseToken) + + status, msg := validateStakworkResponseToken(req) + + assert.Equal(t, http.StatusOK, status) + assert.Empty(t, msg) + }) + + t.Run("accepts swwf token header", func(t *testing.T) { + t.Setenv("SWWF_RESKEY", testWorkflowResponseToken) + + req := httptest.NewRequest(http.MethodPost, "/workflows/response", bytes.NewBuffer(payload)) + req.Header.Set("x-swwf-reskey", testWorkflowResponseToken) + + status, msg := validateStakworkResponseToken(req) + + assert.Equal(t, http.StatusOK, status) + assert.Empty(t, msg) + }) +} diff --git a/mocks/Database.go b/mocks/Database.go index 5c98b7f66..15f0a18b8 100644 --- a/mocks/Database.go +++ b/mocks/Database.go @@ -5696,6 +5696,62 @@ func (_c *Database_GetBountyByCreated_Call) RunAndReturn(run func(uint) (db.NewB return _c } +// GetBountyByUnlockCode provides a mock function with given fields: code +func (_m *Database) GetBountyByUnlockCode(code string) (db.NewBounty, error) { + ret := _m.Called(code) + + if len(ret) == 0 { + panic("no return value specified for GetBountyByUnlockCode") + } + + var r0 db.NewBounty + var r1 error + if rf, ok := ret.Get(0).(func(string) (db.NewBounty, error)); ok { + return rf(code) + } + if rf, ok := ret.Get(0).(func(string) db.NewBounty); ok { + r0 = rf(code) + } else { + r0 = ret.Get(0).(db.NewBounty) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(code) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Database_GetBountyByUnlockCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetBountyByUnlockCode' +type Database_GetBountyByUnlockCode_Call struct { + *mock.Call +} + +// GetBountyByUnlockCode is a helper method to define mock.On call +// - code string +func (_e *Database_Expecter) GetBountyByUnlockCode(code interface{}) *Database_GetBountyByUnlockCode_Call { + return &Database_GetBountyByUnlockCode_Call{Call: _e.mock.On("GetBountyByUnlockCode", code)} +} + +func (_c *Database_GetBountyByUnlockCode_Call) Run(run func(code string)) *Database_GetBountyByUnlockCode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *Database_GetBountyByUnlockCode_Call) Return(_a0 db.NewBounty, _a1 error) *Database_GetBountyByUnlockCode_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Database_GetBountyByUnlockCode_Call) RunAndReturn(run func(string) (db.NewBounty, error)) *Database_GetBountyByUnlockCode_Call { + _c.Call.Return(run) + return _c +} + // GetBountyById provides a mock function with given fields: id func (_m *Database) GetBountyById(id string) ([]db.NewBounty, error) { ret := _m.Called(id) @@ -18423,4 +18479,4 @@ func (_c *Database_DeleteBountyStakeProcess_Call) Return(_a0 error) *Database_De func (_c *Database_DeleteBountyStakeProcess_Call) RunAndReturn(run func(uuid.UUID) error) *Database_DeleteBountyStakeProcess_Call { _c.Call.Return(run) return _c -} \ No newline at end of file +}