Nice working version.

This commit is contained in:
retoor 2024-12-08 01:57:02 +01:00
parent 507bd9fb2c
commit 7a5f4da67c
4 changed files with 102 additions and 43 deletions

View File

@ -1,7 +1,5 @@
import logging import logging
logging.basicConfig( logging.basicConfig(level=logging.INFO)
level=logging.INFO
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View File

@ -1,6 +1,8 @@
from llmbox.app import Application
import argparse import argparse
from llmbox.app import Application
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="LLMBox LLM frontend.") parser = argparse.ArgumentParser(description="LLMBox LLM frontend.")
@ -10,35 +12,35 @@ def parse_args():
help="Host to server on. Default: 127.0.0.1.", help="Host to server on. Default: 127.0.0.1.",
required=False, required=False,
default="127.0.0.1", default="127.0.0.1",
type=str type=str,
) )
parser.add_argument( parser.add_argument(
"--port", "--port",
help="Port to serve on. Default: 3020.", help="Port to serve on. Default: 3020.",
default=3020, default=3020,
required=False, required=False,
type=int type=int,
) )
parser.add_argument( parser.add_argument(
"--llm-name", "--llm-name",
help="Name for the build of your LLM.", help="Name for the build of your LLM.",
default="llmbox", default="llmbox",
required=False, required=False,
type=str type=str,
) )
parser.add_argument( parser.add_argument(
"--llm-extends", "--llm-extends",
help="Name of LLM to extend from, your basis. Must be loaded in Katya LLM server already.", help="Name of LLM to extend from, your basis. Must be loaded in Katya LLM server already.",
default="qwen:0.5b", default="qwen2.5:0.5b",
required=False, required=False,
type=str type=str,
) )
parser.add_argument( parser.add_argument(
"--llm-system", "--llm-system",
help="Path to text file with LLM system messages. The context of your LLM will be described here.", help="Path to text file with LLM system messages. The context of your LLM will be described here.",
required=False, required=False,
type=str, type=str,
default="Be a bot named LLMBox written by retoor." default="Be a bot named LLMBox written by retoor.",
) )
return parser.parse_args() return parser.parse_args()
@ -46,10 +48,11 @@ def parse_args():
def run(): def run():
args = parse_args() args = parse_args()
app = Application(llm_name=args.llm_name, llm_extends=args.llm_extends, llm_system=args.llm_system) app = Application(
llm_name=args.llm_name, llm_extends=args.llm_extends, llm_system=args.llm_system
)
app.run(host=args.host, port=args.port) app.run(host=args.host, port=args.port)
if __name__ == '__main__': if __name__ == "__main__":
run() run()

View File

@ -1,20 +1,32 @@
import json
import pathlib
import aiohttp
from aiohttp import web
from app.app import Application as BaseApplication from app.app import Application as BaseApplication
from yura.client import AsyncClient from yura.client import AsyncClient
from llmbox import log from llmbox import log
import pathlib
from aiohttp import web
class Application(BaseApplication): class Application(BaseApplication):
def __init__(
def __init__(self, llm_name, llm_extends, llm_system,server_url="https://flock.molodetz.nl", *args, **kwargs): self,
llm_name,
llm_extends,
llm_system,
server_url="wss://flock.molodetz.nl",
*args,
**kwargs,
):
self.server_url = server_url self.server_url = server_url
self.client = AsyncClient(self.server_url) self.llm_client = AsyncClient(self.server_url)
log.info("Server url: {}".format(server_url)) log.info(f"Server url: {server_url}")
log.info("LLM_name: {}".format(llm_name)) log.info(f"LLM_name: {llm_name}")
log.info("LLM extends: {}".format(llm_extends)) log.info(f"LLM extends: {llm_extends}")
self.llm_name = llm_name self.llm_name = llm_name
self.llm_extends = llm_extends self.llm_extends = llm_extends
@ -22,18 +34,58 @@ class Application(BaseApplication):
if pathlib.Path(self.llm_system).exists(): if pathlib.Path(self.llm_system).exists():
self.llm_system = pathlib.Path(self.llm_system).read_text() self.llm_system = pathlib.Path(self.llm_system).read_text()
log.info("LLM system: {}".format(self.llm_system)) log.info(f"LLM system: {self.llm_system}")
else: else:
log.info("LLM system: {}".format(llm_system)) log.info(f"LLM system: {llm_system}")
self.static_path = pathlib.Path(__file__).parent.joinpath("static") self.static_path = pathlib.Path(__file__).parent.joinpath("static")
self.llm_initialized = False
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.router.add_get("/", self.handle_index) self.add_routes([web.get("/ws/", self.ws_handler)])
# self.router.add_get("/ws", self.ws_handler)
self.router.add_get("/", self.handle_index)
self.router.add_static("/", self.static_path)
async def handle_index(self, request): async def handle_index(self, request):
index_content = self.static_path.joinpath("index.html").read_text() index_content = self.static_path.joinpath("index.html").read_text()
return web.Response(text=index_content, content_type="text/html") return web.Response(text=index_content, content_type="text/html")
async def prepare_llm(self):
if not self.llm_initialized:
self.llm_initialized = await self.llm_client.create(
self.llm_name, self.llm_extends, self.llm_system
)
async def ws_handler(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
await self.prepare_llm()
token = await self.llm_client.connect(self.llm_name)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
if msg.data == "close":
await ws.close()
else:
prompt = json.loads(msg.data).get("prompt")
log.info(f"Received prompt: {prompt}")
async for message_chunk in self.llm_client.chat(
token=token, message=prompt
):
log.info(
"Received from LLM: {}".format(message_chunk["content"])
)
await ws.send_str(json.dumps(message_chunk))
elif msg.type == aiohttp.WSMsgType.ERROR:
log.warning(f"ws connection closed with exception {ws.exception()}")
log.info("websocket connection closed")
return ws

View File

@ -2,6 +2,7 @@ class App {
url = null url = null
socket = null socket = null
newMessage = null newMessage = null
isNewMessage = false
constructor(){ constructor(){
this.url = window.location.href.replace(/^http/, 'ws') + 'ws/' this.url = window.location.href.replace(/^http/, 'ws') + 'ws/'
@ -22,8 +23,11 @@ class App {
newMessage.classList.add('message') newMessage.classList.add('message')
newMessage.classList.add('bot') newMessage.classList.add('bot')
newMessage.classList.add('botmessage') newMessage.classList.add('botmessage')
newMessage.innerText = "*thinking*"
messageList.appendChild(newMessage) messageList.appendChild(newMessage)
this.newMessage = newMessage this.newMessage = newMessage
this.isNewMessage = true
messageList.scrollTop = messageList.scrollHeight
} }
createNewUserMessage(msg){ createNewUserMessage(msg){
const messageList = document.querySelector('.message-list') const messageList = document.querySelector('.message-list')
@ -51,14 +55,16 @@ class App {
onMessage(event){ onMessage(event){
const messageList = document.querySelector('.message-list') const messageList = document.querySelector('.message-list')
let obj = JSON.parse(event.data) let obj = JSON.parse(event.data)
if(typeof(obj) == 'string'){ if(this.isNewMessage){
this.newMessage.innerText += obj this.newMessage.innerText = obj['content']
this.isNewMessage = false
}else{
this.newMessage.innerText += obj['content']
}
messageList.scrollTop = messageList.scrollHeight messageList.scrollTop = messageList.scrollHeight
if (obj['done']){
}else if (typeof(obj) == 'object' && obj['done']){
this.newMessage = null; this.newMessage = null;
messageList.scrollTop = messageList.scrollHeight this.isNewMessage = true
} }
} }