本站源代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

341 lines
9.3KB

  1. /*
  2. Package gothic wraps common behaviour when using Goth. This makes it quick, and easy, to get up
  3. and running with Goth. Of course, if you want complete control over how things flow, in regards
  4. to the authentication process, feel free and use Goth directly.
  5. See https://github.com/markbates/goth/blob/master/examples/main.go to see this in action.
  6. */
  7. package gothic
  8. import (
  9. "bytes"
  10. "compress/gzip"
  11. "crypto/rand"
  12. "encoding/base64"
  13. "errors"
  14. "fmt"
  15. "io"
  16. "io/ioutil"
  17. "net/http"
  18. "net/url"
  19. "os"
  20. "strings"
  21. "github.com/gorilla/mux"
  22. "github.com/gorilla/sessions"
  23. "github.com/markbates/goth"
  24. )
  25. // SessionName is the key used to access the session store.
  26. const SessionName = "_gothic_session"
  27. // Store can/should be set by applications using gothic. The default is a cookie store.
  28. var Store sessions.Store
  29. var defaultStore sessions.Store
  30. var keySet = false
  31. func init() {
  32. key := []byte(os.Getenv("SESSION_SECRET"))
  33. keySet = len(key) != 0
  34. cookieStore := sessions.NewCookieStore([]byte(key))
  35. cookieStore.Options.HttpOnly = true
  36. Store = cookieStore
  37. defaultStore = Store
  38. }
  39. /*
  40. BeginAuthHandler is a convenience handler for starting the authentication process.
  41. It expects to be able to get the name of the provider from the query parameters
  42. as either "provider" or ":provider".
  43. BeginAuthHandler will redirect the user to the appropriate authentication end-point
  44. for the requested provider.
  45. See https://github.com/markbates/goth/examples/main.go to see this in action.
  46. */
  47. func BeginAuthHandler(res http.ResponseWriter, req *http.Request) {
  48. url, err := GetAuthURL(res, req)
  49. if err != nil {
  50. res.WriteHeader(http.StatusBadRequest)
  51. fmt.Fprintln(res, err)
  52. return
  53. }
  54. http.Redirect(res, req, url, http.StatusTemporaryRedirect)
  55. }
  56. // SetState sets the state string associated with the given request.
  57. // If no state string is associated with the request, one will be generated.
  58. // This state is sent to the provider and can be retrieved during the
  59. // callback.
  60. var SetState = func(req *http.Request) string {
  61. state := req.URL.Query().Get("state")
  62. if len(state) > 0 {
  63. return state
  64. }
  65. // If a state query param is not passed in, generate a random
  66. // base64-encoded nonce so that the state on the auth URL
  67. // is unguessable, preventing CSRF attacks, as described in
  68. //
  69. // https://auth0.com/docs/protocols/oauth2/oauth-state#keep-reading
  70. nonceBytes := make([]byte, 64)
  71. _, err := io.ReadFull(rand.Reader, nonceBytes)
  72. if err != nil {
  73. panic("gothic: source of randomness unavailable: " + err.Error())
  74. }
  75. return base64.URLEncoding.EncodeToString(nonceBytes)
  76. }
  77. // GetState gets the state returned by the provider during the callback.
  78. // This is used to prevent CSRF attacks, see
  79. // http://tools.ietf.org/html/rfc6749#section-10.12
  80. var GetState = func(req *http.Request) string {
  81. return req.URL.Query().Get("state")
  82. }
  83. /*
  84. GetAuthURL starts the authentication process with the requested provided.
  85. It will return a URL that should be used to send users to.
  86. It expects to be able to get the name of the provider from the query parameters
  87. as either "provider" or ":provider".
  88. I would recommend using the BeginAuthHandler instead of doing all of these steps
  89. yourself, but that's entirely up to you.
  90. */
  91. func GetAuthURL(res http.ResponseWriter, req *http.Request) (string, error) {
  92. if !keySet && defaultStore == Store {
  93. fmt.Println("goth/gothic: no SESSION_SECRET environment variable is set. The default cookie store is not available and any calls will fail. Ignore this warning if you are using a different store.")
  94. }
  95. providerName, err := GetProviderName(req)
  96. if err != nil {
  97. return "", err
  98. }
  99. provider, err := goth.GetProvider(providerName)
  100. if err != nil {
  101. return "", err
  102. }
  103. sess, err := provider.BeginAuth(SetState(req))
  104. if err != nil {
  105. return "", err
  106. }
  107. url, err := sess.GetAuthURL()
  108. if err != nil {
  109. return "", err
  110. }
  111. err = StoreInSession(providerName, sess.Marshal(), req, res)
  112. if err != nil {
  113. return "", err
  114. }
  115. return url, err
  116. }
  117. /*
  118. CompleteUserAuth does what it says on the tin. It completes the authentication
  119. process and fetches all of the basic information about the user from the provider.
  120. It expects to be able to get the name of the provider from the query parameters
  121. as either "provider" or ":provider".
  122. See https://github.com/markbates/goth/examples/main.go to see this in action.
  123. */
  124. var CompleteUserAuth = func(res http.ResponseWriter, req *http.Request) (goth.User, error) {
  125. defer Logout(res, req)
  126. if !keySet && defaultStore == Store {
  127. fmt.Println("goth/gothic: no SESSION_SECRET environment variable is set. The default cookie store is not available and any calls will fail. Ignore this warning if you are using a different store.")
  128. }
  129. providerName, err := GetProviderName(req)
  130. if err != nil {
  131. return goth.User{}, err
  132. }
  133. provider, err := goth.GetProvider(providerName)
  134. if err != nil {
  135. return goth.User{}, err
  136. }
  137. value, err := GetFromSession(providerName, req)
  138. if err != nil {
  139. return goth.User{}, err
  140. }
  141. sess, err := provider.UnmarshalSession(value)
  142. if err != nil {
  143. return goth.User{}, err
  144. }
  145. err = validateState(req, sess)
  146. if err != nil {
  147. return goth.User{}, err
  148. }
  149. user, err := provider.FetchUser(sess)
  150. if err == nil {
  151. // user can be found with existing session data
  152. return user, err
  153. }
  154. // get new token and retry fetch
  155. _, err = sess.Authorize(provider, req.URL.Query())
  156. if err != nil {
  157. return goth.User{}, err
  158. }
  159. err = StoreInSession(providerName, sess.Marshal(), req, res)
  160. if err != nil {
  161. return goth.User{}, err
  162. }
  163. gu, err := provider.FetchUser(sess)
  164. return gu, err
  165. }
  166. // validateState ensures that the state token param from the original
  167. // AuthURL matches the one included in the current (callback) request.
  168. func validateState(req *http.Request, sess goth.Session) error {
  169. rawAuthURL, err := sess.GetAuthURL()
  170. if err != nil {
  171. return err
  172. }
  173. authURL, err := url.Parse(rawAuthURL)
  174. if err != nil {
  175. return err
  176. }
  177. originalState := authURL.Query().Get("state")
  178. if originalState != "" && (originalState != req.URL.Query().Get("state")) {
  179. return errors.New("state token mismatch")
  180. }
  181. return nil
  182. }
  183. // Logout invalidates a user session.
  184. func Logout(res http.ResponseWriter, req *http.Request) error {
  185. session, err := Store.Get(req, SessionName)
  186. if err != nil {
  187. return err
  188. }
  189. session.Options.MaxAge = -1
  190. session.Values = make(map[interface{}]interface{})
  191. err = session.Save(req, res)
  192. if err != nil {
  193. return errors.New("Could not delete user session ")
  194. }
  195. return nil
  196. }
  197. // GetProviderName is a function used to get the name of a provider
  198. // for a given request. By default, this provider is fetched from
  199. // the URL query string. If you provide it in a different way,
  200. // assign your own function to this variable that returns the provider
  201. // name for your request.
  202. var GetProviderName = getProviderName
  203. func getProviderName(req *http.Request) (string, error) {
  204. // try to get it from the url param "provider"
  205. if p := req.URL.Query().Get("provider"); p != "" {
  206. return p, nil
  207. }
  208. // try to get it from the url param ":provider"
  209. if p := req.URL.Query().Get(":provider"); p != "" {
  210. return p, nil
  211. }
  212. // try to get it from the context's value of "provider" key
  213. if p, ok := mux.Vars(req)["provider"]; ok {
  214. return p, nil
  215. }
  216. // try to get it from the go-context's value of "provider" key
  217. if p, ok := req.Context().Value("provider").(string); ok {
  218. return p, nil
  219. }
  220. // As a fallback, loop over the used providers, if we already have a valid session for any provider (ie. user has already begun authentication with a provider), then return that provider name
  221. providers := goth.GetProviders()
  222. session, _ := Store.Get(req, SessionName)
  223. for _, provider := range providers {
  224. p := provider.Name()
  225. value := session.Values[p]
  226. if _, ok := value.(string); ok {
  227. return p, nil
  228. }
  229. }
  230. // if not found then return an empty string with the corresponding error
  231. return "", errors.New("you must select a provider")
  232. }
  233. // StoreInSession stores a specified key/value pair in the session.
  234. func StoreInSession(key string, value string, req *http.Request, res http.ResponseWriter) error {
  235. session, _ := Store.New(req, SessionName)
  236. if err := updateSessionValue(session, key, value); err != nil {
  237. return err
  238. }
  239. return session.Save(req, res)
  240. }
  241. // GetFromSession retrieves a previously-stored value from the session.
  242. // If no value has previously been stored at the specified key, it will return an error.
  243. func GetFromSession(key string, req *http.Request) (string, error) {
  244. session, _ := Store.Get(req, SessionName)
  245. value, err := getSessionValue(session, key)
  246. if err != nil {
  247. return "", errors.New("could not find a matching session for this request")
  248. }
  249. return value, nil
  250. }
  251. func getSessionValue(session *sessions.Session, key string) (string, error) {
  252. value := session.Values[key]
  253. if value == nil {
  254. return "", fmt.Errorf("could not find a matching session for this request")
  255. }
  256. rdata := strings.NewReader(value.(string))
  257. r, err := gzip.NewReader(rdata)
  258. if err != nil {
  259. return "", err
  260. }
  261. s, err := ioutil.ReadAll(r)
  262. if err != nil {
  263. return "", err
  264. }
  265. return string(s), nil
  266. }
  267. func updateSessionValue(session *sessions.Session, key, value string) error {
  268. var b bytes.Buffer
  269. gz := gzip.NewWriter(&b)
  270. if _, err := gz.Write([]byte(value)); err != nil {
  271. return err
  272. }
  273. if err := gz.Flush(); err != nil {
  274. return err
  275. }
  276. if err := gz.Close(); err != nil {
  277. return err
  278. }
  279. session.Values[key] = b.String()
  280. return nil
  281. }
上海开阖软件有限公司 沪ICP备12045867号-1