diff --git a/src/zamenyat/app.py b/src/zamenyat/app.py index fc8400b..0895057 100644 --- a/src/zamenyat/app.py +++ b/src/zamenyat/app.py @@ -7,6 +7,64 @@ ZAMENYAT_THREAD_COUNT = 500 ZAMENYAT_BUFFER_SIZE = 4096*2 ZAMENYAT_HEADER_MAX_LENGTH = 4096*2 +class AsyncWriter: + + def __init__(self, writer, buffer_size=ZAMENYAT_BUFFER_SIZE): + self.writer = writer + self.buffer_size = buffer_size + self.drain = self.writer.drain + self.close = self.writer.close + self.wait_closed = self.writer.wait_closed + + async def write(self, data): + + while data: + chunk_size = self.buffer_size if len(data) > self.buffer_size else len(data) + chunk = data[:chunk_size] + self.writer.write(chunk) + data = data[chunk_size:] + await self.writer.drain() + +class AsyncReader: + + def __init__(self, reader): + self.reader = reader + self.buffer = b'' + + async def read(self, buffer_size=ZAMENYAT_BUFFER_SIZE, exact=False): + read_extra = buffer_size - len(self.buffer) + while len(self.buffer) < buffer_size: + chunk_size = buffer_size - len(self.buffer) + chunk = await self.reader.read(chunk_size) + if not chunk: + return None + self.buffer += chunk + if not exact: + break + buffer_size = len(self.buffer) if len(self.buffer) < buffer_size else buffer_size + data = self.buffer[:buffer_size] + self.buffer = self.buffer[buffer_size:] + return data + + async def unread(self, data): + if not data: + return + if hasattr(data, 'encode'): + data = data.encode() + self.buffer = data + self.buffer + +class Socket: + + def __init__(self, reader, writer, buffer_size): + self.reader = AsyncReader(reader) + self.writer = AsyncWriter(writer) + self.read = self.reader.read + self.unread = self.reader.unread + self.write = self.writer.write + self.drain = self.writer.drain + self.close = self.writer.close + self.wait_closed = self.writer.wait_closed + class Application: def __init__(self, upstream_host, upstream_port, *args, **kwargs): @@ -35,10 +93,11 @@ class Application: headers_end = data.find(b'\r\n\r\n') if headers_end: headers = data[:headers_end] - data = data[:headers_end + 4] + data = data[headers_end + 4:] + await reader.unread(data) break if not headers: - return None, None, None + return None, None header_dict = {} req_resp, *headers = headers.split(b"\r\n") for header_line in headers: @@ -46,7 +105,7 @@ class Application: key = key.decode() value = ": ".join([value.decode() for value in value]) header_dict[key] = int(value) if value.isdigit() else value - return req_resp.decode(), header_dict, data + return req_resp.decode(), header_dict def header_dict_to_bytes(self, req_resp, headers): header_list = [req_resp] @@ -55,42 +114,43 @@ class Application: header_list.append("\r\n") return ("\r\n".join(header_list)).encode() - async def stream(self, reader,writer): + async def stream(self, reader,writer,is_websocket=False): + global headers try: + reader = Socket(reader,writer, ZAMENYAT_BUFFER_SIZE) + writer = Socket(reader,writer, ZAMENYAT_BUFFER_SIZE) while True: - req_resp, headers, data = await self.get_headers(reader) - if not headers: - break + req_resp, headers = None, None + data = b'' + if not is_websocket: + req_resp, headers = await self.get_headers(reader) + if not headers: + break + else: + data = await reader.read() if 'Content-Length' in headers: while not len(data) == headers['Content-Length']: chunk_size = headers['Content-Length'] - len(data) if self.buffer_size > headers['Content-Length'] - len(data) else self.buffer_size + print("Bef read") chunk = await reader.read(chunk_size) if not chunk: data = None break + print("Aff read") data += chunk print(self.header_dict_to_bytes(req_resp,headers).decode()) - writer.write(self.header_dict_to_bytes(req_resp, headers)) - #await writer.drain() + await writer.write(self.header_dict_to_bytes(req_resp, headers)) + await writer.drain() if data: - print(data) - while data: - chunk_size = self.buffer_size if len(data) > self.buffer_size else len(data) - - - chunk = data[:chunk_size] - - - writer.write(chunk) - data = data[chunk_size:] - await writer.drain() - if not headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'): - break - + await writer.write(data) + #if not headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'): + # break + break except asyncio.CancelledError: pass finally: pass + return headers #writer.close() #await writer.wait_closed() @@ -101,18 +161,37 @@ class Application: upstream_reader, upstream_writer = await asyncio.open_connection(self.upstream_host, self.upstream_port) - time_start = time.time() - print(f"Connected to upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Time: {get_timestamp()}") - tasks = [ - self.stream(upstream_reader, writer), - self.stream(reader, upstream_writer) - ] - await asyncio.gather(*tasks) - time_end = time.time() - time_duration = time_end - time_start - print(f"Disconnected upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Duration: {time_duration:.5f}s") + + is_websocket = False + + while True: + time_start = time.time() + print(f"Connected to upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Time: {get_timestamp()}") + + request_headers = await self.stream(reader, upstream_writer,is_websocket) + await self.stream(upstream_reader, writer, is_websocket) + time_end = time.time() + time_duration = time_end - time_start + print(f"Disconnected upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Duration: {time_duration:.5f}s") + + keep_alive = False + + if request_headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'): + keep_alive = True + + if request_headers.get("Upgrade") == 'websocket': + is_websocket = True + + if not any([keep_alive, is_websocket]): + break + self.connection_count -= 1 + writer.close() + await writer.wait_closed() + upstream_writer.close() + await upstream_writer.wait_closed() + def upgrade_executor(self, thread_count): self.executor = Executor(max_workers=thread_count) loop = asyncio.get_running_loop()