Raw source file available here .
import asyncio
import logging
import pathlib
import ssl
import uuid
import signal
from datetime import datetime
from contextlib import asynccontextmanager
from snek import snode
from snek.view.threads import ThreadsView
logging.basicConfig(level=logging.DEBUG)
from concurrent.futures import ThreadPoolExecutor
from ipaddress import ip_address
import IP2Location
from aiohttp import web
from aiohttp_session import (
get_session as session_get,
session_middleware,
setup as session_setup,
)
from aiohttp_session.cookie_storage import EncryptedCookieStorage
from app.app import Application as BaseApplication
from jinja2 import FileSystemLoader
from snek.mapper import get_mappers
from snek.service import get_services
from snek.sgit import GitApplication
from snek.sssh import start_ssh_server
from snek.system import http
from snek.system.cache import Cache
from snek.system.markdown import MarkdownExtension
from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware
from snek.system.profiler import profiler_handler
from snek.system.template import (
EmojiExtension,
LinkifyExtension,
PythonExtension,
sanitize_html,
)
from snek.view.about import AboutHTMLView, AboutMDView
from snek.view.avatar import AvatarView
from snek.view.channel import ChannelAttachmentView,ChannelAttachmentUploadView, ChannelView
from snek.view.docs import DocsHTMLView, DocsMDView
from snek.view.drive import DriveApiView, DriveView
from snek.view.channel import ChannelDriveApiView
from snek.view.index import IndexView
from snek.view.login import LoginView
from snek.view.logout import LogoutView
from snek.view.push import PushView
from snek.view.register import RegisterView
from snek.view.repository import RepositoryView
from snek.view.rpc import RPCView
from snek.view.search_user import SearchUserView
from snek.view.container import ContainerView
from snek.view.settings.containers import (
ContainersCreateView,
ContainersDeleteView,
ContainersIndexView,
ContainersUpdateView,
)
from snek.view.settings.index import SettingsIndexView
from snek.view.settings.profile import SettingsProfileView
from snek.view.settings.repositories import (
RepositoriesCreateView,
RepositoriesDeleteView,
RepositoriesIndexView,
RepositoriesUpdateView,
)
from snek.view.stats import StatsView
from snek.view.status import StatusView
from snek.view.terminal import TerminalSocketView, TerminalView
from snek.view.upload import UploadView
from snek.view.user import UserView
from snek.view.web import WebView
from snek.webdav import WebdavApplication
SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34"
from snek.system.template import whitelist_attributes
@web.middleware
async def session_middleware(request, handler):
setattr(request, "session", await session_get(request))
response = await handler(request)
return response
@web.middleware
async def ip2location_middleware(request, handler):
response = await handler(request)
return response
ip = request.headers.get("X-Forwarded-For", request.remote)
ipaddress = ip_address(ip)
if ipaddress.is_private:
return response
if not request.app.session.get("uid"):
return response
user = await request.app.services.user.get(uid=request.app.session.get("uid"))
if not user:
return response
location = request.app.ip2location.get(ip)
user["city"]
if user["city"] != location.city:
user["country_long"] = location.country
user["country_short"] = locaion.country_short
user["city"] = location.city
user["region"] = location.region
user["latitude"] = location.latitude
user["longitude"] = location.longitude
user["ip"] = ip
await request.app.services.user.update(user)
return response
@web.middleware
async def trailing_slash_middleware(request, handler):
if request.path and not request.path.endswith("/"):
# Redirect to the same path with a trailing slash
raise web.HTTPFound(request.path + "/")
return await handler(request)
class Application(BaseApplication):
def __init__(self, *args, **kwargs):
middlewares = [
cors_middleware,
web.normalize_path_middleware(merge_slashes=True),
ip2location_middleware,
csp_middleware,
]
self.template_path = pathlib.Path(__file__).parent.joinpath("templates")
self.static_path = pathlib.Path(__file__).parent.joinpath("static")
super().__init__(
middlewares=middlewares,
template_path=self.template_path,
client_max_size=1024 * 1024 * 1024 * 5 * args,
**kwargs,
)
session_setup(self, EncryptedCookieStorage(SESSION_KEY))
self.tasks = asyncio.Queue()
self._middlewares.append(session_middleware)
self._middlewares.append(auth_middleware)
self.jinja2_env.add_extension(MarkdownExtension)
self.jinja2_env.add_extension(LinkifyExtension)
self.jinja2_env.add_extension(PythonExtension)
self.jinja2_env.add_extension(EmojiExtension)
self.jinja2_env.filters["sanitize"] = sanitize_html
self.time_start = datetime.now()
self.ssh_host = "0.0.0.0"
self.ssh_port = 2242
self.setup_router()
self.ssh_server = None
self.sync_service = None
self.executor = None
self.cache = Cache(self)
self.services = get_services(app=self)
self.mappers = get_mappers(app=self)
self.broadcast_service = None
self.user_availability_service_task = None
base_path = pathlib.Path(__file__).parent
self.ip2location = IP2Location.IP2Location(
base_path.joinpath("IP2LOCATION-LITE-DB11.BIN")
)
self.on_startup.append(self.prepare_asyncio)
self.on_startup.append(self.start_user_availability_service)
self.on_startup.append(self.start_ssh_server)
self.on_startup.append(self.prepare_database)
@property
def uptime_seconds(self):
return (datetime.now() - self.time_start).total_seconds()
@property
def uptime(self):
return self._format_uptime(self.uptime_seconds)
def _format_uptime(self, seconds):
seconds = int(seconds)
days, seconds = divmod(seconds, 86400)
hours, seconds = divmod(seconds, 3600)
minutes, seconds = divmod(seconds, 60)
parts = []
if days > 0:
parts.append(f"{days} day{'s' if days != 1 else ''}")
if hours > 0:
parts.append(f"{hours} hour{'s' if hours != 1 else ''}")
if minutes > 0:
parts.append(f"{minutes} minute{'s' if minutes != 1 else ''}")
if seconds > 0 or not parts:
parts.append(f"{seconds} second{'s' if seconds != 1 else ''}")
return ", ".join(parts)
async def start_user_availability_service(self, app):
app.user_availability_service_task = asyncio.create_task(
app.services.socket.user_availability_service()
)
async def snode_sync(self, app):
self.sync_service = asyncio.create_task(snode.sync_service(app))
async def start_ssh_server(self, app):
app.ssh_server = await start_ssh_server(app, app.ssh_host, app.ssh_port)
if app.ssh_server:
asyncio.create_task(app.ssh_server.wait_closed())
async def prepare_asyncio(self, app):
# app.loop = asyncio.get_running_loop()
app.executor = ThreadPoolExecutor(max_workers=200)
app.loop.set_default_executor(self.executor)
#for sig in (signal.SIGINT, signal.SIGTERM):
#app.loop.add_signal_handler(
# sig, lambda: asyncio.create_task(self.services.container.shutdown())
#)
async def create_task(self, task):
await self.tasks.put(task)
async def task_runner(self):
while True:
task = await self.tasks.get()
self.db.begin()
try:
await task
self.tasks.task_done()
except Exception as ex:
print(ex)
self.db.commit()
async def prepare_database(self, app):
self.db.query("PRAGMA journal_mode=WAL")
self.db.query("PRAGMA syncnorm=off")
try:
if not self.db["user"].has_index("username"):
self.db["user"].create_index("username", unique=True)
if not self.db["channel_member"].has_index(["channel_uid", "user_uid"]):
self.db["channel_member"].create_index(["channel_uid", "user_uid"])
if not self.db["channel_message"].has_index(["channel_uid", "user_uid"]):
self.db["channel_message"].create_index(["channel_uid", "user_uid"])
except:
pass
await self.services.drive.prepare_all()
self.loop.create_task(self.task_runner())
def setup_router(self):
self.router.add_get("/", IndexView)
self.router.add_static(
"/",
pathlib.Path(__file__).parent.joinpath("static"),
name="static",
show_index=True,
)
self.router.add_view("/profiler.html", profiler_handler)
self.router.add_view("/container/sock/{channel_uid}.json", ContainerView)
self.router.add_view("/about.html", AboutHTMLView)
self.router.add_view("/about.md", AboutMDView)
self.router.add_view("/logout.json", LogoutView)
self.router.add_view("/logout.html", LogoutView)
self.router.add_view("/docs.html", DocsHTMLView)
self.router.add_view("/docs.md", DocsMDView)
self.router.add_view("/status.json", StatusView)
self.router.add_view("/settings/index.html", SettingsIndexView)
self.router.add_view("/settings/profile.html", SettingsProfileView)
self.router.add_view("/settings/profile.json", SettingsProfileView)
self.router.add_view("/push.json", PushView)
self.router.add_view("/web.html", WebView)
self.router.add_view("/login.html", LoginView)
self.router.add_view("/login.json", LoginView)
self.router.add_view("/register.html", RegisterView)
self.router.add_view("/register.json", RegisterView)
self.router.add_view("/drive/{rel_path:.*}", DriveView)
self.router.add_view("/drive.bin", UploadView)
self.router.add_view("/drive.bin/{uid}.{ext}", UploadView)
self.router.add_view("/search-user.html", SearchUserView)
self.router.add_view("/search-user.json", SearchUserView)
self.router.add_view("/avatar/{uid}.svg", AvatarView)
self.router.add_get("/http-get", self.handle_http_get)
self.router.add_get("/http-photo", self.handle_http_photo)
self.router.add_get("/rpc.ws", RPCView)
self.router.add_get("/c/{channel:.*}", ChannelView)
self.router.add_view(
"/channel/{channel_uid}/attachment.bin", ChannelAttachmentView
)
self.router.add_view(
"/channel/{channel_uid}/drive.json", ChannelDriveApiView
)
self.router.add_view(
"/channel/{channel_uid}/attachment.sock", ChannelAttachmentUploadView
)
self.router.add_view(
"/channel/attachment/{relative_url:.*}", ChannelAttachmentView
)
self.router.add_view("/channel/{channel}.html", WebView)
self.router.add_view("/threads.html", ThreadsView)
self.router.add_view("/terminal.ws", TerminalSocketView)
self.router.add_view("/terminal.html", TerminalView)
self.router.add_view("/drive.json", DriveApiView)
self.router.add_view("/drive.html", DriveView)
self.router.add_view("/drive/{drive}.json", DriveView)
self.router.add_view("/stats.json", StatsView)
self.router.add_view("/user/{user}.html", UserView)
self.router.add_view("/repository/{username}/{repository}", RepositoryView)
self.router.add_view(
"/repository/{username}/{repository}/{path:.*}", RepositoryView
)
self.router.add_view("/settings/repositories/index.html", RepositoriesIndexView)
self.router.add_view(
"/settings/repositories/create.html", RepositoriesCreateView
)
self.router.add_view(
"/settings/repositories/repository/{name}/update.html",
RepositoriesUpdateView,
)
self.router.add_view(
"/settings/repositories/repository/{name}/delete.html",
RepositoriesDeleteView,
)
self.router.add_view("/settings/containers/index.html", ContainersIndexView)
self.router.add_view("/settings/containers/create.html", ContainersCreateView)
self.router.add_view(
"/settings/containers/container/{uid}/update.html", ContainersUpdateView
)
self.router.add_view(
"/settings/containers/container/{uid}/delete.html", ContainersDeleteView
)
self.webdav = WebdavApplication(self)
self.git = GitApplication(self)
self.add_subapp("/webdav", self.webdav)
self.add_subapp("/git", self.git)
# self.router.add_get("/{file_path:.*}", self.static_handler)
async def handle_test(self, request):
return await whitelist_attributes(
self.render_template("test.html", request, context={"name": "retoor"})
)
async def handle_http_get(self, request: web.Request):
url = request.query.get("url")
content = await http.get(url)
return web.Response(body=content)
async def handle_http_photo(self, request):
url = request.query.get("url")
path = await http.create_site_photo(url)
return web.Response(
body=path.read_bytes(), headers={"Content-Type": "image/png"}
)
# @time_cache_async(60)
async def render_template(self, template, request, context=None):
channels = []
if not context:
context = {}
context["rid"] = str(uuid.uuid4())
if request.session.get("uid"):
async for subscribed_channel in self.services.channel_member.find(
user_uid=request.session.get("uid"), deleted_at=None, is_banned=False
):
parent_object = await subscribed_channel.get_channel()
item = {}
other_user = await self.services.channel_member.get_other_dm_user(
subscribed_channel["channel_uid"], request.session.get("uid")
)
last_message = await parent_object.get_last_message()
color = None
if last_message:
last_message_user = await last_message.get_user()
color = last_message_user["color"]
item["color"] = color
item["last_message_on"] = parent_object["last_message_on"]
item["is_private"] = parent_object["tag"] == "dm"
if other_user:
item["name"] = other_user["nick"]
item["uid"] = subscribed_channel["channel_uid"]
else:
item["name"] = subscribed_channel["label"]
item["uid"] = subscribed_channel["channel_uid"]
item["new_count"] = subscribed_channel["new_count"]
channels.append(item)
channels.sort(key=lambda x: x["last_message_on"] or "", reverse=True)
if "channels" not in context:
context["channels"] = channels
if "user" not in context:
context["user"] = await self.services.user.get(
request.session.get("uid")
)
self.template_path.joinpath(template)
await self.services.user.get_template_path(request.session.get("uid"))
self.original_loader = self.jinja2_env.loader
self.jinja2_env.loader = await self.get_user_template_loader(
request.session.get("uid")
)
try:
context["nonce"] = request['csp_nonce']
except:
context['nonce'] = '?'
rendered = await super().render_template(template, request, context)
self.jinja2_env.loader = self.original_loader
# rendered.text = whitelist_attributes(rendered.text)
# rendered.headers['Content-Lenght'] = len(rendered.text)
return rendered
async def static_handler(self, request):
file_name = request.match_info.get("filename", "")
paths = []
uid = request.session.get("uid")
if uid:
user_static_path = await self.services.user.get_static_path(uid)
if user_static_path:
paths.append(user_static_path)
for admin_uid in self.services.user.get_admin_uids():
user_static_path = await self.services.user.get_static_path(admin_uid)
if user_static_path:
paths.append(user_static_path)
paths.append(self.static_path)
for path in paths:
if pathlib.Path(path).joinpath(file_name).exists():
return web.FileResponse(pathlib.Path(path).joinpath(file_name))
return web.HTTPNotFound()
async def get_user_template_loader(self, uid=None):
template_paths = []
for admin_uid in self.services.user.get_admin_uids():
user_template_path = await self.services.user.get_template_path(admin_uid)
if user_template_path:
template_paths.append(user_template_path)
if uid:
user_template_path = await self.services.user.get_template_path(uid)
if user_template_path:
template_paths.append(user_template_path)
template_paths.append(self.template_path)
return FileSystemLoader(template_paths)
@asynccontextmanager
async def no_save(self):
stats = {
'count': 0
}
async def patched_save(*args, **kwargs):
await self.cache.set(args[0]["uid"], args[0])
stats['count'] = stats['count'] + 1
print(f"save is ignored {stats['count']} times")
return args[0]
save_original = self.services.channel_message.mapper.save
self.services.channel_message.mapper.save = patched_save
raised_exception = None
try:
yield
except Exception as ex:
raised_exception = ex
finally:
self.services.channel_message.mapper.save = save_original
if raised_exception:
raise raised_exception
app = Application(db_path="sqlite:///snek.db")
async def main():
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain("cert.pem", "key.pem")
await web._run_app(app, port=8081, host="0.0.0.0", ssl_context=ssl_context)
if __name__ == "__main__":
asyncio.run(main())