package middleware import ( "bytes" "io" "net/http" "strings" "github.com/zero-ppanel/zero-ppanel/apps/api/internal/config" "github.com/zero-ppanel/zero-ppanel/pkg/signature" "github.com/zero-ppanel/zero-ppanel/pkg/xerr" "github.com/zeromicro/go-zero/rest/httpx" ) type SignatureMiddleware struct { conf config.Config validator *signature.Validator } func NewSignatureMiddleware(c config.Config, store signature.NonceStore) *SignatureMiddleware { return &SignatureMiddleware{ conf: c, validator: signature.NewValidator(c.AppSignature, store), } } func (m *SignatureMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { appId := r.Header.Get("X-App-Id") // X-App-Id 为空,提示非法访问 if appId == "" { httpx.Error(w, xerr.NewErrCode(xerr.InvalidAccess)) return } // SkipPrefixes 白名单 for _, prefix := range m.conf.AppSignature.SkipPrefixes { if strings.HasPrefix(r.URL.Path, prefix) { next(w, r) return } } timestamp := r.Header.Get("X-Timestamp") nonce := r.Header.Get("X-Nonce") sig := r.Header.Get("X-Signature") if timestamp == "" || nonce == "" || sig == "" { httpx.Error(w, xerr.NewErrCode(xerr.SignatureMissing)) return } // 读取 body(签名对原始 body bytes 计算) var bodyBytes []byte if r.Body != nil { bodyBytes, _ = io.ReadAll(r.Body) r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) } sts := signature.BuildStringToSign(r.Method, r.URL.Path, r.URL.RawQuery, bodyBytes, appId, timestamp, nonce) if err := m.validator.Validate(r.Context(), appId, timestamp, nonce, sig, sts); err != nil { code := mapSignatureErr(err) httpx.Error(w, xerr.NewErrCode(code)) return } next(w, r) } } func mapSignatureErr(err error) int { switch err { case signature.ErrSignatureMissing: return xerr.SignatureMissing case signature.ErrSignatureExpired: return xerr.SignatureExpired case signature.ErrSignatureReplay: return xerr.SignatureReplay default: return xerr.SignatureInvalid } }