diff --git a/control/control_plane.go b/control/control_plane.go index 0f8f669..8f23f2f 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -471,6 +471,11 @@ func NewControlPlane( } go dnsUpstream.InitUpstreams() + InitDaeNetns(log) + if err = InitSysctlManager(log); err != nil { + return nil, err + } + close(plane.ready) return plane, nil } diff --git a/control/control_plane_core.go b/control/control_plane_core.go index 6becce7..bafe75c 100644 --- a/control/control_plane_core.go +++ b/control/control_plane_core.go @@ -556,7 +556,9 @@ func (c *controlPlaneCore) setupSkPidMonitor() error { func (c *controlPlaneCore) bindWan(ifname string, autoConfigKernelParameter bool) error { if autoConfigKernelParameter { - SetAcceptLocal(ifname, "1") + if err := sysctl.Set(fmt.Sprintf("net.ipv4.conf.%v.accept_local", ifname), "1", false); err != nil { + return err + } } return c._bindWan(ifname) } diff --git a/control/dns_control.go b/control/dns_control.go index e1f91fc..088a601 100644 --- a/control/dns_control.go +++ b/control/dns_control.go @@ -409,7 +409,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Msg, req *udpRequest) (er // resp is valid. cache2 := c.LookupDnsRespCache(c.cacheKey(qname, qtype2), true) if c.qtypePrefer == qtype || cache2 == nil || !cache2.IncludeAnyIp() { - return sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn) + return sendPkt(c.log, resp, req.realDst, req.realSrc, req.src, req.lConn) } else { return c.sendReject_(dnsMessage, req) } @@ -453,7 +453,7 @@ func (c *DnsController) handle_( if resp := c.LookupDnsRespCache_(dnsMessage, cacheKey, false); resp != nil { // Send cache to client directly. if needResp { - if err = sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn); err != nil { + if err = sendPkt(c.log, resp, req.realDst, req.realSrc, req.src, req.lConn); err != nil { return fmt.Errorf("failed to write cached DNS resp: %w", err) } } @@ -501,7 +501,7 @@ func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Msg, req *udpRequest) if err != nil { return fmt.Errorf("pack DNS packet: %w", err) } - if err = sendPkt(data, req.realDst, req.realSrc, req.src, req.lConn); err != nil { + if err = sendPkt(c.log, data, req.realDst, req.realSrc, req.src, req.lConn); err != nil { return err } return nil @@ -751,7 +751,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte if err != nil { return err } - if err = sendPkt(data, req.realDst, req.realSrc, req.src, req.lConn); err != nil { + if err = sendPkt(c.log, data, req.realDst, req.realSrc, req.src, req.lConn); err != nil { return err } } diff --git a/control/netns_utils.go b/control/netns_utils.go index a33f5fb..1dee26a 100644 --- a/control/netns_utils.go +++ b/control/netns_utils.go @@ -27,6 +27,8 @@ var ( ) type DaeNetns struct { + log *logrus.Logger + setupDone atomic.Bool mu sync.Mutex @@ -34,8 +36,10 @@ type DaeNetns struct { hostNs, daeNs netns.NsHandle } -func init() { - daeNetns = &DaeNetns{} +func InitDaeNetns(log *logrus.Logger) { + daeNetns = &DaeNetns{ + log: log, + } } func GetDaeNetns() *DaeNetns { @@ -85,7 +89,7 @@ func (ns *DaeNetns) With(f func() error) (err error) { } func (ns *DaeNetns) setup() (err error) { - logrus.Trace("setting up dae netns") + ns.log.Trace("setting up dae netns") runtime.LockOSThread() defer runtime.UnlockOSThread() @@ -140,27 +144,27 @@ func (ns *DaeNetns) setupVeth() (err error) { func (ns *DaeNetns) setupSysctl() (err error) { // sysctl net.ipv4.conf.dae0.rp_filter=0 - if err = SetRpFilter(HostVethName, "0"); err != nil { + if err = sysctl.Set(fmt.Sprintf("net.ipv4.conf.%s.rp_filter", HostVethName), "0", true); err != nil { return fmt.Errorf("failed to set rp_filter for dae0: %v", err) } // sysctl net.ipv4.conf.all.rp_filter=0 - if err = SetRpFilter("all", "0"); err != nil { + if err = sysctl.Set("net.ipv4.conf.all.rp_filter", "0", true); err != nil { return fmt.Errorf("failed to set rp_filter for all: %v", err) } // sysctl net.ipv4.conf.dae0.arp_filter=0 - if err = SetArpFilter(HostVethName, "0"); err != nil { + if err = sysctl.Set(fmt.Sprintf("net.ipv4.conf.%s.arp_filter", HostVethName), "0", true); err != nil { return fmt.Errorf("failed to set arp_filter for dae0: %v", err) } // sysctl net.ipv4.conf.all.arp_filter=0 - if err = SetArpFilter("all", "0"); err != nil { + if err = sysctl.Set("net.ipv4.conf.all.arp_filter", "0", true); err != nil { return fmt.Errorf("failed to set arp_filter for all: %v", err) } // sysctl net.ipv4.conf.dae0.accept_local=1 - if err = SetAcceptLocal(HostVethName, "1"); err != nil { + if err = sysctl.Set(fmt.Sprintf("net.ipv4.conf.%s.accept_local", HostVethName), "1", true); err != nil { return fmt.Errorf("failed to set accept_local for dae0: %v", err) } // sysctl net.ipv6.conf.dae0.disable_ipv6=0 - if err = SetDisableIpv6(HostVethName, "0"); err != nil { + if err = sysctl.Set(fmt.Sprintf("net.ipv6.conf.%s.disable_ipv6", HostVethName), "0", true); err != nil { return fmt.Errorf("failed to set disable_ipv6 for dae0: %v", err) } // sysctl net.ipv6.conf.dae0.forwarding=1 @@ -286,17 +290,17 @@ func (ns *DaeNetns) monitorDae0LinkAddr() { err := netlink.LinkSubscribe(ch, done) if err != nil { - logrus.Errorf("failed to subscribe link updates: %v", err) + ns.log.Errorf("failed to subscribe link updates: %v", err) } if ns.dae0, err = netlink.LinkByName(HostVethName); err != nil { - logrus.Errorf("failed to get link dae0: %v", err) + ns.log.Errorf("failed to get link dae0: %v", err) } if err = ns.updateNeigh(); err != nil { - logrus.Errorf("failed to update neigh: %v", err) + ns.log.Errorf("failed to update neigh: %v", err) } for msg := range ch { if msg.Link.Attrs().Name == HostVethName && !bytes.Equal(msg.Link.Attrs().HardwareAddr, ns.dae0.Attrs().HardwareAddr) { - logrus.WithField("old addr", ns.dae0.Attrs().HardwareAddr).WithField("new addr", msg.Link.Attrs().HardwareAddr).Info("dae0 link addr changed") + ns.log.WithField("old addr", ns.dae0.Attrs().HardwareAddr).WithField("new addr", msg.Link.Attrs().HardwareAddr).Info("dae0 link addr changed") ns.dae0 = msg.Link ns.updateNeigh() } diff --git a/control/sysctl.go b/control/sysctl.go new file mode 100644 index 0000000..088ddaa --- /dev/null +++ b/control/sysctl.go @@ -0,0 +1,90 @@ +package control + +import ( + "os" + "strings" + "sync" + + "github.com/fsnotify/fsnotify" + "github.com/sirupsen/logrus" +) + +const SysctlPrefixPath = "/proc/sys/" + +var sysctl *SysctlManager + +type SysctlManager struct { + log *logrus.Logger + mux sync.Mutex + watcher *fsnotify.Watcher + expectations map[string]string +} + +func InitSysctlManager(log *logrus.Logger) (err error) { + sysctl, err = NewSysctlManager(log) + return err +} + +func NewSysctlManager(log *logrus.Logger) (*SysctlManager, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + manager := &SysctlManager{ + log: log, + mux: sync.Mutex{}, + watcher: watcher, + expectations: map[string]string{}, + } + go manager.startWatch() + return manager, nil +} + +func (s *SysctlManager) startWatch() { + for { + select { + case event, ok := <-s.watcher.Events: + if !ok { + return + } + if event.Has(fsnotify.Write) { + s.log.Tracef("sysctl write event: %+v", event) + s.mux.Lock() + expected, ok := s.expectations[event.Name] + s.mux.Unlock() + if ok { + raw, err := os.ReadFile(event.Name) + if err != nil { + s.log.Errorf("failed to read sysctl file %s: %v", event.Name, err) + } + value := strings.TrimSpace(string(raw)) + if value != expected { + s.log.Infof("sysctl %s has unexpected value %s, expected %s", event.Name, value, expected) + if err := os.WriteFile(event.Name, []byte(expected), 0644); err != nil { + s.log.Errorf("failed to write sysctl file %s: %v", event.Name, err) + } + } + } + } + case err, ok := <-s.watcher.Errors: + if !ok { + return + } + s.log.Errorf("sysctl watcher error: %v", err) + } + } +} + +func (s *SysctlManager) Set(key string, value string, watch bool) (err error) { + path := SysctlPrefixPath + strings.Replace(key, ".", "/", -1) + if watch { + s.mux.Lock() + s.expectations[path] = value + s.mux.Unlock() + if err = s.watcher.Add(path); err != nil { + return + } + } + return os.WriteFile(path, []byte(value), 0644) +} diff --git a/control/udp.go b/control/udp.go index a647580..d996fce 100644 --- a/control/udp.go +++ b/control/udp.go @@ -49,7 +49,7 @@ func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout } // sendPkt uses bind first, and fallback to send hdr if addr is in use. -func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn *net.UDPConn) (err error) { +func sendPkt(log *logrus.Logger, data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn *net.UDPConn) (err error) { transparentTimeout := AnyfromTimeout if from.Port() == 53 { @@ -58,7 +58,7 @@ func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn } uConn, _, err := DefaultAnyfromPool.GetOrCreate(from.String(), transparentTimeout) if err != nil && errors.Is(err, syscall.EADDRINUSE) { - logrus.WithField("from", from). + log.WithField("from", from). WithField("to", to). WithField("realTo", realTo). Trace("Port in use, fallback to use netns.") @@ -187,7 +187,7 @@ getNew: // Handler handles response packets and send it to the client. Handler: func(data []byte, from netip.AddrPort) (err error) { // Do not return conn-unrelated err in this func. - return sendPkt(data, from, realSrc, src, lConn) + return sendPkt(c.log, data, from, realSrc, src, lConn) }, NatTimeout: natTimeout, GetDialOption: func() (option *DialOption, err error) { diff --git a/control/utils.go b/control/utils.go index 9888845..f0b93d2 100644 --- a/control/utils.go +++ b/control/utils.go @@ -128,22 +128,6 @@ func SetForwarding(ifname string, val string) { _ = setForwarding(ifname, consts.IpVersionStr_6, val) } -func SetAcceptLocal(ifname, val string) error { - return os.WriteFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/accept_local", ifname), []byte(val), 0644) -} - -func SetRpFilter(ifname, val string) error { - return os.WriteFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/rp_filter", ifname), []byte(val), 0644) -} - -func SetArpFilter(ifname, val string) error { - return os.WriteFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/arp_filter", ifname), []byte(val), 0644) -} - -func SetDisableIpv6(ifname, val string) error { - return os.WriteFile(fmt.Sprintf("/proc/sys/net/ipv6/conf/%s/disable_ipv6", ifname), []byte(val), 0644) -} - func checkSendRedirects(ifname string, ipversion consts.IpVersionStr) error { path := fmt.Sprintf("/proc/sys/net/ipv%v/conf/%v/send_redirects", ipversion, ifname) b, err := os.ReadFile(path) diff --git a/go.mod b/go.mod index 365956c..604fe8b 100644 --- a/go.mod +++ b/go.mod @@ -54,6 +54,7 @@ require ( github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect github.com/dgryski/go-rc2 v0.0.0-20150621095337-8a9021637152 // indirect github.com/eknkc/basex v1.0.1 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/uuid v1.3.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 85bc295..3588dc5 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7z github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gaukas/godicttls v0.0.4 h1:NlRaXb3J6hAnTmWdsEKb9bcSBD6BvcIjdGdeb0zfXbk= github.com/gaukas/godicttls v0.0.4/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=