diff --git a/experimental/libbox/http.go b/experimental/libbox/http.go new file mode 100644 index 00000000..982962a6 --- /dev/null +++ b/experimental/libbox/http.go @@ -0,0 +1,241 @@ +package libbox + +import ( + "bytes" + "context" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "errors" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "net/url" + "os" + "strconv" + "sync" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/protocol/socks" + "github.com/sagernet/sing/protocol/socks/socks5" +) + +type HTTPClient interface { + RestrictedTLS() + ModernTLS() + PinnedTLS12() + PinnedSHA256(sumHex string) + TrySocks5(port int32) + KeepAlive() + NewRequest() HTTPRequest + Close() +} + +type HTTPRequest interface { + SetURL(link string) error + SetMethod(method string) + SetHeader(key string, value string) + SetContent(content []byte) + SetContentString(content string) + RandomUserAgent() + SetUserAgent(userAgent string) + Execute() (HTTPResponse, error) +} + +type HTTPResponse interface { + GetContent() ([]byte, error) + GetContentString() (string, error) + WriteTo(path string) error +} + +var ( + _ HTTPClient = (*httpClient)(nil) + _ HTTPRequest = (*httpRequest)(nil) + _ HTTPResponse = (*httpResponse)(nil) +) + +type httpClient struct { + tls tls.Config + client http.Client + transport http.Transport +} + +func NewHTTPClient() HTTPClient { + client := new(httpClient) + client.client.Timeout = C.TCPTimeout + client.client.Transport = &client.transport + client.transport.TLSClientConfig = &client.tls + client.transport.DisableKeepAlives = true + return client +} + +func (c *httpClient) ModernTLS() { + c.tls.MinVersion = tls.VersionTLS12 + c.tls.CipherSuites = common.Map(tls.CipherSuites(), func(it *tls.CipherSuite) uint16 { return it.ID }) +} + +func (c *httpClient) RestrictedTLS() { + c.tls.MinVersion = tls.VersionTLS13 + c.tls.CipherSuites = common.Map(common.Filter(tls.CipherSuites(), func(it *tls.CipherSuite) bool { + return common.Contains(it.SupportedVersions, uint16(tls.VersionTLS13)) + }), func(it *tls.CipherSuite) uint16 { + return it.ID + }) +} + +func (c *httpClient) PinnedTLS12() { + c.tls.MinVersion = tls.VersionTLS12 + c.tls.MaxVersion = tls.VersionTLS12 +} + +func (c *httpClient) PinnedSHA256(sumHex string) { + c.tls.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + for _, rawCert := range rawCerts { + certSum := sha256.Sum256(rawCert) + if sumHex == hex.EncodeToString(certSum[:]) { + return nil + } + } + return E.New("pinned sha256 sum mismatch") + } +} + +func (c *httpClient) TrySocks5(port int32) { + dialer := new(net.Dialer) + c.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + for { + socksConn, err := dialer.DialContext(ctx, "tcp", "127.0.0.1:"+strconv.Itoa(int(port))) + if err != nil { + break + } + _, err = socks.ClientHandshake5(socksConn, socks5.CommandConnect, M.ParseSocksaddr(addr), "", "") + if err != nil { + break + } + //nolint:staticcheck + return socksConn, err + } + return dialer.DialContext(ctx, network, addr) + } +} + +func (c *httpClient) KeepAlive() { + c.transport.ForceAttemptHTTP2 = true + c.transport.DisableKeepAlives = false +} + +func (c *httpClient) NewRequest() HTTPRequest { + req := &httpRequest{httpClient: c} + req.request = http.Request{ + Method: "GET", + Header: http.Header{}, + } + return req +} + +func (c *httpClient) Close() { + c.transport.CloseIdleConnections() +} + +type httpRequest struct { + *httpClient + request http.Request +} + +func (r *httpRequest) SetURL(link string) (err error) { + r.request.URL, err = url.Parse(link) + if r.request.URL.User != nil { + user := r.request.URL.User.Username() + password, _ := r.request.URL.User.Password() + r.request.SetBasicAuth(user, password) + } + return +} + +func (r *httpRequest) SetMethod(method string) { + r.request.Method = method +} + +func (r *httpRequest) SetHeader(key string, value string) { + r.request.Header.Set(key, value) +} + +func (r *httpRequest) RandomUserAgent() { + r.request.Header.Set("User-Agent", fmt.Sprintf("curl/7.%d.%d", rand.Int()%54, rand.Int()%2)) +} + +func (r *httpRequest) SetUserAgent(userAgent string) { + r.request.Header.Set("User-Agent", userAgent) +} + +func (r *httpRequest) SetContent(content []byte) { + buffer := bytes.Buffer{} + buffer.Write(content) + r.request.Body = io.NopCloser(bytes.NewReader(buffer.Bytes())) + r.request.ContentLength = int64(len(content)) +} + +func (r *httpRequest) SetContentString(content string) { + r.SetContent([]byte(content)) +} + +func (r *httpRequest) Execute() (HTTPResponse, error) { + response, err := r.client.Do(&r.request) + if err != nil { + return nil, err + } + httpResp := &httpResponse{Response: response} + if response.StatusCode != http.StatusOK { + return nil, errors.New(httpResp.errorString()) + } + return httpResp, nil +} + +type httpResponse struct { + *http.Response + + getContentOnce sync.Once + content []byte + contentError error +} + +func (h *httpResponse) errorString() string { + content, err := h.GetContentString() + if err != nil { + return fmt.Sprint("HTTP ", h.Status) + } + return fmt.Sprint("HTTP ", h.Status, ": ", content) +} + +func (h *httpResponse) GetContent() ([]byte, error) { + h.getContentOnce.Do(func() { + defer h.Body.Close() + h.content, h.contentError = io.ReadAll(h.Body) + }) + return h.content, h.contentError +} + +func (h *httpResponse) GetContentString() (string, error) { + content, err := h.GetContent() + if err != nil { + return "", err + } + return string(content), nil +} + +func (h *httpResponse) WriteTo(path string) error { + defer h.Body.Close() + file, err := os.Create(path) + if err != nil { + return err + } + defer file.Close() + return common.Error(bufio.Copy(file, h.Body)) +}