diff --git a/src/yura/client.py b/src/yura/client.py index 08ba098..af7b989 100644 --- a/src/yura/client.py +++ b/src/yura/client.py @@ -67,36 +67,43 @@ class AsyncClient: def __init__(self, url="ws://127.0.0.1:8470"): self.url = url - self.client = AsyncRPCClient(self.url) + self.queue_in = asyncio.Queue() self.queue_out = asyncio.Queue() self.communication_task = None self.session_id = None self.ws = None - async def __aenter__(self): + @property + def _connection(self): + if not self.client: + self.client = AsyncRPCClient(self.url) self.ws = await self.client.ws + return self.client + + async def __aenter__(self): + conn = self._connection return self async def __aexit__(self, *args, **kwargs): - await self.client.close() - self.ws = None + await self.close(0 async def create(self, name, extends, system): - return await self.client.create(name=name, extends=extends, system=system) + return await self._connection.create(name=name, extends=extends, system=system) async def chat(self, token, message, datasets=None): - yield await self.client.chat(uid=token, message=message, datasets=datasets or []) + yield await self._connection.chat(uid=token, message=message, datasets=datasets or []) async for msg in self.client: yield msg if msg.get("done"): break async def connect(self, name): - return await self.client.connect(name) + return await self._connection.connect(name) async def close(self): await self.client.close() + self.client = None self.ws = None