diff --git a/.github/workflows/pr_checks.yaml b/.github/workflows/pr_checks.yaml index c41c57f72b..3c33bfe1cb 100644 --- a/.github/workflows/pr_checks.yaml +++ b/.github/workflows/pr_checks.yaml @@ -96,7 +96,7 @@ jobs: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} install-gui: - needs: [ install-root, install-core ] + needs: [install-root, install-core] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -157,7 +157,7 @@ jobs: npx tsc --noEmit binary-checks: - needs: [ install-root, install-core ] + needs: [install-root, install-core] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -193,7 +193,7 @@ jobs: npx tsc --noEmit install-vscode: - needs: [ install-root, install-core ] + needs: [install-root, install-core] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -274,10 +274,15 @@ jobs: run: | cd core npm test + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} vscode-get-test-file-matrix: runs-on: ubuntu-latest - needs: [ install-root, install-vscode ] + needs: [install-root, install-vscode] outputs: test_file_matrix: ${{ steps.vscode-get-test-file-matrix.outputs.test_file_matrix }} steps: @@ -306,7 +311,7 @@ jobs: vscode-package-extension: runs-on: ubuntu-latest - needs: [ install-vscode, install-core ] + needs: [install-vscode, install-core] steps: - uses: actions/checkout@v4 - uses: actions/setup-node@v4 @@ -338,7 +343,7 @@ jobs: vscode-download-e2e-dependencies: runs-on: ubuntu-latest - needs: [ install-vscode, install-core ] + needs: [install-vscode, install-core] steps: - uses: actions/checkout@v4 - uses: actions/setup-node@v4 @@ -370,7 +375,7 @@ jobs: path: extensions/vscode/e2e/storage vscode-e2e-tests: - name: ${{ matrix.test_file }}" + name: ${{ matrix.test_file }} (${{ matrix.command }}) needs: [ vscode-download-e2e-dependencies, @@ -384,6 +389,7 @@ jobs: fail-fast: false matrix: test_file: ${{ fromJson(needs.vscode-get-test-file-matrix.outputs.test_file_matrix) }} + command: ["e2e:ci:run", "e2e:ci:run-yaml"] steps: - uses: actions/checkout@v4 - uses: actions/setup-node@v4 @@ -432,7 +438,7 @@ jobs: - name: Run e2e tests run: | cd extensions/vscode - TEST_FILE="${{ matrix.test_file }}" npm run e2e:ci:run + TEST_FILE="${{ matrix.test_file }}" npm run ${{ matrix.command }} env: DISPLAY: :99 @@ -445,7 +451,7 @@ jobs: path: extensions/vscode/e2e/storage/screenshots gui-tests: - needs: [ install-gui, install-core ] + needs: [install-gui, install-core] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -476,7 +482,7 @@ jobs: npm test jetbrains-tests: - needs: [ install-root, core-checks ] + needs: [install-root, core-checks] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -573,5 +579,3 @@ jobs: name: jb-failure-report path: | ${{ github.workspace }}/extensions/intellij/build/reports - - diff --git a/.vscode/launch.json b/.vscode/launch.json index dfe474f40b..cdb502c8c1 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -20,10 +20,10 @@ "outFiles": ["${workspaceFolder}/extensions/vscode/out/extension.js"], "preLaunchTask": "vscode-extension:build", "env": { + // "CONTROL_PLANE_ENV": "local", "CONTINUE_GLOBAL_DIR": "${workspaceFolder}/extensions/.continue-debug" // "staging" for the preview deployment "CONTINUE_GLOBAL_DIR": "${workspaceFolder}/extensions/.continue-debug" // "local" for entirely local development of control plane/proxy - // "CONTROL_PLANE_ENV": "staging" } }, { diff --git a/binary/package-lock.json b/binary/package-lock.json index 158342e0b9..89ecc53cc7 100644 --- a/binary/package-lock.json +++ b/binary/package-lock.json @@ -48,7 +48,7 @@ "@aws-sdk/client-sagemaker-runtime": "^3.621.0", "@aws-sdk/credential-providers": "^3.620.1", "@continuedev/config-types": "^1.0.13", - "@continuedev/config-yaml": "^1.0.0", + "@continuedev/config-yaml": "^1.0.10", "@continuedev/fetch": "^1.0.4", "@continuedev/llm-info": "^1.0.2", "@continuedev/openai-adapters": "^1.0.10", diff --git a/core/config/ConfigHandler.test.ts b/core/config/ConfigHandler.test.ts index 5a25e71fc9..deb1dededf 100644 --- a/core/config/ConfigHandler.test.ts +++ b/core/config/ConfigHandler.test.ts @@ -14,7 +14,7 @@ describe.skip("Test the ConfigHandler and E2E config loading", () => { expect(profiles[0].id).toBe("local"); const currentProfile = testConfigHandler.currentProfile; - expect(currentProfile.profileId).toBe("local"); + expect(currentProfile.profileDescription.id).toBe("local"); }); test("should load the default config successfully", async () => { diff --git a/core/config/ConfigHandler.ts b/core/config/ConfigHandler.ts index 2c91148ba0..3e90b8c4f8 100644 --- a/core/config/ConfigHandler.ts +++ b/core/config/ConfigHandler.ts @@ -14,18 +14,21 @@ import Ollama from "../llm/llms/Ollama.js"; import { GlobalContext } from "../util/GlobalContext.js"; import { getConfigJsonPath } from "../util/paths.js"; -import { ConfigResult } from "./load.js"; +import { ConfigResult } from "@continuedev/config-yaml"; +import { pathToFileURL } from "url"; +import { controlPlaneEnv } from "../control-plane/env.js"; +import { usePlatform } from "../control-plane/flags.js"; import { LOCAL_ONBOARDING_CHAT_MODEL, ONBOARDING_LOCAL_MODEL_TITLE, } from "./onboarding.js"; import ControlPlaneProfileLoader from "./profile/ControlPlaneProfileLoader.js"; import LocalProfileLoader from "./profile/LocalProfileLoader.js"; +import PlatformProfileLoader from "./profile/PlatformProfileLoader.js"; import { ProfileDescription, ProfileLifecycleManager, } from "./ProfileLifecycleManager.js"; -import { pathToFileURL } from "url"; export type { ProfileDescription }; @@ -57,7 +60,7 @@ export class ConfigHandler { writeLog, ); this.profiles = [new ProfileLifecycleManager(localProfileLoader)]; - this.selectedProfileId = localProfileLoader.profileId; + this.selectedProfileId = localProfileLoader.description.id; // Always load local profile immediately in case control plane doesn't load try { @@ -77,13 +80,16 @@ export class ConfigHandler { get currentProfile() { return ( - this.profiles.find((p) => p.profileId === this.selectedProfileId) ?? - this.fallbackProfile + this.profiles.find( + (p) => p.profileDescription.id === this.selectedProfileId, + ) ?? this.fallbackProfile ); } get inactiveProfiles() { - return this.profiles.filter((p) => p.profileId !== this.selectedProfileId); + return this.profiles.filter( + (p) => p.profileDescription.id !== this.selectedProfileId, + ); } async openConfigProfile(profileId?: string) { @@ -91,25 +97,23 @@ export class ConfigHandler { if (openProfileId === "local") { await this.ide.openFile(pathToFileURL(getConfigJsonPath()).toString()); } else { - await this.ide.openUrl( - "https://app.continue.dev/", - // `https://app.continue.dev/workspaces/${openProfileId}/chat`, - ); + await this.ide.openUrl(`${controlPlaneEnv.APP_URL}${openProfileId}`); } } - private async fetchControlPlaneProfiles() { + private async loadPlatformProfiles() { // Get the profiles and create their lifecycle managers this.controlPlaneClient - .listWorkspaces() - .then(async (workspaces) => { + .listAssistants() + .then(async (assistants) => { this.profiles = this.profiles.filter( - (profile) => profile.profileId === "local", + (profile) => profile.profileDescription.id === "local", ); - workspaces.forEach((workspace) => { - const profileLoader = new ControlPlaneProfileLoader( - workspace.id, - workspace.name, + assistants.forEach((assistant) => { + const profileLoader = new PlatformProfileLoader( + assistant.configResult, + assistant.ownerSlug, + assistant.packageSlug, this.controlPlaneClient, this.ide, this.ideSettingsPromise, @@ -141,10 +145,67 @@ export class ConfigHandler { } }) .catch((e) => { - console.error(e); + console.error("Failed to list assistants: ", e); }); } + private platformProfilesRefreshInterval: NodeJS.Timeout | undefined; + + private async fetchControlPlaneProfiles() { + if (usePlatform()) { + clearInterval(this.platformProfilesRefreshInterval); + await this.loadPlatformProfiles(); + this.platformProfilesRefreshInterval = setInterval( + this.loadPlatformProfiles.bind(this), + PlatformProfileLoader.RELOAD_INTERVAL, + ); + } else { + this.controlPlaneClient + .listWorkspaces() + .then(async (workspaces) => { + this.profiles = this.profiles.filter( + (profile) => profile.profileDescription.id === "local", + ); + workspaces.forEach((workspace) => { + const profileLoader = new ControlPlaneProfileLoader( + workspace.id, + workspace.name, + this.controlPlaneClient, + this.ide, + this.ideSettingsPromise, + this.writeLog, + this.reloadConfig.bind(this), + ); + this.profiles.push(new ProfileLifecycleManager(profileLoader)); + }); + + this.notifyProfileListeners( + this.profiles.map((profile) => profile.profileDescription), + ); + + // Check the last selected workspace, and reload if it isn't local + const workspaceId = await this.getWorkspaceId(); + const lastSelectedWorkspaceIds = + this.globalContext.get("lastSelectedProfileForWorkspace") ?? {}; + const selectedWorkspaceId = lastSelectedWorkspaceIds[workspaceId]; + if (selectedWorkspaceId) { + this.selectedProfileId = selectedWorkspaceId; + await this.loadConfig(); + } else { + // Otherwise we stick with local profile, and record choice + lastSelectedWorkspaceIds[workspaceId] = this.selectedProfileId; + this.globalContext.update( + "lastSelectedProfileForWorkspace", + lastSelectedWorkspaceIds, + ); + } + }) + .catch((e) => { + console.error(e); + }); + } + } + async setSelectedProfile(profileId: string) { this.selectedProfileId = profileId; const result = await this.loadConfig(); diff --git a/core/config/ProfileLifecycleManager.ts b/core/config/ProfileLifecycleManager.ts index c5ae16f913..88fc387c56 100644 --- a/core/config/ProfileLifecycleManager.ts +++ b/core/config/ProfileLifecycleManager.ts @@ -1,15 +1,17 @@ +import { ConfigResult, ConfigValidationError } from "@continuedev/config-yaml"; import { BrowserSerializedContinueConfig, ContinueConfig, IContextProvider, } from "../index.js"; -import { ConfigResult, finalToBrowserConfig } from "./load.js"; +import { finalToBrowserConfig } from "./load.js"; import { IProfileLoader } from "./profile/IProfileLoader.js"; export interface ProfileDescription { title: string; id: string; + errors: ConfigValidationError[] | undefined; } export class ProfileLifecycleManager { @@ -19,19 +21,8 @@ export class ProfileLifecycleManager { constructor(private readonly profileLoader: IProfileLoader) {} - get profileId() { - return this.profileLoader.profileId; - } - - get profileTitle() { - return this.profileLoader.profileTitle; - } - get profileDescription(): ProfileDescription { - return { - title: this.profileTitle, - id: this.profileId, - }; + return this.profileLoader.description; } clearConfig() { @@ -71,14 +62,10 @@ export class ProfileLifecycleManager { result.config.contextProviders = ( result.config.contextProviders ?? [] ).concat(additionalContextProviders); - - this.savedConfigResult = result; - resolve(result); - } else if (result.errors) { - reject( - `Error in config.json: ${result.errors.map((item) => item.message).join(" | ")}`, - ); } + + this.savedConfigResult = result; + resolve(result); }); // Wait for the config promise to resolve diff --git a/core/config/load.ts b/core/config/load.ts index 9a6599608e..6aaf310468 100644 --- a/core/config/load.ts +++ b/core/config/load.ts @@ -62,6 +62,8 @@ import { getEsbuildBinaryPath, } from "../util/paths"; +import { ConfigResult, ConfigValidationError } from "@continuedev/config-yaml"; +import { usePlatform } from "../control-plane/flags"; import { defaultContextProvidersJetBrains, defaultContextProvidersVsCode, @@ -70,13 +72,7 @@ import { } from "./default"; import { getSystemPromptDotFile } from "./getSystemPromptDotFile"; // import { isSupportedLanceDbCpuTarget } from "./util"; -import { ConfigValidationError, validateConfig } from "./validation.js"; - -export interface ConfigResult { - config: T | undefined; - errors: ConfigValidationError[] | undefined; - configLoadInterrupted: boolean; -} +import { validateConfig } from "./validation.js"; function resolveSerializedConfig(filepath: string): SerializedContinueConfig { let content = fs.readFileSync(filepath, "utf8"); @@ -229,7 +225,7 @@ function isModelDescription( return (llm as ModelDescription).title !== undefined; } -function isContextProviderWithParams( +export function isContextProviderWithParams( contextProvider: CustomContextProvider | ContextProviderWithParams, ): contextProvider is ContextProviderWithParams { return (contextProvider as ContextProviderWithParams).name !== undefined; @@ -582,6 +578,7 @@ function finalToBrowserConfig( experimental: final.experimental, docs: final.docs, tools: final.tools, + usePlatform: usePlatform(), }; } diff --git a/core/config/profile/ControlPlaneProfileLoader.ts b/core/config/profile/ControlPlaneProfileLoader.ts index 448dbdb42d..38c7fb3899 100644 --- a/core/config/profile/ControlPlaneProfileLoader.ts +++ b/core/config/profile/ControlPlaneProfileLoader.ts @@ -7,16 +7,16 @@ import { IdeSettings, SerializedContinueConfig, } from "../../index.js"; -import { ConfigResult } from "../load.js"; +import { ConfigResult } from "@continuedev/config-yaml"; +import { ProfileDescription } from "../ProfileLifecycleManager.js"; import doLoadConfig from "./doLoadConfig.js"; import { IProfileLoader } from "./IProfileLoader.js"; export default class ControlPlaneProfileLoader implements IProfileLoader { private static RELOAD_INTERVAL = 1000 * 60 * 15; // every 15 minutes - readonly profileId: string; - profileTitle: string; + description: ProfileDescription; workspaceSettings: ConfigJson | undefined; @@ -29,12 +29,17 @@ export default class ControlPlaneProfileLoader implements IProfileLoader { private writeLog: (message: string) => Promise, private readonly onReload: () => void, ) { - this.profileId = workspaceId; - this.profileTitle = workspaceTitle; + this.description = { + id: workspaceId, + title: workspaceTitle, + errors: undefined, + }; setInterval(async () => { this.workspaceSettings = - await this.controlPlaneClient.getSettingsForWorkspace(this.profileId); + await this.controlPlaneClient.getSettingsForWorkspace( + this.description.id, + ); this.onReload(); }, ControlPlaneProfileLoader.RELOAD_INTERVAL); } @@ -43,7 +48,7 @@ export default class ControlPlaneProfileLoader implements IProfileLoader { const settings = this.workspaceSettings ?? ((await this.controlPlaneClient.getSettingsForWorkspace( - this.profileId, + this.description.id, )) as any); const serializedConfig: SerializedContinueConfig = settings; @@ -53,6 +58,8 @@ export default class ControlPlaneProfileLoader implements IProfileLoader { this.controlPlaneClient, this.writeLog, serializedConfig, + undefined, + undefined, this.workspaceId, ); diff --git a/core/config/profile/IProfileLoader.ts b/core/config/profile/IProfileLoader.ts index 2d22a821e1..5d07172529 100644 --- a/core/config/profile/IProfileLoader.ts +++ b/core/config/profile/IProfileLoader.ts @@ -1,12 +1,12 @@ // ProfileHandlers manage the loading of a config, allowing us to abstract over different ways of getting to a ContinueConfig +import { ConfigResult } from "@continuedev/config-yaml"; import { ContinueConfig } from "../../index.js"; -import { ConfigResult } from "../load.js"; +import { ProfileDescription } from "../ProfileLifecycleManager.js"; // After we have the ContinueConfig, the ConfigHandler takes care of everything else (loading models, lifecycle, etc.) export interface IProfileLoader { - profileTitle: string; - profileId: string; + description: ProfileDescription; doLoadConfig(): Promise>; setIsActive(isActive: boolean): void; } diff --git a/core/config/profile/LocalProfileLoader.ts b/core/config/profile/LocalProfileLoader.ts index 0705c87456..644e97069f 100644 --- a/core/config/profile/LocalProfileLoader.ts +++ b/core/config/profile/LocalProfileLoader.ts @@ -1,14 +1,18 @@ +import { ConfigResult } from "@continuedev/config-yaml"; import { ControlPlaneClient } from "../../control-plane/client.js"; import { ContinueConfig, IDE, IdeSettings } from "../../index.js"; -import { ConfigResult } from "../load.js"; +import { ProfileDescription } from "../ProfileLifecycleManager.js"; import doLoadConfig from "./doLoadConfig.js"; import { IProfileLoader } from "./IProfileLoader.js"; export default class LocalProfileLoader implements IProfileLoader { static ID = "local"; - profileId = LocalProfileLoader.ID; - profileTitle = "Local Config"; + description: ProfileDescription = { + id: LocalProfileLoader.ID, + title: "Local Config", + errors: undefined, + }; constructor( private ide: IDE, @@ -18,13 +22,19 @@ export default class LocalProfileLoader implements IProfileLoader { ) {} async doLoadConfig(): Promise> { - return doLoadConfig( + const result = await doLoadConfig( this.ide, this.ideSettingsPromise, this.controlPlaneClient, this.writeLog, undefined, + undefined, + undefined, ); + + this.description.errors = result.errors; + + return result; } setIsActive(isActive: boolean): void {} diff --git a/core/config/profile/PlatformProfileLoader.ts b/core/config/profile/PlatformProfileLoader.ts new file mode 100644 index 0000000000..d0f7d6f802 --- /dev/null +++ b/core/config/profile/PlatformProfileLoader.ts @@ -0,0 +1,90 @@ +import { ClientConfigYaml } from "@continuedev/config-yaml/dist/schemas/index.js"; + +import { ControlPlaneClient } from "../../control-plane/client.js"; +import { ContinueConfig, IDE, IdeSettings } from "../../index.js"; + +import { ConfigResult } from "@continuedev/config-yaml"; +import { ProfileDescription } from "../ProfileLifecycleManager.js"; +import doLoadConfig from "./doLoadConfig.js"; +import { IProfileLoader } from "./IProfileLoader.js"; + +/** + * Metadata about the package that is currently being loaded + * If this is `undefined`, it's not a config from the platform, + * could be local for example. + */ +export interface PlatformConfigMetadata { + ownerSlug: string; + packageSlug: string; +} + +export default class PlatformProfileLoader implements IProfileLoader { + static RELOAD_INTERVAL = 1000 * 60 * 15; // every 15 minutes + + description: ProfileDescription; + + constructor( + private configResult: ConfigResult, + private readonly ownerSlug: string, + private readonly packageSlug: string, + private readonly controlPlaneClient: ControlPlaneClient, + private readonly ide: IDE, + private ideSettingsPromise: Promise, + private writeLog: (message: string) => Promise, + private readonly onReload: () => void, + ) { + this.description = { + id: `${ownerSlug}/${packageSlug}`, + title: `${ownerSlug}/${packageSlug}`, + errors: configResult.errors, + }; + + setInterval(async () => { + const assistants = await this.controlPlaneClient.listAssistants(); + const newConfigResult = assistants.find( + (assistant) => + assistant.packageSlug === this.packageSlug && + assistant.ownerSlug === this.ownerSlug, + )?.configResult; + if (!newConfigResult) { + return; + } + this.configResult = { + config: newConfigResult.config, + errors: newConfigResult.errors, + configLoadInterrupted: false, + }; + this.onReload(); + }, PlatformProfileLoader.RELOAD_INTERVAL); + } + + async doLoadConfig(): Promise> { + if (this.configResult.errors?.length) { + return { + config: undefined, + errors: this.configResult.errors, + configLoadInterrupted: false, + }; + } + + const results = await doLoadConfig( + this.ide, + this.ideSettingsPromise, + this.controlPlaneClient, + this.writeLog, + undefined, + this.configResult.config, + { + ownerSlug: this.ownerSlug, + packageSlug: this.packageSlug, + }, + ); + + return { + ...results, + errors: [], // Don't do config validation here, it happens in admin panel + }; + } + + setIsActive(isActive: boolean): void {} +} diff --git a/core/config/profile/doLoadConfig.ts b/core/config/profile/doLoadConfig.ts index a82dd896de..5f20ac7e29 100644 --- a/core/config/profile/doLoadConfig.ts +++ b/core/config/profile/doLoadConfig.ts @@ -1,5 +1,7 @@ import fs from "fs"; +import { ConfigResult, ConfigValidationError } from "@continuedev/config-yaml"; +import { ClientConfigYaml } from "@continuedev/config-yaml/dist/schemas"; import { ContinueConfig, ContinueRcJson, @@ -15,9 +17,9 @@ import ContinueProxy from "../../llm/llms/stubs/ContinueProxy"; import { getConfigYamlPath } from "../../util/paths"; import { Telemetry } from "../../util/posthog"; import { TTS } from "../../util/tts"; -import { ConfigResult, loadFullConfigNode } from "../load"; -import { ConfigValidationError } from "../validation"; +import { loadFullConfigNode } from "../load"; import { loadContinueConfigFromYaml } from "../yaml/loadYaml"; +import { PlatformConfigMetadata } from "./PlatformProfileLoader"; export default async function doLoadConfig( ide: IDE, @@ -25,6 +27,8 @@ export default async function doLoadConfig( controlPlaneClient: ControlPlaneClient, writeLog: (message: string) => Promise, overrideConfigJson: SerializedContinueConfig | undefined, + overrideConfigYaml: ClientConfigYaml | undefined, + platformConfigMetadata: PlatformConfigMetadata | undefined, workspaceId?: string, ): Promise> { const workspaceConfigs = await getWorkspaceConfigs(ide); @@ -39,7 +43,7 @@ export default async function doLoadConfig( let errors: ConfigValidationError[] | undefined; let configLoadInterrupted = false; - if (fs.existsSync(configYamlPath)) { + if (fs.existsSync(configYamlPath) || overrideConfigYaml) { const result = await loadContinueConfigFromYaml( ide, workspaceConfigs.map((c) => JSON.stringify(c)), @@ -48,8 +52,9 @@ export default async function doLoadConfig( uniqueId, writeLog, workOsAccessToken, - undefined, - // overrideConfigYaml, TODO + overrideConfigYaml, + platformConfigMetadata, + controlPlaneClient, ); newConfig = result.config; errors = result.errors; diff --git a/core/config/validation.ts b/core/config/validation.ts index e10b3f5232..d28bb7809a 100644 --- a/core/config/validation.ts +++ b/core/config/validation.ts @@ -1,11 +1,7 @@ +import { ConfigValidationError } from "@continuedev/config-yaml"; import { ModelDescription, SerializedContinueConfig } from "../"; import { Telemetry } from "../util/posthog"; -export interface ConfigValidationError { - fatal: boolean; - message: string; -} - /** * Validates a SerializedContinueConfig object to ensure all properties are correctly formed. * @param config The configuration object to validate. diff --git a/core/config/yaml/convertFromJson.ts b/core/config/yaml/convertFromJson.ts index 59482b5c39..4109d7d377 100644 --- a/core/config/yaml/convertFromJson.ts +++ b/core/config/yaml/convertFromJson.ts @@ -8,6 +8,7 @@ export function convertConfigJsonToConfigYaml( ): ConfigYaml { return { name: "Local Config", + version: "1.0.0", models: [ ...configJson.models.map( (model): ModelConfig => ({ diff --git a/core/config/yaml/default.ts b/core/config/yaml/default.ts index f991561df4..cfdffc12de 100644 --- a/core/config/yaml/default.ts +++ b/core/config/yaml/default.ts @@ -5,10 +5,12 @@ export const defaultConfigYaml: ConfigYaml = { models: [], context: [], name: "Local Config", + version: "1.0.0", }; export const defaultConfigYamlJetBrains: ConfigYaml = { models: [], context: [], name: "Local Config", + version: "1.0.0", }; diff --git a/core/config/yaml/loadYaml.ts b/core/config/yaml/loadYaml.ts index ad33ea4d85..5e0161ad80 100644 --- a/core/config/yaml/loadYaml.ts +++ b/core/config/yaml/loadYaml.ts @@ -1,13 +1,12 @@ import fs from "node:fs"; -import path from "path"; import { - extendConfig, + ConfigResult, fillTemplateVariables, + resolveSecretsOnClient, validateConfigYaml, } from "@continuedev/config-yaml"; -import { ConfigYaml } from "@continuedev/config-yaml/dist/schemas"; -import { ValidationLevel } from "@continuedev/config-yaml/dist/validation"; +import { ClientConfigYaml } from "@continuedev/config-yaml/dist/schemas"; import { fetchwithRequestOptions } from "@continuedev/fetch"; import * as YAML from "yaml"; @@ -26,23 +25,17 @@ import { contextProviderClassFromName } from "../../context/providers/index"; import { allEmbeddingsProviders } from "../../indexing/allEmbeddingsProviders"; import FreeTrial from "../../llm/llms/FreeTrial"; import TransformersJsEmbeddingsProvider from "../../llm/llms/TransformersJsEmbeddingsProvider"; -import { - getConfigYamlPath, - getContinueDotEnv, - readAllGlobalPromptFiles, -} from "../../util/paths"; +import { slashCommandFromPromptFileV1 } from "../../promptFiles/v1/slashCommandFromPromptFile"; +import { getAllPromptFiles } from "../../promptFiles/v2/getPromptFiles"; +import { getConfigYamlPath, getContinueDotEnv } from "../../util/paths"; import { getSystemPromptDotFile } from "../getSystemPromptDotFile"; -import { ConfigValidationError } from "../validation.js"; +import { PlatformConfigMetadata } from "../profile/PlatformProfileLoader"; +import CodebaseContextProvider from "../../context/providers/CodebaseContextProvider"; +import FileContextProvider from "../../context/providers/FileContextProvider"; +import PromptFilesContextProvider from "../../context/providers/PromptFilesContextProvider"; +import { ControlPlaneClient } from "../../control-plane/client"; import { llmsFromModelConfig } from "./models"; -import { getAllPromptFiles } from "../../promptFiles/v2/getPromptFiles"; -import { slashCommandFromPromptFileV1 } from "../../promptFiles/v1/slashCommandFromPromptFile"; - -export interface ConfigResult { - config: T | undefined; - errors: ConfigValidationError[] | undefined; - configLoadInterrupted: boolean; -} function renderTemplateVars(configYaml: string): string { const data: Record = {}; @@ -60,36 +53,28 @@ function renderTemplateVars(configYaml: string): string { function loadConfigYaml( workspaceConfigs: string[], - ideSettings: IdeSettings, - ideType: IdeType, rawYaml: string, -): ConfigResult { - const renderedYaml = renderTemplateVars(rawYaml); - let config = YAML.parse(renderedYaml) as ConfigYaml; + overrideConfigYaml: ClientConfigYaml | undefined, +): ConfigResult { + let config = + overrideConfigYaml ?? + (YAML.parse(renderTemplateVars(rawYaml)) as ClientConfigYaml); const errors = validateConfigYaml(config); - if (errors?.some((error) => error.level === ValidationLevel.Error)) { + if (errors?.some((error) => error.fatal)) { return { - errors: errors.map((error) => ({ - message: error.message, - fatal: error.level === ValidationLevel.Error, - })), + errors, config: undefined, configLoadInterrupted: true, }; } - for (const workspaceConfig of workspaceConfigs) { - const rendered = renderTemplateVars(workspaceConfig); - config = extendConfig(config, YAML.parse(rendered) as ConfigYaml); - } - // Set defaults if undefined (this lets us keep config.json uncluttered for new users) return { config, errors: errors.map((error) => ({ message: error.message, - fatal: error.level === ValidationLevel.Error, + fatal: error.fatal, })), configLoadInterrupted: false, }; @@ -113,12 +98,13 @@ async function slashCommandsFromV1PromptFiles( } async function configYamlToContinueConfig( - config: ConfigYaml, + config: ClientConfigYaml, ide: IDE, ideSettings: IdeSettings, uniqueId: string, writeLog: (log: string) => Promise, workOsAccessToken: string | undefined, + platformConfigMetadata: PlatformConfigMetadata | undefined, allowFreeTrial: boolean = true, ): Promise { const continueConfig: ContinueConfig = { @@ -126,10 +112,24 @@ async function configYamlToContinueConfig( models: [], tabAutocompleteModels: [], tools: [], + systemMessage: config.rules?.join("\n"), embeddingsProvider: new TransformersJsEmbeddingsProvider(), experimental: { - modelContextProtocolServers: [], + modelContextProtocolServers: config.mcpServers?.map((mcpServer) => ({ + transport: { + type: "stdio", + command: mcpServer.command, + args: mcpServer.args ?? [], + env: mcpServer.env, + }, + })), }, + docs: config.docs?.map((doc) => ({ + title: doc.name, + startUrl: doc.startUrl, + rootUrl: doc.rootUrl, + faviconUrl: doc.faviconUrl, + })), }; // Models @@ -146,9 +146,13 @@ async function configYamlToContinueConfig( uniqueId, ideSettings, writeLog, + platformConfigMetadata, + continueConfig.systemMessage, ); continueConfig.models.push(...llms); - } else if (model.roles?.includes("autocomplete")) { + } + + if (model.roles?.includes("autocomplete")) { // Autocomplete models array const llms = await llmsFromModelConfig( model, @@ -156,6 +160,8 @@ async function configYamlToContinueConfig( uniqueId, ideSettings, writeLog, + platformConfigMetadata, + continueConfig.systemMessage, ); continueConfig.tabAutocompleteModels?.push(...llms); } @@ -181,18 +187,33 @@ async function configYamlToContinueConfig( // TODO: Split into model roles. - // Context + // Context providers + const codebaseContextParams: IContextProvider[] = + (config.context || []).find((cp) => cp.uses === "codebase")?.with || {}; + const DEFAULT_CONTEXT_PROVIDERS = [ + new FileContextProvider({}), + new CodebaseContextProvider(codebaseContextParams), + new PromptFilesContextProvider({}), + ]; + + const DEFAULT_CONTEXT_PROVIDERS_TITLES = DEFAULT_CONTEXT_PROVIDERS.map( + ({ description: { title } }) => title, + ); + continueConfig.contextProviders = config.context ?.map((context) => { const cls = contextProviderClassFromName(context.uses) as any; if (!cls) { - console.warn(`Unknown context provider ${context.uses}`); + if (!DEFAULT_CONTEXT_PROVIDERS_TITLES.includes(context.uses)) { + console.warn(`Unknown context provider ${context.uses}`); + } return undefined; } const instance: IContextProvider = new cls(context.with ?? {}); return instance; }) .filter((p) => !!p) as IContextProvider[]; + continueConfig.contextProviders.push(...DEFAULT_CONTEXT_PROVIDERS); // Embeddings Provider const embedConfig = config.models?.find((model) => @@ -283,16 +304,20 @@ export async function loadContinueConfigFromYaml( uniqueId: string, writeLog: (log: string) => Promise, workOsAccessToken: string | undefined, - overrideConfigYaml: string | undefined, + overrideConfigYaml: ClientConfigYaml | undefined, + platformConfigMetadata: PlatformConfigMetadata | undefined, + controlPlaneClient: ControlPlaneClient, ): Promise> { const configYamlPath = getConfigYamlPath(ideType); - const rawYaml = fs.readFileSync(configYamlPath, "utf-8"); + const rawYaml = + overrideConfigYaml === undefined + ? fs.readFileSync(configYamlPath, "utf-8") + : ""; - const configYamlResult = await loadConfigYaml( + const configYamlResult = loadConfigYaml( workspaceConfigs, - ideSettings, - ideType, - overrideConfigYaml ?? rawYaml, + rawYaml, + overrideConfigYaml, ); if (!configYamlResult.config || configYamlResult.configLoadInterrupted) { @@ -303,13 +328,24 @@ export async function loadContinueConfigFromYaml( }; } - const continueConfig = await configYamlToContinueConfig( + const configYaml = await resolveSecretsOnClient( configYamlResult.config, + ide.readSecrets.bind(ide), + async (secretNames: string[]) => { + const secretValues = await controlPlaneClient.syncSecrets(secretNames); + await ide.writeSecrets(secretValues); + return secretValues; + }, + ); + + const continueConfig = await configYamlToContinueConfig( + configYaml, ide, ideSettings, uniqueId, writeLog, workOsAccessToken, + platformConfigMetadata, ); const systemPromptDotFile = await getSystemPromptDotFile(ide); diff --git a/core/config/yaml/models.ts b/core/config/yaml/models.ts index d6eb6d3e07..309e061d19 100644 --- a/core/config/yaml/models.ts +++ b/core/config/yaml/models.ts @@ -3,21 +3,59 @@ import { ModelConfig } from "@continuedev/config-yaml"; import { IDE, IdeSettings, LLMOptions } from "../.."; import { BaseLLM } from "../../llm"; import { LLMClasses } from "../../llm/llms"; +import ContinueProxy from "../../llm/llms/stubs/ContinueProxy"; +import { PlatformConfigMetadata } from "../profile/PlatformProfileLoader"; const AUTODETECT = "AUTODETECT"; +function useContinueProxy( + model: ModelConfig, + platformConfigMetadata: PlatformConfigMetadata | undefined, +): boolean { + return !!platformConfigMetadata && model.apiKeySecret !== undefined; +} + +function getModelClass( + model: ModelConfig, + platformConfigMetadata: PlatformConfigMetadata | undefined, +): (typeof LLMClasses)[number] | undefined { + if (useContinueProxy(model, platformConfigMetadata)) { + return ContinueProxy; + } + return LLMClasses.find((llm) => llm.providerName === model.provider); +} + +function getContinueProxyModelName( + ownerSlug: string, + packageSlug: string, + model: ModelConfig, +): string { + return `${ownerSlug}/${packageSlug}/${model.provider}/${model.model}`; +} + async function modelConfigToBaseLLM( model: ModelConfig, uniqueId: string, ideSettings: IdeSettings, writeLog: (log: string) => Promise, + platformConfigMetadata: PlatformConfigMetadata | undefined, + systemMessage: string | undefined, ): Promise { - const cls = LLMClasses.find((llm) => llm.providerName === model.provider); + const cls = getModelClass(model, platformConfigMetadata); if (!cls) { return undefined; } + const usingContinueProxy = useContinueProxy(model, platformConfigMetadata); + const modelName = usingContinueProxy + ? getContinueProxyModelName( + platformConfigMetadata!.ownerSlug, + platformConfigMetadata!.packageSlug, + model, + ) + : model.model; + let options: LLMOptions = { ...model, completionOptions: { @@ -30,6 +68,8 @@ async function modelConfigToBaseLLM( writeLog, uniqueId, title: model.name, + model: modelName, + systemMessage, }; const llm = new cls(options); @@ -43,6 +83,8 @@ async function autodetectModels( uniqueId: string, ideSettings: IdeSettings, writeLog: (log: string) => Promise, + platformConfigMetadata: PlatformConfigMetadata | undefined, + systemMessage: string | undefined, ): Promise { try { const modelNames = await llm.listModels(); @@ -62,6 +104,8 @@ async function autodetectModels( uniqueId, ideSettings, writeLog, + platformConfigMetadata, + systemMessage, ); }), ); @@ -78,12 +122,16 @@ export async function llmsFromModelConfig( uniqueId: string, ideSettings: IdeSettings, writeLog: (log: string) => Promise, + platformConfigMetadata: PlatformConfigMetadata | undefined, + systemMessage: string | undefined, ): Promise { const baseLlm = await modelConfigToBaseLLM( model, uniqueId, ideSettings, writeLog, + platformConfigMetadata, + systemMessage, ); if (!baseLlm) { return []; @@ -97,6 +145,8 @@ export async function llmsFromModelConfig( uniqueId, ideSettings, writeLog, + platformConfigMetadata, + systemMessage, ); return models; } else { diff --git a/core/context/mcp/index.ts b/core/context/mcp/index.ts index 26be2f471c..304c2a8c4b 100644 --- a/core/context/mcp/index.ts +++ b/core/context/mcp/index.ts @@ -4,9 +4,9 @@ import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js" import { WebSocketClientTransport } from "@modelcontextprotocol/sdk/client/websocket.js"; import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; +import { ConfigValidationError } from "@continuedev/config-yaml"; import { ContinueConfig, MCPOptions, SlashCommand, Tool } from "../.."; import { constructMcpSlashCommand } from "../../commands/slash/mcp"; -import { ConfigValidationError } from "../../config/validation"; import { encodeMCPToolUri } from "../../tools/callTool"; import MCPContextProvider from "../providers/MCPContextProvider"; diff --git a/core/control-plane/auth/index.ts b/core/control-plane/auth/index.ts index ecb6221c74..c1a6f881db 100644 --- a/core/control-plane/auth/index.ts +++ b/core/control-plane/auth/index.ts @@ -1,5 +1,7 @@ import { v4 as uuidv4 } from "uuid"; +import { controlPlaneEnv } from "../env"; + const CLIENT_ID = "client_01J0FW6XN8N2XJAECF7NE0Y65J"; // const CLIENT_ID = "client_01J0FW6XCPMJMQ3CG51RB4HBZQ"; // Staging @@ -8,7 +10,7 @@ export async function getAuthUrlForTokenPage(): Promise { const params = { response_type: "code", client_id: CLIENT_ID, - redirect_uri: "https://app.continue.dev/tokens/callback", + redirect_uri: `${controlPlaneEnv.APP_URL}tokens/callback`, // redirect_uri: "http://localhost:3000/tokens/callback", state: uuidv4(), provider: "authkit", diff --git a/core/control-plane/client.ts b/core/control-plane/client.ts index b84e245702..3a29b3fee9 100644 --- a/core/control-plane/client.ts +++ b/core/control-plane/client.ts @@ -1,8 +1,10 @@ import { ConfigJson } from "@continuedev/config-types"; +import { ClientConfigYaml } from "@continuedev/config-yaml/dist/schemas/index.js"; import fetch, { RequestInit, Response } from "node-fetch"; import { ModelDescription } from "../index.js"; +import { ConfigResult } from "@continuedev/config-yaml"; import { controlPlaneEnv } from "./env.js"; export interface ControlPlaneSessionInfo { @@ -85,6 +87,29 @@ export class ControlPlaneClient { } } + public async listAssistants(): Promise< + { + configResult: ConfigResult; + ownerSlug: string; + packageSlug: string; + iconUrl: string; + }[] + > { + const userId = await this.userId; + if (!userId) { + return []; + } + + try { + const resp = await this.request("ide/list-assistants", { + method: "GET", + }); + return (await resp.json()) as any; + } catch (e) { + return []; + } + } + async getSettingsForWorkspace(workspaceId: string): Promise { const userId = await this.userId; if (!userId) { @@ -96,4 +121,21 @@ export class ControlPlaneClient { }); return ((await resp.json()) as any).settings; } + + async syncSecrets(secretNames: string[]): Promise> { + const userId = await this.userId; + if (!userId) { + throw new Error("No user id"); + } + + try { + const resp = await this.request("ide/sync-secrets", { + method: "POST", + body: JSON.stringify({ secretNames }), + }); + return (await resp.json()) as any; + } catch (e) { + return {}; + } + } } diff --git a/core/control-plane/env.ts b/core/control-plane/env.ts index 4556ce90f8..16adf0990f 100644 --- a/core/control-plane/env.ts +++ b/core/control-plane/env.ts @@ -1,3 +1,6 @@ +import { readUsePlatform } from "../util/paths"; +import { usePlatform } from "./flags"; + interface ControlPlaneEnv { DEFAULT_CONTROL_PLANE_PROXY_URL: string; CONTROL_PLANE_URL: string; @@ -21,7 +24,7 @@ const PRODUCTION_ENV: ControlPlaneEnv = { "https://control-plane-api-service-i3dqylpbqa-uc.a.run.app/", AUTH_TYPE: WORKOS_ENV_ID_PRODUCTION, WORKOS_CLIENT_ID: WORKOS_CLIENT_ID_PRODUCTION, - APP_URL: "https://app.continue.dev", + APP_URL: "https://app.continue.dev/", }; const STAGING_ENV: ControlPlaneEnv = { @@ -31,7 +34,15 @@ const STAGING_ENV: ControlPlaneEnv = { "https://control-plane-api-service-537175798139.us-central1.run.app/", AUTH_TYPE: WORKOS_CLIENT_ID_STAGING, WORKOS_CLIENT_ID: WORKOS_CLIENT_ID_STAGING, - APP_URL: "https://app-preview.continue.dev", + APP_URL: "https://app-preview.continue.dev/", +}; + +const TEST_ENV: ControlPlaneEnv = { + DEFAULT_CONTROL_PLANE_PROXY_URL: "https://api-test.continue.dev/", + CONTROL_PLANE_URL: "https://api-test.continue.dev/", + AUTH_TYPE: WORKOS_ENV_ID_STAGING, + WORKOS_CLIENT_ID: WORKOS_CLIENT_ID_STAGING, + APP_URL: "https://app-test.continue.dev/", }; const LOCAL_ENV: ControlPlaneEnv = { @@ -39,12 +50,20 @@ const LOCAL_ENV: ControlPlaneEnv = { CONTROL_PLANE_URL: "http://localhost:3001/", AUTH_TYPE: WORKOS_ENV_ID_STAGING, WORKOS_CLIENT_ID: WORKOS_CLIENT_ID_STAGING, - APP_URL: "http://localhost:3000", + APP_URL: "http://localhost:3000/", }; -export const controlPlaneEnv = - process.env.CONTROL_PLANE_ENV === "local" +function getControlPlaneEnv(): ControlPlaneEnv { + const usePlatformFileEnv = readUsePlatform(); + const env = usePlatformFileEnv || process.env.CONTROL_PLANE_ENV; + + return env === "local" ? LOCAL_ENV - : process.env.CONTROL_PLANE_ENV === "staging" + : env === "staging" ? STAGING_ENV - : PRODUCTION_ENV; + : env === "test" || usePlatform() + ? TEST_ENV + : PRODUCTION_ENV; +} + +export const controlPlaneEnv = getControlPlaneEnv(); diff --git a/core/control-plane/flags.ts b/core/control-plane/flags.ts new file mode 100644 index 0000000000..7b8ff6799e --- /dev/null +++ b/core/control-plane/flags.ts @@ -0,0 +1,13 @@ +import { readUsePlatform, usePlatformPathExists } from "../util/paths"; + +export function usePlatform(): boolean { + return usePlatformPathExists(); +} + +export function getEnvFromUsePlatformFile(): string | undefined { + const contents = readUsePlatform(); + if (contents && contents.trim().length > 0) { + return contents.trim(); + } + return undefined; +} diff --git a/core/core.ts b/core/core.ts index 7316f9aac2..dde3091ee9 100644 --- a/core/core.ts +++ b/core/core.ts @@ -44,7 +44,7 @@ import { getSymbolsForManyFiles } from "./util/treeSitter"; import { TTS } from "./util/tts"; import { type ContextItemId, type IDE, type IndexingProgressUpdate } from "."; - +import { usePlatform } from "./control-plane/flags"; import type { FromCoreProtocol, ToCoreProtocol } from "./protocol"; import type { IMessenger, Message } from "./protocol/messenger"; @@ -103,6 +103,7 @@ export class Core { const ideSettingsPromise = messenger.request("getIdeSettings", undefined); const sessionInfoPromise = messenger.request("getControlPlaneSessionInfo", { silent: true, + useOnboarding: usePlatform(), }); this.controlPlaneClient = new ControlPlaneClient(sessionInfoPromise); @@ -124,7 +125,7 @@ export class Core { const serializedResult = await this.configHandler.getSerializedConfig(); this.messenger.send("configUpdate", { result: serializedResult, - profileId: this.configHandler.currentProfile.profileId, + profileId: this.configHandler.currentProfile.profileDescription.id, }); }); @@ -378,7 +379,7 @@ export class Core { on("config/getSerializedProfileInfo", async (msg) => { return { result: await this.configHandler.getSerializedConfig(), - profileId: this.configHandler.currentProfile.profileId, + profileId: this.configHandler.currentProfile.profileDescription.id, }; }); diff --git a/core/index.d.ts b/core/index.d.ts index cd6510f197..518e301d96 100644 --- a/core/index.d.ts +++ b/core/index.d.ts @@ -460,6 +460,7 @@ export interface LLMOptions { writeLog?: (str: string) => Promise; llmRequestHook?: (model: string, prompt: string) => any; apiKey?: string; + apiKeySecret?: string; aiGatewaySlug?: string; apiBase?: string; cacheBehavior?: CacheBehavior; @@ -683,6 +684,11 @@ export interface IDE { getGitHubAuthToken(args: GetGhTokenArgs): Promise; + // Secret Storage + readSecrets(keys: string[]): Promise>; + + writeSecrets(secrets: { [key: string]: string }): Promise; + // LSP gotoDefinition(location: Location): Promise; @@ -854,6 +860,13 @@ export interface Tool { uri?: string; } +interface ToolChoice { + type: "function"; + function: { + name: string; + }; +} + export interface BaseCompletionOptions { temperature?: number; topP?: number; @@ -871,6 +884,7 @@ export interface BaseCompletionOptions { stream?: boolean; prediction?: Prediction; tools?: Tool[]; + toolChoice?: ToolChoice; } export interface ModelCapability { @@ -882,6 +896,7 @@ export interface ModelDescription { provider: string; model: string; apiKey?: string; + apiKeySecret?: string; apiBase?: string; contextLength?: number; maxStopWords?: number; @@ -1198,6 +1213,7 @@ export interface BrowserSerializedContinueConfig { analytics?: AnalyticsConfig; docs?: SiteIndexingConfig[]; tools: Tool[]; + usePlatform: boolean; } // DOCS SUGGESTIONS AND PACKAGE INFO @@ -1235,9 +1251,9 @@ export type PackageDocsResult = { } & ( | { error: string; details?: never } | { details: PackageDetailsSuccess; error?: never } - ); +); export interface TerminalOptions { - reuseTerminal?: boolean, - terminalName?: string, + reuseTerminal?: boolean; + terminalName?: string; } diff --git a/core/indexing/docs/DocsService.ts b/core/indexing/docs/DocsService.ts index 105d24b9e9..cab6c35dad 100644 --- a/core/indexing/docs/DocsService.ts +++ b/core/indexing/docs/DocsService.ts @@ -25,7 +25,7 @@ import { } from "../../util/paths"; import { Telemetry } from "../../util/posthog"; -import { ConfigResult } from "../../config/load"; +import { ConfigResult } from "@continuedev/config-yaml"; import { Article, chunkArticle, pageToArticle } from "./article"; import DocsCrawler from "./DocsCrawler"; import { runLanceMigrations, runSqliteMigrations } from "./migrations"; diff --git a/core/llm/index.ts b/core/llm/index.ts index 11b66ce8d3..5ecfa33ffa 100644 --- a/core/llm/index.ts +++ b/core/llm/index.ts @@ -117,6 +117,7 @@ export abstract class BaseLLM implements ILLM { writeLog?: (str: string) => Promise; llmRequestHook?: (model: string, prompt: string) => any; apiKey?: string; + apiKeySecret?: string; apiBase?: string; cacheBehavior?: CacheBehavior; capabilities?: ModelCapability; @@ -194,6 +195,7 @@ export abstract class BaseLLM implements ILLM { this.writeLog = options.writeLog; this.llmRequestHook = options.llmRequestHook; this.apiKey = options.apiKey; + this.apiKeySecret = options.apiKeySecret; this.aiGatewaySlug = options.aiGatewaySlug; this.apiBase = options.apiBase; this.cacheBehavior = options.cacheBehavior; diff --git a/core/llm/llm.test.ts b/core/llm/llm.test.ts index cb2218d1df..09777dae00 100644 --- a/core/llm/llm.test.ts +++ b/core/llm/llm.test.ts @@ -1,13 +1,17 @@ import * as dotenv from "dotenv"; -import { CompletionOptions } from ".."; +import { AssistantChatMessage, CompletionOptions } from ".."; + import { BaseLLM } from "."; +import Anthropic from "./llms/Anthropic"; +import Gemini from "./llms/Gemini"; +import Mistral from "./llms/Mistral"; import OpenAI from "./llms/OpenAI"; dotenv.config(); const COMPLETION_OPTIONS: Partial = { - temperature: 0.5, + // temperature: 0.5, topP: 1, topK: 40, presencePenalty: 0, @@ -16,12 +20,23 @@ const COMPLETION_OPTIONS: Partial = { // maxTokens: 5, }; -function testLLM(llm: BaseLLM) { +function testLLM( + llm: BaseLLM, + { + skip, + testFim, + testToolCall, + }: { skip?: boolean; testFim?: boolean; testToolCall?: boolean }, +) { + if (skip) { + return; + } + beforeAll(() => { llm.completionOptions = { ...llm.completionOptions, ...COMPLETION_OPTIONS }; }); - describe(llm.providerName, () => { + describe(llm.providerName + "/" + llm.model, () => { test("Stream Chat works", async () => { let total = ""; for await (const chunk of llm.streamChat( @@ -54,62 +69,119 @@ function testLLM(llm: BaseLLM) { expect(completion.length).toBeGreaterThan(0); return; }); + + if (testFim) { + test("FIM works", async () => { + let total = ""; + for await (const chunk of llm.streamFim( + "Hi", + "name is ChatGPT.", + new AbortController().signal, + )) { + total += chunk; + } + + expect(total.length).toBeGreaterThan(0); + return; + }); + } + + if (testToolCall) { + test("Tool Call works", async () => { + let args = ""; + let isFirstChunk = true; + for await (const chunk of llm.streamChat( + [{ role: "user", content: "Hi, my name is Nate." }], + new AbortController().signal, + { + tools: [ + { + displayTitle: "Say Hello", + function: { + name: "say_hello", + description: "Say Hello", + parameters: { + type: "object", + properties: { + name: { + type: "string", + description: "The name of the person to greet", + }, + }, + }, + }, + type: "function", + wouldLikeTo: "Say hello", + readonly: true, + }, + ], + toolChoice: { + type: "function", + function: { + name: "say_hello", + }, + }, + }, + )) { + const typedChunk = chunk as AssistantChatMessage; + if (!typedChunk.toolCalls) { + continue; + } + const toolCall = typedChunk.toolCalls[0]; + args += toolCall.function?.arguments ?? ""; + + expect(chunk.role).toBe("assistant"); + expect(chunk.content).toBe(""); + expect(typedChunk.toolCalls).toHaveLength(1); + + if (isFirstChunk) { + isFirstChunk = false; + expect(toolCall.id).toBeDefined(); + expect(toolCall.function!.name).toBe("say_hello"); + } + } + + const parsedArgs = JSON.parse(args); + expect(parsedArgs.name).toBe("Nate"); + }); + } }); } -describe.skip("LLM", () => { - // testLLM( - // new FreeTrial({ - // model: "gpt-3.5-turbo", - // }), - // ); - // testLLM( - // new Anthropic({ - // model: "claude-2", - // apiKey: process.env.ANTHROPIC_API_KEY, - // }), - // ); +describe("LLM", () => { + testLLM( + new Anthropic({ + model: "claude-3-5-sonnet-latest", + apiKey: process.env.ANTHROPIC_API_KEY, + }), + { + skip: false, + testToolCall: true, + }, + ); + testLLM(new OpenAI({ apiKey: process.env.OPENAI_API_KEY, model: "gpt-4o" }), { + skip: false, + testToolCall: true, + }); + testLLM( + new OpenAI({ apiKey: process.env.OPENAI_API_KEY, model: "o1-preview" }), + { skip: false }, + ); + testLLM(new OpenAI({ apiKey: process.env.OPENAI_API_KEY, model: "o1" }), { + skip: false, + }); + testLLM( + new Gemini({ + model: "gemini-2.0-flash-exp", + apiKey: process.env.GEMINI_API_KEY, + }), + { skip: false }, + ); testLLM( - new OpenAI({ apiKey: process.env.OPENAI_API_KEY, model: "gpt-3.5-turbo" }), + new Mistral({ + apiKey: process.env.MISTRAL_API_KEY, + model: "codestral-latest", + }), + { testFim: true, skip: false }, ); - // TODO: Fix Replicate - // testLLM( - // new Replicate({ - // apiKey: process.env.REPLICATE_API_KEY, - // model: "codellama-7b", - // }) - // ); - // testLLM(new LMStudio({ model: "codellama-7b" })); - // testLLM(new Ollama({ model: "codellama-7b" })); - // testLLM( - // new Together({ - // apiKey: process.env.TOGETHER_API_KEY, - // model: "codellama-7b", - // }) - // ); - // testLLM(new LlamaCpp({ model: "deepseek-7b" })); - // testLLM(new Llamafile({ model: "mistral-7b" })); - // TODO: Test these - // testLLM(new TextGenWebUI({ model: "codellama-7b" })); - // testLLM(new HuggingFaceTGI({ model: "codellama-7b" })); - // testLLM(new HuggingFaceInferenceAPI({ model: "codellama-7b" })); - // testLLM( - // new Gemini({ - // model: "gemini-pro", - // // model: "chat-bison-001", - // apiKey: process.env.GOOGLE_PALM_API_KEY, - // }) - // ); - // testLLM( - // new Gemini({ model: "gemini-pro", apiKey: process.env.GOOGLE_PALM_API_KEY }) - // ); - // testLLM( - // new Mistral({ apiKey: process.env.MISTRAL_API_KEY, model: "mistral-small" }) - // ); - // testLLM( - // new Flowise({ apiKey: process.env.FLOWISE_API_KEY, model: "gpt-3.5-turbo" }) - // ); - // testLLM( - // new Nebius({ apiKey: process.env.NEBIUS_API_KEY, model: "llama3.1-8b" }) - // ); }); diff --git a/core/llm/llms/Anthropic.ts b/core/llm/llms/Anthropic.ts index 234c168600..df527d579a 100644 --- a/core/llm/llms/Anthropic.ts +++ b/core/llm/llms/Anthropic.ts @@ -30,6 +30,12 @@ class Anthropic extends BaseLLM { description: tool.function.description, input_schema: tool.function.parameters, })), + tool_choice: options.toolChoice + ? { + type: "tool", + name: options.toolChoice.function.name, + } + : undefined, }; return finalOptions; diff --git a/core/llm/llms/Mock.ts b/core/llm/llms/Mock.ts index 236af9824a..03e2680747 100644 --- a/core/llm/llms/Mock.ts +++ b/core/llm/llms/Mock.ts @@ -1,7 +1,7 @@ import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js"; import { BaseLLM } from "../index.js"; -type MockMessage = ChatMessage | "REPEAT_LAST_MSG"; +type MockMessage = ChatMessage | "REPEAT_LAST_MSG" | "REPEAT_SYSTEM_MSG"; class MockLLM extends BaseLLM { public completion: string = "Test Completion"; @@ -42,6 +42,13 @@ class MockLLM extends BaseLLM { content: messages[messages.length - 1].content, }; break; + case "REPEAT_SYSTEM_MSG": + yield { + role: "assistant", + content: + messages.find((m) => m.role === "system")?.content || "", + }; + break; default: yield message; } diff --git a/core/llm/llms/OpenAI.ts b/core/llm/llms/OpenAI.ts index 1cf1bc44e7..41b3c755be 100644 --- a/core/llm/llms/OpenAI.ts +++ b/core/llm/llms/OpenAI.ts @@ -113,6 +113,10 @@ class OpenAI extends BaseLLM { }; } + protected extraBodyProperties(): Record { + return {}; + } + protected getMaxStopWords(): number { const url = new URL(this.apiBase!); @@ -150,6 +154,10 @@ class OpenAI extends BaseLLM { finalOptions.messages = formatMessageForO1(finalOptions.messages); } + if (options.model === "o1") { + finalOptions.stream = false; + } + if (options.prediction && this.supportsPrediction(options.model)) { if (finalOptions.presence_penalty) { // prediction doesn't support > 0 @@ -241,6 +249,11 @@ class OpenAI extends BaseLLM { body.messages = formatMessageForO1(body.messages); } + if (body.model === "o1") { + // o1 doesn't support streaming + body.stream = false; + } + if (body.prediction && this.supportsPrediction(body.model)) { if (body.presence_penalty) { // prediction doesn't support > 0 @@ -278,6 +291,7 @@ class OpenAI extends BaseLLM { body: JSON.stringify({ ...args, stream: true, + ...this.extraBodyProperties(), }), signal, }); @@ -319,7 +333,10 @@ class OpenAI extends BaseLLM { const response = await this.fetch(this._getEndpoint("chat/completions"), { method: "POST", headers: this._getHeaders(), - body: JSON.stringify(body), + body: JSON.stringify({ + ...body, + ...this.extraBodyProperties(), + }), signal, }); @@ -358,6 +375,7 @@ class OpenAI extends BaseLLM { presence_penalty: options.presencePenalty, stop: options.stop, stream: true, + ...this.extraBodyProperties(), }), headers: { "Content-Type": "application/json", @@ -404,6 +422,7 @@ class OpenAI extends BaseLLM { body: JSON.stringify({ input: chunks, model: this.model, + ...this.extraBodyProperties(), }), headers: { Authorization: `Bearer ${this.apiKey}`, diff --git a/core/llm/llms/stubs/ContinueProxy.ts b/core/llm/llms/stubs/ContinueProxy.ts index c037ebd61d..59d05b115a 100644 --- a/core/llm/llms/stubs/ContinueProxy.ts +++ b/core/llm/llms/stubs/ContinueProxy.ts @@ -10,11 +10,33 @@ class ContinueProxy extends OpenAI { this.apiBase = new URL("openai/v1/", value.controlPlaneProxyUrl).toString(); } + // The apiKey and apiBase are set to the values for the proxy, + // but we need to keep track of the actual values that the proxy will use + // to call whatever LLM API is chosen + private actualApiBase?: string; + private actualApiKey?: string; + + constructor(options: LLMOptions) { + super(options); + this.actualApiBase = options.apiBase; + this.actualApiKey = options.apiKey; + } + static providerName = "continue-proxy"; static defaultOptions: Partial = { useLegacyCompletionsEndpoint: false, }; + protected extraBodyProperties(): Record { + return { + continueProperties: { + apiKey: this.actualApiKey, + apiBase: this.actualApiBase, + apiKeySecret: this.apiKeySecret, + }, + }; + } + protected _getHeaders() { const headers: any = super._getHeaders(); headers["x-continue-unique-id"] = Telemetry.uniqueId; diff --git a/core/llm/openaiTypeConverters.ts b/core/llm/openaiTypeConverters.ts index 057219beb8..839efda4e7 100644 --- a/core/llm/openaiTypeConverters.ts +++ b/core/llm/openaiTypeConverters.ts @@ -80,6 +80,7 @@ export function toChatBody( stream: options.stream ?? true, stop: options.stop, prediction: options.prediction, + tool_choice: options.toolChoice, }; if (options.tools?.length) { diff --git a/core/package-lock.json b/core/package-lock.json index 18a318596a..350607e4d2 100644 --- a/core/package-lock.json +++ b/core/package-lock.json @@ -13,7 +13,7 @@ "@aws-sdk/client-sagemaker-runtime": "^3.621.0", "@aws-sdk/credential-providers": "^3.620.1", "@continuedev/config-types": "^1.0.13", - "@continuedev/config-yaml": "^1.0.0", + "@continuedev/config-yaml": "^1.0.11", "@continuedev/fetch": "^1.0.4", "@continuedev/llm-info": "^1.0.2", "@continuedev/openai-adapters": "^1.0.10", @@ -3030,18 +3030,19 @@ } }, "node_modules/@continuedev/config-types": { - "version": "1.0.13", - "resolved": "https://registry.npmjs.org/@continuedev/config-types/-/config-types-1.0.13.tgz", - "integrity": "sha512-5Qmy2x58uhW1GYknypFT/7iwxEJSbt8gMgfQReV7d6xjbJ09ISOTki3hIfCilKrTCo0gIY6pzKEfCoQkUTEi6Q==", + "version": "1.0.14", + "resolved": "https://registry.npmjs.org/@continuedev/config-types/-/config-types-1.0.14.tgz", + "integrity": "sha512-PVHyHPyRXd2QsaNgnCpiKYU3uHFTlyuQSkqE8OwrBmQqO6/TXUVIr/2EGtyIZGrml4Y+rGMSH40WU4/0t4SGpQ==", "dependencies": { "zod": "^3.23.8" } }, "node_modules/@continuedev/config-yaml": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/@continuedev/config-yaml/-/config-yaml-1.0.0.tgz", - "integrity": "sha512-NwagzzpnPLbMC3delzFSjWiFv5fp8ReFjdMQF0ko0Xg+U0Myja90k/J7JWGegF/SZMABxc/OJchElJFr58TwGA==", + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/@continuedev/config-yaml/-/config-yaml-1.0.11.tgz", + "integrity": "sha512-E3RBQfNEPBGBmlnAbCXgeAasDzTjo4ON/HH0hr5g292i+WdAN3i/omjQ6Iusx00L1Fz7klZGJePZ3GVQKOGEUg==", "dependencies": { + "@continuedev/config-types": "^1.0.14", "yaml": "^2.6.1", "zod": "^3.23.8" } diff --git a/core/package.json b/core/package.json index e865e0ddbf..0e107f9916 100644 --- a/core/package.json +++ b/core/package.json @@ -46,7 +46,7 @@ "@aws-sdk/client-sagemaker-runtime": "^3.621.0", "@aws-sdk/credential-providers": "^3.620.1", "@continuedev/config-types": "^1.0.13", - "@continuedev/config-yaml": "^1.0.0", + "@continuedev/config-yaml": "^1.0.11", "@continuedev/fetch": "^1.0.4", "@continuedev/llm-info": "^1.0.2", "@continuedev/openai-adapters": "^1.0.10", diff --git a/core/protocol/core.ts b/core/protocol/core.ts index 9495dc3b2a..27f1c92ad1 100644 --- a/core/protocol/core.ts +++ b/core/protocol/core.ts @@ -1,6 +1,7 @@ import { AutocompleteInput } from "../autocomplete/util/types"; import { ProfileDescription } from "../config/ConfigHandler"; +import { ConfigResult } from "@continuedev/config-yaml"; import type { BrowserSerializedContinueConfig, ChatMessage, @@ -21,7 +22,6 @@ import type { SiteIndexingConfig, ToolCall, } from "../"; -import { ConfigResult } from "../config/load"; export type ProtocolGeneratorYield = { done?: boolean; diff --git a/core/protocol/ide.ts b/core/protocol/ide.ts index 8c0d2557c5..5b2a7ab1fb 100644 --- a/core/protocol/ide.ts +++ b/core/protocol/ide.ts @@ -13,8 +13,8 @@ import type { Problem, Range, RangeInFile, + TerminalOptions, Thread, - TerminalOptions } from "../"; export interface GetGhTokenArgs { @@ -29,7 +29,7 @@ export type ToIdeFromWebviewOrCoreProtocol = { showVirtualFile: [{ name: string; content: string }, void]; openFile: [{ path: string }, void]; openUrl: [string, void]; - runCommand: [{ command: string, options?: TerminalOptions }, void]; + runCommand: [{ command: string; options?: TerminalOptions }, void]; getSearchResults: [{ query: string }, string]; subprocess: [{ command: string; cwd?: string }, [string, string]]; saveFile: [{ filepath: string }, void]; @@ -72,6 +72,8 @@ export type ToIdeFromWebviewOrCoreProtocol = { isTelemetryEnabled: [undefined, boolean]; getUniqueId: [undefined, string]; getTags: [string, IndexTag[]]; + readSecrets: [{ keys: string[] }, Record]; + writeSecrets: [{ secrets: Record }, void]; // end methods from IDE type getIdeSettings: [undefined, IdeSettings]; @@ -92,7 +94,7 @@ export type ToIdeFromWebviewOrCoreProtocol = { getGitHubAuthToken: [GetGhTokenArgs, string | undefined]; getControlPlaneSessionInfo: [ - { silent: boolean }, + { silent: boolean; useOnboarding: boolean }, ControlPlaneSessionInfo | undefined, ]; logoutOfControlPlane: [undefined, void]; diff --git a/core/protocol/messenger/messageIde.ts b/core/protocol/messenger/messageIde.ts index aaa567f391..4bb78e0ed6 100644 --- a/core/protocol/messenger/messageIde.ts +++ b/core/protocol/messenger/messageIde.ts @@ -29,6 +29,14 @@ export class MessageIde implements IDE { ) => void, ) {} + async readSecrets(keys: string[]): Promise> { + return this.request("readSecrets", { keys }); + } + + async writeSecrets(secrets: { [key: string]: string }): Promise { + return this.request("writeSecrets", { secrets }); + } + fileExists(fileUri: string): Promise { return this.request("fileExists", { filepath: fileUri }); } @@ -152,7 +160,7 @@ export class MessageIde implements IDE { await this.request("openUrl", url); } - async runCommand(command: string, options?: TerminalOptions): Promise { + async runCommand(command: string, options?: TerminalOptions): Promise { await this.request("runCommand", { command, options }); } diff --git a/core/protocol/webview.ts b/core/protocol/webview.ts index 5e22faae6a..822d9b532e 100644 --- a/core/protocol/webview.ts +++ b/core/protocol/webview.ts @@ -1,11 +1,8 @@ -import { ConfigResult } from "../config/load.js"; -import { ConfigValidationError } from "../config/validation.js"; - +import { ConfigResult, ConfigValidationError } from "@continuedev/config-yaml"; import type { BrowserSerializedContinueConfig, ContextItemWithId, ContextProviderName, - ContextSubmenuItem, IndexingProgressUpdate, IndexingStatus, PackageDocsResult, diff --git a/core/util/filesystem.ts b/core/util/filesystem.ts index b4297f7bc2..ccb550ec93 100644 --- a/core/util/filesystem.ts +++ b/core/util/filesystem.ts @@ -21,6 +21,13 @@ import { GetGhTokenArgs } from "../protocol/ide.js"; class FileSystemIde implements IDE { constructor(private readonly workspaceDir: string) {} + + async readSecrets(keys: string[]): Promise> { + return {}; + } + + async writeSecrets(secrets: { [key: string]: string }): Promise {} + showToast( type: ToastType, message: string, diff --git a/core/util/paths.ts b/core/util/paths.ts index 8dd0cf84fb..562d2e819c 100644 --- a/core/util/paths.ts +++ b/core/util/paths.ts @@ -301,11 +301,18 @@ export function getPathToRemoteConfig(remoteConfigServerUrl: string): string { return dir; } -export function internalBetaPathExists(): boolean { - const sPath = path.join(getContinueGlobalPath(), ".internal_beta"); +export function usePlatformPathExists(): boolean { + const sPath = path.join(getContinueGlobalPath(), ".use_platform"); return fs.existsSync(sPath); } +export function readUsePlatform(): string | undefined { + const sPath = path.join(getContinueGlobalPath(), ".use_platform"); + if (fs.existsSync(sPath)) { + return fs.readFileSync(sPath, "utf8"); + } +} + export function getConfigJsonPathForRemote( remoteConfigServerUrl: string, ): string { diff --git a/extensions/vscode/.gitignore b/extensions/vscode/.gitignore index f922ad032f..05afd11d44 100644 --- a/extensions/vscode/.gitignore +++ b/extensions/vscode/.gitignore @@ -23,4 +23,6 @@ e2e/.test-extensions e2e/storage e2e/vsix e2e/test-continue/* -!e2e/test-continue/config.json \ No newline at end of file +!e2e/test-continue/config.json +e2e/test-continue-yaml/* +!e2e/test-continue-yaml/config.yaml \ No newline at end of file diff --git a/extensions/vscode/e2e/selectors/GUI.selectors.ts b/extensions/vscode/e2e/selectors/GUI.selectors.ts index 16c7b39d93..71a38c9f28 100644 --- a/extensions/vscode/e2e/selectors/GUI.selectors.ts +++ b/extensions/vscode/e2e/selectors/GUI.selectors.ts @@ -1,4 +1,5 @@ import { By, WebDriver, WebView } from "vscode-extension-tester"; + import { SelectorUtils } from "./SelectorUtils"; export class GUISelectors { @@ -41,6 +42,24 @@ export class GUISelectors { return SelectorUtils.getElementByDataTestId(view, "model-select-button"); } + public static getFirstContextProviderDropdownItem(view: WebView) { + return SelectorUtils.getElementByDataTestId( + view, + "context-provider-dropdown-item", + ); + } + + public static getContextItemsPeek(view: WebView) { + return SelectorUtils.getElementByDataTestId(view, "context-items-peek"); + } + + public static getFirstContextItemsPeekItem(view: WebView) { + return SelectorUtils.getElementByDataTestId( + view, + "context-items-peek-item", + ); + } + public static getNthHistoryTableRow(view: WebView, index: number) { return SelectorUtils.getElementByDataTestId(view, `history-row-${index}`); } diff --git a/extensions/vscode/e2e/test-continue-yaml/config.yaml b/extensions/vscode/e2e/test-continue-yaml/config.yaml new file mode 100644 index 0000000000..487cdc8dd0 --- /dev/null +++ b/extensions/vscode/e2e/test-continue-yaml/config.yaml @@ -0,0 +1,76 @@ +name: Test Config +version: 0.0.1 + +rules: + - TEST_SYS_MSG + +models: + - name: TEST LLM + provider: test + model: this field is not used + roles: + - chat + - edit + - apply + - autocomplete + + - name: Mock + provider: mock + model: this field is not used + roles: + - chat + - edit + - apply + + - name: TOOL MOCK LLM + provider: mock + model: claude-3-5-sonnet-latest + requestOptions: + extraBodyProperties: + chatStream: + - - role: assistant + content: "I'm going to call a tool:" + - role: assistant + content: "" + toolCalls: + - id: test_id + type: function + function: + name: builtin_exact_search + arguments: '{"query": "test query"}' + - REPEAT_LAST_MSG + roles: + - chat + - edit + - apply + + - name: SYSTEM MESSAGE MOCK LLM + provider: mock + model: claude-3-5-sonnet-latest + requestOptions: + extraBodyProperties: + chatStream: + - - REPEAT_SYSTEM_MSG + roles: + - chat + - edit + - apply + + - name: LAST MESSAGE MOCK LLM + provider: mock + model: claude-3-5-sonnet-latest + requestOptions: + extraBodyProperties: + chatStream: + - - REPEAT_LAST_MSG + roles: + - chat + - edit + - apply + +context: + - uses: docs + - uses: diff + - uses: url + - uses: folder + - uses: terminal diff --git a/extensions/vscode/e2e/test-continue/config.json b/extensions/vscode/e2e/test-continue/config.json index a8932322a3..ad67f1999c 100644 --- a/extensions/vscode/e2e/test-continue/config.json +++ b/extensions/vscode/e2e/test-continue/config.json @@ -44,8 +44,29 @@ ] } } + }, + { + "provider": "mock", + "title": "SYSTEM MESSAGE MOCK LLM", + "model": "claude-3-5-sonnet-latest", + "requestOptions": { + "extraBodyProperties": { + "chatStream": [["REPEAT_SYSTEM_MSG"]] + } + } + }, + { + "provider": "mock", + "title": "LAST MESSAGE MOCK LLM", + "model": "claude-3-5-sonnet-latest", + "requestOptions": { + "extraBodyProperties": { + "chatStream": [["REPEAT_LAST_MSG"]] + } + } } ], + "systemMessage": "TEST_SYS_MSG", "analytics": { "provider": "continue-proxy" }, @@ -69,6 +90,9 @@ }, { "name": "folder" + }, + { + "name": "terminal" } ], "docs": [] diff --git a/extensions/vscode/e2e/tests/GUI.test.ts b/extensions/vscode/e2e/tests/GUI.test.ts index e2545148a9..316db2d0e3 100644 --- a/extensions/vscode/e2e/tests/GUI.test.ts +++ b/extensions/vscode/e2e/tests/GUI.test.ts @@ -207,6 +207,53 @@ describe("GUI Test", () => { }); }); + describe("Context providers", () => { + it("should successfully use the terminal context provider", async () => { + await GUIActions.selectModelFromDropdown(view, "LAST MESSAGE MOCK LLM"); + + // Enter just the context provider in the input and send + const [messageInput] = await GUISelectors.getMessageInputFields(view); + await messageInput.sendKeys("@"); + await messageInput.sendKeys("terminal"); + await messageInput.sendKeys(Key.ENTER); + await messageInput.sendKeys(Key.ENTER); + + // Open the context items peek + const contextItemsPeek = await GUISelectors.getContextItemsPeek(view); + await contextItemsPeek.click(); + + await TestUtils.waitForSuccess(async () => { + const firstContextItemInPeek = + await GUISelectors.getFirstContextItemsPeekItem(view); + await firstContextItemInPeek.click(); + + // Check that item is there with correct name + const description = await firstContextItemInPeek.getText(); + expect(description).to.include("Terminal"); + }); + + // Check that the contents match what we expect (repeated back by the mock LLM) + await TestUtils.waitForSuccess(() => { + return GUISelectors.getThreadMessageByText( + view, + "Current terminal contents:", + ); + }); + }).timeout(DEFAULT_TIMEOUT.MD); + }); + + describe("Repeat back the system message", () => { + it("should repeat back the system message", async () => { + await GUIActions.selectModelFromDropdown(view, "SYSTEM MESSAGE MOCK LLM"); + const [messageInput] = await GUISelectors.getMessageInputFields(view); + await messageInput.sendKeys("Hello"); + await messageInput.sendKeys(Key.ENTER); + await TestUtils.waitForSuccess(() => + GUISelectors.getThreadMessageByText(view, "TEST_SYS_MSG"), + ); + }); + }); + describe("Chat Paths", () => { it("Send many messages → chat auto scrolls → go to history → open previous chat → it is scrolled to the bottom", async () => { for (let i = 0; i <= 20; i++) { diff --git a/extensions/vscode/package-lock.json b/extensions/vscode/package-lock.json index 26d319d834..31f06cbb14 100644 --- a/extensions/vscode/package-lock.json +++ b/extensions/vscode/package-lock.json @@ -1,12 +1,12 @@ { "name": "continue", - "version": "0.9.251", + "version": "0.9.252", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "continue", - "version": "0.9.251", + "version": "0.9.252", "license": "Apache-2.0", "dependencies": { "@continuedev/fetch": "^1.0.3", @@ -106,7 +106,7 @@ "@aws-sdk/client-sagemaker-runtime": "^3.621.0", "@aws-sdk/credential-providers": "^3.620.1", "@continuedev/config-types": "^1.0.13", - "@continuedev/config-yaml": "^1.0.0", + "@continuedev/config-yaml": "^1.0.11", "@continuedev/fetch": "^1.0.4", "@continuedev/llm-info": "^1.0.2", "@continuedev/openai-adapters": "^1.0.10", diff --git a/extensions/vscode/package.json b/extensions/vscode/package.json index fa8b206896..edb58a5543 100644 --- a/extensions/vscode/package.json +++ b/extensions/vscode/package.json @@ -2,7 +2,7 @@ "name": "continue", "icon": "media/icon.png", "author": "Continue Dev, Inc", - "version": "0.9.251", + "version": "0.9.252", "repository": { "type": "git", "url": "https://github.com/continuedev/continue" @@ -628,7 +628,8 @@ "e2e:all": "npm run e2e:build && npm run e2e:compile && npm run e2e:create-storage && npm run e2e:get-chromedriver && npm run e2e:get-vscode && npm run e2e:sign-vscode && npm run e2e:copy-vsix && npm run e2e:install-vsix && npm run e2e:install-extensions && CONTINUE_GLOBAL_DIR=e2e/test-continue npm run e2e:test && npm run e2e:clean", "e2e:quick": "npm run e2e:compile && CONTINUE_GLOBAL_DIR=e2e/test-continue npm run e2e:test && npm run e2e:clean", "e2e:ci:download": "npm run e2e:create-storage && npm run e2e:get-chromedriver && npm run e2e:get-vscode", - "e2e:ci:run": "npm run e2e:compile && npm run e2e:copy-vsix && npm run e2e:install-vsix && npm run e2e:install-extensions && CONTINUE_GLOBAL_DIR=e2e/test-continue npm run e2e:test" + "e2e:ci:run": "npm run e2e:compile && npm run e2e:copy-vsix && npm run e2e:install-vsix && npm run e2e:install-extensions && CONTINUE_GLOBAL_DIR=e2e/test-continue npm run e2e:test", + "e2e:ci:run-yaml": "npm run e2e:compile && npm run e2e:copy-vsix && npm run e2e:install-vsix && npm run e2e:install-extensions && CONTINUE_GLOBAL_DIR=e2e/test-continue-yaml npm run e2e:test" }, "devDependencies": { "@biomejs/biome": "1.6.4", diff --git a/extensions/vscode/src/VsCodeIde.ts b/extensions/vscode/src/VsCodeIde.ts index 2cec274e9b..51008db8ad 100644 --- a/extensions/vscode/src/VsCodeIde.ts +++ b/extensions/vscode/src/VsCodeIde.ts @@ -28,15 +28,41 @@ import type { TerminalOptions, Thread, } from "core"; +import { SecretStorage } from "./stubs/SecretStorage"; class VsCodeIde implements IDE { ideUtils: VsCodeIdeUtils; + secretStorage: SecretStorage; constructor( private readonly vscodeWebviewProtocolPromise: Promise, private readonly context: vscode.ExtensionContext, ) { this.ideUtils = new VsCodeIdeUtils(); + this.secretStorage = new SecretStorage(context); + } + + async readSecrets(keys: string[]): Promise> { + const secretValuePromises = keys.map((key) => this.secretStorage.get(key)); + const secretValues = await Promise.all(secretValuePromises); + + return keys.reduce( + (acc, key, index) => { + if (secretValues[index] === undefined) { + return acc; + } + + acc[key] = secretValues[index]; + return acc; + }, + {} as Record, + ); + } + + async writeSecrets(secrets: { [key: string]: string }): Promise { + for (const [key, value] of Object.entries(secrets)) { + await this.secretStorage.store(key, value); + } } async fileExists(uri: string): Promise { @@ -399,17 +425,22 @@ class VsCodeIde implements IDE { ); } - async runCommand(command: string, options: TerminalOptions = {reuseTerminal: true}): Promise { + async runCommand( + command: string, + options: TerminalOptions = { reuseTerminal: true }, + ): Promise { let terminal: vscode.Terminal | undefined; if (vscode.window.terminals.length && options.reuseTerminal) { if (options.terminalName) { - terminal = vscode.window.terminals.find(t => t?.name === options.terminalName); + terminal = vscode.window.terminals.find( + (t) => t?.name === options.terminalName, + ); } else { terminal = vscode.window.activeTerminal ?? vscode.window.terminals[0]; } } - if( !terminal) { + if (!terminal) { terminal = vscode.window.createTerminal(options?.terminalName); } terminal.show(); diff --git a/extensions/vscode/src/extension/VsCodeExtension.ts b/extensions/vscode/src/extension/VsCodeExtension.ts index 4ff02a736d..9e6da728db 100644 --- a/extensions/vscode/src/extension/VsCodeExtension.ts +++ b/extensions/vscode/src/extension/VsCodeExtension.ts @@ -171,7 +171,7 @@ export class VsCodeExtension { const result = await this.configHandler.getSerializedConfig(); this.sidebar.webviewProtocol?.request("configUpdate", { result, - profileId: this.configHandler.currentProfile.profileId, + profileId: this.configHandler.currentProfile.profileDescription.id, }); this.tabAutocompleteModel.clearLlm(); @@ -301,7 +301,7 @@ export class VsCodeExtension { true, ); - const sessionInfo = await getControlPlaneSessionInfo(true); + const sessionInfo = await getControlPlaneSessionInfo(true, false); this.webviewProtocolPromise.then(async (webviewProtocol) => { void webviewProtocol.request("didChangeControlPlaneSessionInfo", { sessionInfo, diff --git a/extensions/vscode/src/extension/VsCodeMessenger.ts b/extensions/vscode/src/extension/VsCodeMessenger.ts index ce388a4ff2..6bce4818a5 100644 --- a/extensions/vscode/src/extension/VsCodeMessenger.ts +++ b/extensions/vscode/src/extension/VsCodeMessenger.ts @@ -446,7 +446,10 @@ export class VsCodeMessenger { ide.getGitHubAuthToken(msg.data), ); this.onWebviewOrCore("getControlPlaneSessionInfo", async (msg) => { - return getControlPlaneSessionInfo(msg.data.silent); + return getControlPlaneSessionInfo( + msg.data.silent, + msg.data.useOnboarding, + ); }); this.onWebviewOrCore("logoutOfControlPlane", async (msg) => { const sessions = await this.workOsAuthProvider.getSessions(); diff --git a/extensions/vscode/src/stubs/WorkOsAuthProvider.ts b/extensions/vscode/src/stubs/WorkOsAuthProvider.ts index 486b0df12e..a22c4934a2 100644 --- a/extensions/vscode/src/stubs/WorkOsAuthProvider.ts +++ b/extensions/vscode/src/stubs/WorkOsAuthProvider.ts @@ -1,3 +1,7 @@ +import crypto from "crypto"; + +import { ControlPlaneSessionInfo } from "core/control-plane/client"; +import { controlPlaneEnv } from "core/control-plane/env"; import fetch from "node-fetch"; import { v4 as uuidv4 } from "uuid"; import { @@ -16,6 +20,7 @@ import { } from "vscode"; import { PromiseAdapter, promiseFromEvent } from "./promiseUtils"; +import { SecretStorage } from "./SecretStorage"; const AUTH_NAME = "Continue"; @@ -26,14 +31,6 @@ class UriEventHandler extends EventEmitter implements UriHandler { this.fire(uri); } } - -import { ControlPlaneSessionInfo } from "core/control-plane/client"; -import { controlPlaneEnv } from "core/control-plane/env"; - -import crypto from "crypto"; - -import { SecretStorage } from "./SecretStorage"; - // Function to generate a random string of specified length function generateRandomString(length: number): string { const possibleCharacters = @@ -163,12 +160,6 @@ export class WorkOsAuthProvider implements AuthenticationProvider, Disposable { } get ideRedirectUri() { - const publisher = this.context.extension.packageJSON.publisher; - const name = this.context.extension.packageJSON.name; - return `${env.uriScheme}://${publisher}.${name}`; - } - - get redirectUri() { if ( env.uriScheme === "vscode-insiders" || env.uriScheme === "vscode" || @@ -179,6 +170,18 @@ export class WorkOsAuthProvider implements AuthenticationProvider, Disposable { url.pathname = `/auth/${env.uriScheme}-redirect`; return url.toString(); } + const publisher = this.context.extension.packageJSON.publisher; + const name = this.context.extension.packageJSON.name; + return `${env.uriScheme}://${publisher}.${name}`; + } + + public static useOnboardingUri: boolean = false; + get redirectUri() { + if (WorkOsAuthProvider.useOnboardingUri) { + const url = new URL(controlPlaneEnv.APP_URL); + url.pathname = `/onboarding/redirect/${env.uriScheme}`; + return url.toString(); + } return this.ideRedirectUri; } @@ -407,8 +410,9 @@ export class WorkOsAuthProvider implements AuthenticationProvider, Disposable { try { return await Promise.race([ codeExchangePromise.promise, - new Promise((_, reject) => - setTimeout(() => reject("Cancelled"), 15 * 60 * 1_000), + new Promise( + (_, reject) => + setTimeout(() => reject("Cancelled"), 60 * 60 * 1_000), // 60min timeout ), promiseFromEvent( token.onCancellationRequested, @@ -488,20 +492,29 @@ export class WorkOsAuthProvider implements AuthenticationProvider, Disposable { export async function getControlPlaneSessionInfo( silent: boolean, + useOnboarding: boolean, ): Promise { - const session = await authentication.getSession( - controlPlaneEnv.AUTH_TYPE, - [], - silent ? { silent: true } : { createIfNone: true }, - ); - if (!session) { - return undefined; + try { + if (useOnboarding) { + WorkOsAuthProvider.useOnboardingUri = true; + } + + const session = await authentication.getSession( + controlPlaneEnv.AUTH_TYPE, + [], + silent ? { silent: true } : { createIfNone: true }, + ); + if (!session) { + return undefined; + } + return { + accessToken: session.accessToken, + account: { + id: session.account.id, + label: session.account.label, + }, + }; + } finally { + WorkOsAuthProvider.useOnboardingUri = false; } - return { - accessToken: session.accessToken, - account: { - id: session.account.id, - label: session.account.label, - }, - }; } diff --git a/gui/package-lock.json b/gui/package-lock.json index ac7f6ce819..d99d5ebab7 100644 --- a/gui/package-lock.json +++ b/gui/package-lock.json @@ -7,7 +7,7 @@ "name": "gui", "license": "Apache-2.0", "dependencies": { - "@continuedev/config-yaml": "^1.0.0", + "@continuedev/config-yaml": "^1.0.11", "@headlessui/react": "^1.7.17", "@heroicons/react": "^2.0.18", "@reduxjs/toolkit": "^2.3.0", @@ -108,7 +108,7 @@ "@aws-sdk/client-sagemaker-runtime": "^3.621.0", "@aws-sdk/credential-providers": "^3.620.1", "@continuedev/config-types": "^1.0.13", - "@continuedev/config-yaml": "^1.0.0", + "@continuedev/config-yaml": "^1.0.11", "@continuedev/fetch": "^1.0.4", "@continuedev/llm-info": "^1.0.2", "@continuedev/openai-adapters": "^1.0.10", @@ -547,11 +547,20 @@ "integrity": "sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==", "dev": true }, + "node_modules/@continuedev/config-types": { + "version": "1.0.14", + "resolved": "https://registry.npmjs.org/@continuedev/config-types/-/config-types-1.0.14.tgz", + "integrity": "sha512-PVHyHPyRXd2QsaNgnCpiKYU3uHFTlyuQSkqE8OwrBmQqO6/TXUVIr/2EGtyIZGrml4Y+rGMSH40WU4/0t4SGpQ==", + "dependencies": { + "zod": "^3.23.8" + } + }, "node_modules/@continuedev/config-yaml": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/@continuedev/config-yaml/-/config-yaml-1.0.0.tgz", - "integrity": "sha512-NwagzzpnPLbMC3delzFSjWiFv5fp8ReFjdMQF0ko0Xg+U0Myja90k/J7JWGegF/SZMABxc/OJchElJFr58TwGA==", + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/@continuedev/config-yaml/-/config-yaml-1.0.11.tgz", + "integrity": "sha512-E3RBQfNEPBGBmlnAbCXgeAasDzTjo4ON/HH0hr5g292i+WdAN3i/omjQ6Iusx00L1Fz7klZGJePZ3GVQKOGEUg==", "dependencies": { + "@continuedev/config-types": "^1.0.14", "yaml": "^2.6.1", "zod": "^3.23.8" } diff --git a/gui/package.json b/gui/package.json index 58a4529d8a..bd58519fb0 100644 --- a/gui/package.json +++ b/gui/package.json @@ -15,7 +15,7 @@ "test:watch": "vitest" }, "dependencies": { - "@continuedev/config-yaml": "^1.0.0", + "@continuedev/config-yaml": "^1.0.11", "@headlessui/react": "^1.7.17", "@heroicons/react": "^2.0.18", "@reduxjs/toolkit": "^2.3.0", diff --git a/gui/src/components/AccountDialog.tsx b/gui/src/components/AccountDialog.tsx index c555205747..2b8e17ae4c 100644 --- a/gui/src/components/AccountDialog.tsx +++ b/gui/src/components/AccountDialog.tsx @@ -1,6 +1,7 @@ import { Listbox, Transition } from "@headlessui/react"; import { ChevronUpDownIcon } from "@heroicons/react/24/outline"; import { Fragment, useContext } from "react"; +import { useDispatch } from "react-redux"; import styled from "styled-components"; import { Button, @@ -9,10 +10,9 @@ import { vscListActiveBackground, vscListActiveForeground, } from "."; -import { IdeMessengerContext } from "../context/IdeMessenger"; import { useAuth } from "../context/Auth"; +import { IdeMessengerContext } from "../context/IdeMessenger"; import { setSelectedProfileId } from "../redux/slices/sessionSlice"; -import { useDispatch } from "react-redux"; import { setDialogMessage, setShowDialog } from "../redux/slices/uiSlice"; const StyledListboxOption = styled(Listbox.Option)<{ selected: boolean }>` @@ -46,7 +46,7 @@ export default function AccountDialog() { Signed out - diff --git a/gui/src/components/OnboardingCard/platform/PlatformOnboardingCard.tsx b/gui/src/components/OnboardingCard/platform/PlatformOnboardingCard.tsx new file mode 100644 index 0000000000..360282e4a7 --- /dev/null +++ b/gui/src/components/OnboardingCard/platform/PlatformOnboardingCard.tsx @@ -0,0 +1,63 @@ +import { XMarkIcon } from "@heroicons/react/24/outline"; +import { useState } from "react"; +import styled from "styled-components"; +import { CloseButton, defaultBorderRadius, vscInputBackground } from "../.."; +import { getLocalStorage, setLocalStorage } from "../../../util/localStorage"; +import Alert from "../../gui/Alert"; +import { TabTitle } from "../components/OnboardingCardTabs"; +import { useOnboardingCard } from "../hooks"; +import OnboardingLocalTab from "../tabs/OnboardingLocalTab"; +import MainTab from "./tabs/main"; + +const StyledCard = styled.div` + margin: auto; + border-radius: ${defaultBorderRadius}; + background-color: ${vscInputBackground}; + box-shadow: + 0 20px 25px -5px rgb(0 0 0 / 0.1), + 0 8px 10px -6px rgb(0 0 0 / 0.1); +`; + +export interface OnboardingCardState { + show?: boolean; + activeTab?: TabTitle; +} + +export function PlatformOnboardingCard() { + const onboardingCard = useOnboardingCard(); + + if (getLocalStorage("onboardingStatus") === undefined) { + setLocalStorage("onboardingStatus", "Started"); + } + + const [currentTab, setCurrentTab] = useState<"main" | "local">("main"); + + return ( + + + + +
+
+ {currentTab === "main" ? ( + setCurrentTab("local")} /> + ) : ( +
+ + By choosing this option, Continue will be configured by a local{" "} + config.yaml file. If you're just looking to use + Ollama and still want to manage your configuration through + Continue, click{" "} + setCurrentTab("main")}> + here + + + + +
+ )} +
+
+
+ ); +} diff --git a/gui/src/components/OnboardingCard/platform/tabs/local.tsx b/gui/src/components/OnboardingCard/platform/tabs/local.tsx new file mode 100644 index 0000000000..ae30085ecc --- /dev/null +++ b/gui/src/components/OnboardingCard/platform/tabs/local.tsx @@ -0,0 +1,26 @@ +import { useAuth } from "../../../../context/Auth"; +import ContinueLogo from "../../../gui/ContinueLogo"; +import { useOnboardingCard } from "../../hooks"; + +export default function LocalTab() { + const onboardingCard = useOnboardingCard(); + const auth = useAuth(); + + return ( +
+
+ +
+ +

+ Log in to quickly build your first custom AI code assistant +

+ + {/*

+ To prevent abuse, we'll ask you to sign in to GitHub. +

*/} + +
+
+ ); +} diff --git a/gui/src/components/OnboardingCard/platform/tabs/main.tsx b/gui/src/components/OnboardingCard/platform/tabs/main.tsx new file mode 100644 index 0000000000..f6a14cb939 --- /dev/null +++ b/gui/src/components/OnboardingCard/platform/tabs/main.tsx @@ -0,0 +1,49 @@ +import { ChevronRightIcon } from "@heroicons/react/24/outline"; +import { Button, ButtonSubtext } from "../../.."; +import { useAuth } from "../../../../context/Auth"; +import ContinueLogo from "../../../gui/ContinueLogo"; +import { useOnboardingCard } from "../../hooks"; + +export default function MainTab({ + onRemainLocal, +}: { + onRemainLocal: () => void; +}) { + const onboardingCard = useOnboardingCard(); + const auth = useAuth(); + + function onGetStarted() { + auth.login(true).then((success) => { + if (success) { + onboardingCard.close(); + } + }); + } + + return ( +
+
+ +
+ +

+ Log in to quickly build your first custom AI code assistant +

+ +
+ + +
+ Or, remain local + +
+
+
+
+ ); +} diff --git a/gui/src/components/ProfileSwitcher.tsx b/gui/src/components/ProfileSwitcher.tsx index 3a5d620ded..0cecdf8150 100644 --- a/gui/src/components/ProfileSwitcher.tsx +++ b/gui/src/components/ProfileSwitcher.tsx @@ -7,7 +7,7 @@ import { UserCircleIcon as UserCircleIconSolid } from "@heroicons/react/24/solid import { ProfileDescription } from "core/config/ConfigHandler"; import { Fragment, useContext, useEffect, useState } from "react"; import ReactDOM from "react-dom"; -import { useDispatch, useSelector } from "react-redux"; +import { useDispatch } from "react-redux"; import styled from "styled-components"; import { defaultBorderRadius, @@ -21,11 +21,10 @@ import { import { useAuth } from "../context/Auth"; import { IdeMessengerContext } from "../context/IdeMessenger"; import { useWebviewListener } from "../hooks/useWebviewListener"; +import { useAppSelector } from "../redux/hooks"; import { setLastControlServerBetaEnabledStatus } from "../redux/slices/miscSlice"; -import { RootState } from "../redux/store"; import { getFontSize } from "../util"; import HeaderButtonWithToolTip from "./gui/HeaderButtonWithToolTip"; -import { useAppSelector } from "../redux/hooks"; const StyledListbox = styled(Listbox)` background-color: ${vscBackground}; @@ -148,9 +147,6 @@ function ProfileSwitcher() { const shouldShowPopup = !lastControlServerBetaEnabledStatus && enableControlServerBeta; - if (shouldShowPopup) { - ideMessenger.ide.showToast("info", "Continue for Teams enabled"); - } }); }, []); diff --git a/gui/src/components/dialogs/AboutAssistantDialog.tsx b/gui/src/components/dialogs/AboutAssistantDialog.tsx new file mode 100644 index 0000000000..6f1734a92e --- /dev/null +++ b/gui/src/components/dialogs/AboutAssistantDialog.tsx @@ -0,0 +1,80 @@ +import { useAppSelector } from "../../redux/hooks"; +import { RootState } from "../../redux/store"; + +export default function AboutAssistantDialog() { + const configJson = useAppSelector((store: RootState) => store.config.config); + + const selectedProfileId = useAppSelector( + (store) => store.session.selectedProfileId, + ); + + return ( +
+

