From 415d2fb666e62b9c0b66fc165e9af8ba26471df7 Mon Sep 17 00:00:00 2001 From: retoor Date: Tue, 31 Dec 2024 17:34:40 +0100 Subject: [PATCH] Refactor. --- protocol.h | 4 + sock.h | 293 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 297 insertions(+) create mode 100644 protocol.h create mode 100644 sock.h diff --git a/protocol.h b/protocol.h new file mode 100644 index 0000000..c71e9b6 --- /dev/null +++ b/protocol.h @@ -0,0 +1,4 @@ +#include +#include +#include +#include diff --git a/sock.h b/sock.h new file mode 100644 index 0000000..462f272 --- /dev/null +++ b/sock.h @@ -0,0 +1,293 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define MAX_EVENTS 8096 +#define BUFFER_SIZE 1024 + +typedef struct { + int client_fd; + int upstream_fd; + char *buffer; + size_t buffer_size; + size_t buffer_offset; +} connection_t; + +int listen_fd = 0; +int epoll_fd = 0; +connection_t connections[MAX_EVENTS][sizeof(connection_t)] = {0}; + +int sock_init(void); +void sock_exit(void); + +void set_nonblocking(int fd) { + int flags = fcntl(fd, F_GETFL, 0); + if (flags == -1) { + perror("fcntl get"); + exit(EXIT_FAILURE); + } + if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) { + perror("fcntl set"); + exit(EXIT_FAILURE); + } +} + +int prepare_upstream() { + int sockfd = socket(AF_INET, SOCK_STREAM, 0); + + return sockfd; +} + +int connect_upstream(const char *host, int port) { + int sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (sockfd == -1) { + perror("socket"); + return -1; + } + + set_nonblocking(sockfd); + + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(port); + if (inet_pton(AF_INET, host, &server_addr.sin_addr) <= 0) { + perror("inet_pton"); + close(sockfd); + return -1; + } + + if (connect(sockfd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == + -1) { + if (errno != EINPROGRESS) { + perror("connect"); + close(sockfd); + return -1; + } + } + + return sockfd; +} + +int create_listening_socket(int port) { + int listen_fd = socket(AF_INET, SOCK_STREAM, 0); + if (listen_fd == -1) { + perror("socket"); + return -1; + } + + int opt = 1; + if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == + -1) { + perror("setsockopt"); + close(listen_fd); + return -1; + } + + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = INADDR_ANY; + server_addr.sin_port = htons(port); + + if (bind(listen_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == + -1) { + perror("bind"); + close(listen_fd); + return -1; + } + + if (listen(listen_fd, SOMAXCONN) == -1) { + perror("listen"); + close(listen_fd); + return -1; + } + + set_nonblocking(listen_fd); + return listen_fd; +} + +char *sock_read(int fd, char *buf, size_t size) { + connection_t *conn = connections[fd]; + size_t left_in_buffer = conn->buffer_size - conn->buffer_offset; + size_t bytes_to_read = size > left_in_buffer ? left_in_buffer : size; + ssize_t bytes_read = 0; + char *buffer; + buffer[size]; + if (bytes_to_read) { + bytes_read = recv(fd, buffer, size, 0); + buffer[bytes_read] = 0; + } + memcpy(buf, conn->buffer + conn->buffer_offset, bytes_to_read); + + if (bytes_read > 0) { + return buf; + } else if (bytes_read == 0) { + printf("Connection closed by remote (fd=%d)\n", fd); + } else { + perror("read"); + } + return NULL; +} + +void close_connection(int epoll_fd, connection_t *conn) { + if (conn->client_fd != -1) { + epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->client_fd, NULL); + close(conn->client_fd); + } + if (conn->upstream_fd != -1) { + epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->upstream_fd, NULL); + close(conn->upstream_fd); + } +} + +int forward_data(int from_fd, int to_fd) { + static char buffer[BUFFER_SIZE]; + // Feels great to do somehow. Better safe than sorry. + memset(buffer, 0, BUFFER_SIZE); + ssize_t bytes_read = recv(from_fd, buffer, sizeof(buffer), 0); + if (bytes_read > 0) { + ssize_t bytes_written = send(to_fd, buffer, bytes_read, 0); + if (bytes_written == -1) { + perror("write"); + } + } else if (bytes_read == 0) { + printf("Connection closed by remote (fd=%d)\n", from_fd); + } else { + perror("read"); + } + return (int)bytes_read; +} + +bool handle_connect(struct epoll_event event, int epoll_fd) { + struct sockaddr_in client_addr; + socklen_t client_len = sizeof(client_addr); + int client_fd = + accept(listen_fd, (struct sockaddr *)&client_addr, &client_len); + if (client_fd == -1) { + perror("accept"); + return false; + } + set_nonblocking(client_fd); + + struct epoll_event client_event; + client_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP; + client_event.data.ptr = connections[client_fd]; + client_event.data.fd = client_fd; + connections[client_fd]->upstream_fd = -1; + connections[client_fd]->client_fd = client_fd; + + printf("New connection: client_fd=%d\n", client_fd); + + epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &client_event); + return true; +} + +void handle_close(int epoll_fd, connection_t *conn) { + printf("Connection closed: client_fd=%d, upstream_fd=%d\n", conn->client_fd, + conn->upstream_fd); + close_connection(epoll_fd, conn); +} + +void handle_stream(struct epoll_event event, int epoll_fd, connection_t *conn) { + if (conn->upstream_fd == -1) { + conn->upstream_fd = prepare_upstream(); + int upstream_fd = py_route(conn->client_fd, conn->upstream_fd); + + if (upstream_fd == -1) { + close_connection(epoll_fd, conn); + return; + } + set_nonblocking(upstream_fd); + struct epoll_event upstream_event; + upstream_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP; + upstream_event.data.ptr = connections[upstream_fd]; + upstream_event.data.fd = upstream_fd; + + connections[conn->client_fd]->upstream_fd = upstream_fd; + connections[upstream_fd]->client_fd = conn->client_fd; + connections[upstream_fd]->upstream_fd = upstream_fd; + + epoll_ctl(epoll_fd, EPOLL_CTL_ADD, upstream_fd, &upstream_event); + + printf("Connected: client_fd=%d, upstream_fd=%d\n", conn->client_fd, + conn->upstream_fd); + return; + } + + if (event.data.fd == conn->client_fd) { + + if (forward_data(conn->client_fd, conn->upstream_fd) < 1) { + close_connection(epoll_fd, conn); + } + } else if (event.data.fd == conn->upstream_fd) { + + if (forward_data(conn->upstream_fd, conn->client_fd) < 1) { + close_connection(epoll_fd, conn); + } + } +} + +void serve(int port) { + + listen_fd = create_listening_socket(port); + if (listen_fd == -1) { + fprintf(stderr, "Failed to create listening socket\n"); + return; + } + + epoll_fd = epoll_create1(0); + if (epoll_fd == -1) { + perror("epoll_create1"); + close(listen_fd); + return; + } + + struct epoll_event event; + event.events = EPOLLIN; + event.data.fd = listen_fd; + + if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, listen_fd, &event) == -1) { + perror("epoll_ctl"); + close(listen_fd); + close(epoll_fd); + return; + } + + struct epoll_event events[MAX_EVENTS]; + memset(events, 0, sizeof(events)); + + printf("Pretty Good Server listening on port %d\n", port); + + while (1) { + int num_events = epoll_wait(epoll_fd, events, MAX_EVENTS, -1); + if (num_events == -1) { + perror("epoll_wait"); + break; + } + + for (int i = 0; i < num_events; i++) { + if (events[i].data.fd == listen_fd) { + handle_connect(events[i], epoll_fd); + } else { + connection_t *conn = connections[events[i].data.fd]; + if (events[i].events & (EPOLLHUP | EPOLLERR)) { + handle_close(epoll_fd, conn); + } else if (events[i].events & EPOLLIN) { + handle_stream(events[i], epoll_fd, conn); + } + } + } + } +} \ No newline at end of file