diff --git a/src/zamenyat/app.py b/src/zamenyat/app.py index f9193aa..b0eea3b 100644 --- a/src/zamenyat/app.py +++ b/src/zamenyat/app.py @@ -4,12 +4,13 @@ from concurrent.futures import ThreadPoolExecutor as Executor import time ZAMENYAT_THREAD_COUNT = 500 -ZAMENYAT_BUFFER_SIZE = 4096*2 +ZAMENYAT_BUFFER_SIZE = 64 ZAMENYAT_HEADER_MAX_LENGTH = 4096*2 class AsyncWriter: - def __init__(self, writer, buffer_size=ZAMENYAT_BUFFER_SIZE): + def __init__(self, writer, buffer_size=ZAMENYAT_BUFFER_SIZE,debug=False): + self.debug = debug self.writer = writer self.buffer_size = buffer_size self.drain = self.writer.drain @@ -22,14 +23,29 @@ class AsyncWriter: chunk_size = self.buffer_size if len(data) > self.buffer_size else len(data) chunk = data[:chunk_size] self.writer.write(chunk) + if self.debug: + print("Write chunk:", chunk) data = data[chunk_size:] await self.writer.drain() class AsyncReader: - def __init__(self, reader): + def __init__(self, reader,debug=False): self.reader = reader self.buffer = b'' + self.debug = debug + + async def read_until(self, to_match): + buffer = b'' + while not to_match in buffer: + chunk = await self.read() + if not chunk: + return None + buffer += chunk + match_start = buffer.find(to_match) + data = buffer[:match_start + len(to_match)] + await self.unread(buffer[match_start + len(to_match):]) + return data async def read(self, buffer_size=ZAMENYAT_BUFFER_SIZE, exact=False): read_extra = buffer_size - len(self.buffer) @@ -38,12 +54,16 @@ class AsyncReader: chunk = await self.reader.read(chunk_size) if not chunk: return None + + if self.debug: + print("Read chunk:", chunk) self.buffer += chunk if not exact: break buffer_size = len(self.buffer) if len(self.buffer) < buffer_size else buffer_size data = self.buffer[:buffer_size] self.buffer = self.buffer[buffer_size:] + return data async def unread(self, data): @@ -55,10 +75,12 @@ class AsyncReader: class Socket: - def __init__(self, reader, writer, buffer_size): - self.reader = AsyncReader(reader) - self.writer = AsyncWriter(writer) + def __init__(self, reader, writer, buffer_size,debug=True): + self.debug = debug + self.reader = AsyncReader(reader,debug=self.debug) + self.writer = AsyncWriter(writer,debug=self.debug) self.read = self.reader.read + self.read_until = self.reader.read_until self.unread = self.reader.unread self.write = self.writer.write self.drain = self.writer.drain @@ -82,25 +104,13 @@ class Application: async def get_headers(self, reader): data = b'' - headers = None - while True: - chunk = await reader.read(self.buffer_size) - if not chunk: - break - data += chunk - if len(data) > self.header_max_length: - break - headers_end = data.find(b'\r\n\r\n') - if headers_end: - headers = data[:headers_end] - data = data[headers_end + 4:] - await reader.unread(data) - break + headers = await reader.read_until(b'\r\n\r\n') if not headers: return None, None + headers = headers[:-2] header_dict = {} req_resp, *headers = headers.split(b"\r\n") - for header_line in headers: + for header_line in headers[:-1]: key, *value = header_line.split(b": ") key = key.decode() value = ": ".join([value.decode() for value in value]) @@ -125,7 +135,7 @@ class Application: if not is_websocket: req_resp, headers = await self.get_headers(reader) if not headers: - break + return None if headers: if 'Content-Length' in headers: while len(data) != headers['Content-Length']: @@ -134,7 +144,7 @@ class Application: chunk = await reader.read(chunk_size) if not chunk: data = None - break + return None print("Aff read") data += chunk await writer.write(self.header_dict_to_bytes(req_resp, headers)) @@ -146,7 +156,8 @@ class Application: await writer.write(data) #if not headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'): # break - break + if not is_websocket: + break except asyncio.CancelledError: pass finally: @@ -168,23 +179,31 @@ class Application: while True: time_start = time.time() print(f"Connected to upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Time: {get_timestamp()}") - - request_headers = await self.stream(reader, upstream_writer,is_websocket) - - keep_alive = False - if request_headers: - if request_headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'): - keep_alive = True - if request_headers.get("Upgrade") == 'websocket': - is_websocket = True - await self.stream(upstream_reader, writer, is_websocket) - time_end = time.time() - time_duration = time_end - time_start - print(f"Disconnected upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Duration: {time_duration:.5f}s") - - - if not any([keep_alive, is_websocket]): + + if is_websocket: + await asyncio.gather( + self.stream(reader, upstream_writer,is_websocket), + self.stream(upstream_reader, writer, is_websocket) + ) break + else: + request_headers = await self.stream(reader, upstream_writer,is_websocket) + await self.stream(upstream_reader, writer, is_websocket) + + keep_alive = False + if request_headers: + if request_headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'): + keep_alive = True + if request_headers.get("Upgrade") == 'websocket': + is_websocket = True + + time_end = time.time() + time_duration = time_end - time_start + print(f"Disconnected upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Duration: {time_duration:.5f}s") + + + if not any([keep_alive, is_websocket]): + break self.connection_count -= 1