|
import argparse
|
|
import base64
|
|
import json
|
|
import pathlib
|
|
import time
|
|
import uuid
|
|
|
|
import aiohttp_jinja2
|
|
import dataset
|
|
import jinja2
|
|
from aiohttp import web
|
|
|
|
from app.agent import Agent
|
|
from app.rpc import Application as RPCApplication
|
|
|
|
from . import log
|
|
|
|
|
|
def get_timestamp():
|
|
from datetime import datetime
|
|
|
|
now = datetime.now()
|
|
formatted_datetime = now.strftime("%Y-%m-%d %H:%M:%S")
|
|
return formatted_datetime
|
|
|
|
class BaseView(web.View):
|
|
|
|
@property
|
|
def app(self):
|
|
return self.request.app
|
|
|
|
@property
|
|
def template_path(self):
|
|
return pathlib.Path(self.request.app.template_path)
|
|
|
|
async def render_template(self, name, context=None):
|
|
if not context:
|
|
context = {}
|
|
return await self.request.app.render_template(str(name), self.request, context)
|
|
|
|
|
|
class BaseApplication(RPCApplication):
|
|
|
|
def __init__(
|
|
self,
|
|
basic_username=None,
|
|
basic_password=None,
|
|
cookie_name=None,
|
|
session=None,
|
|
template_path=None,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
self.cookie_name = cookie_name or str(uuid.uuid4())
|
|
self.basic_username = basic_username
|
|
self.basic_password = basic_password
|
|
self.session = session or {}
|
|
middlewares = kwargs.pop("middlewares", [])
|
|
middlewares.append(self.request_middleware)
|
|
middlewares.append(self.base64_auth_middleware)
|
|
middlewares.append(self.session_middleware)
|
|
self.template_path = (
|
|
template_path
|
|
and template_path
|
|
or pathlib.Path(__file__).parent.joinpath("templates")
|
|
)
|
|
self.agents = {}
|
|
super().__init__(middlewares=middlewares, *args, **kwargs)
|
|
self.jinja2_env = aiohttp_jinja2.setup(
|
|
self, loader=jinja2.FileSystemLoader(self.template_path)
|
|
)
|
|
|
|
def run(self, *args, **kwargs):
|
|
if kwargs.get("port"):
|
|
if not kwargs.get("host"):
|
|
kwargs["host"] = "127.0.0.1"
|
|
web.run_app(self, *args, **kwargs)
|
|
|
|
async def authenticate(self, username, password):
|
|
return self.basic_username == username and self.basic_password == password
|
|
|
|
async def agent_create_thread(self, api_key, assistent_id):
|
|
agent = Agent(api_key, assistent_id)
|
|
self.agents[str(agent.thread.id)] = agent
|
|
return str(agent.thread.id)
|
|
|
|
async def rpc_agent_create_thread(self, api_key, assistent_id):
|
|
return await self.agent_create_thread(api_key, assistent_id)
|
|
|
|
async def agent_prompt(self, thread_id, message):
|
|
try:
|
|
agent = self.agents[str(thread_id)]
|
|
return await agent.chat(message)
|
|
except Exception as ex:
|
|
return str(ex)
|
|
|
|
async def rpc_agent_prompt(self, thread_id, message):
|
|
return await self.agent_prompt(str(thread_id), message)
|
|
|
|
@web.middleware
|
|
async def base64_auth_middleware(self, request, handler):
|
|
auth_header = request.headers.get("Authorization")
|
|
if not self.basic_username:
|
|
return await handler(request)
|
|
if not auth_header or not auth_header.startswith("Basic "):
|
|
return web.Response(
|
|
status=401,
|
|
text="Unauthorized",
|
|
headers={"WWW-Authenticate": 'Basic realm="Restricted"'},
|
|
)
|
|
|
|
try:
|
|
encoded_credentials = auth_header.split(" ", 1)[1]
|
|
decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
|
|
username, password = decoded_credentials.split(":", 1)
|
|
except (ValueError, base64.binascii.Error):
|
|
return web.Response(status=400, text="Invalid Authorization Header")
|
|
|
|
if not await self.authenticate(username, password):
|
|
return web.Response(
|
|
status=401,
|
|
text="Invalid Credentials",
|
|
headers={"WWW-Authenticate": 'Basic realm="Restricted"'},
|
|
)
|
|
|
|
return await handler(request)
|
|
|
|
async def render_template(self, name, request=None, context=None):
|
|
return aiohttp_jinja2.render_template(name, request, context)
|
|
|
|
@web.middleware
|
|
async def request_middleware(self, request: web.Request, handler):
|
|
time_start = time.time()
|
|
created = get_timestamp()
|
|
response = await handler(request)
|
|
time_end = time.time()
|
|
await self.insert(
|
|
"http_access",
|
|
{
|
|
"created": created,
|
|
"path": request.path,
|
|
"duration": time_end - time_start,
|
|
},
|
|
)
|
|
return response
|
|
|
|
@web.middleware
|
|
async def session_middleware(self, request, handler):
|
|
# Process the request and get the response
|
|
cookies = request.cookies
|
|
session_id = cookies.get(self.cookie_name, None)
|
|
setattr(request, "session", self.session.get(session_id, {}))
|
|
response = await handler(request)
|
|
|
|
if not session_id:
|
|
session_id = str(uuid.uuid4())
|
|
response.set_cookie(self.cookie_name, session_id, max_age=3600, httponly=True)
|
|
return response
|
|
|
|
|
|
class WebDbApplication(BaseApplication):
|
|
|
|
def __init__(
|
|
self, db=None, db_web=False, db_path="sqlite:///:memory:", *args, **kwargs
|
|
):
|
|
self.db_web = db_web
|
|
self.db_path = db_path
|
|
self.db = db or dataset.connect(self.db_path)
|
|
|
|
super().__init__(*args, **kwargs)
|
|
if not self.db_web:
|
|
return
|
|
self.router.add_post("/db/insert", self.insert_handler)
|
|
self.router.add_post("/db/update", self.update_handler)
|
|
self.router.add_post("/db/upsert", self.upsert_handler)
|
|
self.router.add_post("/db/find", self.find_handler)
|
|
self.router.add_post("/db/find_one", self.find_one_handler)
|
|
self.router.add_post("/db/delete", self.delete_handler)
|
|
self.router.add_post("/db/get", self.get_handler)
|
|
self.router.add_post("/db/set", self.set_handler)
|
|
self.rpc_set = self.set
|
|
self.rpc_get = self.get
|
|
self.rpc_insert = self.insert
|
|
self.rpc_update = self.update
|
|
self.rpc_upsert = self.upsert
|
|
self.rpc_find = self.find
|
|
self.rpc_fine_one = self.find_one
|
|
self.rpc_delete = self.delete
|
|
|
|
async def set_handler(self, request):
|
|
obj = await request.json()
|
|
response = await self.set(obj.get("key"), obj.get("value"))
|
|
return web.json_response(response)
|
|
|
|
async def get_handler(self, request):
|
|
obj = await request.json()
|
|
response = await self.get(obj.get("key"), None)
|
|
return web.json_response(response)
|
|
|
|
async def insert_handler(self, request):
|
|
obj = await request.json()
|
|
response = await self.insert(obj.get("table"), obj.get("data"))
|
|
return web.json_response(response)
|
|
|
|
async def update_handler(self, request):
|
|
obj = await request.json()
|
|
response = await self.update(
|
|
obj.get("table"), obj.get("data"), obj.get("where", {})
|
|
)
|
|
return web.json_response(response)
|
|
|
|
async def upsert_handler(self, request):
|
|
obj = await request.json()
|
|
response = await self.upsert(
|
|
obj.get("table"), obj.get("data"), obj.get("keys", [])
|
|
)
|
|
return web.json_response(response)
|
|
|
|
async def find_handler(self, request):
|
|
obj = await request.json()
|
|
response = await self.find(obj.get("table"), obj.get("where", {}))
|
|
return web.json_response(response)
|
|
|
|
async def find_one_handler(self, request):
|
|
obj = await request.json()
|
|
response = await self.find_one(obj.get("table"), obj.get("where", {}))
|
|
return web.json_response(response)
|
|
|
|
async def delete_handler(self, request):
|
|
obj = await request.json()
|
|
response = await self.delete(obj.get("table"), obj.get("where", {}))
|
|
return web.json_response(response)
|
|
|
|
async def set(self, key, value):
|
|
return self.sset(key, value)
|
|
|
|
def sset(self, key, value):
|
|
value = json.dumps(value, default=str)
|
|
return self.db["kv"].upsert({"key": key, "value": value}, ["key"])
|
|
|
|
async def get(self, key, default=None):
|
|
return self.sget(key, default)
|
|
|
|
def sget(self, key, default=None):
|
|
record = self.db["kv"].find_one(key=key)
|
|
if record:
|
|
result = record.get("value", "null")
|
|
return result == "null" and default or json.loads(result)
|
|
return default
|
|
|
|
async def insert(self, table_name, data):
|
|
return self.db[table_name].insert(data)
|
|
|
|
async def update(self, table_name, data, where=None):
|
|
return self.db[table_name].update(data, where or {})
|
|
|
|
async def upsert(self, table_name, data, keys=None):
|
|
return self.db[table_name].upsert(data, keys or [])
|
|
|
|
async def find(self, table_name, filters=None):
|
|
if not filters:
|
|
filters = {}
|
|
return [dict(record) for record in self.db[table_name].find(**filters)]
|
|
|
|
async def find_one(self, table_name, filters=None):
|
|
if not filters:
|
|
filters = {}
|
|
try:
|
|
return dict(self.db[table_name].find_one(**filters))
|
|
except ValueError:
|
|
return None
|
|
|
|
async def delete(self, table_name, where=None):
|
|
where = where or {}
|
|
return self.db[table_name].delete(**where)
|
|
|
|
|
|
class Application(WebDbApplication):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.on_startup.append(self.on_startup_task)
|
|
self.router.add_get("/stat", self.index_handler)
|
|
self.request_count = 0
|
|
self.time_started = time.time()
|
|
self.running_since = None
|
|
|
|
@property
|
|
def uptime(self):
|
|
return time.time() - self.time_started
|
|
|
|
async def on_startup_task(self, app):
|
|
log.debug("App starting.")
|
|
self.running_since = get_timestamp()
|
|
|
|
async def inc_request_count(self):
|
|
request_count = await self.get("root_request_count", 0)
|
|
request_count += 1
|
|
await self.set("root_request_count", request_count)
|
|
return request_count
|
|
|
|
async def index_handler(self, request):
|
|
|
|
return web.json_response(
|
|
{
|
|
"request_count": await self.inc_request_count(),
|
|
"timestamp": get_timestamp(),
|
|
"uptime": self.uptime,
|
|
"running_since": self.running_since,
|
|
},
|
|
content_type="application/json",
|
|
)
|
|
|
|
|
|
argument_parser = argparse.ArgumentParser("Web service")
|
|
argument_parser.add_argument(
|
|
"--host", default="0.0.0.0", required=False, type=str, help="Host to serve on."
|
|
)
|
|
argument_parser.add_argument(
|
|
"--port", default=8888, required=False, type=int, help="Port to serve on."
|
|
)
|
|
argument_parser.add_argument(
|
|
"--db-path",
|
|
default="sqlite:///:memory:",
|
|
required=False,
|
|
type=str,
|
|
help="SQLAlchemy db url. (e.g. sqlite:///app.db)",
|
|
)
|
|
argument_parser.add_argument(
|
|
"--basic-username",
|
|
default=None,
|
|
required=False,
|
|
type=str,
|
|
help="Basic Auth username.",
|
|
)
|
|
argument_parser.add_argument(
|
|
"--basic-password",
|
|
default=None,
|
|
required=False,
|
|
type=str,
|
|
help="Basic Auth password.",
|
|
)
|
|
argument_parser.add_argument(
|
|
"--db-web", action="store_true", help="Enable /db/* endpoints", default=False
|
|
)
|
|
|
|
|
|
def create_app(*args, **kwargs):
|
|
global argument_parser
|
|
args = argument_parser.parse_args()
|
|
app = create_app(
|
|
db_path=args.db_path,
|
|
db_web=args.db_web,
|
|
basic_username=args.basic_username,
|
|
basic_password=args.basic_password,
|
|
)
|
|
return app
|
|
|
|
|
|
def main():
|
|
app = create_app()
|
|
return app
|