#include "py.h" #include #include #include #include #include #include #include #include #include #include #include #define LISTEN_PORT 2222 #define UPSTREAM_HOST "127.0.0.1" #define UPSTREAM_PORT 9999 #define MAX_EVENTS 8096 #define BUFFER_SIZE 1024 void set_nonblocking(int fd) { int flags = fcntl(fd, F_GETFL, 0); if (flags == -1) { perror("fcntl get"); exit(EXIT_FAILURE); } if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) { perror("fcntl set"); exit(EXIT_FAILURE); } } int prepare_upstream() { int sockfd = socket(AF_INET, SOCK_STREAM, 0); return sockfd; } int connect_upstream(const char *host, int port) { int sockfd = socket(AF_INET, SOCK_STREAM, 0); if (sockfd == -1) { perror("socket"); return -1; } set_nonblocking(sockfd); struct sockaddr_in server_addr; memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_port = htons(port); if (inet_pton(AF_INET, host, &server_addr.sin_addr) <= 0) { perror("inet_pton"); close(sockfd); return -1; } if (connect(sockfd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) { if (errno != EINPROGRESS) { perror("connect"); close(sockfd); return -1; } } return sockfd; } int create_listening_socket(int port) { int listen_fd = socket(AF_INET, SOCK_STREAM, 0); if (listen_fd == -1) { perror("socket"); return -1; } int opt = 1; if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) { perror("setsockopt"); close(listen_fd); return -1; } struct sockaddr_in server_addr; memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_addr.s_addr = INADDR_ANY; server_addr.sin_port = htons(port); if (bind(listen_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) { perror("bind"); close(listen_fd); return -1; } if (listen(listen_fd, SOMAXCONN) == -1) { perror("listen"); close(listen_fd); return -1; } set_nonblocking(listen_fd); return listen_fd; } typedef struct { int client_fd; int upstream_fd; } connection_t; void close_connection(int epoll_fd, connection_t *conn) { if (conn->client_fd != -1) { epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->client_fd, NULL); close(conn->client_fd); } if (conn->upstream_fd != -1) { epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->upstream_fd, NULL); close(conn->upstream_fd); } } int forward_data(int from_fd, int to_fd) { static char buffer[BUFFER_SIZE]; // Feels great to do somehow. Better safe than sorry. memset(buffer, 0, BUFFER_SIZE); ssize_t bytes_read = recv(from_fd, buffer, sizeof(buffer), 0); if (bytes_read > 0) { ssize_t bytes_written = send(to_fd, buffer, bytes_read, 0); if (bytes_written == -1) { perror("write"); } } else if (bytes_read == 0) { printf("Connection closed by remote (fd=%d)\n", from_fd); } else { perror("read"); } return (int)bytes_read; } int listen_fd = 0; int epoll_fd = 0; void cleanup() { close(epoll_fd); close(listen_fd); py_destruct(); printf("Graceful exit.\n"); } void handle_sigint(int sig) { printf("\nCtrl+C pressed.\n"); exit(0); } int main() { if (signal(SIGINT, handle_sigint) == SIG_ERR) { perror("Failed to register signal handler"); return EXIT_FAILURE; } listen_fd = create_listening_socket(LISTEN_PORT); if (listen_fd == -1) { fprintf(stderr, "Failed to create listening socket\n"); return EXIT_FAILURE; } atexit(cleanup); epoll_fd = epoll_create1(0); if (epoll_fd == -1) { perror("epoll_create1"); close(listen_fd); return EXIT_FAILURE; } struct epoll_event event; event.events = EPOLLIN; event.data.fd = listen_fd; if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, listen_fd, &event) == -1) { perror("epoll_ctl"); close(listen_fd); close(epoll_fd); return EXIT_FAILURE; } struct epoll_event events[MAX_EVENTS]; memset(events, 0, sizeof(events)); printf("Intercepting load balancer listening on port %d\n", LISTEN_PORT); connection_t connections[MAX_EVENTS][sizeof(connection_t)] = {0}; while (1) { int num_events = epoll_wait(epoll_fd, events, MAX_EVENTS, -1); if (num_events == -1) { perror("epoll_wait"); break; } for (int i = 0; i < num_events; i++) { if (events[i].data.fd == listen_fd) { struct sockaddr_in client_addr; socklen_t client_len = sizeof(client_addr); int client_fd = accept(listen_fd, (struct sockaddr *)&client_addr, &client_len); if (client_fd == -1) { perror("accept"); continue; } set_nonblocking(client_fd); struct epoll_event client_event; client_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP; client_event.data.ptr = connections[client_fd]; client_event.data.fd = client_fd; connections[client_fd]->upstream_fd = -1; connections[client_fd]->client_fd = client_fd; // connections[client_fd]->upstream_fd = upstream_fd; // connections[conn->client_fd] = conn; // connections[upstream_fd]->client_fd = client_fd; // connections[upstream_fd]->upstream_fd = upstream_fd; epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &client_event); } else { // Handle data forwarding for existing connections connection_t *conn = connections[events[i].data.fd]; if (events[i].events & (EPOLLHUP | EPOLLERR)) { printf("Connection closed: client_fd=%d, upstream_fd=%d\n", conn->client_fd, conn->upstream_fd); close_connection(epoll_fd, conn); } else if (events[i].events & EPOLLIN) { if (conn->upstream_fd == -1) { conn->upstream_fd = prepare_upstream(); int upstream_fd = py_route(conn->client_fd, conn->upstream_fd); if (upstream_fd == -1) { close(conn->client_fd); continue; } set_nonblocking(upstream_fd); struct epoll_event upstream_event; upstream_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP; upstream_event.data.ptr = connections[upstream_fd]; upstream_event.data.fd = upstream_fd; connections[conn->client_fd]->upstream_fd = upstream_fd; connections[upstream_fd]->client_fd = conn->client_fd; connections[upstream_fd]->upstream_fd = upstream_fd; epoll_ctl(epoll_fd, EPOLL_CTL_ADD, upstream_fd, &upstream_event); printf("Connected: client_fd=%d, upstream_fd=%d\n", conn->client_fd, conn->upstream_fd); continue; } if (events[i].data.fd == conn->client_fd) { if (forward_data(conn->client_fd, conn->upstream_fd) < 1) { close_connection(epoll_fd, conn); } } else if (events[i].data.fd == conn->upstream_fd) { if (forward_data(conn->upstream_fd, conn->client_fd) < 1) { close_connection(epoll_fd, conn); } } } } } } close(listen_fd); close(epoll_fd); return EXIT_SUCCESS; }