finish wsc outbound communications

This commit is contained in:
Mobin 2025-09-02 19:54:19 +03:30
parent 007bcb78d1
commit 740c1f305c
6 changed files with 284 additions and 305 deletions

View File

@ -2,7 +2,6 @@ package outbound
import (
"context"
"fmt"
"net"
"time"
@ -92,7 +91,6 @@ func (wsc *WSC) ListenPacket(ctx context.Context, destination metadata.Socksaddr
meta.Outbound = wsc.tag
meta.Destination = destination
wsc.logger.InfoContext(ctx, "WSC outbound packet to ", destination)
// return wsc.dialer.ListenPacket(ctx, destination)
return wsc.client.ListenPacket(ctx, N.NetworkUDP, destination.String())
}
@ -101,7 +99,6 @@ func (wsc *WSC) NewConnection(ctx context.Context, conn net.Conn, metadata adapt
}
func (wsc *WSC) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata adapter.InboundContext) error {
fmt.Println("new packet conn: ", metadata)
return NewPacketConnection(ctx, wsc, conn, metadata)
}

View File

@ -36,23 +36,11 @@ func (cli *Client) Close(ctx context.Context) error {
}
func (cli *Client) newWSConn(ctx context.Context, network string, endpoint string) (net.Conn, error) {
scheme := "ws"
if cli.TLS != nil {
scheme = "wss"
pURL, _, err := cli.newURL("ws", "", endpoint, network)
if err != nil {
return nil, err
}
pURL := url.URL{
Scheme: scheme,
Host: cli.Host,
Path: cli.Path,
RawQuery: "",
}
pQuery := pURL.Query()
pQuery.Set("auth", cli.Auth)
pQuery.Set("ep", endpoint)
pQuery.Set("net", network)
pURL.RawQuery = pQuery.Encode()
dialer := ws.Dialer{
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := cli.Dialer.DialContext(ctx, N.NetworkTCP, metadata.ParseSocksaddr(addr))
@ -79,27 +67,11 @@ func (cli *Client) newWSConn(ctx context.Context, network string, endpoint strin
}
func (cli *Client) cleanup(ctx context.Context) error {
scheme := "http"
var tlsConfig *tls.STDConfig
if cli.TLS != nil {
scheme = "https"
var err error
tlsConfig, err = cli.TLS.Config()
if err != nil {
return err
}
pURL, tlsConfig, err := cli.newURL("http", "/cleanup", "", "")
if err != nil {
return err
}
pURL := url.URL{
Scheme: scheme,
Host: cli.Host,
Path: "/cleanup",
RawQuery: "",
}
pQuery := pURL.Query()
pQuery.Set("auth", cli.Auth)
pURL.RawQuery = pQuery.Encode()
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
@ -122,3 +94,37 @@ func (cli *Client) cleanup(ctx context.Context) error {
return nil
}
func (cli *Client) newURL(scheme string, path string, endpoint string, network string) (url.URL, *tls.STDConfig, error) {
var tlsConfig *tls.STDConfig = nil
if cli.TLS != nil {
scheme += "s"
var err error
tlsConfig, err = cli.TLS.Config()
if err != nil {
return url.URL{}, nil, err
}
}
if path == "" {
path = cli.Path
}
pURL := url.URL{
Scheme: scheme,
Host: cli.Host,
Path: path,
RawQuery: "",
}
pQuery := pURL.Query()
pQuery.Set("auth", cli.Auth)
if endpoint != "" {
pQuery.Set("ep", endpoint)
}
if network != "" {
pQuery.Set("net", network)
}
pURL.RawQuery = pQuery.Encode()
return pURL, tlsConfig, nil
}

View File

@ -2,20 +2,24 @@ package wsc
import (
"context"
"errors"
"io"
"net"
"sync"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/network"
"github.com/sagernet/ws"
"github.com/sagernet/ws/wsutil"
)
var _ net.Conn = &clientConn{}
var _ network.ExtendedConn = &clientConn{}
type clientConn struct {
net.Conn
reader *wsutil.Reader
buf [2048]byte
mu sync.Mutex
}
@ -38,6 +42,29 @@ func (conn *clientConn) Close() error {
return conn.Conn.Close()
}
func (conn *clientConn) ReadBuffer(buffer *buf.Buffer) error {
if buffer == nil {
return errors.New("buffer is nil")
}
n, err := conn.Read(conn.buf[:])
if _, wErr := buffer.Write(conn.buf[:n]); wErr != nil {
return wErr
}
if errors.Is(err, io.EOF) {
return nil
}
return err
}
func (conn *clientConn) WriteBuffer(buffer *buf.Buffer) error {
if buffer == nil {
return errors.New("buffer is nil")
}
conn.mu.Lock()
defer conn.mu.Unlock()
return wsutil.WriteClientBinary(conn.Conn, buffer.Bytes())
}
func (conn *clientConn) Read(b []byte) (n int, err error) {
err = nil
var header ws.Header

View File

@ -1,19 +1,32 @@
package wsc
import (
"bytes"
"context"
"errors"
"io"
"net"
"sync"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network"
"github.com/sagernet/ws"
"github.com/sagernet/ws/wsutil"
)
var _ net.PacketConn = &clientPacketConn{}
var _ network.NetPacketReader = &clientPacketConn{}
var _ network.NetPacketWriter = &clientPacketConn{}
type readerCache struct {
reader *bytes.Reader
addr metadata.Socksaddr
}
type clientPacketConn struct {
net.Conn
reader *wsutil.Reader
cache *readerCache
mu sync.Mutex
}
@ -26,34 +39,122 @@ func (cli *Client) newPacketConn(ctx context.Context, network string, endpoint s
return &clientPacketConn{
Conn: conn,
reader: reader,
cache: nil,
}, nil
}
func (packetConn *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination metadata.Socksaddr, err error) {
if buffer == nil {
return metadata.Socksaddr{}, errors.New("buffer is nil")
}
buf, err := wsutil.ReadServerBinary(packetConn.Conn)
if err != nil {
var cerr wsutil.ClosedError
if errors.Is(err, &cerr) {
return metadata.Socksaddr{}, err
}
return metadata.Socksaddr{}, err
}
payload := packetConnPayload{}
if err := payload.UnmarshalBinaryUnsafe(buf); err != nil {
return metadata.Socksaddr{}, err
}
destination = metadata.SocksaddrFromNetIP(payload.addrPort)
if _, err := buffer.Write(payload.payload); err != nil {
return metadata.Socksaddr{}, err
}
return destination, nil
}
func (packetConn *clientPacketConn) WritePacket(buffer *buf.Buffer, destination metadata.Socksaddr) error {
if buffer == nil {
return errors.New("buffer is nil")
}
payload := packetConnPayload{
addrPort: destination.AddrPort(),
payload: buffer.Bytes(),
}
payloadBytes, err := payload.MarshalBinary()
if err != nil {
return err
}
packetConn.mu.Lock()
defer packetConn.mu.Unlock()
if err := wsutil.WriteClientBinary(packetConn.Conn, payloadBytes); err != nil {
return err
}
return nil
}
func (packetConn *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return 0, nil, nil
err = nil
if packetConn.cache != nil {
n, err = packetConn.cache.reader.Read(p)
addr = packetConn.cache.addr
if err == io.EOF {
err = nil
packetConn.cache = nil
} else {
return
}
}
buf, err := wsutil.ReadServerBinary(packetConn.Conn)
if err != nil {
var cerr wsutil.ClosedError
if errors.Is(err, &cerr) {
return 0, nil, io.EOF
}
return 0, nil, err
}
payload := packetConnPayload{}
if err := payload.UnmarshalBinaryUnsafe(buf); err != nil {
return 0, nil, err
}
packetConn.cache = &readerCache{
reader: bytes.NewReader(payload.payload),
addr: metadata.SocksaddrFromNetIP(payload.addrPort),
}
n, err = packetConn.cache.reader.Read(p)
addr = packetConn.cache.addr
if err == io.EOF {
packetConn.cache = nil
}
return
}
func (packetConn *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, nil
}
payload := packetConnPayload{
addrPort: metadata.SocksaddrFromNet(addr).AddrPort(),
payload: p,
}
payloadBytes, err := payload.MarshalBinary()
if err != nil {
return 0, err
}
// func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
// destination := M.SocksaddrFromNet(addr)
// buffer := buf.NewSize(M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
// defer buffer.Release()
// if err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination); err != nil {
// return 0, err
// }
// if _, err = buffer.Write(p); err != nil {
// return 0, err
// }
// c.mu.Lock()
// defer c.mu.Unlock()
// if err = wsutil.WriteClientBinary(c.Conn, buffer.Bytes()); err != nil {
// return 0, err
// }
// return len(p), nil
// }
packetConn.mu.Lock()
defer packetConn.mu.Unlock()
if err := wsutil.WriteClientBinary(packetConn.Conn, payloadBytes); err != nil {
return 0, err
}
return len(payloadBytes), nil
}
func (packetConn *clientPacketConn) Close() error {
packetConn.mu.Lock()
@ -61,153 +162,3 @@ func (packetConn *clientPacketConn) Close() error {
_ = wsutil.WriteClientMessage(packetConn.Conn, ws.OpClose, nil)
return packetConn.Conn.Close()
}
/*
package wsc
import (
"bytes"
"context"
"net"
"net/url"
"sync"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/ws"
"github.com/sagernet/ws/wsutil"
)
// clientPacketConn implements net.PacketConn over WebSocket.
type clientPacketConn struct {
net.Conn
mu sync.Mutex
}
// newPacketConn dials a WebSocket endpoint for packet based communications.
func (cli *Client) newPacketConn(ctx context.Context, network string, endpoint string) (*clientPacketConn, error) {
scheme := "ws"
if cli.TLS != nil {
scheme = "wss"
}
pURL := url.URL{
Scheme: scheme,
Host: cli.Host,
Path: cli.Path,
RawQuery: "",
}
pQuery := pURL.Query()
pQuery.Set("auth", cli.Auth)
if network != "" {
pQuery.Set("net", network)
}
if endpoint != "" {
pQuery.Set("ep", endpoint)
}
pURL.RawQuery = pQuery.Encode()
dialer := ws.Dialer{
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := cli.Dialer.DialContext(ctx, N.NetworkTCP, M.ParseSocksaddr(addr))
if err != nil {
return nil, err
}
if cli.TLS != nil {
conn, err = tls.ClientHandshake(ctx, conn, cli.TLS)
if err != nil {
return nil, err
}
}
return conn, nil
},
}
conn, _, _, err := dialer.Dial(ctx, pURL.String())
if err != nil {
return nil, err
}
return &clientPacketConn{Conn: conn}, nil
}
// ListenPacket creates a packet-oriented WebSocket connection.
func (cli *Client) ListenPacket(ctx context.Context, network string, endpoint string) (net.PacketConn, error) {
return cli.newPacketConn(ctx, network, endpoint)
}
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
msg, err := wsutil.ReadServerBinary(c.Conn)
if err != nil {
return M.Socksaddr{}, err
}
reader := bytes.NewReader(msg)
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return M.Socksaddr{}, err
}
_, err = buffer.Write(msg[len(msg)-reader.Len():])
if err != nil {
return M.Socksaddr{}, err
}
return destination, nil
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination)))
if err := M.SocksaddrSerializer.WriteAddrPort(header, destination); err != nil {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientBinary(c.Conn, buffer.Bytes())
}
func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
msg, err := wsutil.ReadServerBinary(c.Conn)
if err != nil {
return 0, nil, err
}
reader := bytes.NewReader(msg)
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return 0, nil, err
}
n = copy(p, msg[len(msg)-reader.Len():])
if destination.IsFqdn() {
addr = destination
} else {
addr = destination.UDPAddr()
}
return
}
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
buffer := buf.NewSize(M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
defer buffer.Release()
if err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination); err != nil {
return 0, err
}
if _, err = buffer.Write(p); err != nil {
return 0, err
}
c.mu.Lock()
defer c.mu.Unlock()
if err = wsutil.WriteClientBinary(c.Conn, buffer.Bytes()); err != nil {
return 0, err
}
return len(p), nil
}
func (c *clientPacketConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
_ = wsutil.WriteClientMessage(c.Conn, ws.OpClose, nil)
return c.Conn.Close()
}
func (c *clientPacketConn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
}
*/

View File

@ -1,115 +1,76 @@
package wsc
type packetConnPayload struct {
ip [16]byte
port uint16
}
/*
import (
"encoding"
"encoding/binary"
"errors"
"net"
"net/netip"
)
// Header is 18 bytes: 16 for IP + 2 for port (big-endian)
type Header struct {
IP [16]byte
Port uint16
const packetConnPayloadHeaderLen = 18
var _ encoding.BinaryMarshaler = &packetConnPayload{}
var _ encoding.BinaryUnmarshaler = &packetConnPayload{}
type packetConnPayload struct {
addrPort netip.AddrPort
payload []byte
}
const (
headerLen = 18
)
// ipv4ToMapped fills a 16-byte buffer with ::ffff:w.x.y.z
func ipv4ToMapped(v4 net.IP, dst *[16]byte) {
// v4 must be 4 bytes (no zone)
copy(dst[:], []byte{
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0xff, 0xff,
v4[0], v4[1], v4[2], v4[3],
})
}
// ipTo16Mapped returns a 16-byte IPv6 form.
// - IPv6 stays as-is (compressed/expanded form doesn't matter; we copy the 16 raw bytes).
// - IPv4 becomes IPv4-mapped IPv6 ::ffff:w.x.y.z
func ipTo16Mapped(ip net.IP) ([16]byte, error) {
var out [16]byte
if ip == nil {
return out, errors.New("nil IP")
func (payload *packetConnPayload) UnmarshalBinary(data []byte) error {
if err := payload.UnmarshalBinaryUnsafe(data); err != nil {
return err
}
if v4 := ip.To4(); v4 != nil {
ipv4ToMapped(v4, &out)
return out, nil
}
v6 := ip.To16()
if v6 == nil || len(v6) != 16 {
return out, errors.New("invalid IP")
}
copy(out[:], v6)
return out, nil
}
// NewHeader builds a Header from net.IP + port.
func NewHeader(ip net.IP, port int) (Header, error) {
var h Header
ip16, err := ipTo16Mapped(ip)
if err != nil {
return h, err
}
h.IP = ip16
if port < 0 || port > 65535 {
return h, errors.New("invalid port")
}
h.Port = uint16(port)
return h, nil
}
payload.payload = append(make([]byte, 0, len(payload.payload)), payload.payload...)
// FromTCPAddr / FromUDPAddr convenience.
func FromTCPAddr(a *net.TCPAddr) (Header, error) { return NewHeader(a.IP, a.Port) }
func FromUDPAddr(a *net.UDPAddr) (Header, error) { return NewHeader(a.IP, a.Port) }
// MarshalBinary -> 18 bytes
func (h Header) MarshalBinary() []byte {
b := make([]byte, headerLen)
copy(b[:16], h.IP[:])
binary.BigEndian.PutUint16(b[16:], h.Port)
return b
}
// UnmarshalBinary <- 18 bytes
func (h *Header) UnmarshalBinary(b []byte) error {
if len(b) < headerLen {
return errors.New("short header")
}
copy(h.IP[:], b[:16])
h.Port = binary.BigEndian.Uint16(b[16:18])
return nil
}
// ToNetAddr returns a *net.TCPAddr or *net.UDPAddr-ready IP & port.
// If the address is IPv4-mapped, it returns the 4-byte form for convenience.
func (h Header) ToIPPort() (net.IP, int) {
ip := net.IP(h.IP[:]).To16()
// Detect IPv4-mapped ::ffff:w.x.y.z and convert back to v4 if you like:
if ip4 := ip.To4(); ip4 != nil {
return ip4, int(h.Port)
func (payload *packetConnPayload) MarshalBinary() (data []byte, err error) {
if !payload.addrPort.IsValid() {
return nil, errors.New("addr port is not valid")
}
return ip, int(h.Port)
data = make([]byte, len(payload.payload)+packetConnPayloadHeaderLen)
return data, payload.MarshalBinaryUnsafe(data)
}
*/
/*
// Encode
dst := net.ParseIP("192.0.2.10")
hdr, _ := NewHeader(dst, 443)
wireBytes := hdr.MarshalBinary() // 18 bytes ready to send
func (payload *packetConnPayload) UnmarshalBinaryUnsafe(data []byte) error {
const hLen = packetConnPayloadHeaderLen
// Decode
var got Header
_ = got.UnmarshalBinary(wireBytes)
ip, port := got.ToIPPort() // ip is 4-byte 192.0.2.10, port=443
*/
if len(data) < hLen {
return errors.New("invalid payload")
}
addr, ok := netip.AddrFromSlice(data[:hLen-2])
if !ok {
return errors.New("couldn't parse addr port")
}
port := binary.LittleEndian.Uint16(data[hLen-2 : hLen])
payload.addrPort = netip.AddrPortFrom(addr, port)
payload.payload = data[hLen:]
return nil
}
func (payload *packetConnPayload) MarshalBinaryUnsafe(data []byte) error {
const hLen = packetConnPayloadHeaderLen
if !payload.addrPort.IsValid() {
return errors.New("addr port is not valid")
}
if len(data) < hLen+len(payload.payload) {
return errors.New("invalid data length to write")
}
addr := payload.addrPort.Addr().As16()
copy(data[:hLen-2], addr[:])
binary.LittleEndian.PutUint16(data[hLen-2:hLen], payload.addrPort.Port())
copy(data[hLen:], payload.payload)
return nil
}

View File

@ -0,0 +1,37 @@
package wsc
import (
"net/netip"
"slices"
"testing"
)
func TestPacketPayload(t *testing.T) {
text := "salam chetori?"
payload := packetConnPayload{
addrPort: netip.MustParseAddrPort("9.9.9.9:53"),
payload: []byte(text),
}
bin, err := payload.MarshalBinary()
if err != nil {
t.Fatal(err)
}
if len(bin) != packetConnPayloadHeaderLen+len(text) {
t.Fatal("wrong marshal")
}
p2 := packetConnPayload{}
if err := p2.UnmarshalBinary(bin); err != nil {
t.Fatal(err)
}
if p2.addrPort.Port() != payload.addrPort.Port() || p2.addrPort.Addr().As16() != payload.addrPort.Addr().As16() {
t.Fatal("failed to unmarshal addrport")
}
if !slices.Equal(p2.payload, payload.payload) {
t.Fatal("unmarshaled payload not equal")
}
}