Skip to content

Commit

Permalink
feat: show available models list in settings (paradigmxyz#83)
Browse files Browse the repository at this point in the history
* create utils/models

* add availableModels prop

* fetch models on startup and key change

* remove default model list

* consider model fetch for loading spinner

* clearer description

* remove accidental debug delay

* Update src/components/App.tsx

* Update src/components/App.tsx

* Update src/components/App.tsx

* distinguish isAnythingSaving and isAnythingLoading

---------

Co-authored-by: t11s <[email protected]>
  • Loading branch information
adietrichs and transmissions11 authored Jun 15, 2023
1 parent 54ebe4c commit 1dfc26f
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 6 deletions.
52 changes: 50 additions & 2 deletions src/components/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import {
} from "../utils/fluxNode";
import { useLocalStorage } from "../utils/lstore";
import { mod } from "../utils/mod";
import { getAvailableChatModels } from "../utils/models";
import { generateNodeId, generateStreamId } from "../utils/nodeId";
import { messagesFromLineage, promptFromLineage } from "../utils/prompt";
import { getQueryParam, resetURL } from "../utils/qparams";
Expand Down Expand Up @@ -864,11 +865,57 @@ function App() {

const [apiKey, setApiKey] = useLocalStorage<string>(API_KEY_LOCAL_STORAGE_KEY);

const isAnythingLoading = isSavingReactFlow || isSavingSettings;
const [availableModels, setAvailableModels] = useState<string[] | null>(null);

// modelsLoadCounter lets us discard the results of the requests if a concurrent newer one was made.
const modelsLoadCounter = useRef(0);
useEffect(() => {
if (isValidAPIKey(apiKey)) {
const modelsLoadIndex = modelsLoadCounter.current + 1;
modelsLoadCounter.current = modelsLoadIndex;

setAvailableModels(null);

(async () => {
let modelList: string[] = [];
try {
modelList = await getAvailableChatModels(apiKey!);
} catch (e) {
toast({
title: "Failed to load model list!",
status: "error",
...TOAST_CONFIG,
});
}
if (modelsLoadIndex !== modelsLoadCounter.current) return;

if (modelList.length === 0) modelList.push(settings.model);

setAvailableModels(modelList);

if (!modelList.includes(settings.model)) {
const oldModel = settings.model;
const newModel = modelList.includes(DEFAULT_SETTINGS.model) ? DEFAULT_SETTINGS.model : modelList[0];

setSettings((settings) => ({ ...settings, model: newModel }));

toast({
title: `Model "${oldModel}" no longer available!`,
description: `Switched to "${newModel}"`,
status: "warning",
...TOAST_CONFIG,
});
}
})();
}
}, [apiKey]);

const isAnythingSaving = isSavingReactFlow || isSavingSettings;
const isAnythingLoading = isAnythingSaving || (availableModels === null);

useBeforeunload((event: BeforeUnloadEvent) => {
// Prevent leaving the page before saving.
if (isAnythingLoading) event.preventDefault();
if (isAnythingSaving) event.preventDefault();
});

/*//////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1000,6 +1047,7 @@ function App() {
onClose={onCloseSettingsModal}
apiKey={apiKey}
setApiKey={setApiKey}
availableModels={availableModels}
/>
<Column
mainAxisAlignment="center"
Expand Down
6 changes: 4 additions & 2 deletions src/components/modals/SettingsModal.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { MIXPANEL_TOKEN } from "../../main";
import { getFluxNodeTypeDarkColor } from "../../utils/color";
import { DEFAULT_SETTINGS, SUPPORTED_MODELS } from "../../utils/constants";
import { DEFAULT_SETTINGS } from "../../utils/constants";
import { Settings, FluxNodeType } from "../../utils/types";
import { APIKeyInput } from "../utils/APIKeyInput";
import { LabeledSelect, LabeledSlider } from "../utils/LabeledInputs";
Expand All @@ -26,13 +26,15 @@ export const SettingsModal = memo(function SettingsModal({
setSettings,
apiKey,
setApiKey,
availableModels
}: {
isOpen: boolean;
onClose: () => void;
settings: Settings;
setSettings: (settings: Settings) => void;
apiKey: string | null;
setApiKey: (apiKey: string) => void;
availableModels: string[] | null;
}) {
const reset = () => {
if (
Expand Down Expand Up @@ -78,7 +80,7 @@ export const SettingsModal = memo(function SettingsModal({
<LabeledSelect
label="Model"
value={settings.model}
options={SUPPORTED_MODELS}
options={availableModels || [settings.model]}
setValue={(v: string) => {
setSettings({ ...settings, model: v });

Expand Down
2 changes: 0 additions & 2 deletions src/utils/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ export const REACT_FLOW_NODE_TYPES: Record<
LabelUpdater: LabelUpdaterNode,
};

export const SUPPORTED_MODELS = ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k"];

export const DEFAULT_SETTINGS: Settings = {
temp: 1.2,
n: 3,
Expand Down
28 changes: 28 additions & 0 deletions src/utils/models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
export function getAvailableModels(apiKey: string): Promise<string[]> {
return new Promise(async (resolve, reject) => {
try {
const response = await fetch("https://api.openai.com/v1/models", {
method: "GET",
headers: {
Authorization: `Bearer ${apiKey}`,
},
})
const data = await response.json();
resolve(data.data.map((model: any) => model.id).sort());
} catch (err) {
reject(err);
}
});
};

export function getAvailableChatModels(apiKey: string): Promise<string[]> {
return new Promise((resolve, reject) => {
getAvailableModels(apiKey)
.then((models) => {
resolve(models.filter((model) => model.startsWith("gpt-")));
})
.catch((err) => {
reject(err);
});
});
};

0 comments on commit 1dfc26f

Please sign in to comment.