159 lines
3.7 KiB
Go
159 lines
3.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/zero-ppanel/zero-ppanel/apps/api/internal/config"
|
|
"github.com/zero-ppanel/zero-ppanel/pkg/cryptox"
|
|
"github.com/zero-ppanel/zero-ppanel/pkg/xerr"
|
|
"github.com/zeromicro/go-zero/rest/httpx"
|
|
)
|
|
|
|
type DecryptMiddleware struct {
|
|
conf config.Config
|
|
}
|
|
|
|
func NewDecryptMiddleware(c config.Config) *DecryptMiddleware {
|
|
return &DecryptMiddleware{conf: c}
|
|
}
|
|
|
|
func (m *DecryptMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if !m.conf.Security.Enable {
|
|
next(w, r)
|
|
return
|
|
}
|
|
|
|
if r.Header.Get("Login-Type") != "device" {
|
|
next(w, r)
|
|
return
|
|
}
|
|
|
|
secret := m.conf.Security.SecuritySecret
|
|
rw := newEncryptResponseWriter(w, secret)
|
|
|
|
// 解密 GET query
|
|
query := r.URL.Query()
|
|
dataStr := query.Get("data")
|
|
timeStr := query.Get("time")
|
|
if dataStr != "" && timeStr != "" {
|
|
if plain, err := cryptox.Decrypt(dataStr, secret, timeStr); err == nil {
|
|
params := map[string]interface{}{}
|
|
if json.Unmarshal(plain, ¶ms) == nil {
|
|
for k, v := range params {
|
|
query.Set(k, fmt.Sprintf("%v", v))
|
|
}
|
|
query.Del("data")
|
|
query.Del("time")
|
|
rawQuery := query.Encode()
|
|
if strings.Contains(r.RequestURI, "?") {
|
|
r.RequestURI = r.RequestURI[:strings.Index(r.RequestURI, "?")] + "?" + rawQuery
|
|
}
|
|
r.URL.RawQuery = rawQuery
|
|
}
|
|
}
|
|
}
|
|
|
|
// 解密 POST body
|
|
if r.Body != nil {
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil || len(body) == 0 {
|
|
// body 为空或读取失败,直接放行
|
|
r.Body = io.NopCloser(bytes.NewBuffer(body))
|
|
next(rw, r)
|
|
rw.flush()
|
|
return
|
|
}
|
|
|
|
var envelope struct {
|
|
Data string `json:"data"`
|
|
Time string `json:"time"`
|
|
}
|
|
if err := json.Unmarshal(body, &envelope); err != nil || envelope.Data == "" {
|
|
httpx.Error(w, xerr.NewErrCode(xerr.DecryptFailed))
|
|
return
|
|
}
|
|
|
|
plain, err := cryptox.Decrypt(envelope.Data, secret, envelope.Time)
|
|
if err != nil {
|
|
httpx.Error(w, xerr.NewErrCode(xerr.DecryptFailed))
|
|
return
|
|
}
|
|
r.Body = io.NopCloser(bytes.NewBuffer(plain))
|
|
}
|
|
|
|
next(rw, r)
|
|
rw.flush()
|
|
}
|
|
}
|
|
|
|
// encryptResponseWriter 拦截响应,加密 data 字段
|
|
type encryptResponseWriter struct {
|
|
http.ResponseWriter
|
|
body *bytes.Buffer
|
|
secret string
|
|
status int
|
|
}
|
|
|
|
func newEncryptResponseWriter(w http.ResponseWriter, secret string) *encryptResponseWriter {
|
|
return &encryptResponseWriter{
|
|
ResponseWriter: w,
|
|
body: new(bytes.Buffer),
|
|
secret: secret,
|
|
status: http.StatusOK,
|
|
}
|
|
}
|
|
|
|
func (rw *encryptResponseWriter) WriteHeader(code int) {
|
|
rw.status = code
|
|
}
|
|
|
|
func (rw *encryptResponseWriter) Write(data []byte) (int, error) {
|
|
return rw.body.Write(data)
|
|
}
|
|
|
|
func (rw *encryptResponseWriter) WriteString(s string) (int, error) {
|
|
return rw.body.WriteString(s)
|
|
}
|
|
|
|
func (rw *encryptResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
return rw.ResponseWriter.(http.Hijacker).Hijack()
|
|
}
|
|
|
|
func (rw *encryptResponseWriter) flush() {
|
|
buf := rw.body.Bytes()
|
|
out := buf
|
|
|
|
// 尝试加密 data 字段
|
|
params := map[string]interface{}{}
|
|
if err := json.Unmarshal(buf, ¶ms); err == nil {
|
|
if data := params["data"]; data != nil {
|
|
var jsonData []byte
|
|
if str, ok := data.(string); ok {
|
|
jsonData = []byte(str)
|
|
} else {
|
|
jsonData, _ = json.Marshal(data)
|
|
}
|
|
if dataB64, nonce, err := cryptox.Encrypt(jsonData, rw.secret); err == nil {
|
|
params["data"] = map[string]interface{}{
|
|
"data": dataB64,
|
|
"time": nonce,
|
|
}
|
|
if enc, err := json.Marshal(params); err == nil {
|
|
out = enc
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
rw.ResponseWriter.WriteHeader(rw.status)
|
|
rw.ResponseWriter.Write(out)
|
|
}
|