From 8e82bd96d9f12ee97f2a0c893417c7e4aedeef5a Mon Sep 17 00:00:00 2001 From: Hellojack <106379370+h1jk@users.noreply.github.com> Date: Sat, 27 Aug 2022 15:02:45 +0800 Subject: [PATCH] Add gRPC-lite implementation --- transport/v2ray/grpc.go | 2 +- .../v2ray/{grpc_stub.go => grpc_lite.go} | 9 +- transport/v2ray/transport.go | 2 +- transport/v2raygrpclite/client.go | 104 +++++++++++++ transport/v2raygrpclite/conn.go | 141 ++++++++++++++++++ transport/v2raygrpclite/server.go | 126 ++++++++++++++++ 6 files changed, 377 insertions(+), 7 deletions(-) rename transport/v2ray/{grpc_stub.go => grpc_lite.go} (60%) create mode 100644 transport/v2raygrpclite/client.go create mode 100644 transport/v2raygrpclite/conn.go create mode 100644 transport/v2raygrpclite/server.go diff --git a/transport/v2ray/grpc.go b/transport/v2ray/grpc.go index 41e70bbe..d7002a56 100644 --- a/transport/v2ray/grpc.go +++ b/transport/v2ray/grpc.go @@ -13,7 +13,7 @@ import ( N "github.com/sagernet/sing/common/network" ) -func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler) (adapter.V2RayServerTransport, error) { +func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler, errorHandler E.Handler) (adapter.V2RayServerTransport, error) { return v2raygrpc.NewServer(ctx, options, tlsConfig, handler), nil } diff --git a/transport/v2ray/grpc_stub.go b/transport/v2ray/grpc_lite.go similarity index 60% rename from transport/v2ray/grpc_stub.go rename to transport/v2ray/grpc_lite.go index 971492f5..35eb4462 100644 --- a/transport/v2ray/grpc_stub.go +++ b/transport/v2ray/grpc_lite.go @@ -8,17 +8,16 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/v2raygrpclite" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) -var errGRPCNotIncluded = E.New("gRPC is not included in this build, rebuild with -tags with_grpc") - -func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler) (adapter.V2RayServerTransport, error) { - return nil, errGRPCNotIncluded +func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler, errorHandler E.Handler) (adapter.V2RayServerTransport, error) { + return v2raygrpclite.NewServer(ctx, options, tlsConfig, handler, errorHandler), nil } func NewGRPCClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig *tls.Config) (adapter.V2RayClientTransport, error) { - return nil, errGRPCNotIncluded + return v2raygrpclite.NewClient(ctx, dialer, serverAddr, options, tlsConfig), nil } diff --git a/transport/v2ray/transport.go b/transport/v2ray/transport.go index 9411baaf..fc4f02e0 100644 --- a/transport/v2ray/transport.go +++ b/transport/v2ray/transport.go @@ -29,7 +29,7 @@ func NewServerTransport(ctx context.Context, options option.V2RayTransportOption } return NewQUICServer(ctx, options.QUICOptions, tlsConfig, handler, errorHandler) case C.V2RayTransportTypeGRPC: - return NewGRPCServer(ctx, options.GRPCOptions, tlsConfig, handler) + return NewGRPCServer(ctx, options.GRPCOptions, tlsConfig, handler, errorHandler) default: return nil, E.New("unknown transport type: " + options.Type) } diff --git a/transport/v2raygrpclite/client.go b/transport/v2raygrpclite/client.go new file mode 100644 index 00000000..46287e01 --- /dev/null +++ b/transport/v2raygrpclite/client.go @@ -0,0 +1,104 @@ +package v2raygrpclite + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/url" + + "github.com/sagernet/sing-box/adapter" + D "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "golang.org/x/net/http2" +) + +var _ adapter.V2RayClientTransport = (*Client)(nil) + +var defaultClientHeader = http.Header{ + "Content-Type": []string{"application/grpc"}, + "User-Agent": []string{"grpc-go/1.48.0"}, + "TE": []string{"trailers"}, +} + +type Client struct { + ctx context.Context + dialer N.Dialer + serverAddr M.Socksaddr + client *http.Client + options option.V2RayGRPCOptions + url *url.URL +} + +func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig *tls.Config) adapter.V2RayClientTransport { + return &Client{ + ctx: ctx, + dialer: dialer, + serverAddr: serverAddr, + options: options, + client: &http.Client{ + Transport: &http2.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + if err != nil { + return nil, err + } + tlsConn, err := D.TLSClient(ctx, conn, cfg) + if err != nil { + return nil, err + } + return tlsConn, nil + }, + TLSClientConfig: tlsConfig, + AllowHTTP: false, + DisableCompression: true, + PingTimeout: 0, + }, + }, + url: &url.URL{ + Scheme: "https", + Host: serverAddr.String(), + Path: fmt.Sprintf("/%s/Tun", url.QueryEscape(options.ServiceName)), + }, + } +} + +func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { + requestPipeReader, requestPipeWriter := io.Pipe() + request := (&http.Request{ + Method: http.MethodPost, + Body: requestPipeReader, + URL: c.url, + Proto: "HTTP/2", + ProtoMajor: 2, + ProtoMinor: 0, + Header: defaultClientHeader, + }).WithContext(ctx) + responsePipeReader, responsePipeWriter := io.Pipe() + go func() { + defer responsePipeWriter.Close() + response, err := c.client.Do(request) + if err != nil { + return + } + bufio.Copy(responsePipeWriter, response.Body) + }() + return newGunConn(responsePipeReader, requestPipeWriter, ChainedClosable{requestPipeReader, requestPipeWriter, responsePipeReader}), nil +} + +type ChainedClosable []io.Closer + +// Close implements io.Closer.Close(). +func (cc ChainedClosable) Close() error { + for _, c := range cc { + _ = common.Close(c) + } + return nil +} diff --git a/transport/v2raygrpclite/conn.go b/transport/v2raygrpclite/conn.go new file mode 100644 index 00000000..0a827a11 --- /dev/null +++ b/transport/v2raygrpclite/conn.go @@ -0,0 +1,141 @@ +// Modified from: https://github.com/Qv2ray/gun-lite +// License: MIT + +package v2raygrpclite + +import ( + "bytes" + "encoding/binary" + "io" + "net" + "net/http" + "os" + "sync" + "time" + + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" +) + +var ErrInvalidLength = E.New("invalid length") + +var _ net.Conn = (*GunConn)(nil) + +type GunConn struct { + reader io.Reader + writer io.Writer + closer io.Closer + // mu protect done + mu sync.Mutex + done chan struct{} + + toRead []byte + readAt int +} + +func newGunConn(reader io.Reader, writer io.Writer, closer io.Closer) *GunConn { + return &GunConn{ + reader: reader, + writer: writer, + closer: closer, + done: make(chan struct{}), + } +} + +func (c *GunConn) isClosed() bool { + select { + case <-c.done: + return true + default: + return false + } +} + +func (c *GunConn) Read(b []byte) (n int, err error) { + if c.toRead != nil { + n = copy(b, c.toRead[c.readAt:]) + c.readAt += n + if c.readAt >= len(c.toRead) { + buf.Put(c.toRead) + c.toRead = nil + } + return n, nil + } + buffer := buf.Get(5) + _, err = io.ReadFull(c.reader, buffer) + if err != nil { + return 0, err + } + grpcPayloadLen := binary.BigEndian.Uint32(buffer[1:]) + buf.Put(buffer) + + buffer = buf.Get(int(grpcPayloadLen)) + _, err = io.ReadFull(c.reader, buffer) + if err != nil { + return 0, io.ErrUnexpectedEOF + } + protobufPayloadLen, protobufLengthLen := binary.Uvarint(buffer[1:]) + if protobufLengthLen == 0 { + return 0, ErrInvalidLength + } + if grpcPayloadLen != uint32(protobufPayloadLen)+uint32(protobufLengthLen)+1 { + return 0, ErrInvalidLength + } + n = copy(b, buffer[1+protobufLengthLen:]) + if n < int(protobufPayloadLen) { + c.toRead = buffer + c.readAt = 1 + int(protobufLengthLen) + n + return n, nil + } + return n, nil +} + +func (c *GunConn) Write(b []byte) (n int, err error) { + if c.isClosed() { + return 0, io.ErrClosedPipe + } + protobufHeader := [1 + binary.MaxVarintLen64]byte{0x0A} + varuintLen := binary.PutUvarint(protobufHeader[1:], uint64(len(b))) + grpcHeader := buf.Get(5) + grpcPayloadLen := uint32(1 + varuintLen + len(b)) + binary.BigEndian.PutUint32(grpcHeader[1:5], grpcPayloadLen) + _, err = bufio.Copy(c.writer, io.MultiReader(bytes.NewReader(grpcHeader), bytes.NewReader(protobufHeader[:varuintLen+1]), bytes.NewReader(b))) + buf.Put(grpcHeader) + if f, ok := c.writer.(http.Flusher); ok { + f.Flush() + } + return len(b), err +} + +func (c *GunConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + select { + case <-c.done: + return nil + default: + close(c.done) + return c.closer.Close() + } +} + +func (c *GunConn) LocalAddr() net.Addr { + return nil +} + +func (c *GunConn) RemoteAddr() net.Addr { + return nil +} + +func (c *GunConn) SetDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *GunConn) SetReadDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *GunConn) SetWriteDeadline(t time.Time) error { + return os.ErrInvalid +} diff --git a/transport/v2raygrpclite/server.go b/transport/v2raygrpclite/server.go new file mode 100644 index 00000000..1b59af1d --- /dev/null +++ b/transport/v2raygrpclite/server.go @@ -0,0 +1,126 @@ +package v2raygrpclite + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/url" + "os" + "strings" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + sHttp "github.com/sagernet/sing/protocol/http" + + "golang.org/x/net/http2" +) + +var _ adapter.V2RayServerTransport = (*Server)(nil) + +type Server struct { + ctx context.Context + canceler context.CancelFunc + handler N.TCPConnectionHandler + errorHandler E.Handler + h2Opts *http2.ServeConnOpts + h2Server *http2.Server + path string + tlsConfig *tls.Config +} + +func (s *Server) Network() []string { + return []string{N.NetworkTCP} +} + +func NewServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig *tls.Config, handler N.TCPConnectionHandler, errorHandler E.Handler) *Server { + server := &Server{ + handler: handler, + errorHandler: errorHandler, + path: fmt.Sprintf("/%s/Tun", url.QueryEscape(options.ServiceName)), + tlsConfig: tlsConfig, + h2Server: &http2.Server{}, + } + server.ctx, server.canceler = context.WithCancel(ctx) + if !common.Contains(tlsConfig.NextProtos, http2.NextProtoTLS) { + tlsConfig.NextProtos = append(tlsConfig.NextProtos, http2.NextProtoTLS) + } + server.h2Opts = &http2.ServeConnOpts{ + Context: ctx, + Handler: server, + BaseConfig: &http.Server{ + ReadHeaderTimeout: C.TCPTimeout, + MaxHeaderBytes: http.DefaultMaxHeaderBytes, + Handler: server, + }, + } + return server +} + +func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if request.URL.Path != s.path { + writer.WriteHeader(http.StatusNotFound) + s.badRequest(request, E.New("bad path: ", request.URL.Path)) + return + } + if request.Method != http.MethodPost { + writer.WriteHeader(http.StatusNotFound) + s.badRequest(request, E.New("bad method: ", request.Method)) + return + } + if ct := request.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/grpc") { + writer.WriteHeader(http.StatusNotFound) + s.badRequest(request, E.New("bad content type: ", ct)) + return + } + + writer.Header().Set("Content-Type", "application/grpc") + writer.Header().Set("TE", "trailers") + + writer.WriteHeader(http.StatusOK) + //if f, ok := writer.(http.Flusher); ok { + // f.Flush() + //} + var metadata M.Metadata + metadata.Source = sHttp.SourceAddress(request) + conn := newGunConn(request.Body, writer, request.Body) + s.handler.NewConnection(request.Context(), conn, metadata) +} + +func (s *Server) badRequest(request *http.Request, err error) { + s.errorHandler.NewError(request.Context(), E.Cause(err, "process connection from ", request.RemoteAddr)) +} + +func (s *Server) Serve(listener net.Listener) error { + tlsEnabled := s.tlsConfig != nil + for { + conn, err := listener.Accept() + if err != nil { + return err + } + if tlsEnabled { + tlsConn := tls.Server(conn, s.tlsConfig.Clone()) + err = tlsConn.HandshakeContext(s.ctx) + if err != nil { + continue + } + conn = tlsConn + } + go s.h2Server.ServeConn(conn, s.h2Opts) + } +} + +func (s *Server) ServePacket(listener net.PacketConn) error { + return os.ErrInvalid +} + +func (s *Server) Close() error { + s.canceler() + return nil +}