本站源代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

149 lines
3.6KB

  1. package jwt
  2. import (
  3. "crypto"
  4. "crypto/ecdsa"
  5. "crypto/rand"
  6. "errors"
  7. "math/big"
  8. )
  9. var (
  10. // Sadly this is missing from crypto/ecdsa compared to crypto/rsa
  11. ErrECDSAVerification = errors.New("crypto/ecdsa: verification error")
  12. )
  13. // Implements the ECDSA family of signing methods signing methods
  14. // Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification
  15. type SigningMethodECDSA struct {
  16. Name string
  17. Hash crypto.Hash
  18. KeySize int
  19. CurveBits int
  20. }
  21. // Specific instances for EC256 and company
  22. var (
  23. SigningMethodES256 *SigningMethodECDSA
  24. SigningMethodES384 *SigningMethodECDSA
  25. SigningMethodES512 *SigningMethodECDSA
  26. )
  27. func init() {
  28. // ES256
  29. SigningMethodES256 = &SigningMethodECDSA{"ES256", crypto.SHA256, 32, 256}
  30. RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod {
  31. return SigningMethodES256
  32. })
  33. // ES384
  34. SigningMethodES384 = &SigningMethodECDSA{"ES384", crypto.SHA384, 48, 384}
  35. RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod {
  36. return SigningMethodES384
  37. })
  38. // ES512
  39. SigningMethodES512 = &SigningMethodECDSA{"ES512", crypto.SHA512, 66, 521}
  40. RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod {
  41. return SigningMethodES512
  42. })
  43. }
  44. func (m *SigningMethodECDSA) Alg() string {
  45. return m.Name
  46. }
  47. // Implements the Verify method from SigningMethod
  48. // For this verify method, key must be an ecdsa.PublicKey struct
  49. func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
  50. var err error
  51. // Decode the signature
  52. var sig []byte
  53. if sig, err = DecodeSegment(signature); err != nil {
  54. return err
  55. }
  56. // Get the key
  57. var ecdsaKey *ecdsa.PublicKey
  58. switch k := key.(type) {
  59. case *ecdsa.PublicKey:
  60. ecdsaKey = k
  61. default:
  62. return ErrInvalidKeyType
  63. }
  64. if len(sig) != 2*m.KeySize {
  65. return ErrECDSAVerification
  66. }
  67. r := big.NewInt(0).SetBytes(sig[:m.KeySize])
  68. s := big.NewInt(0).SetBytes(sig[m.KeySize:])
  69. // Create hasher
  70. if !m.Hash.Available() {
  71. return ErrHashUnavailable
  72. }
  73. hasher := m.Hash.New()
  74. hasher.Write([]byte(signingString))
  75. // Verify the signature
  76. if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true {
  77. return nil
  78. } else {
  79. return ErrECDSAVerification
  80. }
  81. }
  82. // Implements the Sign method from SigningMethod
  83. // For this signing method, key must be an ecdsa.PrivateKey struct
  84. func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) {
  85. // Get the key
  86. var ecdsaKey *ecdsa.PrivateKey
  87. switch k := key.(type) {
  88. case *ecdsa.PrivateKey:
  89. ecdsaKey = k
  90. default:
  91. return "", ErrInvalidKeyType
  92. }
  93. // Create the hasher
  94. if !m.Hash.Available() {
  95. return "", ErrHashUnavailable
  96. }
  97. hasher := m.Hash.New()
  98. hasher.Write([]byte(signingString))
  99. // Sign the string and return r, s
  100. if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil {
  101. curveBits := ecdsaKey.Curve.Params().BitSize
  102. if m.CurveBits != curveBits {
  103. return "", ErrInvalidKey
  104. }
  105. keyBytes := curveBits / 8
  106. if curveBits%8 > 0 {
  107. keyBytes += 1
  108. }
  109. // We serialize the outpus (r and s) into big-endian byte arrays and pad
  110. // them with zeros on the left to make sure the sizes work out. Both arrays
  111. // must be keyBytes long, and the output must be 2*keyBytes long.
  112. rBytes := r.Bytes()
  113. rBytesPadded := make([]byte, keyBytes)
  114. copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
  115. sBytes := s.Bytes()
  116. sBytesPadded := make([]byte, keyBytes)
  117. copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
  118. out := append(rBytesPadded, sBytesPadded...)
  119. return EncodeSegment(out), nil
  120. } else {
  121. return "", err
  122. }
  123. }
上海开阖软件有限公司 沪ICP备12045867号-1