Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new comand line option -B <address> for binding specified address when connect to DNS server via TCP #15

Merged
merged 5 commits into from
Jul 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 128 additions & 68 deletions dns2tcp.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <errno.h>
#include <signal.h>
#include <unistd.h>
#include <assert.h>
#include <arpa/inet.h>
#include <sys/types.h>
#include <sys/socket.h>
Expand Down Expand Up @@ -60,7 +61,7 @@
})

#define log_verbose(fmt, args...) ({ \
if (verbose()) log_info(fmt, ##args); \
if (verbose) log_info(fmt, ##args); \
})

#define log_info(fmt, args...) \
Expand Down Expand Up @@ -131,48 +132,63 @@ typedef struct {
/* ======================== global-vars ======================== */

enum {
OPT_IPV6_V6ONLY = 1 << 0,
OPT_REUSE_PORT = 1 << 1,
OPT_VERBOSE = 1 << 2,
FLAG_IPV6_V6ONLY = 1 << 0, /* udp listen */
FLAG_REUSE_PORT = 1 << 1, /* udp listen */
FLAG_VERBOSE = 1 << 2, /* logging */
FLAG_LOCAL_ADDR = 1 << 3, /* tcp local addr */
};

#define has_opt(opt) (g_options & (opt))
#define enable_opt(opt) (g_options |= (opt))
#define has_flag(flag) (g_flags & (flag))
#define add_flag(flag) (g_flags |= (flag))

#define verbose() has_opt(OPT_VERBOSE)
#define verbose has_flag(FLAG_VERBOSE)

static uint8_t g_options = 0;
static uint8_t g_syn_maxcnt = 0;
static uint8_t g_flags = 0;
static uint8_t g_syn_cnt = 0;

static int g_udp_sockfd = -1;
/* udp listen */
static int g_listen_fd = -1;
static char g_listen_ipstr[IP6STRLEN] = {0};
static uint16_t g_listen_port = 0;
static union skaddr g_listen_skaddr = {0};

/* tcp server address */
static char g_remote_ipstr[IP6STRLEN] = {0};
static uint16_t g_remote_port = 0;
static union skaddr g_remote_skaddr = {0};

/* tcp local address [optional] */
static char g_local_ipstr[IP6STRLEN] = {0};
static uint16_t g_local_port = 0;
static union skaddr g_local_skaddr = {0};

static void udp_recvmsg_cb(evloop_t *evloop, evio_t *watcher, int events);
static void tcp_connect_cb(evloop_t *evloop, evio_t *watcher, int events);
static void tcp_sendmsg_cb(evloop_t *evloop, evio_t *watcher, int events);
static void tcp_recvmsg_cb(evloop_t *evloop, evio_t *watcher, int events);

static void print_help(void) {
printf("usage: dns2tcp <-L listen> <-R remote> [-s syncnt] [-6rvVh]\n"
" -L <ip[#port]> udp listen address, this is required\n"
" -R <ip[#port]> tcp remote address, this is required\n"
" -s <syncnt> set TCP_SYNCNT(max) for remote socket\n"
" -6 enable IPV6_V6ONLY for listen socket\n"
" -r enable SO_REUSEPORT for listen socket\n"
" -v print verbose log, default: <disabled>\n"
printf("usage: dns2tcp <-L listen> <-R remote> [options...]\n"
" -L <ip[#port]> udp listen address, port default to 53\n"
" -R <ip[#port]> tcp remote address, port default to 53\n"
" -l <ip[#port]> tcp local address, port default to 0\n"
" -s <syncnt> set TCP_SYNCNT option for tcp socket\n"
" -6 set IPV6_V6ONLY option for udp socket\n"
" -r set SO_REUSEPORT option for udp socket\n"
" -v print verbose log, used for debugging\n"
" -V print version number of dns2tcp and exit\n"
" -h print help information of dns2tcp and exit\n"
"bug report: https://github.com/zfl9/dns2tcp. email: [email protected]\n"
);
}

static void parse_addr(const char *addr, bool is_listen_addr) {
enum addr_type {
ADDR_UDP_LISTEN,
ADDR_TCP_REMOTE,
ADDR_TCP_LOCAL,
};

static void parse_addr(const char *addr, enum addr_type addr_type) {
const char *end = addr + strlen(addr);
const char *sep = strchr(addr, '#') ?: end;

Expand All @@ -191,22 +207,45 @@ static void parse_addr(const char *addr, bool is_listen_addr) {
int family = get_ipstr_family(ipstr);
if (family == -1) goto err;

uint16_t port = 53;
if (portlen >= 0 && (port = strtoul(portstart, NULL, 10)) == 0) goto err;

if (is_listen_addr) {
strcpy(g_listen_ipstr, ipstr);
g_listen_port = port;
skaddr_from_text(&g_listen_skaddr, family, ipstr, port);
} else {
strcpy(g_remote_ipstr, ipstr);
g_remote_port = port;
skaddr_from_text(&g_remote_skaddr, family, ipstr, port);
uint16_t port = addr_type != ADDR_TCP_LOCAL ? 53 : 0;
if (portlen >= 0 && (port = strtoul(portstart, NULL, 10)) == 0 && addr_type != ADDR_TCP_LOCAL) goto err;

#define set_addr(tag) ({ \
strcpy(g_##tag##_ipstr, ipstr); \
g_##tag##_port = port; \
skaddr_from_text(&g_##tag##_skaddr, family, ipstr, port); \
})

switch (addr_type) {
case ADDR_UDP_LISTEN:
set_addr(listen);
break;
case ADDR_TCP_REMOTE:
set_addr(remote);
break;
case ADDR_TCP_LOCAL:
set_addr(local);
break;
}

#undef set_addr

return;

err:;
const char *type = is_listen_addr ? "listen" : "remote";
const char *type;
switch (addr_type) {
case ADDR_UDP_LISTEN:
type = "udp_listen";
break;
case ADDR_TCP_REMOTE:
type = "tcp_remote";
break;
case ADDR_TCP_LOCAL:
type = "tcp_local";
break;
}

printf("invalid %s address: '%s'\n", type, addr);
print_help();
exit(1);
Expand All @@ -215,38 +254,47 @@ err:;
static void parse_opt(int argc, char *argv[]) {
char opt_listen_addr[IP6STRLEN + PORTSTRLEN] = {0};
char opt_remote_addr[IP6STRLEN + PORTSTRLEN] = {0};
char opt_local_addr[IP6STRLEN + PORTSTRLEN] = {0};

opterr = 0;
int shortopt;
const char *optstr = "L:R:s:6rafvVh";
const char *optstr = "L:R:l:s:6rafvVh";
while ((shortopt = getopt(argc, argv, optstr)) != -1) {
switch (shortopt) {
case 'L':
if (strlen(optarg) + 1 > IP6STRLEN + PORTSTRLEN) {
printf("invalid listen addr: %s\n", optarg);
printf("invalid udp listen addr: %s\n", optarg);
goto err;
}
strcpy(opt_listen_addr, optarg);
break;
case 'R':
if (strlen(optarg) + 1 > IP6STRLEN + PORTSTRLEN) {
printf("invalid remote addr: %s\n", optarg);
printf("invalid tcp remote addr: %s\n", optarg);
goto err;
}
strcpy(opt_remote_addr, optarg);
break;
case 'l':
if (strlen(optarg) + 1 > IP6STRLEN + PORTSTRLEN) {
printf("invalid tcp local addr: %s\n", optarg);
goto err;
}
strcpy(opt_local_addr, optarg);
add_flag(FLAG_LOCAL_ADDR);
break;
case 's':
g_syn_maxcnt = strtoul(optarg, NULL, 10);
if (g_syn_maxcnt == 0) {
g_syn_cnt = strtoul(optarg, NULL, 10);
if (g_syn_cnt == 0) {
printf("invalid tcp syn cnt: %s\n", optarg);
goto err;
}
break;
case '6':
enable_opt(OPT_IPV6_V6ONLY);
add_flag(FLAG_IPV6_V6ONLY);
break;
case 'r':
enable_opt(OPT_REUSE_PORT);
add_flag(FLAG_REUSE_PORT);
break;
case 'a':
/* nop */
Expand All @@ -255,7 +303,7 @@ static void parse_opt(int argc, char *argv[]) {
/* nop */
break;
case 'v':
enable_opt(OPT_VERBOSE);
add_flag(FLAG_VERBOSE);
break;
case 'V':
printf(DNS2TCP_VER"\n");
Expand All @@ -273,6 +321,7 @@ static void parse_opt(int argc, char *argv[]) {
}
}

/* check the required opt */
if (strlen(opt_listen_addr) == 0) {
printf("missing option: '-L'\n");
goto err;
Expand All @@ -282,8 +331,12 @@ static void parse_opt(int argc, char *argv[]) {
goto err;
}

parse_addr(opt_listen_addr, true);
parse_addr(opt_remote_addr, false);
parse_addr(opt_listen_addr, ADDR_UDP_LISTEN);
parse_addr(opt_remote_addr, ADDR_TCP_REMOTE);

if (has_flag(FLAG_LOCAL_ADDR))
parse_addr(opt_local_addr, ADDR_TCP_LOCAL);

return;

err:
Expand All @@ -302,17 +355,18 @@ static int create_socket(int family, int type) {
}

const int opt = 1;
if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
err_op = "set_reuseaddr";
goto out;
}

if (type == SOCK_DGRAM) {
// udp listen socket
if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
err_op = "set_reuseaddr";
goto out;
}
if (has_opt(OPT_REUSE_PORT) && setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)) < 0) {
if (has_flag(FLAG_REUSE_PORT) && setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)) < 0) {
err_op = "set_reuseport";
goto out;
}
if (family == AF_INET6 && has_opt(OPT_IPV6_V6ONLY) && setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) < 0) {
if (family == AF_INET6 && has_flag(FLAG_IPV6_V6ONLY) && setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) < 0) {
err_op = "set_ipv6only";
goto out;
}
Expand All @@ -322,8 +376,8 @@ static int create_socket(int family, int type) {
err_op = "set_tcp_nodelay";
goto out;
}
const int syn_maxcnt = g_syn_maxcnt;
if (syn_maxcnt && setsockopt(fd, IPPROTO_TCP, TCP_SYNCNT, &syn_maxcnt, sizeof(syn_maxcnt)) < 0) {
const int syn_cnt = g_syn_cnt;
if (syn_cnt && setsockopt(fd, IPPROTO_TCP, TCP_SYNCNT, &syn_cnt, sizeof(syn_cnt)) < 0) {
err_op = "set_tcp_syncnt";
goto out;
}
Expand All @@ -342,24 +396,25 @@ int main(int argc, char *argv[]) {

log_info("udp listen addr: %s#%hu", g_listen_ipstr, g_listen_port);
log_info("tcp remote addr: %s#%hu", g_remote_ipstr, g_remote_port);
if (g_syn_maxcnt) log_info("enable TCP_SYNCNT:%hhu sockopt", g_syn_maxcnt);
if (has_opt(OPT_IPV6_V6ONLY)) log_info("enable IPV6_V6ONLY sockopt");
if (has_opt(OPT_REUSE_PORT)) log_info("enable SO_REUSEPORT sockopt");
log_verbose("verbose mode, affect performance");

g_udp_sockfd = create_socket(skaddr_family(&g_listen_skaddr), SOCK_DGRAM);
if (g_udp_sockfd < 0)
if (has_flag(FLAG_LOCAL_ADDR)) log_info("tcp local addr: %s#%hu", g_local_ipstr, g_local_port);
if (g_syn_cnt) log_info("enable TCP_SYNCNT:%hhu sockopt", g_syn_cnt);
if (has_flag(FLAG_IPV6_V6ONLY)) log_info("enable IPV6_V6ONLY sockopt");
if (has_flag(FLAG_REUSE_PORT)) log_info("enable SO_REUSEPORT sockopt");
log_verbose("print the verbose log");

g_listen_fd = create_socket(skaddr_family(&g_listen_skaddr), SOCK_DGRAM);
if (g_listen_fd < 0)
return 1;

if (bind(g_udp_sockfd, &g_listen_skaddr.sa, skaddr_len(&g_listen_skaddr)) < 0) {
if (bind(g_listen_fd, &g_listen_skaddr.sa, skaddr_len(&g_listen_skaddr)) < 0) {
log_error("bind udp address: %m");
return 1;
}

evloop_t *evloop = ev_default_loop(0);

evio_t watcher;
ev_io_init(&watcher, udp_recvmsg_cb, g_udp_sockfd, EV_READ);
ev_io_init(&watcher, udp_recvmsg_cb, g_listen_fd, EV_READ);
ev_io_start(evloop, &watcher);

return ev_run(evloop, 0);
Expand All @@ -368,14 +423,14 @@ int main(int argc, char *argv[]) {
static void udp_recvmsg_cb(evloop_t *evloop, evio_t *watcher __unused, int events __unused) {
ctx_t *ctx = malloc(sizeof(*ctx));

ssize_t nrecv = recvfrom(g_udp_sockfd, (void *)ctx->buffer + 2, DNS_MSGSZ, 0, &ctx->srcaddr.sa, &(socklen_t){sizeof(ctx->srcaddr)});
ssize_t nrecv = recvfrom(g_listen_fd, (void *)ctx->buffer + 2, DNS_MSGSZ, 0, &ctx->srcaddr.sa, &(socklen_t){sizeof(ctx->srcaddr)});
if (nrecv < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK)
log_warning("recv from udp socket: %m");
goto free_ctx;
}

if (verbose()) {
if (verbose) {
char ip[IP6STRLEN];
uint16_t port;
skaddr_to_text(&ctx->srcaddr, ip, &port);
Expand All @@ -385,23 +440,28 @@ static void udp_recvmsg_cb(evloop_t *evloop, evio_t *watcher __unused, int event
uint16_t *p_msglen = (void *)ctx->buffer;
*p_msglen = htons(nrecv); /* msg length */

int sockfd = create_socket(skaddr_family(&g_remote_skaddr), SOCK_STREAM);
if (sockfd < 0)
int fd = create_socket(skaddr_family(&g_remote_skaddr), SOCK_STREAM);
if (fd < 0)
goto free_ctx;

if (connect(sockfd, &g_remote_skaddr.sa, skaddr_len(&g_remote_skaddr)) < 0 && errno != EINPROGRESS) {
if (has_flag(FLAG_LOCAL_ADDR) && bind(fd, &g_local_skaddr.sa, skaddr_len(&g_local_skaddr)) < 0) {
log_warning("bind tcp address: %m");
goto close_fd;
}

if (connect(fd, &g_remote_skaddr.sa, skaddr_len(&g_remote_skaddr)) < 0 && errno != EINPROGRESS) {
log_warning("connect to %s#%hu: %m", g_remote_ipstr, g_remote_port);
goto close_sockfd;
goto close_fd;
}
log_verbose("try to connect to %s#%hu", g_remote_ipstr, g_remote_port);

ev_io_init(&ctx->watcher, tcp_connect_cb, sockfd, EV_WRITE);
ev_io_init(&ctx->watcher, tcp_connect_cb, fd, EV_WRITE);
ev_io_start(evloop, &ctx->watcher);

return;

close_sockfd:
close(sockfd);
close_fd:
close(fd);
free_ctx:
free(ctx);
}
Expand Down Expand Up @@ -473,8 +533,8 @@ static void tcp_recvmsg_cb(evloop_t *evloop, evio_t *watcher, int events __unuse
uint16_t msglen;
if (ctx->nbytes < 2 || ctx->nbytes < 2 + (msglen = ntohs(*(uint16_t *)buffer))) return;

ssize_t nsend = sendto(g_udp_sockfd, buffer + 2, msglen, 0, &ctx->srcaddr.sa, skaddr_len(&ctx->srcaddr));
if (nsend < 0 || verbose()) {
ssize_t nsend = sendto(g_listen_fd, buffer + 2, msglen, 0, &ctx->srcaddr.sa, skaddr_len(&ctx->srcaddr));
if (nsend < 0 || verbose) {
char ip[IP6STRLEN];
uint16_t port;
skaddr_to_text(&ctx->srcaddr, ip, &port);
Expand Down