/*
* Description: Force bind on a specified address
* Author: Catalin(ux) M. BOIE
* E-mail: catab at embedromix dot ro
* Web: http://kernel.embedromix.ro/us/
*/
#define __USE_GNU
#define _GNU_SOURCE
#define __USE_XOPEN2K
#define __USE_LARGEFILE64
#define __USE_FILE_OFFSET64
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <syslog.h>
#include <dlfcn.h>
#include <fcntl.h>
#include <string.h>
#include <unistd.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <time.h>
#include <errno.h>
#include <dirent.h>
#include <asm/unistd.h>
#include <fcntl.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#define FB_FLAGS_NETSOCK 1
struct private
{
int domain;
int type;
unsigned int flags;
/* bandwidth */
unsigned long long limit;
unsigned long long rest;
struct timeval last;
};
struct node
{
int fd;
struct private priv;
struct node *next;
};
struct info
{
struct node *head, *tail;
};
static int (*old_bind)(int sockfd, const struct sockaddr *addr, socklen_t addrlen) = NULL;
static int (*old_setsockopt)(int sockfd, int level, int optname, const void *optval, socklen_t optlen);
static int (*old_socket)(int domain, int type, int protocol);
static int (*old_close)(int fd);
static ssize_t (*old_write)(int fd, const void *buf, size_t len);
static ssize_t (*old_send)(int sockfd, const void *buf, size_t len, int flags);
static ssize_t (*old_sendto)(int sockfd, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen);
static ssize_t (*old_sendmsg)(int sockfd, const struct msghdr *msg, int flags);
static int (*old_accept)(int sockfd, struct sockaddr *addr, socklen_t *addrlen);
static char *force_address_v4 = NULL;
static char *force_address_v6 = NULL;
static int force_port_v4 = -1;
static int force_port_v6 = -1;
static unsigned int force_tos = 0, tos;
static unsigned int force_ttl = 0, ttl;
static unsigned int force_keepalive = 0, keepalive;
static unsigned int force_mss = 0, mss;
static unsigned int force_reuseaddr = 0, reuseaddr;
static unsigned int force_nodelay = 0, nodelay;
static unsigned long long bw_limit_per_socket = 0;
static struct private bw_global;
static struct info fdinfo;
static unsigned int verbose = 0;
/* Helper functions */
static int my_syslog(int priority, const char *format, ...)
{
va_list ap;
if (verbose == 0)
return 0;
va_start(ap, format);
vsyslog(priority, format, ap);
va_end(ap);
return 0;
}
static struct node *get(const int fd)
{
struct node *p;
p = fdinfo.head;
while (p != NULL) {
if (p->fd == fd)
return p;
p = p->next;
}
return NULL;
}
static void add(const int fd, const struct private *p)
{
struct node *q;
/* do we have a copy? */
q = get(fd);
if (q == NULL) {
/* Try to find a free location */
q = fdinfo.head;
while (q != NULL) {
if (q->fd == -1) {
q->fd = fd;
break;
}
q = q->next;
}
if (q == NULL) {
q = (struct node *) malloc(sizeof(struct node));
if (q == NULL) {
my_syslog(LOG_INFO, "force_bind: Cannot alloc memory; ignore fd!\n");
return;
}
q->next = NULL;
q->fd = fd;
}
}
memcpy(&q->priv, p, sizeof(struct private));
/* Set bandwidth requirements */
q->priv.limit = bw_limit_per_socket;
if (bw_limit_per_socket > 0) {
q->priv.rest = 0;
gettimeofday(&q->priv.last, NULL);
}
if (fdinfo.tail == NULL) {
fdinfo.head = q;
} else {
fdinfo.tail->next = q;
}
fdinfo.tail = q;
}
static void del(const int fd)
{
struct node *p;
p = fdinfo.head;
while (p != NULL) {
if (p->fd == fd) {
p->fd = -1;
return;
}
p = p->next;
}
}
/* Functions */
static void init(void)
{
static unsigned char inited = 0;
char *x;
if (inited == 1)
return;
inited = 1;
fdinfo.head = NULL;
fdinfo.tail = NULL;
x = getenv("FORCE_NET_VERBOSE");
if (x != NULL)
verbose = strtol(x, NULL, 10);
x = getenv("FORCE_BIND_ADDRESS_V4");
if (x != NULL) {
force_address_v4 = x;
my_syslog(LOG_INFO, "force_bind: Force bind to address %s.\n",
force_address_v4);
}
x = getenv("FORCE_BIND_ADDRESS_V6");
if (x != NULL) {
force_address_v6 = x;
my_syslog(LOG_INFO, "force_bind: Force bind to address %s.\n",
force_address_v6);
}
/* obsolete mode */
x = getenv("FORCE_BIND_ADDRESS");
if (x != NULL) {
force_address_v4 = x;
force_address_v6 = x;
my_syslog(LOG_INFO, "force_bind: Force bind to address %s."
" Obsolete, use FORCE_BIND_ADDRESS_V4/6.\n",
force_address_v4);
}
x = getenv("FORCE_BIND_PORT_V4");
if (x != NULL) {
force_port_v4 = strtol(x, NULL, 10);
my_syslog(LOG_INFO, "force_bind: Force bind to port %d.\n",
force_port_v4);
}
x = getenv("FORCE_BIND_PORT_V6");
if (x != NULL) {
force_port_v6 = strtol(x, NULL, 10);
my_syslog(LOG_INFO, "force_bind: Force bind to port %d.\n",
force_port_v6);
}
/* obsolete mode */
x = getenv("FORCE_BIND_PORT");
if (x != NULL) {
force_port_v4 = strtol(x, NULL, 10);
force_port_v6 = strtol(x, NULL, 10);
my_syslog(LOG_INFO, "force_bind: Force bind to port %d."
" Obsolete, use FORCE_BIND_PORT_V4/6.\n",
force_port_v4);
}
/* tos */
x = getenv("FORCE_NET_TOS");
if (x != NULL) {
force_tos = 1;
tos = strtoul(x, NULL, 0);
my_syslog(LOG_INFO, "force_bind: Force TOS to %hhu.\n",
tos);
}
/* ttl */
x = getenv("FORCE_NET_TTL");
if (x != NULL) {
force_ttl = 1;
ttl = strtoul(x, NULL, 0);
my_syslog(LOG_INFO, "force_bind: Force TTL to %hhu.\n",
ttl);
}
/* keep alive */
x = getenv("FORCE_NET_KA");
if (x != NULL) {
force_keepalive = 1;
keepalive = strtoul(x, NULL, 0);
my_syslog(LOG_INFO, "force_bind: Force KA to %u.\n",
keepalive);
}
/* mss */
x = getenv("FORCE_NET_MSS");
if (x != NULL) {
force_mss = 1;
mss = strtoul(x, NULL, 0);
my_syslog(LOG_INFO, "force_bind: Force MSS to %u.\n",
mss);
}
/* REUSEADDR */
x = getenv("FORCE_NET_REUSEADDR");
if (x != NULL) {
force_reuseaddr = 1;
reuseaddr = strtoul(x, NULL, 0);
my_syslog(LOG_INFO, "force_bind: Force REUSEADDR to %u.\n",
reuseaddr);
}
/* NODELAY */
x = getenv("FORCE_NET_NODELAY");
if (x != NULL) {
force_nodelay = 1;
nodelay = strtoul(x, NULL, 0);
my_syslog(LOG_INFO, "force_bind: Force NODELAY to %u.\n",
nodelay);
}
/* bandwidth */
x = getenv("FORCE_NET_BW");
if (x != NULL) {
bw_global.limit = strtoul(x, NULL, 0);
gettimeofday(&bw_global.last, NULL);
bw_global.rest = 0;
my_syslog(LOG_INFO, "force_bind: Force bandwidth to %llub/s.\n",
bw_global.limit);
} else {
bw_global.limit = 0;
}
/* bandwidth per socket */
x = getenv("FORCE_NET_BW_PER_SOCKET");
if (x != NULL) {
if (bw_global.limit > 0) {
my_syslog(LOG_INFO, "force_bind: Cannot set limit per socket"
" when global one is set.\n");
} else {
bw_limit_per_socket = strtoul(x, NULL, 0);
my_syslog(LOG_INFO, "force_bind: Force bandwidth per socket to %llub/s.\n",
bw_limit_per_socket);
}
}
old_bind = dlsym(RTLD_NEXT, "bind");
if (old_bind == NULL) {
my_syslog(LOG_ERR, "force_bind: Cannot resolve 'bind'!\n");
exit(1);
}
old_setsockopt = dlsym(RTLD_NEXT, "setsockopt");
if (old_setsockopt == NULL) {
my_syslog(LOG_ERR, "force_bind: Cannot resolve 'setsockopt'!\n");
exit(1);
}
old_socket = dlsym(RTLD_NEXT, "socket");
if (old_socket == NULL) {
my_syslog(LOG_ERR, "force_bind: Cannot resolve 'socket'!\n");
exit(1);
}
old_close = dlsym(RTLD_NEXT, "close");
if (old_close == NULL) {
my_syslog(LOG_ERR, "force_bind: Cannot resolve 'close'!\n");
exit(1);
}
old_write = dlsym(RTLD_NEXT, "write");
if (old_write == NULL) {
my_syslog(LOG_ERR, "force_bind: Cannot resolve 'write'!\n");
exit(1);
}
old_send = dlsym(RTLD_NEXT, "send");
if (old_send == NULL) {
my_syslog(LOG_ERR, "force_bind: Cannot resolve 'send'!\n");
exit(1);
}
old_sendto = dlsym(RTLD_NEXT, "sendto");
if (old_sendto == NULL) {
my_syslog(LOG_ERR, "force_bind: Cannot resolve 'sendto'!\n");
exit(1);
}
old_sendmsg = dlsym(RTLD_NEXT, "sendmsg");
if (old_sendmsg == NULL) {
my_syslog(LOG_ERR, "force_bind: Cannot resolve 'sendmsg'!\n");
exit(1);
}
old_accept = dlsym(RTLD_NEXT, "accept");
if (old_accept == NULL) {
my_syslog(LOG_ERR, "force_bind: Cannot resolve 'accept'!\n");
exit(1);
}
my_syslog(LOG_INFO, "force_bind: Inited.\n");
}
static int set_ka(int sockfd)
{
int flag, ret;
if (force_keepalive == 0)
return 0;
flag = (keepalive > 0) ? 1 : 0;
ret = old_setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &flag, sizeof(flag));
my_syslog(LOG_INFO, "force_bind: changing SO_KEEPALIVE to %d (ret=%d) [%d].\n",
flag, ret, sockfd);
return 0;
}
static int set_ka_idle(int sockfd)
{
int ret;
if (force_keepalive == 0)
return 0;
ret = old_setsockopt(sockfd, IPPROTO_TCP, TCP_KEEPIDLE, &keepalive, sizeof(keepalive));
my_syslog(LOG_INFO, "force_bind: changing TCP_KEEPIDLE to %us (ret=%d) [%d].\n",
keepalive, ret, sockfd);
return 0;
}
static int set_mss(int sockfd)
{
int ret;
if (force_mss == 0)
return 0;
ret = old_setsockopt(sockfd, IPPROTO_TCP, TCP_MAXSEG, &mss, sizeof(mss));
my_syslog(LOG_INFO, "force_bind: changing MSS to %u (ret=%d) [%d].\n",
mss, ret, sockfd);
return 0;
}
static int set_tos(int sockfd)
{
int ret;
if (force_tos == 0)
return 0;
ret = old_setsockopt(sockfd, IPPROTO_IP, IP_TOS, &tos, sizeof(tos));
my_syslog(LOG_INFO, "force_bind: changing TOS to %hhu (ret=%d) [%d].\n",
tos, ret, sockfd);
return 0;
}
static int set_ttl(int sockfd)
{
int ret;
if (force_ttl == 0)
return 0;
ret = old_setsockopt(sockfd, IPPROTO_IP, IP_TTL, &ttl, sizeof(ttl));
my_syslog(LOG_INFO, "force_bind: changing TTL to %hhu (ret=%d) [%d].\n",
ttl, ret, sockfd);
return 0;
}
static int set_reuseaddr(int sockfd)
{
int ret;
if (force_reuseaddr == 0)
return 0;
ret = old_setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &reuseaddr, sizeof(reuseaddr));
my_syslog(LOG_INFO, "force_bind: changing reuseaddr to %u (ret=%d) [%d].\n",
reuseaddr, ret, sockfd);
return 0;
}
static int set_nodelay(int sockfd)
{
int ret;
if (force_nodelay == 0)
return 0;
ret = old_setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay));
my_syslog(LOG_INFO, "force_bind: changing nodelay to %u (ret=%d) [%d].\n",
nodelay, ret, sockfd);
return 0;
}
static void change_things(int sockfd, struct sockaddr *sa)
{
int err;
struct sockaddr_storage tmp;
socklen_t tmp_len;
struct sockaddr_in *sa4;
struct sockaddr_in6 *sa6;
unsigned short *pport = NULL;
void *p;
struct node *q;
char *force_address;
int force_port;
init();
/* We do not touch non network sockets */
q = get(sockfd);
if ((q == NULL) || ((q->priv.flags & FB_FLAGS_NETSOCK) == 0))
return;
if (sa == NULL) {
tmp_len = sizeof(struct sockaddr_storage);
err = getsockname(sockfd, (struct sockaddr *) &tmp, &tmp_len);
if (err != 0) {
my_syslog(LOG_INFO, "force_bind: Cannot get socket name err=%d (%s) [%d]!\n",
err, strerror(errno), sockfd);
return;
}
sa = (struct sockaddr *) &tmp;
}
switch (sa->sa_family) {
case AF_INET:
sa4 = (struct sockaddr_in *) sa;
p = &sa4->sin_addr;
pport = &sa4->sin_port;
force_address = force_address_v4;
force_port = force_port_v4;
break;
case AF_INET6:
sa6 = (struct sockaddr_in6 *) sa;
p = &sa6->sin6_addr.s6_addr;
pport = &sa6->sin6_port;
force_address = force_address_v6;
force_port = force_port_v6;
break;
default:
my_syslog(LOG_INFO, "force_bind: unsupported family=%u [%d]!\n",
sa->sa_family, sockfd);
return;
}
if (force_address != NULL) {
err = inet_pton(sa->sa_family, force_address, p);
if (err != 1) {
my_syslog(LOG_INFO, "force_bind: cannot convert [%s] (%d) (%s) [%d]!\n",
force_address, err, strerror(errno), sockfd);
return;
}
}
if (force_port != -1)
*pport = htons(force_port);
}
int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen)
{
struct sockaddr_storage new;
memcpy(&new, addr, addrlen);
change_things(sockfd, (struct sockaddr *) &new);
return old_bind(sockfd, (struct sockaddr *) &new, addrlen);
}
int setsockopt(int sockfd, int level, int optname, const void *optval,
socklen_t optlen)
{
init();
if (level == SOL_SOCKET) {
if (optname == SO_KEEPALIVE)
return set_ka(sockfd);
if (optname == SO_REUSEADDR)
return set_reuseaddr(sockfd);
}
if (level == IPPROTO_IP) {
if (optname == IP_TOS)
return set_tos(sockfd);
if (optname == IP_TTL)
return set_ttl(sockfd);
}
if (level == IPPROTO_TCP) {
if (optname == TCP_KEEPIDLE)
return set_ka_idle(sockfd);
if (optname == TCP_MAXSEG)
return set_mss(sockfd);
if (optname == TCP_NODELAY)
return set_nodelay(sockfd);
}
return old_setsockopt(sockfd, level, optname, optval, optlen);
}
/*
* Helper called when a socket is created: socket, accept
*/
void socket_create_callback(const int sockfd, int domain, int type)
{
struct private p;
socklen_t type_len;
int err;
init();
if (type == -1) {
type_len = sizeof(type);
err = getsockopt(sockfd, SOL_SOCKET, SO_TYPE, (void *) &type, &type_len);
if (err != 0)
my_syslog(LOG_INFO, "force_bind: Cannot get socket type err=%d (%s) [%d].\n",
err, strerror(errno), sockfd);
}
set_tos(sockfd);
set_ttl(sockfd);
set_ka(sockfd);
if (type == SOCK_STREAM)
set_ka_idle(sockfd);
set_mss(sockfd);
set_reuseaddr(sockfd);
set_nodelay(sockfd);
p.domain = domain;
p.type = type;
p.flags = FB_FLAGS_NETSOCK;
add(sockfd, &p);
}
/*
* 'socket' is hijacked to be able to call setsockopt on it.
*/
int socket(int domain, int type, int protocol)
{
int sockfd;
init();
sockfd = old_socket(domain, type, protocol);
if (sockfd == -1)
return -1;
socket_create_callback(sockfd, domain, type);
return sockfd;
}
/*
* Enforce bandwidth
*/
static void bw(const int sockfd, const ssize_t bytes)
{
struct timeval now;
struct timespec ts, rest;
long long allowed;
long long diff_ms, sleep_ms;
int err;
struct node *p;
struct private *q;
if (bytes <= 0)
return;
/* Is a network socket? */
p = get(sockfd);
if (p == NULL)
return;
q = &p->priv;
if ((q->flags & FB_FLAGS_NETSOCK) == 0)
return;
if (q->limit == 0) {
if (bw_global.limit == 0)
return;
q = &bw_global;
}
gettimeofday(&now, NULL);
diff_ms = (now.tv_sec - q->last.tv_sec) * 1000
+ (now.tv_usec - q->last.tv_usec) / 1000;
if (diff_ms < 0)
return;
allowed = q->rest + q->limit * diff_ms / 1000;
q->last = now;
printf("diff_ms=%lld rest=%llu bytes=%u allowed=%llub\n",
diff_ms, q->rest, bytes, allowed);
if (bytes <= allowed) {
q->rest = allowed - bytes;
printf("\tInside limit, rest=%llu.\n", q->rest);
return;
}
q->rest = 0;
sleep_ms = (bytes - allowed) * 1000 / q->limit;
ts.tv_sec = sleep_ms / 1000;
ts.tv_nsec = (sleep_ms % 1000) * 1000 * 1000;
printf("\tWe will sleep %lus %lunsec.\n", ts.tv_sec, ts.tv_nsec);
/* We try to sleep even if we are interrupted by signals */
while (1) {
err = nanosleep(&ts, &rest);
if (err == -1) {
if (errno == EINTR) {
ts = rest;
continue;
}
my_syslog(LOG_INFO, "force_bind: nanosleep returned error"
" (%d) (%s).\n",
err, strerror(errno));
}
break;
}
}
int close(int fd)
{
init();
del(fd);
return old_close(fd);
}
ssize_t write(int fd, const void *buf, size_t len)
{
ssize_t n;
change_things(fd, NULL);
n = old_write(fd, buf, len);
bw(fd, n);
return n;
}
ssize_t send(int sockfd, const void *buf, size_t len, int flags)
{
ssize_t n;
change_things(sockfd, NULL);
n = old_send(sockfd, buf, len, flags);
bw(sockfd, n);
return n;
}
ssize_t sendto(int sockfd, const void *buf, size_t len, int flags,
const struct sockaddr *dest_addr, socklen_t addrlen)
{
ssize_t n;
change_things(sockfd, NULL);
n = old_sendto(sockfd, buf, len, flags, dest_addr, addrlen);
bw(sockfd, n);
return n;
}
/*
* TODO: Add sendmmsg
*/
ssize_t sendmsg(int sockfd, const struct msghdr *msg, int flags)
{
ssize_t n;
change_things(sockfd, NULL);
n = old_sendmsg(sockfd, msg, flags);
bw(sockfd, n);
return n;
}
/*
* We have to hijack accept because this program may be a daemon.
* TODO: accept4 should also be hijacked.
*/
int accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
{
int new_sock;
init();
new_sock = old_accept(sockfd, addr, addrlen);
socket_create_callback(new_sock, -1, -1);
return new_sock;
}