Perfect working version with websocket.
This commit is contained in:
parent
19f8f938ff
commit
f506e5e52b
@ -4,12 +4,13 @@ from concurrent.futures import ThreadPoolExecutor as Executor
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
ZAMENYAT_THREAD_COUNT = 500
|
ZAMENYAT_THREAD_COUNT = 500
|
||||||
ZAMENYAT_BUFFER_SIZE = 4096*2
|
ZAMENYAT_BUFFER_SIZE = 64
|
||||||
ZAMENYAT_HEADER_MAX_LENGTH = 4096*2
|
ZAMENYAT_HEADER_MAX_LENGTH = 4096*2
|
||||||
|
|
||||||
class AsyncWriter:
|
class AsyncWriter:
|
||||||
|
|
||||||
def __init__(self, writer, buffer_size=ZAMENYAT_BUFFER_SIZE):
|
def __init__(self, writer, buffer_size=ZAMENYAT_BUFFER_SIZE,debug=False):
|
||||||
|
self.debug = debug
|
||||||
self.writer = writer
|
self.writer = writer
|
||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
self.drain = self.writer.drain
|
self.drain = self.writer.drain
|
||||||
@ -22,14 +23,29 @@ class AsyncWriter:
|
|||||||
chunk_size = self.buffer_size if len(data) > self.buffer_size else len(data)
|
chunk_size = self.buffer_size if len(data) > self.buffer_size else len(data)
|
||||||
chunk = data[:chunk_size]
|
chunk = data[:chunk_size]
|
||||||
self.writer.write(chunk)
|
self.writer.write(chunk)
|
||||||
|
if self.debug:
|
||||||
|
print("Write chunk:", chunk)
|
||||||
data = data[chunk_size:]
|
data = data[chunk_size:]
|
||||||
await self.writer.drain()
|
await self.writer.drain()
|
||||||
|
|
||||||
class AsyncReader:
|
class AsyncReader:
|
||||||
|
|
||||||
def __init__(self, reader):
|
def __init__(self, reader,debug=False):
|
||||||
self.reader = reader
|
self.reader = reader
|
||||||
self.buffer = b''
|
self.buffer = b''
|
||||||
|
self.debug = debug
|
||||||
|
|
||||||
|
async def read_until(self, to_match):
|
||||||
|
buffer = b''
|
||||||
|
while not to_match in buffer:
|
||||||
|
chunk = await self.read()
|
||||||
|
if not chunk:
|
||||||
|
return None
|
||||||
|
buffer += chunk
|
||||||
|
match_start = buffer.find(to_match)
|
||||||
|
data = buffer[:match_start + len(to_match)]
|
||||||
|
await self.unread(buffer[match_start + len(to_match):])
|
||||||
|
return data
|
||||||
|
|
||||||
async def read(self, buffer_size=ZAMENYAT_BUFFER_SIZE, exact=False):
|
async def read(self, buffer_size=ZAMENYAT_BUFFER_SIZE, exact=False):
|
||||||
read_extra = buffer_size - len(self.buffer)
|
read_extra = buffer_size - len(self.buffer)
|
||||||
@ -38,12 +54,16 @@ class AsyncReader:
|
|||||||
chunk = await self.reader.read(chunk_size)
|
chunk = await self.reader.read(chunk_size)
|
||||||
if not chunk:
|
if not chunk:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
print("Read chunk:", chunk)
|
||||||
self.buffer += chunk
|
self.buffer += chunk
|
||||||
if not exact:
|
if not exact:
|
||||||
break
|
break
|
||||||
buffer_size = len(self.buffer) if len(self.buffer) < buffer_size else buffer_size
|
buffer_size = len(self.buffer) if len(self.buffer) < buffer_size else buffer_size
|
||||||
data = self.buffer[:buffer_size]
|
data = self.buffer[:buffer_size]
|
||||||
self.buffer = self.buffer[buffer_size:]
|
self.buffer = self.buffer[buffer_size:]
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def unread(self, data):
|
async def unread(self, data):
|
||||||
@ -55,10 +75,12 @@ class AsyncReader:
|
|||||||
|
|
||||||
class Socket:
|
class Socket:
|
||||||
|
|
||||||
def __init__(self, reader, writer, buffer_size):
|
def __init__(self, reader, writer, buffer_size,debug=True):
|
||||||
self.reader = AsyncReader(reader)
|
self.debug = debug
|
||||||
self.writer = AsyncWriter(writer)
|
self.reader = AsyncReader(reader,debug=self.debug)
|
||||||
|
self.writer = AsyncWriter(writer,debug=self.debug)
|
||||||
self.read = self.reader.read
|
self.read = self.reader.read
|
||||||
|
self.read_until = self.reader.read_until
|
||||||
self.unread = self.reader.unread
|
self.unread = self.reader.unread
|
||||||
self.write = self.writer.write
|
self.write = self.writer.write
|
||||||
self.drain = self.writer.drain
|
self.drain = self.writer.drain
|
||||||
@ -82,25 +104,13 @@ class Application:
|
|||||||
|
|
||||||
async def get_headers(self, reader):
|
async def get_headers(self, reader):
|
||||||
data = b''
|
data = b''
|
||||||
headers = None
|
headers = await reader.read_until(b'\r\n\r\n')
|
||||||
while True:
|
|
||||||
chunk = await reader.read(self.buffer_size)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
data += chunk
|
|
||||||
if len(data) > self.header_max_length:
|
|
||||||
break
|
|
||||||
headers_end = data.find(b'\r\n\r\n')
|
|
||||||
if headers_end:
|
|
||||||
headers = data[:headers_end]
|
|
||||||
data = data[headers_end + 4:]
|
|
||||||
await reader.unread(data)
|
|
||||||
break
|
|
||||||
if not headers:
|
if not headers:
|
||||||
return None, None
|
return None, None
|
||||||
|
headers = headers[:-2]
|
||||||
header_dict = {}
|
header_dict = {}
|
||||||
req_resp, *headers = headers.split(b"\r\n")
|
req_resp, *headers = headers.split(b"\r\n")
|
||||||
for header_line in headers:
|
for header_line in headers[:-1]:
|
||||||
key, *value = header_line.split(b": ")
|
key, *value = header_line.split(b": ")
|
||||||
key = key.decode()
|
key = key.decode()
|
||||||
value = ": ".join([value.decode() for value in value])
|
value = ": ".join([value.decode() for value in value])
|
||||||
@ -125,7 +135,7 @@ class Application:
|
|||||||
if not is_websocket:
|
if not is_websocket:
|
||||||
req_resp, headers = await self.get_headers(reader)
|
req_resp, headers = await self.get_headers(reader)
|
||||||
if not headers:
|
if not headers:
|
||||||
break
|
return None
|
||||||
if headers:
|
if headers:
|
||||||
if 'Content-Length' in headers:
|
if 'Content-Length' in headers:
|
||||||
while len(data) != headers['Content-Length']:
|
while len(data) != headers['Content-Length']:
|
||||||
@ -134,7 +144,7 @@ class Application:
|
|||||||
chunk = await reader.read(chunk_size)
|
chunk = await reader.read(chunk_size)
|
||||||
if not chunk:
|
if not chunk:
|
||||||
data = None
|
data = None
|
||||||
break
|
return None
|
||||||
print("Aff read")
|
print("Aff read")
|
||||||
data += chunk
|
data += chunk
|
||||||
await writer.write(self.header_dict_to_bytes(req_resp, headers))
|
await writer.write(self.header_dict_to_bytes(req_resp, headers))
|
||||||
@ -146,7 +156,8 @@ class Application:
|
|||||||
await writer.write(data)
|
await writer.write(data)
|
||||||
#if not headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'):
|
#if not headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'):
|
||||||
# break
|
# break
|
||||||
break
|
if not is_websocket:
|
||||||
|
break
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
@ -168,23 +179,31 @@ class Application:
|
|||||||
while True:
|
while True:
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
print(f"Connected to upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Time: {get_timestamp()}")
|
print(f"Connected to upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Time: {get_timestamp()}")
|
||||||
|
|
||||||
request_headers = await self.stream(reader, upstream_writer,is_websocket)
|
if is_websocket:
|
||||||
|
await asyncio.gather(
|
||||||
keep_alive = False
|
self.stream(reader, upstream_writer,is_websocket),
|
||||||
if request_headers:
|
self.stream(upstream_reader, writer, is_websocket)
|
||||||
if request_headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'):
|
)
|
||||||
keep_alive = True
|
|
||||||
if request_headers.get("Upgrade") == 'websocket':
|
|
||||||
is_websocket = True
|
|
||||||
await self.stream(upstream_reader, writer, is_websocket)
|
|
||||||
time_end = time.time()
|
|
||||||
time_duration = time_end - time_start
|
|
||||||
print(f"Disconnected upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Duration: {time_duration:.5f}s")
|
|
||||||
|
|
||||||
|
|
||||||
if not any([keep_alive, is_websocket]):
|
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
request_headers = await self.stream(reader, upstream_writer,is_websocket)
|
||||||
|
await self.stream(upstream_reader, writer, is_websocket)
|
||||||
|
|
||||||
|
keep_alive = False
|
||||||
|
if request_headers:
|
||||||
|
if request_headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'):
|
||||||
|
keep_alive = True
|
||||||
|
if request_headers.get("Upgrade") == 'websocket':
|
||||||
|
is_websocket = True
|
||||||
|
|
||||||
|
time_end = time.time()
|
||||||
|
time_duration = time_end - time_start
|
||||||
|
print(f"Disconnected upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Duration: {time_duration:.5f}s")
|
||||||
|
|
||||||
|
|
||||||
|
if not any([keep_alive, is_websocket]):
|
||||||
|
break
|
||||||
|
|
||||||
self.connection_count -= 1
|
self.connection_count -= 1
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user