diff --git a/go.mod b/go.mod index 2edc5a48..72b0da2e 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/sagernet/gomobile v0.0.0-20221130124640-349ebaa752ca github.com/sagernet/quic-go v0.0.0-20230202071646-a8c8afb18b32 github.com/sagernet/reality v0.0.0-20230226124550-f98d51fa21b5 - github.com/sagernet/sing v0.1.8-0.20230221060643-3401d210384b + github.com/sagernet/sing v0.1.8-0.20230226145949-3f0b21359af6 github.com/sagernet/sing-dns v0.1.4 github.com/sagernet/sing-shadowsocks v0.1.2-0.20230221080503-769c01d6bba9 github.com/sagernet/sing-shadowtls v0.0.0-20230221123345-78e50cd7b587 diff --git a/go.sum b/go.sum index 6fcf9d5d..7ee64517 100644 --- a/go.sum +++ b/go.sum @@ -129,8 +129,8 @@ github.com/sagernet/reality v0.0.0-20230226124550-f98d51fa21b5 h1:yDic66vLGsY3zq github.com/sagernet/reality v0.0.0-20230226124550-f98d51fa21b5/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.0.0-20220812082120-05f9836bff8f/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= -github.com/sagernet/sing v0.1.8-0.20230221060643-3401d210384b h1:Ji2AfGlc4j9AitobOx4k3BCj7eS5nSxL1cgaL81zvlo= -github.com/sagernet/sing v0.1.8-0.20230221060643-3401d210384b/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= +github.com/sagernet/sing v0.1.8-0.20230226145949-3f0b21359af6 h1:QLfccQ8S1nqw5+xYEM/xLXQDq70BjAeyuVWluIEytww= +github.com/sagernet/sing v0.1.8-0.20230226145949-3f0b21359af6/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= github.com/sagernet/sing-dns v0.1.4 h1:7VxgeoSCiiazDSaXXQVcvrTBxFpOePPq/4XdgnUDN+0= github.com/sagernet/sing-dns v0.1.4/go.mod h1:1+6pCa48B1AI78lD+/i/dLgpw4MwfnsSpZo0Ds8wzzk= github.com/sagernet/sing-shadowsocks v0.1.2-0.20230221080503-769c01d6bba9 h1:qS39eA4C7x+zhEkySbASrtmb6ebdy5v0y2M6mgkmSO0= diff --git a/outbound/default.go b/outbound/default.go index fd140fb1..1e615701 100644 --- a/outbound/default.go +++ b/outbound/default.go @@ -39,30 +39,6 @@ func (a *myOutboundAdapter) Network() []string { } func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error { - ctx = adapter.WithContext(ctx, &metadata) - var outConn net.Conn - var err error - if len(metadata.DestinationAddresses) > 0 { - outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses) - } else { - outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) - } - if err != nil { - return N.HandshakeFailure(conn, err) - } - if cachedReader, isCached := conn.(N.CachedReader); isCached { - payload := cachedReader.ReadCached() - if payload != nil && !payload.IsEmpty() { - _, err = outConn.Write(payload.Bytes()) - if err != nil { - return err - } - } - } - return bufio.CopyConn(ctx, conn, outConn) -} - -func NewEarlyConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error { ctx = adapter.WithContext(ctx, &metadata) var outConn net.Conn var err error @@ -111,28 +87,30 @@ func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) erro return bufio.CopyConn(ctx, conn, serverConn) } } - _payload := buf.StackNew() - payload := common.Dup(_payload) - err := conn.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout)) - if err != os.ErrInvalid { + if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](conn); isEarlyConn && earlyConn.NeedHandshake() { + _payload := buf.StackNew() + payload := common.Dup(_payload) + err := conn.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout)) + if err != os.ErrInvalid { + if err != nil { + return err + } + _, err = payload.ReadOnceFrom(conn) + if err != nil && !E.IsTimeout(err) { + return E.Cause(err, "read payload") + } + err = conn.SetReadDeadline(time.Time{}) + if err != nil { + payload.Release() + return err + } + } + _, err = serverConn.Write(payload.Bytes()) if err != nil { - return err - } - _, err = payload.ReadOnceFrom(conn) - if err != nil && !E.IsTimeout(err) { - return E.Cause(err, "read payload") - } - err = conn.SetReadDeadline(time.Time{}) - if err != nil { - payload.Release() - return err + return N.HandshakeFailure(conn, err) } + runtime.KeepAlive(_payload) + payload.Release() } - _, err = serverConn.Write(payload.Bytes()) - if err != nil { - return N.HandshakeFailure(conn, err) - } - runtime.KeepAlive(_payload) - payload.Release() return bufio.CopyConn(ctx, conn, serverConn) } diff --git a/outbound/shadowsocks.go b/outbound/shadowsocks.go index 63781fdb..2831d4b1 100644 --- a/outbound/shadowsocks.go +++ b/outbound/shadowsocks.go @@ -125,7 +125,7 @@ func (h *Shadowsocks) ListenPacket(ctx context.Context, destination M.Socksaddr) } func (h *Shadowsocks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - return NewEarlyConnection(ctx, h, conn, metadata) + return NewConnection(ctx, h, conn, metadata) } func (h *Shadowsocks) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { diff --git a/outbound/trojan.go b/outbound/trojan.go index 7c11d445..690d97bb 100644 --- a/outbound/trojan.go +++ b/outbound/trojan.go @@ -96,7 +96,7 @@ func (h *Trojan) ListenPacket(ctx context.Context, destination M.Socksaddr) (net } func (h *Trojan) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - return NewEarlyConnection(ctx, h, conn, metadata) + return NewConnection(ctx, h, conn, metadata) } func (h *Trojan) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { diff --git a/outbound/vless.go b/outbound/vless.go index 946b5c5b..31d4979f 100644 --- a/outbound/vless.go +++ b/outbound/vless.go @@ -135,7 +135,7 @@ func (h *VLESS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net. } func (h *VLESS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - return NewEarlyConnection(ctx, h, conn, metadata) + return NewConnection(ctx, h, conn, metadata) } func (h *VLESS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { diff --git a/outbound/vmess.go b/outbound/vmess.go index 8c5fa0b9..fdbf83f3 100644 --- a/outbound/vmess.go +++ b/outbound/vmess.go @@ -133,7 +133,7 @@ func (h *VMess) ListenPacket(ctx context.Context, destination M.Socksaddr) (net. } func (h *VMess) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - return NewEarlyConnection(ctx, h, conn, metadata) + return NewConnection(ctx, h, conn, metadata) } func (h *VMess) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { diff --git a/transport/trojan/protocol.go b/transport/trojan/protocol.go index d05a9d36..ad9174d1 100644 --- a/transport/trojan/protocol.go +++ b/transport/trojan/protocol.go @@ -26,6 +26,8 @@ const ( var CRLF = []byte{'\r', '\n'} +var _ N.EarlyConn = (*ClientConn)(nil) + type ClientConn struct { N.ExtendedConn key [KeyLength]byte @@ -41,6 +43,10 @@ func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) } } +func (c *ClientConn) NeedHandshake() bool { + return !c.headerWritten +} + func (c *ClientConn) Write(p []byte) (n int, err error) { if c.headerWritten { return c.ExtendedConn.Write(p) @@ -101,6 +107,10 @@ func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn { } } +func (c *ClientPacketConn) NeedHandshake() bool { + return !c.headerWritten +} + func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { return ReadPacket(c.Conn, buffer) } diff --git a/transport/vless/client.go b/transport/vless/client.go index 5a9fd2ed..dd70a2df 100644 --- a/transport/vless/client.go +++ b/transport/vless/client.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/gofrs/uuid" ) @@ -82,6 +83,8 @@ func (c *Client) DialEarlyXUDPPacketConn(conn net.Conn, destination M.Socksaddr) return vmess.NewXUDPConn(&Conn{Conn: conn, protocolConn: conn, key: c.key, command: vmess.CommandMux, destination: destination, flow: c.flow}, destination), nil } +var _ N.EarlyConn = (*Conn)(nil) + type Conn struct { net.Conn protocolConn net.Conn @@ -93,6 +96,10 @@ type Conn struct { responseRead bool } +func (c *Conn) NeedHandshake() bool { + return !c.requestWritten +} + func (c *Conn) Read(b []byte) (n int, err error) { if !c.responseRead { err = ReadResponse(c.Conn)