|
- package ssh
-
- import (
- "io"
- "log"
- "net"
- "strconv"
- "sync"
-
- gossh "golang.org/x/crypto/ssh"
- )
-
- const (
- forwardedTCPChannelType = "forwarded-tcpip"
- )
-
- // direct-tcpip data struct as specified in RFC4254, Section 7.2
- type localForwardChannelData struct {
- DestAddr string
- DestPort uint32
-
- OriginAddr string
- OriginPort uint32
- }
-
- // DirectTCPIPHandler can be enabled by adding it to the server's
- // ChannelHandlers under direct-tcpip.
- func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
- d := localForwardChannelData{}
- if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
- newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
- return
- }
-
- if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) {
- newChan.Reject(gossh.Prohibited, "port forwarding is disabled")
- return
- }
-
- dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10))
-
- var dialer net.Dialer
- dconn, err := dialer.DialContext(ctx, "tcp", dest)
- if err != nil {
- newChan.Reject(gossh.ConnectionFailed, err.Error())
- return
- }
-
- ch, reqs, err := newChan.Accept()
- if err != nil {
- dconn.Close()
- return
- }
- go gossh.DiscardRequests(reqs)
-
- go func() {
- defer ch.Close()
- defer dconn.Close()
- io.Copy(ch, dconn)
- }()
- go func() {
- defer ch.Close()
- defer dconn.Close()
- io.Copy(dconn, ch)
- }()
- }
-
- type remoteForwardRequest struct {
- BindAddr string
- BindPort uint32
- }
-
- type remoteForwardSuccess struct {
- BindPort uint32
- }
-
- type remoteForwardCancelRequest struct {
- BindAddr string
- BindPort uint32
- }
-
- type remoteForwardChannelData struct {
- DestAddr string
- DestPort uint32
- OriginAddr string
- OriginPort uint32
- }
-
- // ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and
- // adding the HandleSSHRequest callback to the server's RequestHandlers under
- // tcpip-forward and cancel-tcpip-forward.
- type ForwardedTCPHandler struct {
- forwards map[string]net.Listener
- sync.Mutex
- }
-
- func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
- h.Lock()
- if h.forwards == nil {
- h.forwards = make(map[string]net.Listener)
- }
- h.Unlock()
- conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn)
- switch req.Type {
- case "tcpip-forward":
- var reqPayload remoteForwardRequest
- if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
- // TODO: log parse failure
- return false, []byte{}
- }
- if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) {
- return false, []byte("port forwarding is disabled")
- }
- addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort)))
- ln, err := net.Listen("tcp", addr)
- if err != nil {
- // TODO: log listen failure
- return false, []byte{}
- }
- _, destPortStr, _ := net.SplitHostPort(ln.Addr().String())
- destPort, _ := strconv.Atoi(destPortStr)
- h.Lock()
- h.forwards[addr] = ln
- h.Unlock()
- go func() {
- <-ctx.Done()
- h.Lock()
- ln, ok := h.forwards[addr]
- h.Unlock()
- if ok {
- ln.Close()
- }
- }()
- go func() {
- for {
- c, err := ln.Accept()
- if err != nil {
- // TODO: log accept failure
- break
- }
- originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String())
- originPort, _ := strconv.Atoi(orignPortStr)
- payload := gossh.Marshal(&remoteForwardChannelData{
- DestAddr: reqPayload.BindAddr,
- DestPort: uint32(destPort),
- OriginAddr: originAddr,
- OriginPort: uint32(originPort),
- })
- go func() {
- ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload)
- if err != nil {
- // TODO: log failure to open channel
- log.Println(err)
- c.Close()
- return
- }
- go gossh.DiscardRequests(reqs)
- go func() {
- defer ch.Close()
- defer c.Close()
- io.Copy(ch, c)
- }()
- go func() {
- defer ch.Close()
- defer c.Close()
- io.Copy(c, ch)
- }()
- }()
- }
- h.Lock()
- delete(h.forwards, addr)
- h.Unlock()
- }()
- return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)})
-
- case "cancel-tcpip-forward":
- var reqPayload remoteForwardCancelRequest
- if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
- // TODO: log parse failure
- return false, []byte{}
- }
- addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort)))
- h.Lock()
- ln, ok := h.forwards[addr]
- h.Unlock()
- if ok {
- ln.Close()
- }
- return true, nil
- default:
- return false, nil
- }
- }
|