本站源代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

591 lines
18KB

  1. // Copyright 2015 go-swagger maintainers
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package middleware
  15. import (
  16. stdContext "context"
  17. "fmt"
  18. "net/http"
  19. "strings"
  20. "sync"
  21. "github.com/go-openapi/runtime/security"
  22. "github.com/go-openapi/analysis"
  23. "github.com/go-openapi/errors"
  24. "github.com/go-openapi/loads"
  25. "github.com/go-openapi/runtime"
  26. "github.com/go-openapi/runtime/logger"
  27. "github.com/go-openapi/runtime/middleware/untyped"
  28. "github.com/go-openapi/spec"
  29. "github.com/go-openapi/strfmt"
  30. )
  31. // Debug when true turns on verbose logging
  32. var Debug = logger.DebugEnabled()
  33. var Logger logger.Logger = logger.StandardLogger{}
  34. func debugLog(format string, args ...interface{}) {
  35. if Debug {
  36. Logger.Printf(format, args...)
  37. }
  38. }
  39. // A Builder can create middlewares
  40. type Builder func(http.Handler) http.Handler
  41. // PassthroughBuilder returns the handler, aka the builder identity function
  42. func PassthroughBuilder(handler http.Handler) http.Handler { return handler }
  43. // RequestBinder is an interface for types to implement
  44. // when they want to be able to bind from a request
  45. type RequestBinder interface {
  46. BindRequest(*http.Request, *MatchedRoute) error
  47. }
  48. // Responder is an interface for types to implement
  49. // when they want to be considered for writing HTTP responses
  50. type Responder interface {
  51. WriteResponse(http.ResponseWriter, runtime.Producer)
  52. }
  53. // ResponderFunc wraps a func as a Responder interface
  54. type ResponderFunc func(http.ResponseWriter, runtime.Producer)
  55. // WriteResponse writes to the response
  56. func (fn ResponderFunc) WriteResponse(rw http.ResponseWriter, pr runtime.Producer) {
  57. fn(rw, pr)
  58. }
  59. // Context is a type safe wrapper around an untyped request context
  60. // used throughout to store request context with the standard context attached
  61. // to the http.Request
  62. type Context struct {
  63. spec *loads.Document
  64. analyzer *analysis.Spec
  65. api RoutableAPI
  66. router Router
  67. }
  68. type routableUntypedAPI struct {
  69. api *untyped.API
  70. hlock *sync.Mutex
  71. handlers map[string]map[string]http.Handler
  72. defaultConsumes string
  73. defaultProduces string
  74. }
  75. func newRoutableUntypedAPI(spec *loads.Document, api *untyped.API, context *Context) *routableUntypedAPI {
  76. var handlers map[string]map[string]http.Handler
  77. if spec == nil || api == nil {
  78. return nil
  79. }
  80. analyzer := analysis.New(spec.Spec())
  81. for method, hls := range analyzer.Operations() {
  82. um := strings.ToUpper(method)
  83. for path, op := range hls {
  84. schemes := analyzer.SecurityRequirementsFor(op)
  85. if oh, ok := api.OperationHandlerFor(method, path); ok {
  86. if handlers == nil {
  87. handlers = make(map[string]map[string]http.Handler)
  88. }
  89. if b, ok := handlers[um]; !ok || b == nil {
  90. handlers[um] = make(map[string]http.Handler)
  91. }
  92. var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  93. // lookup route info in the context
  94. route, rCtx, _ := context.RouteInfo(r)
  95. if rCtx != nil {
  96. r = rCtx
  97. }
  98. // bind and validate the request using reflection
  99. var bound interface{}
  100. var validation error
  101. bound, r, validation = context.BindAndValidate(r, route)
  102. if validation != nil {
  103. context.Respond(w, r, route.Produces, route, validation)
  104. return
  105. }
  106. // actually handle the request
  107. result, err := oh.Handle(bound)
  108. if err != nil {
  109. // respond with failure
  110. context.Respond(w, r, route.Produces, route, err)
  111. return
  112. }
  113. // respond with success
  114. context.Respond(w, r, route.Produces, route, result)
  115. })
  116. if len(schemes) > 0 {
  117. handler = newSecureAPI(context, handler)
  118. }
  119. handlers[um][path] = handler
  120. }
  121. }
  122. }
  123. return &routableUntypedAPI{
  124. api: api,
  125. hlock: new(sync.Mutex),
  126. handlers: handlers,
  127. defaultProduces: api.DefaultProduces,
  128. defaultConsumes: api.DefaultConsumes,
  129. }
  130. }
  131. func (r *routableUntypedAPI) HandlerFor(method, path string) (http.Handler, bool) {
  132. r.hlock.Lock()
  133. paths, ok := r.handlers[strings.ToUpper(method)]
  134. if !ok {
  135. r.hlock.Unlock()
  136. return nil, false
  137. }
  138. handler, ok := paths[path]
  139. r.hlock.Unlock()
  140. return handler, ok
  141. }
  142. func (r *routableUntypedAPI) ServeErrorFor(operationID string) func(http.ResponseWriter, *http.Request, error) {
  143. return r.api.ServeError
  144. }
  145. func (r *routableUntypedAPI) ConsumersFor(mediaTypes []string) map[string]runtime.Consumer {
  146. return r.api.ConsumersFor(mediaTypes)
  147. }
  148. func (r *routableUntypedAPI) ProducersFor(mediaTypes []string) map[string]runtime.Producer {
  149. return r.api.ProducersFor(mediaTypes)
  150. }
  151. func (r *routableUntypedAPI) AuthenticatorsFor(schemes map[string]spec.SecurityScheme) map[string]runtime.Authenticator {
  152. return r.api.AuthenticatorsFor(schemes)
  153. }
  154. func (r *routableUntypedAPI) Authorizer() runtime.Authorizer {
  155. return r.api.Authorizer()
  156. }
  157. func (r *routableUntypedAPI) Formats() strfmt.Registry {
  158. return r.api.Formats()
  159. }
  160. func (r *routableUntypedAPI) DefaultProduces() string {
  161. return r.defaultProduces
  162. }
  163. func (r *routableUntypedAPI) DefaultConsumes() string {
  164. return r.defaultConsumes
  165. }
  166. // NewRoutableContext creates a new context for a routable API
  167. func NewRoutableContext(spec *loads.Document, routableAPI RoutableAPI, routes Router) *Context {
  168. var an *analysis.Spec
  169. if spec != nil {
  170. an = analysis.New(spec.Spec())
  171. }
  172. ctx := &Context{spec: spec, api: routableAPI, analyzer: an, router: routes}
  173. return ctx
  174. }
  175. // NewContext creates a new context wrapper
  176. func NewContext(spec *loads.Document, api *untyped.API, routes Router) *Context {
  177. var an *analysis.Spec
  178. if spec != nil {
  179. an = analysis.New(spec.Spec())
  180. }
  181. ctx := &Context{spec: spec, analyzer: an}
  182. ctx.api = newRoutableUntypedAPI(spec, api, ctx)
  183. ctx.router = routes
  184. return ctx
  185. }
  186. // Serve serves the specified spec with the specified api registrations as a http.Handler
  187. func Serve(spec *loads.Document, api *untyped.API) http.Handler {
  188. return ServeWithBuilder(spec, api, PassthroughBuilder)
  189. }
  190. // ServeWithBuilder serves the specified spec with the specified api registrations as a http.Handler that is decorated
  191. // by the Builder
  192. func ServeWithBuilder(spec *loads.Document, api *untyped.API, builder Builder) http.Handler {
  193. context := NewContext(spec, api, nil)
  194. return context.APIHandler(builder)
  195. }
  196. type contextKey int8
  197. const (
  198. _ contextKey = iota
  199. ctxContentType
  200. ctxResponseFormat
  201. ctxMatchedRoute
  202. ctxBoundParams
  203. ctxSecurityPrincipal
  204. ctxSecurityScopes
  205. )
  206. // MatchedRouteFrom request context value.
  207. func MatchedRouteFrom(req *http.Request) *MatchedRoute {
  208. mr := req.Context().Value(ctxMatchedRoute)
  209. if mr == nil {
  210. return nil
  211. }
  212. if res, ok := mr.(*MatchedRoute); ok {
  213. return res
  214. }
  215. return nil
  216. }
  217. // SecurityPrincipalFrom request context value.
  218. func SecurityPrincipalFrom(req *http.Request) interface{} {
  219. return req.Context().Value(ctxSecurityPrincipal)
  220. }
  221. // SecurityScopesFrom request context value.
  222. func SecurityScopesFrom(req *http.Request) []string {
  223. rs := req.Context().Value(ctxSecurityScopes)
  224. if res, ok := rs.([]string); ok {
  225. return res
  226. }
  227. return nil
  228. }
  229. type contentTypeValue struct {
  230. MediaType string
  231. Charset string
  232. }
  233. // BasePath returns the base path for this API
  234. func (c *Context) BasePath() string {
  235. return c.spec.BasePath()
  236. }
  237. // RequiredProduces returns the accepted content types for responses
  238. func (c *Context) RequiredProduces() []string {
  239. return c.analyzer.RequiredProduces()
  240. }
  241. // BindValidRequest binds a params object to a request but only when the request is valid
  242. // if the request is not valid an error will be returned
  243. func (c *Context) BindValidRequest(request *http.Request, route *MatchedRoute, binder RequestBinder) error {
  244. var res []error
  245. requestContentType := "*/*"
  246. // check and validate content type, select consumer
  247. if runtime.HasBody(request) {
  248. ct, _, err := runtime.ContentType(request.Header)
  249. if err != nil {
  250. res = append(res, err)
  251. } else {
  252. if err := validateContentType(route.Consumes, ct); err != nil {
  253. res = append(res, err)
  254. }
  255. if len(res) == 0 {
  256. cons, ok := route.Consumers[ct]
  257. if !ok {
  258. res = append(res, errors.New(500, "no consumer registered for %s", ct))
  259. } else {
  260. route.Consumer = cons
  261. requestContentType = ct
  262. }
  263. }
  264. }
  265. }
  266. // check and validate the response format
  267. if len(res) == 0 && runtime.HasBody(request) {
  268. if str := NegotiateContentType(request, route.Produces, requestContentType); str == "" {
  269. res = append(res, errors.InvalidResponseFormat(request.Header.Get(runtime.HeaderAccept), route.Produces))
  270. }
  271. }
  272. // now bind the request with the provided binder
  273. // it's assumed the binder will also validate the request and return an error if the
  274. // request is invalid
  275. if binder != nil && len(res) == 0 {
  276. if err := binder.BindRequest(request, route); err != nil {
  277. return err
  278. }
  279. }
  280. if len(res) > 0 {
  281. return errors.CompositeValidationError(res...)
  282. }
  283. return nil
  284. }
  285. // ContentType gets the parsed value of a content type
  286. // Returns the media type, its charset and a shallow copy of the request
  287. // when its context doesn't contain the content type value, otherwise it returns
  288. // the same request
  289. // Returns the error that runtime.ContentType may retunrs.
  290. func (c *Context) ContentType(request *http.Request) (string, string, *http.Request, error) {
  291. var rCtx = request.Context()
  292. if v, ok := rCtx.Value(ctxContentType).(*contentTypeValue); ok {
  293. return v.MediaType, v.Charset, request, nil
  294. }
  295. mt, cs, err := runtime.ContentType(request.Header)
  296. if err != nil {
  297. return "", "", nil, err
  298. }
  299. rCtx = stdContext.WithValue(rCtx, ctxContentType, &contentTypeValue{mt, cs})
  300. return mt, cs, request.WithContext(rCtx), nil
  301. }
  302. // LookupRoute looks a route up and returns true when it is found
  303. func (c *Context) LookupRoute(request *http.Request) (*MatchedRoute, bool) {
  304. if route, ok := c.router.Lookup(request.Method, request.URL.EscapedPath()); ok {
  305. return route, ok
  306. }
  307. return nil, false
  308. }
  309. // RouteInfo tries to match a route for this request
  310. // Returns the matched route, a shallow copy of the request if its context
  311. // contains the matched router, otherwise the same request, and a bool to
  312. // indicate if it the request matches one of the routes, if it doesn't
  313. // then it returns false and nil for the other two return values
  314. func (c *Context) RouteInfo(request *http.Request) (*MatchedRoute, *http.Request, bool) {
  315. var rCtx = request.Context()
  316. if v, ok := rCtx.Value(ctxMatchedRoute).(*MatchedRoute); ok {
  317. return v, request, ok
  318. }
  319. if route, ok := c.LookupRoute(request); ok {
  320. rCtx = stdContext.WithValue(rCtx, ctxMatchedRoute, route)
  321. return route, request.WithContext(rCtx), ok
  322. }
  323. return nil, nil, false
  324. }
  325. // ResponseFormat negotiates the response content type
  326. // Returns the response format and a shallow copy of the request if its context
  327. // doesn't contain the response format, otherwise the same request
  328. func (c *Context) ResponseFormat(r *http.Request, offers []string) (string, *http.Request) {
  329. var rCtx = r.Context()
  330. if v, ok := rCtx.Value(ctxResponseFormat).(string); ok {
  331. debugLog("[%s %s] found response format %q in context", r.Method, r.URL.Path, v)
  332. return v, r
  333. }
  334. format := NegotiateContentType(r, offers, "")
  335. if format != "" {
  336. debugLog("[%s %s] set response format %q in context", r.Method, r.URL.Path, format)
  337. r = r.WithContext(stdContext.WithValue(rCtx, ctxResponseFormat, format))
  338. }
  339. debugLog("[%s %s] negotiated response format %q", r.Method, r.URL.Path, format)
  340. return format, r
  341. }
  342. // AllowedMethods gets the allowed methods for the path of this request
  343. func (c *Context) AllowedMethods(request *http.Request) []string {
  344. return c.router.OtherMethods(request.Method, request.URL.EscapedPath())
  345. }
  346. // ResetAuth removes the current principal from the request context
  347. func (c *Context) ResetAuth(request *http.Request) *http.Request {
  348. rctx := request.Context()
  349. rctx = stdContext.WithValue(rctx, ctxSecurityPrincipal, nil)
  350. rctx = stdContext.WithValue(rctx, ctxSecurityScopes, nil)
  351. return request.WithContext(rctx)
  352. }
  353. // Authorize authorizes the request
  354. // Returns the principal object and a shallow copy of the request when its
  355. // context doesn't contain the principal, otherwise the same request or an error
  356. // (the last) if one of the authenticators returns one or an Unauthenticated error
  357. func (c *Context) Authorize(request *http.Request, route *MatchedRoute) (interface{}, *http.Request, error) {
  358. if route == nil || !route.HasAuth() {
  359. return nil, nil, nil
  360. }
  361. var rCtx = request.Context()
  362. if v := rCtx.Value(ctxSecurityPrincipal); v != nil {
  363. return v, request, nil
  364. }
  365. applies, usr, err := route.Authenticators.Authenticate(request, route)
  366. if !applies || err != nil || !route.Authenticators.AllowsAnonymous() && usr == nil {
  367. if err != nil {
  368. return nil, nil, err
  369. }
  370. return nil, nil, errors.Unauthenticated("invalid credentials")
  371. }
  372. if route.Authorizer != nil {
  373. if err := route.Authorizer.Authorize(request, usr); err != nil {
  374. return nil, nil, errors.New(http.StatusForbidden, err.Error())
  375. }
  376. }
  377. rCtx = stdContext.WithValue(rCtx, ctxSecurityPrincipal, usr)
  378. rCtx = stdContext.WithValue(rCtx, ctxSecurityScopes, route.Authenticator.AllScopes())
  379. return usr, request.WithContext(rCtx), nil
  380. }
  381. // BindAndValidate binds and validates the request
  382. // Returns the validation map and a shallow copy of the request when its context
  383. // doesn't contain the validation, otherwise it returns the same request or an
  384. // CompositeValidationError error
  385. func (c *Context) BindAndValidate(request *http.Request, matched *MatchedRoute) (interface{}, *http.Request, error) {
  386. var rCtx = request.Context()
  387. if v, ok := rCtx.Value(ctxBoundParams).(*validation); ok {
  388. debugLog("got cached validation (valid: %t)", len(v.result) == 0)
  389. if len(v.result) > 0 {
  390. return v.bound, request, errors.CompositeValidationError(v.result...)
  391. }
  392. return v.bound, request, nil
  393. }
  394. result := validateRequest(c, request, matched)
  395. rCtx = stdContext.WithValue(rCtx, ctxBoundParams, result)
  396. request = request.WithContext(rCtx)
  397. if len(result.result) > 0 {
  398. return result.bound, request, errors.CompositeValidationError(result.result...)
  399. }
  400. debugLog("no validation errors found")
  401. return result.bound, request, nil
  402. }
  403. // NotFound the default not found responder for when no route has been matched yet
  404. func (c *Context) NotFound(rw http.ResponseWriter, r *http.Request) {
  405. c.Respond(rw, r, []string{c.api.DefaultProduces()}, nil, errors.NotFound("not found"))
  406. }
  407. // Respond renders the response after doing some content negotiation
  408. func (c *Context) Respond(rw http.ResponseWriter, r *http.Request, produces []string, route *MatchedRoute, data interface{}) {
  409. debugLog("responding to %s %s with produces: %v", r.Method, r.URL.Path, produces)
  410. offers := []string{}
  411. for _, mt := range produces {
  412. if mt != c.api.DefaultProduces() {
  413. offers = append(offers, mt)
  414. }
  415. }
  416. // the default producer is last so more specific producers take precedence
  417. offers = append(offers, c.api.DefaultProduces())
  418. debugLog("offers: %v", offers)
  419. var format string
  420. format, r = c.ResponseFormat(r, offers)
  421. rw.Header().Set(runtime.HeaderContentType, format)
  422. if resp, ok := data.(Responder); ok {
  423. producers := route.Producers
  424. prod, ok := producers[format]
  425. if !ok {
  426. prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()}))
  427. pr, ok := prods[c.api.DefaultProduces()]
  428. if !ok {
  429. panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format))
  430. }
  431. prod = pr
  432. }
  433. resp.WriteResponse(rw, prod)
  434. return
  435. }
  436. if err, ok := data.(error); ok {
  437. if format == "" {
  438. rw.Header().Set(runtime.HeaderContentType, runtime.JSONMime)
  439. }
  440. if realm := security.FailedBasicAuth(r); realm != "" {
  441. rw.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", realm))
  442. }
  443. if route == nil || route.Operation == nil {
  444. c.api.ServeErrorFor("")(rw, r, err)
  445. return
  446. }
  447. c.api.ServeErrorFor(route.Operation.ID)(rw, r, err)
  448. return
  449. }
  450. if route == nil || route.Operation == nil {
  451. rw.WriteHeader(200)
  452. if r.Method == "HEAD" {
  453. return
  454. }
  455. producers := c.api.ProducersFor(normalizeOffers(offers))
  456. prod, ok := producers[format]
  457. if !ok {
  458. panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format))
  459. }
  460. if err := prod.Produce(rw, data); err != nil {
  461. panic(err) // let the recovery middleware deal with this
  462. }
  463. return
  464. }
  465. if _, code, ok := route.Operation.SuccessResponse(); ok {
  466. rw.WriteHeader(code)
  467. if code == 204 || r.Method == "HEAD" {
  468. return
  469. }
  470. producers := route.Producers
  471. prod, ok := producers[format]
  472. if !ok {
  473. if !ok {
  474. prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()}))
  475. pr, ok := prods[c.api.DefaultProduces()]
  476. if !ok {
  477. panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format))
  478. }
  479. prod = pr
  480. }
  481. }
  482. if err := prod.Produce(rw, data); err != nil {
  483. panic(err) // let the recovery middleware deal with this
  484. }
  485. return
  486. }
  487. c.api.ServeErrorFor(route.Operation.ID)(rw, r, errors.New(http.StatusInternalServerError, "can't produce response"))
  488. }
  489. // APIHandler returns a handler to serve the API, this includes a swagger spec, router and the contract defined in the swagger spec
  490. func (c *Context) APIHandler(builder Builder) http.Handler {
  491. b := builder
  492. if b == nil {
  493. b = PassthroughBuilder
  494. }
  495. var title string
  496. sp := c.spec.Spec()
  497. if sp != nil && sp.Info != nil && sp.Info.Title != "" {
  498. title = sp.Info.Title
  499. }
  500. redocOpts := RedocOpts{
  501. BasePath: c.BasePath(),
  502. Title: title,
  503. }
  504. return Spec("", c.spec.Raw(), Redoc(redocOpts, c.RoutesHandler(b)))
  505. }
  506. // RoutesHandler returns a handler to serve the API, just the routes and the contract defined in the swagger spec
  507. func (c *Context) RoutesHandler(builder Builder) http.Handler {
  508. b := builder
  509. if b == nil {
  510. b = PassthroughBuilder
  511. }
  512. return NewRouter(c, b(NewOperationExecutor(c)))
  513. }
上海开阖软件有限公司 沪ICP备12045867号-1