diff --git a/mitm/engine.go b/mitm/engine.go index 32a6baec..b6f502f0 100644 --- a/mitm/engine.go +++ b/mitm/engine.go @@ -17,6 +17,7 @@ import ( "path/filepath" "strings" "time" + "unicode" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" @@ -124,6 +125,7 @@ func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metad acceptHTTP := len(metadata.ClientHello.SupportedProtos) == 0 || common.Contains(metadata.ClientHello.SupportedProtos, "http/1.1") acceptH2 := e.http2Enabled && common.Contains(metadata.ClientHello.SupportedProtos, "h2") if !acceptHTTP && !acceptH2 { + metadata.MITM = nil e.logger.DebugContext(ctx, "unsupported application protocol: ", strings.Join(metadata.ClientHello.SupportedProtos, ",")) e.connection.NewConnection(ctx, this, conn, metadata, onClose) return nil @@ -147,12 +149,11 @@ func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metad serverName = metadata.Destination.Addr.String() } tlsConfig := &tls.Config{ - Time: e.timeFunc, - CipherSuites: metadata.ClientHello.CipherSuites, - ServerName: serverName, - CurvePreferences: metadata.ClientHello.SupportedCurves, - NextProtos: nextProtos, - MinVersion: minVersion, + Time: e.timeFunc, + ServerName: serverName, + NextProtos: nextProtos, + MinVersion: minVersion, + MaxVersion: maxVersion, GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { return sTLS.GenerateKeyPair(e.tlsCertificate, e.tlsPrivateKey, e.timeFunc, serverName) }, @@ -163,7 +164,7 @@ func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metad return E.Cause(err, "TLS handshake") } if tlsConn.ConnectionState().NegotiatedProtocol == "h2" { - return e.newHTTP2(ctx, this, tlsConn, metadata, onClose) + return e.newHTTP2(ctx, this, tlsConn, tlsConfig, metadata, onClose) } else { return e.newHTTP1(ctx, this, tlsConn, tlsConfig, metadata) } @@ -171,7 +172,6 @@ func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metad func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext) error { options := metadata.MITM - metadata.MITM = nil defer conn.Close() reader := bufio.NewReader(conn) request, err := sHTTP.ReadRequest(reader) @@ -209,9 +209,19 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls requestMatch = true break } + var body []byte + if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 { + body, err = io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + request.Body = io.NopCloser(bytes.NewReader(body)) + } + if options.Print { + e.printRequest(ctx, request, body) + } if requestScript != nil { - var body []byte - if requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) { + if body == nil && requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) { body, err = io.ReadAll(request.Body) if err != nil { return E.Cause(err, "read HTTP request body") @@ -266,8 +276,9 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls request.Header.Del("Host") } if result.Body != nil { - request.Body = io.NopCloser(bytes.NewReader(result.Body)) - request.ContentLength = int64(len(result.Body)) + body = result.Body + request.Body = io.NopCloser(bytes.NewReader(body)) + request.ContentLength = int64(len(body)) } } } @@ -337,17 +348,18 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls } requestMatch = true e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) - var body []byte - if request.ContentLength <= 0 { - e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") - break - } else if request.ContentLength > 131072 { - e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) - break - } - body, err = io.ReadAll(request.Body) - if err != nil { - return E.Cause(err, "read HTTP request body") + if body == nil { + if request.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if request.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + body, err = io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } } for mi := 0; i < len(rule.Match); i++ { body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i])) @@ -366,7 +378,6 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls var ( statusCode = http.StatusOK headers = make(http.Header) - body []byte ) if rule.StatusCode > 0 { statusCode = rule.StatusCode @@ -410,26 +421,17 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls } } ctx = adapter.WithContext(ctx, &metadata) - var remoteConn net.Conn - if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { - remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) - } else { - remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) - } - if err != nil { - return E.Cause(err, "open outbound connection") - } - defer remoteConn.Close() var innerErr atomic.TypedValue[error] httpClient := &http.Client{ Transport: &http.Transport{ - DialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { - if tlsConfig != nil { - return tls.Client(remoteConn, tlsConfig), nil + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { + return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) } else { - return remoteConn, nil + return this.DialContext(ctx, N.NetworkTCP, metadata.Destination) } }, + TLSClientConfig: tlsConfig, }, CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse @@ -467,17 +469,27 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls responseMatch = true break } + var responseBody []byte + if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 { + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP response body") + } + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + } + if options.Print { + e.printResponse(ctx, response, responseBody) + } if responseScript != nil { - var body []byte - if responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) { - body, err = io.ReadAll(response.Body) + if responseBody == nil && responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) { + responseBody, err = io.ReadAll(response.Body) if err != nil { return E.Cause(err, "read HTTP response body") } - response.Body = io.NopCloser(bytes.NewReader(body)) + response.Body = io.NopCloser(bytes.NewReader(responseBody)) } var result *adapter.HTTPResponseScriptResult - result, err = responseScript.Run(ctx, request, response, body) + result, err = responseScript.Run(ctx, request, response, responseBody) if err != nil { return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]") } @@ -490,8 +502,9 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls } if result.Body != nil { response.Body.Close() - response.Body = io.NopCloser(bytes.NewReader(result.Body)) - response.ContentLength = int64(len(result.Body)) + responseBody = result.Body + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + response.ContentLength = int64(len(responseBody)) } } if !responseMatch { @@ -528,26 +541,27 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls } responseMatch = true e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) - var body []byte - if response.ContentLength <= 0 { - e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") - break - } else if response.ContentLength > 131072 { - e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) - break - } - body, err = io.ReadAll(response.Body) - if err != nil { - return E.Cause(err, "read HTTP request body") + if responseBody == nil { + if response.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if response.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } } for mi := 0; i < len(rule.Match); i++ { - body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i])) + responseBody = rule.Match[mi].ReplaceAll(responseBody, []byte(rule.Replace[i])) } - response.Body = io.NopCloser(bytes.NewReader(body)) - response.ContentLength = int64(len(body)) + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + response.ContentLength = int64(len(responseBody)) } } - if !requestMatch && !responseMatch { + if !options.Print && !requestMatch && !responseMatch { e.logger.WarnContext(ctx, "request not modified") } err = response.Write(conn) @@ -559,12 +573,13 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls return nil } -func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { +func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { handler := &engineHandler{ - Engine: e, - conn: conn, - dialer: this, - metadata: metadata, + Engine: e, + conn: conn, + tlsConfig: tlsConfig, + dialer: this, + metadata: metadata, httpClient: &http.Client{ Transport: &http2.Transport{ AllowHTTP: true, @@ -585,6 +600,7 @@ func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, met } return tls.Client(remoteConn, cfg), nil }, + TLSClientConfig: tlsConfig, }, CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse @@ -604,17 +620,18 @@ func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, met type engineHandler struct { *Engine - conn net.Conn - dialer N.Dialer - metadata adapter.InboundContext - onClose N.CloseHandlerFunc - + conn net.Conn + tlsConfig *tls.Config + dialer N.Dialer + metadata adapter.InboundContext + onClose N.CloseHandlerFunc httpClient *http.Client } func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { err := e.serveHTTP(request.Context(), writer, request) if err != nil { + e.conn.Close() if E.IsClosedOrCanceled(err) { e.logger.DebugContext(request.Context(), E.Cause(err, "connection closed")) } else { @@ -625,7 +642,6 @@ func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Requ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWriter, request *http.Request) error { options := e.metadata.MITM - e.metadata.MITM = nil rawRequestURL := request.URL rawRequestURL.Scheme = "https" if rawRequestURL.Host == "" { @@ -657,10 +673,23 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite requestMatch = true break } - var err error + var ( + body []byte + err error + ) + if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 { + body, err = io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + request.Body.Close() + request.Body = io.NopCloser(bytes.NewReader(body)) + } + if options.Print { + e.printRequest(ctx, request, body) + } if requestScript != nil { - var body []byte - if requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) { + if body == nil && requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) { body, err = io.ReadAll(request.Body) if err != nil { return E.Cause(err, "read HTTP request body") @@ -700,6 +729,7 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite newDestination.Port = e.metadata.Destination.Port } e.metadata.Destination = newDestination + e.tlsConfig.ServerName = newURL.Hostname() } for key, values := range result.Headers { request.Header[key] = values @@ -734,6 +764,7 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite newDestination.Port = e.metadata.Destination.Port } e.metadata.Destination = newDestination + e.tlsConfig.ServerName = rule.Destination.Hostname() break } for i, rule := range options.SurgeHeaderRewrite { @@ -876,18 +907,29 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite responseMatch = true break } + var responseBody []byte + if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 { + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP response body") + } + response.Body.Close() + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + } + if options.Print { + e.printResponse(ctx, response, responseBody) + } if responseScript != nil { - var body []byte - if responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) { - body, err = io.ReadAll(response.Body) + if responseBody == nil && responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) { + responseBody, err = io.ReadAll(response.Body) if err != nil { return E.Cause(err, "read HTTP response body") } response.Body.Close() - response.Body = io.NopCloser(bytes.NewReader(body)) + response.Body = io.NopCloser(bytes.NewReader(responseBody)) } var result *adapter.HTTPResponseScriptResult - result, err = responseScript.Run(ctx, request, response, body) + result, err = responseScript.Run(ctx, request, response, responseBody) if err != nil { return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]") } @@ -938,30 +980,31 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite } responseMatch = true e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) - var body []byte - if response.ContentLength <= 0 { - e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") - break - } else if response.ContentLength > 131072 { - e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) - break + if responseBody == nil { + if response.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if response.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + response.Body.Close() } - body, err = io.ReadAll(response.Body) - if err != nil { - return E.Cause(err, "read HTTP request body") - } - response.Body.Close() for mi := 0; i < len(rule.Match); i++ { - body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i])) + responseBody = rule.Match[mi].ReplaceAll(responseBody, []byte(rule.Replace[i])) } - response.Body = io.NopCloser(bytes.NewReader(body)) - response.ContentLength = int64(len(body)) + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + response.ContentLength = int64(len(responseBody)) } } - if !requestMatch && !responseMatch { + if !options.Print && !requestMatch && !responseMatch { e.logger.WarnContext(ctx, "request not modified") } - for key, values := range request.Header { + for key, values := range response.Header { writer.Header()[key] = values } writer.WriteHeader(response.StatusCode) @@ -973,6 +1016,45 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite return nil } +func (e *Engine) printRequest(ctx context.Context, request *http.Request, body []byte) { + e.logger.TraceContext(ctx, "request: ", request.Proto, " ", request.Method, " ", request.URL.String()) + if request.URL.Hostname() != "" && request.URL.Hostname() != request.Host { + e.logger.TraceContext(ctx, "request: ", "Host: ", request.Host) + } + for key, values := range request.Header { + for _, value := range values { + e.logger.TraceContext(ctx, "request: ", key, ": ", value) + } + } + if len(body) > 0 { + if !bytes.ContainsFunc(body, func(r rune) bool { + return !unicode.IsPrint(r) && !unicode.IsSpace(r) + }) { + e.logger.TraceContext(ctx, "request: body: ", string(body)) + } else { + e.logger.TraceContext(ctx, "request: body unprintable") + } + } +} + +func (e *Engine) printResponse(ctx context.Context, response *http.Response, body []byte) { + e.logger.TraceContext(ctx, "response: ", response.Proto, " ", response.Status) + for key, values := range response.Header { + for _, value := range values { + e.logger.TraceContext(ctx, "response: ", key, ": ", value) + } + } + if len(body) > 0 { + if !bytes.ContainsFunc(body, func(r rune) bool { + return !unicode.IsPrint(r) && !unicode.IsSpace(r) + }) { + e.logger.TraceContext(ctx, "response: ", string(body)) + } + } else { + e.logger.TraceContext(ctx, "response: body unprintable") + } +} + type simpleResponseWriter struct { statusCode int header http.Header diff --git a/option/mitm.go b/option/mitm.go index be9f0180..7166f76d 100644 --- a/option/mitm.go +++ b/option/mitm.go @@ -18,6 +18,7 @@ type TLSDecryptionOptions struct { type MITMRouteOptions struct { Enabled bool `json:"enabled,omitempty"` + Print bool `json:"print,omitempty"` Script badoption.Listable[string] `json:"script,omitempty"` SurgeURLRewrite badoption.Listable[SurgeURLRewriteLine] `json:"sg_url_rewrite,omitempty"` SurgeHeaderRewrite badoption.Listable[SurgeHeaderRewriteLine] `json:"sg_header_rewrite,omitempty"`