diff --git a/src/snek/service/socket.py b/src/snek/service/socket.py index a0f1e9b..0b6071d 100644 --- a/src/snek/service/socket.py +++ b/src/snek/service/socket.py @@ -34,30 +34,40 @@ class SocketService(BaseService): def __init__(self, app): super().__init__(app) - self.sockets = [] + self.sockets = set() + self.users = {} self.subscriptions = {} async def add(self, ws, user_uid): - self.sockets.append(self.Socket(ws, await self.app.services.user.get(uid=user_uid))) + s = self.Socket(ws, await self.app.services.user.get(uid=user_uid)) + self.sockets.add(s) + if not self.users.get(user_uid): + self.users[user_uid] = set() + self.users[user_uid].add(s) async def subscribe(self, ws,channel_uid, user_uid): + return if not channel_uid in self.subscriptions: self.subscriptions[channel_uid] = set() s = self.Socket(ws,await self.app.services.user.get(uid=user_uid)) self.subscriptions[channel_uid].add(s) + async def send_to_user(self, user_uid, message): + count = 0 + for s in self.users.get(user_uid,[]): + if await s.send_json(message): + count += 1 + return count + async def broadcast(self, channel_uid, message): count = 0 - subscriptions = set(self.subscriptions.get(channel_uid,[])) - for s in subscriptions: - if not await s.send_json(message): - self.subscriptions[channel_uid].remove(s) - continue - count += 1 + async for channel_member in self.app.services.channel_member.find(channel_uid=channel_uid): + count += await self.send_to_user(channel_member["user_uid"],message) return count + async def delete(self, ws): - for s in self.sockets: - if s.ws == ws: - await s.close() - self.sockets.remove(s) - \ No newline at end of file + for s in [sock for sock in self.sockets if sock.ws == ws]: + await s.close() + self.sockets.remove(s) + + \ No newline at end of file diff --git a/src/snek/system/cache.py b/src/snek/system/cache.py index 39e6fc3..f9c4761 100644 --- a/src/snek/system/cache.py +++ b/src/snek/system/cache.py @@ -61,7 +61,7 @@ class Cache: if is_new: self.version += 1 - print("Cache store! New version:", self.version, flush=True) + print(f"Cache store! {len(self.lru)} items. New version:", self.version, flush=True) async def delete(self, args): if args in self.cache: diff --git a/src/snek/view/rpc.py b/src/snek/view/rpc.py index f05f4a5..d664435 100644 --- a/src/snek/view/rpc.py +++ b/src/snek/view/rpc.py @@ -91,9 +91,9 @@ class RPCView(BaseView): }) return channels - async def send_message(self, room, message): + async def send_message(self, channel_uid, message): self._require_login() - await self.services.chat.send(self.user_uid, room, message) + await self.services.chat.send(self.user_uid, channel_uid, message) return True async def echo(self, *args):