This commit is contained in:
retoor 2024-12-18 05:57:06 +01:00
parent 4badc7ae7d
commit 3069fc0168
2 changed files with 103 additions and 91 deletions

View File

@ -1,35 +1,22 @@
import argparse import argparse
from zamenyat.app import Application from zamenyat.app import Application
parser = argparse.ArgumentParser(description="Zamenyat sensitive content replacer.") parser = argparse.ArgumentParser(description="Zamenyat sensitive content replacer.")
parser.add_argument( parser.add_argument("--host", required=True, type=str)
"--host", parser.add_argument("--port", required=True, type=int)
required=True, parser.add_argument("--upstream-host", required=True, type=str)
type=str parser.add_argument("--upstream-port", required=True, type=int)
)
parser.add_argument(
"--port",
required=True,
type=int
)
parser.add_argument(
"--upstream-host",
required=True,
type=str
)
parser.add_argument(
"--upstream-port",
required=True,
type=int
)
def main(): def main():
args = parser.parse_args() args = parser.parse_args()
app = Application(upstream_host=args.upstream_host, upstream_port=args.upstream_port) app = Application(
app.serve(host=args.host,port=args.port) upstream_host=args.upstream_host, upstream_port=args.upstream_port
)
app.serve(host=args.host, port=args.port)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -1,16 +1,18 @@
from app.app import Application as BaseApplication, get_timestamp
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor as Executor
import time import time
from concurrent.futures import ThreadPoolExecutor as Executor
ZAMENYAT_BACKLOG=100 from app.app import get_timestamp
ZAMENYAT_BACKLOG = 100
ZAMENYAT_THREAD_COUNT = 2500 ZAMENYAT_THREAD_COUNT = 2500
ZAMENYAT_BUFFER_SIZE = 4096 ZAMENYAT_BUFFER_SIZE = 4096
ZAMENYAT_HEADER_MAX_LENGTH = 4096*2 ZAMENYAT_HEADER_MAX_LENGTH = 4096 * 2
class AsyncWriter: class AsyncWriter:
def __init__(self, writer, buffer_size=ZAMENYAT_BUFFER_SIZE,debug=False): def __init__(self, writer, buffer_size=ZAMENYAT_BUFFER_SIZE, debug=False):
self.debug = debug self.debug = debug
self.writer = writer self.writer = writer
self.buffer_size = buffer_size self.buffer_size = buffer_size
@ -29,27 +31,28 @@ class AsyncWriter:
data = data[chunk_size:] data = data[chunk_size:]
await self.writer.drain() await self.writer.drain()
class AsyncReader: class AsyncReader:
def __init__(self, reader,debug=False): def __init__(self, reader, debug=False):
self.reader = reader self.reader = reader
self.buffer = b'' self.buffer = b""
self.debug = debug self.debug = debug
async def read_until(self, to_match): async def read_until(self, to_match):
buffer = b'' buffer = b""
while not to_match in buffer: while to_match not in buffer:
chunk = await self.read() chunk = await self.read()
if not chunk: if not chunk:
return None return None
buffer += chunk buffer += chunk
match_start = buffer.find(to_match) match_start = buffer.find(to_match)
data = buffer[:match_start + len(to_match)] data = buffer[: match_start + len(to_match)]
await self.unread(buffer[match_start + len(to_match):]) await self.unread(buffer[match_start + len(to_match) :])
return data return data
async def read(self, buffer_size=ZAMENYAT_BUFFER_SIZE, exact=False): async def read(self, buffer_size=ZAMENYAT_BUFFER_SIZE, exact=False):
read_extra = buffer_size - len(self.buffer) buffer_size - len(self.buffer)
while len(self.buffer) < buffer_size: while len(self.buffer) < buffer_size:
chunk_size = buffer_size - len(self.buffer) chunk_size = buffer_size - len(self.buffer)
chunk = await self.reader.read(chunk_size) chunk = await self.reader.read(chunk_size)
@ -61,7 +64,9 @@ class AsyncReader:
self.buffer += chunk self.buffer += chunk
if not exact: if not exact:
break break
buffer_size = len(self.buffer) if len(self.buffer) < buffer_size else buffer_size buffer_size = (
len(self.buffer) if len(self.buffer) < buffer_size else buffer_size
)
data = self.buffer[:buffer_size] data = self.buffer[:buffer_size]
self.buffer = self.buffer[buffer_size:] self.buffer = self.buffer[buffer_size:]
@ -70,16 +75,17 @@ class AsyncReader:
async def unread(self, data): async def unread(self, data):
if not data: if not data:
return return
if hasattr(data, 'encode'): if hasattr(data, "encode"):
data = data.encode() data = data.encode()
self.buffer = data + self.buffer self.buffer = data + self.buffer
class Socket: class Socket:
def __init__(self, reader, writer, buffer_size,debug=False): def __init__(self, reader, writer, buffer_size, debug=False):
self.debug = debug self.debug = debug
self.reader = AsyncReader(reader,debug=self.debug) self.reader = AsyncReader(reader, debug=self.debug)
self.writer = AsyncWriter(writer,debug=self.debug) self.writer = AsyncWriter(writer, debug=self.debug)
self.read = self.reader.read self.read = self.reader.read
self.read_until = self.reader.read_until self.read_until = self.reader.read_until
self.unread = self.reader.unread self.unread = self.reader.unread
@ -88,12 +94,14 @@ class Socket:
self.close = self.writer.close self.close = self.writer.close
self.wait_closed = self.writer.wait_closed self.wait_closed = self.writer.wait_closed
class Application: class Application:
def __init__(self, upstream_host, upstream_port, *args, **kwargs): def __init__(self, upstream_host, upstream_port, silent=False, *args, **kwargs):
self.upstream_host = upstream_host self.upstream_host = upstream_host
self.upstream_port = upstream_port self.upstream_port = upstream_port
self.server = None self.server = None
self.silent = silent
self.host = None self.host = None
self.port = None self.port = None
self.executor = None self.executor = None
@ -104,8 +112,7 @@ class Application:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
async def get_headers(self, reader): async def get_headers(self, reader):
data = b'' headers = await reader.read_until(b"\r\n\r\n")
headers = await reader.read_until(b'\r\n\r\n')
if not headers: if not headers:
return None, None return None, None
headers = headers[:-2] headers = headers[:-2]
@ -121,26 +128,31 @@ class Application:
def header_dict_to_bytes(self, req_resp, headers): def header_dict_to_bytes(self, req_resp, headers):
header_list = [req_resp] header_list = [req_resp]
for key, value in headers.items(): for key, value in headers.items():
header_list.append("{}: {}".format(key, value)) header_list.append(f"{key}: {value}")
header_list.append("\r\n") header_list.append("\r\n")
return ("\r\n".join(header_list)).encode() return ("\r\n".join(header_list)).encode()
async def stream(self, reader,writer,is_websocket=False): async def stream(self, reader, writer, is_websocket=False):
global headers global headers
try: try:
reader = Socket(reader,writer, ZAMENYAT_BUFFER_SIZE) reader = Socket(reader, writer, ZAMENYAT_BUFFER_SIZE)
writer = Socket(reader,writer, ZAMENYAT_BUFFER_SIZE) writer = Socket(reader, writer, ZAMENYAT_BUFFER_SIZE)
while True: while True:
req_resp, headers = None, None req_resp, headers = None, None
data = b'' data = b""
if not is_websocket: if not is_websocket:
req_resp, headers = await self.get_headers(reader) req_resp, headers = await self.get_headers(reader)
if not headers: if not headers:
return None return None
if headers: if headers:
if 'Content-Length' in headers: if "Content-Length" in headers:
while len(data) != headers['Content-Length']: while len(data) != headers["Content-Length"]:
chunk_size = headers['Content-Length'] - len(data) if self.buffer_size > headers['Content-Length'] - len(data) else self.buffer_size chunk_size = (
headers["Content-Length"] - len(data)
if self.buffer_size
> headers["Content-Length"] - len(data)
else self.buffer_size
)
chunk = await reader.read(chunk_size) chunk = await reader.read(chunk_size)
if not chunk: if not chunk:
data = None data = None
@ -160,44 +172,58 @@ class Application:
finally: finally:
pass pass
return headers return headers
#writer.close() # writer.close()
#await writer.wait_closed() # await writer.wait_closed()
async def handle_client(self,reader,writer): async def handle_client(self, reader, writer):
self.connection_count += 1 self.connection_count += 1
self.total_connection_count += 1 self.total_connection_count += 1
connection_nr = self.total_connection_count connection_nr = self.total_connection_count
upstream_reader, upstream_writer = await asyncio.open_connection(
upstream_reader, upstream_writer = await asyncio.open_connection(self.upstream_host, self.upstream_port) self.upstream_host, self.upstream_port
)
is_websocket = False is_websocket = False
while True: while True:
time_start = time.time() time_start = time.time()
print(f"Connected to upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Time: {get_timestamp()}") if not self.silent:
print(
f"Connected to upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Time: {get_timestamp()}"
)
if is_websocket: if is_websocket:
await asyncio.gather( await asyncio.gather(
self.stream(reader, upstream_writer,is_websocket), self.stream(reader, upstream_writer, is_websocket),
self.stream(upstream_reader, writer, is_websocket) self.stream(upstream_reader, writer, is_websocket),
)
if not self.silent:
print(
f"Disconnected upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Duration: {time_duration:.5f}s"
) )
break break
else: else:
request_headers = await self.stream(reader, upstream_writer,is_websocket) request_headers = await self.stream(
reader, upstream_writer, is_websocket
)
await self.stream(upstream_reader, writer, is_websocket) await self.stream(upstream_reader, writer, is_websocket)
keep_alive = False keep_alive = False
if request_headers: if request_headers:
if request_headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'): if (
request_headers.get("Connection") == "keep-alive"
): # and not headers.get('Upgrade-Insecure-Requests'):
keep_alive = True keep_alive = True
if request_headers.get("Upgrade") == 'websocket': if request_headers.get("Upgrade") == "websocket":
is_websocket = True is_websocket = True
time_end = time.time() time_end = time.time()
time_duration = time_end - time_start time_duration = time_end - time_start
print(f"Disconnected upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Duration: {time_duration:.5f}s") if not self.silent:
print(
f"Disconnected upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Duration: {time_duration:.5f}s"
)
if not any([keep_alive, is_websocket]): if not any([keep_alive, is_websocket]):
break break
@ -209,25 +235,24 @@ class Application:
await writer.wait_closed() await writer.wait_closed()
await upstream_writer.wait_closed() await upstream_writer.wait_closed()
def upgrade_executor(self, thread_count): def upgrade_executor(self, thread_count):
self.executor = Executor(max_workers=thread_count) self.executor = Executor(max_workers=thread_count)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
loop.set_default_executor(self.executor) loop.set_default_executor(self.executor)
return self.executor return self.executor
async def serve_async(self, host,port,backlog=ZAMENYAT_BACKLOG): async def serve_async(self, host, port, backlog=ZAMENYAT_BACKLOG):
self.upgrade_executor(ZAMENYAT_THREAD_COUNT) self.upgrade_executor(ZAMENYAT_THREAD_COUNT)
self.host = host self.host = host
self.port = port self.port = port
self.server = await asyncio.start_server(self.handle_client, self.host, self.port,backlog=backlog) self.server = await asyncio.start_server(
self.handle_client, self.host, self.port, backlog=backlog
)
async with self.server: async with self.server:
await self.server.serve_forever() await self.server.serve_forever()
def serve(self, host, port,backlog=ZAMENYAT_BACKLOG): def serve(self, host, port, backlog=ZAMENYAT_BACKLOG):
try: try:
asyncio.run(self.serve_async(host,port,backlog=backlog)) asyncio.run(self.serve_async(host, port, backlog=backlog))
except KeyboardInterrupt: except KeyboardInterrupt:
print("Shutted down server") print("Shutted down server")