Update wireguard-go

This commit is contained in:
世界 2023-04-20 13:16:31 +08:00
parent 2e98777f82
commit f61c5600e0
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
7 changed files with 97 additions and 158 deletions

2
go.mod
View File

@ -35,7 +35,7 @@ require (
github.com/sagernet/tfo-go v0.0.0-20230303015439-ffcfd8c41cf9 github.com/sagernet/tfo-go v0.0.0-20230303015439-ffcfd8c41cf9
github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2 github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2
github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e
github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c github.com/sagernet/wireguard-go v0.0.0-20230420044414-a7bac1754e77
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.8.2 github.com/stretchr/testify v1.8.2
go.etcd.io/bbolt v1.3.7 go.etcd.io/bbolt v1.3.7

4
go.sum
View File

@ -131,8 +131,8 @@ github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2 h1:kDUqhc9Vsk5HJuhfI
github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2/go.mod h1:JKQMZq/O2qnZjdrt+B57olmfgEmLtY9iiSIEYtWvoSM= github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2/go.mod h1:JKQMZq/O2qnZjdrt+B57olmfgEmLtY9iiSIEYtWvoSM=
github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e h1:7uw2njHFGE+VpWamge6o56j2RWk4omF6uLKKxMmcWvs= github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e h1:7uw2njHFGE+VpWamge6o56j2RWk4omF6uLKKxMmcWvs=
github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e/go.mod h1:45TUl8+gH4SIKr4ykREbxKWTxkDlSzFENzctB1dVRRY= github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e/go.mod h1:45TUl8+gH4SIKr4ykREbxKWTxkDlSzFENzctB1dVRRY=
github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c h1:vK2wyt9aWYHHvNLWniwijBu/n4pySypiKRhN32u/JGo= github.com/sagernet/wireguard-go v0.0.0-20230420044414-a7bac1754e77 h1:g6QtRWQ2dKX7EQP++1JLNtw4C2TNxd4/ov8YUpOPOSo=
github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c/go.mod h1:euOmN6O5kk9dQmgSS8Df4psAl3TCjxOz0NW60EWkSaI= github.com/sagernet/wireguard-go v0.0.0-20230420044414-a7bac1754e77/go.mod h1:pJDdXzZIwJ+2vmnT0TKzmf8meeum+e2mTDSehw79eE0=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=

View File

@ -101,7 +101,7 @@ func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint1
return []conn.ReceiveFunc{c.receive}, 0, nil return []conn.ReceiveFunc{c.receive}, 0, nil
} }
func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) { func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
udpConn, err := c.connect() udpConn, err := c.connect()
if err != nil { if err != nil {
select { select {
@ -113,22 +113,26 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
err = nil err = nil
return return
} }
n, addr, err := udpConn.ReadFrom(b) n, addr, err := udpConn.ReadFrom(packets[0])
if err != nil { if err != nil {
udpConn.Close() udpConn.Close()
select { select {
case <-c.done: case <-c.done:
default: default:
c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet")) c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet"))
err = nil
} }
return return
} }
sizes[0] = n
if n > 3 { if n > 3 {
b := packets[0]
b[1] = 0 b[1] = 0
b[2] = 0 b[2] = 0
b[3] = 0 b[3] = 0
} }
ep = Endpoint(M.SocksaddrFromNet(addr)) eps[0] = Endpoint(M.SocksaddrFromNet(addr))
count = 1
return return
} }
@ -155,12 +159,13 @@ func (c *ClientBind) SetMark(mark uint32) error {
return nil return nil
} }
func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error { func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
udpConn, err := c.connect() udpConn, err := c.connect()
if err != nil { if err != nil {
return err return err
} }
destination := M.Socksaddr(ep.(Endpoint)) destination := M.Socksaddr(ep.(Endpoint))
for _, b := range bufs {
if len(b) > 3 { if len(b) > 3 {
reserved, loaded := c.reservedForEndpoint[destination] reserved, loaded := c.reservedForEndpoint[destination]
if !loaded { if !loaded {
@ -173,14 +178,20 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
_, err = udpConn.WriteTo(b, destination) _, err = udpConn.WriteTo(b, destination)
if err != nil { if err != nil {
udpConn.Close() udpConn.Close()
}
return err return err
}
}
return nil
} }
func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) { func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
return Endpoint(M.ParseSocksaddr(s)), nil return Endpoint(M.ParseSocksaddr(s)), nil
} }
func (c *ClientBind) BatchSize() int {
return 1
}
type wireConn struct { type wireConn struct {
net.PacketConn net.PacketConn
access sync.Mutex access sync.Mutex

View File

@ -26,23 +26,24 @@ func NewNATDevice(upstream Device, ipRewrite bool) NatDevice {
return wrapper return wrapper
} }
func (d *natDeviceWrapper) Read(p []byte, offset int) (int, error) { func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
select { select {
case packet := <-d.outbound: case packet := <-d.outbound:
defer packet.Release() defer packet.Release()
return copy(p[offset:], packet.Bytes()), nil sizes[0] = copy(bufs[0][offset:], packet.Bytes())
return 1, nil
default: default:
} }
return d.Device.Read(p, offset) return d.Device.Read(bufs, sizes, offset)
} }
func (d *natDeviceWrapper) Write(p []byte, offset int) (int, error) { func (d *natDeviceWrapper) Write(bufs [][]byte, offset int) (count int, err error) {
packet := p[offset:] packet := bufs[0][offset:]
handled, err := d.mapping.WritePacket(packet) handled, err := d.mapping.WritePacket(packet)
if handled { if handled {
return len(packet), err return 1, err
} }
return d.Device.Write(p, offset) return d.Device.Write(bufs, offset)
} }
func (d *natDeviceWrapper) CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination { func (d *natDeviceWrapper) CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination {

View File

@ -171,49 +171,60 @@ func (w *StackDevice) File() *os.File {
return nil return nil
} }
func (w *StackDevice) Read(p []byte, offset int) (n int, err error) { func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
select { select {
case packetBuffer, ok := <-w.outbound: case packetBuffer, ok := <-w.outbound:
if !ok { if !ok {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
defer packetBuffer.DecRef() defer packetBuffer.DecRef()
p := bufs[0]
p = p[offset:] p = p[offset:]
n := 0
for _, slice := range packetBuffer.AsSlices() { for _, slice := range packetBuffer.AsSlices() {
n += copy(p[n:], slice) n += copy(p[n:], slice)
} }
sizes[0] = n
count = 1
return return
case packet := <-w.packetOutbound: case packet := <-w.packetOutbound:
defer packet.Release() defer packet.Release()
n = copy(p[offset:], packet.Bytes()) sizes[0] = copy(bufs[0][offset:], packet.Bytes())
count = 1
return return
case <-w.done: case <-w.done:
return 0, os.ErrClosed return 0, os.ErrClosed
} }
} }
func (w *StackDevice) Write(p []byte, offset int) (n int, err error) { func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
p = p[offset:] for _, b := range bufs {
if len(p) == 0 { b = b[offset:]
return if len(b) == 0 {
continue
} }
handled, err := w.mapping.WritePacket(p) handled, err := w.mapping.WritePacket(b)
if handled { if handled {
return len(p), err count++
if err != nil {
return count, err
}
continue
} }
var networkProtocol tcpip.NetworkProtocolNumber var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(p) { switch header.IPVersion(b) {
case header.IPv4Version: case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber networkProtocol = header.IPv4ProtocolNumber
case header.IPv6Version: case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber networkProtocol = header.IPv6ProtocolNumber
} }
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: bufferv2.MakeWithData(p), Payload: bufferv2.MakeWithData(b),
}) })
defer packetBuffer.DecRef()
w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer) w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
n = len(p) packetBuffer.DecRef()
count++
}
return return
} }
@ -229,7 +240,7 @@ func (w *StackDevice) Name() (string, error) {
return "sing-box", nil return "sing-box", nil
} }
func (w *StackDevice) Events() chan wgTun.Event { func (w *StackDevice) Events() <-chan wgTun.Event {
return w.events return w.events
} }
@ -248,6 +259,10 @@ func (w *StackDevice) Close() error {
return nil return nil
} }
func (w *StackDevice) BatchSize() int {
return 1
}
func (w *StackDevice) CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination { func (w *StackDevice) CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination {
w.mapping.CreateSession(session, conn) w.mapping.CreateSession(session, conn)
return &stackNatDestination{ return &stackNatDestination{

View File

@ -27,14 +27,6 @@ type SystemDevice struct {
addr6 netip.Addr addr6 netip.Addr
} }
/*func (w *SystemDevice) NewEndpoint() (stack.LinkEndpoint, error) {
gTun, isGTun := w.device.(tun.GVisorTun)
if !isGTun {
return nil, tun.ErrGVisorUnsupported
}
return gTun.NewEndpoint()
}*/
func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32) (*SystemDevice, error) { func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32) (*SystemDevice, error) {
var inet4Addresses []netip.Prefix var inet4Addresses []netip.Prefix
var inet6Addresses []netip.Prefix var inet6Addresses []netip.Prefix
@ -103,12 +95,23 @@ func (w *SystemDevice) File() *os.File {
return nil return nil
} }
func (w *SystemDevice) Read(p []byte, offset int) (int, error) { func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
return w.device.Read(p[offset-tun.PacketOffset:]) sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:])
if err == nil {
count = 1
}
return
} }
func (w *SystemDevice) Write(p []byte, offset int) (int, error) { func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
return w.device.Write(p[offset:]) for _, b := range bufs {
_, err = w.device.Write(b[offset:])
if err != nil {
return
}
count++
}
return
} }
func (w *SystemDevice) Flush() error { func (w *SystemDevice) Flush() error {
@ -123,10 +126,14 @@ func (w *SystemDevice) Name() (string, error) {
return w.name, nil return w.name, nil
} }
func (w *SystemDevice) Events() chan wgTun.Event { func (w *SystemDevice) Events() <-chan wgTun.Event {
return w.events return w.events
} }
func (w *SystemDevice) Close() error { func (w *SystemDevice) Close() error {
return w.device.Close() return w.device.Close()
} }
func (w *SystemDevice) BatchSize() int {
return 1
}

View File

@ -1,95 +0,0 @@
package wireguard
import (
"io"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/conn"
)
var _ conn.Bind = (*ServerBind)(nil)
type ServerBind struct {
inbound chan serverPacket
done chan struct{}
writeBack N.PacketWriter
}
func NewServerBind(writeBack N.PacketWriter) *ServerBind {
return &ServerBind{
inbound: make(chan serverPacket, 256),
done: make(chan struct{}),
writeBack: writeBack,
}
}
func (s *ServerBind) Abort() error {
select {
case <-s.done:
return io.ErrClosedPipe
default:
close(s.done)
}
return nil
}
type serverPacket struct {
buffer *buf.Buffer
source M.Socksaddr
}
func (s *ServerBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
fns = []conn.ReceiveFunc{s.receive}
return
}
func (s *ServerBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
select {
case packet := <-s.inbound:
defer packet.buffer.Release()
n = copy(b, packet.buffer.Bytes())
ep = Endpoint(packet.source)
return
case <-s.done:
err = io.ErrClosedPipe
return
}
}
func (s *ServerBind) WriteIsThreadUnsafe() {
}
func (s *ServerBind) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
select {
case s.inbound <- serverPacket{
buffer: buffer,
source: destination,
}:
return nil
case <-s.done:
return io.ErrClosedPipe
}
}
func (s *ServerBind) Close() error {
return nil
}
func (s *ServerBind) SetMark(mark uint32) error {
return nil
}
func (s *ServerBind) Send(b []byte, ep conn.Endpoint) error {
return s.writeBack.WritePacket(buf.As(b), M.Socksaddr(ep.(Endpoint)))
}
func (s *ServerBind) ParseEndpoint(addr string) (conn.Endpoint, error) {
destination := M.ParseSocksaddr(addr)
if !destination.IsValid() || destination.Port == 0 {
return nil, E.New("invalid endpoint: ", addr)
}
return Endpoint(destination), nil
}