|
- // Copyright 2018 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
-
- package socks
-
- import (
- "context"
- "errors"
- "io"
- "net"
- "strconv"
- "time"
- )
-
- var (
- noDeadline = time.Time{}
- aLongTimeAgo = time.Unix(1, 0)
- )
-
- func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
- host, port, err := splitHostPort(address)
- if err != nil {
- return nil, err
- }
- if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
- c.SetDeadline(deadline)
- defer c.SetDeadline(noDeadline)
- }
- if ctx != context.Background() {
- errCh := make(chan error, 1)
- done := make(chan struct{})
- defer func() {
- close(done)
- if ctxErr == nil {
- ctxErr = <-errCh
- }
- }()
- go func() {
- select {
- case <-ctx.Done():
- c.SetDeadline(aLongTimeAgo)
- errCh <- ctx.Err()
- case <-done:
- errCh <- nil
- }
- }()
- }
-
- b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
- b = append(b, Version5)
- if len(d.AuthMethods) == 0 || d.Authenticate == nil {
- b = append(b, 1, byte(AuthMethodNotRequired))
- } else {
- ams := d.AuthMethods
- if len(ams) > 255 {
- return nil, errors.New("too many authentication methods")
- }
- b = append(b, byte(len(ams)))
- for _, am := range ams {
- b = append(b, byte(am))
- }
- }
- if _, ctxErr = c.Write(b); ctxErr != nil {
- return
- }
-
- if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
- return
- }
- if b[0] != Version5 {
- return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
- }
- am := AuthMethod(b[1])
- if am == AuthMethodNoAcceptableMethods {
- return nil, errors.New("no acceptable authentication methods")
- }
- if d.Authenticate != nil {
- if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
- return
- }
- }
-
- b = b[:0]
- b = append(b, Version5, byte(d.cmd), 0)
- if ip := net.ParseIP(host); ip != nil {
- if ip4 := ip.To4(); ip4 != nil {
- b = append(b, AddrTypeIPv4)
- b = append(b, ip4...)
- } else if ip6 := ip.To16(); ip6 != nil {
- b = append(b, AddrTypeIPv6)
- b = append(b, ip6...)
- } else {
- return nil, errors.New("unknown address type")
- }
- } else {
- if len(host) > 255 {
- return nil, errors.New("FQDN too long")
- }
- b = append(b, AddrTypeFQDN)
- b = append(b, byte(len(host)))
- b = append(b, host...)
- }
- b = append(b, byte(port>>8), byte(port))
- if _, ctxErr = c.Write(b); ctxErr != nil {
- return
- }
-
- if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
- return
- }
- if b[0] != Version5 {
- return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
- }
- if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
- return nil, errors.New("unknown error " + cmdErr.String())
- }
- if b[2] != 0 {
- return nil, errors.New("non-zero reserved field")
- }
- l := 2
- var a Addr
- switch b[3] {
- case AddrTypeIPv4:
- l += net.IPv4len
- a.IP = make(net.IP, net.IPv4len)
- case AddrTypeIPv6:
- l += net.IPv6len
- a.IP = make(net.IP, net.IPv6len)
- case AddrTypeFQDN:
- if _, err := io.ReadFull(c, b[:1]); err != nil {
- return nil, err
- }
- l += int(b[0])
- default:
- return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
- }
- if cap(b) < l {
- b = make([]byte, l)
- } else {
- b = b[:l]
- }
- if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
- return
- }
- if a.IP != nil {
- copy(a.IP, b)
- } else {
- a.Name = string(b[:len(b)-2])
- }
- a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
- return &a, nil
- }
-
- func splitHostPort(address string) (string, int, error) {
- host, port, err := net.SplitHostPort(address)
- if err != nil {
- return "", 0, err
- }
- portnum, err := strconv.Atoi(port)
- if err != nil {
- return "", 0, err
- }
- if 1 > portnum || portnum > 0xffff {
- return "", 0, errors.New("port number out of range " + port)
- }
- return host, portnum, nil
- }
|