220 lines
7.3 KiB
Python
Raw Normal View History

2024-12-20 17:49:17 +00:00
"""
Written in 2024 by retoor@molodetz.nl.
MIT license. Enjoy!
2024-12-22 10:25:17 +00:00
You'll need a paid OpenAI account, named a project in it, requested an api key and created an assistant.
URL's to all these pages are described in the class for convenience.
2024-12-20 17:49:17 +00:00
The API keys described in this document are fake but are in the correct format for educational purposes.
How to start:
- sudo apt install python3.12-venv python3-pip -y
- python3 -m venv .venv
- . .venv/bin/activate
2024-12-22 10:25:17 +00:00
- pip install openapi
2024-12-20 17:49:17 +00:00
2024-12-22 10:25:17 +00:00
This file is to be used as part of your project or a standalone after doing
some modifications at the end of the file.
2024-12-20 17:49:17 +00:00
"""
2024-12-22 10:25:17 +00:00
try:
import os
import sys
sys.path.append(os.getcwd())
import env
API_KEY = env.API_KEY
ASSISTANT_ID = env.ASSISTANT_ID
except:
pass
2024-12-20 17:49:17 +00:00
import asyncio
import functools
from collections.abc import Generator
from typing import Optional
from openai import OpenAI
class Agent:
"""
2024-12-22 10:25:17 +00:00
This class translates into an instance a single user session with its own memory.
2024-12-20 17:49:17 +00:00
The messages property of this class is a list containing the full chat history about
2024-12-22 10:25:17 +00:00
what the user said and what the assistant (agent) said. This can be used in future to continue
where you left off. Format is described in the docs of __init__ function below.
2024-12-20 17:49:17 +00:00
2024-12-22 10:25:17 +00:00
Introduction API usage for if you want to extend this class:
2024-12-20 17:49:17 +00:00
https://platform.openai.com/docs/api-reference/introduction
"""
def __init__(
self, api_key: str, assistant_id: int, messages: Optional[list] = None
):
"""
You can find and create API keys here:
https://platform.openai.com/api-keys
2024-12-22 10:25:17 +00:00
You can find assistant_id (agent_id) here. It is the id that starts with 'asst_', not your custom name:
2024-12-20 17:49:17 +00:00
https://platform.openai.com/assistants/
2024-12-22 10:25:17 +00:00
Messages are optional in this format, this is to keep a message history that you can later use again:
2024-12-20 17:49:17 +00:00
[
{"role": "user", "message": "What is choking the chicken?"},
{"role": "assistant", "message": "Lucky for the cock."}
]
"""
self.assistant_id = assistant_id
self.api_key = api_key
self.client = OpenAI(api_key=self.api_key)
self.messages = messages or []
self.thread = self.client.beta.threads.create(messages=self.messages)
async def dalle2(
self, prompt: str, width: Optional[int] = 512, height: Optional[int] = 512
) -> dict:
"""
2024-12-22 10:25:17 +00:00
In my opinion dall-e-2 produces unusual results.
Sizes: 256x256, 512x512 or 1024x1024.
2024-12-20 17:49:17 +00:00
"""
result = self.client.images.generate(
model="dall-e-2", prompt=prompt, n=1, size=f"{width}x{height}"
)
return result
@property
async def models(self):
"""
List models in dict format. That's more convenient than the original
2024-12-22 10:25:17 +00:00
list method because this can be directly converted to json to be used
in your front end or api. That's not the original result which is a
2024-12-20 17:49:17 +00:00
custom list with unserializable models.
"""
return [
{
"id": model.id,
"owned_by": model.owned_by,
"object": model.object,
"created": model.created,
}
for model in self.client.models.list()
]
async def dalle3(
self, prompt: str, height: Optional[int] = 1024, width: Optional[int] = 1024
) -> dict:
"""
2024-12-22 10:25:17 +00:00
Sadly only big sizes allowed. Is more pricy.
2024-12-20 17:49:17 +00:00
Sizes: 1024x1024, 1792x1024, or 1024x1792.
"""
result = self.client.images.generate(
model="dall-e-3", prompt=prompt, n=1, size=f"{width}x{height}"
)
print(result)
return result
2024-12-22 10:25:17 +00:00
def upload_file(file_name: str, purpose: str) -> str:
with open(file_name, "rb") as file_fd:
response = self.client.files.create(file=file_fd, purpose=purpose)
return response.id
2024-12-20 17:49:17 +00:00
async def chat(
self, message: str, interval: Optional[float] = 0.2
) -> Generator[None, None, str]:
"""
2024-12-22 10:25:17 +00:00
Chat with the agent. It yields on given interval to inform the caller it' still busy so you can
update the user with live status. It doesn't hang. You can use this fully async with other
2024-12-20 17:49:17 +00:00
instances of this class.
This function also updates the self.messages list with chat history for later use.
"""
message_object = {"role": "user", "content": message}
self.messages.append(message_object)
self.client.beta.threads.messages.create(
self.thread.id,
role=message_object["role"],
content=message_object["content"],
)
run = self.client.beta.threads.runs.create(
thread_id=self.thread.id, assistant_id=self.assistant_id
)
while run.status != "completed":
run = self.client.beta.threads.runs.retrieve(
thread_id=self.thread.id, run_id=run.id
)
yield None
await asyncio.sleep(interval)
response_messages = self.client.beta.threads.messages.list(
thread_id=self.thread.id
).data
2024-12-22 10:25:17 +00:00
last_message = response_messages[0].content[0].text.value
2024-12-20 17:49:17 +00:00
self.messages.append({"role": "assistant", "content": last_message})
2024-12-22 10:25:17 +00:00
print(last_message)
yield str(last_message)
2024-12-20 17:49:17 +00:00
async def chatp(self, message: str) -> str:
"""
2024-12-22 10:25:17 +00:00
Just like regular chat function but with progress indication and returns string directly.
2024-12-20 17:49:17 +00:00
This is handy for interactive usage or for a process log.
"""
asyncio.get_event_loop()
print("Processing", end="")
async for message in self.chat(message):
if not message:
print(".", end="", flush=True)
continue
print("")
break
return message
async def read_line(self, ps: Optional[str] = "> "):
"""
2024-12-22 10:25:17 +00:00
Non blocking read_line.
Blocking read line can break web socket connections.
2024-12-20 17:49:17 +00:00
That's why.
"""
loop = asyncio.get_event_loop()
patched_input = functools.partial(input, ps)
return await loop.run_in_executor(None, patched_input)
async def cli(self):
"""
2024-12-22 10:25:17 +00:00
Interactive client. Can be used on terminal by user or a different process.
The bottom new line is so that a process can check for \n\n to check if it's end response
and there's nothing left to wait for and thus can send next prompt if the '>' shows.
2024-12-20 17:49:17 +00:00
"""
while True:
try:
message = await self.read_line("> ")
if not message.strip():
continue
response = await self.chatp(message)
print(response.content[0].text.value)
print("")
except KeyboardInterrupt:
2024-12-22 10:25:17 +00:00
print("Exiting..")
2024-12-20 17:49:17 +00:00
break
2024-12-22 10:25:17 +00:00
2024-12-20 17:49:17 +00:00
async def main():
"""
2024-12-22 10:25:17 +00:00
Example main function. The keys here are not real but look exactly like
the real ones for example purposes and that you're sure your key is in the
right format.
2024-12-20 17:49:17 +00:00
"""
agent = Agent(api_key=API_KEY, assistant_id=ASSISTANT_ID)
2024-12-22 10:25:17 +00:00
2024-12-20 17:49:17 +00:00
# Run interactive chat
await agent.cli()
if __name__ == "__main__":
2024-12-22 10:25:17 +00:00
# Only gets executed by direct execution of script. Not when important.
2024-12-20 17:49:17 +00:00
asyncio.run(main())