refactor(apple): 优化base64解码和JWT签名逻辑
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 6m28s

移除notification.go中冗余的base64解码代码,使用统一的decodeB64URL函数处理
在serverapi.go中改进ES256签名实现,正确处理P-256曲线的R和S值填充
This commit is contained in:
shanshanzhong 2025-12-15 23:59:46 -08:00
parent e11ed2338d
commit b391c12c1b
2 changed files with 20 additions and 13 deletions

View File

@ -1,7 +1,6 @@
package apple package apple
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"strings" "strings"
) )
@ -16,14 +15,7 @@ func ParseNotificationSignedPayload(jws string) (*NotificationEnvelope, error) {
if len(parts) != 3 { if len(parts) != 3 {
return nil, ErrInvalidJWS return nil, ErrInvalidJWS
} }
payloadB64 := cleanB64(parts[1]) data, err := decodeB64URL(parts[1])
switch len(payloadB64) % 4 {
case 2:
payloadB64 += "=="
case 3:
payloadB64 += "="
}
data, err := base64.RawURLEncoding.DecodeString(payloadB64)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -22,7 +22,7 @@ type ServerAPIConfig struct {
} }
func buildAPIToken(cfg ServerAPIConfig) (string, error) { func buildAPIToken(cfg ServerAPIConfig) (string, error) {
header := map[string]string{ header := map[string]interface{}{
"alg": "ES256", "alg": "ES256",
"kid": cfg.KeyID, "kid": cfg.KeyID,
"typ": "JWT", "typ": "JWT",
@ -52,12 +52,27 @@ func buildAPIToken(cfg ServerAPIConfig) (string, error) {
if !ok { if !ok {
return "", fmt.Errorf("private key is not ECDSA") return "", fmt.Errorf("private key is not ECDSA")
} }
hash := unsigned // ES256 signs SHA-256 of input; jwt libs do hashing, we implement manually // Correctly calculate R and S for ES256 (P-256 + SHA-256)
digest := sha256Sum([]byte(hash)) digest := sha256Sum([]byte(unsigned))
sig, err := ecdsa.SignASN1(rand.Reader, priv, digest) r, s, err := ecdsa.Sign(rand.Reader, priv, digest)
if err != nil { if err != nil {
return "", err return "", err
} }
// Concatenate R and S (each 32 bytes for P-256)
curveBits := priv.Curve.Params().BitSize
keyBytes := curveBits / 8
if curveBits%8 > 0 {
keyBytes += 1
}
rBytes := r.Bytes()
rBytesPadded := make([]byte, keyBytes)
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
sBytes := s.Bytes()
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
sig := append(rBytesPadded, sBytesPadded...)
return unsigned + "." + base64.RawURLEncoding.EncodeToString(sig), nil return unsigned + "." + base64.RawURLEncoding.EncodeToString(sig), nil
} }