package middleware import ( "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "encoding/json" "io" "net/http" "net/http/httptest" "strconv" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/config" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/pkg/signature" "github.com/perfect-panel/server/pkg/xerr" ) type testNonceStore struct { seen map[string]bool } func newTestNonceStore() *testNonceStore { return &testNonceStore{seen: map[string]bool{}} } func (s *testNonceStore) SetIfNotExists(_ context.Context, appId, nonce string, _ int64) (bool, error) { key := appId + ":" + nonce if s.seen[key] { return true, nil } s.seen[key] = true return false, nil } func makeTestSignature(secret, sts string) string { mac := hmac.New(sha256.New, []byte(secret)) mac.Write([]byte(sts)) return hex.EncodeToString(mac.Sum(nil)) } func newTestServiceContext() *svc.ServiceContext { conf := config.Config{} conf.Signature.EnableSignature = true conf.AppSignature = signature.SignatureConf{ AppSecrets: map[string]string{ "web-client": "test-secret", }, ValidWindowSeconds: 300, SkipPrefixes: []string{ "/v1/public/health", }, } return &svc.ServiceContext{ Config: conf, SignatureValidator: signature.NewValidator(conf.AppSignature, newTestNonceStore()), } } func newTestServiceContextWithSwitch(enabled bool) *svc.ServiceContext { svcCtx := newTestServiceContext() svcCtx.Config.Signature.EnableSignature = enabled return svcCtx } func decodeCode(t *testing.T, body []byte) uint32 { t.Helper() var resp struct { Code uint32 `json:"code"` } if err := json.Unmarshal(body, &resp); err != nil { t.Fatalf("unmarshal response failed: %v", err) } return resp.Code } func TestSignatureMiddlewareMissingAppID(t *testing.T) { gin.SetMode(gin.TestMode) svcCtx := newTestServiceContext() r := gin.New() r.Use(SignatureMiddleware(svcCtx)) r.GET("/v1/public/ping", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) req := httptest.NewRequest(http.MethodGet, "/v1/public/ping", nil) req.Header.Set("X-Signature-Enabled", "1") resp := httptest.NewRecorder() r.ServeHTTP(resp, req) if code := decodeCode(t, resp.Body.Bytes()); code != xerr.InvalidAccess { t.Fatalf("expected InvalidAccess(%d), got %d", xerr.InvalidAccess, code) } } func TestSignatureMiddlewareMissingSignatureHeaders(t *testing.T) { gin.SetMode(gin.TestMode) svcCtx := newTestServiceContext() r := gin.New() r.Use(SignatureMiddleware(svcCtx)) r.GET("/v1/public/ping", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) req := httptest.NewRequest(http.MethodGet, "/v1/public/ping", nil) req.Header.Set("X-Signature-Enabled", "1") req.Header.Set("X-App-Id", "web-client") resp := httptest.NewRecorder() r.ServeHTTP(resp, req) if code := decodeCode(t, resp.Body.Bytes()); code != xerr.SignatureMissing { t.Fatalf("expected SignatureMissing(%d), got %d", xerr.SignatureMissing, code) } } func TestSignatureMiddlewarePassesWhenSignatureHeaderMissing(t *testing.T) { gin.SetMode(gin.TestMode) svcCtx := newTestServiceContext() r := gin.New() r.Use(SignatureMiddleware(svcCtx)) r.GET("/v1/public/ping", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) req := httptest.NewRequest(http.MethodGet, "/v1/public/ping", nil) resp := httptest.NewRecorder() r.ServeHTTP(resp, req) if resp.Code != http.StatusOK || resp.Body.String() != "ok" { t.Fatalf("expected pass-through without X-Signature-Enabled, got code=%d body=%s", resp.Code, resp.Body.String()) } } func TestSignatureMiddlewarePassesWhenSignatureHeaderIsZero(t *testing.T) { gin.SetMode(gin.TestMode) svcCtx := newTestServiceContext() r := gin.New() r.Use(SignatureMiddleware(svcCtx)) r.GET("/v1/public/ping", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) req := httptest.NewRequest(http.MethodGet, "/v1/public/ping", nil) req.Header.Set("X-Signature-Enabled", "0") resp := httptest.NewRecorder() r.ServeHTTP(resp, req) if resp.Code != http.StatusOK || resp.Body.String() != "ok" { t.Fatalf("expected pass-through when X-Signature-Enabled=0, got code=%d body=%s", resp.Code, resp.Body.String()) } } func TestSignatureMiddlewarePassesWhenSystemSwitchDisabled(t *testing.T) { gin.SetMode(gin.TestMode) svcCtx := newTestServiceContextWithSwitch(false) r := gin.New() r.Use(SignatureMiddleware(svcCtx)) r.GET("/v1/public/ping", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) req := httptest.NewRequest(http.MethodGet, "/v1/public/ping", nil) req.Header.Set("X-Signature-Enabled", "1") resp := httptest.NewRecorder() r.ServeHTTP(resp, req) if resp.Code != http.StatusOK || resp.Body.String() != "ok" { t.Fatalf("expected pass-through when system switch is disabled, got code=%d body=%s", resp.Code, resp.Body.String()) } } func TestSignatureMiddlewareSkipsNonPublicPath(t *testing.T) { gin.SetMode(gin.TestMode) svcCtx := newTestServiceContext() r := gin.New() r.Use(SignatureMiddleware(svcCtx)) r.GET("/v1/admin/ping", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) req := httptest.NewRequest(http.MethodGet, "/v1/admin/ping", nil) resp := httptest.NewRecorder() r.ServeHTTP(resp, req) if resp.Code != http.StatusOK || resp.Body.String() != "ok" { t.Fatalf("expected pass-through for non-public path, got code=%d body=%s", resp.Code, resp.Body.String()) } } func TestSignatureMiddlewareHonorsSkipPrefix(t *testing.T) { gin.SetMode(gin.TestMode) svcCtx := newTestServiceContext() r := gin.New() r.Use(SignatureMiddleware(svcCtx)) r.GET("/v1/public/healthz", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) req := httptest.NewRequest(http.MethodGet, "/v1/public/healthz", nil) resp := httptest.NewRecorder() r.ServeHTTP(resp, req) if resp.Code != http.StatusOK || resp.Body.String() != "ok" { t.Fatalf("expected skip-prefix pass-through, got code=%d body=%s", resp.Code, resp.Body.String()) } } func TestSignatureMiddlewareRestoresBodyAfterVerify(t *testing.T) { gin.SetMode(gin.TestMode) svcCtx := newTestServiceContext() r := gin.New() r.Use(SignatureMiddleware(svcCtx)) r.POST("/v1/public/body", func(c *gin.Context) { body, _ := io.ReadAll(c.Request.Body) c.String(http.StatusOK, string(body)) }) body := `{"hello":"world"}` req := httptest.NewRequest(http.MethodPost, "/v1/public/body?a=1&b=2", strings.NewReader(body)) ts := strconv.FormatInt(time.Now().Unix(), 10) nonce := "nonce-body-1" sts := signature.BuildStringToSign(http.MethodPost, "/v1/public/body", "a=1&b=2", []byte(body), "web-client", ts, nonce) req.Header.Set("X-Signature-Enabled", "1") req.Header.Set("X-App-Id", "web-client") req.Header.Set("X-Timestamp", ts) req.Header.Set("X-Nonce", nonce) req.Header.Set("X-Signature", makeTestSignature("test-secret", sts)) resp := httptest.NewRecorder() r.ServeHTTP(resp, req) if resp.Code != http.StatusOK || resp.Body.String() != body { t.Fatalf("expected restored body, got code=%d body=%s", resp.Code, resp.Body.String()) } }