|
- // Copyright 2015 go-swagger maintainers
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
-
- package middleware
-
- import (
- stdContext "context"
- "fmt"
- "net/http"
- "strings"
- "sync"
-
- "github.com/go-openapi/runtime/security"
-
- "github.com/go-openapi/analysis"
- "github.com/go-openapi/errors"
- "github.com/go-openapi/loads"
- "github.com/go-openapi/runtime"
- "github.com/go-openapi/runtime/logger"
- "github.com/go-openapi/runtime/middleware/untyped"
- "github.com/go-openapi/spec"
- "github.com/go-openapi/strfmt"
- )
-
- // Debug when true turns on verbose logging
- var Debug = logger.DebugEnabled()
- var Logger logger.Logger = logger.StandardLogger{}
-
- func debugLog(format string, args ...interface{}) {
- if Debug {
- Logger.Printf(format, args...)
- }
- }
-
- // A Builder can create middlewares
- type Builder func(http.Handler) http.Handler
-
- // PassthroughBuilder returns the handler, aka the builder identity function
- func PassthroughBuilder(handler http.Handler) http.Handler { return handler }
-
- // RequestBinder is an interface for types to implement
- // when they want to be able to bind from a request
- type RequestBinder interface {
- BindRequest(*http.Request, *MatchedRoute) error
- }
-
- // Responder is an interface for types to implement
- // when they want to be considered for writing HTTP responses
- type Responder interface {
- WriteResponse(http.ResponseWriter, runtime.Producer)
- }
-
- // ResponderFunc wraps a func as a Responder interface
- type ResponderFunc func(http.ResponseWriter, runtime.Producer)
-
- // WriteResponse writes to the response
- func (fn ResponderFunc) WriteResponse(rw http.ResponseWriter, pr runtime.Producer) {
- fn(rw, pr)
- }
-
- // Context is a type safe wrapper around an untyped request context
- // used throughout to store request context with the standard context attached
- // to the http.Request
- type Context struct {
- spec *loads.Document
- analyzer *analysis.Spec
- api RoutableAPI
- router Router
- }
-
- type routableUntypedAPI struct {
- api *untyped.API
- hlock *sync.Mutex
- handlers map[string]map[string]http.Handler
- defaultConsumes string
- defaultProduces string
- }
-
- func newRoutableUntypedAPI(spec *loads.Document, api *untyped.API, context *Context) *routableUntypedAPI {
- var handlers map[string]map[string]http.Handler
- if spec == nil || api == nil {
- return nil
- }
- analyzer := analysis.New(spec.Spec())
- for method, hls := range analyzer.Operations() {
- um := strings.ToUpper(method)
- for path, op := range hls {
- schemes := analyzer.SecurityRequirementsFor(op)
-
- if oh, ok := api.OperationHandlerFor(method, path); ok {
- if handlers == nil {
- handlers = make(map[string]map[string]http.Handler)
- }
- if b, ok := handlers[um]; !ok || b == nil {
- handlers[um] = make(map[string]http.Handler)
- }
-
- var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // lookup route info in the context
- route, rCtx, _ := context.RouteInfo(r)
- if rCtx != nil {
- r = rCtx
- }
-
- // bind and validate the request using reflection
- var bound interface{}
- var validation error
- bound, r, validation = context.BindAndValidate(r, route)
- if validation != nil {
- context.Respond(w, r, route.Produces, route, validation)
- return
- }
-
- // actually handle the request
- result, err := oh.Handle(bound)
- if err != nil {
- // respond with failure
- context.Respond(w, r, route.Produces, route, err)
- return
- }
-
- // respond with success
- context.Respond(w, r, route.Produces, route, result)
- })
-
- if len(schemes) > 0 {
- handler = newSecureAPI(context, handler)
- }
- handlers[um][path] = handler
- }
- }
- }
-
- return &routableUntypedAPI{
- api: api,
- hlock: new(sync.Mutex),
- handlers: handlers,
- defaultProduces: api.DefaultProduces,
- defaultConsumes: api.DefaultConsumes,
- }
- }
-
- func (r *routableUntypedAPI) HandlerFor(method, path string) (http.Handler, bool) {
- r.hlock.Lock()
- paths, ok := r.handlers[strings.ToUpper(method)]
- if !ok {
- r.hlock.Unlock()
- return nil, false
- }
- handler, ok := paths[path]
- r.hlock.Unlock()
- return handler, ok
- }
- func (r *routableUntypedAPI) ServeErrorFor(operationID string) func(http.ResponseWriter, *http.Request, error) {
- return r.api.ServeError
- }
- func (r *routableUntypedAPI) ConsumersFor(mediaTypes []string) map[string]runtime.Consumer {
- return r.api.ConsumersFor(mediaTypes)
- }
- func (r *routableUntypedAPI) ProducersFor(mediaTypes []string) map[string]runtime.Producer {
- return r.api.ProducersFor(mediaTypes)
- }
- func (r *routableUntypedAPI) AuthenticatorsFor(schemes map[string]spec.SecurityScheme) map[string]runtime.Authenticator {
- return r.api.AuthenticatorsFor(schemes)
- }
- func (r *routableUntypedAPI) Authorizer() runtime.Authorizer {
- return r.api.Authorizer()
- }
- func (r *routableUntypedAPI) Formats() strfmt.Registry {
- return r.api.Formats()
- }
-
- func (r *routableUntypedAPI) DefaultProduces() string {
- return r.defaultProduces
- }
-
- func (r *routableUntypedAPI) DefaultConsumes() string {
- return r.defaultConsumes
- }
-
- // NewRoutableContext creates a new context for a routable API
- func NewRoutableContext(spec *loads.Document, routableAPI RoutableAPI, routes Router) *Context {
- var an *analysis.Spec
- if spec != nil {
- an = analysis.New(spec.Spec())
- }
- ctx := &Context{spec: spec, api: routableAPI, analyzer: an, router: routes}
- return ctx
- }
-
- // NewContext creates a new context wrapper
- func NewContext(spec *loads.Document, api *untyped.API, routes Router) *Context {
- var an *analysis.Spec
- if spec != nil {
- an = analysis.New(spec.Spec())
- }
- ctx := &Context{spec: spec, analyzer: an}
- ctx.api = newRoutableUntypedAPI(spec, api, ctx)
- ctx.router = routes
- return ctx
- }
-
- // Serve serves the specified spec with the specified api registrations as a http.Handler
- func Serve(spec *loads.Document, api *untyped.API) http.Handler {
- return ServeWithBuilder(spec, api, PassthroughBuilder)
- }
-
- // ServeWithBuilder serves the specified spec with the specified api registrations as a http.Handler that is decorated
- // by the Builder
- func ServeWithBuilder(spec *loads.Document, api *untyped.API, builder Builder) http.Handler {
- context := NewContext(spec, api, nil)
- return context.APIHandler(builder)
- }
-
- type contextKey int8
-
- const (
- _ contextKey = iota
- ctxContentType
- ctxResponseFormat
- ctxMatchedRoute
- ctxBoundParams
- ctxSecurityPrincipal
- ctxSecurityScopes
- )
-
- // MatchedRouteFrom request context value.
- func MatchedRouteFrom(req *http.Request) *MatchedRoute {
- mr := req.Context().Value(ctxMatchedRoute)
- if mr == nil {
- return nil
- }
- if res, ok := mr.(*MatchedRoute); ok {
- return res
- }
- return nil
- }
-
- // SecurityPrincipalFrom request context value.
- func SecurityPrincipalFrom(req *http.Request) interface{} {
- return req.Context().Value(ctxSecurityPrincipal)
- }
-
- // SecurityScopesFrom request context value.
- func SecurityScopesFrom(req *http.Request) []string {
- rs := req.Context().Value(ctxSecurityScopes)
- if res, ok := rs.([]string); ok {
- return res
- }
- return nil
- }
-
- type contentTypeValue struct {
- MediaType string
- Charset string
- }
-
- // BasePath returns the base path for this API
- func (c *Context) BasePath() string {
- return c.spec.BasePath()
- }
-
- // RequiredProduces returns the accepted content types for responses
- func (c *Context) RequiredProduces() []string {
- return c.analyzer.RequiredProduces()
- }
-
- // BindValidRequest binds a params object to a request but only when the request is valid
- // if the request is not valid an error will be returned
- func (c *Context) BindValidRequest(request *http.Request, route *MatchedRoute, binder RequestBinder) error {
- var res []error
-
- requestContentType := "*/*"
- // check and validate content type, select consumer
- if runtime.HasBody(request) {
- ct, _, err := runtime.ContentType(request.Header)
- if err != nil {
- res = append(res, err)
- } else {
- if err := validateContentType(route.Consumes, ct); err != nil {
- res = append(res, err)
- }
- if len(res) == 0 {
- cons, ok := route.Consumers[ct]
- if !ok {
- res = append(res, errors.New(500, "no consumer registered for %s", ct))
- } else {
- route.Consumer = cons
- requestContentType = ct
- }
- }
- }
- }
-
- // check and validate the response format
- if len(res) == 0 && runtime.HasBody(request) {
- if str := NegotiateContentType(request, route.Produces, requestContentType); str == "" {
- res = append(res, errors.InvalidResponseFormat(request.Header.Get(runtime.HeaderAccept), route.Produces))
- }
- }
-
- // now bind the request with the provided binder
- // it's assumed the binder will also validate the request and return an error if the
- // request is invalid
- if binder != nil && len(res) == 0 {
- if err := binder.BindRequest(request, route); err != nil {
- return err
- }
- }
-
- if len(res) > 0 {
- return errors.CompositeValidationError(res...)
- }
- return nil
- }
-
- // ContentType gets the parsed value of a content type
- // Returns the media type, its charset and a shallow copy of the request
- // when its context doesn't contain the content type value, otherwise it returns
- // the same request
- // Returns the error that runtime.ContentType may retunrs.
- func (c *Context) ContentType(request *http.Request) (string, string, *http.Request, error) {
- var rCtx = request.Context()
-
- if v, ok := rCtx.Value(ctxContentType).(*contentTypeValue); ok {
- return v.MediaType, v.Charset, request, nil
- }
-
- mt, cs, err := runtime.ContentType(request.Header)
- if err != nil {
- return "", "", nil, err
- }
- rCtx = stdContext.WithValue(rCtx, ctxContentType, &contentTypeValue{mt, cs})
- return mt, cs, request.WithContext(rCtx), nil
- }
-
- // LookupRoute looks a route up and returns true when it is found
- func (c *Context) LookupRoute(request *http.Request) (*MatchedRoute, bool) {
- if route, ok := c.router.Lookup(request.Method, request.URL.EscapedPath()); ok {
- return route, ok
- }
- return nil, false
- }
-
- // RouteInfo tries to match a route for this request
- // Returns the matched route, a shallow copy of the request if its context
- // contains the matched router, otherwise the same request, and a bool to
- // indicate if it the request matches one of the routes, if it doesn't
- // then it returns false and nil for the other two return values
- func (c *Context) RouteInfo(request *http.Request) (*MatchedRoute, *http.Request, bool) {
- var rCtx = request.Context()
-
- if v, ok := rCtx.Value(ctxMatchedRoute).(*MatchedRoute); ok {
- return v, request, ok
- }
-
- if route, ok := c.LookupRoute(request); ok {
- rCtx = stdContext.WithValue(rCtx, ctxMatchedRoute, route)
- return route, request.WithContext(rCtx), ok
- }
-
- return nil, nil, false
- }
-
- // ResponseFormat negotiates the response content type
- // Returns the response format and a shallow copy of the request if its context
- // doesn't contain the response format, otherwise the same request
- func (c *Context) ResponseFormat(r *http.Request, offers []string) (string, *http.Request) {
- var rCtx = r.Context()
-
- if v, ok := rCtx.Value(ctxResponseFormat).(string); ok {
- debugLog("[%s %s] found response format %q in context", r.Method, r.URL.Path, v)
- return v, r
- }
-
- format := NegotiateContentType(r, offers, "")
- if format != "" {
- debugLog("[%s %s] set response format %q in context", r.Method, r.URL.Path, format)
- r = r.WithContext(stdContext.WithValue(rCtx, ctxResponseFormat, format))
- }
- debugLog("[%s %s] negotiated response format %q", r.Method, r.URL.Path, format)
- return format, r
- }
-
- // AllowedMethods gets the allowed methods for the path of this request
- func (c *Context) AllowedMethods(request *http.Request) []string {
- return c.router.OtherMethods(request.Method, request.URL.EscapedPath())
- }
-
- // ResetAuth removes the current principal from the request context
- func (c *Context) ResetAuth(request *http.Request) *http.Request {
- rctx := request.Context()
- rctx = stdContext.WithValue(rctx, ctxSecurityPrincipal, nil)
- rctx = stdContext.WithValue(rctx, ctxSecurityScopes, nil)
- return request.WithContext(rctx)
- }
-
- // Authorize authorizes the request
- // Returns the principal object and a shallow copy of the request when its
- // context doesn't contain the principal, otherwise the same request or an error
- // (the last) if one of the authenticators returns one or an Unauthenticated error
- func (c *Context) Authorize(request *http.Request, route *MatchedRoute) (interface{}, *http.Request, error) {
- if route == nil || !route.HasAuth() {
- return nil, nil, nil
- }
-
- var rCtx = request.Context()
- if v := rCtx.Value(ctxSecurityPrincipal); v != nil {
- return v, request, nil
- }
-
- applies, usr, err := route.Authenticators.Authenticate(request, route)
- if !applies || err != nil || !route.Authenticators.AllowsAnonymous() && usr == nil {
- if err != nil {
- return nil, nil, err
- }
- return nil, nil, errors.Unauthenticated("invalid credentials")
- }
- if route.Authorizer != nil {
- if err := route.Authorizer.Authorize(request, usr); err != nil {
- return nil, nil, errors.New(http.StatusForbidden, err.Error())
- }
- }
-
- rCtx = stdContext.WithValue(rCtx, ctxSecurityPrincipal, usr)
- rCtx = stdContext.WithValue(rCtx, ctxSecurityScopes, route.Authenticator.AllScopes())
- return usr, request.WithContext(rCtx), nil
- }
-
- // BindAndValidate binds and validates the request
- // Returns the validation map and a shallow copy of the request when its context
- // doesn't contain the validation, otherwise it returns the same request or an
- // CompositeValidationError error
- func (c *Context) BindAndValidate(request *http.Request, matched *MatchedRoute) (interface{}, *http.Request, error) {
- var rCtx = request.Context()
-
- if v, ok := rCtx.Value(ctxBoundParams).(*validation); ok {
- debugLog("got cached validation (valid: %t)", len(v.result) == 0)
- if len(v.result) > 0 {
- return v.bound, request, errors.CompositeValidationError(v.result...)
- }
- return v.bound, request, nil
- }
- result := validateRequest(c, request, matched)
- rCtx = stdContext.WithValue(rCtx, ctxBoundParams, result)
- request = request.WithContext(rCtx)
- if len(result.result) > 0 {
- return result.bound, request, errors.CompositeValidationError(result.result...)
- }
- debugLog("no validation errors found")
- return result.bound, request, nil
- }
-
- // NotFound the default not found responder for when no route has been matched yet
- func (c *Context) NotFound(rw http.ResponseWriter, r *http.Request) {
- c.Respond(rw, r, []string{c.api.DefaultProduces()}, nil, errors.NotFound("not found"))
- }
-
- // Respond renders the response after doing some content negotiation
- func (c *Context) Respond(rw http.ResponseWriter, r *http.Request, produces []string, route *MatchedRoute, data interface{}) {
- debugLog("responding to %s %s with produces: %v", r.Method, r.URL.Path, produces)
- offers := []string{}
- for _, mt := range produces {
- if mt != c.api.DefaultProduces() {
- offers = append(offers, mt)
- }
- }
- // the default producer is last so more specific producers take precedence
- offers = append(offers, c.api.DefaultProduces())
- debugLog("offers: %v", offers)
-
- var format string
- format, r = c.ResponseFormat(r, offers)
- rw.Header().Set(runtime.HeaderContentType, format)
-
- if resp, ok := data.(Responder); ok {
- producers := route.Producers
- prod, ok := producers[format]
- if !ok {
- prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()}))
- pr, ok := prods[c.api.DefaultProduces()]
- if !ok {
- panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format))
- }
- prod = pr
- }
- resp.WriteResponse(rw, prod)
- return
- }
-
- if err, ok := data.(error); ok {
- if format == "" {
- rw.Header().Set(runtime.HeaderContentType, runtime.JSONMime)
- }
-
- if realm := security.FailedBasicAuth(r); realm != "" {
- rw.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", realm))
- }
-
- if route == nil || route.Operation == nil {
- c.api.ServeErrorFor("")(rw, r, err)
- return
- }
- c.api.ServeErrorFor(route.Operation.ID)(rw, r, err)
- return
- }
-
- if route == nil || route.Operation == nil {
- rw.WriteHeader(200)
- if r.Method == "HEAD" {
- return
- }
- producers := c.api.ProducersFor(normalizeOffers(offers))
- prod, ok := producers[format]
- if !ok {
- panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format))
- }
- if err := prod.Produce(rw, data); err != nil {
- panic(err) // let the recovery middleware deal with this
- }
- return
- }
-
- if _, code, ok := route.Operation.SuccessResponse(); ok {
- rw.WriteHeader(code)
- if code == 204 || r.Method == "HEAD" {
- return
- }
-
- producers := route.Producers
- prod, ok := producers[format]
- if !ok {
- if !ok {
- prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()}))
- pr, ok := prods[c.api.DefaultProduces()]
- if !ok {
- panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format))
- }
- prod = pr
- }
- }
- if err := prod.Produce(rw, data); err != nil {
- panic(err) // let the recovery middleware deal with this
- }
- return
- }
-
- c.api.ServeErrorFor(route.Operation.ID)(rw, r, errors.New(http.StatusInternalServerError, "can't produce response"))
- }
-
- // APIHandler returns a handler to serve the API, this includes a swagger spec, router and the contract defined in the swagger spec
- func (c *Context) APIHandler(builder Builder) http.Handler {
- b := builder
- if b == nil {
- b = PassthroughBuilder
- }
-
- var title string
- sp := c.spec.Spec()
- if sp != nil && sp.Info != nil && sp.Info.Title != "" {
- title = sp.Info.Title
- }
-
- redocOpts := RedocOpts{
- BasePath: c.BasePath(),
- Title: title,
- }
-
- return Spec("", c.spec.Raw(), Redoc(redocOpts, c.RoutesHandler(b)))
- }
-
- // RoutesHandler returns a handler to serve the API, just the routes and the contract defined in the swagger spec
- func (c *Context) RoutesHandler(builder Builder) http.Handler {
- b := builder
- if b == nil {
- b = PassthroughBuilder
- }
- return NewRouter(c, b(NewOperationExecutor(c)))
- }
|