Format.
This commit is contained in:
parent
4badc7ae7d
commit
3069fc0168
@ -1,35 +1,22 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from zamenyat.app import Application
|
from zamenyat.app import Application
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Zamenyat sensitive content replacer.")
|
parser = argparse.ArgumentParser(description="Zamenyat sensitive content replacer.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--host", required=True, type=str)
|
||||||
"--host",
|
parser.add_argument("--port", required=True, type=int)
|
||||||
required=True,
|
parser.add_argument("--upstream-host", required=True, type=str)
|
||||||
type=str
|
parser.add_argument("--upstream-port", required=True, type=int)
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--port",
|
|
||||||
required=True,
|
|
||||||
type=int
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--upstream-host",
|
|
||||||
required=True,
|
|
||||||
type=str
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--upstream-port",
|
|
||||||
required=True,
|
|
||||||
type=int
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
app = Application(upstream_host=args.upstream_host, upstream_port=args.upstream_port)
|
app = Application(
|
||||||
|
upstream_host=args.upstream_host, upstream_port=args.upstream_port
|
||||||
|
)
|
||||||
app.serve(host=args.host, port=args.port)
|
app.serve(host=args.host, port=args.port)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
from app.app import Application as BaseApplication, get_timestamp
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import ThreadPoolExecutor as Executor
|
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor as Executor
|
||||||
|
|
||||||
|
from app.app import get_timestamp
|
||||||
|
|
||||||
ZAMENYAT_BACKLOG = 100
|
ZAMENYAT_BACKLOG = 100
|
||||||
ZAMENYAT_THREAD_COUNT = 2500
|
ZAMENYAT_THREAD_COUNT = 2500
|
||||||
ZAMENYAT_BUFFER_SIZE = 4096
|
ZAMENYAT_BUFFER_SIZE = 4096
|
||||||
ZAMENYAT_HEADER_MAX_LENGTH = 4096 * 2
|
ZAMENYAT_HEADER_MAX_LENGTH = 4096 * 2
|
||||||
|
|
||||||
|
|
||||||
class AsyncWriter:
|
class AsyncWriter:
|
||||||
|
|
||||||
def __init__(self, writer, buffer_size=ZAMENYAT_BUFFER_SIZE, debug=False):
|
def __init__(self, writer, buffer_size=ZAMENYAT_BUFFER_SIZE, debug=False):
|
||||||
@ -29,16 +31,17 @@ class AsyncWriter:
|
|||||||
data = data[chunk_size:]
|
data = data[chunk_size:]
|
||||||
await self.writer.drain()
|
await self.writer.drain()
|
||||||
|
|
||||||
|
|
||||||
class AsyncReader:
|
class AsyncReader:
|
||||||
|
|
||||||
def __init__(self, reader, debug=False):
|
def __init__(self, reader, debug=False):
|
||||||
self.reader = reader
|
self.reader = reader
|
||||||
self.buffer = b''
|
self.buffer = b""
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
|
||||||
async def read_until(self, to_match):
|
async def read_until(self, to_match):
|
||||||
buffer = b''
|
buffer = b""
|
||||||
while not to_match in buffer:
|
while to_match not in buffer:
|
||||||
chunk = await self.read()
|
chunk = await self.read()
|
||||||
if not chunk:
|
if not chunk:
|
||||||
return None
|
return None
|
||||||
@ -49,7 +52,7 @@ class AsyncReader:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
async def read(self, buffer_size=ZAMENYAT_BUFFER_SIZE, exact=False):
|
async def read(self, buffer_size=ZAMENYAT_BUFFER_SIZE, exact=False):
|
||||||
read_extra = buffer_size - len(self.buffer)
|
buffer_size - len(self.buffer)
|
||||||
while len(self.buffer) < buffer_size:
|
while len(self.buffer) < buffer_size:
|
||||||
chunk_size = buffer_size - len(self.buffer)
|
chunk_size = buffer_size - len(self.buffer)
|
||||||
chunk = await self.reader.read(chunk_size)
|
chunk = await self.reader.read(chunk_size)
|
||||||
@ -61,7 +64,9 @@ class AsyncReader:
|
|||||||
self.buffer += chunk
|
self.buffer += chunk
|
||||||
if not exact:
|
if not exact:
|
||||||
break
|
break
|
||||||
buffer_size = len(self.buffer) if len(self.buffer) < buffer_size else buffer_size
|
buffer_size = (
|
||||||
|
len(self.buffer) if len(self.buffer) < buffer_size else buffer_size
|
||||||
|
)
|
||||||
data = self.buffer[:buffer_size]
|
data = self.buffer[:buffer_size]
|
||||||
self.buffer = self.buffer[buffer_size:]
|
self.buffer = self.buffer[buffer_size:]
|
||||||
|
|
||||||
@ -70,10 +75,11 @@ class AsyncReader:
|
|||||||
async def unread(self, data):
|
async def unread(self, data):
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
if hasattr(data, 'encode'):
|
if hasattr(data, "encode"):
|
||||||
data = data.encode()
|
data = data.encode()
|
||||||
self.buffer = data + self.buffer
|
self.buffer = data + self.buffer
|
||||||
|
|
||||||
|
|
||||||
class Socket:
|
class Socket:
|
||||||
|
|
||||||
def __init__(self, reader, writer, buffer_size, debug=False):
|
def __init__(self, reader, writer, buffer_size, debug=False):
|
||||||
@ -88,12 +94,14 @@ class Socket:
|
|||||||
self.close = self.writer.close
|
self.close = self.writer.close
|
||||||
self.wait_closed = self.writer.wait_closed
|
self.wait_closed = self.writer.wait_closed
|
||||||
|
|
||||||
|
|
||||||
class Application:
|
class Application:
|
||||||
|
|
||||||
def __init__(self, upstream_host, upstream_port, *args, **kwargs):
|
def __init__(self, upstream_host, upstream_port, silent=False, *args, **kwargs):
|
||||||
self.upstream_host = upstream_host
|
self.upstream_host = upstream_host
|
||||||
self.upstream_port = upstream_port
|
self.upstream_port = upstream_port
|
||||||
self.server = None
|
self.server = None
|
||||||
|
self.silent = silent
|
||||||
self.host = None
|
self.host = None
|
||||||
self.port = None
|
self.port = None
|
||||||
self.executor = None
|
self.executor = None
|
||||||
@ -104,8 +112,7 @@ class Application:
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
async def get_headers(self, reader):
|
async def get_headers(self, reader):
|
||||||
data = b''
|
headers = await reader.read_until(b"\r\n\r\n")
|
||||||
headers = await reader.read_until(b'\r\n\r\n')
|
|
||||||
if not headers:
|
if not headers:
|
||||||
return None, None
|
return None, None
|
||||||
headers = headers[:-2]
|
headers = headers[:-2]
|
||||||
@ -121,7 +128,7 @@ class Application:
|
|||||||
def header_dict_to_bytes(self, req_resp, headers):
|
def header_dict_to_bytes(self, req_resp, headers):
|
||||||
header_list = [req_resp]
|
header_list = [req_resp]
|
||||||
for key, value in headers.items():
|
for key, value in headers.items():
|
||||||
header_list.append("{}: {}".format(key, value))
|
header_list.append(f"{key}: {value}")
|
||||||
header_list.append("\r\n")
|
header_list.append("\r\n")
|
||||||
return ("\r\n".join(header_list)).encode()
|
return ("\r\n".join(header_list)).encode()
|
||||||
|
|
||||||
@ -132,15 +139,20 @@ class Application:
|
|||||||
writer = Socket(reader, writer, ZAMENYAT_BUFFER_SIZE)
|
writer = Socket(reader, writer, ZAMENYAT_BUFFER_SIZE)
|
||||||
while True:
|
while True:
|
||||||
req_resp, headers = None, None
|
req_resp, headers = None, None
|
||||||
data = b''
|
data = b""
|
||||||
if not is_websocket:
|
if not is_websocket:
|
||||||
req_resp, headers = await self.get_headers(reader)
|
req_resp, headers = await self.get_headers(reader)
|
||||||
if not headers:
|
if not headers:
|
||||||
return None
|
return None
|
||||||
if headers:
|
if headers:
|
||||||
if 'Content-Length' in headers:
|
if "Content-Length" in headers:
|
||||||
while len(data) != headers['Content-Length']:
|
while len(data) != headers["Content-Length"]:
|
||||||
chunk_size = headers['Content-Length'] - len(data) if self.buffer_size > headers['Content-Length'] - len(data) else self.buffer_size
|
chunk_size = (
|
||||||
|
headers["Content-Length"] - len(data)
|
||||||
|
if self.buffer_size
|
||||||
|
> headers["Content-Length"] - len(data)
|
||||||
|
else self.buffer_size
|
||||||
|
)
|
||||||
chunk = await reader.read(chunk_size)
|
chunk = await reader.read(chunk_size)
|
||||||
if not chunk:
|
if not chunk:
|
||||||
data = None
|
data = None
|
||||||
@ -168,36 +180,50 @@ class Application:
|
|||||||
self.total_connection_count += 1
|
self.total_connection_count += 1
|
||||||
connection_nr = self.total_connection_count
|
connection_nr = self.total_connection_count
|
||||||
|
|
||||||
|
upstream_reader, upstream_writer = await asyncio.open_connection(
|
||||||
upstream_reader, upstream_writer = await asyncio.open_connection(self.upstream_host, self.upstream_port)
|
self.upstream_host, self.upstream_port
|
||||||
|
)
|
||||||
|
|
||||||
is_websocket = False
|
is_websocket = False
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
print(f"Connected to upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Time: {get_timestamp()}")
|
if not self.silent:
|
||||||
|
print(
|
||||||
|
f"Connected to upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Time: {get_timestamp()}"
|
||||||
|
)
|
||||||
|
|
||||||
if is_websocket:
|
if is_websocket:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
self.stream(reader, upstream_writer, is_websocket),
|
self.stream(reader, upstream_writer, is_websocket),
|
||||||
self.stream(upstream_reader, writer, is_websocket)
|
self.stream(upstream_reader, writer, is_websocket),
|
||||||
|
)
|
||||||
|
if not self.silent:
|
||||||
|
print(
|
||||||
|
f"Disconnected upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Duration: {time_duration:.5f}s"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
request_headers = await self.stream(reader, upstream_writer,is_websocket)
|
request_headers = await self.stream(
|
||||||
|
reader, upstream_writer, is_websocket
|
||||||
|
)
|
||||||
await self.stream(upstream_reader, writer, is_websocket)
|
await self.stream(upstream_reader, writer, is_websocket)
|
||||||
|
|
||||||
keep_alive = False
|
keep_alive = False
|
||||||
if request_headers:
|
if request_headers:
|
||||||
if request_headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'):
|
if (
|
||||||
|
request_headers.get("Connection") == "keep-alive"
|
||||||
|
): # and not headers.get('Upgrade-Insecure-Requests'):
|
||||||
keep_alive = True
|
keep_alive = True
|
||||||
if request_headers.get("Upgrade") == 'websocket':
|
if request_headers.get("Upgrade") == "websocket":
|
||||||
is_websocket = True
|
is_websocket = True
|
||||||
|
|
||||||
time_end = time.time()
|
time_end = time.time()
|
||||||
time_duration = time_end - time_start
|
time_duration = time_end - time_start
|
||||||
print(f"Disconnected upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Duration: {time_duration:.5f}s")
|
if not self.silent:
|
||||||
|
print(
|
||||||
|
f"Disconnected upstream #{connection_nr} server {self.upstream_host}:{self.upstream_port} #{self.connection_count} Duration: {time_duration:.5f}s"
|
||||||
|
)
|
||||||
|
|
||||||
if not any([keep_alive, is_websocket]):
|
if not any([keep_alive, is_websocket]):
|
||||||
break
|
break
|
||||||
@ -209,9 +235,6 @@ class Application:
|
|||||||
await writer.wait_closed()
|
await writer.wait_closed()
|
||||||
await upstream_writer.wait_closed()
|
await upstream_writer.wait_closed()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade_executor(self, thread_count):
|
def upgrade_executor(self, thread_count):
|
||||||
self.executor = Executor(max_workers=thread_count)
|
self.executor = Executor(max_workers=thread_count)
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
@ -222,7 +245,9 @@ class Application:
|
|||||||
self.upgrade_executor(ZAMENYAT_THREAD_COUNT)
|
self.upgrade_executor(ZAMENYAT_THREAD_COUNT)
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.server = await asyncio.start_server(self.handle_client, self.host, self.port,backlog=backlog)
|
self.server = await asyncio.start_server(
|
||||||
|
self.handle_client, self.host, self.port, backlog=backlog
|
||||||
|
)
|
||||||
async with self.server:
|
async with self.server:
|
||||||
await self.server.serve_forever()
|
await self.server.serve_forever()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user