fix: Watch sysctl changes to ensure expected values (#426)

Co-authored-by: Sumire (菫) <151038614+sumire88@users.noreply.github.com>
This commit is contained in:
/gray 2024-01-23 21:11:44 +08:00 committed by GitHub
parent c26169d3a4
commit f47caada0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 125 additions and 37 deletions

View File

@ -471,6 +471,11 @@ func NewControlPlane(
}
go dnsUpstream.InitUpstreams()
InitDaeNetns(log)
if err = InitSysctlManager(log); err != nil {
return nil, err
}
close(plane.ready)
return plane, nil
}

View File

@ -556,7 +556,9 @@ func (c *controlPlaneCore) setupSkPidMonitor() error {
func (c *controlPlaneCore) bindWan(ifname string, autoConfigKernelParameter bool) error {
if autoConfigKernelParameter {
SetAcceptLocal(ifname, "1")
if err := sysctl.Set(fmt.Sprintf("net.ipv4.conf.%v.accept_local", ifname), "1", false); err != nil {
return err
}
}
return c._bindWan(ifname)
}

View File

@ -409,7 +409,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Msg, req *udpRequest) (er
// resp is valid.
cache2 := c.LookupDnsRespCache(c.cacheKey(qname, qtype2), true)
if c.qtypePrefer == qtype || cache2 == nil || !cache2.IncludeAnyIp() {
return sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn)
return sendPkt(c.log, resp, req.realDst, req.realSrc, req.src, req.lConn)
} else {
return c.sendReject_(dnsMessage, req)
}
@ -453,7 +453,7 @@ func (c *DnsController) handle_(
if resp := c.LookupDnsRespCache_(dnsMessage, cacheKey, false); resp != nil {
// Send cache to client directly.
if needResp {
if err = sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn); err != nil {
if err = sendPkt(c.log, resp, req.realDst, req.realSrc, req.src, req.lConn); err != nil {
return fmt.Errorf("failed to write cached DNS resp: %w", err)
}
}
@ -501,7 +501,7 @@ func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Msg, req *udpRequest)
if err != nil {
return fmt.Errorf("pack DNS packet: %w", err)
}
if err = sendPkt(data, req.realDst, req.realSrc, req.src, req.lConn); err != nil {
if err = sendPkt(c.log, data, req.realDst, req.realSrc, req.src, req.lConn); err != nil {
return err
}
return nil
@ -751,7 +751,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
if err != nil {
return err
}
if err = sendPkt(data, req.realDst, req.realSrc, req.src, req.lConn); err != nil {
if err = sendPkt(c.log, data, req.realDst, req.realSrc, req.src, req.lConn); err != nil {
return err
}
}

View File

@ -27,6 +27,8 @@ var (
)
type DaeNetns struct {
log *logrus.Logger
setupDone atomic.Bool
mu sync.Mutex
@ -34,8 +36,10 @@ type DaeNetns struct {
hostNs, daeNs netns.NsHandle
}
func init() {
daeNetns = &DaeNetns{}
func InitDaeNetns(log *logrus.Logger) {
daeNetns = &DaeNetns{
log: log,
}
}
func GetDaeNetns() *DaeNetns {
@ -85,7 +89,7 @@ func (ns *DaeNetns) With(f func() error) (err error) {
}
func (ns *DaeNetns) setup() (err error) {
logrus.Trace("setting up dae netns")
ns.log.Trace("setting up dae netns")
runtime.LockOSThread()
defer runtime.UnlockOSThread()
@ -140,27 +144,27 @@ func (ns *DaeNetns) setupVeth() (err error) {
func (ns *DaeNetns) setupSysctl() (err error) {
// sysctl net.ipv4.conf.dae0.rp_filter=0
if err = SetRpFilter(HostVethName, "0"); err != nil {
if err = sysctl.Set(fmt.Sprintf("net.ipv4.conf.%s.rp_filter", HostVethName), "0", true); err != nil {
return fmt.Errorf("failed to set rp_filter for dae0: %v", err)
}
// sysctl net.ipv4.conf.all.rp_filter=0
if err = SetRpFilter("all", "0"); err != nil {
if err = sysctl.Set("net.ipv4.conf.all.rp_filter", "0", true); err != nil {
return fmt.Errorf("failed to set rp_filter for all: %v", err)
}
// sysctl net.ipv4.conf.dae0.arp_filter=0
if err = SetArpFilter(HostVethName, "0"); err != nil {
if err = sysctl.Set(fmt.Sprintf("net.ipv4.conf.%s.arp_filter", HostVethName), "0", true); err != nil {
return fmt.Errorf("failed to set arp_filter for dae0: %v", err)
}
// sysctl net.ipv4.conf.all.arp_filter=0
if err = SetArpFilter("all", "0"); err != nil {
if err = sysctl.Set("net.ipv4.conf.all.arp_filter", "0", true); err != nil {
return fmt.Errorf("failed to set arp_filter for all: %v", err)
}
// sysctl net.ipv4.conf.dae0.accept_local=1
if err = SetAcceptLocal(HostVethName, "1"); err != nil {
if err = sysctl.Set(fmt.Sprintf("net.ipv4.conf.%s.accept_local", HostVethName), "1", true); err != nil {
return fmt.Errorf("failed to set accept_local for dae0: %v", err)
}
// sysctl net.ipv6.conf.dae0.disable_ipv6=0
if err = SetDisableIpv6(HostVethName, "0"); err != nil {
if err = sysctl.Set(fmt.Sprintf("net.ipv6.conf.%s.disable_ipv6", HostVethName), "0", true); err != nil {
return fmt.Errorf("failed to set disable_ipv6 for dae0: %v", err)
}
// sysctl net.ipv6.conf.dae0.forwarding=1
@ -286,17 +290,17 @@ func (ns *DaeNetns) monitorDae0LinkAddr() {
err := netlink.LinkSubscribe(ch, done)
if err != nil {
logrus.Errorf("failed to subscribe link updates: %v", err)
ns.log.Errorf("failed to subscribe link updates: %v", err)
}
if ns.dae0, err = netlink.LinkByName(HostVethName); err != nil {
logrus.Errorf("failed to get link dae0: %v", err)
ns.log.Errorf("failed to get link dae0: %v", err)
}
if err = ns.updateNeigh(); err != nil {
logrus.Errorf("failed to update neigh: %v", err)
ns.log.Errorf("failed to update neigh: %v", err)
}
for msg := range ch {
if msg.Link.Attrs().Name == HostVethName && !bytes.Equal(msg.Link.Attrs().HardwareAddr, ns.dae0.Attrs().HardwareAddr) {
logrus.WithField("old addr", ns.dae0.Attrs().HardwareAddr).WithField("new addr", msg.Link.Attrs().HardwareAddr).Info("dae0 link addr changed")
ns.log.WithField("old addr", ns.dae0.Attrs().HardwareAddr).WithField("new addr", msg.Link.Attrs().HardwareAddr).Info("dae0 link addr changed")
ns.dae0 = msg.Link
ns.updateNeigh()
}

90
control/sysctl.go Normal file
View File

@ -0,0 +1,90 @@
package control
import (
"os"
"strings"
"sync"
"github.com/fsnotify/fsnotify"
"github.com/sirupsen/logrus"
)
const SysctlPrefixPath = "/proc/sys/"
var sysctl *SysctlManager
type SysctlManager struct {
log *logrus.Logger
mux sync.Mutex
watcher *fsnotify.Watcher
expectations map[string]string
}
func InitSysctlManager(log *logrus.Logger) (err error) {
sysctl, err = NewSysctlManager(log)
return err
}
func NewSysctlManager(log *logrus.Logger) (*SysctlManager, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
manager := &SysctlManager{
log: log,
mux: sync.Mutex{},
watcher: watcher,
expectations: map[string]string{},
}
go manager.startWatch()
return manager, nil
}
func (s *SysctlManager) startWatch() {
for {
select {
case event, ok := <-s.watcher.Events:
if !ok {
return
}
if event.Has(fsnotify.Write) {
s.log.Tracef("sysctl write event: %+v", event)
s.mux.Lock()
expected, ok := s.expectations[event.Name]
s.mux.Unlock()
if ok {
raw, err := os.ReadFile(event.Name)
if err != nil {
s.log.Errorf("failed to read sysctl file %s: %v", event.Name, err)
}
value := strings.TrimSpace(string(raw))
if value != expected {
s.log.Infof("sysctl %s has unexpected value %s, expected %s", event.Name, value, expected)
if err := os.WriteFile(event.Name, []byte(expected), 0644); err != nil {
s.log.Errorf("failed to write sysctl file %s: %v", event.Name, err)
}
}
}
}
case err, ok := <-s.watcher.Errors:
if !ok {
return
}
s.log.Errorf("sysctl watcher error: %v", err)
}
}
}
func (s *SysctlManager) Set(key string, value string, watch bool) (err error) {
path := SysctlPrefixPath + strings.Replace(key, ".", "/", -1)
if watch {
s.mux.Lock()
s.expectations[path] = value
s.mux.Unlock()
if err = s.watcher.Add(path); err != nil {
return
}
}
return os.WriteFile(path, []byte(value), 0644)
}

View File

@ -49,7 +49,7 @@ func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout
}
// sendPkt uses bind first, and fallback to send hdr if addr is in use.
func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn *net.UDPConn) (err error) {
func sendPkt(log *logrus.Logger, data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn *net.UDPConn) (err error) {
transparentTimeout := AnyfromTimeout
if from.Port() == 53 {
@ -58,7 +58,7 @@ func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn
}
uConn, _, err := DefaultAnyfromPool.GetOrCreate(from.String(), transparentTimeout)
if err != nil && errors.Is(err, syscall.EADDRINUSE) {
logrus.WithField("from", from).
log.WithField("from", from).
WithField("to", to).
WithField("realTo", realTo).
Trace("Port in use, fallback to use netns.")
@ -187,7 +187,7 @@ getNew:
// Handler handles response packets and send it to the client.
Handler: func(data []byte, from netip.AddrPort) (err error) {
// Do not return conn-unrelated err in this func.
return sendPkt(data, from, realSrc, src, lConn)
return sendPkt(c.log, data, from, realSrc, src, lConn)
},
NatTimeout: natTimeout,
GetDialOption: func() (option *DialOption, err error) {

View File

@ -128,22 +128,6 @@ func SetForwarding(ifname string, val string) {
_ = setForwarding(ifname, consts.IpVersionStr_6, val)
}
func SetAcceptLocal(ifname, val string) error {
return os.WriteFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/accept_local", ifname), []byte(val), 0644)
}
func SetRpFilter(ifname, val string) error {
return os.WriteFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/rp_filter", ifname), []byte(val), 0644)
}
func SetArpFilter(ifname, val string) error {
return os.WriteFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/arp_filter", ifname), []byte(val), 0644)
}
func SetDisableIpv6(ifname, val string) error {
return os.WriteFile(fmt.Sprintf("/proc/sys/net/ipv6/conf/%s/disable_ipv6", ifname), []byte(val), 0644)
}
func checkSendRedirects(ifname string, ipversion consts.IpVersionStr) error {
path := fmt.Sprintf("/proc/sys/net/ipv%v/conf/%v/send_redirects", ipversion, ifname)
b, err := os.ReadFile(path)

1
go.mod
View File

@ -54,6 +54,7 @@ require (
github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect
github.com/dgryski/go-rc2 v0.0.0-20150621095337-8a9021637152 // indirect
github.com/eknkc/basex v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect

2
go.sum
View File

@ -38,6 +38,8 @@ github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7z
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gaukas/godicttls v0.0.4 h1:NlRaXb3J6hAnTmWdsEKb9bcSBD6BvcIjdGdeb0zfXbk=
github.com/gaukas/godicttls v0.0.4/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=