Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: only perform lancedb cpu check for linux #3883

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
26 changes: 10 additions & 16 deletions core/config/load.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import * as fs from "fs";
import os from "os";
import path from "path";

import { ConfigResult, ConfigValidationError } from "@continuedev/config-yaml";
import { fetchwithRequestOptions } from "@continuedev/fetch";
import * as JSONC from "comment-json";
import * as tar from "tar";
Expand Down Expand Up @@ -30,14 +31,16 @@ import {
import {
slashCommandFromDescription,
slashFromCustomCommand,
} from "../commands/index.js";
} from "../commands/index";
import { AllRerankers } from "../context/allRerankers";
import { MCPManagerSingleton } from "../context/mcp";
import CodebaseContextProvider from "../context/providers/CodebaseContextProvider";
import ContinueProxyContextProvider from "../context/providers/ContinueProxyContextProvider";
import CustomContextProviderClass from "../context/providers/CustomContextProvider";
import FileContextProvider from "../context/providers/FileContextProvider";
import { contextProviderClassFromName } from "../context/providers/index";
import PromptFilesContextProvider from "../context/providers/PromptFilesContextProvider";
import { useHub } from "../control-plane/env";
import { allEmbeddingsProviders } from "../indexing/allEmbeddingsProviders";
import { BaseLLM } from "../llm";
import { llmFromDescription } from "../llm/llms";
Expand All @@ -61,18 +64,17 @@ import {
getContinueDotEnv,
getEsbuildBinaryPath,
} from "../util/paths";
import { localPathToUri } from "../util/pathToUri";

import { ConfigResult, ConfigValidationError } from "@continuedev/config-yaml";
import {
defaultContextProvidersJetBrains,
defaultContextProvidersVsCode,
defaultSlashCommandsJetBrains,
defaultSlashCommandsVscode,
} from "./default";
import { getSystemPromptDotFile } from "./getSystemPromptDotFile";
import { useHub } from "../control-plane/env";
import { localPathToUri } from "../util/pathToUri";
import { modifyContinueConfigWithSharedConfig } from "./sharedConfig";
import { isSupportedLanceDbCpuTargetForLinux } from "./util";
import { validateConfig } from "./validation.js";

