diff --git a/src/yura/client.py b/src/yura/client.py index 85d3c6e..08ba098 100644 --- a/src/yura/client.py +++ b/src/yura/client.py @@ -72,18 +72,21 @@ class AsyncClient: self.queue_out = asyncio.Queue() self.communication_task = None self.session_id = None + self.ws = None async def __aenter__(self): + self.ws = await self.client.ws return self async def __aexit__(self, *args, **kwargs): - pass + await self.client.close() + self.ws = None async def create(self, name, extends, system): return await self.client.create(name=name, extends=extends, system=system) - async def chat(self, token, message): - yield await self.client.chat(uid=token, message=message) + async def chat(self, token, message, datasets=None): + yield await self.client.chat(uid=token, message=message, datasets=datasets or []) async for msg in self.client: yield msg if msg.get("done"): @@ -94,7 +97,7 @@ class AsyncClient: async def close(self): await self.client.close() - self.client = None + self.ws = None async def cli_client(url="ws://127.0.0.1:8470"):