From 3069fc0168e1e316a7dd2b115def49dd4a7e8467 Mon Sep 17 00:00:00 2001 From: retoor Date: Wed, 18 Dec 2024 05:57:06 +0100 Subject: [PATCH] Format. --- src/zamenyat/__main__.py | 35 +++------ src/zamenyat/app.py | 159 ++++++++++++++++++++++----------------- 2 files changed, 103 insertions(+), 91 deletions(-) diff --git a/src/zamenyat/__main__.py b/src/zamenyat/__main__.py index bdb656a..235306d 100644 --- a/src/zamenyat/__main__.py +++ b/src/zamenyat/__main__.py @@ -1,35 +1,22 @@ import argparse + from zamenyat.app import Application parser = argparse.ArgumentParser(description="Zamenyat sensitive content replacer.") -parser.add_argument( - "--host", - required=True, - type=str -) -parser.add_argument( - "--port", - required=True, - type=int -) -parser.add_argument( - "--upstream-host", - required=True, - type=str -) -parser.add_argument( - "--upstream-port", - required=True, - type=int -) +parser.add_argument("--host", required=True, type=str) +parser.add_argument("--port", required=True, type=int) +parser.add_argument("--upstream-host", required=True, type=str) +parser.add_argument("--upstream-port", required=True, type=int) def main(): args = parser.parse_args() - app = Application(upstream_host=args.upstream_host, upstream_port=args.upstream_port) - app.serve(host=args.host,port=args.port) + app = Application( + upstream_host=args.upstream_host, upstream_port=args.upstream_port + ) + app.serve(host=args.host, port=args.port) -if __name__ == '__main__': + +if __name__ == "__main__": main() - diff --git a/src/zamenyat/app.py b/src/zamenyat/app.py index 65f8869..0bd5126 100644 --- a/src/zamenyat/app.py +++ b/src/zamenyat/app.py @@ -1,25 +1,27 @@ -from app.app import Application as BaseApplication, get_timestamp import asyncio +import time from concurrent.futures import ThreadPoolExecutor as Executor -import time -ZAMENYAT_BACKLOG=100 +from app.app import get_timestamp + +ZAMENYAT_BACKLOG = 100 ZAMENYAT_THREAD_COUNT = 2500 ZAMENYAT_BUFFER_SIZE = 4096 -ZAMENYAT_HEADER_MAX_LENGTH = 4096*2 +ZAMENYAT_HEADER_MAX_LENGTH = 4096 * 2 + class AsyncWriter: - def __init__(self, writer, buffer_size=ZAMENYAT_BUFFER_SIZE,debug=False): + def __init__(self, writer, buffer_size=ZAMENYAT_BUFFER_SIZE, debug=False): self.debug = debug - self.writer = writer + self.writer = writer self.buffer_size = buffer_size self.drain = self.writer.drain self.close = self.writer.close self.wait_closed = self.writer.wait_closed async def write(self, data): - + while data: chunk_size = self.buffer_size if len(data) > self.buffer_size else len(data) chunk = data[:chunk_size] @@ -28,40 +30,43 @@ class AsyncWriter: print("Write chunk:", chunk) data = data[chunk_size:] await self.writer.drain() - + + class AsyncReader: - def __init__(self, reader,debug=False): - self.reader = reader - self.buffer = b'' + def __init__(self, reader, debug=False): + self.reader = reader + self.buffer = b"" self.debug = debug - + async def read_until(self, to_match): - buffer = b'' - while not to_match in buffer: + buffer = b"" + while to_match not in buffer: chunk = await self.read() if not chunk: return None buffer += chunk match_start = buffer.find(to_match) - data = buffer[:match_start + len(to_match)] - await self.unread(buffer[match_start + len(to_match):]) + data = buffer[: match_start + len(to_match)] + await self.unread(buffer[match_start + len(to_match) :]) return data async def read(self, buffer_size=ZAMENYAT_BUFFER_SIZE, exact=False): - read_extra = buffer_size - len(self.buffer) + buffer_size - len(self.buffer) while len(self.buffer) < buffer_size: chunk_size = buffer_size - len(self.buffer) chunk = await self.reader.read(chunk_size) if not chunk: return None - + if self.debug: print("Read chunk:", chunk) self.buffer += chunk if not exact: break - buffer_size = len(self.buffer) if len(self.buffer) < buffer_size else buffer_size + buffer_size = ( + len(self.buffer) if len(self.buffer) < buffer_size else buffer_size + ) data = self.buffer[:buffer_size] self.buffer = self.buffer[buffer_size:] @@ -70,30 +75,33 @@ class AsyncReader: async def unread(self, data): if not data: return - if hasattr(data, 'encode'): + if hasattr(data, "encode"): data = data.encode() self.buffer = data + self.buffer + class Socket: - def __init__(self, reader, writer, buffer_size,debug=False): + def __init__(self, reader, writer, buffer_size, debug=False): self.debug = debug - self.reader = AsyncReader(reader,debug=self.debug) - self.writer = AsyncWriter(writer,debug=self.debug) + self.reader = AsyncReader(reader, debug=self.debug) + self.writer = AsyncWriter(writer, debug=self.debug) self.read = self.reader.read self.read_until = self.reader.read_until self.unread = self.reader.unread self.write = self.writer.write self.drain = self.writer.drain self.close = self.writer.close - self.wait_closed = self.writer.wait_closed + self.wait_closed = self.writer.wait_closed + class Application: - def __init__(self, upstream_host, upstream_port, *args, **kwargs): + def __init__(self, upstream_host, upstream_port, silent=False, *args, **kwargs): self.upstream_host = upstream_host self.upstream_port = upstream_port self.server = None + self.silent = silent self.host = None self.port = None self.executor = None @@ -102,45 +110,49 @@ class Application: self.connection_count = 0 self.total_connection_count = 0 super().__init__(*args, **kwargs) - + async def get_headers(self, reader): - data = b'' - headers = await reader.read_until(b'\r\n\r\n') + headers = await reader.read_until(b"\r\n\r\n") if not headers: return None, None - headers = headers[:-2] + headers = headers[:-2] header_dict = {} req_resp, *headers = headers.split(b"\r\n") for header_line in headers[:-1]: key, *value = header_line.split(b": ") key = key.decode() value = ": ".join([value.decode() for value in value]) - header_dict[key] = int(value) if value.isdigit() else value + header_dict[key] = int(value) if value.isdigit() else value return req_resp.decode(), header_dict def header_dict_to_bytes(self, req_resp, headers): header_list = [req_resp] for key, value in headers.items(): - header_list.append("{}: {}".format(key, value)) + header_list.append(f"{key}: {value}") header_list.append("\r\n") return ("\r\n".join(header_list)).encode() - async def stream(self, reader,writer,is_websocket=False): + async def stream(self, reader, writer, is_websocket=False): global headers try: - reader = Socket(reader,writer, ZAMENYAT_BUFFER_SIZE) - writer = Socket(reader,writer, ZAMENYAT_BUFFER_SIZE) + reader = Socket(reader, writer, ZAMENYAT_BUFFER_SIZE) + writer = Socket(reader, writer, ZAMENYAT_BUFFER_SIZE) while True: req_resp, headers = None, None - data = b'' + data = b"" if not is_websocket: - req_resp, headers = await self.get_headers(reader) + req_resp, headers = await self.get_headers(reader) if not headers: return None if headers: - if 'Content-Length' in headers: - while len(data) != headers['Content-Length']: - chunk_size = headers['Content-Length'] - len(data) if self.buffer_size > headers['Content-Length'] - len(data) else self.buffer_size + if "Content-Length" in headers: + while len(data) != headers["Content-Length"]: + chunk_size = ( + headers["Content-Length"] - len(data) + if self.buffer_size + > headers["Content-Length"] - len(data) + else self.buffer_size + ) chunk = await reader.read(chunk_size) if not chunk: data = None @@ -154,51 +166,65 @@ class Application: data = await reader.read() await writer.write(data) if not is_websocket: - break + break except asyncio.CancelledError: pass finally: pass - return headers - #writer.close() - #await writer.wait_closed() - - async def handle_client(self,reader,writer): + return headers + # writer.close() + # await writer.wait_closed() + + async def handle_client(self, reader, writer): self.connection_count += 1 self.total_connection_count += 1 connection_nr = self.total_connection_count + upstream_reader, upstream_writer = await asyncio.open_connection( + self.upstream_host, self.upstream_port + ) - upstream_reader, upstream_writer = await asyncio.open_connection(self.upstream_host, self.upstream_port) - is_websocket = False while True: time_start = time.time() - print(f"Connected to upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Time: {get_timestamp()}") - + if not self.silent: + print( + f"Connected to upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Time: {get_timestamp()}" + ) + if is_websocket: await asyncio.gather( - self.stream(reader, upstream_writer,is_websocket), - self.stream(upstream_reader, writer, is_websocket) + self.stream(reader, upstream_writer, is_websocket), + self.stream(upstream_reader, writer, is_websocket), ) + if not self.silent: + print( + f"Disconnected upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Duration: {time_duration:.5f}s" + ) break else: - request_headers = await self.stream(reader, upstream_writer,is_websocket) + request_headers = await self.stream( + reader, upstream_writer, is_websocket + ) await self.stream(upstream_reader, writer, is_websocket) - + keep_alive = False if request_headers: - if request_headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'): + if ( + request_headers.get("Connection") == "keep-alive" + ): # and not headers.get('Upgrade-Insecure-Requests'): keep_alive = True - if request_headers.get("Upgrade") == 'websocket': + if request_headers.get("Upgrade") == "websocket": is_websocket = True - + time_end = time.time() time_duration = time_end - time_start - print(f"Disconnected upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Duration: {time_duration:.5f}s") - - + if not self.silent: + print( + f"Disconnected upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Duration: {time_duration:.5f}s" + ) + if not any([keep_alive, is_websocket]): break @@ -209,25 +235,24 @@ class Application: await writer.wait_closed() await upstream_writer.wait_closed() - - - def upgrade_executor(self, thread_count): self.executor = Executor(max_workers=thread_count) loop = asyncio.get_running_loop() loop.set_default_executor(self.executor) return self.executor - async def serve_async(self, host,port,backlog=ZAMENYAT_BACKLOG): + async def serve_async(self, host, port, backlog=ZAMENYAT_BACKLOG): self.upgrade_executor(ZAMENYAT_THREAD_COUNT) self.host = host self.port = port - self.server = await asyncio.start_server(self.handle_client, self.host, self.port,backlog=backlog) + self.server = await asyncio.start_server( + self.handle_client, self.host, self.port, backlog=backlog + ) async with self.server: await self.server.serve_forever() - def serve(self, host, port,backlog=ZAMENYAT_BACKLOG): + def serve(self, host, port, backlog=ZAMENYAT_BACKLOG): try: - asyncio.run(self.serve_async(host,port,backlog=backlog)) - except KeyboardInterrupt: + asyncio.run(self.serve_async(host, port, backlog=backlog)) + except KeyboardInterrupt: print("Shutted down server")