本站源代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

413 lines
12KB

  1. package openidConnect
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "github.com/markbates/goth"
  9. "golang.org/x/oauth2"
  10. "io/ioutil"
  11. "net/http"
  12. "strings"
  13. "time"
  14. )
  15. const (
  16. // Standard Claims http://openid.net/specs/openid-connect-core-1_0.html#StandardClaims
  17. // fixed, cannot be changed
  18. subjectClaim = "sub"
  19. expiryClaim = "exp"
  20. audienceClaim = "aud"
  21. issuerClaim = "iss"
  22. PreferredUsernameClaim = "preferred_username"
  23. EmailClaim = "email"
  24. NameClaim = "name"
  25. NicknameClaim = "nickname"
  26. PictureClaim = "picture"
  27. GivenNameClaim = "given_name"
  28. FamilyNameClaim = "family_name"
  29. AddressClaim = "address"
  30. // Unused but available to set in Provider claims
  31. MiddleNameClaim = "middle_name"
  32. ProfileClaim = "profile"
  33. WebsiteClaim = "website"
  34. EmailVerifiedClaim = "email_verified"
  35. GenderClaim = "gender"
  36. BirthdateClaim = "birthdate"
  37. ZoneinfoClaim = "zoneinfo"
  38. LocaleClaim = "locale"
  39. PhoneNumberClaim = "phone_number"
  40. PhoneNumberVerifiedClaim = "phone_number_verified"
  41. UpdatedAtClaim = "updated_at"
  42. clockSkew = 10 * time.Second
  43. )
  44. // Provider is the implementation of `goth.Provider` for accessing OpenID Connect provider
  45. type Provider struct {
  46. ClientKey string
  47. Secret string
  48. CallbackURL string
  49. HTTPClient *http.Client
  50. config *oauth2.Config
  51. openIDConfig *OpenIDConfig
  52. providerName string
  53. UserIdClaims []string
  54. NameClaims []string
  55. NickNameClaims []string
  56. EmailClaims []string
  57. AvatarURLClaims []string
  58. FirstNameClaims []string
  59. LastNameClaims []string
  60. LocationClaims []string
  61. SkipUserInfoRequest bool
  62. }
  63. type OpenIDConfig struct {
  64. AuthEndpoint string `json:"authorization_endpoint"`
  65. TokenEndpoint string `json:"token_endpoint"`
  66. UserInfoEndpoint string `json:"userinfo_endpoint"`
  67. Issuer string `json:"issuer"`
  68. }
  69. // New creates a new OpenID Connect provider, and sets up important connection details.
  70. // You should always call `openidConnect.New` to get a new Provider. Never try to create
  71. // one manually.
  72. // New returns an implementation of an OpenID Connect Authorization Code Flow
  73. // See http://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth
  74. // ID Token decryption is not (yet) supported
  75. // UserInfo decryption is not (yet) supported
  76. func New(clientKey, secret, callbackURL, openIDAutoDiscoveryURL string, scopes ...string) (*Provider, error) {
  77. p := &Provider{
  78. ClientKey: clientKey,
  79. Secret: secret,
  80. CallbackURL: callbackURL,
  81. UserIdClaims: []string{subjectClaim},
  82. NameClaims: []string{NameClaim},
  83. NickNameClaims: []string{NicknameClaim, PreferredUsernameClaim},
  84. EmailClaims: []string{EmailClaim},
  85. AvatarURLClaims: []string{PictureClaim},
  86. FirstNameClaims: []string{GivenNameClaim},
  87. LastNameClaims: []string{FamilyNameClaim},
  88. LocationClaims: []string{AddressClaim},
  89. providerName: "openid-connect",
  90. }
  91. openIDConfig, err := getOpenIDConfig(p, openIDAutoDiscoveryURL)
  92. if err != nil {
  93. return nil, err
  94. }
  95. p.openIDConfig = openIDConfig
  96. p.config = newConfig(p, scopes, openIDConfig)
  97. return p, nil
  98. }
  99. // Name is the name used to retrieve this provider later.
  100. func (p *Provider) Name() string {
  101. return p.providerName
  102. }
  103. // SetName is to update the name of the provider (needed in case of multiple providers of 1 type)
  104. func (p *Provider) SetName(name string) {
  105. p.providerName = name
  106. }
  107. func (p *Provider) Client() *http.Client {
  108. return goth.HTTPClientWithFallBack(p.HTTPClient)
  109. }
  110. // Debug is a no-op for the openidConnect package.
  111. func (p *Provider) Debug(debug bool) {}
  112. // BeginAuth asks the OpenID Connect provider for an authentication end-point.
  113. func (p *Provider) BeginAuth(state string) (goth.Session, error) {
  114. url := p.config.AuthCodeURL(state)
  115. session := &Session{
  116. AuthURL: url,
  117. }
  118. return session, nil
  119. }
  120. // FetchUser will use the the id_token and access requested information about the user.
  121. func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
  122. sess := session.(*Session)
  123. expiresAt := sess.ExpiresAt
  124. if sess.IDToken == "" {
  125. return goth.User{}, fmt.Errorf("%s cannot get user information without id_token", p.providerName)
  126. }
  127. // decode returned id token to get expiry
  128. claims, err := decodeJWT(sess.IDToken)
  129. if err != nil {
  130. return goth.User{}, fmt.Errorf("oauth2: error decoding JWT token: %v", err)
  131. }
  132. expiry, err := p.validateClaims(claims)
  133. if err != nil {
  134. return goth.User{}, fmt.Errorf("oauth2: error validating JWT token: %v", err)
  135. }
  136. if expiry.Before(expiresAt) {
  137. expiresAt = expiry
  138. }
  139. if err := p.getUserInfo(sess.AccessToken, claims); err != nil {
  140. return goth.User{}, err
  141. }
  142. user := goth.User{
  143. AccessToken: sess.AccessToken,
  144. Provider: p.Name(),
  145. RefreshToken: sess.RefreshToken,
  146. ExpiresAt: expiresAt,
  147. RawData: claims,
  148. }
  149. p.userFromClaims(claims, &user)
  150. return user, err
  151. }
  152. //RefreshTokenAvailable refresh token is provided by auth provider or not
  153. func (p *Provider) RefreshTokenAvailable() bool {
  154. return true
  155. }
  156. //RefreshToken get new access token based on the refresh token
  157. func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
  158. token := &oauth2.Token{RefreshToken: refreshToken}
  159. ts := p.config.TokenSource(oauth2.NoContext, token)
  160. newToken, err := ts.Token()
  161. if err != nil {
  162. return nil, err
  163. }
  164. return newToken, err
  165. }
  166. // validate according to standard, returns expiry
  167. // http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
  168. func (p *Provider) validateClaims(claims map[string]interface{}) (time.Time, error) {
  169. audience := getClaimValue(claims, []string{audienceClaim})
  170. if audience != p.ClientKey {
  171. found := false
  172. audiences := getClaimValues(claims, []string{audienceClaim})
  173. for _, aud := range audiences {
  174. if aud == p.ClientKey {
  175. found = true
  176. break
  177. }
  178. }
  179. if !found {
  180. return time.Time{}, errors.New("audience in token does not match client key")
  181. }
  182. }
  183. issuer := getClaimValue(claims, []string{issuerClaim})
  184. if issuer != p.openIDConfig.Issuer {
  185. return time.Time{}, errors.New("issuer in token does not match issuer in OpenIDConfig discovery")
  186. }
  187. // expiry is required for JWT, not for UserInfoResponse
  188. // is actually a int64, so force it in to that type
  189. expiryClaim := int64(claims[expiryClaim].(float64))
  190. expiry := time.Unix(expiryClaim, 0)
  191. if expiry.Add(clockSkew).Before(time.Now()) {
  192. return time.Time{}, errors.New("user info JWT token is expired")
  193. }
  194. return expiry, nil
  195. }
  196. func (p *Provider) userFromClaims(claims map[string]interface{}, user *goth.User) {
  197. // required
  198. user.UserID = getClaimValue(claims, p.UserIdClaims)
  199. user.Name = getClaimValue(claims, p.NameClaims)
  200. user.NickName = getClaimValue(claims, p.NickNameClaims)
  201. user.Email = getClaimValue(claims, p.EmailClaims)
  202. user.AvatarURL = getClaimValue(claims, p.AvatarURLClaims)
  203. user.FirstName = getClaimValue(claims, p.FirstNameClaims)
  204. user.LastName = getClaimValue(claims, p.LastNameClaims)
  205. user.Location = getClaimValue(claims, p.LocationClaims)
  206. }
  207. func (p *Provider) getUserInfo(accessToken string, claims map[string]interface{}) error {
  208. // skip if there is no UserInfoEndpoint or is explicitly disabled
  209. if p.openIDConfig.UserInfoEndpoint == "" || p.SkipUserInfoRequest {
  210. return nil
  211. }
  212. userInfoClaims, err := p.fetchUserInfo(p.openIDConfig.UserInfoEndpoint, accessToken)
  213. if err != nil {
  214. return err
  215. }
  216. // The sub (subject) Claim MUST always be returned in the UserInfo Response.
  217. // http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
  218. userInfoSubject := getClaimValue(userInfoClaims, []string{subjectClaim})
  219. if userInfoSubject == "" {
  220. return fmt.Errorf("userinfo response did not contain a 'sub' claim: %#v", userInfoClaims)
  221. }
  222. // The sub Claim in the UserInfo Response MUST be verified to exactly match the sub Claim in the ID Token;
  223. // if they do not match, the UserInfo Response values MUST NOT be used.
  224. // http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
  225. subject := getClaimValue(claims, []string{subjectClaim})
  226. if userInfoSubject != subject {
  227. return fmt.Errorf("userinfo 'sub' claim (%s) did not match id_token 'sub' claim (%s)", userInfoSubject, subject)
  228. }
  229. // Merge in userinfo claims in case id_token claims contained some that userinfo did not
  230. for k, v := range userInfoClaims {
  231. claims[k] = v
  232. }
  233. return nil
  234. }
  235. // fetch and decode JSON from the given UserInfo URL
  236. func (p *Provider) fetchUserInfo(url, accessToken string) (map[string]interface{}, error) {
  237. req, _ := http.NewRequest("GET", url, nil)
  238. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
  239. resp, err := p.Client().Do(req)
  240. if err != nil {
  241. return nil, err
  242. }
  243. defer resp.Body.Close()
  244. if resp.StatusCode != http.StatusOK {
  245. return nil, fmt.Errorf("Non-200 response from UserInfo: %d, WWW-Authenticate=%s", resp.StatusCode, resp.Header.Get("WWW-Authenticate"))
  246. }
  247. // The UserInfo Claims MUST be returned as the members of a JSON object
  248. // http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
  249. data, err := ioutil.ReadAll(resp.Body)
  250. if err != nil {
  251. return nil, err
  252. }
  253. return unMarshal(data)
  254. }
  255. func getOpenIDConfig(p *Provider, openIDAutoDiscoveryURL string) (*OpenIDConfig, error) {
  256. res, err := p.Client().Get(openIDAutoDiscoveryURL)
  257. if err != nil {
  258. return nil, err
  259. }
  260. defer res.Body.Close()
  261. body, err := ioutil.ReadAll(res.Body)
  262. if err != nil {
  263. return nil, err
  264. }
  265. openIDConfig := &OpenIDConfig{}
  266. err = json.Unmarshal(body, openIDConfig)
  267. if err != nil {
  268. return nil, err
  269. }
  270. return openIDConfig, nil
  271. }
  272. func newConfig(provider *Provider, scopes []string, openIDConfig *OpenIDConfig) *oauth2.Config {
  273. c := &oauth2.Config{
  274. ClientID: provider.ClientKey,
  275. ClientSecret: provider.Secret,
  276. RedirectURL: provider.CallbackURL,
  277. Endpoint: oauth2.Endpoint{
  278. AuthURL: openIDConfig.AuthEndpoint,
  279. TokenURL: openIDConfig.TokenEndpoint,
  280. },
  281. Scopes: []string{},
  282. }
  283. if len(scopes) > 0 {
  284. foundOpenIDScope := false
  285. for _, scope := range scopes {
  286. if scope == "openid" {
  287. foundOpenIDScope = true
  288. }
  289. c.Scopes = append(c.Scopes, scope)
  290. }
  291. if !foundOpenIDScope {
  292. c.Scopes = append(c.Scopes, "openid")
  293. }
  294. } else {
  295. c.Scopes = []string{"openid"}
  296. }
  297. return c
  298. }
  299. func getClaimValue(data map[string]interface{}, claims []string) string {
  300. for _, claim := range claims {
  301. if value, ok := data[claim]; ok {
  302. if stringValue, ok := value.(string); ok && len(stringValue) > 0 {
  303. return stringValue
  304. }
  305. }
  306. }
  307. return ""
  308. }
  309. func getClaimValues(data map[string]interface{}, claims []string) []string {
  310. var result []string
  311. for _, claim := range claims {
  312. if value, ok := data[claim]; ok {
  313. if stringValues, ok := value.([]interface{}); ok {
  314. for _, stringValue := range stringValues {
  315. if s, ok := stringValue.(string); ok && len(s) > 0 {
  316. result = append(result, s)
  317. }
  318. }
  319. }
  320. }
  321. }
  322. return result
  323. }
  324. // decodeJWT decodes a JSON Web Token into a simple map
  325. // http://openid.net/specs/draft-jones-json-web-token-07.html
  326. func decodeJWT(jwt string) (map[string]interface{}, error) {
  327. jwtParts := strings.Split(jwt, ".")
  328. if len(jwtParts) != 3 {
  329. return nil, errors.New("jws: invalid token received, not all parts available")
  330. }
  331. // Re-pad, if needed
  332. encodedPayload := jwtParts[1]
  333. if l := len(encodedPayload) % 4; l != 0 {
  334. encodedPayload += strings.Repeat("=", 4-l)
  335. }
  336. decodedPayload, err := base64.StdEncoding.DecodeString(encodedPayload)
  337. if err != nil {
  338. return nil, err
  339. }
  340. return unMarshal(decodedPayload)
  341. }
  342. func unMarshal(payload []byte) (map[string]interface{}, error) {
  343. data := make(map[string]interface{})
  344. return data, json.NewDecoder(bytes.NewBuffer(payload)).Decode(&data)
  345. }
上海开阖软件有限公司 沪ICP备12045867号-1