refactor: refactor lan tproxy using sk_assign

This commit is contained in:
mzz2017 2023-02-06 13:56:43 +08:00
parent 31fcc288fc
commit 8c65f8ff39
11 changed files with 492 additions and 328 deletions

View File

@ -57,6 +57,8 @@ const (
OutboundControlPlaneDirect OutboundIndex = 0xFD
OutboundLogicalOr OutboundIndex = 0xFE
OutboundLogicalAnd OutboundIndex = 0xFF
OutboundLogicalMax = OutboundLogicalAnd
)
func (i OutboundIndex) String() string {
@ -102,6 +104,11 @@ var (
FtraceFeatureVersion = internal.Version{5, 5, 0}
UserspaceBatchUpdateFeatureVersion = internal.Version{5, 6, 0}
CgSocketCookieFeatureVersion = internal.Version{5, 7, 0}
SkAssignFeatureVersion = internal.Version{5, 7, 0}
ChecksumFeatureVersion = internal.Version{5, 8, 0}
UserspaceBatchUpdateLpmTrieFeatureVersion = internal.Version{5, 13, 0}
)
const (
TproxyMark uint32 = 0x80000000
)

View File

@ -7,6 +7,7 @@ package control
import (
"context"
"encoding/hex"
"errors"
"fmt"
"github.com/cilium/ebpf"
@ -30,6 +31,7 @@ import (
"strconv"
"strings"
"sync"
"syscall"
"time"
)
@ -68,15 +70,24 @@ func NewControlPlane(
}
// Must judge version from high to low to reduce the number of user upgrading kernel.
if kernelVersion.Less(consts.ChecksumFeatureVersion) {
return nil, fmt.Errorf("your kernel version %v does not support checksum related features; expect >=%v; upgrade your kernel and try again", kernelVersion.String(),
return nil, fmt.Errorf("your kernel version %v does not support checksum related features; expect >=%v; upgrade your kernel and try again",
kernelVersion.String(),
consts.ChecksumFeatureVersion.String())
}
if len(wanInterface) > 0 && kernelVersion.Less(consts.CgSocketCookieFeatureVersion) {
return nil, fmt.Errorf("your kernel version %v does not support bind to WAN; expect >=%v; remove wan_interface in config file and try again", kernelVersion.String(),
return nil, fmt.Errorf("your kernel version %v does not support bind to WAN; expect >=%v; remove wan_interface in config file and try again",
kernelVersion.String(),
consts.CgSocketCookieFeatureVersion.String())
}
if len(lanInterface) > 0 && c.kernelVersion.Less(consts.SkAssignFeatureVersion) {
return nil, fmt.Errorf("your kernel version %v does not support bind to LAN; expect >=%v; remove lan_interface in config file and try again",
c.kernelVersion.String(),
consts.SkAssignFeatureVersion.String())
}
if kernelVersion.Less(consts.BasicFeatureVersion) {
return nil, fmt.Errorf("your kernel version %v does not satisfy basic requirement; expect >=%v", c.kernelVersion.String(), consts.BasicFeatureVersion.String())
return nil, fmt.Errorf("your kernel version %v does not satisfy basic requirement; expect >=%v",
c.kernelVersion.String(),
consts.BasicFeatureVersion.String())
}
// Allow the current process to lock memory for eBPF resources.
@ -127,11 +138,11 @@ retryLoadBpf:
goto retryLoadBpf
}
// Get detailed log from ebpf.internal.(*VerifierError)
if log.Level == logrus.PanicLevel {
if log.Level == logrus.FatalLevel {
if v := reflect.Indirect(reflect.ValueOf(errors.Unwrap(errors.Unwrap(err)))); v.Kind() == reflect.Struct {
if _log := v.FieldByName("Log"); _log.IsValid() {
if strSlice, ok := _log.Interface().([]string); ok {
log.Panicln(strings.Join(strSlice, "\n"))
log.Fatalln(strings.Join(strSlice, "\n"))
}
}
}
@ -306,19 +317,22 @@ retryLoadBpf:
func (c *ControlPlane) ListenAndServe(port uint16) (err error) {
// Listen.
listener, err := net.Listen("tcp", "0.0.0.0:"+strconv.Itoa(int(port)))
var listenConfig = net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
return dialer.TproxyControl(c)
},
}
tcpListener, err := listenConfig.Listen(context.TODO(), "tcp", "[::1]:"+strconv.Itoa(int(port)))
if err != nil {
return fmt.Errorf("listenTCP: %w", err)
}
defer listener.Close()
lConn, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.IP{0, 0, 0, 0},
Port: int(port),
})
defer tcpListener.Close()
packetConn, err := listenConfig.ListenPacket(context.TODO(), "udp", "[::1]:"+strconv.Itoa(int(port)))
if err != nil {
return fmt.Errorf("listenUDP: %w", err)
}
defer lConn.Close()
defer packetConn.Close()
udpConn := packetConn.(*net.UDPConn)
// Serve.
@ -334,7 +348,7 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) {
go func() {
defer cancel()
for {
lconn, err := listener.Accept()
lconn, err := tcpListener.Accept()
if err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
c.log.Errorf("Error when accept: %v", err)
@ -352,26 +366,30 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) {
defer cancel()
for {
var buf [65535]byte
n, lAddrPort, err := lConn.ReadFromUDPAddrPort(buf[:])
var oob [120]byte // Size for original dest
n, oobn, _, src, err := udpConn.ReadMsgUDPAddrPort(buf[:], oob[:])
if err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
c.log.Errorf("ReadFromUDPAddrPort: %v, %v", lAddrPort.String(), err)
c.log.Errorf("ReadFromUDPAddrPort: %v, %v", src.String(), err)
}
break
}
addrHdr, dataOffset, err := ParseAddrHdr(buf[:n])
if err != nil {
c.log.Warnf("No AddrPort presented")
dst := RetrieveOriginalDest(oob[:oobn])
if !dst.IsValid() {
c.log.WithFields(logrus.Fields{
"source": src.String(),
"oob": hex.EncodeToString(oob[:oobn]),
}).Warnf("Failed to retrieve original dest")
continue
}
newBuf := pool.Get(n - dataOffset)
copy(newBuf, buf[dataOffset:n])
go func(data []byte, lConn *net.UDPConn, lAddrPort netip.AddrPort, addrHdr *AddrHdr) {
if e := c.handlePkt(newBuf, lConn, lAddrPort, addrHdr); e != nil {
newBuf := pool.Get(n)
copy(newBuf, buf[:n])
go func(data []byte, src, dst netip.AddrPort) {
if e := c.handlePkt(newBuf, src, dst); e != nil {
c.log.Warnln("handlePkt:", e)
}
pool.Put(newBuf)
}(newBuf, lConn, lAddrPort, addrHdr)
}(newBuf, src, dst)
}
}()
<-ctx.Done()

View File

@ -18,6 +18,7 @@ import (
"golang.org/x/sys/unix"
"net/netip"
"os"
"os/exec"
"regexp"
)
@ -115,7 +116,23 @@ func (c *ControlPlaneCore) BindLan(ifname string) error {
if err != nil {
return err
}
/// Insert ip rule / ip route.
if err = exec.Command("sh", "-c", `
ip rule add fwmark 0x80000000/0x80000000 table 2023
ip route add local 0.0.0.0/0 dev lo table 2023
ip -6 rule add fwmark 0x80000000/0x80000000 table 2023
ip -6 route add local ::/0 dev lo table 2023
`).Run(); err != nil {
return err
}
c.deferFuncs = append(c.deferFuncs, func() error {
return exec.Command("sh", "-c", `
ip rule del fwmark 0x80000000/0x80000000 table 2023
ip route del local 0.0.0.0/0 dev lo table 2023
ip -6 rule del fwmark 0x80000000/0x80000000 table 2023
ip -6 route del local ::/0 dev lo table 2023
`).Run()
})
/// Insert an elem into IfindexParamsMap.
ifParams, err := getifParamsFromLink(link)
if err != nil {

View File

@ -41,6 +41,7 @@
#define NOWHERE_IFINDEX 0
#define LOOPBACK_IFINDEX 1
#define LOOPBACK_ADDR 0x7f000001
#define MAX_PARAM_LEN 16
#define MAX_INTERFACE_NUM 128
@ -61,6 +62,12 @@
#define OUTBOUND_LOGICAL_AND 0xFF
#define OUTBOUND_LOGICAL_MASK 0xFE
#define TPROXY_MARK 0x80000000
#define ESOCKTNOSUPPORT 94 /* Socket type not supported */
enum { BPF_F_CURRENT_NETNS = -1 };
enum {
DisableL4ChecksumPolicy_EnableL4Checksum,
DisableL4ChecksumPolicy_Restore,
@ -88,6 +95,12 @@ struct ip_port_outbound {
__u8 unused;
};
struct tuples {
struct ip_port src;
struct ip_port dst;
__u8 l4proto;
};
/// TODO: Remove items from the dst_map by conntrack.
// Dest map:
struct {
@ -103,6 +116,13 @@ struct {
__uint(pinning, LIBBPF_PIN_BY_NAME);
} tcp_dst_map SEC(".maps");
struct {
__uint(type, BPF_MAP_TYPE_LRU_HASH);
__type(key, struct tuples);
__type(value, __u32); // outbound
__uint(max_entries, MAX_DST_MAPPING_NUM);
} routing_tuples_map SEC(".maps");
// Params:
struct {
__uint(type, BPF_MAP_TYPE_ARRAY);
@ -575,7 +595,8 @@ parse_transport(const struct __sk_buff *skb, struct ethhdr *ethh,
__u8 *l4proto) {
__u32 offset = 0;
int ret = bpf_skb_load_bytes(skb, offset, ethh, sizeof(struct ethhdr));
int ret;
ret = bpf_skb_load_bytes(skb, offset, ethh, sizeof(struct ethhdr));
if (ret) {
bpf_printk("not ethernet packet");
return 1;
@ -583,8 +604,8 @@ parse_transport(const struct __sk_buff *skb, struct ethhdr *ethh,
// Skip ethhdr for next hdr.
offset += sizeof(struct ethhdr);
*ihl = 0;
*ipversion = 0;
*ihl = 0;
*l4proto = 0;
// bpf_printk("parse_transport: h_proto: %u ? %u %u", eth->h_proto,
@ -635,8 +656,9 @@ parse_transport(const struct __sk_buff *skb, struct ethhdr *ethh,
return handle_ipv6_extensions(skb, offset, ipv6h->nexthdr, tcph, udph, ihl,
l4proto);
} else {
return 1;
}
return 1;
}
static __always_inline int
@ -913,7 +935,7 @@ static __always_inline int decap_after_udp_hdr(struct __sk_buff *skb,
// Do not use __always_inline here because this function is too heavy.
static int __attribute__((noinline))
routing(const __u32 flag[6], const void *l4_hdr, const __be32 saddr[4],
routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
const __be32 _daddr[4], const __be32 mac[4]) {
#define _l4proto_type flag[0]
#define _ipversion_type flag[1]
@ -938,11 +960,11 @@ routing(const __u32 flag[6], const void *l4_hdr, const __be32 saddr[4],
__u16 h_dport;
__u16 h_sport;
if (_l4proto_type == L4ProtoType_TCP) {
h_dport = bpf_ntohs(((struct tcphdr *)l4_hdr)->dest);
h_sport = bpf_ntohs(((struct tcphdr *)l4_hdr)->source);
h_dport = bpf_ntohs(((struct tcphdr *)l4hdr)->dest);
h_sport = bpf_ntohs(((struct tcphdr *)l4hdr)->source);
} else {
h_dport = bpf_ntohs(((struct udphdr *)l4_hdr)->dest);
h_sport = bpf_ntohs(((struct udphdr *)l4_hdr)->source);
h_dport = bpf_ntohs(((struct udphdr *)l4hdr)->dest);
h_sport = bpf_ntohs(((struct udphdr *)l4hdr)->source);
}
key = MatchType_SourcePort;
@ -1153,11 +1175,9 @@ int tproxy_lan_ingress(struct __sk_buff *skb) {
struct ipv6hdr ipv6h;
struct tcphdr tcph;
struct udphdr udph;
// __sum16 bak_cksm = 0;
__u8 ihl;
__u8 ipversion;
__u8 l4proto;
bool tcp_state_syn;
int ret = parse_transport(skb, &ethh, &iph, &ipv6h, &tcph, &udph, &ihl,
&ipversion, &l4proto);
if (ret) {
@ -1165,216 +1185,182 @@ int tproxy_lan_ingress(struct __sk_buff *skb) {
return TC_ACT_OK;
}
// Backup for further use.
__be16 ipv4_tot_len = 0;
// Parse saddr and daddr as ipv6 format.
__be32 saddr[4];
__be32 daddr[4];
// Prepare five tuples.
struct tuples tuples = {0};
tuples.l4proto = l4proto;
if (ipversion == 4) {
saddr[0] = 0;
saddr[1] = 0;
saddr[2] = bpf_htonl(0x0000ffff);
saddr[3] = iph.saddr;
tuples.src.ip[2] = bpf_htonl(0x0000ffff);
tuples.src.ip[3] = iph.saddr;
daddr[0] = 0;
daddr[1] = 0;
daddr[2] = bpf_htonl(0x0000ffff);
daddr[3] = iph.daddr;
tuples.dst.ip[2] = bpf_htonl(0x0000ffff);
tuples.dst.ip[3] = iph.daddr;
ipv4_tot_len = iph.tot_len;
} else {
__builtin_memcpy(daddr, &ipv6h.daddr, IPV6_BYTE_LENGTH);
__builtin_memcpy(saddr, &ipv6h.saddr, IPV6_BYTE_LENGTH);
__builtin_memcpy(tuples.dst.ip, &ipv6h.daddr, IPV6_BYTE_LENGTH);
__builtin_memcpy(tuples.src.ip, &ipv6h.saddr, IPV6_BYTE_LENGTH);
}
if (l4proto == IPPROTO_TCP) {
tuples.src.port = tcph.source;
tuples.dst.port = tcph.dest;
} else {
tuples.src.port = udph.source;
tuples.dst.port = udph.dest;
}
__u32 ifindex = skb->ifindex;
struct if_params *ifparams =
bpf_map_lookup_elem(&ifindex_params_map, &ifindex);
if (unlikely(!ifparams)) {
return -1;
}
// Never disable checksum in rx.
bool disable_checksum = false;
/**
ip rule add fwmark 0x80000000/0x80000000 table 1000
ip route add local 0.0.0.0/0 dev lo table 1000
ip -6 rule add fwmark 0x80000000/0x80000000 table 1000
ip -6 route add local ::/0 dev lo table 1000
// If this packet is sent to this host and not a DNS packet, accept it.
__u32 tproxy_ip[4];
int to_host = ip_is_host(ipversion, ifparams, daddr, tproxy_ip);
if (to_host < 0) { // error
// bpf_printk("to_host: %ld", to_host);
return TC_ACT_OK;
}
if (to_host == 1) {
if (l4proto == IPPROTO_UDP && udph.dest == 53) {
// To udp:host:53. Process it.
} else {
// To host. Accept.
return TC_ACT_OK;
}
ip rule del fwmark 0x80000000/0x80000000 table 1000
ip route del local 0.0.0.0/0 dev lo table 1000
ip -6 rule del fwmark 0x80000000/0x80000000 table 1000
ip -6 route del local ::/0 dev lo table 1000
*/
struct bpf_sock_tuple tuple = {0};
__u32 tuple_size;
if (ipversion == 4) {
tuple.ipv4.daddr = tuples.dst.ip[3];
tuple.ipv4.saddr = tuples.src.ip[3];
tuple.ipv4.dport = tuples.dst.port;
tuple.ipv4.sport = tuples.src.port;
tuple_size = sizeof(tuple.ipv4);
} else {
__builtin_memcpy(tuple.ipv6.daddr, tuples.dst.ip, IPV6_BYTE_LENGTH);
__builtin_memcpy(tuple.ipv6.saddr, tuples.src.ip, IPV6_BYTE_LENGTH);
tuple.ipv6.dport = tuples.dst.port;
tuple.ipv6.sport = tuples.src.port;
tuple_size = sizeof(tuple.ipv6);
}
__be16 *tproxy_port = bpf_map_lookup_elem(&param_map, &tproxy_port_key);
if (!tproxy_port) {
return TC_ACT_OK;
}
struct bpf_sock *sk;
bool is_old_conn = false;
if (l4proto == IPPROTO_TCP) {
// Backup for further use.
// bak_cksm = tcph.check;
tcp_state_syn = tcph.syn && !tcph.ack;
struct ip_port key_src;
__builtin_memset(&key_src, 0, sizeof(key_src));
__builtin_memcpy(key_src.ip, saddr, IPV6_BYTE_LENGTH);
key_src.port = tcph.source;
__u8 outbound;
if (unlikely(tcp_state_syn)) {
// New TCP connection.
// bpf_printk("[%X]New Connection", bpf_ntohl(tcph.seq));
__u32 flag[6] = {L4ProtoType_TCP}; // TCP
if (ipversion == 6) {
flag[1] = IpVersionType_6;
} else {
flag[1] = IpVersionType_4;
}
__be32 mac[4] = {
0,
0,
bpf_htonl((ethh.h_source[0] << 8) + (ethh.h_source[1])),
bpf_htonl((ethh.h_source[2] << 24) + (ethh.h_source[3] << 16) +
(ethh.h_source[4] << 8) + (ethh.h_source[5])),
};
if ((ret = routing(flag, &tcph, saddr, daddr, mac)) < 0) {
bpf_printk("shot routing: %d", ret);
return TC_ACT_SHOT;
}
// TCP.
outbound = ret;
#if defined(__DEBUG_ROUTING) || defined(__PRINT_ROUTING_RESULT)
// Print only new connection.
bpf_printk("tcp(lan): outbound: %u, %pI6:%u", outbound, daddr,
bpf_ntohs(key_src.port));
#endif
} else {
// bpf_printk("[%X]Old Connection", bpf_ntohl(tcph.seq));
// The TCP connection exists.
struct ip_port_outbound *dst =
bpf_map_lookup_elem(&tcp_dst_map, &key_src);
if (!dst) {
// Do not impact previous connections.
return TC_ACT_OK;
sk = bpf_skc_lookup_tcp(skb, &tuple, tuple_size, BPF_F_CURRENT_NETNS, 0);
if (sk) {
if (sk->state != BPF_TCP_LISTEN) {
// Old connection.
is_old_conn = true;
goto assign;
}
outbound = dst->outbound;
bpf_sk_release(sk);
}
} else {
// UDP.
if (outbound == OUTBOUND_DIRECT) {
return TC_ACT_OK;
} else if (unlikely(outbound == OUTBOUND_BLOCK)) {
return TC_ACT_SHOT;
} else {
// Rewrite to control plane.
if (unlikely(tcp_state_syn)) {
struct ip_port_outbound value_dst;
__builtin_memcpy(value_dst.ip, daddr, IPV6_BYTE_LENGTH);
value_dst.port = tcph.dest;
value_dst.outbound = outbound;
bpf_map_update_elem(&tcp_dst_map, &key_src, &value_dst, BPF_ANY);
}
__u32 *dst_ip = daddr;
__u16 dst_port = tcph.dest;
if ((ret = rewrite_ip(skb, ipversion, IPPROTO_TCP, ihl, dst_ip, tproxy_ip,
true, !disable_checksum))) {
bpf_printk("Shot IP: %d", ret);
return TC_ACT_SHOT;
}
if ((ret = rewrite_port(skb, IPPROTO_TCP, ihl, dst_port, *tproxy_port,
true, !disable_checksum))) {
bpf_printk("Shot Port: %d", ret);
return TC_ACT_SHOT;
}
}
} else if (l4proto == IPPROTO_UDP) {
// Backup for further use.
// bak_cksm = udph.check;
struct ip_port_outbound new_hdr;
__builtin_memset(&new_hdr, 0, sizeof(new_hdr));
__builtin_memcpy(new_hdr.ip, daddr, IPV6_BYTE_LENGTH);
new_hdr.port = udph.dest;
// Routing. It decides if we redirect traffic to control plane.
__u32 flag[6] = {L4ProtoType_UDP};
if (ipversion == 6) {
flag[1] = IpVersionType_6;
} else {
flag[1] = IpVersionType_4;
}
__be32 mac[4] = {
0,
0,
bpf_htonl((ethh.h_source[0] << 8) + (ethh.h_source[1])),
bpf_htonl((ethh.h_source[2] << 24) + (ethh.h_source[3] << 16) +
(ethh.h_source[4] << 8) + (ethh.h_source[5])),
};
if ((ret = routing(flag, &udph, saddr, daddr, mac)) < 0) {
bpf_printk("shot routing: %d", ret);
return TC_ACT_SHOT;
}
new_hdr.outbound = ret;
#if defined(__DEBUG_ROUTING) || defined(__PRINT_ROUTING_RESULT)
bpf_printk("udp(lan): outbound: %u, %pI6:%u", new_hdr.outbound, daddr,
bpf_ntohs(new_hdr.port));
#endif
if (new_hdr.outbound == OUTBOUND_DIRECT) {
return TC_ACT_OK;
} else if (unlikely(new_hdr.outbound == OUTBOUND_BLOCK)) {
return TC_ACT_SHOT;
} else {
// Rewrite to control plane.
// Encap a header to transmit fullcone tuple.
if ((ret =
encap_after_udp_hdr(skb, ipversion, ihl, ipv4_tot_len, &new_hdr,
sizeof(new_hdr), !disable_checksum))) {
return TC_ACT_SHOT;
}
// Rewrite udp dst ip.
// bpf_printk("rewrite dst ip from %pI4", &ori_dst.ip);
if ((ret = rewrite_ip(skb, ipversion, IPPROTO_UDP, ihl, new_hdr.ip,
tproxy_ip, true, !disable_checksum))) {
bpf_printk("Shot IP: %d", ret);
return TC_ACT_SHOT;
}
// Rewrite udp dst port.
if ((ret = rewrite_port(skb, IPPROTO_UDP, ihl, new_hdr.port, *tproxy_port,
true, !disable_checksum))) {
bpf_printk("Shot Port: %d", ret);
return TC_ACT_SHOT;
}
sk = bpf_sk_lookup_udp(skb, &tuple, tuple_size, BPF_F_CURRENT_NETNS, 0);
if (sk) {
goto assign;
}
}
// Print packet in hex for debugging (checksum or something else).
// bpf_printk("DEBUG");
// for (__u32 i = 0; i < skb->len && i < 200; i++) {
// __u8 t = 0;
// bpf_skb_load_bytes(skb, i, &t, 1);
// bpf_printk("%02x", t);
// }
// Routing for new connection.
__u32 flag[6] = {0}; // TCP
void *l4hdr;
if (l4proto == IPPROTO_TCP) {
l4hdr = &tcph;
flag[0] = L4ProtoType_TCP;
} else {
l4hdr = &udph;
flag[0] = L4ProtoType_UDP;
}
if (ipversion == 4) {
flag[1] = IpVersionType_4;
} else {
flag[1] = IpVersionType_6;
}
__be32 mac[4] = {
0,
0,
bpf_htonl((ethh.h_source[0] << 8) + (ethh.h_source[1])),
bpf_htonl((ethh.h_source[2] << 24) + (ethh.h_source[3] << 16) +
(ethh.h_source[4] << 8) + (ethh.h_source[5])),
};
if ((ret = routing(flag, l4hdr, tuples.src.ip, tuples.dst.ip, mac)) < 0) {
bpf_printk("shot routing: %d", ret);
return TC_ACT_SHOT;
}
__u32 outbound = ret;
#if defined(__DEBUG_ROUTING) || defined(__PRINT_ROUTING_RESULT)
if (l4proto == IPPROTO_TCP) {
bpf_printk("tcp(lan): outbound: %u, target: %pI6:%u", outbound,
tuples.dst.ip, bpf_ntohs(tuples.dst.port));
} else {
bpf_printk("udp(lan): outbound: %u, target: %pI6:%u", outbound,
tuples.dst.ip, bpf_ntohs(tuples.dst.port));
}
#endif
if (outbound == OUTBOUND_DIRECT) {
goto direct;
} else if (unlikely(outbound == OUTBOUND_BLOCK)) {
goto block;
}
// Disable checksum.
if (disable_checksum) {
// Set checksum zero.
__u32 l4_cksm_off = l4_checksum_off(l4proto, ihl);
__sum16 bak_cksm = 0;
bpf_skb_store_bytes(skb, l4_cksm_off, &bak_cksm, sizeof(bak_cksm), 0);
bpf_csum_level(skb, BPF_CSUM_LEVEL_RESET);
// Save routing result.
if ((ret = bpf_map_update_elem(&routing_tuples_map, &tuples, &outbound,
BPF_ANY))) {
bpf_printk("shot save routing result: %d", ret);
return TC_ACT_SHOT;
}
// Assign to control plane.
__be16 *tproxy_port = bpf_map_lookup_elem(&param_map, &tproxy_port_key);
if (!tproxy_port) {
bpf_printk("shot tproxy port not set: %d", ret);
return TC_ACT_SHOT;
}
__builtin_memset(&tuple, 0, sizeof(tuple));
tuple.ipv6.daddr[3] = bpf_htonl(0x00000001);
tuple.ipv6.dport = *tproxy_port;
if (l4proto == IPPROTO_TCP) {
// TCP.
sk = bpf_skc_lookup_tcp(skb, &tuple, sizeof(tuple), BPF_F_CURRENT_NETNS, 0);
if (!sk || sk->state != BPF_TCP_LISTEN) {
bpf_printk("shot tproxy not listen: %d", ret);
goto sk_shot;
}
} else {
// UDP.
sk = bpf_sk_lookup_udp(skb, &tuple, sizeof(tuple), BPF_F_CURRENT_NETNS, 0);
if (!sk) {
goto sk_shot;
}
}
assign:
skb->mark = TPROXY_MARK;
ret = bpf_sk_assign(skb, sk, 0);
bpf_sk_release(sk);
if (ret) {
if (is_old_conn && ret == -ESOCKTNOSUPPORT) {
bpf_printk("bpf_sk_assign: %d, perhaps you have other TPROXY programs "
"(such as v2ray) running?",
ret);
} else {
bpf_printk("bpf_sk_assign: %d", ret);
}
return TC_ACT_SHOT;
}
return TC_ACT_OK;
sk_shot:
if (sk) {
bpf_sk_release(sk);
}
return TC_ACT_SHOT;
direct:
return TC_ACT_OK;
block:
return TC_ACT_SHOT;
}
// Cookie will change after the first packet, so we just use it for
@ -2115,7 +2101,6 @@ static int __always_inline update_map_elem_by_cookie(const __u64 cookie) {
for (loc = 0, j = 0; j < MAX_ARG_LEN_TO_PROBE;
++j, loc = ((loc + 1) & (MAX_ARG_SCANNER_BUFFER_SIZE - 1))) {
// volatile unsigned long k = j; // Cheat to unroll.
// if (arg_start + k >= arg_end) {
if (unlikely(arg_start + j >= arg_end)) {
break;
}
@ -2129,6 +2114,7 @@ static int __always_inline update_map_elem_by_cookie(const __u64 cookie) {
} else {
buf[to_read] = 0;
}
// No need to CO-RE.
if ((ret = bpf_probe_read_user(&buf, to_read,
(const void *)(arg_start + j)))) {
bpf_printk("failed to read process name: %d", ret);
@ -2184,25 +2170,4 @@ int tproxy_wan_cg_sock_release(struct bpf_sock *sk) {
return 1;
}
// SEC("cgroup/connect4")
// int tproxy_wan_cg_connect4(struct bpf_sock_addr *ctx) {
// update_map_elem_by_cookie(bpf_get_socket_cookie(ctx));
// return 1;
// }
// SEC("cgroup/connect6")
// int tproxy_wan_cg_connect6(struct bpf_sock_addr *ctx) {
// update_map_elem_by_cookie(bpf_get_socket_cookie(ctx));
// return 1;
// }
// SEC("cgroup/sendmsg4")
// int tproxy_wan_cg_sendmsg4(struct bpf_sock_addr *ctx) {
// update_map_elem_by_cookie(bpf_get_socket_cookie(ctx));
// return 1;
// }
// SEC("cgroup/sendmsg6")
// int tproxy_wan_cg_sendmsg6(struct bpf_sock_addr *ctx) {
// update_map_elem_by_cookie(bpf_get_socket_cookie(ctx));
// return 1;
// }
SEC("license") const char __license[] = "Dual BSD/GPL";

View File

@ -70,6 +70,7 @@ func generate(output string) error {
func GenerateObjects(output string) {
if err := generate(output); err != nil {
fmt.Println(err.Error())
os.Exit(1)
}
}

View File

@ -9,48 +9,36 @@ import (
"fmt"
"github.com/mzz2017/softwind/pkg/zeroalloc/io"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
internal "github.com/v2rayA/dae/pkg/ebpf_internal"
"golang.org/x/sys/unix"
"net"
"net/netip"
"strings"
"time"
)
func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
defer lConn.Close()
rAddr := lConn.RemoteAddr().(*net.TCPAddr).AddrPort()
ip6 := rAddr.Addr().As16()
var value bpfIpPortOutbound
if err := c.bpf.TcpDstMap.Lookup(bpfIpPort{
Ip: common.Ipv6ByteSliceToUint32Array(ip6[:]),
Port: internal.Htons(rAddr.Port()),
}, &value); err != nil {
return fmt.Errorf("reading map: key %v: %w", rAddr.String(), err)
src := lConn.RemoteAddr().(*net.TCPAddr).AddrPort()
dst := lConn.LocalAddr().(*net.TCPAddr).AddrPort()
outboundIndex, err := c.RetrieveOutboundIndex(src, dst, unix.IPPROTO_TCP)
if err != nil {
return fmt.Errorf("RetrieveOutboundIndex: %w", err)
}
dstSlice, ok := netip.AddrFromSlice(common.Ipv6Uint32ArrayToByteSlice(value.Ip))
if !ok {
return fmt.Errorf("failed to parse dest ip: %v", value.Ip)
}
dst := netip.AddrPortFrom(dstSlice, internal.Htons(value.Port))
switch consts.OutboundIndex(value.Outbound) {
switch consts.OutboundIndex(outboundIndex) {
case consts.OutboundDirect:
case consts.OutboundControlPlaneDirect:
value.Outbound = uint8(consts.OutboundDirect)
outboundIndex = consts.OutboundDirect
c.log.Tracef("outbound: %v => %v",
consts.OutboundControlPlaneDirect.String(),
consts.OutboundIndex(value.Outbound).String(),
consts.OutboundIndex(outboundIndex).String(),
)
default:
}
outbound := c.outbounds[value.Outbound]
outbound := c.outbounds[outboundIndex]
// TODO: Set-up ip to domain mapping and show domain if possible.
src := lConn.RemoteAddr().(*net.TCPAddr).AddrPort()
if value.Outbound < 0 || int(value.Outbound) >= len(c.outbounds) {
return fmt.Errorf("outbound id from bpf is out of range: %v not in [0, %v]", value.Outbound, len(c.outbounds)-1)
if outboundIndex < 0 || int(outboundIndex) >= len(c.outbounds) {
return fmt.Errorf("outbound id from bpf is out of range: %v not in [0, %v]", outboundIndex, len(c.outbounds)-1)
}
dialer, err := outbound.Select()
if err != nil {

View File

@ -0,0 +1,60 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2023, mzz2017 <mzz@tuta.io>
*/
package control
import (
"encoding/binary"
"fmt"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
internal "github.com/v2rayA/dae/pkg/ebpf_internal"
"golang.org/x/sys/unix"
"net/netip"
"syscall"
)
func (c *ControlPlaneCore) RetrieveOutboundIndex(src, dst netip.AddrPort, l4proto uint8) (consts.OutboundIndex, error) {
srcIp6 := src.Addr().As16()
dstIp6 := dst.Addr().As16()
var outboundIndex uint32
if err := c.bpf.RoutingTuplesMap.Lookup(bpfTuples{
Src: bpfIpPort{
Ip: common.Ipv6ByteSliceToUint32Array(srcIp6[:]),
Port: internal.Htons(src.Port()),
},
Dst: bpfIpPort{
Ip: common.Ipv6ByteSliceToUint32Array(dstIp6[:]),
Port: internal.Htons(dst.Port()),
},
L4proto: l4proto,
}, &outboundIndex); err != nil {
return 0, fmt.Errorf("reading map: key %v: %w", src.String(), err)
}
if outboundIndex > uint32(consts.OutboundLogicalMax) {
return 0, fmt.Errorf("bad outbound index")
}
return consts.OutboundIndex(outboundIndex), nil
}
func RetrieveOriginalDest(oob []byte) netip.AddrPort {
msgs, err := syscall.ParseSocketControlMessage(oob)
if err != nil {
return netip.AddrPort{}
}
for _, msg := range msgs {
if msg.Header.Level == syscall.SOL_IP && msg.Header.Type == syscall.IP_RECVORIGDSTADDR {
ip := msg.Data[4:8]
port := binary.BigEndian.Uint16(msg.Data[2:4])
return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(ip)), port)
} else if msg.Header.Level == syscall.SOL_IPV6 && msg.Header.Type == unix.IPV6_RECVORIGDSTADDR {
ip := msg.Data[8:24]
port := binary.BigEndian.Uint16(msg.Data[2:4])
return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(ip)), port)
}
}
return netip.AddrPort{}
}

View File

@ -14,9 +14,11 @@ import (
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/outbound/dialer"
"golang.org/x/net/dns/dnsmessage"
"golang.org/x/sys/unix"
"net"
"net/netip"
"strings"
"syscall"
"time"
)
@ -80,7 +82,20 @@ func sendPktWithHdr(data []byte, from netip.AddrPort, lConn *net.UDPConn, to net
return err
}
func (c *ControlPlane) RelayToUDP(lConn *net.UDPConn, to netip.AddrPort, isDNS bool, dummyFrom *netip.AddrPort, validateRushAns bool) UdpHandler {
func sendPktBind(data []byte, from netip.AddrPort, to netip.AddrPort) error {
d := net.Dialer{Control: func(network, address string, c syscall.RawConn) error {
return dialer.BindControl(c, from)
}}
conn, err := d.Dial("udp", to.String())
if err != nil {
return err
}
uConn := conn.(*net.UDPConn)
_, err = uConn.Write(data)
return err
}
func (c *ControlPlane) RelayToUDP(to netip.AddrPort, isDNS bool, dummyFrom *netip.AddrPort, validateRushAns bool) UdpHandler {
return func(data []byte, from netip.AddrPort) (err error) {
// Do not return conn-unrelated err in this func.
@ -103,52 +118,57 @@ func (c *ControlPlane) RelayToUDP(lConn *net.UDPConn, to netip.AddrPort, isDNS b
if dummyFrom != nil {
from = *dummyFrom
}
return sendPktWithHdr(data, from, lConn, to)
return sendPktBind(data, from, to)
}
}
func (c *ControlPlane) handlePkt(data []byte, lConn *net.UDPConn, lAddrPort netip.AddrPort, addrHdr *AddrHdr) (err error) {
switch consts.OutboundIndex(addrHdr.Outbound) {
func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort) (err error) {
outboundIndex, err := c.RetrieveOutboundIndex(src, dst, unix.IPPROTO_UDP)
if err != nil {
return fmt.Errorf("RetrieveOutboundIndex: %w", err)
}
switch outboundIndex {
case consts.OutboundDirect:
case consts.OutboundControlPlaneDirect:
addrHdr.Outbound = uint8(consts.OutboundDirect)
outboundIndex = consts.OutboundDirect
c.log.Tracef("outbound: %v => %v",
consts.OutboundControlPlaneDirect.String(),
consts.OutboundIndex(addrHdr.Outbound).String(),
outboundIndex.String(),
)
default:
}
if int(addrHdr.Outbound) >= len(c.outbounds) {
return fmt.Errorf("outbound %v out of range", addrHdr.Outbound)
if int(outboundIndex) >= len(c.outbounds) {
return fmt.Errorf("outbound %v out of range", outboundIndex)
}
outbound := c.outbounds[addrHdr.Outbound]
outbound := c.outbounds[outboundIndex]
dnsMessage, natTimeout := ChooseNatTimeout(data)
// We should cache DNS records and set record TTL to 0, in order to monitor the dns req and resp in real time.
isDns := dnsMessage != nil
var dummyFrom *netip.AddrPort
dest := addrHdr.Dest
destToSend := dst
if isDns {
if resp := c.LookupDnsRespCache(dnsMessage); resp != nil {
// Send cache to client directly.
if err = sendPktWithHdr(resp, dest, lConn, lAddrPort); err != nil {
if err = sendPktBind(resp, destToSend, src); err != nil {
return fmt.Errorf("failed to write cached DNS resp: %w", err)
}
if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 {
q := dnsMessage.Questions[0]
c.log.Tracef("UDP(DNS) %v <-[%v]-> Cache: %v %v",
RefineSourceToShow(lAddrPort, dest.Addr()), outbound.Name, strings.ToLower(q.Name.String()), q.Type,
RefineSourceToShow(src, destToSend.Addr()), outbound.Name, strings.ToLower(q.Name.String()), q.Type,
)
}
return nil
}
// Need to make a DNS request.
c.log.Tracef("Modify dns target %v to upstream: %v", RefineAddrPortToShow(dest), c.dnsUpstream)
c.log.Tracef("Modify dns target %v to upstream: %v", RefineAddrPortToShow(destToSend), c.dnsUpstream)
// Modify dns target to upstream.
// NOTICE: Routing was calculated in advance by the eBPF program.
dummyFrom = &addrHdr.Dest
dest = c.dnsUpstream
dummyFrom = &dst
destToSend = c.dnsUpstream
// Flip dns question to reduce dns pollution.
FlipDnsQuestionCase(dnsMessage)
@ -166,9 +186,9 @@ func (c *ControlPlane) handlePkt(data []byte, lConn *net.UDPConn, lAddrPort neti
// We only validate rush-ans when outbound is direct and pkt does not send to a home device.
// Because additional record OPT may not be supported by home router.
// So se should trust home devices even if they make rush-answer (or looks like).
validateRushAns := addrHdr.Outbound == uint8(consts.OutboundDirect) && !dest.Addr().IsPrivate()
ue, err := DefaultUdpEndpointPool.GetOrCreate(lAddrPort, &UdpEndpointOptions{
Handler: c.RelayToUDP(lConn, lAddrPort, isDns, dummyFrom, validateRushAns),
validateRushAns := outboundIndex == consts.OutboundDirect && !destToSend.Addr().IsPrivate()
ue, err := DefaultUdpEndpointPool.GetOrCreate(src, &UdpEndpointOptions{
Handler: c.RelayToUDP(src, isDns, dummyFrom, validateRushAns),
NatTimeout: natTimeout,
DialerFunc: func() (*dialer.Dialer, error) {
newDialer, err := outbound.Select()
@ -177,7 +197,7 @@ func (c *ControlPlane) handlePkt(data []byte, lConn *net.UDPConn, lAddrPort neti
}
return newDialer, nil
},
Target: dest,
Target: destToSend,
})
if err != nil {
return fmt.Errorf("failed to GetOrCreate: %w", err)
@ -194,7 +214,7 @@ func (c *ControlPlane) handlePkt(data []byte, lConn *net.UDPConn, lAddrPort neti
"qname": strings.ToLower(q.Name.String()),
"qtype": q.Type,
}).Infof("%v <-> %v",
RefineSourceToShow(lAddrPort, dest.Addr()), RefineAddrPortToShow(dest),
RefineSourceToShow(src, destToSend.Addr()), RefineAddrPortToShow(destToSend),
)
} else {
// TODO: Set-up ip to domain mapping and show domain if possible.
@ -203,11 +223,11 @@ func (c *ControlPlane) handlePkt(data []byte, lConn *net.UDPConn, lAddrPort neti
"outbound": outbound.Name,
"dialer": d.Name(),
}).Infof("%v <-> %v",
RefineSourceToShow(lAddrPort, dest.Addr()), RefineAddrPortToShow(dest),
RefineSourceToShow(src, destToSend.Addr()), RefineAddrPortToShow(destToSend),
)
}
//log.Printf("WriteToUDPAddrPort->%v", dest)
_, err = ue.WriteToUDPAddrPort(data, dest)
//log.Printf("WriteToUDPAddrPort->%v", destToSend)
_, err = ue.WriteToUDPAddrPort(data, destToSend)
if err != nil {
return fmt.Errorf("failed to write UDP packet req: %w", err)
}

View File

@ -3,8 +3,6 @@ package dialer
import (
"golang.org/x/net/proxy"
"net"
"runtime"
"syscall"
)
var SymmetricDirect = newDirect(false)
@ -26,9 +24,7 @@ type direct struct {
func newDirect(fullCone bool) proxy.Dialer {
return &direct{
netDialer: &net.Dialer{Control: func(network, address string, c syscall.RawConn) error {
return SoMarkControl(c)
}},
netDialer: &net.Dialer{},
fullCone: fullCone,
}
}
@ -44,10 +40,6 @@ func (d *direct) Dial(network, addr string) (c net.Conn, err error) {
if err != nil {
return nil, err
}
raw, e := conn.SyscallConn()
if e == nil {
_ = SoMarkControl(raw)
}
return &directUDPConn{UDPConn: conn, FullCone: true}, nil
} else {
conn, err := d.netDialer.Dial(network, addr)
@ -88,26 +80,3 @@ func (c *directUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
}
return c.UDPConn.WriteToUDP(b, addr)
}
var fwmarkIoctl int
func init() {
switch runtime.GOOS {
case "linux", "android":
fwmarkIoctl = 36 /* unix.SO_MARK */
case "freebsd":
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
case "openbsd":
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
}
}
func SoMarkControl(c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
//TODO: force to set 0xff. any chances to customize this value?
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, fwmarkIoctl, 0x100)
if err != nil {
return
}
})
}

View File

@ -0,0 +1,119 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2023, mzz2017 <mzz@tuta.io>
*/
package dialer
import (
"fmt"
"golang.org/x/sys/unix"
"net/netip"
"runtime"
"syscall"
)
var fwmarkIoctl int
func init() {
switch runtime.GOOS {
case "linux", "android":
fwmarkIoctl = 36 /* unix.SO_MARK */
case "freebsd":
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
case "openbsd":
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
}
}
func SoMarkControl(c syscall.RawConn, mark int) error {
var sockOptErr error
controlErr := c.Control(func(fd uintptr) {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, fwmarkIoctl, mark)
if err != nil {
sockOptErr = fmt.Errorf("error setting SO_MARK socket option: %w", err)
}
})
if controlErr != nil {
return fmt.Errorf("error invoking socket control function: %w", controlErr)
}
return sockOptErr
}
func TproxyControl(c syscall.RawConn) error {
var sockOptErr error
controlErr := c.Control(func(fd uintptr) {
// - https://www.kernel.org/doc/Documentation/networking/tproxy.txt
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_TRANSPARENT, 1); err != nil {
sockOptErr = fmt.Errorf("error setting IP_TRANSPARENT socket option: %w", err)
return
}
if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil {
sockOptErr = fmt.Errorf("error setting SO_REUSEADDR socket option: %w", err)
return
}
e4 := unix.SetsockoptInt(int(fd), syscall.SOL_IP, unix.IP_RECVORIGDSTADDR, 1)
e6 := unix.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
if e4 != nil && e6 != nil {
if e4 != nil {
sockOptErr = fmt.Errorf("error setting IP_RECVORIGDSTADDR socket option: %w", e4)
} else {
sockOptErr = fmt.Errorf("error setting IPV6_RECVORIGDSTADDR socket option: %w", e6)
}
return
}
})
if controlErr != nil {
return fmt.Errorf("error invoking socket control function: %w", controlErr)
}
return sockOptErr
}
func BindControl(c syscall.RawConn, lAddrPort netip.AddrPort) error {
var sockOptErr error
controlErr := c.Control(func(fd uintptr) {
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil {
sockOptErr = fmt.Errorf("error setting IP_TRANSPARENT socket option: %w", err)
}
if err := bindAddr(fd, lAddrPort); err != nil {
sockOptErr = fmt.Errorf("error bindAddr: %w", err)
}
})
if controlErr != nil {
return fmt.Errorf("error invoking socket control function: %w", controlErr)
}
return sockOptErr
}
func bindAddr(fd uintptr, addrPort netip.AddrPort) error {
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
return fmt.Errorf("error setting SO_REUSEADDR socket option: %w", err)
}
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
return fmt.Errorf("error setting SO_REUSEPORT socket option: %w", err)
}
var sockAddr syscall.Sockaddr
switch addr := addrPort.Addr().AsSlice(); len(addr) {
case 4:
a4 := &syscall.SockaddrInet4{
Port: int(addrPort.Port()),
}
copy(a4.Addr[:], addr)
sockAddr = a4
case 16:
a6 := &syscall.SockaddrInet6{
Port: int(addrPort.Port()),
}
copy(a6.Addr[:], addr)
sockAddr = a6
default:
return fmt.Errorf("unexpected length of ip")
}
return syscall.Bind(int(fd), sockAddr)
}

View File

@ -3,7 +3,7 @@ lan=docker0
wan=wlp5s0
sudo tc qdisc add dev $lan clsact > /dev/null 2>&1
sudo tc qdisc add dev $wan clsact > /dev/null 2>&1
# sudo tc qdisc add dev $wan clsact > /dev/null 2>&1
set -ex
@ -16,8 +16,8 @@ sudo tc filter del dev $lan egress
sudo tc filter del dev $wan ingress
sudo tc filter del dev $wan egress
sudo tc filter add dev $lan ingress bpf direct-action obj foo.o sec tc/ingress
sudo tc filter add dev $lan egress bpf direct-action obj foo.o sec tc/egress
sudo tc filter add dev $wan ingress bpf direct-action obj foo.o sec tc/wan_ingress
sudo tc filter add dev $wan egress bpf direct-action obj foo.o sec tc/wan_egress
# sudo tc filter add dev $lan egress bpf direct-action obj foo.o sec tc/egress
# sudo tc filter add dev $wan ingress bpf direct-action obj foo.o sec tc/wan_ingress
# sudo tc filter add dev $wan egress bpf direct-action obj foo.o sec tc/wan_egress
exit 0