diff --git a/common/dialer/conntrack/conn.go b/common/dialer/conntrack/conn.go index 743e39e6..d4c678c2 100644 --- a/common/dialer/conntrack/conn.go +++ b/common/dialer/conntrack/conn.go @@ -1,29 +1,32 @@ package conntrack import ( + "io" "net" - "runtime/debug" "github.com/sagernet/sing/common/x/list" ) type Conn struct { net.Conn - element *list.Element[*ConnEntry] + element *list.Element[io.Closer] } -func NewConn(conn net.Conn) *Conn { - entry := &ConnEntry{ - Conn: conn, - Stack: debug.Stack(), - } +func NewConn(conn net.Conn) (*Conn, error) { connAccess.Lock() - element := openConnection.PushBack(entry) + element := openConnection.PushBack(conn) connAccess.Unlock() + if KillerEnabled { + err := killerCheck() + if err != nil { + conn.Close() + return nil, err + } + } return &Conn{ Conn: conn, element: element, - } + }, nil } func (c *Conn) Close() error { diff --git a/common/dialer/conntrack/killer.go b/common/dialer/conntrack/killer.go new file mode 100644 index 00000000..40224462 --- /dev/null +++ b/common/dialer/conntrack/killer.go @@ -0,0 +1,38 @@ +package conntrack + +import ( + "runtime" + runtimeDebug "runtime/debug" + "time" + + E "github.com/sagernet/sing/common/exceptions" +) + +var ( + KillerEnabled bool + MemoryLimit int64 + killerLastCheck time.Time +) + +func killerCheck() error { + if !KillerEnabled { + return nil + } + nowTime := time.Now() + if nowTime.Sub(killerLastCheck) < 3*time.Second { + return nil + } + killerLastCheck = nowTime + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + inuseMemory := int64(memStats.StackInuse + memStats.HeapInuse + memStats.HeapIdle - memStats.HeapReleased) + if inuseMemory > MemoryLimit { + Close() + go func() { + time.Sleep(time.Second) + runtimeDebug.FreeOSMemory() + }() + return E.New("out of memory") + } + return nil +} diff --git a/common/dialer/conntrack/packet_conn.go b/common/dialer/conntrack/packet_conn.go index 00c56cec..33028a69 100644 --- a/common/dialer/conntrack/packet_conn.go +++ b/common/dialer/conntrack/packet_conn.go @@ -1,29 +1,32 @@ package conntrack import ( + "io" "net" - "runtime/debug" "github.com/sagernet/sing/common/x/list" ) type PacketConn struct { net.PacketConn - element *list.Element[*ConnEntry] + element *list.Element[io.Closer] } -func NewPacketConn(conn net.PacketConn) *PacketConn { - entry := &ConnEntry{ - Conn: conn, - Stack: debug.Stack(), - } +func NewPacketConn(conn net.PacketConn) (*PacketConn, error) { connAccess.Lock() - element := openConnection.PushBack(entry) + element := openConnection.PushBack(conn) connAccess.Unlock() + if KillerEnabled { + err := killerCheck() + if err != nil { + conn.Close() + return nil, err + } + } return &PacketConn{ PacketConn: conn, element: element, - } + }, nil } func (c *PacketConn) Close() error { diff --git a/common/dialer/conntrack/track.go b/common/dialer/conntrack/track.go index 531a5857..3de0b8eb 100644 --- a/common/dialer/conntrack/track.go +++ b/common/dialer/conntrack/track.go @@ -10,22 +10,17 @@ import ( var ( connAccess sync.RWMutex - openConnection list.List[*ConnEntry] + openConnection list.List[io.Closer] ) -type ConnEntry struct { - Conn io.Closer - Stack []byte -} - func Count() int { return openConnection.Len() } -func List() []*ConnEntry { +func List() []io.Closer { connAccess.RLock() defer connAccess.RUnlock() - connList := make([]*ConnEntry, 0, openConnection.Len()) + connList := make([]io.Closer, 0, openConnection.Len()) for element := openConnection.Front(); element != nil; element = element.Next() { connList = append(connList, element.Value) } @@ -36,8 +31,8 @@ func Close() { connAccess.Lock() defer connAccess.Unlock() for element := openConnection.Front(); element != nil; element = element.Next() { - common.Close(element.Value.Conn) + common.Close(element.Value) element.Value = nil } - openConnection = list.List[*ConnEntry]{} + openConnection.Init() } diff --git a/common/dialer/default.go b/common/dialer/default.go index 760d2d3c..ce4cd7d8 100644 --- a/common/dialer/default.go +++ b/common/dialer/default.go @@ -178,12 +178,12 @@ func trackConn(conn net.Conn, err error) (net.Conn, error) { if !conntrack.Enabled || err != nil { return conn, err } - return conntrack.NewConn(conn), nil + return conntrack.NewConn(conn) } func trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) { if !conntrack.Enabled || err != nil { return conn, err } - return conntrack.NewPacketConn(conn), nil + return conntrack.NewPacketConn(conn) } diff --git a/experimental/libbox/memory.go b/experimental/libbox/memory.go index c3092b0e..173eaf7d 100644 --- a/experimental/libbox/memory.go +++ b/experimental/libbox/memory.go @@ -2,9 +2,17 @@ package libbox -import "runtime/debug" +import ( + runtimeDebug "runtime/debug" + + "github.com/sagernet/sing-box/common/dialer/conntrack" +) + +const memoryLimit = 30 * 1024 * 1024 func SetMemoryLimit() { - debug.SetGCPercent(10) - debug.SetMemoryLimit(30 * 1024 * 1024) + runtimeDebug.SetGCPercent(10) + runtimeDebug.SetMemoryLimit(memoryLimit) + conntrack.KillerEnabled = true + conntrack.MemoryLimit = memoryLimit }