package middleware import ( "bufio" "bytes" "encoding/json" "fmt" "io" "net" "net/http" "net/url" "strings" "github.com/zero-ppanel/zero-ppanel/apps/api/internal/config" "github.com/zero-ppanel/zero-ppanel/pkg/cryptox" "github.com/zero-ppanel/zero-ppanel/pkg/xerr" "github.com/zeromicro/go-zero/rest/httpx" ) type DecryptMiddleware struct { conf config.Config } func NewDecryptMiddleware(c config.Config) *DecryptMiddleware { return &DecryptMiddleware{conf: c} } func (m *DecryptMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { fmt.Printf("[DEBUG] DecryptMiddleware entered, Security.Enable=%v, Login-Type=%q\n", m.conf.Security.Enable, r.Header.Get("Login-Type")) if !m.conf.Security.Enable { next(w, r) return } if r.Header.Get("Login-Type") != "device" { next(w, r) return } secret := m.conf.Security.SecuritySecret rw := newEncryptResponseWriter(w, secret) // 解密 GET query query := r.URL.Query() dataStr := query.Get("data") timeStr := query.Get("time") if dataStr != "" && timeStr != "" { if plain, err := cryptox.Decrypt(dataStr, secret, timeStr); err == nil { params := map[string]interface{}{} if json.Unmarshal(plain, ¶ms) == nil { for k, v := range params { query.Set(k, fmt.Sprintf("%v", v)) } query.Del("data") query.Del("time") rawQuery := query.Encode() if strings.Contains(r.RequestURI, "?") { r.RequestURI = r.RequestURI[:strings.Index(r.RequestURI, "?")] + "?" + rawQuery } r.URL.RawQuery = rawQuery } } } // 解密 POST body if r.Body != nil { body, err := io.ReadAll(r.Body) if err != nil || len(body) == 0 { // body 为空或读取失败,直接放行 r.Body = io.NopCloser(bytes.NewBuffer(body)) next(rw, r) rw.flush() return } var envelope struct { Data string `json:"data"` Time string `json:"time"` } if err := json.Unmarshal(body, &envelope); err != nil || envelope.Data == "" { httpx.Error(w, xerr.NewErrCode(xerr.DecryptFailed)) return } plain, err := cryptox.Decrypt(envelope.Data, secret, envelope.Time) if err != nil { httpx.Error(w, xerr.NewErrCode(xerr.DecryptFailed)) return } fmt.Printf("[DEBUG] decrypted body: %s\n", string(plain)) r.Body = io.NopCloser(bytes.NewBuffer(plain)) // 防止 httpx.Parse 内部的 r.ParseForm() 消费已替换的 body: // ParseForm 首行检查 r.PostForm == nil,置空后它不会再读 body。 r.PostForm = url.Values{} r.ContentLength = int64(len(plain)) } next(rw, r) rw.flush() } } // encryptResponseWriter 拦截响应,加密 data 字段 type encryptResponseWriter struct { http.ResponseWriter body *bytes.Buffer secret string status int } func newEncryptResponseWriter(w http.ResponseWriter, secret string) *encryptResponseWriter { return &encryptResponseWriter{ ResponseWriter: w, body: new(bytes.Buffer), secret: secret, status: http.StatusOK, } } func (rw *encryptResponseWriter) WriteHeader(code int) { rw.status = code } func (rw *encryptResponseWriter) Write(data []byte) (int, error) { return rw.body.Write(data) } func (rw *encryptResponseWriter) WriteString(s string) (int, error) { return rw.body.WriteString(s) } func (rw *encryptResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return rw.ResponseWriter.(http.Hijacker).Hijack() } func (rw *encryptResponseWriter) flush() { buf := rw.body.Bytes() out := buf // 尝试加密 data 字段 params := map[string]interface{}{} if err := json.Unmarshal(buf, ¶ms); err == nil { if data := params["data"]; data != nil { var jsonData []byte if str, ok := data.(string); ok { jsonData = []byte(str) } else { jsonData, _ = json.Marshal(data) } if dataB64, nonce, err := cryptox.Encrypt(jsonData, rw.secret); err == nil { params["data"] = map[string]interface{}{ "data": dataB64, "time": nonce, } if enc, err := json.Marshal(params); err == nil { out = enc } } } } rw.ResponseWriter.WriteHeader(rw.status) rw.ResponseWriter.Write(out) }