|
- // @author Couchbase <info@couchbase.com>
- // @copyright 2018 Couchbase, Inc.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
-
- // Package scramsha provides implementation of client side SCRAM-SHA
- // via Http according to https://tools.ietf.org/html/rfc7804
- package scramsha
-
- import (
- "encoding/base64"
- "github.com/pkg/errors"
- "io"
- "io/ioutil"
- "net/http"
- "strings"
- )
-
- // consts used to parse scramsha response from target
- const (
- WWWAuthenticate = "WWW-Authenticate"
- AuthenticationInfo = "Authentication-Info"
- Authorization = "Authorization"
- DataPrefix = "data="
- SidPrefix = "sid="
- )
-
- // Request provides implementation of http request that can be retried
- type Request struct {
- body io.ReadSeeker
-
- // Embed an HTTP request directly. This makes a *Request act exactly
- // like an *http.Request so that all meta methods are supported.
- *http.Request
- }
-
- type lenReader interface {
- Len() int
- }
-
- // NewRequest creates http request that can be retried
- func NewRequest(method, url string, body io.ReadSeeker) (*Request, error) {
- // Wrap the body in a noop ReadCloser if non-nil. This prevents the
- // reader from being closed by the HTTP client.
- var rcBody io.ReadCloser
- if body != nil {
- rcBody = ioutil.NopCloser(body)
- }
-
- // Make the request with the noop-closer for the body.
- httpReq, err := http.NewRequest(method, url, rcBody)
- if err != nil {
- return nil, err
- }
-
- // Check if we can set the Content-Length automatically.
- if lr, ok := body.(lenReader); ok {
- httpReq.ContentLength = int64(lr.Len())
- }
-
- return &Request{body, httpReq}, nil
- }
-
- func encode(str string) string {
- return base64.StdEncoding.EncodeToString([]byte(str))
- }
-
- func decode(str string) (string, error) {
- bytes, err := base64.StdEncoding.DecodeString(str)
- if err != nil {
- return "", errors.Errorf("Cannot base64 decode %s",
- str)
- }
- return string(bytes), err
- }
-
- func trimPrefix(s, prefix string) (string, error) {
- l := len(s)
- trimmed := strings.TrimPrefix(s, prefix)
- if l == len(trimmed) {
- return trimmed, errors.Errorf("Prefix %s not found in %s",
- prefix, s)
- }
- return trimmed, nil
- }
-
- func drainBody(resp *http.Response) {
- defer resp.Body.Close()
- io.Copy(ioutil.Discard, resp.Body)
- }
-
- // DoScramSha performs SCRAM-SHA handshake via Http
- func DoScramSha(req *Request,
- username string,
- password string,
- client *http.Client) (*http.Response, error) {
-
- method := "SCRAM-SHA-512"
- s, err := NewScramSha("SCRAM-SHA512")
- if err != nil {
- return nil, errors.Wrap(err,
- "Unable to initialize SCRAM-SHA handler")
- }
-
- message, err := s.GetStartRequest(username)
- if err != nil {
- return nil, err
- }
-
- encodedMessage := method + " " + DataPrefix + encode(message)
-
- req.Header.Set(Authorization, encodedMessage)
-
- res, err := client.Do(req.Request)
- if err != nil {
- return nil, errors.Wrap(err, "Problem sending SCRAM-SHA start"+
- "request")
- }
-
- if res.StatusCode != http.StatusUnauthorized {
- return res, nil
- }
-
- authHeader := res.Header.Get(WWWAuthenticate)
- if authHeader == "" {
- drainBody(res)
- return nil, errors.Errorf("Header %s is not populated in "+
- "SCRAM-SHA start response", WWWAuthenticate)
- }
-
- authHeader, err = trimPrefix(authHeader, method+" ")
- if err != nil {
- if strings.HasPrefix(authHeader, "Basic ") {
- // user not found
- return res, nil
- }
- drainBody(res)
- return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+
- "start response %s", authHeader)
- }
-
- drainBody(res)
-
- sid, response, err := parseSidAndData(authHeader)
- if err != nil {
- return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+
- "start response %s", authHeader)
- }
-
- err = s.HandleStartResponse(response)
- if err != nil {
- return nil, errors.Wrapf(err, "Error parsing SCRAM-SHA start "+
- "response %s", response)
- }
-
- message = s.GetFinalRequest(password)
- encodedMessage = method + " " + SidPrefix + sid + "," + DataPrefix +
- encode(message)
-
- req.Header.Set(Authorization, encodedMessage)
-
- // rewind request body so it can be resent again
- if req.body != nil {
- if _, err = req.body.Seek(0, 0); err != nil {
- return nil, errors.Errorf("Failed to seek body: %v",
- err)
- }
- }
-
- res, err = client.Do(req.Request)
- if err != nil {
- return nil, errors.Wrap(err, "Problem sending SCRAM-SHA final"+
- "request")
- }
-
- if res.StatusCode == http.StatusUnauthorized {
- // TODO retrieve and return error
- return res, nil
- }
-
- if res.StatusCode >= http.StatusInternalServerError {
- // in this case we cannot expect server to set headers properly
- return res, nil
- }
-
- authHeader = res.Header.Get(AuthenticationInfo)
- if authHeader == "" {
- drainBody(res)
- return nil, errors.Errorf("Header %s is not populated in "+
- "SCRAM-SHA final response", AuthenticationInfo)
- }
-
- finalSid, response, err := parseSidAndData(authHeader)
- if err != nil {
- drainBody(res)
- return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+
- "final response %s", authHeader)
- }
-
- if finalSid != sid {
- drainBody(res)
- return nil, errors.Errorf("Sid %s returned by server "+
- "doesn't match the original sid %s", finalSid, sid)
- }
-
- err = s.HandleFinalResponse(response)
- if err != nil {
- drainBody(res)
- return nil, errors.Wrapf(err,
- "Error handling SCRAM-SHA final server response %s",
- response)
- }
- return res, nil
- }
-
- func parseSidAndData(authHeader string) (string, string, error) {
- sidIndex := strings.Index(authHeader, SidPrefix)
- if sidIndex < 0 {
- return "", "", errors.Errorf("Cannot find %s in %s",
- SidPrefix, authHeader)
- }
-
- sidEndIndex := strings.Index(authHeader, ",")
- if sidEndIndex < 0 {
- return "", "", errors.Errorf("Cannot find ',' in %s",
- authHeader)
- }
-
- sid := authHeader[sidIndex+len(SidPrefix) : sidEndIndex]
-
- dataIndex := strings.Index(authHeader, DataPrefix)
- if dataIndex < 0 {
- return "", "", errors.Errorf("Cannot find %s in %s",
- DataPrefix, authHeader)
- }
-
- data, err := decode(authHeader[dataIndex+len(DataPrefix):])
- if err != nil {
- return "", "", err
- }
- return sid, data, nil
- }
|