diff --git a/src/llmbox/__init__.py b/src/llmbox/__init__.py index a3c1390..a2d757f 100644 --- a/src/llmbox/__init__.py +++ b/src/llmbox/__init__.py @@ -1,7 +1,5 @@ -import logging +import logging -logging.basicConfig( - level=logging.INFO -) +logging.basicConfig(level=logging.INFO) log = logging.getLogger(__name__) diff --git a/src/llmbox/__main__.py b/src/llmbox/__main__.py index b6f5b76..6237291 100644 --- a/src/llmbox/__main__.py +++ b/src/llmbox/__main__.py @@ -1,6 +1,8 @@ -from llmbox.app import Application import argparse +from llmbox.app import Application + + def parse_args(): parser = argparse.ArgumentParser(description="LLMBox LLM frontend.") @@ -10,35 +12,35 @@ def parse_args(): help="Host to server on. Default: 127.0.0.1.", required=False, default="127.0.0.1", - type=str + type=str, ) parser.add_argument( "--port", help="Port to serve on. Default: 3020.", default=3020, required=False, - type=int + type=int, ) parser.add_argument( "--llm-name", help="Name for the build of your LLM.", default="llmbox", required=False, - type=str + type=str, ) parser.add_argument( "--llm-extends", help="Name of LLM to extend from, your basis. Must be loaded in Katya LLM server already.", - default="qwen:0.5b", + default="qwen2.5:0.5b", required=False, - type=str + type=str, ) parser.add_argument( "--llm-system", help="Path to text file with LLM system messages. The context of your LLM will be described here.", required=False, type=str, - default="Be a bot named LLMBox written by retoor." + default="Be a bot named LLMBox written by retoor.", ) return parser.parse_args() @@ -46,10 +48,11 @@ def parse_args(): def run(): args = parse_args() - app = Application(llm_name=args.llm_name, llm_extends=args.llm_extends, llm_system=args.llm_system) + app = Application( + llm_name=args.llm_name, llm_extends=args.llm_extends, llm_system=args.llm_system + ) app.run(host=args.host, port=args.port) -if __name__ == '__main__': +if __name__ == "__main__": run() - diff --git a/src/llmbox/app.py b/src/llmbox/app.py index 9f662a9..dcf14f5 100644 --- a/src/llmbox/app.py +++ b/src/llmbox/app.py @@ -1,39 +1,91 @@ -from app.app import Application as BaseApplication -from yura.client import AsyncClient -from llmbox import log -import pathlib +import json +import pathlib + +import aiohttp from aiohttp import web +from app.app import Application as BaseApplication +from yura.client import AsyncClient + +from llmbox import log + class Application(BaseApplication): + def __init__( + self, + llm_name, + llm_extends, + llm_system, + server_url="wss://flock.molodetz.nl", + *args, + **kwargs, + ): - def __init__(self, llm_name, llm_extends, llm_system,server_url="https://flock.molodetz.nl", *args, **kwargs): - self.server_url = server_url - self.client = AsyncClient(self.server_url) - - log.info("Server url: {}".format(server_url)) - log.info("LLM_name: {}".format(llm_name)) - log.info("LLM extends: {}".format(llm_extends)) + self.llm_client = AsyncClient(self.server_url) - self.llm_name = llm_name + log.info(f"Server url: {server_url}") + log.info(f"LLM_name: {llm_name}") + log.info(f"LLM extends: {llm_extends}") + + self.llm_name = llm_name self.llm_extends = llm_extends - self.llm_system = llm_system - + self.llm_system = llm_system + if pathlib.Path(self.llm_system).exists(): self.llm_system = pathlib.Path(self.llm_system).read_text() - log.info("LLM system: {}".format(self.llm_system)) + log.info(f"LLM system: {self.llm_system}") else: - log.info("LLM system: {}".format(llm_system)) - + log.info(f"LLM system: {llm_system}") + self.static_path = pathlib.Path(__file__).parent.joinpath("static") + self.llm_initialized = False + super().__init__(*args, **kwargs) - - self.router.add_get("/", self.handle_index) - + + self.add_routes([web.get("/ws/", self.ws_handler)]) + + # self.router.add_get("/ws", self.ws_handler) + self.router.add_get("/", self.handle_index) + self.router.add_static("/", self.static_path) async def handle_index(self, request): index_content = self.static_path.joinpath("index.html").read_text() return web.Response(text=index_content, content_type="text/html") + async def prepare_llm(self): + if not self.llm_initialized: + self.llm_initialized = await self.llm_client.create( + self.llm_name, self.llm_extends, self.llm_system + ) + + async def ws_handler(self, request): + + ws = web.WebSocketResponse() + await ws.prepare(request) + + await self.prepare_llm() + + token = await self.llm_client.connect(self.llm_name) + + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if msg.data == "close": + await ws.close() + else: + prompt = json.loads(msg.data).get("prompt") + log.info(f"Received prompt: {prompt}") + async for message_chunk in self.llm_client.chat( + token=token, message=prompt + ): + log.info( + "Received from LLM: {}".format(message_chunk["content"]) + ) + await ws.send_str(json.dumps(message_chunk)) + elif msg.type == aiohttp.WSMsgType.ERROR: + log.warning(f"ws connection closed with exception {ws.exception()}") + + log.info("websocket connection closed") + + return ws diff --git a/src/llmbox/static/app.js b/src/llmbox/static/app.js index f355794..9848137 100644 --- a/src/llmbox/static/app.js +++ b/src/llmbox/static/app.js @@ -2,6 +2,7 @@ class App { url = null socket = null newMessage = null + isNewMessage = false constructor(){ this.url = window.location.href.replace(/^http/, 'ws') + 'ws/' @@ -22,8 +23,11 @@ class App { newMessage.classList.add('message') newMessage.classList.add('bot') newMessage.classList.add('botmessage') - messageList.appendChild(newMessage) + newMessage.innerText = "*thinking*" + messageList.appendChild(newMessage) this.newMessage = newMessage + this.isNewMessage = true + messageList.scrollTop = messageList.scrollHeight } createNewUserMessage(msg){ const messageList = document.querySelector('.message-list') @@ -51,15 +55,17 @@ class App { onMessage(event){ const messageList = document.querySelector('.message-list') let obj = JSON.parse(event.data) - if(typeof(obj) == 'string'){ - this.newMessage.innerText += obj - - messageList.scrollTop = messageList.scrollHeight - - }else if (typeof(obj) == 'object' && obj['done']){ + if(this.isNewMessage){ + this.newMessage.innerText = obj['content'] + this.isNewMessage = false + }else{ + this.newMessage.innerText += obj['content'] + } + messageList.scrollTop = messageList.scrollHeight + if (obj['done']){ this.newMessage = null; - messageList.scrollTop = messageList.scrollHeight - } + this.isNewMessage = true + } } }