-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompletions.py
153 lines (122 loc) · 3.95 KB
/
completions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import openai
from groq import Groq
from rich.console import Console
from rich import print
from pathlib import Path
import os
from dotenv import load_dotenv
import yaml
from prompt_toolkit import PromptSession
ENV_FILENAME = ".promptlab"
console = Console()
# model looks like this: "openai:gpt-4o" or "groq:llama-3"
def parse_model(model):
model_components = model.split(":")
if len(model_components) == 1:
provider = "openai"
model = model
else:
provider = model_components[0]
model = model_components[1]
return provider, model
# Load the config file
def load_config():
home = str(Path.home())
# Test if the file exists
if not os.path.isfile(home + "/" + ENV_FILENAME):
return {}
# Load the file
with open(home + "/" + ENV_FILENAME, "r") as f:
config = yaml.safe_load(f)
return config
def save_config(config):
home = str(Path.home())
with open(home + "/" + ENV_FILENAME, "w") as f:
yaml.dump(config, f)
console.log(f"API key set successfully and saved in {home}/{ENV_FILENAME}")
return
# Write the API key to the .promptlab file in the home directory
def action_set_api_key():
session = PromptSession()
# get the config
config = load_config()
# get the provider and API key
provider = session.prompt("Provider (openai | groq)> ")
api_key = session.prompt(f"API key > ")
# write the key to the config
config[provider] = api_key
# save the config
save_config(config)
# Check if the .promptlab file exists in the home directory
def load_env():
home = str(Path.home())
if not os.path.isfile(home + "/" + ENV_FILENAME):
return False
load_dotenv(home + "/" + ENV_FILENAME)
console.log(f"Loaded API key from {home}/{ENV_FILENAME}")
return True
def openai_completion(args, config, text):
openai.api_key = config["openai"]
response = openai.ChatCompletion.create(
model=args.model, messages=[{"role": "user", "content": text}], temperature=0.1
)
response_txt = str(response.choices[0].message.content)
return response_txt
def openai_models(args):
# Load the config file
config = load_config()
# Check if the API key is set
openai.api_key = config["openai"]
models = openai.Model.list()
out = [model.id for model in models.data]
return sorted(out)
def groq_completion(args, config, text):
provider, model = parse_model(args.model)
client = Groq(
# This is the default and can be omitted
api_key=config["groq"],
)
chat_completion = client.chat.completions.create(
messages=[
{"role": "system", "content": "you are a helpful assistant."},
{
"role": "user",
"content": text,
},
],
model=model,
)
return chat_completion.choices[0].message.content
def groq_models(args):
provider, model = parse_model(args.model)
# Load the config file
config = load_config()
# Check if the API key is set
client = Groq(
# This is the default and can be omitted
api_key=config["groq"],
)
models = client.models.list()
out = [model.id for model in models.data]
return sorted(out)
def action_models(args):
if args.provider == "openai":
return openai_models(args)
elif args.provider == "groq":
return groq_models(args)
return "Unknown provider"
def complete(args, text):
provider, model = parse_model(args.model)
# Load the config file
config = load_config()
# Check if the API key is set
if provider not in config:
raise Exception(
f"You must set an API key for {provider} to use the {model}. Run the auth command to set it."
)
# Perform the correct call
if provider == "openai":
return openai_completion(args, config, text)
elif provider == "groq":
return groq_completion(args, config, text)
return "Unknown provider"