diff --git a/gin.go b/gin.go index 106e1b9c..f9ea76d8 100644 --- a/gin.go +++ b/gin.go @@ -1,13 +1,17 @@ package gin import ( + "crypto/tls" + "errors" "github.com/gin-gonic/gin/render" "github.com/julienschmidt/httprouter" "html/template" "math" + "net" "net/http" "path" "sync" + "time" ) const ( @@ -40,6 +44,7 @@ type ( finalNoRoute []HandlerFunc noRoute []HandlerFunc router *httprouter.Router + listener *stoppableListener } ) @@ -108,17 +113,156 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func (engine *Engine) Run(addr string) { - if err := http.ListenAndServe(addr, engine); err != nil { - panic(err) + server := &http.Server{Addr: addr, Handler: engine} + err := engine.listenAndServe(server) + if err != nil { + if err != stoppedError { + panic(err) + } } } func (engine *Engine) RunTLS(addr string, cert string, key string) { - if err := http.ListenAndServeTLS(addr, cert, key, engine); err != nil { - panic(err) + server := &http.Server{Addr: addr, Handler: engine} + err := engine.listenAndServeTLS(server, cert, key) + if err != nil { + if err != stoppedError { + panic(err) + } } } +func (engine *Engine) Stop() { + engine.listener.Stop() +} + +// Inlined from net/http source so we can inject our own listener. +func (engine *Engine) listenAndServe(srv *http.Server) error { + addr := srv.Addr + if addr == "" { + addr = ":http" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + engine.listener, err = newStoppableListener(tcpKeepAliveListener{ln.(*net.TCPListener)}) + if err != nil { + return err + } + return srv.Serve(engine.listener) +} + +// Inlined from net/http source so we can inject our own listener. +func (engine *Engine) listenAndServeTLS(srv *http.Server, certFile string, keyFile string) error { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + config := &tls.Config{} + if srv.TLSConfig != nil { + *config = *srv.TLSConfig + } + if config.NextProtos == nil { + config.NextProtos = []string{"http/1.1"} + } + + var err error + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) + engine.listener, err = newStoppableListener(tlsListener) + if err != nil { + return err + } + return srv.Serve(engine.listener) +} + +/************************************/ +/******** KEEP ALIVE LISTENER *******/ +/************************************/ + +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +/************************************/ +/******** STOPPABLE LISTENER ********/ +/************************************/ + +var stoppedError = errors.New("Webserver is being stopped") + +type stoppableListener struct { + tcpKeepAliveListener //Wrapped listener + stop chan int //Channel used only to indicate listener should shutdown +} + +func newStoppableListener(l net.Listener) (*stoppableListener, error) { + tcpL, ok := l.(tcpKeepAliveListener) + + if !ok { + return nil, errors.New("Cannot wrap listener") + } + + retval := &stoppableListener{} + retval.tcpKeepAliveListener = tcpL + retval.stop = make(chan int) + + return retval, nil +} + +func (sl *stoppableListener) Accept() (net.Conn, error) { + for { + //Wait up to one second for a new connection + sl.SetDeadline(time.Now().Add(time.Second)) + + newConn, err := sl.tcpKeepAliveListener.Accept() + + //Check for the channel being closed + select { + case <-sl.stop: + return nil, stoppedError + default: + //If the channel is still open, continue as normal + } + + if err != nil { + netErr, ok := err.(net.Error) + + //If this is a timeout, then continue to wait for + //new connections + if ok && netErr.Timeout() && netErr.Temporary() { + continue + } + } + + return newConn, err + } +} + +func (sl *stoppableListener) Stop() { + close(sl.stop) +} + /************************************/ /********** ROUTES GROUPING *********/ /************************************/