add rule-provider clash-api

(cherry picked from commit e6e574a8e868a630fe17ce6cab05a68fe96653c1)
This commit is contained in:
PuerNya 2024-08-13 06:42:47 +08:00 committed by CHIZI-0618
parent 0a59285020
commit 0ec07e573e
10 changed files with 142 additions and 37 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"net/http" "net/http"
"net/netip" "net/netip"
"time"
"github.com/sagernet/sing-box/common/geoip" "github.com/sagernet/sing-box/common/geoip"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
@ -34,6 +35,7 @@ type Router interface {
GeoIPReader() *geoip.Reader GeoIPReader() *geoip.Reader
LoadGeosite(code string) (Rule, error) LoadGeosite(code string) (Rule, error)
RuleSets() []RuleSet
RuleSet(tag string) (RuleSet, bool) RuleSet(tag string) (RuleSet, bool)
NeedWIFIState() bool NeedWIFIState() bool
@ -76,6 +78,7 @@ func RouterFromContext(ctx context.Context) Router {
type HeadlessRule interface { type HeadlessRule interface {
Match(metadata *InboundContext) bool Match(metadata *InboundContext) bool
RuleCount() uint64
String() string String() string
} }
@ -98,6 +101,10 @@ type DNSRule interface {
type RuleSet interface { type RuleSet interface {
Name() string Name() string
Type() string
Format() string
UpdatedTime() time.Time
Update(ctx context.Context) error
StartContext(ctx context.Context, startContext RuleSetStartContext) error StartContext(ctx context.Context, startContext RuleSetStartContext) error
PostStart() error PostStart() error
Metadata() RuleSetMetadata Metadata() RuleSetMetadata

View File

@ -1,58 +1,93 @@
package clashapi package clashapi
import ( import (
"context"
"net/http" "net/http"
"strings"
"github.com/sagernet/sing-box/adapter"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/json/badjson"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/render" "github.com/go-chi/render"
) )
func ruleProviderRouter() http.Handler { func ruleProviderRouter(router adapter.Router) http.Handler {
r := chi.NewRouter() r := chi.NewRouter()
r.Get("/", getRuleProviders) r.Get("/", getRuleProviders(router))
r.Route("/{name}", func(r chi.Router) { r.Route("/{name}", func(r chi.Router) {
r.Use(parseProviderName, findRuleProviderByName) r.Use(parseProviderName, findRuleProviderByName(router))
r.Get("/", getRuleProvider) r.Get("/", getRuleProvider)
r.Put("/", updateRuleProvider) r.Put("/", updateRuleProvider)
}) })
return r return r
} }
func getRuleProviders(w http.ResponseWriter, r *http.Request) { func ruleSetInfo(ruleSet adapter.RuleSet) *badjson.JSONObject {
render.JSON(w, r, render.M{ var info badjson.JSONObject
"providers": []string{}, info.Put("name", ruleSet.Name())
}) info.Put("type", "Rule")
info.Put("vehicleType", strings.ToUpper(ruleSet.Type()))
info.Put("behavior", strings.ToUpper(ruleSet.Format()))
info.Put("ruleCount", ruleSet.RuleCount())
info.Put("updatedAt", ruleSet.UpdatedTime().Format("2006-01-02T15:04:05.999999999-07:00"))
return &info
}
func getRuleProviders(router adapter.Router) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
providerMap := render.M{}
for i, ruleSet := range router.RuleSets() {
var tag string
if ruleSet.Name() == "" {
tag = F.ToString(i)
} else {
tag = ruleSet.Name()
}
providerMap[tag] = ruleSetInfo(ruleSet)
}
render.JSON(w, r, render.M{
"providers": providerMap,
})
}
} }
func getRuleProvider(w http.ResponseWriter, r *http.Request) { func getRuleProvider(w http.ResponseWriter, r *http.Request) {
// provider := r.Context().Value(CtxKeyProvider).(provider.RuleProvider) ruleSet := r.Context().Value(CtxKeyProvider).(adapter.RuleSet)
// render.JSON(w, r, provider) response, err := ruleSetInfo(ruleSet).MarshalJSON()
render.NoContent(w, r) if err != nil {
render.Status(r, http.StatusInternalServerError)
render.JSON(w, r, newError(err.Error()))
return
}
w.Write(response)
} }
func updateRuleProvider(w http.ResponseWriter, r *http.Request) { func updateRuleProvider(w http.ResponseWriter, r *http.Request) {
/*provider := r.Context().Value(CtxKeyProvider).(provider.RuleProvider) ruleSet := r.Context().Value(CtxKeyProvider).(adapter.RuleSet)
if err := provider.Update(); err != nil { err := ruleSet.Update(r.Context())
render.Status(r, http.StatusServiceUnavailable) if err != nil {
render.Status(r, http.StatusInternalServerError)
render.JSON(w, r, newError(err.Error())) render.JSON(w, r, newError(err.Error()))
return return
}*/ }
render.NoContent(w, r) render.NoContent(w, r)
} }
func findRuleProviderByName(next http.Handler) http.Handler { func findRuleProviderByName(router adapter.Router) func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return func(next http.Handler) http.Handler {
/*name := r.Context().Value(CtxKeyProviderName).(string) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
providers := tunnel.RuleProviders() name := r.Context().Value(CtxKeyProviderName).(string)
provider, exist := providers[name] provider, exist := router.RuleSet(name)
if !exist {*/ if !exist {
render.Status(r, http.StatusNotFound) render.Status(r, http.StatusNotFound)
render.JSON(w, r, ErrNotFound) render.JSON(w, r, ErrNotFound)
//return return
//} }
ctx := context.WithValue(r.Context(), CtxKeyProvider, provider)
// ctx := context.WithValue(r.Context(), CtxKeyProvider, provider) next.ServeHTTP(w, r.WithContext(ctx))
// next.ServeHTTP(w, r.WithContext(ctx)) })
}) }
} }

