diff --git a/option/v2ray_transport.go b/option/v2ray_transport.go index 6c45b178..f5c6b962 100644 --- a/option/v2ray_transport.go +++ b/option/v2ray_transport.go @@ -78,4 +78,5 @@ type V2RayQUICOptions struct{} type V2RayGRPCOptions struct { ServiceName string `json:"service_name,omitempty"` + ForceLite bool `json:"-"` // for test } diff --git a/test/v2ray_transport_test.go b/test/v2ray_transport_test.go index 326b6e0f..0632cec8 100644 --- a/test/v2ray_transport_test.go +++ b/test/v2ray_transport_test.go @@ -20,6 +20,52 @@ func TestV2RayGRPCSelf(t *testing.T) { }) } +func TestV2RayGRPCLite(t *testing.T) { + t.Run("server", func(t *testing.T) { + testV2RayTransportSelfWith(t, &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeGRPC, + GRPCOptions: option.V2RayGRPCOptions{ + ServiceName: "TunService", + ForceLite: true, + }, + }, &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeGRPC, + GRPCOptions: option.V2RayGRPCOptions{ + ServiceName: "TunService", + }, + }) + }) + t.Run("client", func(t *testing.T) { + testV2RayTransportSelfWith(t, &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeGRPC, + GRPCOptions: option.V2RayGRPCOptions{ + ServiceName: "TunService", + }, + }, &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeGRPC, + GRPCOptions: option.V2RayGRPCOptions{ + ServiceName: "TunService", + ForceLite: true, + }, + }) + }) + t.Run("self", func(t *testing.T) { + testV2RayTransportSelfWith(t, &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeGRPC, + GRPCOptions: option.V2RayGRPCOptions{ + ServiceName: "TunService", + ForceLite: true, + }, + }, &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeGRPC, + GRPCOptions: option.V2RayGRPCOptions{ + ServiceName: "TunService", + ForceLite: true, + }, + }) + }) +} + func TestV2RayWebscoketSelf(t *testing.T) { t.Run("basic", func(t *testing.T) { testV2RayTransportSelf(t, &option.V2RayTransportOptions{ @@ -58,15 +104,19 @@ func TestV2RayHTTPPlainSelf(t *testing.T) { } func testV2RayTransportSelf(t *testing.T, transport *option.V2RayTransportOptions) { + testV2RayTransportSelfWith(t, transport, transport) +} + +func testV2RayTransportSelfWith(t *testing.T, server, client *option.V2RayTransportOptions) { t.Run("vmess", func(t *testing.T) { - testVMessTransportSelf(t, transport) + testVMessTransportSelf(t, server, client) }) t.Run("trojan", func(t *testing.T) { - testTrojanTransportSelf(t, transport) + testTrojanTransportSelf(t, server, client) }) } -func testVMessTransportSelf(t *testing.T, transport *option.V2RayTransportOptions) { +func testVMessTransportSelf(t *testing.T, server *option.V2RayTransportOptions, client *option.V2RayTransportOptions) { user, err := uuid.DefaultGenerator.NewV4() require.NoError(t, err) _, certPem, keyPem := createSelfSignedCertificate(t, "example.org") @@ -104,7 +154,7 @@ func testVMessTransportSelf(t *testing.T, transport *option.V2RayTransportOption CertificatePath: certPem, KeyPath: keyPem, }, - Transport: transport, + Transport: server, }, }, }, @@ -127,7 +177,7 @@ func testVMessTransportSelf(t *testing.T, transport *option.V2RayTransportOption ServerName: "example.org", CertificatePath: certPem, }, - Transport: transport, + Transport: client, }, }, }, @@ -145,7 +195,7 @@ func testVMessTransportSelf(t *testing.T, transport *option.V2RayTransportOption testSuit(t, clientPort, testPort) } -func testTrojanTransportSelf(t *testing.T, transport *option.V2RayTransportOptions) { +func testTrojanTransportSelf(t *testing.T, server *option.V2RayTransportOptions, client *option.V2RayTransportOptions) { user, err := uuid.DefaultGenerator.NewV4() require.NoError(t, err) _, certPem, keyPem := createSelfSignedCertificate(t, "example.org") @@ -183,7 +233,7 @@ func testTrojanTransportSelf(t *testing.T, transport *option.V2RayTransportOptio CertificatePath: certPem, KeyPath: keyPem, }, - Transport: transport, + Transport: server, }, }, }, @@ -205,7 +255,7 @@ func testTrojanTransportSelf(t *testing.T, transport *option.V2RayTransportOptio ServerName: "example.org", CertificatePath: certPem, }, - Transport: transport, + Transport: client, }, }, }, diff --git a/transport/v2ray/grpc.go b/transport/v2ray/grpc.go index d7002a56..a6f031f8 100644 --- a/transport/v2ray/grpc.go +++ b/transport/v2ray/grpc.go @@ -9,14 +9,22 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/v2raygrpc" + "github.com/sagernet/sing-box/transport/v2raygrpclite" + E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler, errorHandler E.Handler) (adapter.V2RayServerTransport, error) { + if options.ForceLite { + return v2raygrpclite.NewServer(ctx, options, tlsConfig, handler, errorHandler), nil + } return v2raygrpc.NewServer(ctx, options, tlsConfig, handler), nil } func NewGRPCClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) { + if options.ForceLite { + return v2raygrpclite.NewClient(ctx, dialer, serverAddr, options, tlsConfig), nil + } return v2raygrpc.NewClient(ctx, dialer, serverAddr, options, tlsConfig), nil } diff --git a/transport/v2raygrpclite/client.go b/transport/v2raygrpclite/client.go index 46287e01..3610f94c 100644 --- a/transport/v2raygrpclite/client.go +++ b/transport/v2raygrpclite/client.go @@ -10,14 +10,11 @@ import ( "net/url" "github.com/sagernet/sing-box/adapter" - D "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - - "golang.org/x/net/http2" ) var _ adapter.V2RayClientTransport = (*Client)(nil) @@ -44,22 +41,12 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt serverAddr: serverAddr, options: options, client: &http.Client{ - Transport: &http2.Transport{ - DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { - conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - if err != nil { - return nil, err - } - tlsConn, err := D.TLSClient(ctx, conn, cfg) - if err != nil { - return nil, err - } - return tlsConn, nil + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) }, - TLSClientConfig: tlsConfig, - AllowHTTP: false, - DisableCompression: true, - PingTimeout: 0, + ForceAttemptHTTP2: true, + TLSClientConfig: tlsConfig, }, }, url: &url.URL{ @@ -72,7 +59,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { requestPipeReader, requestPipeWriter := io.Pipe() - request := (&http.Request{ + request := &http.Request{ Method: http.MethodPost, Body: requestPipeReader, URL: c.url, @@ -80,7 +67,8 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { ProtoMajor: 2, ProtoMinor: 0, Header: defaultClientHeader, - }).WithContext(ctx) + } + request = request.WithContext(ctx) responsePipeReader, responsePipeWriter := io.Pipe() go func() { defer responsePipeWriter.Close() diff --git a/transport/v2raygrpclite/conn.go b/transport/v2raygrpclite/conn.go index 0a827a11..8f1d2354 100644 --- a/transport/v2raygrpclite/conn.go +++ b/transport/v2raygrpclite/conn.go @@ -10,7 +10,6 @@ import ( "net" "net/http" "os" - "sync" "time" "github.com/sagernet/sing/common/buf" @@ -23,15 +22,11 @@ var ErrInvalidLength = E.New("invalid length") var _ net.Conn = (*GunConn)(nil) type GunConn struct { - reader io.Reader - writer io.Writer - closer io.Closer - // mu protect done - mu sync.Mutex - done chan struct{} - - toRead []byte - readAt int + reader io.Reader + writer io.Writer + closer io.Closer + cached []byte + cachedIndex int } func newGunConn(reader io.Reader, writer io.Writer, closer io.Closer) *GunConn { @@ -39,28 +34,18 @@ func newGunConn(reader io.Reader, writer io.Writer, closer io.Closer) *GunConn { reader: reader, writer: writer, closer: closer, - done: make(chan struct{}), - } -} - -func (c *GunConn) isClosed() bool { - select { - case <-c.done: - return true - default: - return false } } func (c *GunConn) Read(b []byte) (n int, err error) { - if c.toRead != nil { - n = copy(b, c.toRead[c.readAt:]) - c.readAt += n - if c.readAt >= len(c.toRead) { - buf.Put(c.toRead) - c.toRead = nil + if c.cached != nil { + n = copy(b, c.cached[c.cachedIndex:]) + c.cachedIndex += n + if c.cachedIndex == len(c.cached) { + buf.Put(c.cached) + c.cached = nil } - return n, nil + return } buffer := buf.Get(5) _, err = io.ReadFull(c.reader, buffer) @@ -84,17 +69,14 @@ func (c *GunConn) Read(b []byte) (n int, err error) { } n = copy(b, buffer[1+protobufLengthLen:]) if n < int(protobufPayloadLen) { - c.toRead = buffer - c.readAt = 1 + int(protobufLengthLen) + n + c.cached = buffer + c.cachedIndex = 1 + int(protobufLengthLen) + n return n, nil } return n, nil } func (c *GunConn) Write(b []byte) (n int, err error) { - if c.isClosed() { - return 0, io.ErrClosedPipe - } protobufHeader := [1 + binary.MaxVarintLen64]byte{0x0A} varuintLen := binary.PutUvarint(protobufHeader[1:], uint64(len(b))) grpcHeader := buf.Get(5) @@ -108,16 +90,14 @@ func (c *GunConn) Write(b []byte) (n int, err error) { return len(b), err } +/*func (c *GunConn) ReadBuffer(buffer *buf.Buffer) error { +} + +func (c *GunConn) WriteBuffer(buffer *buf.Buffer) error { +}*/ + func (c *GunConn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - select { - case <-c.done: - return nil - default: - close(c.done) - return c.closer.Close() - } + return c.closer.Close() } func (c *GunConn) LocalAddr() net.Addr { diff --git a/transport/v2raygrpclite/server.go b/transport/v2raygrpclite/server.go index 1b59af1d..f9ad2b73 100644 --- a/transport/v2raygrpclite/server.go +++ b/transport/v2raygrpclite/server.go @@ -11,28 +11,21 @@ import ( "strings" "github.com/sagernet/sing-box/adapter" - C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" sHttp "github.com/sagernet/sing/protocol/http" - - "golang.org/x/net/http2" ) var _ adapter.V2RayServerTransport = (*Server)(nil) type Server struct { - ctx context.Context - canceler context.CancelFunc handler N.TCPConnectionHandler errorHandler E.Handler - h2Opts *http2.ServeConnOpts - h2Server *http2.Server + httpServer *http.Server path string - tlsConfig *tls.Config } func (s *Server) Network() []string { @@ -44,21 +37,13 @@ func NewServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig * handler: handler, errorHandler: errorHandler, path: fmt.Sprintf("/%s/Tun", url.QueryEscape(options.ServiceName)), - tlsConfig: tlsConfig, - h2Server: &http2.Server{}, } - server.ctx, server.canceler = context.WithCancel(ctx) - if !common.Contains(tlsConfig.NextProtos, http2.NextProtoTLS) { - tlsConfig.NextProtos = append(tlsConfig.NextProtos, http2.NextProtoTLS) + if !common.Contains(tlsConfig.NextProtos, "h2") { + tlsConfig.NextProtos = append(tlsConfig.NextProtos, "h2") } - server.h2Opts = &http2.ServeConnOpts{ - Context: ctx, - Handler: server, - BaseConfig: &http.Server{ - ReadHeaderTimeout: C.TCPTimeout, - MaxHeaderBytes: http.DefaultMaxHeaderBytes, - Handler: server, - }, + server.httpServer = &http.Server{ + Handler: server, + TLSConfig: tlsConfig, } return server } @@ -79,14 +64,9 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { s.badRequest(request, E.New("bad content type: ", ct)) return } - writer.Header().Set("Content-Type", "application/grpc") writer.Header().Set("TE", "trailers") - writer.WriteHeader(http.StatusOK) - //if f, ok := writer.(http.Flusher); ok { - // f.Flush() - //} var metadata M.Metadata metadata.Source = sHttp.SourceAddress(request) conn := newGunConn(request.Body, writer, request.Body) @@ -98,21 +78,10 @@ func (s *Server) badRequest(request *http.Request, err error) { } func (s *Server) Serve(listener net.Listener) error { - tlsEnabled := s.tlsConfig != nil - for { - conn, err := listener.Accept() - if err != nil { - return err - } - if tlsEnabled { - tlsConn := tls.Server(conn, s.tlsConfig.Clone()) - err = tlsConn.HandshakeContext(s.ctx) - if err != nil { - continue - } - conn = tlsConn - } - go s.h2Server.ServeConn(conn, s.h2Opts) + if s.httpServer.TLSConfig == nil { + return s.httpServer.Serve(listener) + } else { + return s.httpServer.ServeTLS(listener, "", "") } } @@ -121,6 +90,5 @@ func (s *Server) ServePacket(listener net.PacketConn) error { } func (s *Server) Close() error { - s.canceler() - return nil + return common.Close(common.PtrOrNil(s.httpServer)) }