export function resolveSerializedConfig(
Expand Down Expand Up @@ -166,10 +168,9 @@ function loadSerializedConfig(
? [...defaultSlashCommandsVscode]
: [...defaultSlashCommandsJetBrains];

// Temporarily disabling this check until we can verify the commands are accuarate
// if (!isSupportedLanceDbCpuTarget(ide)) {
// config.disableIndexing = true;
// }
if (os.platform() === "linux" && !isSupportedLanceDbCpuTargetForLinux(ide)) {
config.disableIndexing = true;
}

return { config, errors, configLoadInterrupted: false };
}
Expand Down Expand Up @@ -226,13 +227,6 @@ export function isContextProviderWithParams(
return (contextProvider as ContextProviderWithParams).name !== undefined;
}

const getCodebaseProvider = async (params: any) => {
const { default: CodebaseContextProvider } = await import(
"../context/providers/CodebaseContextProvider"
);
return new CodebaseContextProvider(params);
};

/** Only difference between intermediate and final configs is the `models` array */
async function intermediateToFinalConfig(
config: Config,
Expand Down Expand Up @@ -397,7 +391,7 @@ async function intermediateToFinalConfig(
new FileContextProvider({}),
// Add codebase provider if indexing is enabled
...(!config.disableIndexing
? [await getCodebaseProvider(codebaseContextParams)]
? [new CodebaseContextProvider(codebaseContextParams)]
: []),
// Add prompt files provider if enabled
...(loadPromptFiles ? [new PromptFilesContextProvider({})] : []),
Expand Down
46 changes: 17 additions & 29 deletions core/config/util.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { execSync } from "child_process";
import fs from "fs";
import os from "os";

import {
Expand Down Expand Up @@ -108,63 +108,49 @@ export function getModelByRole<T extends keyof ModelRoles>(
*
* See here for details: https://github.com/continuedev/continue/issues/940
*/
export function isSupportedLanceDbCpuTarget(ide: IDE) {
export function isSupportedLanceDbCpuTargetForLinux(ide?: IDE) {
const CPU_FEATURES_TO_CHECK = ["avx2", "fma"] as const;

const globalContext = new GlobalContext();
const globalContextVal = globalContext.get("isSupportedLanceDbCpuTarget");
const globalContextVal = globalContext.get(
"isSupportedLanceDbCpuTargetForLinux",
);

// If we've already checked the CPU target, return the cached value
if (globalContextVal !== undefined) {
return globalContextVal;
}

const arch = os.arch();
const platform = os.platform();

// This check only applies to x64
//https://github.com/lancedb/lance/issues/2195#issuecomment-2057841311
if (arch !== "x64") {
globalContext.update("isSupportedLanceDbCpuTarget", true);
globalContext.update("isSupportedLanceDbCpuTargetForLinux", true);
return true;
}

try {
const cpuFlags = (() => {
switch (platform) {
case "darwin":
return execSync("sysctl -n machdep.cpu.features")
.toString()
.toLowerCase();
case "linux":
return execSync("cat /proc/cpuinfo").toString().toLowerCase();
case "win32":
return execSync("wmic cpu get caption /format:list")
.toString()
.toLowerCase();
default:
return "";
}
})();
const cpuFlags = fs.readFileSync("/proc/cpuinfo", "utf-8").toLowerCase();

const isSupportedLanceDbCpuTarget = cpuFlags
const isSupportedLanceDbCpuTargetForLinux = cpuFlags
? CPU_FEATURES_TO_CHECK.every((feature) => cpuFlags.includes(feature))
: true;

// If it's not a supported CPU target, and it's the first time we are checking,
// show a toast to inform the user that we are going to disable indexing.
if (!isSupportedLanceDbCpuTarget) {
if (!isSupportedLanceDbCpuTargetForLinux && ide) {
// We offload our async toast to `showUnsupportedCpuToast` to prevent making
// our config loading async upstream of `isSupportedLanceDbCpuTarget`
// our config loading async upstream of `isSupportedLanceDbCpuTargetForLinux`
void showUnsupportedCpuToast(ide);
}

globalContext.update(
"isSupportedLanceDbCpuTarget",
isSupportedLanceDbCpuTarget,
"isSupportedLanceDbCpuTargetForLinux",
isSupportedLanceDbCpuTargetForLinux,
);

return isSupportedLanceDbCpuTarget;
return isSupportedLanceDbCpuTargetForLinux;
} catch (error) {
// If we can't determine CPU features, default to true
return true;
Expand All @@ -174,11 +160,13 @@ export function isSupportedLanceDbCpuTarget(ide: IDE) {
async function showUnsupportedCpuToast(ide: IDE) {
const shouldOpenLink = await ide.showToast(
"warning",
"Codebase indexing is disabled due to CPU incompatibility",
"Codebase indexing disabled - Your Linux system lacks required CPU features (AVX2, FMA)",
"Learn more",
);

if (shouldOpenLink) {
void ide.openUrl("https://github.com/continuedev/continue/pull/3551");
void ide.openUrl(
"https://docs.continue.dev/troubleshooting#i-received-a-codebase-indexing-disabled---your-linux-system-lacks-required-cpu-features-avx2-fma-notification",
);
}
}
19 changes: 15 additions & 4 deletions core/context/retrieval/pipelines/BaseRetrievalPipeline.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ export interface IRetrievalPipeline {

export default class BaseRetrievalPipeline implements IRetrievalPipeline {
private ftsIndex = new FullTextSearchCodebaseIndex();
private lanceDbIndex: LanceDbIndex;
private lanceDbIndex: LanceDbIndex | null = null;

constructor(protected readonly options: RetrievalPipelineOptions) {
this.lanceDbIndex = new LanceDbIndex(
options.config.embeddingsProvider,
(uri) => options.ide.readFile(uri),
void this.initLanceDb();
}

private async initLanceDb() {
this.lanceDbIndex = await LanceDbIndex.create(
this.options.config.embeddingsProvider,
(uri) => this.options.ide.readFile(uri),
);
}

Expand Down Expand Up @@ -125,6 +129,13 @@ export default class BaseRetrievalPipeline implements IRetrievalPipeline {
input: string,
n: number,
): Promise<Chunk[]> {
if (!this.lanceDbIndex) {
console.warn(
"LanceDB index not available, skipping embeddings retrieval",
);
return [];
}

return this.lanceDbIndex.retrieve(
input,
n,
Expand Down
25 changes: 16 additions & 9 deletions core/indexing/CodebaseIndexer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import { IContinueServerClient } from "../continueServer/interface.js";
import { IDE, IndexingProgressUpdate, IndexTag } from "../index.js";
import { extractMinimalStackTraceInfo } from "../util/extractMinimalStackTraceInfo.js";
import { getIndexSqlitePath, getLanceDbPath } from "../util/paths.js";

import { findUriInDirs, getUriPathBasename } from "../util/uri.js";

import { ChunkCodebaseIndex } from "./chunk/ChunkCodebaseIndex.js";
import { CodeSnippetsCodebaseIndex } from "./CodeSnippetsIndex.js";
import { FullTextSearchCodebaseIndex } from "./FullTextSearchCodebaseIndex.js";
Expand Down Expand Up @@ -80,24 +80,31 @@ export class CodebaseIndexer {
return [];
}

const indexes = [
const indexes: CodebaseIndex[] = [
new ChunkCodebaseIndex(
this.ide.readFile.bind(this.ide),
this.continueServerClient,
config.embeddingsProvider.maxEmbeddingChunkSize,
), // Chunking must come first
new LanceDbIndex(
config.embeddingsProvider,
this.ide.readFile.bind(this.ide),
this.continueServerClient,
),
];

const lanceDbIndex = await LanceDbIndex.create(
config.embeddingsProvider,
this.ide.readFile.bind(this.ide),
this.continueServerClient,
);

if (lanceDbIndex) {
indexes.push(lanceDbIndex);
}

indexes.push(
new FullTextSearchCodebaseIndex(),
new CodeSnippetsCodebaseIndex(this.ide),
];
);

return indexes;
}

public async refreshFile(
file: string,
workspaceDirs: string[],
Expand Down
10 changes: 5 additions & 5 deletions core/indexing/LanceDbIndex.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ import lance from "vectordb";

import { testConfigHandler, testIde } from "../test/fixtures";
import { getLanceDbPath } from "../util/paths";

import { LanceDbIndex } from "./LanceDbIndex";
import { DatabaseConnection, SqliteDb } from "./refreshIndex";
import {
mockPathAndCacheKey,
mockTag,
testContinueServerClient,
updateIndexAndAwaitGenerator,
} from "./test/indexing";

import { LanceDbIndex } from "./LanceDbIndex";
import { DatabaseConnection, SqliteDb } from "./refreshIndex";
import { IndexResultType } from "./types";

jest.useFakeTimers();
Expand All @@ -31,11 +31,11 @@ describe.skip("ChunkCodebaseIndex", () => {
throw new Error("Failed to load config");
}

index = new LanceDbIndex(
index = (await LanceDbIndex.create(
mockConfig.embeddingsProvider,
testIde.readFile.bind(testIde),
testContinueServerClient,
);
))!;

sqliteDb = await SqliteDb.get();
lanceDb = await lance.connect(getLanceDbPath());
Expand Down
Loading
Loading