feat(lan): use sockmap instead of sk_lookup for tproxy socket (#8)

This commit is contained in:
mzz 2023-02-07 13:49:47 +08:00 committed by GitHub
parent 6f1ec9a4d6
commit 9f33ecf809
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 44 additions and 19 deletions

View File

@ -26,6 +26,9 @@ var (
logrus.Fatalln("Argument \"--config\" or \"-c\" is required but not provided.")
}
// Require "sudo" if necessary.
internal.AutoSu()
// Read config from --config cfgFile.
param, err := readConfig(cfgFile)
if err != nil {
@ -48,9 +51,6 @@ func init() {
func Run(log *logrus.Logger, param *config.Params) (err error) {
// Require "sudo" if necessary.
internal.AutoSu()
// Resolve subscriptions to nodes.
nodeList := make([]string, len(param.Node))
copy(nodeList, param.Node)

View File

@ -24,6 +24,8 @@ const (
DisableL4TxChecksumKey
DisableL4RxChecksumKey
ControlPlaneOidKey
OneKey ParamKey = 1
)
type DisableL4ChecksumPolicy uint32

View File

@ -345,8 +345,30 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) {
defer packetConn.Close()
udpConn := packetConn.(*net.UDPConn)
// Serve.
/// Serve.
// TCP socket.
tcpFile, err := tcpListener.(*net.TCPListener).File()
if err != nil {
return fmt.Errorf("failed to retrieve copy of the underlying TCP connection file")
}
c.deferFuncs = append(c.deferFuncs, func() error {
return tcpFile.Close()
})
if err := c.bpf.ListenSocketMap.Update(consts.ZeroKey, uint64(tcpFile.Fd()), ebpf.UpdateAny); err != nil {
return err
}
// UDP socket.
udpFile, err := udpConn.File()
if err != nil {
return fmt.Errorf("failed to retrieve copy of the underlying UDP connection file")
}
c.deferFuncs = append(c.deferFuncs, func() error {
return udpFile.Close()
})
if err := c.bpf.ListenSocketMap.Update(consts.OneKey, uint64(udpFile.Fd()), ebpf.UpdateAny); err != nil {
return err
}
// Port.
if err := c.bpf.ParamMap.Update(consts.BigEndianTproxyPortKey, uint32(internal.Htons(port)), ebpf.UpdateAny); err != nil {
return err
}

View File

@ -74,9 +74,18 @@ enum {
DisableL4ChecksumPolicy_SetZero,
};
// Sockmap:
struct {
__uint(type, BPF_MAP_TYPE_SOCKMAP);
__type(key, __u32); // 0 is tcp, 1 is udp.
__type(value, __u64); // fd of socket.
__uint(max_entries, 2);
} listen_socket_map SEC(".maps");
// Param keys:
static const __u32 zero_key = 0;
static const __u32 tproxy_port_key = 1;
static const __u32 one_key = 1;
static const __u32 disable_l4_tx_checksum_key
__attribute__((unused, deprecated)) = 2;
static const __u32 disable_l4_rx_checksum_key
@ -1191,7 +1200,7 @@ int tproxy_lan_ingress(struct __sk_buff *skb) {
struct bpf_sock_tuple tuple = {0};
__u32 tuple_size;
struct bpf_sock *sk;
bool is_old_conn;
bool is_old_conn = false;
__u32 flag[6] = {0};
void *l4hdr;
@ -1288,28 +1297,20 @@ new_connection:
}
// Assign to control plane.
__be16 *tproxy_port = bpf_map_lookup_elem(&param_map, &tproxy_port_key);
if (!tproxy_port) {
bpf_printk("shot tproxy port not set: %d", ret);
return TC_ACT_SHOT;
}
__builtin_memset(&tuple, 0, sizeof(tuple));
tuple.ipv6.daddr[3] = bpf_htonl(0x00000001);
tuple.ipv6.dport = *tproxy_port;
if (l4proto == IPPROTO_TCP) {
// TCP.
sk = bpf_skc_lookup_tcp(skb, &tuple, sizeof(tuple), BPF_F_CURRENT_NETNS, 0);
sk = bpf_map_lookup_elem(&listen_socket_map, &zero_key);
if (!sk || sk->state != BPF_TCP_LISTEN) {
bpf_printk("shot tproxy not listen: %d", ret);
bpf_printk("shot tcp tproxy not listen: %d", ret);
goto sk_shot;
}
} else {
// UDP.
sk = bpf_sk_lookup_udp(skb, &tuple, sizeof(tuple), BPF_F_CURRENT_NETNS, 0);
sk = bpf_map_lookup_elem(&listen_socket_map, &one_key);
if (!sk) {
bpf_printk("shot udp tproxy not listen: %d", ret);
goto sk_shot;
}
}

View File

@ -34,7 +34,7 @@ func (c *ControlPlaneCore) RetrieveOutboundIndex(src, dst netip.AddrPort, l4prot
var _outboundIndex uint32
if err := c.bpf.RoutingTuplesMap.Lookup(tuples, &_outboundIndex); err != nil {
return 0, nil, fmt.Errorf("reading map: key %v: %w", src.String(), err)
return 0, nil, fmt.Errorf("reading map: key [%v, %v, %v]: %w", src.String(), l4proto, dst.String(), err)
}
if _outboundIndex > uint32(consts.OutboundLogicalMax) {
return 0, nil, fmt.Errorf("bad outbound index")