本站源代码
Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

395 lines
11KB

  1. package ssh
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "sync"
  8. "time"
  9. gossh "golang.org/x/crypto/ssh"
  10. )
  11. // ErrServerClosed is returned by the Server's Serve, ListenAndServe,
  12. // and ListenAndServeTLS methods after a call to Shutdown or Close.
  13. var ErrServerClosed = errors.New("ssh: Server closed")
  14. type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
  15. var DefaultRequestHandlers = map[string]RequestHandler{}
  16. type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
  17. var DefaultChannelHandlers = map[string]ChannelHandler{
  18. "session": DefaultSessionHandler,
  19. }
  20. // Server defines parameters for running an SSH server. The zero value for
  21. // Server is a valid configuration. When both PasswordHandler and
  22. // PublicKeyHandler are nil, no client authentication is performed.
  23. type Server struct {
  24. Addr string // TCP address to listen on, ":22" if empty
  25. Handler Handler // handler to invoke, ssh.DefaultHandler if nil
  26. HostSigners []Signer // private keys for the host key, must have at least one
  27. Version string // server version to be sent before the initial handshake
  28. KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler
  29. PasswordHandler PasswordHandler // password authentication handler
  30. PublicKeyHandler PublicKeyHandler // public key authentication handler
  31. PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
  32. ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
  33. LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
  34. ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
  35. ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options
  36. SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions
  37. IdleTimeout time.Duration // connection timeout when no activity, none if empty
  38. MaxTimeout time.Duration // absolute connection timeout, none if empty
  39. // ChannelHandlers allow overriding the built-in session handlers or provide
  40. // extensions to the protocol, such as tcpip forwarding. By default only the
  41. // "session" handler is enabled.
  42. ChannelHandlers map[string]ChannelHandler
  43. // RequestHandlers allow overriding the server-level request handlers or
  44. // provide extensions to the protocol, such as tcpip forwarding. By default
  45. // no handlers are enabled.
  46. RequestHandlers map[string]RequestHandler
  47. listenerWg sync.WaitGroup
  48. mu sync.Mutex
  49. listeners map[net.Listener]struct{}
  50. conns map[*gossh.ServerConn]struct{}
  51. connWg sync.WaitGroup
  52. doneChan chan struct{}
  53. }
  54. func (srv *Server) ensureHostSigner() error {
  55. if len(srv.HostSigners) == 0 {
  56. signer, err := generateSigner()
  57. if err != nil {
  58. return err
  59. }
  60. srv.HostSigners = append(srv.HostSigners, signer)
  61. }
  62. return nil
  63. }
  64. func (srv *Server) ensureHandlers() {
  65. srv.mu.Lock()
  66. defer srv.mu.Unlock()
  67. if srv.RequestHandlers == nil {
  68. srv.RequestHandlers = map[string]RequestHandler{}
  69. for k, v := range DefaultRequestHandlers {
  70. srv.RequestHandlers[k] = v
  71. }
  72. }
  73. if srv.ChannelHandlers == nil {
  74. srv.ChannelHandlers = map[string]ChannelHandler{}
  75. for k, v := range DefaultChannelHandlers {
  76. srv.ChannelHandlers[k] = v
  77. }
  78. }
  79. }
  80. func (srv *Server) config(ctx Context) *gossh.ServerConfig {
  81. var config *gossh.ServerConfig
  82. if srv.ServerConfigCallback == nil {
  83. config = &gossh.ServerConfig{}
  84. } else {
  85. config = srv.ServerConfigCallback(ctx)
  86. }
  87. for _, signer := range srv.HostSigners {
  88. config.AddHostKey(signer)
  89. }
  90. if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil {
  91. config.NoClientAuth = true
  92. }
  93. if srv.Version != "" {
  94. config.ServerVersion = "SSH-2.0-" + srv.Version
  95. }
  96. if srv.PasswordHandler != nil {
  97. config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) {
  98. applyConnMetadata(ctx, conn)
  99. if ok := srv.PasswordHandler(ctx, string(password)); !ok {
  100. return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
  101. }
  102. return ctx.Permissions().Permissions, nil
  103. }
  104. }
  105. if srv.PublicKeyHandler != nil {
  106. config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) {
  107. applyConnMetadata(ctx, conn)
  108. if ok := srv.PublicKeyHandler(ctx, key); !ok {
  109. return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
  110. }
  111. ctx.SetValue(ContextKeyPublicKey, key)
  112. return ctx.Permissions().Permissions, nil
  113. }
  114. }
  115. if srv.KeyboardInteractiveHandler != nil {
  116. config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) {
  117. if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok {
  118. return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
  119. }
  120. return ctx.Permissions().Permissions, nil
  121. }
  122. }
  123. return config
  124. }
  125. // Handle sets the Handler for the server.
  126. func (srv *Server) Handle(fn Handler) {
  127. srv.Handler = fn
  128. }
  129. // Close immediately closes all active listeners and all active
  130. // connections.
  131. //
  132. // Close returns any error returned from closing the Server's
  133. // underlying Listener(s).
  134. func (srv *Server) Close() error {
  135. srv.mu.Lock()
  136. defer srv.mu.Unlock()
  137. srv.closeDoneChanLocked()
  138. err := srv.closeListenersLocked()
  139. for c := range srv.conns {
  140. c.Close()
  141. delete(srv.conns, c)
  142. }
  143. return err
  144. }
  145. // Shutdown gracefully shuts down the server without interrupting any
  146. // active connections. Shutdown works by first closing all open
  147. // listeners, and then waiting indefinitely for connections to close.
  148. // If the provided context expires before the shutdown is complete,
  149. // then the context's error is returned.
  150. func (srv *Server) Shutdown(ctx context.Context) error {
  151. srv.mu.Lock()
  152. lnerr := srv.closeListenersLocked()
  153. srv.closeDoneChanLocked()
  154. srv.mu.Unlock()
  155. finished := make(chan struct{}, 1)
  156. go func() {
  157. srv.listenerWg.Wait()
  158. srv.connWg.Wait()
  159. finished <- struct{}{}
  160. }()
  161. select {
  162. case <-ctx.Done():
  163. return ctx.Err()
  164. case <-finished:
  165. return lnerr
  166. }
  167. }
  168. // Serve accepts incoming connections on the Listener l, creating a new
  169. // connection goroutine for each. The connection goroutines read requests and then
  170. // calls srv.Handler to handle sessions.
  171. //
  172. // Serve always returns a non-nil error.
  173. func (srv *Server) Serve(l net.Listener) error {
  174. srv.ensureHandlers()
  175. defer l.Close()
  176. if err := srv.ensureHostSigner(); err != nil {
  177. return err
  178. }
  179. if srv.Handler == nil {
  180. srv.Handler = DefaultHandler
  181. }
  182. var tempDelay time.Duration
  183. srv.trackListener(l, true)
  184. defer srv.trackListener(l, false)
  185. for {
  186. conn, e := l.Accept()
  187. if e != nil {
  188. select {
  189. case <-srv.getDoneChan():
  190. return ErrServerClosed
  191. default:
  192. }
  193. if ne, ok := e.(net.Error); ok && ne.Temporary() {
  194. if tempDelay == 0 {
  195. tempDelay = 5 * time.Millisecond
  196. } else {
  197. tempDelay *= 2
  198. }
  199. if max := 1 * time.Second; tempDelay > max {
  200. tempDelay = max
  201. }
  202. time.Sleep(tempDelay)
  203. continue
  204. }
  205. return e
  206. }
  207. go srv.handleConn(conn)
  208. }
  209. }
  210. func (srv *Server) handleConn(newConn net.Conn) {
  211. if srv.ConnCallback != nil {
  212. cbConn := srv.ConnCallback(newConn)
  213. if cbConn == nil {
  214. newConn.Close()
  215. return
  216. }
  217. newConn = cbConn
  218. }
  219. ctx, cancel := newContext(srv)
  220. conn := &serverConn{
  221. Conn: newConn,
  222. idleTimeout: srv.IdleTimeout,
  223. closeCanceler: cancel,
  224. }
  225. if srv.MaxTimeout > 0 {
  226. conn.maxDeadline = time.Now().Add(srv.MaxTimeout)
  227. }
  228. defer conn.Close()
  229. sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
  230. if err != nil {
  231. // TODO: trigger event callback
  232. return
  233. }
  234. srv.trackConn(sshConn, true)
  235. defer srv.trackConn(sshConn, false)
  236. ctx.SetValue(ContextKeyConn, sshConn)
  237. applyConnMetadata(ctx, sshConn)
  238. //go gossh.DiscardRequests(reqs)
  239. go srv.handleRequests(ctx, reqs)
  240. for ch := range chans {
  241. handler := srv.ChannelHandlers[ch.ChannelType()]
  242. if handler == nil {
  243. handler = srv.ChannelHandlers["default"]
  244. }
  245. if handler == nil {
  246. ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
  247. continue
  248. }
  249. go handler(srv, sshConn, ch, ctx)
  250. }
  251. }
  252. func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
  253. for req := range in {
  254. handler := srv.RequestHandlers[req.Type]
  255. if handler == nil {
  256. handler = srv.RequestHandlers["default"]
  257. }
  258. if handler == nil {
  259. req.Reply(false, nil)
  260. continue
  261. }
  262. /*reqCtx, cancel := context.WithCancel(ctx)
  263. defer cancel() */
  264. ret, payload := handler(ctx, srv, req)
  265. req.Reply(ret, payload)
  266. }
  267. }
  268. // ListenAndServe listens on the TCP network address srv.Addr and then calls
  269. // Serve to handle incoming connections. If srv.Addr is blank, ":22" is used.
  270. // ListenAndServe always returns a non-nil error.
  271. func (srv *Server) ListenAndServe() error {
  272. addr := srv.Addr
  273. if addr == "" {
  274. addr = ":22"
  275. }
  276. ln, err := net.Listen("tcp", addr)
  277. if err != nil {
  278. return err
  279. }
  280. return srv.Serve(ln)
  281. }
  282. // AddHostKey adds a private key as a host key. If an existing host key exists
  283. // with the same algorithm, it is overwritten. Each server config must have at
  284. // least one host key.
  285. func (srv *Server) AddHostKey(key Signer) {
  286. // these are later added via AddHostKey on ServerConfig, which performs the
  287. // check for one of every algorithm.
  288. srv.HostSigners = append(srv.HostSigners, key)
  289. }
  290. // SetOption runs a functional option against the server.
  291. func (srv *Server) SetOption(option Option) error {
  292. return option(srv)
  293. }
  294. func (srv *Server) getDoneChan() <-chan struct{} {
  295. srv.mu.Lock()
  296. defer srv.mu.Unlock()
  297. return srv.getDoneChanLocked()
  298. }
  299. func (srv *Server) getDoneChanLocked() chan struct{} {
  300. if srv.doneChan == nil {
  301. srv.doneChan = make(chan struct{})
  302. }
  303. return srv.doneChan
  304. }
  305. func (srv *Server) closeDoneChanLocked() {
  306. ch := srv.getDoneChanLocked()
  307. select {
  308. case <-ch:
  309. // Already closed. Don't close again.
  310. default:
  311. // Safe to close here. We're the only closer, guarded
  312. // by srv.mu.
  313. close(ch)
  314. }
  315. }
  316. func (srv *Server) closeListenersLocked() error {
  317. var err error
  318. for ln := range srv.listeners {
  319. if cerr := ln.Close(); cerr != nil && err == nil {
  320. err = cerr
  321. }
  322. delete(srv.listeners, ln)
  323. }
  324. return err
  325. }
  326. func (srv *Server) trackListener(ln net.Listener, add bool) {
  327. srv.mu.Lock()
  328. defer srv.mu.Unlock()
  329. if srv.listeners == nil {
  330. srv.listeners = make(map[net.Listener]struct{})
  331. }
  332. if add {
  333. // If the *Server is being reused after a previous
  334. // Close or Shutdown, reset its doneChan:
  335. if len(srv.listeners) == 0 && len(srv.conns) == 0 {
  336. srv.doneChan = nil
  337. }
  338. srv.listeners[ln] = struct{}{}
  339. srv.listenerWg.Add(1)
  340. } else {
  341. delete(srv.listeners, ln)
  342. srv.listenerWg.Done()
  343. }
  344. }
  345. func (srv *Server) trackConn(c *gossh.ServerConn, add bool) {
  346. srv.mu.Lock()
  347. defer srv.mu.Unlock()
  348. if srv.conns == nil {
  349. srv.conns = make(map[*gossh.ServerConn]struct{})
  350. }
  351. if add {
  352. srv.conns[c] = struct{}{}
  353. srv.connWg.Add(1)
  354. } else {
  355. delete(srv.conns, c)
  356. srv.connWg.Done()
  357. }
  358. }
上海开阖软件有限公司 沪ICP备12045867号-1