zero-ppanel/apps/api/internal/middleware/decryptMiddleware.go

159 lines
3.7 KiB
Go

package middleware
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
"github.com/zero-ppanel/zero-ppanel/apps/api/internal/config"
"github.com/zero-ppanel/zero-ppanel/pkg/cryptox"
"github.com/zero-ppanel/zero-ppanel/pkg/xerr"
"github.com/zeromicro/go-zero/rest/httpx"
)
type DecryptMiddleware struct {
conf config.Config
}
func NewDecryptMiddleware(c config.Config) *DecryptMiddleware {
return &DecryptMiddleware{conf: c}
}
func (m *DecryptMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !m.conf.Security.Enable {
next(w, r)
return
}
if r.Header.Get("Login-Type") != "device" {
next(w, r)
return
}
secret := m.conf.Security.SecuritySecret
rw := newEncryptResponseWriter(w, secret)
// 解密 GET query
query := r.URL.Query()
dataStr := query.Get("data")
timeStr := query.Get("time")
if dataStr != "" && timeStr != "" {
if plain, err := cryptox.Decrypt(dataStr, secret, timeStr); err == nil {
params := map[string]interface{}{}
if json.Unmarshal(plain, &params) == nil {
for k, v := range params {
query.Set(k, fmt.Sprintf("%v", v))
}
query.Del("data")
query.Del("time")
rawQuery := query.Encode()
if strings.Contains(r.RequestURI, "?") {
r.RequestURI = r.RequestURI[:strings.Index(r.RequestURI, "?")] + "?" + rawQuery
}
r.URL.RawQuery = rawQuery
}
}
}
// 解密 POST body
if r.Body != nil {
body, err := io.ReadAll(r.Body)
if err != nil || len(body) == 0 {
// body 为空或读取失败,直接放行
r.Body = io.NopCloser(bytes.NewBuffer(body))
next(rw, r)
rw.flush()
return
}
var envelope struct {
Data string `json:"data"`
Time string `json:"time"`
}
if err := json.Unmarshal(body, &envelope); err != nil || envelope.Data == "" {
httpx.Error(w, xerr.NewErrCode(xerr.DecryptFailed))
return
}
plain, err := cryptox.Decrypt(envelope.Data, secret, envelope.Time)
if err != nil {
httpx.Error(w, xerr.NewErrCode(xerr.DecryptFailed))
return
}
r.Body = io.NopCloser(bytes.NewBuffer(plain))
}
next(rw, r)
rw.flush()
}
}
// encryptResponseWriter 拦截响应,加密 data 字段
type encryptResponseWriter struct {
http.ResponseWriter
body *bytes.Buffer
secret string
status int
}
func newEncryptResponseWriter(w http.ResponseWriter, secret string) *encryptResponseWriter {
return &encryptResponseWriter{
ResponseWriter: w,
body: new(bytes.Buffer),
secret: secret,
status: http.StatusOK,
}
}
func (rw *encryptResponseWriter) WriteHeader(code int) {
rw.status = code
}
func (rw *encryptResponseWriter) Write(data []byte) (int, error) {
return rw.body.Write(data)
}
func (rw *encryptResponseWriter) WriteString(s string) (int, error) {
return rw.body.WriteString(s)
}
func (rw *encryptResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return rw.ResponseWriter.(http.Hijacker).Hijack()
}
func (rw *encryptResponseWriter) flush() {
buf := rw.body.Bytes()
out := buf
// 尝试加密 data 字段
params := map[string]interface{}{}
if err := json.Unmarshal(buf, &params); err == nil {
if data := params["data"]; data != nil {
var jsonData []byte
if str, ok := data.(string); ok {
jsonData = []byte(str)
} else {
jsonData, _ = json.Marshal(data)
}
if dataB64, nonce, err := cryptox.Encrypt(jsonData, rw.secret); err == nil {
params["data"] = map[string]interface{}{
"data": dataB64,
"time": nonce,
}
if enc, err := json.Marshal(params); err == nil {
out = enc
}
}
}
}
rw.ResponseWriter.WriteHeader(rw.status)
rw.ResponseWriter.Write(out)
}