本站源代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

219 lines
5.7KB

  1. // Copyright 2012 The Gorilla Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package sessions
  5. import (
  6. "context"
  7. "encoding/gob"
  8. "fmt"
  9. "net/http"
  10. "time"
  11. )
  12. // Default flashes key.
  13. const flashesKey = "_flash"
  14. // Session --------------------------------------------------------------------
  15. // NewSession is called by session stores to create a new session instance.
  16. func NewSession(store Store, name string) *Session {
  17. return &Session{
  18. Values: make(map[interface{}]interface{}),
  19. store: store,
  20. name: name,
  21. Options: new(Options),
  22. }
  23. }
  24. // Session stores the values and optional configuration for a session.
  25. type Session struct {
  26. // The ID of the session, generated by stores. It should not be used for
  27. // user data.
  28. ID string
  29. // Values contains the user-data for the session.
  30. Values map[interface{}]interface{}
  31. Options *Options
  32. IsNew bool
  33. store Store
  34. name string
  35. }
  36. // Flashes returns a slice of flash messages from the session.
  37. //
  38. // A single variadic argument is accepted, and it is optional: it defines
  39. // the flash key. If not defined "_flash" is used by default.
  40. func (s *Session) Flashes(vars ...string) []interface{} {
  41. var flashes []interface{}
  42. key := flashesKey
  43. if len(vars) > 0 {
  44. key = vars[0]
  45. }
  46. if v, ok := s.Values[key]; ok {
  47. // Drop the flashes and return it.
  48. delete(s.Values, key)
  49. flashes = v.([]interface{})
  50. }
  51. return flashes
  52. }
  53. // AddFlash adds a flash message to the session.
  54. //
  55. // A single variadic argument is accepted, and it is optional: it defines
  56. // the flash key. If not defined "_flash" is used by default.
  57. func (s *Session) AddFlash(value interface{}, vars ...string) {
  58. key := flashesKey
  59. if len(vars) > 0 {
  60. key = vars[0]
  61. }
  62. var flashes []interface{}
  63. if v, ok := s.Values[key]; ok {
  64. flashes = v.([]interface{})
  65. }
  66. s.Values[key] = append(flashes, value)
  67. }
  68. // Save is a convenience method to save this session. It is the same as calling
  69. // store.Save(request, response, session). You should call Save before writing to
  70. // the response or returning from the handler.
  71. func (s *Session) Save(r *http.Request, w http.ResponseWriter) error {
  72. return s.store.Save(r, w, s)
  73. }
  74. // Name returns the name used to register the session.
  75. func (s *Session) Name() string {
  76. return s.name
  77. }
  78. // Store returns the session store used to register the session.
  79. func (s *Session) Store() Store {
  80. return s.store
  81. }
  82. // Registry -------------------------------------------------------------------
  83. // sessionInfo stores a session tracked by the registry.
  84. type sessionInfo struct {
  85. s *Session
  86. e error
  87. }
  88. // contextKey is the type used to store the registry in the context.
  89. type contextKey int
  90. // registryKey is the key used to store the registry in the context.
  91. const registryKey contextKey = 0
  92. // GetRegistry returns a registry instance for the current request.
  93. func GetRegistry(r *http.Request) *Registry {
  94. var ctx = r.Context()
  95. registry := ctx.Value(registryKey)
  96. if registry != nil {
  97. return registry.(*Registry)
  98. }
  99. newRegistry := &Registry{
  100. request: r,
  101. sessions: make(map[string]sessionInfo),
  102. }
  103. *r = *r.WithContext(context.WithValue(ctx, registryKey, newRegistry))
  104. return newRegistry
  105. }
  106. // Registry stores sessions used during a request.
  107. type Registry struct {
  108. request *http.Request
  109. sessions map[string]sessionInfo
  110. }
  111. // Get registers and returns a session for the given name and session store.
  112. //
  113. // It returns a new session if there are no sessions registered for the name.
  114. func (s *Registry) Get(store Store, name string) (session *Session, err error) {
  115. if !isCookieNameValid(name) {
  116. return nil, fmt.Errorf("sessions: invalid character in cookie name: %s", name)
  117. }
  118. if info, ok := s.sessions[name]; ok {
  119. session, err = info.s, info.e
  120. } else {
  121. session, err = store.New(s.request, name)
  122. session.name = name
  123. s.sessions[name] = sessionInfo{s: session, e: err}
  124. }
  125. session.store = store
  126. return
  127. }
  128. // Save saves all sessions registered for the current request.
  129. func (s *Registry) Save(w http.ResponseWriter) error {
  130. var errMulti MultiError
  131. for name, info := range s.sessions {
  132. session := info.s
  133. if session.store == nil {
  134. errMulti = append(errMulti, fmt.Errorf(
  135. "sessions: missing store for session %q", name))
  136. } else if err := session.store.Save(s.request, w, session); err != nil {
  137. errMulti = append(errMulti, fmt.Errorf(
  138. "sessions: error saving session %q -- %v", name, err))
  139. }
  140. }
  141. if errMulti != nil {
  142. return errMulti
  143. }
  144. return nil
  145. }
  146. // Helpers --------------------------------------------------------------------
  147. func init() {
  148. gob.Register([]interface{}{})
  149. }
  150. // Save saves all sessions used during the current request.
  151. func Save(r *http.Request, w http.ResponseWriter) error {
  152. return GetRegistry(r).Save(w)
  153. }
  154. // NewCookie returns an http.Cookie with the options set. It also sets
  155. // the Expires field calculated based on the MaxAge value, for Internet
  156. // Explorer compatibility.
  157. func NewCookie(name, value string, options *Options) *http.Cookie {
  158. cookie := newCookieFromOptions(name, value, options)
  159. if options.MaxAge > 0 {
  160. d := time.Duration(options.MaxAge) * time.Second
  161. cookie.Expires = time.Now().Add(d)
  162. } else if options.MaxAge < 0 {
  163. // Set it to the past to expire now.
  164. cookie.Expires = time.Unix(1, 0)
  165. }
  166. return cookie
  167. }
  168. // Error ----------------------------------------------------------------------
  169. // MultiError stores multiple errors.
  170. //
  171. // Borrowed from the App Engine SDK.
  172. type MultiError []error
  173. func (m MultiError) Error() string {
  174. s, n := "", 0
  175. for _, e := range m {
  176. if e != nil {
  177. if n == 0 {
  178. s = e.Error()
  179. }
  180. n++
  181. }
  182. }
  183. switch n {
  184. case 0:
  185. return "(0 errors)"
  186. case 1:
  187. return s
  188. case 2:
  189. return s + " (and 1 other error)"
  190. }
  191. return fmt.Sprintf("%s (and %d other errors)", s, n-1)
  192. }
上海开阖软件有限公司 沪ICP备12045867号-1