diff --git a/adapter/hook.go b/adapter/hook.go new file mode 100644 index 00000000..74c6922d --- /dev/null +++ b/adapter/hook.go @@ -0,0 +1,8 @@ +package adapter + +type Hook interface { + PreStart() error + PostStart() error + PreStop() error + PostStop() error +} diff --git a/adapter/hook/manager.go b/adapter/hook/manager.go new file mode 100644 index 00000000..abecb1af --- /dev/null +++ b/adapter/hook/manager.go @@ -0,0 +1,157 @@ +package hook + +import ( + "context" + "crypto/tls" + "errors" + "io" + "net/http" + "net/url" + "strings" + + "github.com/sagernet/sing-box/option" +) + +type Manager struct { + hook *option.HookOptions + httpExecutor *http.Client +} + +func NewManager(hook *option.HookOptions) *Manager { + return &Manager{ + hook: hook, + httpExecutor: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + CheckRedirect: redirectChecker(false), + }, + } +} +func (m *Manager) PreStart() error { + if m.hook == nil { + return nil + } + return m.execute(m.hook.PreStart) +} +func (m *Manager) PostStart() error { + if m.hook == nil { + return nil + } + return m.execute(m.hook.PostStart) +} +func (m *Manager) PreStop() error { + if m.hook == nil { + return nil + } + return m.execute(m.hook.PreStop) +} +func (m *Manager) PostStop() error { + if m.hook == nil { + return nil + } + return m.execute(m.hook.PostStop) +} +func redirectChecker(followNonLocalRedirects bool) func(*http.Request, []*http.Request) error { + if followNonLocalRedirects { + return nil + } + return func(req *http.Request, via []*http.Request) error { + if req.URL.Hostname() != via[0].URL.Hostname() { + return http.ErrUseLastResponse + } + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + return nil + } +} +func (m *Manager) execute(execution option.Hook) error { + for _, httpExecution := range execution.HTTP { + if err := m.executeHTTP(&httpExecution); err != nil && execution.HandleError { + return err + } + } + return nil +} +func (m *Manager) executeHTTP(httpExecution *option.HTTPExecution) error { + if httpExecution == nil { + return nil + } + req, err := m.buildRequest(httpExecution) + if err != nil { + return err + } + resp, err := m.httpExecutor.Do(req) + discardHTTPRespBody(resp) + if isHTTPResponseError(err) { + req := req.Clone(context.Background()) + req.URL.Scheme = "http" + req.Header.Del("Authorization") + resp, httpErr := m.httpExecutor.Do(req) + if httpErr == nil { + err = nil + } + discardHTTPRespBody(resp) + } + return err +} +func isHTTPResponseError(err error) bool { + if err == nil { + return false + } + urlErr := &url.Error{} + if !errors.As(err, &urlErr) { + return false + } + return strings.Contains(urlErr.Err.Error(), "server gave HTTP response to HTTPS client") +} + +const ( + maxRespBodyLength = 10 * 1 << 10 +) + +func discardHTTPRespBody(resp *http.Response) { + if resp == nil { + return + } + defer resp.Body.Close() + if resp.ContentLength <= maxRespBodyLength { + io.Copy(io.Discard, &io.LimitedReader{R: resp.Body, N: maxRespBodyLength}) + } +} +func (m *Manager) buildRequest(httpExecution *option.HTTPExecution) (*http.Request, error) { + u, err := url.Parse(httpExecution.URL) + if err != nil { + return nil, err + } + headers := buildHeader(httpExecution.Headers) + return newProbeRequest(u, headers) +} +func newProbeRequest(url *url.URL, headers http.Header) (*http.Request, error) { + req, err := http.NewRequest("GET", url.String(), nil) + if err != nil { + return nil, err + } + if headers == nil { + headers = http.Header{} + } + if _, ok := headers["User-Agent"]; !ok { + headers.Set("User-Agent", "TODO://") + } + if _, ok := headers["Accept"]; !ok { + headers.Set("Accept", "*/*") + } else if headers.Get("Accept") == "" { + headers.Del("Accept") + } + req.Header = headers + req.Host = headers.Get("Host") + return req, nil +} +func buildHeader(headerList []option.Header) http.Header { + headers := make(http.Header) + for _, header := range headerList { + headers.Add(header.Name, header.Value) + } + return headers +} diff --git a/box.go b/box.go index 5dc76ebc..fbe61093 100644 --- a/box.go +++ b/box.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/endpoint" + "github.com/sagernet/sing-box/adapter/hook" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/common/dialer" @@ -44,6 +45,7 @@ type Box struct { router *route.Router services []adapter.LifecycleService done chan struct{} + hook *hook.Manager } type Options struct { @@ -283,6 +285,7 @@ func New(options Options) (*Box, error) { logger: logFactory.Logger(), services: services, done: make(chan struct{}), + hook: hook.NewManager(options.Hook), }, nil } @@ -321,10 +324,13 @@ func (s *Box) Start() error { return err } s.logger.Info("sing-box started (", F.Seconds(time.Since(s.createdAt).Seconds()), "s)") - return nil + return s.hook.PostStart() } func (s *Box) preStart() error { + if err := s.hook.PreStart(); err != nil { + return err + } monitor := taskmonitor.New(s.logger, C.StartTimeout) monitor.Start("start logger") err := s.logFactory.Start() @@ -390,6 +396,9 @@ func (s *Box) Close() error { default: close(s.done) } + if err := s.preClose(); err != nil { + return err + } err := common.Close( s.inbound, s.outbound, s.router, s.connection, s.network, ) @@ -401,7 +410,10 @@ func (s *Box) Close() error { err = E.Append(err, s.logFactory.Close(), func(err error) error { return E.Cause(err, "close logger") }) - return err + if err != nil { + return err + } + return s.postClose() } func (s *Box) Network() adapter.NetworkManager { @@ -419,3 +431,17 @@ func (s *Box) Inbound() adapter.InboundManager { func (s *Box) Outbound() adapter.OutboundManager { return s.outbound } + +func (s *Box) preClose() error { + if err := s.hook.PreStop(); err != nil { + return err + } + return nil +} + +func (s *Box) postClose() error { + if err := s.hook.PostStop(); err != nil { + return err + } + return nil +} diff --git a/option/hook.go b/option/hook.go new file mode 100644 index 00000000..e390c1f6 --- /dev/null +++ b/option/hook.go @@ -0,0 +1,22 @@ +package option + +type HookOptions struct { + PreStart Hook `json:"preStart"` + PostStart Hook `json:"postStart"` + PreStop Hook `json:"preStop"` + PostStop Hook `json:"postStop"` +} +type Hook struct { + HandleError bool `json:"ignoreError"` + HTTP []HTTPExecution `json:"http"` + // Others: like tcp, websocket etc. +} +type Header struct { + Name string `json:"name"` + Value string `json:"value"` +} +type HTTPExecution struct { + Name string `json:"name"` + URL string `json:"url"` + Headers []Header `json:"headers"` +} diff --git a/option/options.go b/option/options.go index 94c97719..c07a570e 100644 --- a/option/options.go +++ b/option/options.go @@ -18,6 +18,7 @@ type _Options struct { Outbounds []Outbound `json:"outbounds,omitempty"` Route *RouteOptions `json:"route,omitempty"` Experimental *ExperimentalOptions `json:"experimental,omitempty"` + Hook *HookOptions `json:"hook,omitempty"` } type Options _Options