About {selectedProfileId} assistant

+ {configJson.models?.length > 0 && ( +
+ Chat models +
    + {configJson.models.map((model, index) => ( +
  • {model.title}
  • + ))} +
+
+ )} + + {configJson.embeddingsProvider && ( +
+ Embedding model +

{configJson.embeddingsProvider}

+
+ )} + + {configJson.reranker?.name && ( +
+ Rerank model +

{configJson.reranker.name}

+
+ )} + + {configJson.systemMessage && ( +
+ Rules +

{configJson.systemMessage}

+
+ )} + + {configJson.contextProviders?.length ? ( +
+ Context +
    + {configJson.contextProviders.map((provider, index) => ( +
  • {provider.title}
  • + ))} +
+
+ ) : null} + + {configJson.tools.length > 0 && ( +
+ Tools +
    + {configJson.tools.map((tool, index) => ( +
  • {tool.displayTitle}
  • + ))} +
+
+ )} + + {configJson.docs?.length ? ( +
+ Docs +
    + {configJson.docs.map((doc, index) => ( +
  • {doc.title}
  • + ))} +
+
+ ) : null} +
+ ); +} diff --git a/gui/src/components/index.ts b/gui/src/components/index.ts index 95e1784805..071900b51d 100644 --- a/gui/src/components/index.ts +++ b/gui/src/components/index.ts @@ -1,4 +1,4 @@ -import styled, { keyframes } from "styled-components"; +import styled from "styled-components"; import { getFontSize, isJetBrains } from "../util"; export const VSC_INPUT_BACKGROUND_VAR = "--vscode-input-background"; @@ -298,11 +298,6 @@ export const HeaderButton = styled.button<{ padding: 2px; `; -export const Divider = styled.div` - height: 1px; - background-color: ${lightGray}; -`; - export const StyledActionButton = styled.div` display: flex; justify-content: space-between; diff --git a/gui/src/components/mainInput/ContextItemsPeek.tsx b/gui/src/components/mainInput/ContextItemsPeek.tsx index 0450d15af1..9e7d34070c 100644 --- a/gui/src/components/mainInput/ContextItemsPeek.tsx +++ b/gui/src/components/mainInput/ContextItemsPeek.tsx @@ -107,6 +107,7 @@ function ContextItemsPeekItem({ contextItem }: ContextItemsPeekItemProps) {
{getContextItemIcon()} @@ -161,6 +162,7 @@ function ContextItemsPeek({
setOpen((prev) => !prev)} + data-testid="context-items-peek" >
- + {usePlatform ? : }
{props.toolbarOptions?.hideImageUpload || (supportsImages && ( diff --git a/gui/src/components/mainInput/InputToolbar/PopoverTransition.tsx b/gui/src/components/mainInput/InputToolbar/PopoverTransition.tsx new file mode 100644 index 0000000000..24b64c27ba --- /dev/null +++ b/gui/src/components/mainInput/InputToolbar/PopoverTransition.tsx @@ -0,0 +1,26 @@ +import { Transition } from "@headlessui/react"; + +export default function PopoverTransition({ + children, + show, + afterLeave, +}: { + children: React.ReactNode; + show?: boolean; + afterLeave?: () => void; +}) { + return ( + + {children} + + ); +} diff --git a/gui/src/components/mainInput/InputToolbar/ToggleToolsButton.tsx b/gui/src/components/mainInput/InputToolbar/ToggleToolsButton.tsx index edb82fd3d3..cb6a056668 100644 --- a/gui/src/components/mainInput/InputToolbar/ToggleToolsButton.tsx +++ b/gui/src/components/mainInput/InputToolbar/ToggleToolsButton.tsx @@ -1,4 +1,4 @@ -import { Listbox, Transition } from "@headlessui/react"; +import { Listbox } from "@headlessui/react"; import { EllipsisHorizontalCircleIcon as EllipsisHorizontalIcon, WrenchScrewdriverIcon as WrenchScrewdriverIconOutline, @@ -7,12 +7,13 @@ import { WrenchScrewdriverIcon as WrenchScrewdriverIconSolid } from "@heroicons/ import { useEffect, useRef, useState } from "react"; import { useDispatch } from "react-redux"; import { lightGray, vscForeground } from "../.."; +import { useAppSelector } from "../../../redux/hooks"; import { toggleUseTools } from "../../../redux/slices/uiSlice"; +import { ToolTip } from "../../gui/Tooltip"; import InfoHover from "../../InfoHover"; import HoverItem from "./HoverItem"; +import PopoverTransition from "./PopoverTransition"; import ToolDropdownItem from "./ToolDropdownItem"; -import { useAppSelector } from "../../../redux/hooks"; -import { ToolTip } from "../../gui/Tooltip"; export default function ToolDropdown() { const buttonRef = useRef(null); @@ -85,14 +86,8 @@ export default function ToolDropdown() { > - setDropdownOpen(false)} > - + )} diff --git a/gui/src/components/mainInput/MentionList.tsx b/gui/src/components/mainInput/MentionList.tsx index 815a12f001..530144e4ad 100644 --- a/gui/src/components/mainInput/MentionList.tsx +++ b/gui/src/components/mainInput/MentionList.tsx @@ -46,11 +46,11 @@ import FileIcon from "../FileIcon"; import SafeImg from "../SafeImg"; import AddDocsDialog from "../dialogs/AddDocsDialog"; import HeaderButtonWithToolTip from "../gui/HeaderButtonWithToolTip"; -import { ComboBoxItem, ComboBoxItemType } from "./types"; import { DiscordIcon } from "../svg/DiscordIcon"; -import { GoogleIcon } from "../svg/GoogleIcon"; -import { GitlabIcon } from "../svg/GitlabIcon"; import { GithubIcon } from "../svg/GithubIcon"; +import { GitlabIcon } from "../svg/GitlabIcon"; +import { GoogleIcon } from "../svg/GoogleIcon"; +import { ComboBoxItem, ComboBoxItemType } from "./types"; const ICONS_FOR_DROPDOWN: { [key: string]: any } = { file: FolderIcon, @@ -414,6 +414,7 @@ const MentionList = forwardRef((props: MentionListProps, ref) => { key={index} onClick={() => selectItem(index)} onMouseEnter={() => setSelectedIndex(index)} + data-testid="context-provider-dropdown-item" >
diff --git a/gui/src/components/modelSelection/ModelSelect.tsx b/gui/src/components/modelSelection/ModelSelect.tsx index 93047be18c..63d0973706 100644 --- a/gui/src/components/modelSelection/ModelSelect.tsx +++ b/gui/src/components/modelSelection/ModelSelect.tsx @@ -6,17 +6,17 @@ import { PlusIcon, TrashIcon, } from "@heroicons/react/24/outline"; -import { MouseEvent, useContext, useEffect, useRef, useState } from "react"; -import { useDispatch, useSelector } from "react-redux"; +import { useContext, useEffect, useRef, useState } from "react"; +import { useDispatch } from "react-redux"; import styled from "styled-components"; -import { - defaultBorderRadius, - Divider, - lightGray, - vscInputBackground, -} from ".."; +import { defaultBorderRadius, lightGray, vscInputBackground } from ".."; import { IdeMessengerContext } from "../../context/IdeMessenger"; import AddModelForm from "../../forms/AddModelForm"; +import { useAppSelector } from "../../redux/hooks"; +import { + selectDefaultModel, + setDefaultModel, +} from "../../redux/slices/configSlice"; import { setDialogMessage, setShowDialog } from "../../redux/slices/uiSlice"; import { getFontSize, @@ -24,11 +24,7 @@ import { isMetaEquivalentKeyPressed, } from "../../util"; import ConfirmationDialog from "../dialogs/ConfirmationDialog"; -import { useAppSelector } from "../../redux/hooks"; -import { - selectDefaultModel, - setDefaultModel, -} from "../../redux/slices/configSlice"; +import { Divider } from "./platform/shared"; interface ModelOptionProps { option: Option; @@ -144,7 +140,7 @@ function ModelOption({ const dispatch = useDispatch(); const [hovered, setHovered] = useState(false); - function onClickDelete(e: MouseEvent) { + function onClickDelete(e: any) { e.stopPropagation(); e.preventDefault(); @@ -164,7 +160,7 @@ function ModelOption({ ); } - function onClickGear(e: MouseEvent) { + function onClickGear(e: any) { e.stopPropagation(); e.preventDefault(); @@ -173,7 +169,7 @@ function ModelOption({ }); } - function handleOptionClick(e: MouseEvent) { + function handleOptionClick(e: any) { if (showMissingApiKeyMsg) { e.preventDefault(); e.stopPropagation(); diff --git a/gui/src/components/modelSelection/platform/AssistantAndModelSelect.tsx b/gui/src/components/modelSelection/platform/AssistantAndModelSelect.tsx new file mode 100644 index 0000000000..4d0df15e72 --- /dev/null +++ b/gui/src/components/modelSelection/platform/AssistantAndModelSelect.tsx @@ -0,0 +1,194 @@ +import { Popover } from "@headlessui/react"; +import { + ChevronDownIcon, + DocumentIcon, + SparklesIcon, +} from "@heroicons/react/24/outline"; +import { useEffect, useRef, useState } from "react"; +import styled from "styled-components"; +import { defaultBorderRadius, lightGray, vscInputBackground } from "../.."; +import { useAuth } from "../../../context/Auth"; +import AddModelForm from "../../../forms/AddModelForm"; +import { useAppDispatch, useAppSelector } from "../../../redux/hooks"; +import { + cycleDefaultModel, + selectDefaultModel, +} from "../../../redux/slices/configSlice"; +import { setDialogMessage, setShowDialog } from "../../../redux/slices/uiSlice"; +import { cycleProfile } from "../../../redux/thunks/cycleProfile"; +import { getFontSize, isMetaEquivalentKeyPressed } from "../../../util"; +import PopoverTransition from "../../mainInput/InputToolbar/PopoverTransition"; +import { AssistantSelect } from "./AssistantSelect"; +import { ModelSelect, modelSelectTitle } from "./ModelSelect"; +import { MAX_HEIGHT_PX } from "./shared"; + +const StyledPopoverButton = styled(Popover.Button)` + font-family: inherit; + display: flex; + align-items: center; + gap: 2px; + border: none; + cursor: pointer; + font-size: ${getFontSize() - 2}px; + background: transparent; + color: ${lightGray}; + &:focus { + outline: none; + } +`; + +const StyledPopoverPanel = styled(Popover.Panel)<{ $showabove: boolean }>` + margin-top: 4px; + position: absolute; + padding: 0px; + cursor: default; + + display: flex; + flex-direction: row; + + border-radius: ${defaultBorderRadius}; + border: 0.5px solid ${lightGray}; + background-color: ${vscInputBackground}; + + ${(props) => (props.$showabove ? "bottom: 100%;" : "top: 100%;")} +`; + +function AssistantAndModelSelect() { + const dispatch = useAppDispatch(); + const defaultModel = useAppSelector(selectDefaultModel); + const [showAbove, setShowAbove] = useState(false); + const buttonRef = useRef(null); + const selectedProfileId = useAppSelector( + (store) => store.session.selectedProfileId, + ); + + const { profiles, selectedProfile } = useAuth(); + + useEffect(() => { + const handleResize = () => calculatePosition(); + window.addEventListener("resize", handleResize); + return () => window.removeEventListener("resize", handleResize); + }, []); + + function calculatePosition() { + if (!buttonRef.current) { + return; + } + const rect = buttonRef.current.getBoundingClientRect(); + const spaceBelow = window.innerHeight - rect.bottom; + const spaceAbove = rect.top; + const dropdownHeight = MAX_HEIGHT_PX; + + setShowAbove(spaceBelow < dropdownHeight && spaceAbove > spaceBelow); + } + + function onClickAddModel(e: any) { + e.stopPropagation(); + e.preventDefault(); + + // Close the dropdown + if (buttonRef.current) { + buttonRef.current.click(); + } + dispatch(setShowDialog(true)); + dispatch( + setDialogMessage( + { + dispatch(setShowDialog(false)); + }} + />, + ), + ); + } + + useEffect(() => { + let lastToggleTime = 0; + const DEBOUNCE_MS = 500; + + const handleKeyDown = (event: KeyboardEvent) => { + if (event.key === "'" && isMetaEquivalentKeyPressed(event as any)) { + const now = Date.now(); + + if (event.shiftKey) { + dispatch(cycleDefaultModel("next")); + } else { + if (now - lastToggleTime >= DEBOUNCE_MS) { + dispatch(cycleProfile()); + lastToggleTime = now; + } + } + } + }; + + window.addEventListener("keydown", handleKeyDown); + return () => { + window.removeEventListener("keydown", handleKeyDown); + }; + }, []); + + return ( + +
+ +
+ {selectedProfile?.id === "local" ? ( + <> +
+
+ + + { + if (buttonRef.current) { + buttonRef.current.click(); + } + }} + /> + + + +
+
+ ); +} + +export default AssistantAndModelSelect; diff --git a/gui/src/components/modelSelection/platform/AssistantSelect.tsx b/gui/src/components/modelSelection/platform/AssistantSelect.tsx new file mode 100644 index 0000000000..383c50f1ef --- /dev/null +++ b/gui/src/components/modelSelection/platform/AssistantSelect.tsx @@ -0,0 +1,111 @@ +import { + DocumentIcon, + PlusIcon, + SparklesIcon, +} from "@heroicons/react/24/outline"; +import { useContext } from "react"; +import { lightGray } from "../.."; +import { useAuth } from "../../../context/Auth"; +import { IdeMessengerContext } from "../../../context/IdeMessenger"; +import { useAppDispatch } from "../../../redux/hooks"; +import { setDialogMessage, setShowDialog } from "../../../redux/slices/uiSlice"; +import { setProfileId } from "../../../redux/thunks/setProfileId"; +import { getFontSize, getMetaKeyLabel } from "../../../util"; +import AboutAssistantDialog from "../../dialogs/AboutAssistantDialog"; +import { Divider, Option, OptionDiv } from "./shared"; + +interface AssistantSelectProps { + onClose: () => void; +} + +export function AssistantSelect(props: AssistantSelectProps) { + const ideMessenger = useContext(IdeMessengerContext); + const { profiles, selectedProfile } = useAuth(); + + const dispatch = useAppDispatch(); + + function onNewAssistant() { + ideMessenger.post("openUrl", "https://app-test.continue.dev/new"); + } + + return ( +
+
+ {profiles.map((option, idx) => ( + + ))} +
+ +
+ +
+ + New Assistant +
+
+ + + + + {getMetaKeyLabel()}' toggle assistant + +
+
+ ); +} diff --git a/gui/src/components/modelSelection/platform/ModelSelect.tsx b/gui/src/components/modelSelection/platform/ModelSelect.tsx new file mode 100644 index 0000000000..80ad7179a5 --- /dev/null +++ b/gui/src/components/modelSelection/platform/ModelSelect.tsx @@ -0,0 +1,121 @@ +import { CubeIcon, PlusIcon } from "@heroicons/react/24/outline"; +import { useContext, useEffect, useState } from "react"; +import { lightGray } from "../.."; +import { IdeMessengerContext } from "../../../context/IdeMessenger"; +import { useAppDispatch, useAppSelector } from "../../../redux/hooks"; +import { + selectDefaultModel, + setDefaultModel, +} from "../../../redux/slices/configSlice"; +import { getFontSize, getMetaKeyLabel } from "../../../util"; +import { Divider, Option, OptionDiv } from "./shared"; + +export interface ModelOption { + value: string; + title: string; + apiKey?: string; +} + +export function modelSelectTitle(model: any): string { + if (model?.title) return model?.title; + if (model?.model !== undefined && model?.model.trim() !== "") { + if (model?.class_name) { + return `${model?.class_name} - ${model?.model}`; + } + return model?.model; + } + return model?.class_name; +} + +interface ModelSelectProps { + selectedProfileId: string; + onClickAddModel: (e: any) => void; +} + +export function ModelSelect(props: ModelSelectProps) { + const dispatch = useAppDispatch(); + const ideMessenger = useContext(IdeMessengerContext); + const defaultModel = useAppSelector(selectDefaultModel); + const allModels = useAppSelector((state) => state.config.config.models); + const [options, setOptions] = useState([]); + const [sortedOptions, setSortedOptions] = useState([]); + + // Sort so that options without an API key are at the end + useEffect(() => { + const enabledOptions = options.filter((option) => option.apiKey !== ""); + const disabledOptions = options.filter((option) => option.apiKey === ""); + + const sorted = [...enabledOptions, ...disabledOptions]; + + setSortedOptions(sorted); + }, [options]); + + useEffect(() => { + setOptions( + allModels.map((model) => { + return { + value: model.title, + title: modelSelectTitle(model), + apiKey: model.apiKey, + }; + }), + ); + }, [allModels]); + + function onModelChange(modelTitle: string) { + if (modelTitle === defaultModel?.title) return; + dispatch(setDefaultModel({ title: modelTitle })); + } + + return ( +
+
+ {sortedOptions.map((option, idx) => ( + + ))} +
+ +
+ {props.selectedProfileId === "local" && ( + <> + +
+ + Add Chat model +
+
+ + )} + + + + + {getMetaKeyLabel()}⇧' toggle model + +
+
+ ); +} diff --git a/gui/src/components/modelSelection/platform/shared.tsx b/gui/src/components/modelSelection/platform/shared.tsx new file mode 100644 index 0000000000..d860cab49f --- /dev/null +++ b/gui/src/components/modelSelection/platform/shared.tsx @@ -0,0 +1,151 @@ +import { ConfigValidationError } from "@continuedev/config-yaml"; +import { + ArrowTopRightOnSquareIcon, + Cog6ToothIcon, + ExclamationTriangleIcon, +} from "@heroicons/react/24/outline"; +import { useState } from "react"; +import styled from "styled-components"; +import { defaultBorderRadius, lightGray } from "../.."; +import { ToolTip } from "../../gui/Tooltip"; + +export const OptionDiv = styled.div<{ + isDisabled?: boolean; + isSelected?: boolean; +}>` + border-radius: ${defaultBorderRadius}; + padding: 6px 12px; + + min-width: 0px; + + ${({ isDisabled, isSelected }) => + !isDisabled && + ` + cursor: pointer; + + &:hover { + background: ${lightGray}33; + } + + ${ + isSelected && + ` + background: ${lightGray}22; + ` + } + `} + + ${({ isDisabled }) => + isDisabled && + ` + opacity: 0.5; + `} +`; + +export const MAX_HEIGHT_PX = 300; + +export const Divider = styled.div` + height: 0.5px; + background-color: ${lightGray}; +`; + +interface ModelOptionProps { + children: React.ReactNode; + idx: number; + disabled: boolean; + selected: boolean; + showConfigure: boolean; + onLink?: (e: any) => void; + onConfigure: (e: any) => void; + onClick: () => void; + errors?: ConfigValidationError[]; + onClickError?: (e: any) => void; +} + +const IconBase = styled.div<{ $hovered: boolean }>` + width: 1.2em; + height: 1.2em; + cursor: pointer; + padding: 4px; + border-radius: ${defaultBorderRadius}; + opacity: ${(props) => (props.$hovered ? 0.75 : 0)}; + visibility: ${(props) => (props.$hovered ? "visible" : "hidden")}; + + &:hover { + opacity: 1; + background-color: ${lightGray}33; + } +`; + +const StyledCog6ToothIcon = styled(IconBase).attrs({ as: Cog6ToothIcon })``; +const StyledArrowTopRightOnSquareIcon = styled(IconBase).attrs({ + as: ArrowTopRightOnSquareIcon, +})``; +const StyledExclamationTriangleIcon = styled(IconBase).attrs({ + as: ExclamationTriangleIcon, +})``; + +export function Option({ + onConfigure, + children, + idx, + disabled, + onClick, + showConfigure, + onLink, + selected, + errors, + onClickError, +}: ModelOptionProps) { + const [hovered, setHovered] = useState(false); + + function handleOptionClick(e: any) { + if (disabled) { + e.preventDefault(); + e.stopPropagation(); + } + onClick(); + } + + return ( + setHovered(true)} + onMouseLeave={() => setHovered(false)} + isSelected={selected} + onClick={!disabled ? handleOptionClick : undefined} + > +
+
+ {children} +
+ {!errors?.length && showConfigure && ( + + )} + {!errors?.length && onLink && ( + + )} + {errors?.length ? ( + <> + + +
Errors
+ {JSON.stringify(errors, null, 2)} +
+ + ) : null} +
+
+
+
+ ); +} diff --git a/gui/src/context/Auth.tsx b/gui/src/context/Auth.tsx index c2ad174f29..361ee8e0c2 100644 --- a/gui/src/context/Auth.tsx +++ b/gui/src/context/Auth.tsx @@ -1,26 +1,30 @@ +import { ProfileDescription } from "core/config/ProfileLifecycleManager"; +import { ControlPlaneSessionInfo } from "core/control-plane/client"; import React, { createContext, useContext, - useState, useEffect, useMemo, + useState, } from "react"; -import { ControlPlaneSessionInfo } from "core/control-plane/client"; import { useDispatch } from "react-redux"; +import AccountDialog from "../components/AccountDialog"; import ConfirmationDialog from "../components/dialogs/ConfirmationDialog"; -import { IdeMessengerContext } from "./IdeMessenger"; -import { setDialogMessage, setShowDialog } from "../redux/slices/uiSlice"; -import { getLocalStorage, setLocalStorage } from "../util/localStorage"; -import { ProfileDescription } from "core/config/ProfileLifecycleManager"; -import { setLastControlServerBetaEnabledStatus } from "../redux/slices/miscSlice"; import { useWebviewListener } from "../hooks/useWebviewListener"; -import AccountDialog from "../components/AccountDialog"; import { useAppSelector } from "../redux/hooks"; +import { setLastControlServerBetaEnabledStatus } from "../redux/slices/miscSlice"; +import { + selectAvailableProfiles, + setAvailableProfiles, +} from "../redux/slices/sessionSlice"; +import { setDialogMessage, setShowDialog } from "../redux/slices/uiSlice"; +import { getLocalStorage, setLocalStorage } from "../util/localStorage"; +import { IdeMessengerContext } from "./IdeMessenger"; interface AuthContextType { session: ControlPlaneSessionInfo | undefined; logout: () => void; - login: () => void; + login: (useOnboarding: boolean) => Promise; selectedProfile: ProfileDescription | undefined; profiles: ProfileDescription[]; controlServerBetaEnabled: boolean; @@ -34,7 +38,9 @@ export const AuthProvider: React.FC<{ children: React.ReactNode }> = ({ const [session, setSession] = useState( undefined, ); - const [profiles, setProfiles] = useState([]); + + const profiles = useAppSelector(selectAvailableProfiles); + const selectedProfileId = useAppSelector( (store) => store.session.selectedProfileId, ); @@ -49,33 +55,42 @@ export const AuthProvider: React.FC<{ children: React.ReactNode }> = ({ (state) => state.misc.lastControlServerBetaEnabledStatus, ); - const login = () => { - ideMessenger - .request("getControlPlaneSessionInfo", { silent: false }) - .then((result) => { - if (result.status === "error") { - return; - } - const session = result.content; - setSession(session); - - // If this is the first time the user has logged in, explain how profiles work - if (!getLocalStorage("shownProfilesIntroduction")) { - dispatch(setShowDialog(true)); - dispatch( - setDialogMessage( - {}} - />, - ), - ); - setLocalStorage("shownProfilesIntroduction", true); - } - }); + const login: AuthContextType["login"] = (useOnboarding: boolean) => { + return new Promise((resolve) => { + ideMessenger + .request("getControlPlaneSessionInfo", { + silent: false, + useOnboarding, + }) + .then((result) => { + if (result.status === "error") { + resolve(false); + return; + } + + const session = result.content; + setSession(session); + + // If this is the first time the user has logged in, explain how profiles work + if (!getLocalStorage("shownProfilesIntroduction")) { + dispatch(setShowDialog(true)); + dispatch( + setDialogMessage( + {}} + />, + ), + ); + setLocalStorage("shownProfilesIntroduction", true); + } + + resolve(true); + }); + }); }; const logout = () => { @@ -102,12 +117,15 @@ export const AuthProvider: React.FC<{ children: React.ReactNode }> = ({ }); useWebviewListener("signInToControlPlane", async () => { - login(); + login(false); }); useEffect(() => { ideMessenger - .request("getControlPlaneSessionInfo", { silent: true }) + .request("getControlPlaneSessionInfo", { + silent: true, + useOnboarding: false, + }) .then( (result) => result.status === "success" && setSession(result.content), ); @@ -123,9 +141,6 @@ export const AuthProvider: React.FC<{ children: React.ReactNode }> = ({ const shouldShowPopup = !lastControlServerBetaEnabledStatus && enableControlServerBeta; - if (shouldShowPopup) { - ideMessenger.ide.showToast("info", "Continue for Teams enabled"); - } }); }, []); @@ -145,14 +160,16 @@ export const AuthProvider: React.FC<{ children: React.ReactNode }> = ({ ideMessenger .request("config/listProfiles", undefined) .then( - (result) => result.status === "success" && setProfiles(result.content), + (result) => + result.status === "success" && + dispatch(setAvailableProfiles(result.content)), ); }, []); useWebviewListener( "didChangeAvailableProfiles", async (data) => { - setProfiles(data.profiles); + dispatch(setAvailableProfiles(data.profiles)); }, [], ); diff --git a/gui/src/hooks/useSetup.ts b/gui/src/hooks/useSetup.ts index a545aedafb..57b4479af0 100644 --- a/gui/src/hooks/useSetup.ts +++ b/gui/src/hooks/useSetup.ts @@ -2,8 +2,8 @@ import { useCallback, useContext, useEffect, useRef } from "react"; import { VSC_THEME_COLOR_VARS } from "../components"; import { IdeMessengerContext } from "../context/IdeMessenger"; +import { ConfigResult } from "@continuedev/config-yaml"; import { BrowserSerializedContinueConfig } from "core"; -import { ConfigResult } from "core/config/load"; import { useAppDispatch, useAppSelector } from "../redux/hooks"; import { setConfigError, setConfigResult } from "../redux/slices/configSlice"; import { updateIndexingStatus } from "../redux/slices/indexingSlice"; diff --git a/gui/src/pages/gui/Chat.tsx b/gui/src/pages/gui/Chat.tsx index 99c61a71c3..3f36adcef8 100644 --- a/gui/src/pages/gui/Chat.tsx +++ b/gui/src/pages/gui/Chat.tsx @@ -32,6 +32,7 @@ import { OnboardingCard, useOnboardingCard, } from "../../components/OnboardingCard"; +import { PlatformOnboardingCard } from "../../components/OnboardingCard/platform/PlatformOnboardingCard"; import PageHeader from "../../components/PageHeader"; import StepContainer from "../../components/StepContainer"; import AcceptRejectAllButtons from "../../components/StepContainer/AcceptRejectAllButtons"; @@ -39,6 +40,7 @@ import { IdeMessengerContext } from "../../context/IdeMessenger"; import { useTutorialCard } from "../../hooks/useTutorialCard"; import { useWebviewListener } from "../../hooks/useWebviewListener"; import { useAppDispatch, useAppSelector } from "../../redux/hooks"; +import { selectUsePlatform } from "../../redux/selectors"; import { selectCurrentToolCall } from "../../redux/selectors/selectCurrentToolCall"; import { selectDefaultModel } from "../../redux/slices/configSlice"; import { submitEdit } from "../../redux/slices/editModeState"; @@ -57,6 +59,7 @@ import { import { RootState } from "../../redux/store"; import { cancelStream } from "../../redux/thunks/cancelStream"; import { exitEditMode } from "../../redux/thunks/exitEditMode"; +import { loadLastSession } from "../../redux/thunks/session"; import { streamResponseThunk } from "../../redux/thunks/streamResponse"; import { getFontSize, @@ -70,7 +73,6 @@ import ConfigErrorIndicator from "./ConfigError"; import { ToolCallDiv } from "./ToolCallDiv"; import { ToolCallButtons } from "./ToolCallDiv/ToolCallButtonsDiv"; import ToolOutput from "./ToolCallDiv/ToolOutput"; -import { loadLastSession } from "../../redux/thunks/session"; const StopButton = styled.div` background-color: ${vscBackground}; @@ -215,6 +217,7 @@ export function Chat() { selectIsSingleRangeEditOrInsertion, ); const lastSessionId = useAppSelector((state) => state.session.lastSessionId); + const usePlatform = useAppSelector(selectUsePlatform); useEffect(() => { // Cmd + Backspace to delete current step @@ -539,7 +542,11 @@ export function Chat() { <> {onboardingCard.show && (
- + {usePlatform ? ( + + ) : ( + + )}
)} diff --git a/gui/src/redux/selectors/index.ts b/gui/src/redux/selectors/index.ts index 4ffd3d4aff..737e865d39 100644 --- a/gui/src/redux/selectors/index.ts +++ b/gui/src/redux/selectors/index.ts @@ -42,3 +42,8 @@ export const selectUseActiveFile = createSelector( [(state: RootState) => state.config.config.experimental?.defaultContext], (defaultContext) => defaultContext?.includes("activeFile" as any), ); + +export const selectUsePlatform = createSelector( + [(state: RootState) => state.config.config.usePlatform], + (usePlatform) => usePlatform, +); diff --git a/gui/src/redux/slices/configSlice.ts b/gui/src/redux/slices/configSlice.ts index 99401fe05d..e8939e81d0 100644 --- a/gui/src/redux/slices/configSlice.ts +++ b/gui/src/redux/slices/configSlice.ts @@ -1,7 +1,6 @@ +import { ConfigResult, ConfigValidationError } from "@continuedev/config-yaml"; import { createSlice, PayloadAction } from "@reduxjs/toolkit"; import { BrowserSerializedContinueConfig } from "core"; -import { ConfigResult } from "core/config/load"; -import { ConfigValidationError } from "core/config/validation"; import { DEFAULT_MAX_TOKENS } from "core/llm/constants"; export type ConfigState = { @@ -27,6 +26,7 @@ const initialState: ConfigState = { contextProviders: [], models: [], tools: [], + usePlatform: false, }, }; @@ -80,6 +80,20 @@ export const configSlice = createSlice({ defaultModelTitle: payload.title, }; }, + cycleDefaultModel: (state, { payload }: PayloadAction<"next" | "prev">) => { + const currentIndex = state.config.models.findIndex( + (model) => model.title === state.defaultModelTitle, + ); + const nextIndex = + (currentIndex + + (payload === "next" ? 1 : -1) + + state.config.models.length) % + state.config.models.length; + return { + ...state, + defaultModelTitle: state.config.models[nextIndex].title, + }; + }, }, selectors: { selectDefaultModel: (state) => { @@ -101,6 +115,7 @@ export const configSlice = createSlice({ export const { setDefaultModel, + cycleDefaultModel, updateConfig, setConfigResult, setConfigError, diff --git a/gui/src/redux/slices/sessionSlice.ts b/gui/src/redux/slices/sessionSlice.ts index 0fbfa3bfe8..919490412f 100644 --- a/gui/src/redux/slices/sessionSlice.ts +++ b/gui/src/redux/slices/sessionSlice.ts @@ -21,14 +21,15 @@ import { ToolCallDelta, ToolCallState, } from "core"; +import { ProfileDescription } from "core/config/ConfigHandler"; import { NEW_SESSION_TITLE } from "core/util/constants"; import { incrementalParseJson } from "core/util/incrementalParseJson"; import { renderChatMessage } from "core/util/messageContent"; +import { findUriInDirs, getUriPathBasename } from "core/util/uri"; import { v4 as uuidv4 } from "uuid"; import { RootState } from "../store"; import { streamResponseThunk } from "../thunks/streamResponse"; import { findCurrentToolCall } from "../util"; -import { findUriInDirs, getUriPathBasename } from "core/util/uri"; // We need this to handle reorderings (e.g. a mid-array deletion) of the messages array. // The proper fix is adding a UUID to all chat messages, but this is the temp workaround. @@ -44,6 +45,7 @@ type SessionState = { title: string; id: string; selectedProfileId: string; + availableProfiles: ProfileDescription[]; streamAborter: AbortController; codeToEdit: CodeToEdit[]; curCheckpointIndex: number; @@ -83,6 +85,13 @@ const initialState: SessionState = { title: NEW_SESSION_TITLE, id: uuidv4(), selectedProfileId: "local", + availableProfiles: [ + { + id: "local", + title: "Local", + errors: undefined, + }, + ], curCheckpointIndex: 0, streamAborter: new AbortController(), codeToEdit: [], @@ -504,8 +513,19 @@ export const sessionSlice = createSlice({ selectedProfileId: payload, }; }, - setCurCheckpointIndex: (state, { payload }: PayloadAction) => { - state.curCheckpointIndex = payload; + setAvailableProfiles: ( + state, + { payload }: PayloadAction, + ) => { + return { + ...state, + availableProfiles: payload, + selectedProfileId: payload.find( + (profile) => profile.id === state.selectedProfileId, + ) + ? state.selectedProfileId + : payload[0]?.id, + }; }, updateCurCheckpoint: ( state, @@ -516,6 +536,9 @@ export const sessionSlice = createSlice({ checkpoint[payload.filepath] = payload.content; } }, + setCurCheckpointIndex: (state, { payload }: PayloadAction) => { + state.curCheckpointIndex = payload; + }, updateApplyState: (state, { payload }: PayloadAction) => { const applyState = state.codeBlockApplyStates.states.find( (state) => state.streamId === payload.streamId, @@ -618,6 +641,9 @@ export const sessionSlice = createSlice({ selectHasCodeToEdit: (state) => { return state.codeToEdit.length > 0; }, + selectAvailableProfiles: (state) => { + return state.availableProfiles; + }, }, extraReducers: (builder) => { addPassthroughCases(builder, [streamResponseThunk]); @@ -681,6 +707,7 @@ export const { removeCodeToEdit, setCalling, cancelToolCall, + setAvailableProfiles, acceptToolCall, setToolGenerated, setToolCallOutput, @@ -696,6 +723,7 @@ export const { selectIsInEditMode, selectIsSingleRangeEditOrInsertion, selectHasCodeToEdit, + selectAvailableProfiles, } = sessionSlice.selectors; export default sessionSlice.reducer; diff --git a/gui/src/redux/thunks/cycleProfile.ts b/gui/src/redux/thunks/cycleProfile.ts new file mode 100644 index 0000000000..0ee548c319 --- /dev/null +++ b/gui/src/redux/thunks/cycleProfile.ts @@ -0,0 +1,21 @@ +import { createAsyncThunk } from "@reduxjs/toolkit"; +import { ThunkApiType } from "../store"; +import { setProfileId } from "./setProfileId"; + +export const cycleProfile = createAsyncThunk( + "profile/cycle", + async (_, { dispatch, getState, extra }) => { + const state = getState(); + const profileIds = state.session.availableProfiles.map( + (profile) => profile.id, + ); + const curIndex = profileIds.indexOf(state.session.selectedProfileId); + const nextIndex = (curIndex + 1) % profileIds.length; + const nextId = profileIds[nextIndex]; + dispatch(setProfileId(nextId)); + + extra.ideMessenger.post("didChangeSelectedProfile", { + id: nextId, + }); + }, +); diff --git a/gui/src/redux/thunks/setProfileId.ts b/gui/src/redux/thunks/setProfileId.ts new file mode 100644 index 0000000000..6ec1020ce9 --- /dev/null +++ b/gui/src/redux/thunks/setProfileId.ts @@ -0,0 +1,13 @@ +import { createAsyncThunk } from "@reduxjs/toolkit"; +import { setSelectedProfileId } from "../slices/sessionSlice"; +import { ThunkApiType } from "../store"; + +export const setProfileId = createAsyncThunk( + "profile/setId", + async (id, { dispatch, extra }) => { + dispatch(setSelectedProfileId(id)); + extra.ideMessenger.post("didChangeSelectedProfile", { + id, + }); + }, +); diff --git a/packages/config-types/package.json b/packages/config-types/package.json index fd8a423729..99df3ce607 100644 --- a/packages/config-types/package.json +++ b/packages/config-types/package.json @@ -1,6 +1,6 @@ { "name": "@continuedev/config-types", - "version": "1.0.13", + "version": "1.0.14", "description": "", "main": "dist/index.js", "types": "dist/index.d.ts", diff --git a/packages/config-yaml/package-lock.json b/packages/config-yaml/package-lock.json index 28126b962d..b62809bc35 100644 --- a/packages/config-yaml/package-lock.json +++ b/packages/config-yaml/package-lock.json @@ -1,14 +1,15 @@ { "name": "@continuedev/config-yaml", - "version": "1.0.0", + "version": "1.0.10", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@continuedev/config-yaml", - "version": "1.0.0", + "version": "1.0.10", "license": "Apache-2.0", "dependencies": { + "@continuedev/config-types": "^1.0.14", "yaml": "^2.6.1", "zod": "^3.23.8" }, @@ -18,6 +19,17 @@ "ts-node": "^10.9.2" } }, + "../config-types": { + "name": "@continuedev/config-types", + "version": "1.0.14", + "license": "Apache-2.0", + "dependencies": { + "zod": "^3.23.8" + }, + "devDependencies": { + "typescript": "^5.5.2" + } + }, "node_modules/@ampproject/remapping": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.3.0.tgz", @@ -575,6 +587,10 @@ "dev": true, "peer": true }, + "node_modules/@continuedev/config-types": { + "resolved": "../config-types", + "link": true + }, "node_modules/@cspotcode/source-map-support": { "version": "0.8.1", "resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz", diff --git a/packages/config-yaml/package.json b/packages/config-yaml/package.json index 3903c91d81..1e77d7d228 100644 --- a/packages/config-yaml/package.json +++ b/packages/config-yaml/package.json @@ -1,6 +1,6 @@ { "name": "@continuedev/config-yaml", - "version": "1.0.0", + "version": "1.0.11", "description": "", "main": "dist/index.js", "types": "dist/index.d.ts", @@ -13,7 +13,8 @@ "license": "Apache-2.0", "dependencies": { "yaml": "^2.6.1", - "zod": "^3.23.8" + "zod": "^3.23.8", + "@continuedev/config-types": "^1.0.14" }, "devDependencies": { "cross-env": "^7.0.3", diff --git a/packages/config-yaml/src/converter.ts b/packages/config-yaml/src/converter.ts new file mode 100644 index 0000000000..485d469a9e --- /dev/null +++ b/packages/config-yaml/src/converter.ts @@ -0,0 +1,142 @@ +import { ConfigJson } from "@continuedev/config-types"; +import { ConfigYaml } from "./schemas/index.js"; +import { ModelRoles } from "./schemas/models.js"; + +type ModelYaml = NonNullable[number]; +type ContextYaml = NonNullable[number]; +type PromptYaml = NonNullable[number]; + +function convertModel( + m: ConfigJson["models"][number], + roles: ModelRoles[], +): ModelYaml { + return { + name: m.title, + provider: m.provider, + model: m.model, + apiKey: m.apiKey, + apiBase: m.apiBase, + roles, + requestOptions: m.requestOptions, + defaultCompletionOptions: m.completionOptions, + }; +} + +function convertEmbeddingsProvider( + m: NonNullable, +): ModelYaml { + return { + name: "Embeddings Model", + provider: m.provider, + model: m.model ?? "", + apiKey: m.apiKey, + apiBase: m.apiBase, + roles: ["embed"], + }; +} + +function convertReranker(m: NonNullable): ModelYaml { + return { + name: "Reranker", + provider: m.name, + model: m.params?.model ?? "", + apiKey: m.params?.apiKey, + apiBase: m.params?.apiBase, + roles: ["rerank"], + }; +} + +function convertContext(configJson: ConfigJson): ContextYaml[] { + const context: ContextYaml[] = + configJson.contextProviders?.map((ctx) => { + return { + uses: `builtin/${ctx.name}`, + with: ctx.params, + }; + }) ?? []; + + return context; +} + +function convertCustomCommand( + cmd: NonNullable[number], +): PromptYaml { + return { + name: cmd.name, + description: cmd.description, + prompt: (cmd as any).prompt, // The type is wrong in @continuedev/config-types + type: "slash-command", + }; +} + +function convertMcp(mcp: any): NonNullable[number] { + const { transport } = mcp; + const { command, args, env } = transport; + + return { + command, + args, + env, + name: "MCP Server", + }; +} + +function convertDoc( + doc: NonNullable[number], +): NonNullable[number] { + return { + name: doc.title, + startUrl: doc.startUrl, + rootUrl: doc.rootUrl, + faviconUrl: doc.faviconUrl, + }; +} + +export function convertJsonToYamlConfig(configJson: ConfigJson): ConfigYaml { + // models + const models = configJson.models.map((m) => convertModel(m, ["chat"])); + const autocompleteModels = Array.isArray(configJson.tabAutocompleteModel) + ? configJson.tabAutocompleteModel + : configJson.tabAutocompleteModel + ? [configJson.tabAutocompleteModel] + : []; + models.push( + ...autocompleteModels.map((m) => convertModel(m, ["autocomplete"])), + ); + + if (configJson.embeddingsProvider) { + models.push(convertEmbeddingsProvider(configJson.embeddingsProvider)); + } + + if (configJson.reranker) { + models.push(convertReranker(configJson.reranker)); + } + + // context + const context = convertContext(configJson); + + // mcpServers + // Types for "experimental" don't exist + const mcpServers = ( + configJson as any + ).experimental?.modelContextProtocolServers?.map(convertMcp); + + // prompts + const prompts = configJson.customCommands?.map(convertCustomCommand); + + // docs + const docs = configJson.docs?.map(convertDoc); + + const configYaml: ConfigYaml = { + name: "Continue Config", + version: "0.0.1", + models, + context, + rules: configJson.systemMessage ? [configJson.systemMessage] : undefined, + prompts, + mcpServers, + docs, + }; + + return configYaml; +} diff --git a/packages/config-yaml/src/index.ts b/packages/config-yaml/src/index.ts index c229f147bd..c3be3e7133 100644 --- a/packages/config-yaml/src/index.ts +++ b/packages/config-yaml/src/index.ts @@ -91,5 +91,17 @@ export function fillTemplateVariables( }); } +export { convertJsonToYamlConfig } from "./converter.js"; +export { resolveSecretsOnClient } from "./resolveSecretsOnClient.js"; +export { + ClientConfigYaml, + clientConfigYamlSchema, + ConfigYaml, + configYamlSchema, +} from "./schemas/index.js"; export type { ModelConfig } from "./schemas/models.js"; -export { validateConfigYaml } from "./validation.js"; +export { + ConfigResult, + ConfigValidationError, + validateConfigYaml, +} from "./validation.js"; diff --git a/packages/config-yaml/src/resolveSecretsOnClient.ts b/packages/config-yaml/src/resolveSecretsOnClient.ts new file mode 100644 index 0000000000..c3029fbda2 --- /dev/null +++ b/packages/config-yaml/src/resolveSecretsOnClient.ts @@ -0,0 +1,64 @@ +import { ClientConfigYaml, ConfigYaml } from "./schemas/index.js"; +type SecretProvider = ( + secretNames: string[], +) => Promise<{ [key: string]: string }>; + +/** + * Take a ConfigYaml with apiKeySecrets, and look to fill in these secrets + * with whatever secret store exists in the client. + */ +export async function resolveSecretsOnClient( + configYaml: ClientConfigYaml, + getSecretsFromClientStore: SecretProvider, + getSecretsFromServer: SecretProvider, +): Promise { + const requiredSecrets = getRequiredSecretsInClientConfig(configYaml); + + const secretsFoundOnClient = await getSecretsFromClientStore(requiredSecrets); + + const secretsNotFoundOnClient = requiredSecrets.filter( + (secret) => !secretsFoundOnClient[secret], + ); + + let secretsFoundOnServer = {}; + if (secretsNotFoundOnClient.length > 0) { + secretsFoundOnServer = await getSecretsFromServer(secretsNotFoundOnClient); + } + + const clientSecrets = { + ...secretsFoundOnClient, + ...secretsFoundOnServer, + }; + + const finalConfigYaml = injectClientSecrets(configYaml, clientSecrets); + + // Anything with an apiKeySecret left over must use proxy + return finalConfigYaml; +} + +function getRequiredSecretsInClientConfig( + configYaml: ClientConfigYaml, +): string[] { + const secrets = new Set(); + for (const model of configYaml.models ?? []) { + if (model.apiKeySecret) { + secrets.add(model.apiKeySecret); + } + } + return Array.from(secrets); +} + +function injectClientSecrets( + configYaml: ClientConfigYaml, + clientSecrets: Record, +): ConfigYaml { + for (const model of configYaml.models ?? []) { + if (model.apiKeySecret && clientSecrets[model.apiKeySecret]) { + // Remove apiKeySecret and place the client secret in apiKey + model.apiKey = clientSecrets[model.apiKeySecret]; + delete model.apiKeySecret; + } + } + + return configYaml; +} diff --git a/packages/config-yaml/src/schemas/index.ts b/packages/config-yaml/src/schemas/index.ts index a75566f96b..f763c7be56 100644 --- a/packages/config-yaml/src/schemas/index.ts +++ b/packages/config-yaml/src/schemas/index.ts @@ -6,9 +6,13 @@ import { modelSchema } from "./models.js"; const packageSchema = z.object({ uses: z.string(), with: z.any().optional(), + secrets: z.array(z.string()).optional(), }); const toolSchema = z.object({ + name: z.string(), + description: z.string(), + policy: z.enum(["automatic", "allowed", "disabled"]).optional(), url: z.string(), apiKey: z.string().optional(), }); @@ -21,14 +25,36 @@ const mcpServerSchema = z.object({ env: z.record(z.string()).optional(), }); +const promptSchema = z.object({ + name: z.string(), + description: z.string().optional(), + type: z.enum(["slash-command", "context-provider"]).optional(), + prompt: z.string(), +}); + +const docSchema = z.object({ + name: z.string(), + startUrl: z.string(), + rootUrl: z.string().optional(), + faviconUrl: z.string().optional(), +}); + export const configYamlSchema = z.object({ name: z.string(), + version: z.string(), packages: z.array(packageSchema).optional(), models: z.array(modelSchema).optional(), context: z.array(contextSchema).optional(), data: z.array(dataSchema).optional(), tools: z.array(toolSchema).optional(), mcpServers: z.array(mcpServerSchema).optional(), + rules: z.array(z.string()).optional(), + prompts: z.array(promptSchema).optional(), + docs: z.array(docSchema).optional(), }); export type ConfigYaml = z.infer; + +export const clientConfigYamlSchema = configYamlSchema.omit({ packages: true }); + +export type ClientConfigYaml = z.infer; diff --git a/packages/config-yaml/src/schemas/models.ts b/packages/config-yaml/src/schemas/models.ts index 57415247d2..89b740e573 100644 --- a/packages/config-yaml/src/schemas/models.ts +++ b/packages/config-yaml/src/schemas/models.ts @@ -46,6 +46,9 @@ export const modelSchema = z.object({ name: z.string(), provider: z.string(), model: z.string(), + apiKey: z.string().optional(), + apiBase: z.string().optional(), + apiKeySecret: z.string().optional(), roles: modelRolesSchema.array().optional(), defaultCompletionOptions: completionOptionsSchema.optional(), requestOptions: requestOptionsSchema.optional(), diff --git a/packages/config-yaml/src/validation.ts b/packages/config-yaml/src/validation.ts index 223764e71b..7af938d999 100644 --- a/packages/config-yaml/src/validation.ts +++ b/packages/config-yaml/src/validation.ts @@ -1,15 +1,16 @@ import { ConfigYaml, configYamlSchema } from "./schemas/index.js"; -export enum ValidationLevel { - Warning, - Error, -} - export interface ConfigValidationError { - level: ValidationLevel; + fatal: boolean; message: string; } +export interface ConfigResult { + config: T | undefined; + errors: ConfigValidationError[] | undefined; + configLoadInterrupted: boolean; +} + export function validateConfigYaml( config: ConfigYaml, ): ConfigValidationError[] { @@ -20,7 +21,7 @@ export function validateConfigYaml( } catch (e: any) { return [ { - level: ValidationLevel.Error, + fatal: true, message: e.message, }, ]; @@ -38,7 +39,7 @@ export function validateConfigYaml( if (difference < 1000) { errors.push({ - level: ValidationLevel.Warning, + fatal: false, message: `Model "${model.name}" has a contextLength of ${model.defaultCompletionOptions?.contextLength} and a maxTokens of ${model.defaultCompletionOptions?.maxTokens}. This leaves only ${difference} tokens for input context and will likely result in your inputs being truncated.`, }); } @@ -60,7 +61,7 @@ export function validateConfigYaml( !modelName.toLowerCase().includes("coder") ) { errors.push({ - level: ValidationLevel.Warning, + fatal: false, message: `${model.model} is not trained for tab-autocomplete, and will result in low-quality suggestions. See the docs to learn more about why: https://docs.continue.dev/features/tab-autocomplete#i-want-better-completions-should-i-use-gpt-4`, }); } diff --git a/packages/openai-adapters/package.json b/packages/openai-adapters/package.json index a3d0bf7d0e..e075e8e9a8 100644 --- a/packages/openai-adapters/package.json +++ b/packages/openai-adapters/package.json @@ -1,6 +1,6 @@ { "name": "@continuedev/openai-adapters", - "version": "1.0.10", + "version": "1.0.12", "description": "", "main": "dist/index.js", "types": "dist/index.d.ts", diff --git a/packages/openai-adapters/src/apis/OpenAI.ts b/packages/openai-adapters/src/apis/OpenAI.ts index db52a67587..aa059f7dbf 100644 --- a/packages/openai-adapters/src/apis/OpenAI.ts +++ b/packages/openai-adapters/src/apis/OpenAI.ts @@ -4,6 +4,7 @@ import { OpenAI } from "openai/index"; import { ChatCompletion, ChatCompletionChunk, + ChatCompletionCreateParams, ChatCompletionCreateParamsNonStreaming, ChatCompletionCreateParamsStreaming, Completion, @@ -34,22 +35,46 @@ export class OpenAIApi implements BaseLlmApi { }); } + modifyChatBody(body: T): T { + // o-series models + if (body.model.startsWith("o")) { + // a) use max_completion_tokens instead of max_tokens + body.max_completion_tokens = body.max_tokens; + body.max_tokens = undefined; + + // b) use "developer" message role rather than "system" + body.messages = body.messages.map((message) => { + if (message.role === "system") { + return { ...message, role: "developer" } as any; + } + return message; + }); + } + return body; + } + async chatCompletionNonStream( body: ChatCompletionCreateParamsNonStreaming, signal: AbortSignal, ): Promise { - const response = await this.openai.chat.completions.create(body, { - signal, - }); + const response = await this.openai.chat.completions.create( + this.modifyChatBody(body), + { + signal, + }, + ); return response; } async *chatCompletionStream( body: ChatCompletionCreateParamsStreaming, signal: AbortSignal, ): AsyncGenerator { - const response = await this.openai.chat.completions.create(body, { - signal, - }); + const response = await this.openai.chat.completions.create( + this.modifyChatBody(body), + { + signal, + }, + ); for await (const result of response) { yield result; } diff --git a/packages/openai-adapters/src/index.ts b/packages/openai-adapters/src/index.ts index 4a42bbfeef..094920e63b 100644 --- a/packages/openai-adapters/src/index.ts +++ b/packages/openai-adapters/src/index.ts @@ -103,3 +103,4 @@ export { // export export type { BaseLlmApi } from "./apis/base.js"; +export type { LLMConfig } from "./types.js"; diff --git a/packages/openai-adapters/src/types.ts b/packages/openai-adapters/src/types.ts index 29d0310f25..c86d63abfc 100644 --- a/packages/openai-adapters/src/types.ts +++ b/packages/openai-adapters/src/types.ts @@ -54,6 +54,7 @@ export const OpenAIConfigSchema = BasePlusConfig.extend({ z.literal("text-gen-webui"), z.literal("vllm"), z.literal("x-ai"), + z.literal("scaleway"), ]), }); export type OpenAIConfig = z.infer;