#include #include #include #include #include #include #include #include #include #include #include #include #include #include #define MAX_EVENTS 8096 #define BUFFER_SIZE 1024 typedef struct { int client_fd; int upstream_fd; char *buffer; size_t buffer_size; size_t buffer_offset; } connection_t; int listen_fd = 0; int epoll_fd = 0; connection_t connections[MAX_EVENTS][sizeof(connection_t)] = {0}; int sock_init(void); void sock_exit(void); void set_nonblocking(int fd) { int flags = fcntl(fd, F_GETFL, 0); if (flags == -1) { perror("fcntl get"); exit(EXIT_FAILURE); } if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) { perror("fcntl set"); exit(EXIT_FAILURE); } } int prepare_upstream() { int sockfd = socket(AF_INET, SOCK_STREAM, 0); return sockfd; } int connect_upstream(const char *host, int port) { int sockfd = socket(AF_INET, SOCK_STREAM, 0); if (sockfd == -1) { perror("socket"); return -1; } set_nonblocking(sockfd); struct sockaddr_in server_addr; memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_port = htons(port); if (inet_pton(AF_INET, host, &server_addr.sin_addr) <= 0) { perror("inet_pton"); close(sockfd); return -1; } if (connect(sockfd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) { if (errno != EINPROGRESS) { perror("connect"); close(sockfd); return -1; } } return sockfd; } int create_listening_socket(int port) { int listen_fd = socket(AF_INET, SOCK_STREAM, 0); if (listen_fd == -1) { perror("socket"); return -1; } int opt = 1; if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) { perror("setsockopt"); close(listen_fd); return -1; } struct sockaddr_in server_addr; memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_addr.s_addr = INADDR_ANY; server_addr.sin_port = htons(port); if (bind(listen_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) { perror("bind"); close(listen_fd); return -1; } if (listen(listen_fd, SOMAXCONN) == -1) { perror("listen"); close(listen_fd); return -1; } set_nonblocking(listen_fd); return listen_fd; } char *sock_read(int fd, char *buf, size_t size) { connection_t *conn = connections[fd]; size_t left_in_buffer = conn->buffer_size - conn->buffer_offset; size_t bytes_to_read = size > left_in_buffer ? left_in_buffer : size; ssize_t bytes_read = 0; char *buffer; buffer[size]; if (bytes_to_read) { bytes_read = recv(fd, buffer, size, 0); buffer[bytes_read] = 0; } memcpy(buf, conn->buffer + conn->buffer_offset, bytes_to_read); if (bytes_read > 0) { return buf; } else if (bytes_read == 0) { printf("Connection closed by remote (fd=%d)\n", fd); } else { perror("read"); } return NULL; } void close_connection(int epoll_fd, connection_t *conn) { if (conn->client_fd != -1) { epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->client_fd, NULL); close(conn->client_fd); } if (conn->upstream_fd != -1) { epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->upstream_fd, NULL); close(conn->upstream_fd); } } int forward_data(int from_fd, int to_fd) { static char buffer[BUFFER_SIZE]; // Feels great to do somehow. Better safe than sorry. memset(buffer, 0, BUFFER_SIZE); ssize_t bytes_read = recv(from_fd, buffer, sizeof(buffer), 0); if (bytes_read > 0) { ssize_t bytes_written = send(to_fd, buffer, bytes_read, 0); if (bytes_written == -1) { perror("write"); } } else if (bytes_read == 0) { printf("Connection closed by remote (fd=%d)\n", from_fd); } else { perror("read"); } return (int)bytes_read; } bool handle_connect(struct epoll_event event, int epoll_fd) { struct sockaddr_in client_addr; socklen_t client_len = sizeof(client_addr); int client_fd = accept(listen_fd, (struct sockaddr *)&client_addr, &client_len); if (client_fd == -1) { perror("accept"); return false; } set_nonblocking(client_fd); struct epoll_event client_event; client_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP; client_event.data.ptr = connections[client_fd]; client_event.data.fd = client_fd; connections[client_fd]->upstream_fd = -1; connections[client_fd]->client_fd = client_fd; printf("New connection: client_fd=%d\n", client_fd); epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &client_event); return true; } void handle_close(int epoll_fd, connection_t *conn) { printf("Connection closed: client_fd=%d, upstream_fd=%d\n", conn->client_fd, conn->upstream_fd); close_connection(epoll_fd, conn); } void handle_stream(struct epoll_event event, int epoll_fd, connection_t *conn) { if (conn->upstream_fd == -1) { conn->upstream_fd = prepare_upstream(); int upstream_fd = py_route(conn->client_fd, conn->upstream_fd); if (upstream_fd == -1) { close_connection(epoll_fd, conn); return; } set_nonblocking(upstream_fd); struct epoll_event upstream_event; upstream_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP; upstream_event.data.ptr = connections[upstream_fd]; upstream_event.data.fd = upstream_fd; connections[conn->client_fd]->upstream_fd = upstream_fd; connections[upstream_fd]->client_fd = conn->client_fd; connections[upstream_fd]->upstream_fd = upstream_fd; epoll_ctl(epoll_fd, EPOLL_CTL_ADD, upstream_fd, &upstream_event); printf("Connected: client_fd=%d, upstream_fd=%d\n", conn->client_fd, conn->upstream_fd); return; } if (event.data.fd == conn->client_fd) { if (forward_data(conn->client_fd, conn->upstream_fd) < 1) { close_connection(epoll_fd, conn); } } else if (event.data.fd == conn->upstream_fd) { if (forward_data(conn->upstream_fd, conn->client_fd) < 1) { close_connection(epoll_fd, conn); } } } void serve(int port) { listen_fd = create_listening_socket(port); if (listen_fd == -1) { fprintf(stderr, "Failed to create listening socket\n"); return; } epoll_fd = epoll_create1(0); if (epoll_fd == -1) { perror("epoll_create1"); close(listen_fd); return; } struct epoll_event event; event.events = EPOLLIN; event.data.fd = listen_fd; if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, listen_fd, &event) == -1) { perror("epoll_ctl"); close(listen_fd); close(epoll_fd); return; } struct epoll_event events[MAX_EVENTS]; memset(events, 0, sizeof(events)); printf("Pretty Good Server listening on port %d\n", port); while (1) { int num_events = epoll_wait(epoll_fd, events, MAX_EVENTS, -1); if (num_events == -1) { perror("epoll_wait"); break; } for (int i = 0; i < num_events; i++) { if (events[i].data.fd == listen_fd) { handle_connect(events[i], epoll_fd); } else { connection_t *conn = connections[events[i].data.fd]; if (events[i].events & (EPOLLHUP | EPOLLERR)) { handle_close(epoll_fd, conn); } else if (events[i].events & EPOLLIN) { handle_stream(events[i], epoll_fd, conn); } } } } }