diff --git a/experimental/clashapi/mitm.go b/experimental/clashapi/mitm.go index 6afb8895..358f9694 100644 --- a/experimental/clashapi/mitm.go +++ b/experimental/clashapi/mitm.go @@ -40,7 +40,7 @@ func getMobileConfig(ctx context.Context) http.HandlerFunc { mobileConfig := map[string]interface{}{ "PayloadContent": []interface{}{ map[string]interface{}{ - "PayloadCertificateFileName": "Certificate.cer", + "PayloadCertificateFileName": "Certificates.cer", "PayloadContent": certificate.Raw, "PayloadDescription": "Adds a root certificate", "PayloadDisplayName": certificate.Subject.CommonName, diff --git a/mitm/engine.go b/mitm/engine.go index b9b830e1..8fbd8b44 100644 --- a/mitm/engine.go +++ b/mitm/engine.go @@ -26,6 +26,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/atomic" E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -165,7 +166,7 @@ func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metad tlsConn := tls.Server(conn, tlsConfig) err := tlsConn.HandshakeContext(ctx) if err != nil { - return E.Cause(err, "TLS handshake") + return E.Cause(err, "TLS handshake failed for ", metadata.ClientHello.ServerName, ", ", strings.Join(metadata.ClientHello.SupportedProtos, ", ")) } if tlsConn.ConnectionState().NegotiatedProtocol == "h2" { return e.newHTTP2(ctx, this, tlsConn, tlsConfig, metadata, onClose) @@ -183,7 +184,11 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls return E.Cause(err, "read HTTP request") } rawRequestURL := request.URL - rawRequestURL.Scheme = "https" + if tlsConfig != nil { + rawRequestURL.Scheme = "https" + } else { + rawRequestURL.Scheme = "http" + } if rawRequestURL.Host == "" { rawRequestURL.Host = request.Host } @@ -482,7 +487,7 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls response.Body = io.NopCloser(bytes.NewReader(responseBody)) } if options.Print { - e.printResponse(ctx, response, responseBody) + e.printResponse(ctx, request, response, responseBody) } if responseScript != nil { if responseBody == nil && responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) { @@ -578,6 +583,22 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls } func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { + httpTransport := &http.Transport{ + ForceAttemptHTTP2: true, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + ctx = adapter.WithContext(ctx, &metadata) + if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { + return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) + } else { + return this.DialContext(ctx, N.NetworkTCP, metadata.Destination) + } + }, + TLSClientConfig: tlsConfig, + } + err := http2.ConfigureTransport(httpTransport) + if err != nil { + return E.Cause(err, "configure HTTP/2 transport") + } handler := &engineHandler{ Engine: e, conn: conn, @@ -585,27 +606,7 @@ func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, tls dialer: this, metadata: metadata, httpClient: &http.Client{ - Transport: &http2.Transport{ - AllowHTTP: true, - MaxReadFrameSize: math.MaxUint32, - DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { - ctx = adapter.WithContext(ctx, &metadata) - var ( - remoteConn net.Conn - err error - ) - if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { - remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) - } else { - remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) - } - if err != nil { - return nil, err - } - return tls.Client(remoteConn, cfg), nil - }, - TLSClientConfig: tlsConfig, - }, + Transport: httpTransport, CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, @@ -635,7 +636,6 @@ type engineHandler struct { func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { err := e.serveHTTP(request.Context(), writer, request) if err != nil { - e.conn.Close() if E.IsClosedOrCanceled(err) { e.logger.DebugContext(request.Context(), E.Cause(err, "connection closed")) } else { @@ -921,7 +921,7 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite response.Body = io.NopCloser(bytes.NewReader(responseBody)) } if options.Print { - e.printResponse(ctx, response, responseBody) + e.printResponse(ctx, request, response, responseBody) } if responseScript != nil { if responseBody == nil && responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) { @@ -1021,42 +1021,58 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite } func (e *Engine) printRequest(ctx context.Context, request *http.Request, body []byte) { - e.logger.TraceContext(ctx, "request: ", request.Proto, " ", request.Method, " ", request.URL.String()) + var builder strings.Builder + builder.WriteString(F.ToString(request.Proto, " ", request.Method, " ", request.URL)) + builder.WriteString("\n") if request.URL.Hostname() != "" && request.URL.Hostname() != request.Host { - e.logger.TraceContext(ctx, "request: ", "Host: ", request.Host) + builder.WriteString("Host: ") + builder.WriteString(request.Host) + builder.WriteString("\n") } for key, values := range request.Header { for _, value := range values { - e.logger.TraceContext(ctx, "request: ", key, ": ", value) + builder.WriteString(key) + builder.WriteString(": ") + builder.WriteString(value) + builder.WriteString("\n") } } if len(body) > 0 { + builder.WriteString("\n") if !bytes.ContainsFunc(body, func(r rune) bool { return !unicode.IsPrint(r) && !unicode.IsSpace(r) }) { - e.logger.TraceContext(ctx, "request: body: ", string(body)) + builder.Write(body) } else { - e.logger.TraceContext(ctx, "request: body unprintable") + builder.WriteString("(body not printable)") } } + e.logger.InfoContext(ctx, "request: ", builder.String()) } -func (e *Engine) printResponse(ctx context.Context, response *http.Response, body []byte) { - e.logger.TraceContext(ctx, "response: ", response.Proto, " ", response.Status) +func (e *Engine) printResponse(ctx context.Context, request *http.Request, response *http.Response, body []byte) { + var builder strings.Builder + builder.WriteString(F.ToString(response.Proto, " ", response.Status, " ", request.URL)) + builder.WriteString("\n") for key, values := range response.Header { for _, value := range values { - e.logger.TraceContext(ctx, "response: ", key, ": ", value) + builder.WriteString(key) + builder.WriteString(": ") + builder.WriteString(value) + builder.WriteString("\n") } } if len(body) > 0 { + builder.WriteString("\n") if !bytes.ContainsFunc(body, func(r rune) bool { return !unicode.IsPrint(r) && !unicode.IsSpace(r) }) { - e.logger.TraceContext(ctx, "response: ", string(body)) + builder.Write(body) + } else { + builder.WriteString("(body not printable)") } - } else { - e.logger.TraceContext(ctx, "response: body unprintable") } + e.logger.InfoContext(ctx, "response: ", builder.String()) } type simpleResponseWriter struct {