diff --git a/go.mod b/go.mod index 66b0212b..19daca12 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/goccy/go-json v0.9.8 github.com/logrusorgru/aurora v2.0.3+incompatible github.com/oschwald/geoip2-golang v1.7.0 - github.com/sagernet/sing v0.0.0-20220703122912-677c52f01aba + github.com/sagernet/sing v0.0.0-20220704113227-8b990551511a github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649 github.com/sirupsen/logrus v1.8.1 github.com/spf13/cobra v1.5.0 diff --git a/go.sum b/go.sum index ad1bd5c6..79f316de 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,8 @@ github.com/oschwald/maxminddb-golang v1.9.0/go.mod h1:TK+s/Z2oZq0rSl4PSeAEoP0bgm github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sagernet/sing v0.0.0-20220703122912-677c52f01aba h1:ffb+Es7ddyDDOYUXKoJz5vpA+9C80GK7f7sjYN9rFvY= -github.com/sagernet/sing v0.0.0-20220703122912-677c52f01aba/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c= +github.com/sagernet/sing v0.0.0-20220704113227-8b990551511a h1:IvYjuvuPNmZzQfBbCxE/uQqGkNWUa5/KrEMIecRMjZk= +github.com/sagernet/sing v0.0.0-20220704113227-8b990551511a/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c= github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649 h1:whNDUGOAX5GPZkSy4G3Gv9QyIgk5SXRyjkRuP7ohF8k= github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649/go.mod h1:MuyT+9fEPjvauAv0fSE0a6Q+l0Tv2ZrAafTkYfnxBFw= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= diff --git a/route/router.go b/route/router.go index 7af53144..097fcc3d 100644 --- a/route/router.go +++ b/route/router.go @@ -19,6 +19,7 @@ import ( F "github.com/sagernet/sing/common/format" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" ) var _ adapter.Router = (*Router)(nil) @@ -82,7 +83,7 @@ func isGeoRule(rule option.DefaultRule) bool { } func notPrivateNode(code string) bool { - return code == "private" + return code != "private" } func (r *Router) Initialize(outbounds []adapter.Outbound, defaultOutbound func() adapter.Outbound) error { @@ -156,7 +157,10 @@ func (r *Router) Initialize(outbounds []adapter.Outbound, defaultOutbound func() func (r *Router) Start() error { if r.needGeoDatabase { - go r.prepareGeoIPDatabase() + err := r.prepareGeoIPDatabase() + if err != nil { + return err + } } return nil } @@ -171,15 +175,17 @@ func (r *Router) GeoIPReader() *geoip2.Reader { return r.geoReader } -func (r *Router) prepareGeoIPDatabase() { +func (r *Router) prepareGeoIPDatabase() error { var geoPath string if r.geoOptions.Path != "" { geoPath = r.geoOptions.Path } else { geoPath = "Country.mmdb" + if foundPath, loaded := C.Find(geoPath); loaded { + geoPath = foundPath + } } - geoPath, loaded := C.Find(geoPath) - if !loaded { + if !rw.FileExists(geoPath) { r.logger.Warn("geoip database not exists: ", geoPath) var err error for attempts := 0; attempts < 3; attempts++ { @@ -192,7 +198,7 @@ func (r *Router) prepareGeoIPDatabase() { time.Sleep(10 * time.Second) } if err != nil { - return + return err } } geoReader, err := geoip2.Open(geoPath) @@ -200,9 +206,9 @@ func (r *Router) prepareGeoIPDatabase() { r.logger.Info("loaded geoip database") r.geoReader = geoReader } else { - r.logger.Error("open geoip database: ", err) - return + return E.Cause(err, "open geoip database") } + return nil } func (r *Router) downloadGeoIPDatabase(savePath string) error { diff --git a/route/rule.go b/route/rule.go index 38a9c737..0adb2fbe 100644 --- a/route/rule.go +++ b/route/rule.go @@ -41,9 +41,10 @@ func NewRule(router adapter.Router, logger log.Logger, options option.Rule) (ada var _ adapter.Rule = (*DefaultRule)(nil) type DefaultRule struct { - index int - outbound string - items []RuleItem + items []RuleItem + sourceAddressItems []RuleItem + destinationAddressItems []RuleItem + outbound string } type RuleItem interface { @@ -78,37 +79,37 @@ func NewDefaultRule(router adapter.Router, logger log.Logger, options option.Def rule.items = append(rule.items, NewProtocolItem(options.Protocol)) } if len(options.Domain) > 0 || len(options.DomainSuffix) > 0 { - rule.items = append(rule.items, NewDomainItem(options.Domain, options.DomainSuffix)) + rule.destinationAddressItems = append(rule.destinationAddressItems, NewDomainItem(options.Domain, options.DomainSuffix)) } if len(options.DomainKeyword) > 0 { - rule.items = append(rule.items, NewDomainKeywordItem(options.DomainKeyword)) + rule.destinationAddressItems = append(rule.destinationAddressItems, NewDomainKeywordItem(options.DomainKeyword)) } if len(options.DomainRegex) > 0 { item, err := NewDomainRegexItem(options.DomainRegex) if err != nil { return nil, E.Cause(err, "domain_regex") } - rule.items = append(rule.items, item) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) } if len(options.SourceGeoIP) > 0 { - rule.items = append(rule.items, NewGeoIPItem(router, logger, true, options.SourceGeoIP)) + rule.sourceAddressItems = append(rule.sourceAddressItems, NewGeoIPItem(router, logger, true, options.SourceGeoIP)) } if len(options.GeoIP) > 0 { - rule.items = append(rule.items, NewGeoIPItem(router, logger, false, options.GeoIP)) + rule.destinationAddressItems = append(rule.destinationAddressItems, NewGeoIPItem(router, logger, false, options.GeoIP)) } if len(options.SourceIPCIDR) > 0 { item, err := NewIPCIDRItem(true, options.SourceIPCIDR) if err != nil { return nil, E.Cause(err, "source_ipcidr") } - rule.items = append(rule.items, item) + rule.sourceAddressItems = append(rule.sourceAddressItems, item) } if len(options.IPCIDR) > 0 { item, err := NewIPCIDRItem(false, options.IPCIDR) if err != nil { return nil, E.Cause(err, "ipcidr") } - rule.items = append(rule.items, item) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) } if len(options.SourcePort) > 0 { rule.items = append(rule.items, NewPortItem(true, options.SourcePort)) @@ -121,11 +122,38 @@ func NewDefaultRule(router adapter.Router, logger log.Logger, options option.Def func (r *DefaultRule) Match(metadata *adapter.InboundContext) bool { for _, item := range r.items { - if item.Match(metadata) { - return true + if !item.Match(metadata) { + return false } } - return false + + if len(r.sourceAddressItems) > 0 { + var sourceAddressMatch bool + for _, item := range r.sourceAddressItems { + if item.Match(metadata) { + sourceAddressMatch = true + break + } + } + if !sourceAddressMatch { + return false + } + } + + if len(r.destinationAddressItems) > 0 { + var destinationAddressMatch bool + for _, item := range r.destinationAddressItems { + if item.Match(metadata) { + destinationAddressMatch = true + break + } + } + if !destinationAddressMatch { + return false + } + } + + return true } func (r *DefaultRule) Outbound() string { diff --git a/service.go b/service.go index 79665784..4ee7f18f 100644 --- a/service.go +++ b/service.go @@ -2,6 +2,7 @@ package box import ( "context" + "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/inbound" @@ -11,6 +12,7 @@ import ( "github.com/sagernet/sing-box/route" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" ) var _ adapter.Service = (*Service)(nil) @@ -20,9 +22,11 @@ type Service struct { logger log.Logger inbounds []adapter.Inbound outbounds []adapter.Outbound + createdAt time.Time } func NewService(ctx context.Context, options option.Options) (*Service, error) { + createdAt := time.Now() logger, err := log.NewLogger(common.PtrValueOrDefault(options.Log)) if err != nil { return nil, E.Cause(err, "parse log options") @@ -63,6 +67,7 @@ func NewService(ctx context.Context, options option.Options) (*Service, error) { logger: logger, inbounds: inbounds, outbounds: outbounds, + createdAt: createdAt, }, nil } @@ -71,15 +76,18 @@ func (s *Service) Start() error { if err != nil { return err } + err = s.router.Start() + if err != nil { + return err + } for _, in := range s.inbounds { err = in.Start() if err != nil { return err } } - return common.AnyError( - s.router.Start(), - ) + s.logger.Info("sing-box started (", F.Seconds(time.Since(s.createdAt).Seconds()), "s)") + return nil } func (s *Service) Close() error {