View File

@ -113,7 +113,7 @@ func NewServer(ctx context.Context, router adapter.Router, logFactory log.Observ
r.Mount("/rules", ruleRouter(router)) r.Mount("/rules", ruleRouter(router))
r.Mount("/connections", connectionRouter(router, trafficManager)) r.Mount("/connections", connectionRouter(router, trafficManager))
r.Mount("/providers/proxies", proxyProviderRouter()) r.Mount("/providers/proxies", proxyProviderRouter())
r.Mount("/providers/rules", ruleProviderRouter()) r.Mount("/providers/rules", ruleProviderRouter(router))
r.Mount("/script", scriptRouter()) r.Mount("/script", scriptRouter())
r.Mount("/profile", profileRouter()) r.Mount("/profile", profileRouter())
r.Mount("/cache", cacheRouter(ctx)) r.Mount("/cache", cacheRouter(ctx))

View File

@ -57,7 +57,7 @@ func (r *RuleSet) UnmarshalJSON(bytes []byte) error {
return E.New("unknown rule-set format: " + r.Format) return E.New("unknown rule-set format: " + r.Format)
} }
} else { } else {
r.Format = "" r.Format = C.RuleSetFormatSource
r.Path = "" r.Path = ""
} }
var v any var v any

View File

@ -784,6 +784,10 @@ func (r *Router) FakeIPStore() adapter.FakeIPStore {
return r.fakeIPStore return r.fakeIPStore
} }
func (r *Router) RuleSets() []adapter.RuleSet {
return r.ruleSets
}
func (r *Router) RuleSet(tag string) (adapter.RuleSet, bool) { func (r *Router) RuleSet(tag string) (adapter.RuleSet, bool) {
ruleSet, loaded := r.ruleSetMap[tag] ruleSet, loaded := r.ruleSetMap[tag]
return ruleSet, loaded return ruleSet, loaded

View File

@ -19,6 +19,7 @@ type abstractDefaultRule struct {
destinationPortItems []RuleItem destinationPortItems []RuleItem
allItems []RuleItem allItems []RuleItem
ruleSetItem RuleItem ruleSetItem RuleItem
ruleCount uint64
invert bool invert bool
outbound string outbound string
} }
@ -27,6 +28,10 @@ func (r *abstractDefaultRule) Type() string {
return C.RuleTypeDefault return C.RuleTypeDefault
} }
func (r *abstractDefaultRule) RuleCount() uint64 {
return r.ruleCount
}
func (r *abstractDefaultRule) Start() error { func (r *abstractDefaultRule) Start() error {
for _, item := range r.allItems { for _, item := range r.allItems {
if starter, isStarter := item.(interface { if starter, isStarter := item.(interface {
@ -163,16 +168,21 @@ func (r *abstractDefaultRule) String() string {
} }
type abstractLogicalRule struct { type abstractLogicalRule struct {
rules []adapter.HeadlessRule rules []adapter.HeadlessRule
mode string mode string
invert bool invert bool
outbound string outbound string
ruleCount uint64
} }
func (r *abstractLogicalRule) Type() string { func (r *abstractLogicalRule) Type() string {
return C.RuleTypeLogical return C.RuleTypeLogical
} }
func (r *abstractLogicalRule) RuleCount() uint64 {
return r.ruleCount
}
func (r *abstractLogicalRule) UpdateGeosite() error { func (r *abstractLogicalRule) UpdateGeosite() error {
for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (adapter.Rule, bool) { for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (adapter.Rule, bool) {
rule, loaded := it.(adapter.Rule) rule, loaded := it.(adapter.Rule)

View File

@ -159,6 +159,18 @@ func NewDefaultHeadlessRule(router adapter.Router, options option.DefaultHeadles
rule.destinationAddressItems = append(rule.destinationAddressItems, item) rule.destinationAddressItems = append(rule.destinationAddressItems, item)
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} }
switch true {
case len(rule.allItems) == len(rule.destinationAddressItems)+len(rule.destinationIPCIDRItems):
rule.ruleCount = uint64(len(rule.destinationAddressItems) + len(rule.destinationIPCIDRItems))
case len(rule.allItems) == len(rule.sourceAddressItems):
rule.ruleCount = uint64(len(rule.sourceAddressItems))
case len(rule.allItems) == len(rule.sourcePortItems):
rule.ruleCount = uint64(len(rule.sourcePortItems))
case len(rule.allItems) == len(rule.destinationPortItems):
rule.ruleCount = uint64(len(rule.destinationPortItems))
default:
rule.ruleCount = 1
}
return rule, nil return rule, nil
} }
@ -190,5 +202,6 @@ func NewLogicalHeadlessRule(router adapter.Router, options option.LogicalHeadles
} }
r.rules[i] = rule r.rules[i] = rule
} }
r.ruleCount = 1
return r, nil return r, nil
} }

View File

@ -26,9 +26,11 @@ type abstractRuleSet struct {
router adapter.Router router adapter.Router
logger logger.ContextLogger logger logger.ContextLogger
tag string tag string
sType string
path string path string
format string format string
rules []adapter.HeadlessRule rules []adapter.HeadlessRule
ruleCount uint64
metadata adapter.RuleSetMetadata metadata adapter.RuleSetMetadata
lastUpdated time.Time lastUpdated time.Time
refs atomic.Int32 refs atomic.Int32
@ -38,6 +40,22 @@ func (s *abstractRuleSet) Name() string {
return s.tag return s.tag
} }
func (s *abstractRuleSet) Type() string {
return s.sType
}
func (s *abstractRuleSet) Format() string {
return s.format
}
func (s *abstractRuleSet) RuleCount() uint64 {
return s.ruleCount
}
func (s *abstractRuleSet) UpdatedTime() time.Time {
return s.lastUpdated
}
func (s *abstractRuleSet) String() string { func (s *abstractRuleSet) String() string {
return strings.Join(F.MapToString(s.rules), " ") return strings.Join(F.MapToString(s.rules), " ")
} }
@ -129,18 +147,21 @@ func (s *abstractRuleSet) loadBytes(content []byte) error {
func (s *abstractRuleSet) reloadRules(headlessRules []option.HeadlessRule) error { func (s *abstractRuleSet) reloadRules(headlessRules []option.HeadlessRule) error {
rules := make([]adapter.HeadlessRule, len(headlessRules)) rules := make([]adapter.HeadlessRule, len(headlessRules))
var err error var ruleCount uint64
for i, ruleOptions := range headlessRules { for i, ruleOptions := range headlessRules {
rules[i], err = NewHeadlessRule(s.router, ruleOptions) rule, err := NewHeadlessRule(s.router, ruleOptions)
if err != nil { if err != nil {
return E.Cause(err, "parse rule_set.rules.[", i, "]") return E.Cause(err, "parse rule_set.rules.[", i, "]")
} }
rules[i] = rule
ruleCount += rule.RuleCount()
} }
var metadata adapter.RuleSetMetadata var metadata adapter.RuleSetMetadata
metadata.ContainsProcessRule = hasHeadlessRule(headlessRules, isProcessHeadlessRule) metadata.ContainsProcessRule = hasHeadlessRule(headlessRules, isProcessHeadlessRule)
metadata.ContainsWIFIRule = hasHeadlessRule(headlessRules, isWIFIHeadlessRule) metadata.ContainsWIFIRule = hasHeadlessRule(headlessRules, isWIFIHeadlessRule)
metadata.ContainsIPCIDRRule = hasHeadlessRule(headlessRules, isIPCIDRHeadlessRule) metadata.ContainsIPCIDRRule = hasHeadlessRule(headlessRules, isIPCIDRHeadlessRule)
s.rules = rules s.rules = rules
s.ruleCount = ruleCount
s.metadata = metadata s.metadata = metadata
return nil return nil
} }

View File

@ -28,6 +28,8 @@ func NewLocalRuleSet(ctx context.Context, router adapter.Router, logger logger.C
router: router, router: router,
logger: logger, logger: logger,
tag: options.Tag, tag: options.Tag,
sType: options.Type,
format: options.Format,
}, },
} }
if options.Type == C.RuleSetTypeInline { if options.Type == C.RuleSetTypeInline {
@ -41,7 +43,6 @@ func NewLocalRuleSet(ctx context.Context, router adapter.Router, logger logger.C
return ruleSet, nil return ruleSet, nil
} }
ruleSet.path = options.Path ruleSet.path = options.Path
ruleSet.format = options.Format
path, err := ruleSet.getPath(options.Path) path, err := ruleSet.getPath(options.Path)
if err != nil { if err != nil {
return nil, err return nil, err
@ -89,6 +90,10 @@ func (s *LocalRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback)
func (s *LocalRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUpdateCallback]) { func (s *LocalRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUpdateCallback]) {
} }
func (s *LocalRuleSet) Update(ctx context.Context) error {
return nil
}
func (s *LocalRuleSet) Close() error { func (s *LocalRuleSet) Close() error {
s.rules = nil s.rules = nil
return common.Close(common.PtrOrNil(s.watcher)) return common.Close(common.PtrOrNil(s.watcher))

View File

@ -157,6 +157,16 @@ func (s *RemoteRuleSet) update() {
} }
} }
func (s *RemoteRuleSet) Update(ctx context.Context) error {
err := s.fetchOnce(log.ContextWithNewID(ctx), nil)
if err != nil {
return err
} else if s.refs.Load() == 0 {
s.rules = nil
}
return nil
}
func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.RuleSetStartContext) error { func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.RuleSetStartContext) error {
s.logger.DebugContext(ctx, "updating rule-set ", s.tag, " from URL: ", s.options.URL) s.logger.DebugContext(ctx, "updating rule-set ", s.tag, " from URL: ", s.options.URL)
var httpClient *http.Client var httpClient *http.Client