mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-06-08 11:44:13 +08:00
268 lines
7.4 KiB
Go
268 lines
7.4 KiB
Go
//go:build windows
|
|
|
|
package ndis
|
|
|
|
import (
|
|
"context"
|
|
"net/netip"
|
|
"time"
|
|
|
|
"github.com/sagernet/gvisor/pkg/buffer"
|
|
"github.com/sagernet/gvisor/pkg/tcpip"
|
|
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
|
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
|
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
|
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
|
"github.com/sagernet/sing-box/adapter"
|
|
"github.com/sagernet/sing-box/common/conntrack"
|
|
"github.com/sagernet/sing-tun"
|
|
"github.com/sagernet/sing/common/control"
|
|
"github.com/sagernet/sing/common/debug"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
"github.com/sagernet/sing/common/logger"
|
|
|
|
"github.com/wiresock/ndisapi-go"
|
|
"github.com/wiresock/ndisapi-go/driver"
|
|
"go4.org/netipx"
|
|
"golang.org/x/net/ipv4"
|
|
"golang.org/x/net/ipv6"
|
|
)
|
|
|
|
type Stack struct {
|
|
ctx context.Context
|
|
logger logger.ContextLogger
|
|
network adapter.NetworkManager
|
|
trackerIn conntrack.Tracker
|
|
trackerOut conntrack.Tracker
|
|
api *ndisapi.NdisApi
|
|
handler tun.Handler
|
|
udpTimeout time.Duration
|
|
filter *driver.QueuedPacketFilter
|
|
stack *stack.Stack
|
|
endpoint *ndisEndpoint
|
|
routeAddress []netip.Prefix
|
|
routeExcludeAddress []netip.Prefix
|
|
routeAddressSet []*netipx.IPSet
|
|
routeExcludeAddressSet []*netipx.IPSet
|
|
currentInterface *control.Interface
|
|
}
|
|
|
|
func (s *Stack) Start() error {
|
|
err := s.start(s.network.InterfaceMonitor().DefaultInterface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.network.InterfaceMonitor().RegisterCallback(s.updateDefaultInterface)
|
|
return nil
|
|
}
|
|
|
|
func (s *Stack) updateDefaultInterface(defaultInterface *control.Interface, flags int) {
|
|
if s.currentInterface.Equals(*defaultInterface) {
|
|
return
|
|
}
|
|
err := s.start(defaultInterface)
|
|
if err != nil {
|
|
s.logger.Error(E.Cause(err, "reconfigure NDIS at: ", defaultInterface.Name))
|
|
}
|
|
}
|
|
|
|
func (s *Stack) start(defaultInterface *control.Interface) error {
|
|
_ = s.Close()
|
|
adapters, err := s.api.GetTcpipBoundAdaptersInfo()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if defaultInterface != nil {
|
|
for index := 0; index < int(adapters.AdapterCount); index++ {
|
|
name := s.api.ConvertWindows2000AdapterName(string(adapters.AdapterNameList[index][:]))
|
|
if name != defaultInterface.Name {
|
|
continue
|
|
}
|
|
s.filter, err = driver.NewQueuedPacketFilter(s.api, adapters, nil, s.processOut)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
address := tcpip.LinkAddress(adapters.CurrentAddress[index][:])
|
|
mtu := uint32(adapters.MTU[index])
|
|
endpoint := &ndisEndpoint{
|
|
filter: s.filter,
|
|
mtu: mtu,
|
|
address: address,
|
|
}
|
|
s.stack, err = tun.NewGVisorStack(endpoint)
|
|
if err != nil {
|
|
s.filter = nil
|
|
return err
|
|
}
|
|
s.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(s.ctx, s.stack, s.handler).HandlePacket)
|
|
s.stack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(s.ctx, s.stack, s.handler, s.udpTimeout).HandlePacket)
|
|
err = s.filter.StartFilter(index)
|
|
if err != nil {
|
|
s.filter = nil
|
|
s.stack.Close()
|
|
s.stack = nil
|
|
return err
|
|
}
|
|
s.endpoint = endpoint
|
|
s.logger.Info("started at ", defaultInterface.Name)
|
|
break
|
|
}
|
|
}
|
|
s.currentInterface = defaultInterface
|
|
return nil
|
|
}
|
|
|
|
func (s *Stack) Close() error {
|
|
if s.filter != nil {
|
|
s.filter.StopFilter()
|
|
s.filter.Close()
|
|
s.filter = nil
|
|
}
|
|
if s.stack != nil {
|
|
s.stack.Close()
|
|
for _, endpoint := range s.stack.CleanupEndpoints() {
|
|
endpoint.Abort()
|
|
}
|
|
s.stack = nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Stack) processOut(handle ndisapi.Handle, packet *ndisapi.IntermediateBuffer) ndisapi.FilterAction {
|
|
if packet.Length < header.EthernetMinimumSize {
|
|
return ndisapi.FilterActionPass
|
|
}
|
|
if s.endpoint.dispatcher == nil || s.filterPacket(packet.Buffer[:packet.Length]) {
|
|
return ndisapi.FilterActionPass
|
|
}
|
|
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
Payload: buffer.MakeWithData(packet.Buffer[:packet.Length]),
|
|
})
|
|
_, ok := packetBuffer.LinkHeader().Consume(header.EthernetMinimumSize)
|
|
if !ok {
|
|
packetBuffer.DecRef()
|
|
return ndisapi.FilterActionPass
|
|
}
|
|
ethHdr := header.Ethernet(packetBuffer.LinkHeader().Slice())
|
|
destinationAddress := ethHdr.DestinationAddress()
|
|
if destinationAddress == header.EthernetBroadcastAddress {
|
|
packetBuffer.PktType = tcpip.PacketBroadcast
|
|
} else if header.IsMulticastEthernetAddress(destinationAddress) {
|
|
packetBuffer.PktType = tcpip.PacketMulticast
|
|
} else if destinationAddress == s.endpoint.address {
|
|
packetBuffer.PktType = tcpip.PacketHost
|
|
} else {
|
|
packetBuffer.PktType = tcpip.PacketOtherHost
|
|
}
|
|
s.endpoint.dispatcher.DeliverNetworkPacket(ethHdr.Type(), packetBuffer)
|
|
packetBuffer.DecRef()
|
|
return ndisapi.FilterActionDrop
|
|
}
|
|
|
|
func (s *Stack) filterPacket(packet []byte) bool {
|
|
var ipHdr header.Network
|
|
switch header.IPVersion(packet[header.EthernetMinimumSize:]) {
|
|
case ipv4.Version:
|
|
ipHdr = header.IPv4(packet[header.EthernetMinimumSize:])
|
|
case ipv6.Version:
|
|
ipHdr = header.IPv6(packet[header.EthernetMinimumSize:])
|
|
default:
|
|
return true
|
|
}
|
|
sourceAddr := tun.AddrFromAddress(ipHdr.SourceAddress())
|
|
destinationAddr := tun.AddrFromAddress(ipHdr.DestinationAddress())
|
|
if !destinationAddr.IsGlobalUnicast() {
|
|
return true
|
|
}
|
|
var (
|
|
transportProtocol tcpip.TransportProtocolNumber
|
|
transportHdr header.Transport
|
|
)
|
|
switch ipHdr.TransportProtocol() {
|
|
case tcp.ProtocolNumber:
|
|
transportProtocol = header.TCPProtocolNumber
|
|
transportHdr = header.TCP(ipHdr.Payload())
|
|
case udp.ProtocolNumber:
|
|
transportProtocol = header.UDPProtocolNumber
|
|
transportHdr = header.UDP(ipHdr.Payload())
|
|
default:
|
|
return false
|
|
}
|
|
source := netip.AddrPortFrom(sourceAddr, transportHdr.SourcePort())
|
|
destination := netip.AddrPortFrom(destinationAddr, transportHdr.DestinationPort())
|
|
if transportProtocol == header.TCPProtocolNumber {
|
|
if s.trackerIn.CheckConn(source, destination) {
|
|
if debug.Enabled {
|
|
s.logger.Trace("fall exists TCP ", source, " ", destination)
|
|
}
|
|
return false
|
|
}
|
|
} else {
|
|
if s.trackerIn.CheckPacketConn(source) {
|
|
if debug.Enabled {
|
|
s.logger.Trace("fall exists UDP ", source, " ", destination)
|
|
}
|
|
}
|
|
}
|
|
if len(s.routeAddress) > 0 {
|
|
var match bool
|
|
for _, route := range s.routeAddress {
|
|
if route.Contains(destinationAddr) {
|
|
match = true
|
|
}
|
|
}
|
|
if !match {
|
|
return true
|
|
}
|
|
}
|
|
if len(s.routeAddressSet) > 0 {
|
|
var match bool
|
|
for _, ipSet := range s.routeAddressSet {
|
|
if ipSet.Contains(destinationAddr) {
|
|
match = true
|
|
}
|
|
}
|
|
if !match {
|
|
return true
|
|
}
|
|
}
|
|
if len(s.routeExcludeAddress) > 0 {
|
|
for _, address := range s.routeExcludeAddress {
|
|
if address.Contains(destinationAddr) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
if len(s.routeExcludeAddressSet) > 0 {
|
|
for _, ipSet := range s.routeAddressSet {
|
|
if ipSet.Contains(destinationAddr) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
if s.trackerOut.CheckDestination(destination) {
|
|
if debug.Enabled {
|
|
s.logger.Trace("passing pending ", source, " ", destination)
|
|
}
|
|
return true
|
|
}
|
|
if transportProtocol == header.TCPProtocolNumber {
|
|
if s.trackerOut.CheckConn(source, destination) {
|
|
if debug.Enabled {
|
|
s.logger.Trace("passing TCP ", source, " ", destination)
|
|
}
|
|
return true
|
|
}
|
|
} else {
|
|
if s.trackerOut.CheckPacketConn(source) {
|
|
if debug.Enabled {
|
|
s.logger.Trace("passing UDP ", source, " ", destination)
|
|
}
|
|
}
|
|
}
|
|
if debug.Enabled {
|
|
s.logger.Trace("fall ", source, " ", destination)
|
|
}
|
|
return false
|
|
}
|