Skip to content

Commit

Permalink
Added new comand line option -B <address> for binding specified addre…
Browse files Browse the repository at this point in the history
…ss when connect to DNS server via TCP (#15)

allow setting the local address for the tcp connection.
  • Loading branch information
ku4in authored Jul 10, 2024
1 parent 8796f05 commit 19eb638
Showing 1 changed file with 128 additions and 68 deletions.
196 changes: 128 additions & 68 deletions dns2tcp.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <errno.h>
#include <signal.h>
#include <unistd.h>
#include <assert.h>
#include <arpa/inet.h>
#include <sys/types.h>
#include <sys/socket.h>
Expand Down Expand Up @@ -60,7 +61,7 @@
})

#define log_verbose(fmt, args...) ({ \
if (verbose()) log_info(fmt, ##args); \
if (verbose) log_info(fmt, ##args); \
})

#define log_info(fmt, args...) \
Expand Down Expand Up @@ -131,48 +132,63 @@ typedef struct {
/* ======================== global-vars ======================== */

enum {
OPT_IPV6_V6ONLY = 1 << 0,
OPT_REUSE_PORT = 1 << 1,
OPT_VERBOSE = 1 << 2,
FLAG_IPV6_V6ONLY = 1 << 0, /* udp listen */
FLAG_REUSE_PORT = 1 << 1, /* udp listen */
FLAG_VERBOSE = 1 << 2, /* logging */
FLAG_LOCAL_ADDR = 1 << 3, /* tcp local addr */
};

#define has_opt(opt) (g_options & (opt))
#define enable_opt(opt) (g_options |= (opt))
#define has_flag(flag) (g_flags & (flag))
#define add_flag(flag) (g_flags |= (flag))

#define verbose() has_opt(OPT_VERBOSE)
#define verbose has_flag(FLAG_VERBOSE)

static uint8_t g_options = 0;
static uint8_t g_syn_maxcnt = 0;
static uint8_t g_flags = 0;
static uint8_t g_syn_cnt = 0;

static int g_udp_sockfd = -1;
/* udp listen */
static int g_listen_fd = -1;
static char g_listen_ipstr[IP6STRLEN] = {0};
static uint16_t g_listen_port = 0;
static union skaddr g_listen_skaddr = {0};

/* tcp server address */
static char g_remote_ipstr[IP6STRLEN] = {0};
static uint16_t g_remote_port = 0;
static union skaddr g_remote_skaddr = {0};

/* tcp local address [optional] */
static char g_local_ipstr[IP6STRLEN] = {0};
static uint16_t g_local_port = 0;
static union skaddr g_local_skaddr = {0};

static void udp_recvmsg_cb(evloop_t *evloop, evio_t *watcher, int events);
static void tcp_connect_cb(evloop_t *evloop, evio_t *watcher, int events);
static void tcp_sendmsg_cb(evloop_t *evloop, evio_t *watcher, int events);
static void tcp_recvmsg_cb(evloop_t *evloop, evio_t *watcher, int events);

static void print_help(void) {
printf("usage: dns2tcp <-L listen> <-R remote> [-s syncnt] [-6rvVh]\n"
" -L <ip[#port]> udp listen address, this is required\n"
" -R <ip[#port]> tcp remote address, this is required\n"
" -s <syncnt> set TCP_SYNCNT(max) for remote socket\n"
" -6 enable IPV6_V6ONLY for listen socket\n"
" -r enable SO_REUSEPORT for listen socket\n"
" -v print verbose log, default: <disabled>\n"
printf("usage: dns2tcp <-L listen> <-R remote> [options...]\n"
" -L <ip[#port]> udp listen address, port default to 53\n"
" -R <ip[#port]> tcp remote address, port default to 53\n"
" -l <ip[#port]> tcp local address, port default to 0\n"
" -s <syncnt> set TCP_SYNCNT option for tcp socket\n"
" -6 set IPV6_V6ONLY option for udp socket\n"
" -r set SO_REUSEPORT option for udp socket\n"
" -v print verbose log, used for debugging\n"
" -V print version number of dns2tcp and exit\n"
" -h print help information of dns2tcp and exit\n"
"bug report: https://github.com/zfl9/dns2tcp. email: [email protected]\n"
);
}

static void parse_addr(const char *addr, bool is_listen_addr) {
enum addr_type {
ADDR_UDP_LISTEN,
ADDR_TCP_REMOTE,
ADDR_TCP_LOCAL,
};

static void parse_addr(const char *addr, enum addr_type addr_type) {
const char *end = addr + strlen(addr);
const char *sep = strchr(addr, '#') ?: end;

Expand All @@ -191,22 +207,45 @@ static void parse_addr(const char *addr, bool is_listen_addr) {
int family = get_ipstr_family(ipstr);
if (family == -1) goto err;

uint16_t port = 53;
if (portlen >= 0 && (port = strtoul(portstart, NULL, 10)) == 0) goto err;

if (is_listen_addr) {
strcpy(g_listen_ipstr, ipstr);
g_listen_port = port;
skaddr_from_text(&g_listen_skaddr, family, ipstr, port);
} else {
strcpy(g_remote_ipstr, ipstr);
g_remote_port = port;
skaddr_from_text(&g_remote_skaddr, family, ipstr, port);
uint16_t port = addr_type != ADDR_TCP_LOCAL ? 53 : 0;
if (portlen >= 0 && (port = strtoul(portstart, NULL, 10)) == 0 && addr_type != ADDR_TCP_LOCAL) goto err;

#define set_addr(tag) ({ \
strcpy(g_##tag##_ipstr, ipstr); \
g_##tag##_port = port; \
skaddr_from_text(&g_##tag##_skaddr, family, ipstr, port); \
})

switch (addr_type) {
case ADDR_UDP_LISTEN:
set_addr(listen);
break;
case ADDR_TCP_REMOTE:
set_addr(remote);
break;
case ADDR_TCP_LOCAL:
set_addr(local);
break;
}

#undef set_addr

return;

err:;
const char *type = is_listen_addr ? "listen" : "remote";
const char *type;
switch (addr_type) {
case ADDR_UDP_LISTEN:
type = "udp_listen";
break;
case ADDR_TCP_REMOTE:
type = "tcp_remote";
break;
case ADDR_TCP_LOCAL:
type = "tcp_local";
break;
}

printf("invalid %s address: '%s'\n", type, addr);
print_help();
exit(1);
Expand All @@ -215,38 +254,47 @@ err:;
static void parse_opt(int argc, char *argv[]) {
char opt_listen_addr[IP6STRLEN + PORTSTRLEN] = {0};
char opt_remote_addr[IP6STRLEN + PORTSTRLEN] = {0};
char opt_local_addr[IP6STRLEN + PORTSTRLEN] = {0};

opterr = 0;
int shortopt;
const char *optstr = "L:R:s:6rafvVh";
const char *optstr = "L:R:l:s:6rafvVh";
while ((shortopt = getopt(argc, argv, optstr)) != -1) {
switch (shortopt) {
case 'L':
if (strlen(optarg) + 1 > IP6STRLEN + PORTSTRLEN) {
printf("invalid listen addr: %s\n", optarg);
printf("invalid udp listen addr: %s\n", optarg);
goto err;
}
strcpy(opt_listen_addr, optarg);
break;
case 'R':
if (strlen(optarg) + 1 > IP6STRLEN + PORTSTRLEN) {
printf("invalid remote addr: %s\n", optarg);
printf("invalid tcp remote addr: %s\n", optarg);
goto err;
}
strcpy(opt_remote_addr, optarg);
break;
case 'l':
if (strlen(optarg) + 1 > IP6STRLEN + PORTSTRLEN) {
printf("invalid tcp local addr: %s\n", optarg);
goto err;
}
strcpy(opt_local_addr, optarg);
add_flag(FLAG_LOCAL_ADDR);
break;
case 's':
g_syn_maxcnt = strtoul(optarg, NULL, 10);
if (g_syn_maxcnt == 0) {
g_syn_cnt = strtoul(optarg, NULL, 10);
if (g_syn_cnt == 0) {
printf("invalid tcp syn cnt: %s\n", optarg);
goto err;
}
break;
case '6':
enable_opt(OPT_IPV6_V6ONLY);
add_flag(FLAG_IPV6_V6ONLY);
break;
case 'r':
enable_opt(OPT_REUSE_PORT);
add_flag(FLAG_REUSE_PORT);
break;
case 'a':
/* nop */
Expand All @@ -255,7 +303,7 @@ static void parse_opt(int argc, char *argv[]) {
/* nop */
break;
case 'v':
enable_opt(OPT_VERBOSE);
add_flag(FLAG_VERBOSE);
break;
case 'V':
printf(DNS2TCP_VER"\n");
Expand All @@ -273,6 +321,7 @@ static void parse_opt(int argc, char *argv[]) {
}
}

/* check the required opt */
if (strlen(opt_listen_addr) == 0) {
printf("missing option: '-L'\n");
goto err;
Expand All @@ -282,8 +331,12 @@ static void parse_opt(int argc, char *argv[]) {
goto err;
}

parse_addr(opt_listen_addr, true);
parse_addr(opt_remote_addr, false);
parse_addr(opt_listen_addr, ADDR_UDP_LISTEN);
parse_addr(opt_remote_addr, ADDR_TCP_REMOTE);

if (has_flag(FLAG_LOCAL_ADDR))
parse_addr(opt_local_addr, ADDR_TCP_LOCAL);

return;

err:
Expand All @@ -302,17 +355,18 @@ static int create_socket(int family, int type) {
}

const int opt = 1;
if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
err_op = "set_reuseaddr";
goto out;
}

if (type == SOCK_DGRAM) {
// udp listen socket
if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
err_op = "set_reuseaddr";
goto out;
}
if (has_opt(OPT_REUSE_PORT) && setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)) < 0) {
if (has_flag(FLAG_REUSE_PORT) && setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)) < 0) {
err_op = "set_reuseport";
goto out;
}
if (family == AF_INET6 && has_opt(OPT_IPV6_V6ONLY) && setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) < 0) {
if (family == AF_INET6 && has_flag(FLAG_IPV6_V6ONLY) && setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) < 0) {
err_op = "set_ipv6only";
goto out;
}
Expand All @@ -322,8 +376,8 @@ static int create_socket(int family, int type) {
err_op = "set_tcp_nodelay";
goto out;
}
const int syn_maxcnt = g_syn_maxcnt;
if (syn_maxcnt && setsockopt(fd, IPPROTO_TCP, TCP_SYNCNT, &syn_maxcnt, sizeof(syn_maxcnt)) < 0) {
const int syn_cnt = g_syn_cnt;
if (syn_cnt && setsockopt(fd, IPPROTO_TCP, TCP_SYNCNT, &syn_cnt, sizeof(syn_cnt)) < 0) {
err_op = "set_tcp_syncnt";
goto out;
}
Expand All @@ -342,24 +396,25 @@ int main(int argc, char *argv[]) {

log_info("udp listen addr: %s#%hu", g_listen_ipstr, g_listen_port);
log_info("tcp remote addr: %s#%hu", g_remote_ipstr, g_remote_port);
if (g_syn_maxcnt) log_info("enable TCP_SYNCNT:%hhu sockopt", g_syn_maxcnt);
if (has_opt(OPT_IPV6_V6ONLY)) log_info("enable IPV6_V6ONLY sockopt");
if (has_opt(OPT_REUSE_PORT)) log_info("enable SO_REUSEPORT sockopt");
log_verbose("verbose mode, affect performance");

g_udp_sockfd = create_socket(skaddr_family(&g_listen_skaddr), SOCK_DGRAM);
if (g_udp_sockfd < 0)
if (has_flag(FLAG_LOCAL_ADDR)) log_info("tcp local addr: %s#%hu", g_local_ipstr, g_local_port);
if (g_syn_cnt) log_info("enable TCP_SYNCNT:%hhu sockopt", g_syn_cnt);
if (has_flag(FLAG_IPV6_V6ONLY)) log_info("enable IPV6_V6ONLY sockopt");
if (has_flag(FLAG_REUSE_PORT)) log_info("enable SO_REUSEPORT sockopt");
log_verbose("print the verbose log");

g_listen_fd = create_socket(skaddr_family(&g_listen_skaddr), SOCK_DGRAM);
if (g_listen_fd < 0)
return 1;

if (bind(g_udp_sockfd, &g_listen_skaddr.sa, skaddr_len(&g_listen_skaddr)) < 0) {
if (bind(g_listen_fd, &g_listen_skaddr.sa, skaddr_len(&g_listen_skaddr)) < 0) {
log_error("bind udp address: %m");
return 1;
}

evloop_t *evloop = ev_default_loop(0);

evio_t watcher;
ev_io_init(&watcher, udp_recvmsg_cb, g_udp_sockfd, EV_READ);
ev_io_init(&watcher, udp_recvmsg_cb, g_listen_fd, EV_READ);
ev_io_start(evloop, &watcher);

return ev_run(evloop, 0);
Expand All @@ -368,14 +423,14 @@ int main(int argc, char *argv[]) {
static void udp_recvmsg_cb(evloop_t *evloop, evio_t *watcher __unused, int events __unused) {
ctx_t *ctx = malloc(sizeof(*ctx));

ssize_t nrecv = recvfrom(g_udp_sockfd, (void *)ctx->buffer + 2, DNS_MSGSZ, 0, &ctx->srcaddr.sa, &(socklen_t){sizeof(ctx->srcaddr)});
ssize_t nrecv = recvfrom(g_listen_fd, (void *)ctx->buffer + 2, DNS_MSGSZ, 0, &ctx->srcaddr.sa, &(socklen_t){sizeof(ctx->srcaddr)});
if (nrecv < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK)
log_warning("recv from udp socket: %m");
goto free_ctx;
}

if (verbose()) {
if (verbose) {
char ip[IP6STRLEN];
uint16_t port;
skaddr_to_text(&ctx->srcaddr, ip, &port);
Expand All @@ -385,23 +440,28 @@ static void udp_recvmsg_cb(evloop_t *evloop, evio_t *watcher __unused, int event
uint16_t *p_msglen = (void *)ctx->buffer;
*p_msglen = htons(nrecv); /* msg length */

int sockfd = create_socket(skaddr_family(&g_remote_skaddr), SOCK_STREAM);
if (sockfd < 0)
int fd = create_socket(skaddr_family(&g_remote_skaddr), SOCK_STREAM);
if (fd < 0)
goto free_ctx;

if (connect(sockfd, &g_remote_skaddr.sa, skaddr_len(&g_remote_skaddr)) < 0 && errno != EINPROGRESS) {
if (has_flag(FLAG_LOCAL_ADDR) && bind(fd, &g_local_skaddr.sa, skaddr_len(&g_local_skaddr)) < 0) {
log_warning("bind tcp address: %m");
goto close_fd;
}

if (connect(fd, &g_remote_skaddr.sa, skaddr_len(&g_remote_skaddr)) < 0 && errno != EINPROGRESS) {
log_warning("connect to %s#%hu: %m", g_remote_ipstr, g_remote_port);
goto close_sockfd;
goto close_fd;
}
log_verbose("try to connect to %s#%hu", g_remote_ipstr, g_remote_port);

ev_io_init(&ctx->watcher, tcp_connect_cb, sockfd, EV_WRITE);
ev_io_init(&ctx->watcher, tcp_connect_cb, fd, EV_WRITE);
ev_io_start(evloop, &ctx->watcher);

return;

close_sockfd:
close(sockfd);
close_fd:
close(fd);
free_ctx:
free(ctx);
}
Expand Down Expand Up @@ -473,8 +533,8 @@ static void tcp_recvmsg_cb(evloop_t *evloop, evio_t *watcher, int events __unuse
uint16_t msglen;
if (ctx->nbytes < 2 || ctx->nbytes < 2 + (msglen = ntohs(*(uint16_t *)buffer))) return;

ssize_t nsend = sendto(g_udp_sockfd, buffer + 2, msglen, 0, &ctx->srcaddr.sa, skaddr_len(&ctx->srcaddr));
if (nsend < 0 || verbose()) {
ssize_t nsend = sendto(g_listen_fd, buffer + 2, msglen, 0, &ctx->srcaddr.sa, skaddr_len(&ctx->srcaddr));
if (nsend < 0 || verbose) {
char ip[IP6STRLEN];
uint16_t port;
skaddr_to_text(&ctx->srcaddr, ip, &port);
Expand Down

0 comments on commit 19eb638

Please sign in to comment.