diff --git a/.continueignore b/.continueignore index 7e0d0c115f..79e924012d 100644 --- a/.continueignore +++ b/.continueignore @@ -4,4 +4,5 @@ docs/docs/languages .idea/ .vscode/ .archive/ -**/*.scm \ No newline at end of file +**/*.scm +**/*.diff \ No newline at end of file diff --git a/core/autocomplete/generation/ListenableGenerator.test.ts b/core/autocomplete/generation/ListenableGenerator.test.ts index 9dc6cc8937..916d8dbef3 100644 --- a/core/autocomplete/generation/ListenableGenerator.test.ts +++ b/core/autocomplete/generation/ListenableGenerator.test.ts @@ -1,4 +1,5 @@ import { jest } from "@jest/globals"; + import { ListenableGenerator } from "./ListenableGenerator"; describe("ListenableGenerator", () => { diff --git a/core/commands/index.ts b/core/commands/index.ts index cb5e162f7c..492b7f90a6 100644 --- a/core/commands/index.ts +++ b/core/commands/index.ts @@ -1,6 +1,6 @@ import { CustomCommand, SlashCommand, SlashCommandDescription } from "../"; -import { stripImages } from "../llm/images"; import { renderTemplatedString } from "../promptFiles/v1/renderTemplatedString"; +import { renderChatMessage } from "../util/messageContent"; import SlashCommands from "./slash"; @@ -29,7 +29,8 @@ export function slashFromCustomCommand( const messages = [...history]; // Find the last chat message with this slash command and replace it with the user input for (let i = messages.length - 1; i >= 0; i--) { - const { role, content } = messages[i]; + const message = messages[i]; + const { role, content } = message; if (role !== "user") { continue; } @@ -41,7 +42,7 @@ export function slashFromCustomCommand( ) ) { messages[i] = { - ...messages[i], + ...message, content: content.map((part) => { return part.text?.startsWith(`/${customCommand.name}`) ? { ...part, text: promptUserInput } @@ -53,13 +54,16 @@ export function slashFromCustomCommand( typeof content === "string" && content.startsWith(`/${customCommand.name}`) ) { - messages[i] = { ...messages[i], content: promptUserInput }; + messages[i] = { ...message, content: promptUserInput }; break; } } - for await (const chunk of llm.streamChat(messages, new AbortController().signal)) { - yield stripImages(chunk.content); + for await (const chunk of llm.streamChat( + messages, + new AbortController().signal, + )) { + yield renderChatMessage(chunk); } }, }; diff --git a/core/commands/slash/comment.ts b/core/commands/slash/comment.ts deleted file mode 100644 index adadb026c0..0000000000 --- a/core/commands/slash/comment.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { SlashCommand } from "../../"; - -import EditSlashCommand from "./edit"; - -const CommentSlashCommand: SlashCommand = { - name: "comment", - description: "Write comments for highlighted code", - run: async function* (sdk) { - for await (const update of EditSlashCommand.run({ - ...sdk, - input: - "Write comments for this code. Do not change anything about the code itself.", - })) { - yield update; - } - }, -}; - -export default CommentSlashCommand; diff --git a/core/commands/slash/commit.ts b/core/commands/slash/commit.ts index f6521a73dc..8fc1ae27bb 100644 --- a/core/commands/slash/commit.ts +++ b/core/commands/slash/commit.ts @@ -1,5 +1,5 @@ import { SlashCommand } from "../../index.js"; -import { stripImages } from "../../llm/images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; const CommitMessageCommand: SlashCommand = { name: "commit", @@ -18,7 +18,7 @@ const CommitMessageCommand: SlashCommand = { [{ role: "user", content: prompt }], new AbortController().signal, )) { - yield stripImages(chunk.content); + yield renderChatMessage(chunk); } }, }; diff --git a/core/commands/slash/draftIssue.ts b/core/commands/slash/draftIssue.ts index 268c68384c..a640dbd25d 100644 --- a/core/commands/slash/draftIssue.ts +++ b/core/commands/slash/draftIssue.ts @@ -1,6 +1,6 @@ import { ChatMessage, SlashCommand } from "../../index.js"; -import { stripImages } from "../../llm/images.js"; import { removeQuotesAndEscapes } from "../../util/index.js"; +import { renderChatMessage } from "../../util/messageContent.js"; const PROMPT = ( input: string, @@ -30,7 +30,8 @@ const DraftIssueCommand: SlashCommand = { return; } let title = await llm.complete( - `Generate a title for the GitHub issue requested in this user input: '${input}'. Use no more than 20 words and output nothing other than the title. Do not surround it with quotes. The title is: `, new AbortController().signal, + `Generate a title for the GitHub issue requested in this user input: '${input}'. Use no more than 20 words and output nothing other than the title. Do not surround it with quotes. The title is: `, + new AbortController().signal, { maxTokens: 20 }, ); @@ -43,9 +44,12 @@ const DraftIssueCommand: SlashCommand = { { role: "user", content: PROMPT(input, title) }, ]; - for await (const chunk of llm.streamChat(messages, new AbortController().signal)) { + for await (const chunk of llm.streamChat( + messages, + new AbortController().signal, + )) { body += chunk.content; - yield stripImages(chunk.content); + yield renderChatMessage(chunk); } const url = `${params.repositoryUrl}/issues/new?title=${encodeURIComponent( diff --git a/core/commands/slash/edit.ts b/core/commands/slash/edit.ts deleted file mode 100644 index a8f8db944c..0000000000 --- a/core/commands/slash/edit.ts +++ /dev/null @@ -1,637 +0,0 @@ -import { - ContextItemWithId, - ILLM, - RangeInFileWithContents, - SlashCommand, -} from "../../"; -import { - filterCodeBlockLines, - filterEnglishLinesAtEnd, - filterEnglishLinesAtStart, - fixCodeLlamaFirstLineIndentation, - stopAtLines, - streamWithNewLines, -} from "../../autocomplete/filtering/streamTransforms/lineStream"; -import { streamLines } from "../../diff/util"; -import { stripImages } from "../../llm/images"; -import { - dedentAndGetCommonWhitespace, - getMarkdownLanguageTagForFile, -} from "../../util/"; -import { ctxItemToRifWithContents } from "../util"; - -const PROMPT = `Take the file prefix and suffix into account, but only rewrite the code_to_edit as specified in the user_request. The code you write in modified_code_to_edit will replace the code between the code_to_edit tags. Do NOT preface your answer or write anything other than code. The tag should be written to indicate the end of the modified code section. Do not ever use nested tags. - -Example: - - -class Database: - def __init__(self): - self._data = {{}} - - def get(self, key): - return self._data[key] - - - - def set(self, key, value): - self._data[key] = value - - - - def clear_all(): - self._data = {{}} - - -Raise an error if the key already exists. - - - def set(self, key, value): - if key in self._data: - raise KeyError(f"Key {{key}} already exists") - self._data[key] = value - - -Main task: -`; - -export async function getPromptParts( - rif: RangeInFileWithContents, - fullFileContents: string, - model: ILLM, - input: string, - tokenLimit: number | undefined, -) { - const maxTokens = Math.floor(model.contextLength / 2); - - const BUFFER_FOR_FUNCTIONS = 400; - let totalTokens = - model.countTokens(fullFileContents + PROMPT + input) + - BUFFER_FOR_FUNCTIONS + - maxTokens; - - const fullFileContentsList = fullFileContents.split("\n"); - const maxStartLine = rif.range.start.line; - const minEndLine = rif.range.end.line; - let curStartLine = 0; - let curEndLine = fullFileContentsList.length - 1; - - if (totalTokens > model.contextLength) { - while (curEndLine > minEndLine) { - totalTokens -= model.countTokens(fullFileContentsList[curEndLine]); - curEndLine--; - if (totalTokens < model.contextLength) { - break; - } - } - } - - if (totalTokens > model.contextLength) { - while (curStartLine < maxStartLine) { - curStartLine++; - totalTokens -= model.countTokens(fullFileContentsList[curStartLine]); - if (totalTokens < model.contextLength) { - break; - } - } - } - - let filePrefix = fullFileContentsList - .slice(curStartLine, maxStartLine - 1) - .join("\n"); - let fileSuffix = fullFileContentsList - .slice(minEndLine, curEndLine + 1) - .join("\n"); - - if (rif.contents.length > 0) { - let lines = rif.contents.split(/\r?\n/); - let firstLine = lines[0] || null; - while (firstLine && firstLine.trim() === "") { - filePrefix += firstLine; - rif.contents = rif.contents.substring(firstLine.length); - lines = rif.contents.split(/\r?\n/); - firstLine = lines[0] || null; - } - - let lastLine = lines[lines.length - 1] || null; - while (lastLine && lastLine.trim() === "") { - fileSuffix = lastLine + fileSuffix; - rif.contents = rif.contents.substring( - 0, - rif.contents.length - lastLine.length, - ); - lines = rif.contents.split(/\r?\n/); - lastLine = lines[lines.length - 1] || null; - } - - while (rif.contents.startsWith("\n")) { - filePrefix += "\n"; - rif.contents = rif.contents.substring(1); - } - while (rif.contents.endsWith("\n")) { - fileSuffix = `\n${fileSuffix}`; - rif.contents = rif.contents.substring(0, rif.contents.length - 1); - } - } - return { filePrefix, fileSuffix, contents: rif.contents, maxTokens }; -} - -function compilePrompt( - filePrefix: string, - contents: string, - fileSuffix: string, - input: string, -): string { - if (contents.trim() === "") { - // Separate prompt for insertion at the cursor, the other tends to cause it to repeat whole file - return `\ - -${filePrefix} - - - -${fileSuffix} - - -${input} - - -Please output the code to be inserted at the cursor in order to fulfill the user_request. Do NOT preface your answer or write anything other than code. You should not write any tags, just the code. Make sure to correctly indent the code:`; - } - - let prompt = PROMPT; - if (filePrefix.trim() !== "") { - prompt += ` - -${filePrefix} -`; - } - prompt += ` - -${contents} -`; - - if (fileSuffix.trim() !== "") { - prompt += ` - -${fileSuffix} -`; - } - prompt += ` - -${input} - - -`; - - return prompt; -} - -function isEndLine(line: string) { - return ( - line.includes("") || - line.includes("") || - line.includes("[/CODE]") - ); -} - -function lineToBeIgnored(line: string, isFirstLine = false): boolean { - return ( - line.includes("```") || - line.includes("") || - line.includes("") || - line.includes("") || - line.includes("") || - line.includes("") || - line.includes("") || - line.includes("") || - line.includes("") - ); -} - -const EditSlashCommand: SlashCommand = { - name: "edit", - description: "Edit selected code", - run: async function* ({ ide, llm, input, history, contextItems, params }) { - let contextItemToEdit = contextItems.find( - (item: ContextItemWithId) => - item.editing && item.id.providerTitle === "code", - ); - if (!contextItemToEdit) { - contextItemToEdit = contextItems.find( - (item: ContextItemWithId) => item.id.providerTitle === "code", - ); - } - - if (!contextItemToEdit) { - yield "Please highlight the code you want to edit, then press `cmd/ctrl+shift+L` to add it to chat"; - return; - } - - // Strip unecessary parts of the input (the fact that you have to do this is suboptimal, should be refactored away) - let content = history[history.length - 1].content; - if (typeof content !== "string") { - content.forEach((part) => { - if (part.text?.startsWith("/edit")) { - part.text = part.text.replace("/edit", "").trimStart(); - } - }); - } else if (input?.startsWith("/edit")) { - content = input.replace("/edit", "").trimStart(); - } else if (input?.startsWith("/comment")) { - content = input.replace("/comment", "").trimStart(); - } - let userInput = stripImages(content).replace( - `\`\`\`${contextItemToEdit.name}\n${contextItemToEdit.content}\n\`\`\`\n`, - "", - ); - // if the above replace fails to find a match, the code will still be present - // in the userInput. Replace it with input if available. - if (userInput.includes("```") && (input !== "" || !input)) { - userInput = input; - } - - const rif: RangeInFileWithContents = - ctxItemToRifWithContents(contextItemToEdit); - - await ide.saveFile(rif.filepath); - const fullFileContents = await ide.readFile(rif.filepath); - - let { filePrefix, contents, fileSuffix, maxTokens } = await getPromptParts( - rif, - fullFileContents, - llm, - userInput, - params?.tokenLimit, - ); - const [dedentedContents, commonWhitespace] = - dedentAndGetCommonWhitespace(contents); - contents = dedentedContents; - - const prompt = compilePrompt(filePrefix, contents, fileSuffix, userInput); - const fullFileContentsLines = fullFileContents.split("\n"); - const fullPrefixLines = fullFileContentsLines.slice( - 0, - Math.max(0, rif.range.start.line - 1), - ); - const fullSuffixLines = fullFileContentsLines.slice(rif.range.end.line); - - let linesToDisplay: string[] = []; - - async function sendDiffUpdate(lines: string[], final = false) { - const completion = lines.join("\n"); - - // Don't do this at the very end, just show the inserted code - if (final) { - linesToDisplay = []; - } - - // Only recalculate at every new-line, because this is sort of expensive - else if (completion.endsWith("\n")) { - const contentsLines = rif.contents.split("\n"); - let rewrittenLines = 0; - for (const line of lines) { - for (let i = rewrittenLines; i < contentsLines.length; i++) { - if ( - // difflib.SequenceMatcher( - // null, line, contentsLines[i] - // ).ratio() - // > 0.7 - line.trim() === contentsLines[i].trim() && // Temp replacement for difflib (TODO) - contentsLines[i].trim() !== "" - ) { - rewrittenLines = i + 1; - break; - } - } - } - linesToDisplay = contentsLines.slice(rewrittenLines); - } - - const newFileContents = `${fullPrefixLines.join("\n")}\n${completion}\n${ - linesToDisplay.length > 0 ? `${linesToDisplay.join("\n")}\n` : "" - }${fullSuffixLines.join("\n")}`; - - const stepIndex = history.length - 1; - - await ide.showDiff(rif.filepath, newFileContents, stepIndex); - } - - // Important state variables - // ------------------------- - const originalLines = rif.contents === "" ? [] : rif.contents.split("\n"); - // In the actual file, taking into account block offset - let currentLineInFile = rif.range.start.line; - let currentBlockLines: string[] = []; - let originalLinesBelowPreviousBlocks = originalLines; - // The start of the current block in file, taking into account block offset - let currentBlockStart = -1; - let offsetFromBlocks = 0; - - // Don't end the block until you've matched N simultaneous lines - // This helps avoid many tiny blocks - const LINES_TO_MATCH_BEFORE_ENDING_BLOCK = 2; - // If a line has been matched at the end of the block, this is its index within originalLinesBelowPreviousBlocks - // Except we are keeping track of multiple potentialities, so it's a list - // We always check the lines following each of these leads, but if multiple make it out at the end, we use the first one - // This is a tuple of (index_of_last_matched_line, number_of_lines_matched) - let indicesOfLastMatchedLines: [number, number][] = []; - - async function handleGeneratedLine(line: string) { - if (currentBlockLines.length === 0) { - // Set this as the start of the next block - currentBlockStart = - rif.range.start.line + - originalLines.length - - originalLinesBelowPreviousBlocks.length + - offsetFromBlocks; - if ( - originalLinesBelowPreviousBlocks.length > 0 && - line === originalLinesBelowPreviousBlocks[0] - ) { - // Line is equal to the next line in file, move past this line - originalLinesBelowPreviousBlocks = - originalLinesBelowPreviousBlocks.slice(1); - return; - } - } - - // In a block, and have already matched at least one line - // Check if the next line matches, for each of the candidates - const matchesFound: any[] = []; - let firstValidMatch: any = null; - for (const [ - index_of_last_matched_line, - num_lines_matched, - ] of indicesOfLastMatchedLines) { - if ( - index_of_last_matched_line + 1 < - originalLinesBelowPreviousBlocks.length && - line === - originalLinesBelowPreviousBlocks[index_of_last_matched_line + 1] - ) { - matchesFound.push([ - index_of_last_matched_line + 1, - num_lines_matched + 1, - ]); - if ( - firstValidMatch === null && - num_lines_matched + 1 >= LINES_TO_MATCH_BEFORE_ENDING_BLOCK - ) { - firstValidMatch = [ - index_of_last_matched_line + 1, - num_lines_matched + 1, - ]; - } - } - } - indicesOfLastMatchedLines = matchesFound; - - if (firstValidMatch !== null) { - // We've matched the required number of lines, insert suggestion! - - // We added some lines to the block that were matched (including maybe some blank lines) - // So here we will strip all matching lines from the end of currentBlockLines - const linesStripped: string[] = []; - let indexOfLastLineInBlock: number = firstValidMatch[0]; - while ( - currentBlockLines.length > 0 && - currentBlockLines[currentBlockLines.length - 1] === - originalLinesBelowPreviousBlocks[indexOfLastLineInBlock - 1] - ) { - linesStripped.push(currentBlockLines.pop() as string); - indexOfLastLineInBlock -= 1; - } - - // Reset current block / update variables - currentLineInFile += 1; - offsetFromBlocks += currentBlockLines.length; - originalLinesBelowPreviousBlocks = - originalLinesBelowPreviousBlocks.slice(indexOfLastLineInBlock + 1); - currentBlockLines = []; - currentBlockStart = -1; - indicesOfLastMatchedLines = []; - - return; - } - - // Always look for new matching candidates - const newMatches: any[] = []; - for (let i = 0; i < originalLinesBelowPreviousBlocks.length; i++) { - const ogLine = originalLinesBelowPreviousBlocks[i]; - // TODO: It's a bit sus to be disqualifying empty lines. - // What you ideally do is find ALL matches, and then throw them out as you check the following lines - if (ogLine === line) { - // and og_line.trim() !== "": - newMatches.push([i, 1]); - } - } - indicesOfLastMatchedLines = indicesOfLastMatchedLines.concat(newMatches); - - // Make sure they are sorted by index - indicesOfLastMatchedLines = indicesOfLastMatchedLines.sort( - (a, b) => a[0] - b[0], - ); - - currentBlockLines.push(line); - } - - let messages = history; - messages[messages.length - 1] = { role: "user", content: prompt }; - - let linesOfPrefixCopied = 0; - const lines = []; - let unfinishedLine = ""; - let completionLinesCovered = 0; - let repeatingFileSuffix = false; - const lineBelowHighlightedRange = fileSuffix.trim().split("\n")[0]; - - // Use custom templates defined by the model - const template = llm.promptTemplates?.edit; - let generator: AsyncGenerator; - if (template) { - const rendered = llm.renderPromptTemplate( - template, - // typeof template === 'string' ? template : template.prompt, - messages.slice(0, messages.length - 1), - { - codeToEdit: rif.contents, - userInput, - filePrefix: filePrefix, - fileSuffix: fileSuffix, - - // Some built-in templates use these instead of the above - prefix: filePrefix, - suffix: fileSuffix, - - language: getMarkdownLanguageTagForFile(rif.filepath), - systemMessage: llm.systemMessage ?? "", - // "contextItems": (await sdk.getContextItemChatMessages()).map(x => x.content || "").join("\n\n"), - }, - ); - if (typeof rendered === "string") { - messages = [ - { - role: "user", - content: rendered, - }, - ]; - } else { - messages = rendered; - } - - const completion = llm.streamComplete( - rendered as string, - new AbortController().signal, - { - maxTokens: Math.min( - maxTokens, - Math.floor(llm.contextLength / 2), - 4096, - ), - raw: true, - }, - ); - let lineStream = streamLines(completion); - - lineStream = filterEnglishLinesAtStart(lineStream); - - lineStream = filterEnglishLinesAtEnd(filterCodeBlockLines(lineStream)); - lineStream = stopAtLines(lineStream, () => {}); - - generator = streamWithNewLines( - fixCodeLlamaFirstLineIndentation(lineStream), - ); - } else { - async function* gen() { - for await (const chunk of llm.streamChat( - messages, - new AbortController().signal, - { - temperature: 0.5, // TODO - maxTokens: Math.min( - maxTokens, - Math.floor(llm.contextLength / 2), - 4096, - ), - }, - )) { - yield stripImages(chunk.content); - } - } - - generator = gen(); - } - - for await (const chunk of generator) { - // Stop early if it is repeating the fileSuffix or the step was deleted - if (repeatingFileSuffix) { - break; - } - - // Allow stopping breakpoints - yield undefined; - - // Accumulate lines - const chunkLines = chunk.split("\n"); - chunkLines[0] = unfinishedLine + chunkLines[0]; - if (chunk.endsWith("\n")) { - unfinishedLine = ""; - chunkLines.pop(); // because this will be an empty string - } else { - unfinishedLine = chunkLines.pop() ?? ""; - } - - // Deal with newly accumulated lines - for (let i = 0; i < chunkLines.length; i++) { - // Trailing whitespace doesn't matter - chunkLines[i] = chunkLines[i].trimEnd(); - chunkLines[i] = commonWhitespace + chunkLines[i]; - - // Lines that should signify the end of generation - if (isEndLine(chunkLines[i])) { - break; - } - // Lines that should be ignored, like the <> tags - if (lineToBeIgnored(chunkLines[i], completionLinesCovered === 0)) { - continue; // noice - } - // Check if we are currently just copying the prefix - if ( - (linesOfPrefixCopied > 0 || completionLinesCovered === 0) && - linesOfPrefixCopied < filePrefix.split("\n").length && - chunkLines[i] === fullPrefixLines[linesOfPrefixCopied] - ) { - // This is a sketchy way of stopping it from repeating the filePrefix. Is a bug if output happens to have a matching line - linesOfPrefixCopied += 1; - continue; // also nice - } - // Because really short lines might be expected to be repeated, this is only a !heuristic! - // Stop when it starts copying the fileSuffix - if ( - chunkLines[i].trim() === lineBelowHighlightedRange.trim() && - chunkLines[i].trim().length > 4 && - !( - originalLinesBelowPreviousBlocks.length > 0 && - chunkLines[i].trim() === originalLinesBelowPreviousBlocks[0].trim() - ) - ) { - repeatingFileSuffix = true; - break; - } - - lines.push(chunkLines[i]); - completionLinesCovered += 1; - currentLineInFile += 1; - } - - await sendDiffUpdate( - lines.concat([ - unfinishedLine?.startsWith("<") - ? commonWhitespace - : commonWhitespace + unfinishedLine, - ]), - ); - } - - // Add the unfinished line - if ( - unfinishedLine !== "" && - !lineToBeIgnored(unfinishedLine, completionLinesCovered === 0) && - !isEndLine(unfinishedLine) - ) { - unfinishedLine = commonWhitespace + unfinishedLine; - lines.push(unfinishedLine); - await handleGeneratedLine(unfinishedLine); - completionLinesCovered += 1; - currentLineInFile += 1; - } - - await sendDiffUpdate(lines, true); - - if (params?.recap) { - const prompt = `This is the code before editing: -\`\`\` -${contents} -\`\`\` - -This is the code after editing: - -\`\`\` -${lines.join("\n")} -\`\`\` - -Please briefly explain the changes made to the code above. Give no more than 2-3 sentences, and use markdown bullet points:`; - - for await (const update of llm.streamComplete( - prompt, - new AbortController().signal, - )) { - yield update; - } - } - yield `Edited ${contextItemToEdit.name}`; - }, -}; - -export default EditSlashCommand; diff --git a/core/commands/slash/mcp.ts b/core/commands/slash/mcp.ts new file mode 100644 index 0000000000..a9fa332bbf --- /dev/null +++ b/core/commands/slash/mcp.ts @@ -0,0 +1,46 @@ +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; + +import { ChatMessage, SlashCommand } from "../../index.js"; +import { renderChatMessage } from "../../util/messageContent.js"; +export function constructMcpSlashCommand( + client: Client, + name: string, + description?: string, + args?: string[], +): SlashCommand { + return { + name, + description: description ?? "MCP Prompt", + params: {}, + run: async function* (context) { + const argsObject: { [key: string]: string } = {}; + const userInput = context.input.split(" ").slice(1).join(" "); + if (args) { + args.forEach((arg, i) => { + argsObject[arg] = userInput; + }); + } + + const result = await client.getPrompt({ name, arguments: argsObject }); + + const messages: ChatMessage[] = result.messages.map((msg) => { + if (msg.content.type !== "text") { + throw new Error( + "Continue currently only supports text prompts through MCP", + ); + } + return { + content: msg.content.text, + role: msg.role, + }; + }); + + for await (const chunk of context.llm.streamChat( + messages, + new AbortController().signal, + )) { + yield renderChatMessage(chunk); + } + }, + }; +} diff --git a/core/commands/slash/onboard.ts b/core/commands/slash/onboard.ts index 4d8aa3a38a..b9a39aa0b3 100644 --- a/core/commands/slash/onboard.ts +++ b/core/commands/slash/onboard.ts @@ -9,7 +9,7 @@ import { defaultIgnoreFile, gitIgArrayFromFile, } from "../../indexing/ignore"; -import { stripImages } from "../../llm/images"; +import { renderChatMessage } from "../../util/messageContent"; const LANGUAGE_DEP_MGMT_FILENAMES = [ "package.json", // JavaScript (Node.js) @@ -45,10 +45,11 @@ const OnboardSlashCommand: SlashCommand = { const context = await gatherProjectContext(workspaceDir, ide); const prompt = createOnboardingPrompt(context); - for await (const chunk of llm.streamChat([ - { role: "user", content: prompt }, - ], new AbortController().signal)) { - yield stripImages(chunk.content); + for await (const chunk of llm.streamChat( + [{ role: "user", content: prompt }], + new AbortController().signal, + )) { + yield renderChatMessage(chunk); } }, }; diff --git a/core/commands/slash/review.ts b/core/commands/slash/review.ts index 77357647a1..a3a2be6a74 100644 --- a/core/commands/slash/review.ts +++ b/core/commands/slash/review.ts @@ -1,5 +1,5 @@ import { ChatMessage, SlashCommand } from "../../index.js"; -import { stripImages } from "../../llm/images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; const prompt = ` Review the following code, focusing on Readability, Maintainability, Code Smells, Speed, and Memory Performance. Provide feedback with these guidelines: @@ -43,10 +43,11 @@ const ReviewMessageCommand: SlashCommand = { const content = `${prompt} \r\n ${reviewText}`; - for await (const chunk of llm.streamChat([ - { role: "user", content: content }, - ], new AbortController().signal)) { - yield stripImages(chunk.content); + for await (const chunk of llm.streamChat( + [{ role: "user", content: content }], + new AbortController().signal, + )) { + yield renderChatMessage(chunk); } }, }; diff --git a/core/commands/slash/share.ts b/core/commands/slash/share.ts index 5ee1fd938e..4b69f4c623 100644 --- a/core/commands/slash/share.ts +++ b/core/commands/slash/share.ts @@ -4,7 +4,7 @@ import path from "path"; import { languageForFilepath } from "../../autocomplete/constants/AutocompleteLanguageInfo.js"; import { SlashCommand } from "../../index.js"; -import { stripImages } from "../../llm/images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; // If useful elsewhere, helper funcs should move to core/util/index.ts or similar function getOffsetDatetime(date: Date): Date { @@ -48,7 +48,7 @@ const ShareSlashCommand: SlashCommand = { // message in the chat history, this will omit it for (const msg of history.slice(0, history.length - 1)) { let msgText = msg.content; - msgText = stripImages(msg.content); + msgText = renderChatMessage(msg); if (msg.role === "user" && msgText.search("```") > -1) { msgText = reformatCodeBlocks(msgText); diff --git a/core/config/load.ts b/core/config/load.ts index 9911c840fe..1e16913e51 100644 --- a/core/config/load.ts +++ b/core/config/load.ts @@ -1,10 +1,11 @@ import { execSync } from "child_process"; -import * as JSONC from "comment-json"; import * as fs from "fs"; import os from "os"; import path from "path"; +import * as JSONC from "comment-json"; import * as tar from "tar"; + import { BrowserSerializedContinueConfig, Config, @@ -28,6 +29,7 @@ import { slashCommandFromDescription, slashFromCustomCommand, } from "../commands/index.js"; +import MCPConnectionSingleton from "../context/mcp"; import CodebaseContextProvider from "../context/providers/CodebaseContextProvider"; import ContinueProxyContextProvider from "../context/providers/ContinueProxyContextProvider"; import CustomContextProviderClass from "../context/providers/CustomContextProvider"; @@ -42,6 +44,7 @@ import { BaseLLM } from "../llm"; import { llmFromDescription } from "../llm/llms"; import CustomLLMClass from "../llm/llms/CustomLLM"; import FreeTrial from "../llm/llms/FreeTrial"; +import { allTools } from "../tools"; import { copyOf } from "../util"; import { fetchwithRequestOptions } from "../util/fetchWithOptions"; import { GlobalContext } from "../util/GlobalContext"; @@ -57,7 +60,6 @@ import { getEsbuildBinaryPath, readAllGlobalPromptFiles, } from "../util/paths"; - import { defaultContextProvidersJetBrains, defaultContextProvidersVsCode, @@ -484,14 +486,36 @@ async function intermediateToFinalConfig( } } - return { + let continueConfig: ContinueConfig = { ...config, contextProviders, models, embeddingsProvider: config.embeddingsProvider as any, tabAutocompleteModels, reranker: config.reranker as any, + tools: allTools, }; + + // Apply MCP if specified + if (config.experimental?.modelContextProtocolServer) { + const mcpConnection = await MCPConnectionSingleton.getInstance( + config.experimental.modelContextProtocolServer, + ); + continueConfig = await Promise.race([ + mcpConnection.modifyConfig(continueConfig), + new Promise((_, reject) => + setTimeout( + () => reject(new Error("MCP connection timed out after 2000ms")), + 2000, + ), + ), + ]).catch((error) => { + console.warn("MCP connection error:", error); + return continueConfig; // Return original config if timeout occurs + }); + } + + return continueConfig; } function finalToBrowserConfig( @@ -528,6 +552,7 @@ function finalToBrowserConfig( ui: final.ui, experimental: final.experimental, docs: final.docs, + tools: final.tools, }; } diff --git a/core/config/promptFile.ts b/core/config/promptFile.ts index a630e41d65..5020fc1122 100644 --- a/core/config/promptFile.ts +++ b/core/config/promptFile.ts @@ -4,9 +4,9 @@ import Handlebars from "handlebars"; import * as YAML from "yaml"; import { walkDir } from "../indexing/walkDir"; -import { stripImages } from "../llm/images"; import { renderTemplatedString } from "../promptFiles/v1/renderTemplatedString"; import { getBasename } from "../util/index"; +import { renderChatMessage } from "../util/messageContent"; import type { ChatMessage, @@ -129,7 +129,7 @@ export function slashCommandFromPromptFile( messages, new AbortController().signal, )) { - yield stripImages(chunk.content); + yield renderChatMessage(chunk); } context.llm.systemMessage = originalSystemMessage; @@ -236,7 +236,8 @@ function updateChatHistory( const messages = [...history]; for (let i = messages.length - 1; i >= 0; i--) { - const { role, content } = messages[i]; + const message = messages[i]; + const { role, content } = message; if (role !== "user") { continue; } @@ -254,7 +255,7 @@ function updateChatHistory( typeof content === "string" && content.startsWith(`/${commandName}`) ) { - messages[i] = { ...messages[i], content: renderedPrompt }; + messages[i] = { ...message, content: renderedPrompt }; break; } } diff --git a/core/context/mcp/index.ts b/core/context/mcp/index.ts new file mode 100644 index 0000000000..984ba8f7b6 --- /dev/null +++ b/core/context/mcp/index.ts @@ -0,0 +1,157 @@ +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; +import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; +import { WebSocketClientTransport } from "@modelcontextprotocol/sdk/client/websocket.js"; +import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; + +import { ContinueConfig, MCPOptions, SlashCommand, Tool } from "../.."; +import { constructMcpSlashCommand } from "../../commands/slash/mcp"; +import MCPContextProvider from "../providers/MCPContextProvider"; + +class MCPConnectionSingleton { + private static instance: MCPConnectionSingleton; + public client: Client; + private transport: Transport; + + private constructor(private readonly options: MCPOptions) { + this.transport = this.constructTransport(options); + + this.client = new Client( + { + name: "continue-client", + version: "1.0.0", + }, + { + capabilities: {}, + }, + ); + } + + public static getInstance(options: MCPOptions): MCPConnectionSingleton { + if (!MCPConnectionSingleton.instance) { + MCPConnectionSingleton.instance = new MCPConnectionSingleton(options); + } + return MCPConnectionSingleton.instance; + } + + public static getExistingInstance(): MCPConnectionSingleton | null { + return MCPConnectionSingleton.instance; + } + + private constructTransport(options: MCPOptions): Transport { + switch (options.transport.type) { + case "stdio": + return new StdioClientTransport({ + command: options.transport.command, + args: options.transport.args, + }); + case "websocket": + return new WebSocketClientTransport(new URL(options.transport.url)); + case "sse": + return new SSEClientTransport(new URL(options.transport.url)); + default: + throw new Error( + `Unsupported transport type: ${(options.transport as any).type}`, + ); + } + } + + private isConnected: boolean = false; + private connectPromise: Promise | null = null; + + private async connectClient() { + if (this.isConnected) { + // Already connected + return; + } + + if (this.connectPromise) { + // Connection is already in progress; wait for it to complete + await this.connectPromise; + return; + } + + this.connectPromise = (async () => { + await this.client.connect(this.transport); + this.isConnected = true; + })(); + + try { + await this.connectPromise; + } catch (error) { + // Handle connection error if needed + throw error; + } finally { + // Reset the promise so future attempts can try again if necessary + this.connectPromise = null; + } + } + + async modifyConfig(config: ContinueConfig): Promise { + try { + await this.connectClient(); + } catch (error: any) { + if (!error.message.startsWith("StdioClientTransport already started")) { + console.error("Failed to connect client:", error); + return config; + } + } + + // Resources <—> Context Provider + const { resources } = await this.client.listResources(); + + const submenuItems = resources.map((resource: any) => ({ + title: resource.name, + description: resource.description, + id: resource.uri, + })); + + if (!config.contextProviders) { + config.contextProviders = []; + } + + config.contextProviders!.push( + new MCPContextProvider({ + submenuItems, + client: this.client, + }), + ); + + // Tools <—> Tools + const { tools } = await this.client.listTools(); + const continueTools: Tool[] = tools.map((tool) => ({ + displayTitle: tool.name, + function: { + description: tool.description, + name: tool.name, + parameters: tool.inputSchema, + }, + readonly: false, + type: "function", + wouldLikeTo: `use the ${tool.name} tool`, + uri: `mcp://${tool.name}`, + })); + + config.tools = [...config.tools, ...continueTools]; + + // Prompts <—> Slash commands + const { prompts } = await this.client.listPrompts(); + if (!config.slashCommands) { + config.slashCommands = []; + } + + const slashCommands: SlashCommand[] = prompts.map((prompt) => { + return constructMcpSlashCommand( + this.client, + prompt.name, + prompt.description, + prompt.arguments?.map((a) => a.name), + ); + }); + config.slashCommands!.push(...slashCommands); + + return config; + } +} + +export default MCPConnectionSingleton; diff --git a/core/context/providers/MCPContextProvider.ts b/core/context/providers/MCPContextProvider.ts new file mode 100644 index 0000000000..0c0a6939df --- /dev/null +++ b/core/context/providers/MCPContextProvider.ts @@ -0,0 +1,66 @@ +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; + +import { BaseContextProvider } from "../"; +import { + ContextItem, + ContextProviderDescription, + ContextProviderExtras, + ContextSubmenuItem, + LoadSubmenuItemsArgs, +} from "../../"; + +interface MCPContextProviderOptions { + submenuItems: ContextSubmenuItem[]; + client: Client; +} + +class MCPContextProvider extends BaseContextProvider { + static description: ContextProviderDescription = { + title: "mcp", + displayTitle: "MCP", + description: "Model Context Protocol", + type: "submenu", + }; + + constructor(options: MCPContextProviderOptions) { + super(options); + } + + async getContextItems( + query: string, + extras: ContextProviderExtras, + ): Promise { + const { contents } = await ( + this.options as MCPContextProviderOptions + ).client.readResource({ uri: query }); + + return await Promise.all( + contents.map(async (resource) => { + const content = resource.text; + if (typeof content !== "string") { + throw new Error( + "Continue currently only supports text resources from MCP", + ); + } + + return { + name: resource.uri, + description: resource.uri, + content, + uri: { + type: "url", + value: resource.uri, + }, + }; + }), + ); + } + + async loadSubmenuItems( + args: LoadSubmenuItemsArgs, + ): Promise { + return (this.options as MCPContextProviderOptions).submenuItems; + } +} + +export default MCPContextProvider; diff --git a/core/context/providers/WebContextProvider.ts b/core/context/providers/WebContextProvider.ts index 2c8cb89284..9ea4bac410 100644 --- a/core/context/providers/WebContextProvider.ts +++ b/core/context/providers/WebContextProvider.ts @@ -3,12 +3,37 @@ import { ContextItem, ContextProviderDescription, ContextProviderExtras, + FetchFunction, } from "../.."; import { getHeaders } from "../../continueServer/stubs/headers"; import { TRIAL_PROXY_URL } from "../../control-plane/client"; +export const fetchSearchResults = async ( + query: string, + n: number, + fetchFn: FetchFunction, +): Promise => { + const resp = await fetchFn(WebContextProvider.ENDPOINT, { + method: "POST", + headers: { + "Content-Type": "application/json", + ...(await getHeaders()), + }, + body: JSON.stringify({ + query, + n, + }), + }); + + if (!resp.ok) { + const text = await resp.text(); + throw new Error(`Failed to fetch web context: ${text}`); + } + return await resp.json(); +}; + export default class WebContextProvider extends BaseContextProvider { - private static ENDPOINT = new URL("web", TRIAL_PROXY_URL); + public static ENDPOINT = new URL("web", TRIAL_PROXY_URL); private static DEFAULT_N = 6; static description: ContextProviderDescription = { @@ -23,22 +48,10 @@ export default class WebContextProvider extends BaseContextProvider { query: string, extras: ContextProviderExtras, ): Promise { - const resp = await extras.fetch(WebContextProvider.ENDPOINT, { - method: "POST", - headers: { - "Content-Type": "application/json", - ...(await getHeaders()), - }, - body: JSON.stringify({ - query: extras.fullInput, - n: this.options.n ?? WebContextProvider.DEFAULT_N, - }), - }); - if (!resp.ok) { - const text = await resp.text(); - throw new Error(`Failed to fetch web context: ${text}`); - } - const json = await resp.json(); - return json; + return await fetchSearchResults( + extras.fullInput, + this.options.n ?? WebContextProvider.DEFAULT_N, + extras.fetch, + ); } } diff --git a/core/context/providers/index.ts b/core/context/providers/index.ts index 316e4936b6..f5656832c4 100644 --- a/core/context/providers/index.ts +++ b/core/context/providers/index.ts @@ -17,6 +17,7 @@ import GoogleContextProvider from "./GoogleContextProvider"; import GreptileContextProvider from "./GreptileContextProvider"; import HttpContextProvider from "./HttpContextProvider"; import JiraIssuesContextProvider from "./JiraIssuesContextProvider/"; +import MCPContextProvider from "./MCPContextProvider"; import OpenFilesContextProvider from "./OpenFilesContextProvider"; import OSContextProvider from "./OSContextProvider"; import PostgresContextProvider from "./PostgresContextProvider"; @@ -60,6 +61,7 @@ export const Providers: (typeof BaseContextProvider)[] = [ DiscordContextProvider, GreptileContextProvider, WebContextProvider, + MCPContextProvider, ]; export function contextProviderClassFromName( diff --git a/core/context/providers/utils.ts b/core/context/providers/utils.ts index 7bc20e4918..fd73a9bec5 100644 --- a/core/context/providers/utils.ts +++ b/core/context/providers/utils.ts @@ -4,4 +4,5 @@ export const INSTRUCTIONS_BASE_ITEM: ContextItem = { name: "Instructions", description: "Instructions", content: "", + hidden: true, }; diff --git a/core/context/retrieval/repoMapRequest.ts b/core/context/retrieval/repoMapRequest.ts index 8dc55df9db..57f2b339e4 100644 --- a/core/context/retrieval/repoMapRequest.ts +++ b/core/context/retrieval/repoMapRequest.ts @@ -1,7 +1,7 @@ import { Chunk, ContinueConfig, IDE, ILLM } from "../.."; import { getModelByRole } from "../../config/util"; -import { stripImages } from "../../llm/images"; import generateRepoMap from "../../util/generateRepoMap"; +import { renderChatMessage } from "../../util/messageContent"; const SUPPORTED_MODEL_TITLE_FAMILIES = [ "claude-3", @@ -58,11 +58,14 @@ After this, your response should begin with a tag, followed by a list This is the question that you should select relevant files for: "${input}"`; - const response = await llm.chat([ - { role: "user", content: prompt }, - { role: "assistant", content: "" }, - ], new AbortController().signal); - const content = stripImages(response.content); + const response = await llm.chat( + [ + { role: "user", content: prompt }, + { role: "assistant", content: "" }, + ], + new AbortController().signal, + ); + const content = renderChatMessage(response); console.debug("Repo map retrieval response: ", content); if (!content.includes("\n")) { diff --git a/core/core.ts b/core/core.ts index a1d74f2818..2108d3706f 100644 --- a/core/core.ts +++ b/core/core.ts @@ -36,6 +36,7 @@ import { TTS } from "./util/tts"; import type { ContextItemId, IDE, IndexingProgressUpdate } from "."; import type { FromCoreProtocol, ToCoreProtocol } from "./protocol"; +import { callTool } from "./tools/callTool"; import type { IMessenger, Message } from "./util/messenger"; export class Core { @@ -382,6 +383,7 @@ export class Core { } const model = await configHandler.llmFromTitle(msg.data.title); + const gen = model.streamChat( msg.data.messages, new AbortController().signal, @@ -402,8 +404,11 @@ export class Core { }); break; } + + const chunk = next.value; + // @ts-ignore - yield { content: next.value.content }; + yield { content: chunk }; next = await gen.next(); } @@ -774,6 +779,32 @@ export class Core { recentlyEditedFilesCache.set(filepath, filepath); } }); + + on("tools/call", async ({ data: { toolCall } }) => { + const config = await this.configHandler.loadConfig(); + const tool = config.tools.find( + (t) => t.function.name === toolCall.function.name, + ); + + if (!tool) { + throw new Error(`Tool ${toolCall.function.name} not found`); + } + + const llm = await this.getSelectedModel(); + + const contextItems = await callTool( + tool.uri ?? tool.function.name, + JSON.parse(toolCall.function.arguments || "{}"), + { + ide: this.ide, + llm, + fetch: (url, init) => + fetchwithRequestOptions(url, init, config.requestOptions), + }, + ); + + return { contextItems }; + }); } private indexingCancellationController: AbortController | undefined; diff --git a/core/diff/util.ts b/core/diff/util.ts index 839ab15b16..c56e88504e 100644 --- a/core/diff/util.ts +++ b/core/diff/util.ts @@ -1,7 +1,7 @@ import { distance } from "fastest-levenshtein"; import { ChatMessage } from "../index.js"; -import { stripImages } from "../llm/images.js"; +import { renderChatMessage } from "../util/messageContent.js"; export type LineStream = AsyncGenerator; @@ -103,7 +103,7 @@ export async function* streamLines( try { for await (const update of streamCompletion) { const chunk = - typeof update === "string" ? update : stripImages(update.content); + typeof update === "string" ? update : renderChatMessage(update); buffer += chunk; const lines = buffer.split("\n"); buffer = lines.pop() ?? ""; diff --git a/core/index.d.ts b/core/index.d.ts index 58ba9d8a51..6c17fd1c7c 100644 --- a/core/index.d.ts +++ b/core/index.d.ts @@ -288,7 +288,7 @@ export interface CompletionOptions extends BaseCompletionOptions { model: string; } -export type ChatMessageRole = "user" | "assistant" | "system"; +export type ChatMessageRole = "user" | "assistant" | "system" | "tool"; export interface MessagePart { type: "text" | "imageUrl"; @@ -298,11 +298,52 @@ export interface MessagePart { export type MessageContent = string | MessagePart[]; -export interface ChatMessage { - role: ChatMessageRole; +export interface ToolCall { + id: string; + type: "function"; + function: { + name: string; + arguments: string; + }; +} + +export interface ToolCallDelta { + id?: string; + type?: "function"; + function?: { + name?: string; + arguments?: string; + }; +} + +export interface ToolResultChatMessage { + role: "tool"; + content: string; + toolCallId: string; +} + +export interface UserChatMessage { + role: "user"; content: MessageContent; } +export interface AssistantChatMessage { + role: "assistant"; + content: MessageContent; + toolCalls?: ToolCallDelta[]; +} + +export interface SystemChatMessage { + role: "system"; + content: string; +} + +export type ChatMessage = + | UserChatMessage + | AssistantChatMessage + | SystemChatMessage + | ToolResultChatMessage; + export interface ContextItemId { providerTitle: string; itemId: string; @@ -323,6 +364,7 @@ export interface ContextItem { editable?: boolean; icon?: string; uri?: ContextItemUri; + hidden?: boolean; } export interface ContextItemWithId extends ContextItem { @@ -347,12 +389,30 @@ export interface PromptLog { prompt: string; completion: string; } + +export type ToolStatus = + | "generating" + | "generated" + | "calling" + | "done" + | "canceled"; + +// Will exist only on "assistant" messages with tool calls +interface ToolCallState { + toolCallId: string; + toolCall: ToolCall; + status: ToolStatus; + parsedArgs: any; + output?: ContextItem[]; +} + export interface ChatHistoryItem { message: ChatMessage; editorState?: any; modifiers?: InputModifiers; contextItems: ContextItemWithId[]; promptLogs?: PromptLog[]; + toolCallState?: ToolCallState; } // LLM @@ -856,6 +916,27 @@ interface Prediction { }[]; } +export interface ToolExtras { + ide: IDE; + llm: ILLM; + fetch: FetchFunction; +} + +export interface Tool { + type: "function"; + function: { + name: string; + description?: string; + parameters?: Record; + strict?: boolean | null; + }; + + displayTitle: string; + wouldLikeTo: string; + readonly: boolean; + uri?: string; +} + interface BaseCompletionOptions { temperature?: number; topP?: number; @@ -872,6 +953,7 @@ interface BaseCompletionOptions { raw?: boolean; stream?: boolean; prediction?: Prediction; + tools?: Tool[]; } export interface ModelCapability { @@ -986,6 +1068,28 @@ export interface TabAutocompleteOptions { showWhateverWeHaveAtXMs?: number; } +interface StdioOptions { + type: "stdio"; + command: string; + args: string[]; +} + +interface WebSocketOptions { + type: "websocket"; + url: string; +} + +interface SSEOptions { + type: "sse"; + url: string; +} + +type TransportOptions = StdioOptions | WebSocketOptions | SSEOptions; + +export interface MCPOptions { + transport: TransportOptions; +} + export interface ContinueUIConfig { codeBlockToolbarPosition?: "top" | "bottom"; fontSize?: number; @@ -1091,6 +1195,8 @@ interface ExperimentalConfig { * This is needed to crawl a large number of documentation sites that are dynamically rendered. */ useChromiumForDocsCrawling?: boolean; + useTools?: boolean; + modelContextProtocolServer?: MCPOptions; } interface AnalyticsConfig { @@ -1195,6 +1301,7 @@ export interface ContinueConfig { experimental?: ExperimentalConfig; analytics?: AnalyticsConfig; docs?: SiteIndexingConfig[]; + tools: Tool[]; } export interface BrowserSerializedContinueConfig { @@ -1214,6 +1321,7 @@ export interface BrowserSerializedContinueConfig { experimental?: ExperimentalConfig; analytics?: AnalyticsConfig; docs?: SiteIndexingConfig[]; + tools: Tool[]; } // DOCS SUGGESTIONS AND PACKAGE INFO diff --git a/core/llm/autodetect.ts b/core/llm/autodetect.ts index 9849af95b9..363e7c28ad 100644 --- a/core/llm/autodetect.ts +++ b/core/llm/autodetect.ts @@ -84,6 +84,13 @@ const MODEL_SUPPORTS_IMAGES: string[] = [ "llama3.2", ]; +function modelSupportsTools(modelName: string) { + return ( + modelName.includes("claude") && + (modelName.includes("3-5") || modelName.includes("3.5")) + ); +} + function modelSupportsImages( provider: ModelProvider, model: string, @@ -354,4 +361,5 @@ export { autodetectTemplateType, llmCanGenerateInParallel, modelSupportsImages, + modelSupportsTools, }; diff --git a/core/llm/constructMessages.ts b/core/llm/constructMessages.ts index 5191765cd2..b1cb0da93d 100644 --- a/core/llm/constructMessages.ts +++ b/core/llm/constructMessages.ts @@ -1,4 +1,6 @@ import { ChatHistoryItem, ChatMessage, MessagePart } from "../index.js"; +import { normalizeToMessageParts } from "../util/messageContent.js"; +import { modelSupportsTools } from "./autodetect.js"; const SYSTEM_MESSAGE = `When generating new code: @@ -58,52 +60,67 @@ function helloWorld() { Always follow these guidelines when generating code responses.`; -function hasCodeBlockWithFilename(content: ChatMessage["content"]): boolean { - const contentStr = typeof content === "string" ? content : content[0].text; +const TOOL_USE_RULES = `When using tools, follow the following guidelines: +- Avoid calling tools unless they are absolutely necessary. For example, if you are asked a simple programming question you do not need web search. As another example, if the user asks you to explain something about code, do not create a new file.`; - if (!contentStr) { - return false; +function constructSystemPrompt(model: string): string | null { + if (model.includes("sonnet")) { + return SYSTEM_MESSAGE + "\n\n" + TOOL_USE_RULES; + } + if (modelSupportsTools(model)) { + return TOOL_USE_RULES; } - const codeBlockRegex = /```[\w\W]*?\.[\w\W]*/; - return codeBlockRegex.test(contentStr); + return null; } +const CANCELED_TOOL_CALL_MESSAGE = + "This tool call was cancelled by the user. You should clarify next steps, as they don't wish for you to use this tool."; + export function constructMessages( history: ChatHistoryItem[], model: string, ): ChatMessage[] { - const msgs = []; + const msgs: ChatMessage[] = []; - // Only using this system message with Sonnet right now - if ( - // hasCodeBlockWithFilename(history[0].message.content) && - model.includes("sonnet") - ) { + const systemMessage = constructSystemPrompt(model); + if (systemMessage) { msgs.push({ role: "system" as const, - content: SYSTEM_MESSAGE, + content: systemMessage, }); } for (let i = 0; i < history.length; i++) { const historyItem = history[i]; - let content = Array.isArray(historyItem.message.content) - ? historyItem.message.content - : [{ type: "text", text: historyItem.message.content } as MessagePart]; - - const ctxItems = historyItem.contextItems.map((ctxItem) => { - return { type: "text", text: `${ctxItem.content}\n` } as MessagePart; - }); - - content = [...ctxItems, ...content]; - - msgs.push({ - role: historyItem.message.role, - content, - }); + if (historyItem.message.role === "user") { + // Gather context items for user messages + let content = normalizeToMessageParts(historyItem.message); + + const ctxItems = historyItem.contextItems.map((ctxItem) => { + return { type: "text", text: `${ctxItem.content}\n` } as MessagePart; + }); + + content = [...ctxItems, ...content]; + msgs.push({ + ...historyItem.message, + content, + }); + } else if (historyItem.toolCallState?.status === "canceled") { + // Canceled tool call + msgs.push({ + ...historyItem.message, + content: CANCELED_TOOL_CALL_MESSAGE, + }); + } else { + msgs.push(historyItem.message); + } } - return msgs; + // Remove the "id" from all of the messages + return msgs.map((msg) => { + const { id, ...rest } = msg as any; + return rest; + }); } diff --git a/core/llm/countTokens.ts b/core/llm/countTokens.ts index 9b30ad4bc3..49aed0e8dc 100644 --- a/core/llm/countTokens.ts +++ b/core/llm/countTokens.ts @@ -2,6 +2,7 @@ import { Tiktoken, encodingForModel as _encodingForModel } from "js-tiktoken"; import { ChatMessage, MessageContent, MessagePart } from "../index.js"; +import { renderChatMessage } from "../util/messageContent.js"; import { AsyncEncoder, GPTAsyncEncoder, @@ -9,7 +10,6 @@ import { } from "./asyncEncoder.js"; import { autodetectTemplateType } from "./autodetect.js"; import { TOKEN_BUFFER_FOR_SAFETY } from "./constants.js"; -import { stripImages } from "./images.js"; import llamaTokenizer from "./llamaTokenizer.js"; interface Encoding { encode: Tiktoken["encode"]; @@ -115,19 +115,28 @@ function countTokens( } } -function flattenMessages(msgs: ChatMessage[]): ChatMessage[] { +function messageHasToolCalls(msg: ChatMessage): boolean { + return msg.role === "assistant" && !!msg.toolCalls; +} + +export function flattenMessages(msgs: ChatMessage[]): ChatMessage[] { const flattened: ChatMessage[] = []; + for (let i = 0; i < msgs.length; i++) { const msg = msgs[i]; + if ( flattened.length > 0 && - flattened[flattened.length - 1].role === msg.role + flattened[flattened.length - 1].role === msg.role && + !messageHasToolCalls(msg) && + !messageHasToolCalls(flattened[flattened.length - 1]) ) { flattened[flattened.length - 1].content += `\n\n${msg.content || ""}`; } else { flattened.push(msg); } } + return flattened; } @@ -222,11 +231,8 @@ function pruneRawPromptFromBottom( return pruneStringFromBottom(modelName, maxTokens, prompt); } -function summarize(message: MessageContent): string { - if (Array.isArray(message)) { - return `${stripImages(message).substring(0, 100)}...`; - } - return `${message.substring(0, 100)}...`; +function summarize(message: ChatMessage): string { + return `${renderChatMessage(message).substring(0, 100)}...`; } function pruneChatHistory( @@ -257,7 +263,7 @@ function pruneChatHistory( for (let i = 0; i < longerThanOneThird.length; i++) { // Prune line-by-line from the top const message = longerThanOneThird[i]; - const content = stripImages(message.content); + const content = renderChatMessage(message); const deltaNeeded = totalTokens - contextLength; const delta = Math.min(deltaNeeded, distanceFromThird[i]); message.content = pruneStringFromTop( @@ -273,8 +279,8 @@ function pruneChatHistory( while (totalTokens > contextLength && i < chatHistory.length - 5) { const message = chatHistory[0]; totalTokens -= countTokens(message.content, modelName); - totalTokens += countTokens(summarize(message.content), modelName); - message.content = summarize(message.content); + totalTokens += countTokens(summarize(message), modelName); + message.content = summarize(message); i++; } @@ -297,8 +303,8 @@ function pruneChatHistory( ) { const message = chatHistory[i]; totalTokens -= countTokens(message.content, modelName); - totalTokens += countTokens(summarize(message.content), modelName); - message.content = summarize(message.content); + totalTokens += countTokens(summarize(message), modelName); + message.content = summarize(message); i++; } @@ -314,7 +320,7 @@ function pruneChatHistory( message.content = pruneRawPromptFromTop( modelName, contextLength, - stripImages(message.content), + renderChatMessage(message), tokensForCompletion, ); totalTokens = contextLength; @@ -344,6 +350,24 @@ function addSpaceToAnyEmptyMessages(messages: ChatMessage[]): ChatMessage[] { }); } +function chatMessageIsEmpty(message: ChatMessage): boolean { + switch (message.role) { + case "system": + case "user": + return ( + typeof message.content === "string" && message.content.trim() === "" + ); + case "assistant": + return ( + typeof message.content === "string" && + message.content.trim() === "" && + !message.toolCalls + ); + case "tool": + return false; + } +} + function compileChatMessages( modelName: string, msgs: ChatMessage[] | undefined, @@ -357,7 +381,7 @@ function compileChatMessages( let msgsCopy = msgs ? msgs .map((msg) => ({ ...msg })) - .filter((msg) => msg.content !== "" && msg.role !== "system") + .filter((msg) => !chatMessageIsEmpty(msg) && msg.role !== "system") : []; msgsCopy = addSpaceToAnyEmptyMessages(msgsCopy); @@ -376,7 +400,7 @@ function compileChatMessages( ) { let content = ""; if (msgs?.[0].role === "system") { - content = stripImages(msgs?.[0].content); + content = renderChatMessage(msgs?.[0]); } if (systemMessage && systemMessage.trim() !== "") { const shouldAddNewLines = content !== ""; @@ -411,7 +435,7 @@ function compileChatMessages( if (!supportsImages) { for (const msg of msgsCopy) { if ("content" in msg && Array.isArray(msg.content)) { - const content = stripImages(msg.content); + const content = renderChatMessage(msg); msg.content = content; } } diff --git a/core/llm/flattenMessages.test.ts b/core/llm/flattenMessages.test.ts new file mode 100644 index 0000000000..6f5ba0a4c2 --- /dev/null +++ b/core/llm/flattenMessages.test.ts @@ -0,0 +1,104 @@ +import { ChatMessage } from ".."; +import { flattenMessages } from "./countTokens"; + +describe("flattenMessages", () => { + it("should return an empty array when given an empty array", () => { + expect(flattenMessages([])).toEqual([]); + }); + + it("should return the same array when there is only one message", () => { + const messages: ChatMessage[] = [{ role: "user", content: "Hello" }]; + expect(flattenMessages(messages)).toEqual(messages); + }); + + it("should return the same array when messages have different roles", () => { + const messages: ChatMessage[] = [ + { role: "user", content: "Hello" }, + { role: "assistant", content: "Hi there!" }, + { role: "user", content: "How are you?" }, + ]; + expect(flattenMessages(messages)).toEqual(messages); + }); + + it("should combine consecutive messages with the same role", () => { + const messages: ChatMessage[] = [ + { role: "user", content: "Hello" }, + { role: "user", content: "How are you?" }, + { role: "assistant", content: "I am an AI assistant." }, + { role: "assistant", content: "How can I help you?" }, + ]; + const expected: ChatMessage[] = [ + { role: "user", content: "Hello\n\nHow are you?" }, + { + role: "assistant", + content: "I am an AI assistant.\n\nHow can I help you?", + }, + ]; + expect(flattenMessages(messages)).toEqual(expected); + }); + + it("should handle messages with undefined or empty content", () => { + const messages: ChatMessage[] = [ + { role: "user", content: "Hello" }, + { role: "user", content: "" }, + { role: "assistant", content: "" }, + { role: "assistant", content: "Sure, I can help with that." }, + ]; + const expected: ChatMessage[] = [ + { role: "user", content: "Hello\n\n" }, + { role: "assistant", content: "\n\nSure, I can help with that." }, + ]; + expect(flattenMessages(messages)).toEqual(expected); + }); + + it("should not combine non-consecutive messages with the same role", () => { + const messages: ChatMessage[] = [ + { role: "user", content: "First message" }, + { role: "assistant", content: "Reply to first message" }, + { role: "user", content: "Second message" }, + ]; + expect(flattenMessages(messages)).toEqual(messages); + }); + + it("should handle a mix of message roles and contents correctly", () => { + const messages: ChatMessage[] = [ + { role: "user", content: "Message one" }, + { role: "user", content: "Message two" }, + { role: "assistant", content: "Reply one" }, + { role: "user", content: "Message three" }, + { role: "user", content: "Message four" }, + { role: "assistant", content: "Reply two" }, + ]; + const expected: ChatMessage[] = [ + { role: "user", content: "Message one\n\nMessage two" }, + { role: "assistant", content: "Reply one" }, + { role: "user", content: "Message three\n\nMessage four" }, + { role: "assistant", content: "Reply two" }, + ]; + expect(flattenMessages(messages)).toEqual(expected); + }); + + it("should not combine adjacent assistant messages when one of them is a tool call", () => { + const messages: ChatMessage[] = [ + { role: "user", content: "What's the capital of France?" }, + { role: "assistant", content: "Paris" }, + { + role: "assistant", + content: " ", + toolCalls: [ + { + id: "test", + type: "function", + function: { + name: "getCapital", + arguments: JSON.stringify({ + country: "France", + }), + }, + }, + ], + }, + ]; + expect(flattenMessages(messages).length).toBe(3); + }); +}); diff --git a/core/llm/images.ts b/core/llm/images.ts deleted file mode 100644 index 912ae48caf..0000000000 --- a/core/llm/images.ts +++ /dev/null @@ -1,11 +0,0 @@ -import { MessageContent } from "../index.js"; - -export function stripImages(content: MessageContent): string { - if (Array.isArray(content)) { - return content - .filter((part) => part.type === "text") - .map((part) => part.text) - .join("\n"); - } - return content; - } \ No newline at end of file diff --git a/core/llm/index.ts b/core/llm/index.ts index 8759af8a51..144df72d25 100644 --- a/core/llm/index.ts +++ b/core/llm/index.ts @@ -4,7 +4,6 @@ import Handlebars from "handlebars"; import { CacheBehavior, ChatMessage, - ChatMessageRole, CompletionOptions, ILLM, LLMFullCompletionOptions, @@ -24,6 +23,7 @@ import mergeJson from "../util/merge.js"; import { Telemetry } from "../util/posthog.js"; import { withExponentialBackoff } from "../util/withExponentialBackoff.js"; +import { renderChatMessage } from "../util/messageContent.js"; import { autodetectPromptTemplates, autodetectTemplateFunction, @@ -41,7 +41,6 @@ import { countTokens, pruneRawPromptFromTop, } from "./countTokens.js"; -import { stripImages } from "./images.js"; import CompletionOptionsForModels from "./templates/options.js"; export abstract class BaseLLM implements ILLM { @@ -372,7 +371,8 @@ export abstract class BaseLLM implements ILLM { `HTTP ${e.response.status} ${e.response.statusText} from ${e.response.url}\n\n${e.response.body}`, ); } else { - if (e.name !== "AbortError") { // Don't pollute console with abort errors. Check on name instead of instanceof, to avoid importing node-fetch here + if (e.name !== "AbortError") { + // Don't pollute console with abort errors. Check on name instead of instanceof, to avoid importing node-fetch here console.debug( `${e.message}\n\nCode: ${e.code}\nError number: ${e.errno}\nSyscall: ${e.erroredSysCall}\nType: ${e.type}\n\n${e.stack}`, ); @@ -414,7 +414,7 @@ export abstract class BaseLLM implements ILLM { let formatted = ""; for (const msg of msgsCopy) { if ("content" in msg && Array.isArray(msg.content)) { - const content = stripImages(msg.content); + const content = renderChatMessage(msg); msg.content = content; } formatted += `<${msg.role}>\n${msg.content || ""}\n\n`; @@ -533,7 +533,11 @@ export abstract class BaseLLM implements ILLM { }; } - async complete(_prompt: string, signal: AbortSignal, options: LLMFullCompletionOptions = {}) { + async complete( + _prompt: string, + signal: AbortSignal, + options: LLMFullCompletionOptions = {}, + ) { const { completionOptions, log, raw } = this._parseCompletionOptions(options); @@ -568,12 +572,16 @@ export abstract class BaseLLM implements ILLM { return completion; } - async chat(messages: ChatMessage[], signal: AbortSignal, options: LLMFullCompletionOptions = {}) { + async chat( + messages: ChatMessage[], + signal: AbortSignal, + options: LLMFullCompletionOptions = {}, + ) { let completion = ""; for await (const chunk of this.streamChat(messages, signal, options)) { completion += chunk.content; } - return { role: "assistant" as ChatMessageRole, content: completion }; + return { role: "assistant" as const, content: completion }; } async *streamChat( @@ -668,7 +676,11 @@ export abstract class BaseLLM implements ILLM { } } - protected async _complete(prompt: string, signal: AbortSignal, options: CompletionOptions) { + protected async _complete( + prompt: string, + signal: AbortSignal, + options: CompletionOptions, + ) { let completion = ""; for await (const chunk of this._streamComplete(prompt, signal, options)) { completion += chunk; diff --git a/core/llm/llms/Anthropic.ts b/core/llm/llms/Anthropic.ts index 5033658c7f..a10d2c4a48 100644 --- a/core/llm/llms/Anthropic.ts +++ b/core/llm/llms/Anthropic.ts @@ -4,7 +4,7 @@ import { LLMOptions, ModelProvider, } from "../../index.js"; -import { stripImages } from "../images.js"; +import { renderChatMessage, stripImages } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; import { streamSse } from "../stream.js"; @@ -30,11 +30,79 @@ class Anthropic extends BaseLLM { model: options.model === "claude-2" ? "claude-2.1" : options.model, stop_sequences: options.stop?.filter((x) => x.trim() !== ""), stream: options.stream ?? true, + tools: options.tools?.map((tool) => ({ + name: tool.function.name, + description: tool.function.description, + input_schema: tool.function.parameters, + })), }; return finalOptions; } + private convertMessage(message: ChatMessage, addCaching: boolean): any { + if (message.role === "tool") { + return { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: message.toolCallId, + content: renderChatMessage(message) || undefined, + }, + ], + }; + } else if (message.role === "assistant" && message.toolCalls) { + return { + role: "assistant", + content: message.toolCalls.map((toolCall) => ({ + type: "tool_use", + id: toolCall.id, + name: toolCall.function?.name, + input: JSON.parse(toolCall.function?.arguments || "{}"), + })), + }; + } + + if (typeof message.content === "string") { + var chatMessage = { + role: message.role, + content: [ + { + type: "text", + text: message.content, + ...(addCaching ? { cache_control: { type: "ephemeral" } } : {}), + }, + ], + }; + return chatMessage; + } + + return { + role: message.role, + content: message.content.map((part, contentIdx) => { + if (part.type === "text") { + const newpart = { + ...part, + // If multiple text parts, only add cache_control to the last one + ...(addCaching && contentIdx == message.content.length - 1 + ? { cache_control: { type: "ephemeral" } } + : {}), + }; + return newpart; + } + return { + type: "image", + source: { + type: "base64", + media_type: "image/jpeg", + data: part.imageUrl?.url.split(",")[1], + }, + }; + }), + }; + } + public convertMessages(msgs: ChatMessage[]): any[] { // should be public for use within VertexAI const filteredmessages = msgs.filter( @@ -54,43 +122,8 @@ class Anthropic extends BaseLLM { this.cacheBehavior?.cacheConversation && lastTwoUserMsgIndices.includes(filteredMsgIdx); - if (typeof message.content === "string") { - var chatMessage = { - ...message, - content: [ - { - type: "text", - text: message.content, - ...(addCaching ? { cache_control: { type: "ephemeral" } } : {}), - }, - ], - }; - return chatMessage; - } - - return { - ...message, - content: message.content.map((part, contentIdx) => { - if (part.type === "text") { - const newpart = { - ...part, - // If multiple text parts, only add cache_control to the last one - ...(addCaching && contentIdx == message.content.length - 1 - ? { cache_control: { type: "ephemeral" } } - : {}), - }; - return newpart; - } - return { - type: "image", - source: { - type: "base64", - media_type: "image/jpeg", - data: part.imageUrl?.url.split(",")[1], - }, - }; - }), - }; + const chatMessage = this.convertMessage(message, !!addCaching); + return chatMessage; }); return messages; } @@ -102,7 +135,7 @@ class Anthropic extends BaseLLM { ): AsyncGenerator { const messages = [{ role: "user" as const, content: prompt }]; for await (const update of this._streamChat(messages, signal, options)) { - yield stripImages(update.content); + yield renderChatMessage(update); } } @@ -114,9 +147,10 @@ class Anthropic extends BaseLLM { const shouldCacheSystemMessage = !!this.systemMessage && this.cacheBehavior?.cacheSystemMessage; const systemMessage: string = stripImages( - messages.filter((m) => m.role === "system")[0]?.content, + messages.filter((m) => m.role === "system")[0]?.content ?? "", ); + const msgs = this.convertMessages(messages); const response = await this.fetch(new URL("messages", this.apiBase), { method: "POST", headers: { @@ -130,7 +164,7 @@ class Anthropic extends BaseLLM { }, body: JSON.stringify({ ...this.convertArgs(options), - messages: this.convertMessages(messages), + messages: msgs, system: shouldCacheSystemMessage ? [ { @@ -141,21 +175,74 @@ class Anthropic extends BaseLLM { ] : systemMessage, }), - signal + signal, }); + if (!response.ok) { + const json = await response.json(); + if (json.type === "error") { + if (json.error?.type === "overloaded_error") { + throw new Error( + "The Anthropic API is currently overloaded. Please check their status page: https://status.anthropic.com/#past-incidents", + ); + } + throw new Error(json.message); + } + throw new Error( + `Anthropic API sent back ${response.status}: ${JSON.stringify(json)}`, + ); + } + if (options.stream === false) { const data = await response.json(); yield { role: "assistant", content: data.content[0].text }; return; } + let lastToolUseId: string | undefined; + let lastToolUseName: string | undefined; for await (const value of streamSse(response)) { - if (value.type == "message_start") { - console.log(value); - } - if (value.delta?.text) { - yield { role: "assistant", content: value.delta.text }; + // https://docs.anthropic.com/en/api/messages-streaming#event-types + switch (value.type) { + case "content_block_start": + if (value.content_block.type === "tool_use") { + lastToolUseId = value.content_block.id; + lastToolUseName = value.content_block.name; + } + break; + case "content_block_delta": + // https://docs.anthropic.com/en/api/messages-streaming#delta-types + switch (value.delta.type) { + case "text_delta": + yield { role: "assistant", content: value.delta.text }; + break; + case "input_json_delta": + if (!lastToolUseId || !lastToolUseName) { + throw new Error("No tool use found"); + } + yield { + role: "assistant", + content: "", + toolCalls: [ + { + id: lastToolUseId, + type: "function", + function: { + name: lastToolUseName, + arguments: value.delta.partial_json, + }, + }, + ], + }; + break; + } + break; + case "content_block_stop": + lastToolUseId = undefined; + lastToolUseName = undefined; + break; + default: + break; } } } diff --git a/core/llm/llms/Bedrock.ts b/core/llm/llms/Bedrock.ts index c2f76d0a8b..31445c84aa 100644 --- a/core/llm/llms/Bedrock.ts +++ b/core/llm/llms/Bedrock.ts @@ -11,7 +11,7 @@ import { MessageContent, ModelProvider, } from "../../index.js"; -import { stripImages } from "../images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; /** @@ -47,7 +47,7 @@ class Bedrock extends BaseLLM { ): AsyncGenerator { const messages = [{ role: "user" as const, content: prompt }]; for await (const update of this._streamChat(messages, signal, options)) { - yield stripImages(update.content); + yield renderChatMessage(update); } } @@ -91,7 +91,7 @@ class Bedrock extends BaseLLM { const input = this._generateConverseInput(messages, options); const command = new ConverseStreamCommand(input); - const response = await client.send(command, { abortSignal: signal}); + const response = await client.send(command, { abortSignal: signal }); if (response.stream) { for await (const chunk of response.stream) { diff --git a/core/llm/llms/Cloudflare.ts b/core/llm/llms/Cloudflare.ts index 07e2e427c9..215e97b37c 100644 --- a/core/llm/llms/Cloudflare.ts +++ b/core/llm/llms/Cloudflare.ts @@ -1,5 +1,5 @@ import { ChatMessage, CompletionOptions, ModelProvider } from "../../index.js"; -import { stripImages } from "../images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; import { streamSse } from "../stream.js"; @@ -36,7 +36,7 @@ export default class Cloudflare extends BaseLLM { model: this.model, ...this._convertArgs(options), }), - signal + signal, }); for await (const value of streamSse(resp)) { @@ -56,7 +56,7 @@ export default class Cloudflare extends BaseLLM { signal, options, )) { - yield stripImages(chunk.content); + yield renderChatMessage(chunk); } } } diff --git a/core/llm/llms/Cohere.ts b/core/llm/llms/Cohere.ts index 98135c7cc0..a43df7a62a 100644 --- a/core/llm/llms/Cohere.ts +++ b/core/llm/llms/Cohere.ts @@ -4,7 +4,7 @@ import { LLMOptions, ModelProvider, } from "../../index.js"; -import { stripImages } from "../images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; import { streamJSON } from "../stream.js"; @@ -51,7 +51,7 @@ class Cohere extends BaseLLM { ): AsyncGenerator { const messages = [{ role: "user" as const, content: prompt }]; for await (const update of this._streamChat(messages, signal, options)) { - yield stripImages(update.content); + yield renderChatMessage(update); } } @@ -75,7 +75,7 @@ class Cohere extends BaseLLM { chat_history: this._convertMessages(messages), preamble: this.systemMessage, }), - signal + signal, }); if (options.stream === false) { diff --git a/core/llm/llms/Flowise.ts b/core/llm/llms/Flowise.ts index ecca89d16a..8c5b0f4594 100644 --- a/core/llm/llms/Flowise.ts +++ b/core/llm/llms/Flowise.ts @@ -6,7 +6,7 @@ import { LLMOptions, ModelProvider, } from "../../index.js"; -import { stripImages } from "../images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; interface IFlowiseApiOptions { @@ -126,7 +126,7 @@ class Flowise extends BaseLLM { ): AsyncGenerator { const message: ChatMessage = { role: "user", content: prompt }; for await (const chunk of this._streamChat([message], signal, options)) { - yield stripImages(chunk.content); + yield renderChatMessage(chunk); } } @@ -141,7 +141,7 @@ class Flowise extends BaseLLM { method: "POST", headers: this._getHeaders(), body: JSON.stringify({ ...requestBody, socketIOClientId: socket.id }), - signal + signal, }).then((res) => res.json()); while (await socketInfo.hasNextToken()) { diff --git a/core/llm/llms/FreeTrial.ts b/core/llm/llms/FreeTrial.ts index 8f80c74189..895d6c42fc 100644 --- a/core/llm/llms/FreeTrial.ts +++ b/core/llm/llms/FreeTrial.ts @@ -74,7 +74,7 @@ class FreeTrial extends BaseLLM { prompt, ...args, }), - signal + signal, }); let completion = ""; @@ -82,10 +82,18 @@ class FreeTrial extends BaseLLM { yield value; completion += value; } - this._countTokens(completion, args.model, false); + void this._countTokens(completion, args.model, false); } protected _convertMessage(message: ChatMessage) { + if (message.role === "tool") { + return { + role: "tool", + content: message.content, + tool_call_id: message.toolCallId, + }; + } + if (typeof message.content === "string") { return message; } @@ -123,7 +131,7 @@ class FreeTrial extends BaseLLM { messages: messages.map(this._convertMessage), ...args, }), - signal + signal, }); let completion = ""; @@ -158,7 +166,7 @@ class FreeTrial extends BaseLLM { suffix, ...args, }), - signal + signal, }); let completion = ""; diff --git a/core/llm/llms/Gemini.ts b/core/llm/llms/Gemini.ts index ada6032401..e9cc484c8a 100644 --- a/core/llm/llms/Gemini.ts +++ b/core/llm/llms/Gemini.ts @@ -5,7 +5,7 @@ import { MessagePart, ModelProvider, } from "../../index.js"; -import { stripImages } from "../images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; import { streamResponse } from "../stream.js"; @@ -19,7 +19,8 @@ class Gemini extends BaseLLM { }; // Function to convert completion options to Gemini format - public convertArgs(options: CompletionOptions) { // should be public for use within VertexAI + public convertArgs(options: CompletionOptions) { + // should be public for use within VertexAI const finalOptions: any = {}; // Initialize an empty object // Map known options @@ -54,17 +55,19 @@ class Gemini extends BaseLLM { signal, options, )) { - yield stripImages(message.content); + yield renderChatMessage(message); } } - public removeSystemMessage(messages: ChatMessage[]) { // should be public for use within VertexAI + public removeSystemMessage(messages: ChatMessage[]) { + // should be public for use within VertexAI const msgs = [...messages]; if (msgs[0]?.role === "system") { const sysMsg = msgs.shift()?.content; // @ts-ignore if (msgs[0]?.role === "user") { + // @ts-ignore msgs[0].content = `System message - follow these instructions in every response: ${sysMsg}\n\n---\n\n${msgs[0].content}`; } } @@ -143,6 +146,9 @@ class Gemini extends BaseLLM { if (msg.role === "system" && !isV1API) { return null; // Don't include system message in contents } + if (msg.role === "tool") { + return null; + } return { role: msg.role === "assistant" ? "model" : "user", parts: @@ -165,7 +171,7 @@ class Gemini extends BaseLLM { const response = await this.fetch(apiURL, { method: "POST", body: JSON.stringify(body), - signal + signal, }); let buffer = ""; @@ -234,7 +240,7 @@ class Gemini extends BaseLLM { const response = await this.fetch(apiURL, { method: "POST", body: JSON.stringify(body), - signal + signal, }); const data = await response.json(); yield { role: "assistant", content: data.candidates[0].content }; diff --git a/core/llm/llms/Ollama.ts b/core/llm/llms/Ollama.ts index 642f8bcc5d..c56884e4b1 100644 --- a/core/llm/llms/Ollama.ts +++ b/core/llm/llms/Ollama.ts @@ -4,13 +4,11 @@ import { LLMOptions, ModelProvider, } from "../../index.js"; -import { stripImages } from "../images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; import { streamResponse } from "../stream.js"; -interface OllamaChatMessage extends ChatMessage { - images?: string[]; -} +type OllamaChatMessage = ChatMessage & { images?: string[] }; // See https://github.com/ollama/ollama/blob/main/docs/modelfile.md for details on each parameter interface ModelFileParams { @@ -195,6 +193,10 @@ class Ollama extends BaseLLM { } private _convertMessage(message: ChatMessage) { + if (message.role === "tool") { + return null; + } + if (typeof message.content === "string") { return message; } @@ -210,7 +212,7 @@ class Ollama extends BaseLLM { return { role: message.role, - content: stripImages(message.content), + content: renderChatMessage(message), images, }; } @@ -221,7 +223,7 @@ class Ollama extends BaseLLM { ): ChatOptions { return { model: this._getModel(), - messages: messages.map(this._convertMessage), + messages: messages.map(this._convertMessage).filter(Boolean) as any, options: this._getModelFileParams(options), keep_alive: options.keepAlive ?? 60 * 30, // 30 minutes stream: options.stream, diff --git a/core/llm/llms/OpenAI.ts b/core/llm/llms/OpenAI.ts index 4f929b5067..35d05a62d3 100644 --- a/core/llm/llms/OpenAI.ts +++ b/core/llm/llms/OpenAI.ts @@ -3,8 +3,9 @@ import { CompletionOptions, LLMOptions, ModelProvider, + Tool, } from "../../index.js"; -import { stripImages } from "../images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; import { streamSse } from "../stream.js"; @@ -56,6 +57,14 @@ class OpenAI extends BaseLLM { }; protected _convertMessage(message: ChatMessage) { + if (message.role === "tool") { + return { + role: "tool", + content: message.content, + tool_call_id: message.toolCallId, + }; + } + if (typeof message.content === "string") { return message; } else if (!message.content.some((item) => item.type !== "text")) { @@ -99,8 +108,22 @@ class OpenAI extends BaseLLM { return ["gpt-4o-mini", "gpt-4o"].includes(model); } + private convertTool(tool: Tool): any { + return { + type: tool.type, + function: { + name: tool.function.name, + description: tool.function.description, + parameters: tool.function.parameters, + strict: tool.function.strict, + }, + }; + } + protected _convertArgs(options: CompletionOptions, messages: ChatMessage[]) { const url = new URL(this.apiBase!); + const tools = options.tools?.map(this.convertTool); + const finalOptions: any = { messages: messages.map(this._convertMessage), model: this._convertModelName(options.model), @@ -122,6 +145,7 @@ class OpenAI extends BaseLLM { this.apiType === "azure" ? options.stop?.slice(0, 4) : options.stop, + tools, }; // OpenAI o1-preview and o1-mini: @@ -206,7 +230,7 @@ class OpenAI extends BaseLLM { signal, options, )) { - yield stripImages(chunk.content); + yield renderChatMessage(chunk); } } @@ -249,7 +273,7 @@ class OpenAI extends BaseLLM { options.raw) ) { for await (const content of this._legacystreamComplete( - stripImages(messages[messages.length - 1]?.content || ""), + renderChatMessage(messages[messages.length - 1]), signal, options, )) { @@ -266,6 +290,9 @@ class OpenAI extends BaseLLM { body.messages = body.messages.map((m: any) => ({ ...m, content: m.content === "" ? " " : m.content, + // We call it toolCalls, they call it tool_calls + tool_calls: m.toolCalls, + tool_call_id: m.toolCallId, })) as any; const response = await this.fetch(this._getEndpoint("chat/completions"), { method: "POST", @@ -284,6 +311,21 @@ class OpenAI extends BaseLLM { for await (const value of streamSse(response)) { if (value.choices?.[0]?.delta?.content) { yield value.choices[0].delta; + } else if (value.choices?.[0]?.delta?.tool_calls) { + yield { + role: "assistant", + content: "", + toolCalls: value.choices?.[0]?.delta?.tool_calls.map( + (tool_call: any) => ({ + id: tool_call.id, + type: tool_call.type, + function: { + name: tool_call.function.name, + arguments: tool_call.function.arguments, + }, + }), + ), + }; } } } diff --git a/core/llm/llms/VertexAI.ts b/core/llm/llms/VertexAI.ts index c17aac90df..c9254eaccb 100644 --- a/core/llm/llms/VertexAI.ts +++ b/core/llm/llms/VertexAI.ts @@ -4,13 +4,11 @@ import { ChatMessage, CompletionOptions, LLMOptions, - MessagePart, ModelProvider, } from "../../index.js"; -import { stripImages } from "../images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; -import { streamSse } from "../stream.js"; -import { streamResponse } from "../stream.js"; +import { streamResponse, streamSse } from "../stream.js"; import Anthropic from "./Anthropic.js"; import Gemini from "./Gemini.js"; @@ -40,16 +38,17 @@ class VertexAI extends BaseLLM { super(_options); this.apiBase ??= VertexAI.getDefaultApiBaseFrom(_options); this.vertexProvider = - _options.model.includes("mistral") || _options.model.includes("codestral") || _options.model.includes("mixtral") + _options.model.includes("mistral") || + _options.model.includes("codestral") || + _options.model.includes("mixtral") ? "mistral" : _options.model.includes("claude") ? "anthropic" : _options.model.includes("gemini") - ? "gemini" : - "unknown"; + ? "gemini" + : "unknown"; this.anthropicInstance = new Anthropic(_options); this.geminiInstance = new Gemini(_options); - } async fetch(url: RequestInfo | URL, init?: RequestInit) { @@ -82,18 +81,18 @@ class VertexAI extends BaseLLM { }; } - protected async *StreamChatAnthropic( messages: ChatMessage[], options: CompletionOptions, ): AsyncGenerator { - const shouldCacheSystemMessage = !!this.systemMessage && this.cacheBehavior?.cacheSystemMessage; - const systemMessage: string = stripImages( - messages.filter((m) => m.role === "system")[0]?.content, + const shouldCacheSystemMessage = + !!this.systemMessage && this.cacheBehavior?.cacheSystemMessage; + const systemMessage: string = renderChatMessage( + messages.filter((m) => m.role === "system")[0], ); const apiURL = new URL( `publishers/anthropic/models/${options.model}:streamRawPredict`, - this.apiBase + this.apiBase, ); const response = await this.fetch(apiURL, { @@ -108,12 +107,12 @@ class VertexAI extends BaseLLM { messages: this.anthropicInstance.convertMessages(messages), system: shouldCacheSystemMessage ? [ - { - type: "text", - text: this.systemMessage, - cache_control: { type: "ephemeral" }, - }, - ] + { + type: "text", + text: this.systemMessage, + cache_control: { type: "ephemeral" }, + }, + ] : systemMessage, }), }); @@ -125,14 +124,15 @@ class VertexAI extends BaseLLM { } for await (const value of streamSse(response)) { - if (value.type == "message_start") {console.log(value);} + if (value.type == "message_start") { + console.log(value); + } if (value.delta?.text) { yield { role: "assistant", content: value.delta.text }; } } } - //Gemini private async *streamChatGemini( @@ -152,6 +152,10 @@ class VertexAI extends BaseLLM { if (msg.role === "system" && !isV1API) { return null; // Don't include system message in contents } + if (msg.role === "tool") { + return null; + } + return { role: msg.role === "assistant" ? "model" : "user", parts: @@ -168,8 +172,8 @@ class VertexAI extends BaseLLM { // if this.systemMessage is defined, reformat it for Gemini API ...(this.systemMessage && !isV1API && { - systemInstruction: { parts: [{ text: this.systemMessage }] }, - }), + systemInstruction: { parts: [{ text: this.systemMessage }] }, + }), }; const response = await this.fetch(apiURL, { method: "POST", @@ -265,7 +269,6 @@ class VertexAI extends BaseLLM { //Mistral - protected async *StreamChatMistral( messages: ChatMessage[], options: CompletionOptions, @@ -330,7 +333,7 @@ class VertexAI extends BaseLLM { const response = await this.fetch(apiURL, { method: "POST", body: JSON.stringify(body), - signal + signal, }); for await (const chunk of streamSse(response)) { @@ -338,8 +341,6 @@ class VertexAI extends BaseLLM { } } - - //gecko protected async *streamFimGecko( prefix: string, @@ -347,15 +348,18 @@ class VertexAI extends BaseLLM { signal: AbortSignal, options: CompletionOptions, ): AsyncGenerator { - const endpoint = new URL("publishers/google/models/code-gecko:predict", this.apiBase); + const endpoint = new URL( + "publishers/google/models/code-gecko:predict", + this.apiBase, + ); const resp = await this.fetch(endpoint, { method: "POST", body: JSON.stringify({ instances: [ { prefix: prefix, - suffix: suffix - } + suffix: suffix, + }, ], parameters: { temperature: options.temperature, @@ -363,10 +367,9 @@ class VertexAI extends BaseLLM { stopSequences: options.stop?.splice(0, 5), frequencyPenalty: options.frequencyPenalty, presencePenalty: options.frequencyPenalty, - } - + }, }), - signal + signal, }); // Streaming is not supported by code-gecko // TODO: convert to non-streaming fim method when one exist in continue. @@ -409,7 +412,7 @@ class VertexAI extends BaseLLM { signal, options, )) { - yield stripImages(message.content); + yield renderChatMessage(message); } } @@ -419,9 +422,6 @@ class VertexAI extends BaseLLM { signal: AbortSignal, options: CompletionOptions, ): AsyncGenerator { - - - if (this.model === "code-gecko") { yield* this.streamFimGecko(prefix, suffix, signal, options); } else if (this.model.includes("codestral")) { @@ -429,21 +429,15 @@ class VertexAI extends BaseLLM { } else { throw new Error(`Unsupported model: ${this.model}`); } - - } supportsFim(): boolean { return ["code-gecko", "codestral-latest"].includes(this.model); } - - } async function delay(seconds: number) { return new Promise((resolve) => setTimeout(resolve, seconds * 1000)); } - - export default VertexAI; diff --git a/core/llm/llms/WatsonX.ts b/core/llm/llms/WatsonX.ts index 8085576920..fb82dacf9d 100644 --- a/core/llm/llms/WatsonX.ts +++ b/core/llm/llms/WatsonX.ts @@ -4,25 +4,22 @@ import { LLMOptions, ModelProvider, } from "../../index.js"; -import { stripImages } from "../images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; import { streamResponse } from "../stream.js"; let watsonxToken = { - expiration: 0, - token: "" + expiration: 0, + token: "", }; class WatsonX extends BaseLLM { - constructor(options: LLMOptions) { super(options); } async getBearerToken(): Promise<{ token: string; expiration: number }> { - if ( - this.apiBase?.includes("cloud.ibm.com") - ) { + if (this.apiBase?.includes("cloud.ibm.com")) { // watsonx SaaS const wxToken = await ( await fetch( @@ -82,11 +79,9 @@ class WatsonX extends BaseLLM { } getWatsonxEndpoint(): string { - return ( - this.deploymentId ? - `${this.apiBase}/ml/v1/deployments/${this.deploymentId}/text/generation_stream?version=${this.apiVersion}`: - `${this.apiBase}/ml/v1/text/generation_stream?version=${this.apiVersion}` - ); + return this.deploymentId + ? `${this.apiBase}/ml/v1/deployments/${this.deploymentId}/text/generation_stream?version=${this.apiVersion}` + : `${this.apiBase}/ml/v1/text/generation_stream?version=${this.apiVersion}`; } static providerName: ModelProvider = "watsonx"; @@ -96,6 +91,10 @@ class WatsonX extends BaseLLM { return message; } + if (message.role === "tool") { + return null; + } + const parts = message.content.map((part) => { const msg: any = { type: part.type, @@ -119,7 +118,7 @@ class WatsonX extends BaseLLM { protected _convertArgs(options: any, messages: ChatMessage[]) { const finalOptions = { - messages: messages.map(this._convertMessage), + messages: messages.map(this._convertMessage).filter(Boolean), model: this._convertModelName(options.model), max_tokens: options.maxTokens, temperature: options.temperature, @@ -133,8 +132,9 @@ class WatsonX extends BaseLLM { protected _getHeaders() { return { "Content-Type": "application/json", - Authorization: `${watsonxToken.expiration === -1 ? "ZenApiKey" : "Bearer" - } ${watsonxToken.token}`, + Authorization: `${ + watsonxToken.expiration === -1 ? "ZenApiKey" : "Bearer" + } ${watsonxToken.token}`, }; } @@ -165,7 +165,7 @@ class WatsonX extends BaseLLM { signal, options, )) { - yield stripImages(chunk.content); + yield renderChatMessage(chunk); } } @@ -183,7 +183,8 @@ class WatsonX extends BaseLLM { watsonxToken = await this.getBearerToken(); } else { console.log( - `Reusing token (expires in ${(watsonxToken.expiration - now) / 60 + `Reusing token (expires in ${ + (watsonxToken.expiration - now) / 60 } mins)`, ); } @@ -191,7 +192,8 @@ class WatsonX extends BaseLLM { throw new Error("Something went wrong. Check your credentials, please."); } const stopSequences = - options.stop?.slice(0,6) ?? (options.model?.includes("granite") ? ["Question:"] : []); + options.stop?.slice(0, 6) ?? + (options.model?.includes("granite") ? ["Question:"] : []); const url = this.getWatsonxEndpoint(); const headers = this._getHeaders(); @@ -224,7 +226,7 @@ class WatsonX extends BaseLLM { method: "POST", headers: headers, body: JSON.stringify(payload), - signal + signal, }); if (!response.ok || response.body === null) { diff --git a/core/llm/templates/chat.ts b/core/llm/templates/chat.ts index 9d26620f44..a4b999b05b 100644 --- a/core/llm/templates/chat.ts +++ b/core/llm/templates/chat.ts @@ -1,5 +1,5 @@ import { ChatMessage } from "../../index.js"; -import { stripImages } from "../images.js"; +import { renderChatMessage } from "../../util/messageContent.js"; function templateFactory( systemMessage: (msg: ChatMessage) => string, @@ -64,7 +64,7 @@ function llama2TemplateMessages(msgs: ChatMessage[]): string { let prompt = ""; let hasSystem = msgs[0].role === "system"; - if (hasSystem && stripImages(msgs[0].content).trim() === "") { + if (hasSystem && renderChatMessage(msgs[0]).trim() === "") { hasSystem = false; msgs = msgs.slice(1); } @@ -158,7 +158,7 @@ function deepseekTemplateMessages(msgs: ChatMessage[]): string { prompt += "You are an AI programming assistant, utilizing the DeepSeek Coder model, developed by DeepSeek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n"; if (msgs[0].role === "system") { - system = stripImages(msgs.shift()!.content); + system = renderChatMessage(msgs.shift()!); } for (let i = 0; i < msgs.length; i++) { @@ -245,7 +245,7 @@ function codeLlama70bTemplateMessages(msgs: ChatMessage[]): string { let prompt = ""; for (const msg of msgs) { - prompt += `Source: ${msg.role}\n\n ${stripImages(msg.content).trim()}`; + prompt += `Source: ${msg.role}\n\n ${renderChatMessage(msg).trim()}`; prompt += " "; } diff --git a/core/package-lock.json b/core/package-lock.json index a53a1a49cd..9175dbfd92 100644 --- a/core/package-lock.json +++ b/core/package-lock.json @@ -14,6 +14,7 @@ "@aws-sdk/credential-providers": "^3.620.1", "@continuedev/config-types": "^1.0.13", "@continuedev/llm-info": "^1.0.2", + "@modelcontextprotocol/sdk": "^1.0.0", "@mozilla/readability": "^0.5.0", "@octokit/rest": "^20.0.2", "@typescript-eslint/eslint-plugin": "^7.8.0", @@ -47,6 +48,7 @@ "onnxruntime-node": "1.14.0", "openai": "^4.20.1", "p-limit": "^6.1.0", + "partial-json": "^0.1.7", "pg": "^8.11.3", "posthog-node": "^3.6.3", "puppeteer": "^22.4.0", @@ -4155,6 +4157,16 @@ "win32" ] }, + "node_modules/@modelcontextprotocol/sdk": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.0.0.tgz", + "integrity": "sha512-mbe0otw8vTtZoL5pVucXAmx6oEC7YjdXBgVeFkJXASu4OAnLkrIeNw9zwzU5CwEp19M54bjOUGcna90Dl/H5Bw==", + "dependencies": { + "content-type": "^1.0.5", + "raw-body": "^3.0.0", + "zod": "^3.23.8" + } + }, "node_modules/@mozilla/readability": { "version": "0.5.0", "resolved": "https://registry.npmjs.org/@mozilla/readability/-/readability-0.5.0.tgz", @@ -6557,6 +6569,14 @@ "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", "dev": true }, + "node_modules/bytes": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", + "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/cacache": { "version": "15.3.0", "resolved": "https://registry.npmjs.org/cacache/-/cacache-15.3.0.tgz", @@ -6996,6 +7016,14 @@ "resolved": "https://registry.npmjs.org/console-control-strings/-/console-control-strings-1.1.0.tgz", "integrity": "sha512-ty/fTekppD2fIwRvnZAVdeOiGd1c7YXEixbgJTNzqcxJWKQnjJ/V1bNEEE6hygpM3WjwHFUVK6HTjWSzV4a8sQ==" }, + "node_modules/content-type": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.5.tgz", + "integrity": "sha512-nTjqfcBFEipKdXCv4YDQWCfmcLZKm81ldF0pAopTvyrFGVbcR6P/VAAd5G7N+0tTr8QqiU0tFadD6FK4NtJwOA==", + "engines": { + "node": ">= 0.6" + } + }, "node_modules/convert-source-map": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-2.0.0.tgz", @@ -7379,6 +7407,14 @@ "node": ">=0.10" } }, + "node_modules/depd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/deprecation": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/deprecation/-/deprecation-2.3.1.tgz", @@ -9343,6 +9379,21 @@ "integrity": "sha512-er295DKPVsV82j5kw1Gjt+ADA/XYHsajl82cGNQG2eyoPkvgUhX+nDIyelzhIWbbsXP39EHcI6l5tYs2FYqYXQ==", "optional": true }, + "node_modules/http-errors": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-2.0.0.tgz", + "integrity": "sha512-FtwrG/euBzaEjYeRqOgly7G0qviiXoJWnvEH2Z1plBdXgbyjv34pHTSb9zoeHMyDy33+DWy5Wt9Wo+TURtOYSQ==", + "dependencies": { + "depd": "2.0.0", + "inherits": "2.0.4", + "setprototypeof": "1.2.0", + "statuses": "2.0.1", + "toidentifier": "1.0.1" + }, + "engines": { + "node": ">= 0.8" + } + }, "node_modules/http-proxy-agent": { "version": "7.0.2", "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-7.0.2.tgz", @@ -12497,6 +12548,11 @@ "url": "https://github.com/inikulin/parse5?sponsor=1" } }, + "node_modules/partial-json": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/partial-json/-/partial-json-0.1.7.tgz", + "integrity": "sha512-Njv/59hHaokb/hRUjce3Hdv12wd60MtM9Z5Olmn+nehe0QDAsRtRbJPvJ0Z91TusF0SuZRIvnM+S4l6EIP8leA==" + }, "node_modules/path-exists": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", @@ -13262,6 +13318,20 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/raw-body": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-3.0.0.tgz", + "integrity": "sha512-RmkhL8CAyCRPXCE28MMH0z2PNWQBNk2Q09ZdxM9IOOXwxwZbN+qbWaatPkdkWIKL2ZVDImrN/pK5HTRz2PcS4g==", + "dependencies": { + "bytes": "3.1.2", + "http-errors": "2.0.0", + "iconv-lite": "0.6.3", + "unpipe": "1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, "node_modules/rc": { "version": "1.2.8", "resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz", @@ -13740,6 +13810,11 @@ "node": ">= 0.4" } }, + "node_modules/setprototypeof": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", + "integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==" + }, "node_modules/sharp": { "version": "0.32.6", "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz", @@ -14134,6 +14209,14 @@ "node": ">=8" } }, + "node_modules/statuses": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", + "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/stoppable": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/stoppable/-/stoppable-1.1.0.tgz", @@ -14626,6 +14709,14 @@ "node": ">=8.0" } }, + "node_modules/toidentifier": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/toidentifier/-/toidentifier-1.0.1.tgz", + "integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==", + "engines": { + "node": ">=0.6" + } + }, "node_modules/tough-cookie": { "version": "4.1.4", "resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-4.1.4.tgz", @@ -15042,6 +15133,14 @@ "node": ">= 4.0.0" } }, + "node_modules/unpipe": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", + "integrity": "sha512-pjy2bYhSsufwWlKwPc+l3cN7+wuJlK6uz0YdJEOlQDbl6jo/YlPi4mb8agUkVC8BF7V8NuzeyPNqRksA3hztKQ==", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/update-browserslist-db": { "version": "1.0.16", "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.16.tgz", diff --git a/core/package.json b/core/package.json index 91652280a7..78d7ecf65e 100644 --- a/core/package.json +++ b/core/package.json @@ -45,6 +45,7 @@ "@aws-sdk/credential-providers": "^3.620.1", "@continuedev/config-types": "^1.0.13", "@continuedev/llm-info": "^1.0.2", + "@modelcontextprotocol/sdk": "^1.0.0", "@mozilla/readability": "^0.5.0", "@octokit/rest": "^20.0.2", "@typescript-eslint/eslint-plugin": "^7.8.0", @@ -78,6 +79,7 @@ "onnxruntime-node": "1.14.0", "openai": "^4.20.1", "p-limit": "^6.1.0", + "partial-json": "^0.1.7", "pg": "^8.11.3", "posthog-node": "^3.6.3", "puppeteer": "^22.4.0", diff --git a/core/promptFiles/v1/slashCommandFromPromptFile.ts b/core/promptFiles/v1/slashCommandFromPromptFile.ts index 9e4f0c1942..7d3750ad41 100644 --- a/core/promptFiles/v1/slashCommandFromPromptFile.ts +++ b/core/promptFiles/v1/slashCommandFromPromptFile.ts @@ -1,8 +1,8 @@ import * as YAML from "yaml"; import { ContinueSDK, SlashCommand } from "../.."; -import { stripImages } from "../../llm/images"; import { getBasename } from "../../util/index"; +import { renderChatMessage } from "../../util/messageContent"; import { getContextProviderHelpers } from "./getContextProviderHelpers"; import { renderTemplatedString } from "./renderTemplatedString"; @@ -90,8 +90,11 @@ export function slashCommandFromPromptFile( systemMessage, ); - for await (const chunk of context.llm.streamChat(messages, new AbortController().signal)) { - yield stripImages(chunk.content); + for await (const chunk of context.llm.streamChat( + messages, + new AbortController().signal, + )) { + yield renderChatMessage(chunk); } }, }; diff --git a/core/protocol/core.ts b/core/protocol/core.ts index 9202d08b40..35785d359c 100644 --- a/core/protocol/core.ts +++ b/core/protocol/core.ts @@ -4,21 +4,21 @@ import { ProfileDescription } from "../config/ConfigHandler"; import type { BrowserSerializedContinueConfig, ChatMessage, + ContextItem, ContextItemWithId, ContextSubmenuItem, DiffLine, FileSymbolMap, IdeSettings, - IndexingStatus, LLMFullCompletionOptions, - MessageContent, ModelDescription, ModelRoles, - Session, RangeInFile, SerializedContinueConfig, + Session, SessionMetadata, SiteIndexingConfig, + ToolCall, } from "../"; export type ProtocolGeneratorType = AsyncGenerator<{ @@ -123,7 +123,7 @@ export type ToCoreFromIdeOrWebviewProtocol = { completionOptions: LLMFullCompletionOptions; title: string; }, - ProtocolGeneratorType, + ProtocolGeneratorType, ]; streamDiffLines: [ { @@ -174,4 +174,5 @@ export type ToCoreFromIdeOrWebviewProtocol = { "profiles/switch": [{ id: string }, undefined]; "auth/getAuthUrl": [undefined, { url: string }]; + "tools/call": [{ toolCall: ToolCall }, { contextItems: ContextItem[] }]; }; diff --git a/core/protocol/passThrough.ts b/core/protocol/passThrough.ts index 720a6f66e8..bbcfbf3f28 100644 --- a/core/protocol/passThrough.ts +++ b/core/protocol/passThrough.ts @@ -58,6 +58,7 @@ export const WEBVIEW_TO_CORE_PASS_THROUGH: (keyof ToCoreFromWebviewProtocol)[] = "addAutocompleteModel", "profiles/switch", "didChangeSelectedProfile", + "tools/call", ]; // Message types to pass through from core to webview diff --git a/core/tools/builtIn.ts b/core/tools/builtIn.ts new file mode 100644 index 0000000000..55e2c52148 --- /dev/null +++ b/core/tools/builtIn.ts @@ -0,0 +1,11 @@ +export enum BuiltInToolNames { + ReadFile = "read_file", + ReadCurrentlyOpenFile = "read_currently_open_file", + CreateNewFile = "create_new_file", + RunTerminalCommand = "run_terminal_command", + ViewSubdirectory = "view_subdirectory", + ViewRepoMap = "view_repo_map", + ExactSearch = "exact_search", + SearchWeb = "search_web", + ViewDiff = "view_diff", +} diff --git a/core/tools/callTool.ts b/core/tools/callTool.ts new file mode 100644 index 0000000000..69307100e3 --- /dev/null +++ b/core/tools/callTool.ts @@ -0,0 +1,109 @@ +import { ContextItem, ToolExtras } from ".."; +import MCPConnectionSingleton from "../context/mcp"; +import { BuiltInToolNames } from "./builtIn"; + +import { createNewFileImpl } from "./implementations/createNewFile"; +import { exactSearchImpl } from "./implementations/exactSearch"; +import { readCurrentlyOpenFileImpl } from "./implementations/readCurrentlyOpenFile"; +import { readFileImpl } from "./implementations/readFile"; +import { runTerminalCommandImpl } from "./implementations/runTerminalCommand"; +import { searchWebImpl } from "./implementations/searchWeb"; +import { viewDiffImpl } from "./implementations/viewDiff"; +import { viewRepoMapImpl } from "./implementations/viewRepoMap"; +import { viewSubdirectoryImpl } from "./implementations/viewSubdirectory"; + +async function callHttpTool( + url: string, + args: any, + extras: ToolExtras, +): Promise { + const response = await extras.fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + arguments: args, + }), + }); + + if (!response.ok) { + throw new Error(`Failed to call tool: ${url}`); + } + + const data = await response.json(); + return data.output; +} + +async function callToolFromUri( + uri: string, + args: any, + extras: ToolExtras, +): Promise { + const canParse = URL.canParse(uri); + if (!canParse) { + throw new Error(`Invalid URI: ${uri}`); + } + const parsedUri = new URL(uri); + + switch (parsedUri?.protocol) { + case "http:": + case "https:": + return callHttpTool(uri, args, extras); + case "mcp:": + const client = await MCPConnectionSingleton.getExistingInstance(); + if (!client) { + throw new Error("MCP connection not found"); + } + const toolName = parsedUri!.hostname; + const response = await client.client.callTool({ + name: toolName, + arguments: args, + }); + + if (response.isError === true) { + throw new Error(`Failed to call tool: ${toolName}`); + } + + return (response.content as any).map((item: any) => { + if (item.type !== "text") { + throw new Error( + `Continue received item of type "${item.type}" from MCP tool, but currently only supports "text".`, + ); + } + return { name: toolName, description: toolName, content: item.text }; + }); + + default: + throw new Error(`Unsupported protocol: ${parsedUri?.protocol}`); + } +} + +export async function callTool( + uri: string, + args: any, + extras: ToolExtras, +): Promise { + switch (uri) { + case BuiltInToolNames.ReadFile: + return await readFileImpl(args, extras); + case BuiltInToolNames.CreateNewFile: + return await createNewFileImpl(args, extras); + case BuiltInToolNames.ExactSearch: + return await exactSearchImpl(args, extras); + case BuiltInToolNames.RunTerminalCommand: + return await runTerminalCommandImpl(args, extras); + case BuiltInToolNames.SearchWeb: + return await searchWebImpl(args, extras); + case BuiltInToolNames.ViewDiff: + return await viewDiffImpl(args, extras); + case BuiltInToolNames.ViewRepoMap: + return await viewRepoMapImpl(args, extras); + case BuiltInToolNames.ViewSubdirectory: + return await viewSubdirectoryImpl(args, extras); + case BuiltInToolNames.ReadCurrentlyOpenFile: + return await readCurrentlyOpenFileImpl(args, extras); + default: + return await callToolFromUri(uri, args, extras); + } +} diff --git a/core/tools/definitions/createNewFile.ts b/core/tools/definitions/createNewFile.ts new file mode 100644 index 0000000000..f0082891cf --- /dev/null +++ b/core/tools/definitions/createNewFile.ts @@ -0,0 +1,27 @@ +import { Tool } from "../.."; +import { BuiltInToolNames } from "../builtIn"; + +export const createNewFileTool: Tool = { + type: "function", + displayTitle: "Create New File", + wouldLikeTo: "create a new file", + readonly: false, + function: { + name: BuiltInToolNames.CreateNewFile, + description: "Create a new file", + parameters: { + type: "object", + required: ["filepath", "contents"], + properties: { + filepath: { + type: "string", + description: "The path where the new file should be created", + }, + contents: { + type: "string", + description: "The contents to write to the new file", + }, + }, + }, + }, +}; diff --git a/core/tools/definitions/exactSearch.ts b/core/tools/definitions/exactSearch.ts new file mode 100644 index 0000000000..8358b55da1 --- /dev/null +++ b/core/tools/definitions/exactSearch.ts @@ -0,0 +1,24 @@ +import { Tool } from "../.."; +import { BuiltInToolNames } from "../builtIn"; + +export const exactSearchTool: Tool = { + type: "function", + displayTitle: "Exact Search", + wouldLikeTo: 'search for "{{{ query }}}" in the repository', + readonly: true, + function: { + name: BuiltInToolNames.ExactSearch, + description: "Perform an exact search over the repository using ripgrep.", + parameters: { + type: "object", + required: ["query"], + properties: { + query: { + type: "string", + description: + "The search query to use. Must be a valid ripgrep regex expression, escaped where needed", + }, + }, + }, + }, +}; diff --git a/core/tools/definitions/readCurrentlyOpenFile.ts b/core/tools/definitions/readCurrentlyOpenFile.ts new file mode 100644 index 0000000000..79497a158c --- /dev/null +++ b/core/tools/definitions/readCurrentlyOpenFile.ts @@ -0,0 +1,18 @@ +import { Tool } from "../.."; +import { BuiltInToolNames } from "../builtIn"; + +export const readCurrentlyOpenFileTool: Tool = { + type: "function", + displayTitle: "Read Currently Open File", + wouldLikeTo: "read the current file", + readonly: true, + function: { + name: BuiltInToolNames.ReadCurrentlyOpenFile, + description: + "Read the currently open file in the IDE. If the user seems to be referring to a file that you can't see, this is probably it.", + parameters: { + type: "object", + properties: {}, + }, + }, +}; diff --git a/core/tools/definitions/readFile.ts b/core/tools/definitions/readFile.ts new file mode 100644 index 0000000000..00e16ba32e --- /dev/null +++ b/core/tools/definitions/readFile.ts @@ -0,0 +1,25 @@ +import { Tool } from "../.."; +import { BuiltInToolNames } from "../builtIn"; + +export const readFileTool: Tool = { + type: "function", + displayTitle: "Read File", + wouldLikeTo: "read {{{ filepath }}}", + readonly: true, + function: { + name: BuiltInToolNames.ReadFile, + description: + "Use this tool whenever you need to view the contents of a file.", + parameters: { + type: "object", + required: ["filepath"], + properties: { + filepath: { + type: "string", + description: + "The path of the file to read, relative to the root of the workspace.", + }, + }, + }, + }, +}; diff --git a/core/tools/definitions/runTerminalCommand.ts b/core/tools/definitions/runTerminalCommand.ts new file mode 100644 index 0000000000..82b2d342ce --- /dev/null +++ b/core/tools/definitions/runTerminalCommand.ts @@ -0,0 +1,25 @@ +import { Tool } from "../.."; +import { BuiltInToolNames } from "../builtIn"; + +export const runTerminalCommandTool: Tool = { + type: "function", + displayTitle: "Run Terminal Command", + wouldLikeTo: "run a terminal command", + readonly: false, + function: { + name: BuiltInToolNames.RunTerminalCommand, + description: + "Run a terminal command in the current directory. The shell is not stateful and will not remember any previous commands.", + parameters: { + type: "object", + required: ["command"], + properties: { + command: { + type: "string", + description: + "The command to run. This will be passed directly into the shell.", + }, + }, + }, + }, +}; diff --git a/core/tools/definitions/searchWeb.ts b/core/tools/definitions/searchWeb.ts new file mode 100644 index 0000000000..eb8e58c992 --- /dev/null +++ b/core/tools/definitions/searchWeb.ts @@ -0,0 +1,25 @@ +import { Tool } from "../.."; + +import { BuiltInToolNames } from "../builtIn"; + +export const searchWebTool: Tool = { + type: "function", + displayTitle: "Search Web", + wouldLikeTo: 'search the web for "{{{ query }}}"', + readonly: true, + function: { + name: BuiltInToolNames.SearchWeb, + description: + "Performs a web search, returning top results. This tool should only be called for questions that require external knowledge. Common programming questions do not require web search.", + parameters: { + type: "object", + required: ["query"], + properties: { + repo_url: { + type: "string", + description: "The natural language search query", + }, + }, + }, + }, +}; diff --git a/core/tools/definitions/viewDiff.ts b/core/tools/definitions/viewDiff.ts new file mode 100644 index 0000000000..e54086d052 --- /dev/null +++ b/core/tools/definitions/viewDiff.ts @@ -0,0 +1,18 @@ +import { Tool } from "../.."; + +import { BuiltInToolNames } from "../builtIn"; + +export const viewDiffTool: Tool = { + type: "function", + displayTitle: "View Diff", + wouldLikeTo: "view a diff", + readonly: true, + function: { + name: BuiltInToolNames.ViewDiff, + description: "View the current diff of working changes", + parameters: { + type: "object", + properties: {}, + }, + }, +}; diff --git a/core/tools/definitions/viewRepoMap.ts b/core/tools/definitions/viewRepoMap.ts new file mode 100644 index 0000000000..9cf3d319d2 --- /dev/null +++ b/core/tools/definitions/viewRepoMap.ts @@ -0,0 +1,18 @@ +import { Tool } from "../.."; + +import { BuiltInToolNames } from "../builtIn"; + +export const viewRepoMapTool: Tool = { + type: "function", + displayTitle: "View Repo Map", + wouldLikeTo: "view the repository map", + readonly: true, + function: { + name: BuiltInToolNames.ViewRepoMap, + description: "View the repository map", + parameters: { + type: "object", + properties: {}, + }, + }, +}; diff --git a/core/tools/definitions/viewSubdirectory.ts b/core/tools/definitions/viewSubdirectory.ts new file mode 100644 index 0000000000..70b7524bd8 --- /dev/null +++ b/core/tools/definitions/viewSubdirectory.ts @@ -0,0 +1,24 @@ +import { Tool } from "../.."; +import { BuiltInToolNames } from "../builtIn"; + +export const viewSubdirectoryTool: Tool = { + type: "function", + displayTitle: "View Subdirectory", + wouldLikeTo: 'view the contents of "{{{ directory_path }}}"', + readonly: true, + function: { + name: BuiltInToolNames.ViewSubdirectory, + description: "View the contents of a subdirectory", + parameters: { + type: "object", + required: ["directory_path"], + properties: { + directory_path: { + type: "string", + description: + "The path of the subdirectory to view, relative to the root of the workspace", + }, + }, + }, + }, +}; diff --git a/core/tools/implementations/createNewFile.ts b/core/tools/implementations/createNewFile.ts new file mode 100644 index 0000000000..0be187b193 --- /dev/null +++ b/core/tools/implementations/createNewFile.ts @@ -0,0 +1,17 @@ +import { getPathModuleForIde } from "../../util/pathModule"; + +import { ToolImpl } from "."; + +export const createNewFileImpl: ToolImpl = async (args, extras) => { + const pathSep = await extras.ide.pathSep(); + let filepath = args.filepath; + if (!args.filepath.startsWith(pathSep)) { + const pathModule = await getPathModuleForIde(extras.ide); + const workspaceDirs = await extras.ide.getWorkspaceDirs(); + const cwd = workspaceDirs[0]; + filepath = pathModule.join(cwd, filepath); + } + await extras.ide.writeFile(filepath, args.contents); + await extras.ide.openFile(filepath); + return []; +}; diff --git a/core/tools/implementations/exactSearch.ts b/core/tools/implementations/exactSearch.ts new file mode 100644 index 0000000000..ae8f6fd02b --- /dev/null +++ b/core/tools/implementations/exactSearch.ts @@ -0,0 +1,12 @@ +import { ToolImpl } from "."; + +export const exactSearchImpl: ToolImpl = async (args, extras) => { + const content = await extras.ide.getSearchResults(args.query); + return [ + { + name: "Search results", + description: "Results from exact search", + content, + }, + ]; +}; diff --git a/core/tools/implementations/index.ts b/core/tools/implementations/index.ts new file mode 100644 index 0000000000..c0ef15be9c --- /dev/null +++ b/core/tools/implementations/index.ts @@ -0,0 +1,6 @@ +import { ContextItem, ToolExtras } from "../.."; + +export type ToolImpl = ( + parameters: any, + extras: ToolExtras, +) => Promise; diff --git a/core/tools/implementations/readCurrentlyOpenFile.ts b/core/tools/implementations/readCurrentlyOpenFile.ts new file mode 100644 index 0000000000..7fff3b621f --- /dev/null +++ b/core/tools/implementations/readCurrentlyOpenFile.ts @@ -0,0 +1,20 @@ +import { ToolImpl } from "."; +import { getBasename } from "../../util"; + +export const readCurrentlyOpenFileImpl: ToolImpl = async (args, extras) => { + const result = await extras.ide.getCurrentFile(); + + if (!result) { + return []; + } + + const basename = getBasename(result.path); + + return [ + { + name: "Current file", + description: basename, + content: `\`\`\`${basename}\n${result.contents}\n\`\`\``, + }, + ]; +}; diff --git a/core/tools/implementations/readFile.ts b/core/tools/implementations/readFile.ts new file mode 100644 index 0000000000..032db3dce5 --- /dev/null +++ b/core/tools/implementations/readFile.ts @@ -0,0 +1,14 @@ +import { getBasename } from "../../util"; + +import { ToolImpl } from "."; + +export const readFileImpl: ToolImpl = async (args, extras) => { + const content = await extras.ide.readFile(args.filepath); + return [ + { + name: getBasename(args.filepath), + description: args.filepath, + content, + }, + ]; +}; diff --git a/core/tools/implementations/runTerminalCommand.ts b/core/tools/implementations/runTerminalCommand.ts new file mode 100644 index 0000000000..ddc8575c44 --- /dev/null +++ b/core/tools/implementations/runTerminalCommand.ts @@ -0,0 +1,40 @@ +import { ToolImpl } from "."; +const util = require("node:util"); +const asyncExec = util.promisify(require("node:child_process").exec); + +export const runTerminalCommandImpl: ToolImpl = async (args, extras) => { + const ideInfo = await extras.ide.getIdeInfo(); + + if (ideInfo.remoteName === "local" || ideInfo.remoteName === "") { + try { + const output = await asyncExec(args.command, { + cwd: (await extras.ide.getWorkspaceDirs())[0], + }); + return [ + { + name: "Terminal", + description: "Terminal command output", + content: output.stdout ?? "", + }, + ]; + } catch (error: any) { + return [ + { + name: "Terminal", + description: "Terminal command output", + content: error.stderr ?? error.toString(), + }, + ]; + } + } + + await extras.ide.runCommand(args.command); + return [ + { + name: "Terminal", + description: "Terminal command output", + content: + "[Terminal output not available. This is only available in local development environments and not in SSH environments for example.]", + }, + ]; +}; diff --git a/core/tools/implementations/searchWeb.ts b/core/tools/implementations/searchWeb.ts new file mode 100644 index 0000000000..a012126596 --- /dev/null +++ b/core/tools/implementations/searchWeb.ts @@ -0,0 +1,8 @@ +import { fetchSearchResults } from "../../context/providers/WebContextProvider"; + +import { ToolImpl } from "."; + +export const searchWebImpl: ToolImpl = async (args, extras) => { + const webResults = await fetchSearchResults(args.query, 5, extras.fetch); + return webResults; +}; diff --git a/core/tools/implementations/viewDiff.ts b/core/tools/implementations/viewDiff.ts new file mode 100644 index 0000000000..11003b61ad --- /dev/null +++ b/core/tools/implementations/viewDiff.ts @@ -0,0 +1,12 @@ +import { ToolImpl } from "."; + +export const viewDiffImpl: ToolImpl = async (args, extras) => { + const diff = await extras.ide.getDiff(true); + return [ + { + name: "Diff", + description: "The current git diff", + content: diff.join("\n"), + }, + ]; +}; diff --git a/core/tools/implementations/viewRepoMap.ts b/core/tools/implementations/viewRepoMap.ts new file mode 100644 index 0000000000..56c2654162 --- /dev/null +++ b/core/tools/implementations/viewRepoMap.ts @@ -0,0 +1,14 @@ +import generateRepoMap from "../../util/generateRepoMap"; + +import { ToolImpl } from "."; + +export const viewRepoMapImpl: ToolImpl = async (args, extras) => { + const repoMap = await generateRepoMap(extras.llm, extras.ide, {}); + return [ + { + name: "Repo map", + description: "Overview of the repository structure", + content: repoMap, + }, + ]; +}; diff --git a/core/tools/implementations/viewSubdirectory.ts b/core/tools/implementations/viewSubdirectory.ts new file mode 100644 index 0000000000..a8204ca6a2 --- /dev/null +++ b/core/tools/implementations/viewSubdirectory.ts @@ -0,0 +1,27 @@ +import generateRepoMap from "../../util/generateRepoMap"; +import { resolveRelativePathInWorkspace } from "../../util/ideUtils"; + +import { ToolImpl } from "."; + +export const viewSubdirectoryImpl: ToolImpl = async (args: any, extras) => { + const { directory_path } = args; + const absolutePath = await resolveRelativePathInWorkspace( + directory_path, + extras.ide, + ); + + if (!absolutePath) { + throw new Error(`Directory path "${directory_path}" does not exist.`); + } + + const repoMap = await generateRepoMap(extras.llm, extras.ide, { + dirs: [absolutePath], + }); + return [ + { + name: "Repo map", + description: `Map of ${directory_path}`, + content: repoMap, + }, + ]; +}; diff --git a/core/tools/index.ts b/core/tools/index.ts new file mode 100644 index 0000000000..c27860a6f9 --- /dev/null +++ b/core/tools/index.ts @@ -0,0 +1,21 @@ +import { createNewFileTool } from "./definitions/createNewFile"; +import { exactSearchTool } from "./definitions/exactSearch"; +import { readCurrentlyOpenFileTool } from "./definitions/readCurrentlyOpenFile"; +import { readFileTool } from "./definitions/readFile"; +import { runTerminalCommandTool } from "./definitions/runTerminalCommand"; +import { searchWebTool } from "./definitions/searchWeb"; +import { viewDiffTool } from "./definitions/viewDiff"; +import { viewRepoMapTool } from "./definitions/viewRepoMap"; +import { viewSubdirectoryTool } from "./definitions/viewSubdirectory"; + +export const allTools = [ + readFileTool, + createNewFileTool, + runTerminalCommandTool, + viewSubdirectoryTool, + viewRepoMapTool, + exactSearchTool, + searchWebTool, + viewDiffTool, + readCurrentlyOpenFileTool, +]; diff --git a/core/util/chatDescriber.ts b/core/util/chatDescriber.ts index b8ca609a21..b52fd7696a 100644 --- a/core/util/chatDescriber.ts +++ b/core/util/chatDescriber.ts @@ -1,5 +1,4 @@ import { ILLM, LLMFullCompletionOptions } from ".."; -import { stripImages } from "../llm/images"; import { removeQuotesAndEscapes } from "."; @@ -35,7 +34,6 @@ export class ChatDescriber { } // Clean up and distill the message we want to send to the LLM - message = stripImages(message); message = removeCodeBlocksAndTrim(message); if (!message) { diff --git a/core/util/incrementalParseJson.ts b/core/util/incrementalParseJson.ts new file mode 100644 index 0000000000..21385ad6a9 --- /dev/null +++ b/core/util/incrementalParseJson.ts @@ -0,0 +1,13 @@ +import { parse } from "partial-json"; + +export function incrementalParseJson(raw: string): [boolean, any] { + try { + return [true, JSON.parse(raw)]; + } catch (e) { + try { + return [false, parse(raw)]; + } catch (e2) { + return [false, {}]; + } + } +} diff --git a/core/util/messageContent.ts b/core/util/messageContent.ts new file mode 100644 index 0000000000..b42bb082a5 --- /dev/null +++ b/core/util/messageContent.ts @@ -0,0 +1,45 @@ +import { + ChatMessage, + ContextItem, + MessageContent, + MessagePart, +} from "../index.js"; + +export function stripImages(messageContent: MessageContent): string { + if (typeof messageContent === "string") { + return messageContent; + } + + return messageContent + .filter((part) => part.type === "text") + .map((part) => part.text) + .join("\n"); +} + +export function renderChatMessage(message: ChatMessage): string { + switch (message.role) { + case "user": + case "assistant": + case "system": + return stripImages(message.content); + case "tool": + return message.content; + } +} + +export function renderContextItems(contextItems: ContextItem[]): string { + return contextItems.map((item) => item.content).join("\n\n"); +} + +export function normalizeToMessageParts(message: ChatMessage): MessagePart[] { + switch (message.role) { + case "user": + case "assistant": + case "system": + return Array.isArray(message.content) + ? message.content + : [{ type: "text", text: message.content }]; + case "tool": + return [{ type: "text", text: message.content }]; + } +} diff --git a/docs/docs/customize/context-providers.md b/docs/docs/customize/context-providers.md index a423ff4dc8..2c7708cacc 100644 --- a/docs/docs/customize/context-providers.md +++ b/docs/docs/customize/context-providers.md @@ -495,6 +495,26 @@ Reference the architecture and platform of your current operating system. } ``` +### Model Context Protocol + +The [Model Context Protocol](https://modelcontextprotocol.io/introduction) is a standard proposed by Anthropic to unify prompts, context, and tool use. Continue supports any MCP server with the MCP context provider. Read their [quickstart](https://modelcontextprotocol.io/quickstart) to learn how to set up a local server and then configure your `config.json` like this: + +```json +{ + "experimental": { + "modelContextProtocolServer": { + "transport": { + "type": "stdio", + "command": "uvx", + "args": ["mcp-server-sqlite", "--db-path", "/Users/NAME/test.db"] + } + } + } +} +``` + +You'll then be able to type "@" and see "MCP" in the context providers dropdown. + ### `@HTTP` The HttpContextProvider makes a POST request to the url passed in the configuration. The server must return 200 OK with a ContextItem object or an array of ContextItems. diff --git a/docs/docs/customize/slash-commands.md b/docs/docs/customize/slash-commands.md index 96988e8587..20283ae899 100644 --- a/docs/docs/customize/slash-commands.md +++ b/docs/docs/customize/slash-commands.md @@ -106,3 +106,21 @@ The onboard slash command helps to familiarize yourself with a new project by an ] } ``` + +### Model Context Protocol + +The [Model Context Protocol](https://modelcontextprotocol.io/introduction) is a standard proposed by Anthropic to unify prompts, context, and tool use. Continue supports MCP "prompts" by creating slash commands. Read their [quickstart](https://modelcontextprotocol.io/quickstart) to learn how to set up a local server and then configure your `config.json` like this: + +```json +{ + "experimental": { + "modelContextProtocolServer": { + "transport": { + "type": "stdio", + "command": "uvx", + "args": ["mcp-server-sqlite", "--db-path", "/Users/NAME/test.db"] + } + } + } +} +``` diff --git a/extensions/vscode/config_schema.json b/extensions/vscode/config_schema.json index 9b17fe65d0..ba68ed19a9 100644 --- a/extensions/vscode/config_schema.json +++ b/extensions/vscode/config_schema.json @@ -3048,6 +3048,68 @@ "type": "boolean", "description": "Use Chromium to crawl docs locally. Useful if the default Cheerio crawler fails on sites that require JavaScript rendering. Downloads and installs Chromium to `~/.continue/.utils`.", "default": false + }, + "useTools": { + "type": "boolean", + "description": "Allow tool use. Currently only supported with Claude 3.5 Sonnet", + "default": true + }, + "modelContextProtocolServer": { + "type": "object", + "properties": { + "transport": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["stdio"] + }, + "command": { + "type": "string" + }, + "args": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": ["type", "command", "args"] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["websocket"] + }, + "url": { + "type": "string", + "format": "uri" + } + }, + "required": ["type", "url"] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["sse"] + }, + "url": { + "type": "string", + "format": "uri" + } + }, + "required": ["type", "url"] + } + ] + } + }, + "required": ["transport"] } } } diff --git a/extensions/vscode/package-lock.json b/extensions/vscode/package-lock.json index e0bc72e92e..ef1df33320 100644 --- a/extensions/vscode/package-lock.json +++ b/extensions/vscode/package-lock.json @@ -1,12 +1,12 @@ { "name": "continue", - "version": "0.9.235", + "version": "0.9.237", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "continue", - "version": "0.9.235", + "version": "0.9.237", "license": "Apache-2.0", "dependencies": { "@electron/rebuild": "^3.2.10", @@ -99,6 +99,7 @@ "@aws-sdk/credential-providers": "^3.620.1", "@continuedev/config-types": "^1.0.13", "@continuedev/llm-info": "^1.0.2", + "@modelcontextprotocol/sdk": "^1.0.0", "@mozilla/readability": "^0.5.0", "@octokit/rest": "^20.0.2", "@typescript-eslint/eslint-plugin": "^7.8.0", @@ -109,6 +110,7 @@ "cheerio": "^1.0.0-rc.12", "commander": "^12.0.0", "comment-json": "^4.2.3", + "dbinfoz": "^0.14.0", "diff": "^7.0.0", "dotenv": "^16.4.5", "fastest-levenshtein": "^1.0.16", @@ -131,6 +133,7 @@ "onnxruntime-node": "1.14.0", "openai": "^4.20.1", "p-limit": "^6.1.0", + "partial-json": "^0.1.7", "pg": "^8.11.3", "posthog-node": "^3.6.3", "puppeteer": "^22.4.0", diff --git a/extensions/vscode/package.json b/extensions/vscode/package.json index f3bb0896f4..af638882d0 100644 --- a/extensions/vscode/package.json +++ b/extensions/vscode/package.json @@ -57,6 +57,12 @@ ], "main": "./out/extension.js", "contributes": { + "authentication": [ + { + "id": "continue", + "label": "Continue" + } + ], "languages": [ { "filenames": [ diff --git a/extensions/vscode/src/VsCodeIde.ts b/extensions/vscode/src/VsCodeIde.ts index cc0d8c3435..5138ba490f 100644 --- a/extensions/vscode/src/VsCodeIde.ts +++ b/extensions/vscode/src/VsCodeIde.ts @@ -531,8 +531,14 @@ class VsCodeIde implements IDE { "bin", "rg", ), - ["-i", "-C", "2", "--", `${query}`, "."], //no regex - //["-i", "-C", "2", "-e", `${query}`, "."], //use regex + [ + "-i", // Case-insensitive search + "-C", + "2", // Show 2 lines of context + "-e", + query, // Pattern to search for + ".", // Directory to search in + ], { cwd: dir }, ); let output = ""; @@ -546,6 +552,9 @@ class VsCodeIde implements IDE { p.on("close", (code) => { if (code === 0) { resolve(output); + } else if (code === 1) { + // No matches + resolve("No matches found"); } else { reject(new Error(`Process exited with code ${code}`)); } diff --git a/extensions/vscode/src/autocomplete/completionProvider.ts b/extensions/vscode/src/autocomplete/completionProvider.ts index ffe69a2c4e..5d6faef3f6 100644 --- a/extensions/vscode/src/autocomplete/completionProvider.ts +++ b/extensions/vscode/src/autocomplete/completionProvider.ts @@ -19,8 +19,8 @@ import { stopStatusBarLoading, } from "./statusBar"; -import type { TabAutocompleteModel } from "../util/loadAutocompleteModel"; import type { IDE } from "core"; +import type { TabAutocompleteModel } from "../util/loadAutocompleteModel"; interface VsCodeCompletionInput { document: vscode.TextDocument; diff --git a/extensions/vscode/src/extension/VsCodeMessenger.ts b/extensions/vscode/src/extension/VsCodeMessenger.ts index 8a70038a34..0f5a13cffc 100644 --- a/extensions/vscode/src/extension/VsCodeMessenger.ts +++ b/extensions/vscode/src/extension/VsCodeMessenger.ts @@ -4,7 +4,6 @@ import * as path from "node:path"; import { ConfigHandler } from "core/config/ConfigHandler"; import { getModelByRole } from "core/config/util"; import { applyCodeBlock } from "core/edit/lazy/applyCodeBlock"; -import { stripImages } from "core/llm/images"; import { FromCoreProtocol, FromWebviewProtocol, @@ -21,6 +20,7 @@ import { getBasename } from "core/util"; import { InProcessMessenger, Message } from "core/util/messenger"; import * as vscode from "vscode"; +import { stripImages } from "core/util/messageContent"; import { VerticalDiffManager } from "../diff/vertical/manager"; import EditDecorationManager from "../quickEdit/EditDecorationManager"; import { diff --git a/gui/package-lock.json b/gui/package-lock.json index 3a69aaaeae..1b087fdf96 100644 --- a/gui/package-lock.json +++ b/gui/package-lock.json @@ -30,6 +30,7 @@ "downshift": "^7.6.0", "lodash": "^4.17.21", "minisearch": "^7.0.2", + "mustache": "^4.2.0", "onigasm": "^2.2.5", "posthog-js": "^1.130.1", "prismjs": "^1.29.0", @@ -47,6 +48,7 @@ "react-tooltip": "^5.18.0", "redux-persist": "^6.0.0", "redux-persist-transform-filter": "^0.0.22", + "redux-thunk": "^3.1.0", "rehype-highlight": "^7.0.0", "rehype-katex": "^7.0.0", "rehype-wrap-all": "^1.1.0", @@ -69,6 +71,7 @@ "@testing-library/react": "^16.0.1", "@testing-library/user-event": "^14.5.2", "@types/lodash": "^4.17.6", + "@types/mustache": "^4.2.5", "@types/node": "^20.5.6", "@types/node-fetch": "^2.6.4", "@types/react": "^18.3.11", @@ -105,6 +108,7 @@ "@aws-sdk/credential-providers": "^3.620.1", "@continuedev/config-types": "^1.0.13", "@continuedev/llm-info": "^1.0.2", + "@modelcontextprotocol/sdk": "^1.0.0", "@mozilla/readability": "^0.5.0", "@octokit/rest": "^20.0.2", "@typescript-eslint/eslint-plugin": "^7.8.0", @@ -138,6 +142,7 @@ "onnxruntime-node": "1.14.0", "openai": "^4.20.1", "p-limit": "^6.1.0", + "partial-json": "^0.1.7", "pg": "^8.11.3", "posthog-node": "^3.6.3", "puppeteer": "^22.4.0", @@ -1841,6 +1846,22 @@ } } }, + "node_modules/@reduxjs/toolkit/node_modules/redux": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/redux/-/redux-4.2.1.tgz", + "integrity": "sha512-LAUYz4lc+Do8/g7aeRa8JkyDErK6ekstQaqWQrNRW//MY1TvCEpMtpTWvlQ+FPbWCx+Xixu/6SHt5N0HR+SB4w==", + "dependencies": { + "@babel/runtime": "^7.9.2" + } + }, + "node_modules/@reduxjs/toolkit/node_modules/redux-thunk": { + "version": "2.4.2", + "resolved": "https://registry.npmjs.org/redux-thunk/-/redux-thunk-2.4.2.tgz", + "integrity": "sha512-+P3TjtnP0k/FEjcBL5FZpoovtvrTNT/UXd4/sluaSyrURlSlhLSzEdfsTBW7WsKB6yPvgd7q/iZPICFjW4o57Q==", + "peerDependencies": { + "redux": "^4" + } + }, "node_modules/@reduxjs/toolkit/node_modules/reselect": { "version": "4.1.8", "resolved": "https://registry.npmjs.org/reselect/-/reselect-4.1.8.tgz", @@ -3102,6 +3123,12 @@ "resolved": "https://registry.npmjs.org/@types/ms/-/ms-0.7.34.tgz", "integrity": "sha512-nG96G3Wp6acyAgJqGasjODb+acrI7KltPiRxzHPXnP3NgI28bpQDRv53olbqGXbfcgF5aiiHmO3xpwEpS5Ld9g==" }, + "node_modules/@types/mustache": { + "version": "4.2.5", + "resolved": "https://registry.npmjs.org/@types/mustache/-/mustache-4.2.5.tgz", + "integrity": "sha512-PLwiVvTBg59tGFL/8VpcGvqOu3L4OuveNvPi0EYbWchRdEVP++yRUXJPFl+CApKEq13017/4Nf7aQ5lTtHUNsA==", + "dev": true + }, "node_modules/@types/node": { "version": "20.16.5", "resolved": "https://registry.npmjs.org/@types/node/-/node-20.16.5.tgz", @@ -9338,6 +9365,14 @@ "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==" }, + "node_modules/mustache": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/mustache/-/mustache-4.2.0.tgz", + "integrity": "sha512-71ippSywq5Yb7/tVYyGbkBggbU8H3u5Rz56fH60jGFgr8uHwxs+aSKeqmluIVzM0m0kB7xQjKS6qPfd0b2ZoqQ==", + "bin": { + "mustache": "bin/mustache" + } + }, "node_modules/mz": { "version": "2.7.0", "resolved": "https://registry.npmjs.org/mz/-/mz-2.7.0.tgz", @@ -10974,12 +11009,10 @@ } }, "node_modules/redux": { - "version": "4.2.1", - "resolved": "https://registry.npmjs.org/redux/-/redux-4.2.1.tgz", - "integrity": "sha512-LAUYz4lc+Do8/g7aeRa8JkyDErK6ekstQaqWQrNRW//MY1TvCEpMtpTWvlQ+FPbWCx+Xixu/6SHt5N0HR+SB4w==", - "dependencies": { - "@babel/runtime": "^7.9.2" - } + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/redux/-/redux-5.0.1.tgz", + "integrity": "sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==", + "peer": true }, "node_modules/redux-persist": { "version": "6.0.0", @@ -11002,11 +11035,11 @@ } }, "node_modules/redux-thunk": { - "version": "2.4.2", - "resolved": "https://registry.npmjs.org/redux-thunk/-/redux-thunk-2.4.2.tgz", - "integrity": "sha512-+P3TjtnP0k/FEjcBL5FZpoovtvrTNT/UXd4/sluaSyrURlSlhLSzEdfsTBW7WsKB6yPvgd7q/iZPICFjW4o57Q==", + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/redux-thunk/-/redux-thunk-3.1.0.tgz", + "integrity": "sha512-NW2r5T6ksUKXCabzhL9z+h206HQw/NJkcLm1GPImRQ8IzfXwRGqjVhKJGauHirT0DAuyy6hjdnMZaRoAcy0Klw==", "peerDependencies": { - "redux": "^4" + "redux": "^5.0.0" } }, "node_modules/reflect.getprototypeof": { diff --git a/gui/package.json b/gui/package.json index 4365e2ba26..c305006b89 100644 --- a/gui/package.json +++ b/gui/package.json @@ -38,6 +38,7 @@ "downshift": "^7.6.0", "lodash": "^4.17.21", "minisearch": "^7.0.2", + "mustache": "^4.2.0", "onigasm": "^2.2.5", "posthog-js": "^1.130.1", "prismjs": "^1.29.0", @@ -55,6 +56,7 @@ "react-tooltip": "^5.18.0", "redux-persist": "^6.0.0", "redux-persist-transform-filter": "^0.0.22", + "redux-thunk": "^3.1.0", "rehype-highlight": "^7.0.0", "rehype-katex": "^7.0.0", "rehype-wrap-all": "^1.1.0", @@ -77,6 +79,7 @@ "@testing-library/react": "^16.0.1", "@testing-library/user-event": "^14.5.2", "@types/lodash": "^4.17.6", + "@types/mustache": "^4.2.5", "@types/node": "^20.5.6", "@types/node-fetch": "^2.6.4", "@types/react": "^18.3.11", diff --git a/gui/src/App.tsx b/gui/src/App.tsx index 2e09970be9..953e8a80c7 100644 --- a/gui/src/App.tsx +++ b/gui/src/App.tsx @@ -31,10 +31,6 @@ const router = createMemoryRouter([ path: "/", element: , }, - { - path: "/edit", - element: , - }, { path: "/history", element: , @@ -43,6 +39,10 @@ const router = createMemoryRouter([ path: "/stats", element: , }, + { + path: "/edit", + element: , + }, { path: "/settings", element: , diff --git a/gui/src/components/ChatScrollAnchor.tsx b/gui/src/components/ChatScrollAnchor.tsx index 40c0f9d190..9b32e7283b 100644 --- a/gui/src/components/ChatScrollAnchor.tsx +++ b/gui/src/components/ChatScrollAnchor.tsx @@ -1,5 +1,6 @@ import { useEffect } from "react"; import { useInView } from "react-intersection-observer"; +import { vscBackground } from "."; interface ChatScrollAnchorProps { trackVisibility: boolean; @@ -25,5 +26,13 @@ export function ChatScrollAnchor({ } }, [inView, entry, isAtBottom, trackVisibility]); - return
; + return ( +
+ ); } diff --git a/gui/src/components/InfoHover.tsx b/gui/src/components/InfoHover.tsx index aa6541c5bf..4b0774efe7 100644 --- a/gui/src/components/InfoHover.tsx +++ b/gui/src/components/InfoHover.tsx @@ -1,17 +1,28 @@ import { InformationCircleIcon } from "@heroicons/react/24/outline"; +import { ReactNode } from "react"; import { ToolTip } from "./gui/Tooltip"; -const InfoHover = ({ msg }: { msg: string }) => { - const id = `info-hover-${encodeURIComponent(msg)}`; +const DEFAULT_SIZE = "5"; + +const InfoHover = ({ + msg, + size, + id, +}: { + id: string; + msg: ReactNode; + size?: string; +}) => { + const dataTooltipId = `info-hover-${encodeURIComponent(id)}`; return ( <> - + {msg} diff --git a/gui/src/components/StepContainer/EditActions.tsx b/gui/src/components/StepContainer/EditActions.tsx index 733ae623b2..e1afee4ade 100644 --- a/gui/src/components/StepContainer/EditActions.tsx +++ b/gui/src/components/StepContainer/EditActions.tsx @@ -1,10 +1,11 @@ +import { ChatHistoryItem } from "core"; import { useSelector } from "react-redux"; import { RootState } from "../../redux/store"; import AcceptRejectAllButtons from "./AcceptRejectAllButtons"; import FeedbackButtons from "./FeedbackButtons"; import UndoAndRedoButtons from "./UndoAndRedoButtons"; -import { ChatHistoryItem } from "core"; -import { stripImages } from "core/llm/images"; + +import { renderChatMessage } from "core/util/messageContent"; import { CopyIconButton } from "../gui/CopyIconButton"; export interface EditActionsProps { @@ -61,7 +62,7 @@ export default function EditActions({ index, item }: EditActionsProps) {
diff --git a/gui/src/components/StepContainer/ResponseActions.tsx b/gui/src/components/StepContainer/ResponseActions.tsx index 8e09cb1422..70c4d73c09 100644 --- a/gui/src/components/StepContainer/ResponseActions.tsx +++ b/gui/src/components/StepContainer/ResponseActions.tsx @@ -1,6 +1,6 @@ import { BarsArrowDownIcon, TrashIcon } from "@heroicons/react/24/outline"; import { ChatHistoryItem } from "core"; -import { stripImages } from "core/llm/images"; +import { renderChatMessage } from "core/util/messageContent"; import { useSelector } from "react-redux"; import { RootState } from "../../redux/store"; import { CopyIconButton } from "../gui/CopyIconButton"; @@ -14,6 +14,7 @@ export interface ResponseActionsProps { index: number; onDelete: () => void; item: ChatHistoryItem; + shouldHideActions: boolean; } export default function ResponseActions({ @@ -22,6 +23,7 @@ export default function ResponseActions({ item, isTruncated, onDelete, + shouldHideActions, }: ResponseActionsProps) { const isInEditMode = useSelector( (store: RootState) => store.editModeState.isInEditMode, @@ -32,28 +34,37 @@ export default function ResponseActions({ } return ( -
- {isTruncated && ( - - - - )} +
+ {shouldHideActions || ( + <> + {isTruncated && ( + + + + )} + + + + - - - + - - + + + )}
); } diff --git a/gui/src/components/StepContainer/StepContainer.tsx b/gui/src/components/StepContainer/StepContainer.tsx index 7620f16ca9..d7f73f91fd 100644 --- a/gui/src/components/StepContainer/StepContainer.tsx +++ b/gui/src/components/StepContainer/StepContainer.tsx @@ -1,5 +1,5 @@ import { ChatHistoryItem } from "core"; -import { stripImages } from "core/llm/images"; +import { renderChatMessage, stripImages } from "core/util/messageContent"; import { useEffect, useState } from "react"; import { useDispatch, useSelector } from "react-redux"; import styled from "styled-components"; @@ -19,6 +19,7 @@ interface StepContainerProps { const ContentDiv = styled.div<{ fontSize?: number }>` padding-top: 4px; + padding-bottom: 4px; background-color: ${vscBackground}; font-size: ${getFontSize()}px; overflow: hidden; @@ -34,15 +35,23 @@ export default function StepContainer(props: StepContainerProps) { const isInEditMode = useSelector( (store: RootState) => store.editModeState.isInEditMode, ); + + const historyItemAfterThis = useSelector( + (store: RootState) => store.state.history[props.index + 1], + ); + const uiConfig = useUIConfig(); - const shouldHideActions = active && props.isLast; + const shouldHideActions = + (active && props.isLast) || + historyItemAfterThis?.message.role === "assistant"; + // const isStepAheadOfCurCheckpoint = // isInEditMode && Math.floor(props.index / 2) > curCheckpointIndex; useEffect(() => { if (!active) { - const content = stripImages(props.item.message.content).trim(); - const endingPunctuation = [".", "?", "!", "```"]; + const content = renderChatMessage(props.item.message).trim(); + const endingPunctuation = [".", "?", "!", "```", ":"]; // If not ending in punctuation or emoji, we assume the response got truncated if ( @@ -76,10 +85,7 @@ export default function StepContainer(props: StepContainerProps) { return (
{uiConfig?.displayRawMarkdown ? ( @@ -87,7 +93,7 @@ export default function StepContainer(props: StepContainerProps) { className="max-w-full overflow-x-auto whitespace-pre-wrap break-words p-4" style={{ fontSize: getFontSize() - 2 }} > - {stripImages(props.item.message.content)} + {renderChatMessage(props.item.message)} ) : ( )}
diff --git a/gui/src/components/mainInput/ContextItemsPeek.tsx b/gui/src/components/mainInput/ContextItemsPeek.tsx index 3772ba0570..80ed7243a0 100644 --- a/gui/src/components/mainInput/ContextItemsPeek.tsx +++ b/gui/src/components/mainInput/ContextItemsPeek.tsx @@ -1,17 +1,19 @@ -import { ChevronDownIcon, ChevronRightIcon } from "@heroicons/react/24/outline"; +import { + ArrowTopRightOnSquareIcon, + ChevronDownIcon, + ChevronRightIcon, +} from "@heroicons/react/24/outline"; import { ContextItemWithId } from "core"; import { ctxItemToRifWithContents } from "core/commands/util"; +import { getBasename } from "core/util"; import { useContext, useMemo, useState } from "react"; +import { useSelector } from "react-redux"; import { AnimatedEllipsis, lightGray, vscBackground } from ".."; import { IdeMessengerContext } from "../../context/IdeMessenger"; +import { RootState } from "../../redux/store"; import FileIcon from "../FileIcon"; import SafeImg from "../SafeImg"; -import { INSTRUCTIONS_BASE_ITEM } from "core/context/providers/utils"; import { getIconFromDropdownItem } from "./MentionList"; -import { getBasename } from "core/util"; -import { RootState } from "../../redux/store"; -import { useSelector } from "react-redux"; -import { ArrowTopRightOnSquareIcon } from "@heroicons/react/24/outline"; interface ContextItemsPeekProps { contextItems?: ContextItemWithId[]; @@ -133,11 +135,7 @@ function ContextItemsPeek({ const [open, setOpen] = useState(false); const ctxItems = useMemo(() => { - return ( - contextItems?.filter( - (ctxItem) => !ctxItem.name.includes(INSTRUCTIONS_BASE_ITEM.name), - ) ?? [] - ); + return contextItems?.filter((ctxItem) => !ctxItem.hidden) ?? []; }, [contextItems]); const isGatheringContext = useSelector( @@ -154,7 +152,12 @@ function ContextItemsPeek({ } return ( -
+
setOpen((prev) => !prev)} diff --git a/gui/src/components/mainInput/ContinueInputBox.tsx b/gui/src/components/mainInput/ContinueInputBox.tsx index 93ac4fddf8..463203c4e5 100644 --- a/gui/src/components/mainInput/ContinueInputBox.tsx +++ b/gui/src/components/mainInput/ContinueInputBox.tsx @@ -67,6 +67,9 @@ function ContinueInputBox(props: ContinueInputBoxProps) { const availableContextProviders = useSelector( (store: RootState) => store.state.config.contextProviders, ); + const useTools = useSelector( + (store: RootState) => store.state.config.experimental?.useTools !== false, + ); useWebviewListener( "newSessionWithPrompt", @@ -85,7 +88,7 @@ function ContinueInputBox(props: ContinueInputBoxProps) { ); return ( -
+
diff --git a/gui/src/components/mainInput/InputToolbar.tsx b/gui/src/components/mainInput/InputToolbar.tsx index e1fd7ad745..1aec78c921 100644 --- a/gui/src/components/mainInput/InputToolbar.tsx +++ b/gui/src/components/mainInput/InputToolbar.tsx @@ -20,6 +20,8 @@ import { } from "../../util"; import { ToolTip } from "../gui/Tooltip"; import ModelSelect from "../modelSelection/ModelSelect"; +import HoverItem from "./InputToolbar/HoverItem"; +import ToggleToolsButton from "./InputToolbar/ToggleToolsButton"; const StyledDiv = styled.div<{ isHidden?: boolean }>` padding-top: 4px; @@ -38,17 +40,6 @@ const StyledDiv = styled.div<{ isHidden?: boolean }>` } `; -const HoverItem = styled.span<{ isActive?: boolean }>` - padding: 0 4px; - padding-top: 2px; - padding-bottom: 2px; - cursor: pointer; - transition: - color 200ms, - background-color 200ms, - box-shadow 200ms; -`; - const EnterButton = styled.button` all: unset; padding: 2px 4px; @@ -69,6 +60,7 @@ export interface ToolbarOptions { hideAddContext?: boolean; enterText?: string; hideSelectModel?: boolean; + hideTools?: boolean; } interface InputToolbarProps { @@ -145,6 +137,8 @@ function InputToolbar(props: InputToolbarProps) { )} + + {props.toolbarOptions?.hideTools || }
@@ -153,7 +147,7 @@ function InputToolbar(props: InputToolbarProps) {
{props.activeKey === "Alt" ? ( - {`${getAltKeyLabel()}⏎ + {`${getAltKeyLabel()}⏎ ${useActiveFile ? "No active file" : "Active file"}`} ) : ( diff --git a/gui/src/components/mainInput/InputToolbar/HoverItem.tsx b/gui/src/components/mainInput/InputToolbar/HoverItem.tsx new file mode 100644 index 0000000000..e668a69f2a --- /dev/null +++ b/gui/src/components/mainInput/InputToolbar/HoverItem.tsx @@ -0,0 +1,14 @@ +import styled from "styled-components"; + +const HoverItem = styled.span<{ isActive?: boolean }>` + padding: 0 4px; + padding-top: 2px; + padding-bottom: 2px; + cursor: pointer; + transition: + color 200ms, + background-color 200ms, + box-shadow 200ms; +`; + +export default HoverItem; diff --git a/gui/src/components/mainInput/InputToolbar/ToggleToolsButton.tsx b/gui/src/components/mainInput/InputToolbar/ToggleToolsButton.tsx new file mode 100644 index 0000000000..1c91aee042 --- /dev/null +++ b/gui/src/components/mainInput/InputToolbar/ToggleToolsButton.tsx @@ -0,0 +1,161 @@ +import { Listbox, Transition } from "@headlessui/react"; +import { + ChevronDownIcon, + ChevronUpIcon, + WrenchScrewdriverIcon, +} from "@heroicons/react/24/outline"; +import { modelSupportsTools } from "core/llm/autodetect"; +import { useEffect, useRef, useState } from "react"; +import { useDispatch, useSelector } from "react-redux"; +import styled from "styled-components"; +import { defaultBorderRadius, lightGray, vscForeground } from "../.."; +import { defaultModelSelector } from "../../../redux/selectors/modelSelectors"; +import { toggleUseTools } from "../../../redux/slices/uiStateSlice"; +import { RootState } from "../../../redux/store"; +import { getFontSize } from "../../../util"; +import InfoHover from "../../InfoHover"; +import HoverItem from "./HoverItem"; +import ToolDropdownItem from "./ToolDropdownItem"; + +const BackgroundDiv = styled.div<{ useTools: boolean }>` + background-color: ${(props) => + props.useTools ? `${lightGray}33` : "transparent"}; + border-radius: ${defaultBorderRadius}; + padding: 1px; + + font-size: ${getFontSize() - 4}px; + + display: flex; + flex-direction: row; + align-items: center; + gap: 2px; + + transition: background-color 200ms; +`; + +export default function ToolDropdown() { + const buttonRef = useRef(null); + const dispatch = useDispatch(); + const [isDropdownOpen, setDropdownOpen] = useState(false); + + const useTools = useSelector((store: RootState) => store.uiState.useTools); + const availableTools = useSelector( + (store: RootState) => store.state.config.tools, + ); + const defaultModel = useSelector(defaultModelSelector); + + const [showAbove, setShowAbove] = useState(false); + + useEffect(() => { + const checkPosition = () => { + if (buttonRef.current) { + const rect = buttonRef.current.getBoundingClientRect(); + const windowHeight = window.innerHeight; + const spaceBelow = windowHeight - rect.bottom; + setShowAbove(spaceBelow < 250); + } + }; + + if (isDropdownOpen) { + checkPosition(); + } + }, [isDropdownOpen]); + + if (!defaultModel || !modelSupportsTools(defaultModel.model)) { + return null; + } + + return ( + dispatch(toggleUseTools())}> + + + {useTools && ( + <> + Tools + +
+ {}}> + { + e.stopPropagation(); + setDropdownOpen((prev) => !prev); + }} + className="text-lightgray flex cursor-pointer items-center border-none bg-transparent outline-none" + > + {isDropdownOpen ? ( + + ) : ( + + )} + + + +
+
+ Tool policies{" "} + +

+ + Automatic: + {" "} + Can be used without asking +

+

+ + Allowed: + {" "} + Will ask before using +

+

+ Disabled:{" "} + Cannot be used +

+
+ } + /> +
+
+
+ {availableTools.map((tool) => ( + + + + ))} +
+ + + +
+ + )} + + + ); +} diff --git a/gui/src/components/mainInput/InputToolbar/ToolDropdownItem.tsx b/gui/src/components/mainInput/InputToolbar/ToolDropdownItem.tsx new file mode 100644 index 0000000000..7098733ebf --- /dev/null +++ b/gui/src/components/mainInput/InputToolbar/ToolDropdownItem.tsx @@ -0,0 +1,59 @@ +import { Tool } from "core"; +import { useEffect } from "react"; +import { useDispatch, useSelector } from "react-redux"; +import { addTool, toggleToolSetting } from "../../../redux/slices/uiStateSlice"; +import { RootState } from "../../../redux/store"; + +interface ToolDropdownItemProps { + tool: Tool; +} + +function ToolDropdownItem(props: ToolDropdownItemProps) { + const dispatch = useDispatch(); + const settings = useSelector( + (state: RootState) => state.uiState.toolSettings[props.tool.function.name], + ); + + useEffect(() => { + if (!settings) { + dispatch(addTool(props.tool)); + } + }, [props.tool.function.name, settings]); + + if (!settings) { + return null; + } + + return ( +
{ + dispatch(toggleToolSetting(props.tool.function.name)); + e.stopPropagation(); + e.preventDefault(); + }} + > + + {props.tool.displayTitle}{" "} + {/* */} + +
+ {(settings === "allowedWithPermission" || settings === undefined) && ( + Allowed + )} + {settings === "allowedWithoutPermission" && ( + Automatic + )} + {settings === "disabled" && ( + Disabled + )} +
+
+ ); +} + +export default ToolDropdownItem; diff --git a/gui/src/components/mainInput/resolveInput.ts b/gui/src/components/mainInput/resolveInput.ts index 96ab0d93ba..7d3ac63099 100644 --- a/gui/src/components/mainInput/resolveInput.ts +++ b/gui/src/components/mainInput/resolveInput.ts @@ -7,7 +7,7 @@ import { MessagePart, RangeInFile, } from "core"; -import { stripImages } from "core/llm/images"; +import { stripImages } from "core/util/messageContent"; import { IIdeMessenger } from "../../context/IdeMessenger"; import { Dispatch } from "@reduxjs/toolkit"; import { setIsGatheringContext } from "../../redux/slices/stateSlice"; diff --git a/gui/src/components/markdown/StepContainerPreToolbar/ApplyActions.tsx b/gui/src/components/markdown/StepContainerPreToolbar/ApplyActions.tsx index d96afc7fe2..42ee8e0486 100644 --- a/gui/src/components/markdown/StepContainerPreToolbar/ApplyActions.tsx +++ b/gui/src/components/markdown/StepContainerPreToolbar/ApplyActions.tsx @@ -1,10 +1,10 @@ -import { CheckIcon, XMarkIcon, PlayIcon } from "@heroicons/react/24/outline"; -import Spinner from "./Spinner"; -import { lightGray, vscForeground } from "../.."; +import { CheckIcon, PlayIcon, XMarkIcon } from "@heroicons/react/24/outline"; +import { ApplyState } from "core"; import { useState } from "react"; +import { lightGray, vscForeground } from "../.."; import { getMetaKeyLabel } from "../../../util"; +import Spinner from "./Spinner"; import { ToolbarButtonWithTooltip } from "./ToolbarButtonWithTooltip"; -import { ApplyState } from "core"; interface ApplyActionsProps { applyState?: ApplyState; diff --git a/gui/src/components/markdown/StepContainerPreToolbar/RunInTerminalButton.tsx b/gui/src/components/markdown/StepContainerPreToolbar/RunInTerminalButton.tsx new file mode 100644 index 0000000000..35bc543497 --- /dev/null +++ b/gui/src/components/markdown/StepContainerPreToolbar/RunInTerminalButton.tsx @@ -0,0 +1,35 @@ +import { CommandLineIcon } from "@heroicons/react/24/outline"; +import { useContext } from "react"; +import { lightGray, vscForeground } from "../.."; +import { IdeMessengerContext } from "../../../context/IdeMessenger"; + +interface RunInTerminalButtonProps { + command: string; +} + +export default function RunInTerminalButton({ + command, +}: RunInTerminalButtonProps) { + const ideMessenger = useContext(IdeMessengerContext); + + function runInTerminal() { + void ideMessenger.post("runCommand", { command }); + } + + return ( +
+
+ <> + + Run + +
+
+ ); +} diff --git a/gui/src/components/markdown/StepContainerPreToolbar/StepContainerPreToolbar.tsx b/gui/src/components/markdown/StepContainerPreToolbar/StepContainerPreToolbar.tsx index b147fff2ed..ca7d78d971 100644 --- a/gui/src/components/markdown/StepContainerPreToolbar/StepContainerPreToolbar.tsx +++ b/gui/src/components/markdown/StepContainerPreToolbar/StepContainerPreToolbar.tsx @@ -7,14 +7,15 @@ import { v4 as uuidv4 } from "uuid"; import { defaultBorderRadius, lightGray, vscEditorBackground } from "../.."; import { IdeMessengerContext } from "../../../context/IdeMessenger"; import { useWebviewListener } from "../../../hooks/useWebviewListener"; +import { defaultModelSelector } from "../../../redux/selectors/modelSelectors"; import { RootState } from "../../../redux/store"; import { getFontSize } from "../../../util"; -import { childrenToText } from "../utils"; +import { childrenToText, isTerminalCodeBlock } from "../utils"; import ApplyActions from "./ApplyActions"; import CopyButton from "./CopyButton"; import FileInfo from "./FileInfo"; import GeneratingCodeLoader from "./GeneratingCodeLoader"; -import { defaultModelSelector } from "../../../redux/selectors/modelSelectors"; +import RunInTerminalButton from "./RunInTerminalButton"; const fadeInAnimation = keyframes` from { @@ -60,6 +61,8 @@ export interface StepContainerPreToolbarProps { codeBlockIndex: number; // To track which codeblock we are applying range?: string; children: any; + expanded?: boolean; + hideApply?: boolean; } export default function StepContainerPreToolbar( @@ -71,8 +74,10 @@ export default function StepContainerPreToolbar( const isInEditMode = useSelector( (state: RootState) => state.editModeState.isInEditMode, ); + const [isExpanded, setIsExpanded] = useState( + props.expanded ?? (isInEditMode ? false : true), + ); const active = useSelector((state: RootState) => state.state.active); - const [isExpanded, setIsExpanded] = useState(isInEditMode ? false : true); const [codeBlockContent, setCodeBlockContent] = useState(""); const isChatActive = useSelector((state: RootState) => state.state.active); @@ -198,12 +203,17 @@ export default function StepContainerPreToolbar( {!isGeneratingCodeBlock && ( <> - + {props.hideApply || + (isTerminalCodeBlock(props.language, props.codeBlockContent) ? ( + + ) : ( + + ))} )}
diff --git a/gui/src/components/markdown/StyledMarkdownPreview.tsx b/gui/src/components/markdown/StyledMarkdownPreview.tsx index caad6428c5..f857a1e18b 100644 --- a/gui/src/components/markdown/StyledMarkdownPreview.tsx +++ b/gui/src/components/markdown/StyledMarkdownPreview.tsx @@ -1,4 +1,7 @@ -import { memo, useCallback, useEffect, useMemo, useRef } from "react"; +import { SymbolWithRange } from "core"; +import { ctxItemToRifWithContents } from "core/commands/util"; +import { memo, useEffect, useRef } from "react"; +import { useSelector } from "react-redux"; import { useRemark } from "react-remark"; import rehypeHighlight, { Options } from "rehype-highlight"; import rehypeKatex from "rehype-katex"; @@ -11,20 +14,17 @@ import { vscEditorBackground, vscForeground, } from ".."; +import useUpdatingRef from "../../hooks/useUpdatingRef"; +import { RootState } from "../../redux/store"; import { getFontSize, isJetBrains } from "../../util"; +import FilenameLink from "./FilenameLink"; import "./katex.css"; import "./markdown.css"; -import { ctxItemToRifWithContents } from "core/commands/util"; -import FilenameLink from "./FilenameLink"; +import StepContainerPreActionButtons from "./StepContainerPreActionButtons"; import StepContainerPreToolbar from "./StepContainerPreToolbar"; +import SymbolLink from "./SymbolLink"; import { SyntaxHighlightedPre } from "./SyntaxHighlightedPre"; -import StepContainerPreActionButtons from "./StepContainerPreActionButtons"; import { patchNestedMarkdown } from "./utils/patchNestedMarkdown"; -import { RootState } from "../../redux/store"; -import { ContextItemWithId, SymbolWithRange } from "core"; -import SymbolLink from "./SymbolLink"; -import { useSelector } from "react-redux"; -import useUpdatingRef from "../../hooks/useUpdatingRef"; const StyledMarkdown = styled.div<{ fontSize?: number; diff --git a/gui/src/components/markdown/SyntaxHighlightedPre.tsx b/gui/src/components/markdown/SyntaxHighlightedPre.tsx index a3caba685b..6e2be08f27 100644 --- a/gui/src/components/markdown/SyntaxHighlightedPre.tsx +++ b/gui/src/components/markdown/SyntaxHighlightedPre.tsx @@ -1,6 +1,6 @@ import { useContext } from "react"; import styled from "styled-components"; -import { vscForeground, defaultBorderRadius } from ".."; +import { defaultBorderRadius, vscForeground } from ".."; import { VscThemeContext } from "../../context/VscTheme"; const generateThemeStyles = (theme: any) => { diff --git a/gui/src/components/modelSelection/ModelCard.tsx b/gui/src/components/modelSelection/ModelCard.tsx index 85658d1734..2c34bdc51b 100644 --- a/gui/src/components/modelSelection/ModelCard.tsx +++ b/gui/src/components/modelSelection/ModelCard.tsx @@ -175,7 +175,10 @@ function ModelCard(props: ModelCardProps) { <>
- +

{dimension.name}

@@ -204,6 +207,7 @@ function ModelCard(props: ModelCardProps) {
` `} `; -const IconBase = styled.div<{ hovered: boolean }>` +const IconBase = styled.div<{ $hovered: boolean }>` width: 1.2em; height: 1.2em; cursor: pointer; padding: 4px; border-radius: ${defaultBorderRadius}; - opacity: ${(props) => (props.hovered ? 0.75 : 0)}; - visibility: ${(props) => (props.hovered ? "visible" : "hidden")}; + opacity: ${(props) => (props.$hovered ? 0.75 : 0)}; + visibility: ${(props) => (props.$hovered ? "visible" : "hidden")}; &:hover { opacity: 1; @@ -204,9 +204,9 @@ function ModelOption({
- + {showDelete && ( - + )}
diff --git a/gui/src/context/IdeMessenger.ts b/gui/src/context/IdeMessenger.ts index c5c4f9b3e4..bddd3fbbee 100644 --- a/gui/src/context/IdeMessenger.ts +++ b/gui/src/context/IdeMessenger.ts @@ -156,7 +156,7 @@ export class IdeMessenger implements IIdeMessenger { this.post(messageType, data, messageId); - let buffer = ""; + const buffer: any[] = []; let index = 0; let done = false; let returnVal = undefined; @@ -169,7 +169,7 @@ export class IdeMessenger implements IIdeMessenger { done = true; returnVal = responseData; } else { - buffer += responseData.content; + buffer.push(responseData.content); } } }; @@ -181,16 +181,16 @@ export class IdeMessenger implements IIdeMessenger { while (!done) { if (buffer.length > index) { - const chunk = buffer.slice(index); - index = buffer.length; + const chunk = buffer[index]; + index++; yield chunk; } await new Promise((resolve) => setTimeout(resolve, 50)); } - if (buffer.length > index) { - const chunk = buffer.slice(index); - index = buffer.length; + while (buffer.length > index) { + const chunk = buffer[index]; + index++; yield chunk; } @@ -215,7 +215,7 @@ export class IdeMessenger implements IIdeMessenger { let next = await gen.next(); while (!next.done) { - yield { role: "user", content: next.value }; + yield next.value; next = await gen.next(); } diff --git a/gui/src/hooks/useChatHandler.ts b/gui/src/hooks/useChatHandler.ts deleted file mode 100644 index 06e60102e7..0000000000 --- a/gui/src/hooks/useChatHandler.ts +++ /dev/null @@ -1,341 +0,0 @@ -import { Dispatch } from "@reduxjs/toolkit"; -import { JSONContent } from "@tiptap/react"; -import { - ChatHistoryItem, - ChatMessage, - ContextItemWithId, - InputModifiers, - MessageContent, - PromptLog, - RangeInFile, - SlashCommandDescription, -} from "core"; -import { constructMessages } from "core/llm/constructMessages"; -import { stripImages } from "core/llm/images"; -import { getBasename, getRelativePath } from "core/util"; -import { usePostHog } from "posthog-js/react"; -import { useEffect, useMemo, useRef, useState } from "react"; -import { useSelector } from "react-redux"; -import resolveEditorContent, { - hasSlashCommandOrContextProvider, -} from "../components/mainInput/resolveInput"; -import { IIdeMessenger } from "../context/IdeMessenger"; -import { defaultModelSelector } from "../redux/selectors/modelSelectors"; -import { - abortStream, - addPromptCompletionPair, - clearLastEmptyResponse, - initNewActiveMessage, - resubmitAtIndex, - setCurCheckpointIndex, - setInactive, - setIsGatheringContext, - setMessageAtIndex, - streamUpdate, -} from "../redux/slices/stateSlice"; -import { resetNextCodeBlockToApplyIndex } from "../redux/slices/stateSlice"; -import { RootState } from "../redux/store"; -import useHistory from "./useHistory"; -import { updateFileSymbolsFromContextItems } from "../util/symbols"; -import { - selectDefaultContextProviders, - selectSlashCommands, -} from "../redux/selectors"; - -function useChatHandler(dispatch: Dispatch, ideMessenger: IIdeMessenger) { - const posthog = usePostHog(); - - const defaultModel = useSelector(defaultModelSelector); - const defaultContextProviders = useSelector(selectDefaultContextProviders); - - const slashCommands = useSelector(selectSlashCommands); - - const history = useSelector((store: RootState) => store.state.history); - const active = useSelector((store: RootState) => store.state.active); - const streamAborter = useSelector( - (store: RootState) => store.state.streamAborter, - ); - const activeRef = useRef(active); - - const { saveSession } = useHistory(dispatch); - const [save, triggerSave] = useState(false); - - useEffect(() => { - saveSession(false); - }, [save]); - - useEffect(() => { - activeRef.current = active; - }, [active]); - - async function _streamNormalInput(messages: ChatMessage[]) { - if (!defaultModel) { - throw new Error("Default model not defined"); - } - const gen = ideMessenger.llmStreamChat( - defaultModel.title, - streamAborter.signal, - messages, - ); - let next = await gen.next(); - while (!next.done) { - if (!activeRef.current) { - dispatch(abortStream()); - break; - } - dispatch(streamUpdate(stripImages((next.value as ChatMessage).content))); - next = await gen.next(); - } - - let returnVal = next.value as PromptLog; - if (returnVal) { - dispatch(addPromptCompletionPair([returnVal])); - } - } - - const getSlashCommandForInput = ( - input: MessageContent, - ): [SlashCommandDescription, string] | undefined => { - let slashCommand: SlashCommandDescription | undefined; - let slashCommandName: string | undefined; - - let lastText = - typeof input === "string" - ? input - : input.filter((part) => part.type === "text").slice(-1)[0]?.text || ""; - - if (lastText.startsWith("/")) { - slashCommandName = lastText.split(" ")[0].substring(1); - slashCommand = slashCommands.find( - (command) => command.name === slashCommandName, - ); - } - if (!slashCommand || !slashCommandName) { - return undefined; - } - - // Convert to actual slash command object with runnable function - return [slashCommand, stripImages(input)]; - }; - - async function _streamSlashCommand( - messages: ChatMessage[], - slashCommand: SlashCommandDescription, - input: string, - historyIndex: number, - selectedCode: RangeInFile[], - contextItems: ContextItemWithId[], - ) { - if (!defaultModel) { - throw new Error("Default model not defined"); - } - - const modelTitle = defaultModel.title; - - const checkActiveInterval = setInterval(() => { - if (!activeRef.current) { - dispatch(abortStream()); - clearInterval(checkActiveInterval); - } - }, 100); - - try { - for await (const update of ideMessenger.streamRequest( - "command/run", - { - input, - history: messages, - modelTitle, - slashCommandName: slashCommand.name, - contextItems, - params: slashCommand.params, - historyIndex, - selectedCode, - }, - streamAborter.signal, - )) { - if (!activeRef.current) { - dispatch(abortStream()); - clearInterval(checkActiveInterval); - break; - } - if (typeof update === "string") { - dispatch(streamUpdate(update)); - } - } - } catch (e) { - throw e; - } finally { - clearInterval(checkActiveInterval); - } - } - - async function streamResponse( - editorState: JSONContent, - modifiers: InputModifiers, - ideMessenger: IIdeMessenger, - index?: number, - promptPreamble?: string, - ) { - try { - if (typeof index === "number") { - dispatch(resubmitAtIndex({ index, editorState })); - } else { - dispatch(initNewActiveMessage({ editorState })); - } - - // Reset current code block index - dispatch(resetNextCodeBlockToApplyIndex()); - - if (index) { - dispatch(setCurCheckpointIndex(Math.floor(index / 2))); - } - - const shouldGatherContext = - modifiers.useCodebase || hasSlashCommandOrContextProvider(editorState); - - if (shouldGatherContext) { - dispatch( - setIsGatheringContext({ - isGathering: true, - gatheringMessage: "Gathering Context", - }), - ); - } - - // Resolve context providers and construct new history - let [selectedContextItems, selectedCode, content] = - await resolveEditorContent( - editorState, - modifiers, - ideMessenger, - defaultContextProviders, - dispatch, - ); - - // Automatically use currently open file - if (!modifiers.noContext) { - const usingFreeTrial = defaultModel?.provider === "free-trial"; - - const currentFile = await ideMessenger.ide.getCurrentFile(); - if (currentFile) { - let currentFileContents = currentFile.contents; - if (usingFreeTrial) { - currentFileContents = currentFile.contents - .split("\n") - .slice(0, 1000) - .join("\n"); - } - if ( - !selectedContextItems.find( - (item) => item.uri?.value === currentFile.path, - ) - ) { - // don't add the file if it's already in the context items - selectedContextItems.unshift({ - content: `The following file is currently open. Don't reference it if it's not relevant to the user's message.\n\n\`\`\`${getRelativePath( - currentFile.path, - await ideMessenger.ide.getWorkspaceDirs(), - )}\n${currentFileContents}\n\`\`\``, - name: `Active file: ${getBasename(currentFile.path)}`, - description: currentFile.path, - id: { - itemId: currentFile.path, - providerTitle: "file", - }, - uri: { - type: "file", - value: currentFile.path, - }, - }); - } - } - } - - await updateFileSymbolsFromContextItems( - selectedContextItems, - ideMessenger, - dispatch, - ); - if (promptPreamble) { - if (typeof content === "string") { - content = promptPreamble + content; - } else { - content[0].text = promptPreamble + content[0].text; - } - } - - const message: ChatMessage = { - role: "user", - content, - }; - - const historyItem: ChatHistoryItem = { - message, - contextItems: selectedContextItems, - editorState, - }; - - let newHistory: ChatHistoryItem[] = [ - ...history.slice(0, index), - historyItem, - ]; - const historyIndex = index || newHistory.length - 1; - dispatch( - setMessageAtIndex({ - message, - index: historyIndex, - contextItems: selectedContextItems, - }), - ); - - // TODO: hacky way to allow rerender - await new Promise((resolve) => setTimeout(resolve, 0)); - - posthog.capture("step run", { - step_name: "User Input", - params: {}, - }); - posthog.capture("userInput", {}); - - const messages = constructMessages(newHistory, defaultModel.model); - - // Determine if the input is a slash command - let commandAndInput = getSlashCommandForInput(content); - - if (!commandAndInput) { - await _streamNormalInput(messages); - } else { - const [slashCommand, commandInput] = commandAndInput; - - let updatedContextItems = []; - - posthog.capture("step run", { - step_name: slashCommand.name, - params: {}, - }); - - await _streamSlashCommand( - messages, - slashCommand, - commandInput, - historyIndex, - selectedCode, - updatedContextItems, - ); - } - } catch (e: any) { - dispatch(clearLastEmptyResponse()); - console.debug("Error streaming response: ", e); - } finally { - dispatch(setInactive()); - triggerSave(!save); - } - } - - return { - streamResponse, - }; -} - -export default useChatHandler; diff --git a/gui/src/hooks/useHistory.tsx b/gui/src/hooks/useHistory.tsx index e477dfba9d..32bee26e04 100644 --- a/gui/src/hooks/useHistory.tsx +++ b/gui/src/hooks/useHistory.tsx @@ -1,7 +1,7 @@ import { Dispatch } from "@reduxjs/toolkit"; import { Session, SessionMetadata } from "core"; -import { stripImages } from "core/llm/images"; +import { renderChatMessage } from "core/util/messageContent"; import { useCallback, useContext } from "react"; import { useSelector } from "react-redux"; import { IdeMessengerContext } from "../context/IdeMessenger"; @@ -77,7 +77,7 @@ function useHistory(dispatch: Dispatch) { let newTitle = currentTitle === "New Session" ? truncateText( - stripImages(history[0].message.content) + renderChatMessage(history[0].message) .split("\n") .filter((l) => l.trim() !== "") .slice(-1)[0] || "", diff --git a/gui/src/hooks/useSetup.ts b/gui/src/hooks/useSetup.ts index e1c1e103e6..c62f847ea5 100644 --- a/gui/src/hooks/useSetup.ts +++ b/gui/src/hooks/useSetup.ts @@ -1,4 +1,3 @@ -import { Dispatch } from "@reduxjs/toolkit"; import { useCallback, useContext, useEffect, useRef } from "react"; import { useSelector } from "react-redux"; import { VSC_THEME_COLOR_VARS } from "../components"; @@ -14,15 +13,15 @@ import { updateDocsSuggestions, updateIndexingStatus, } from "../redux/slices/stateSlice"; -import { RootState } from "../redux/store"; +import { AppDispatch, RootState } from "../redux/store"; +import { streamResponseThunk } from "../redux/thunks/streamResponse"; import { isJetBrains } from "../util"; import { setLocalStorage } from "../util/localStorage"; -import useChatHandler from "./useChatHandler"; -import { useWebviewListener } from "./useWebviewListener"; import { updateFileSymbolsFromContextItems } from "../util/symbols"; +import { useWebviewListener } from "./useWebviewListener"; -function useSetup(dispatch: Dispatch) { +function useSetup(dispatch: AppDispatch) { const ideMessenger = useContext(IdeMessengerContext); const history = useSelector((store: RootState) => store.state.history); @@ -131,8 +130,6 @@ function useSetup(dispatch: Dispatch) { dispatch(updateDocsSuggestions(data)); }); - const { streamResponse } = useChatHandler(dispatch, ideMessenger); - // IDE event listeners useWebviewListener( "getWebviewHistoryLength", @@ -163,10 +160,11 @@ function useSetup(dispatch: Dispatch) { // TODO - remove? useWebviewListener("submitMessage", async (data) => { - streamResponse( - data.message, - { useCodebase: false, noContext: true }, - ideMessenger, + dispatch( + streamResponseThunk({ + editorState: data.message, + modifiers: { useCodebase: false, noContext: true }, + }), ); }); diff --git a/gui/src/pages/edit/Edit.tsx b/gui/src/pages/edit/Edit.tsx index 201a4bc4cf..d1c8f343fa 100644 --- a/gui/src/pages/edit/Edit.tsx +++ b/gui/src/pages/edit/Edit.tsx @@ -1,27 +1,27 @@ import { ArrowLeftIcon } from "@heroicons/react/24/outline"; import { Editor, JSONContent } from "@tiptap/core"; import { InputModifiers, RangeInFileWithContents } from "core"; -import { stripImages } from "core/llm/images"; +import { stripImages } from "core/util/messageContent"; import { useCallback, useContext, useEffect, useMemo } from "react"; -import { useDispatch, useSelector } from "react-redux"; +import { useSelector } from "react-redux"; import { useNavigate } from "react-router-dom"; +import ContinueInputBox from "../../components/mainInput/ContinueInputBox"; import { NewSessionButton } from "../../components/mainInput/NewSessionButton"; import resolveEditorContent from "../../components/mainInput/resolveInput"; import TipTapEditor from "../../components/mainInput/TipTapEditor"; +import StepContainer from "../../components/StepContainer"; +import AcceptRejectAllButtons from "../../components/StepContainer/AcceptRejectAllButtons"; import { IdeMessengerContext } from "../../context/IdeMessenger"; +import { selectApplyState } from "../../redux/selectors"; import { clearCodeToEdit, setEditDone, submitEdit, } from "../../redux/slices/editModeState"; -import { RootState } from "../../redux/store"; +import { RootState, useAppDispatch } from "../../redux/store"; +import { streamResponseThunk } from "../../redux/thunks/streamResponse"; import CodeToEdit from "./CodeToEdit"; -import useChatHandler from "../../hooks/useChatHandler"; -import AcceptRejectAllButtons from "../../components/StepContainer/AcceptRejectAllButtons"; -import ContinueInputBox from "../../components/mainInput/ContinueInputBox"; -import StepContainer from "../../components/StepContainer"; import getMultifileEditPrompt from "./getMultifileEditPrompt"; -import { selectApplyState } from "../../redux/selectors"; const EDIT_DISALLOWED_CONTEXT_PROVIDERS = [ "codebase", @@ -36,10 +36,9 @@ const EDIT_DISALLOWED_CONTEXT_PROVIDERS = [ ]; export default function Edit() { - const dispatch = useDispatch(); + const dispatch = useAppDispatch(); const navigate = useNavigate(); const ideMessenger = useContext(IdeMessengerContext); - const { streamResponse } = useChatHandler(dispatch, ideMessenger); const editModeState = useSelector((state: RootState) => state.editModeState); const availableContextProviders = useSelector( (store: RootState) => store.state.config.contextProviders, @@ -138,12 +137,12 @@ export default function Edit() { } else { const promptPreamble = getMultifileEditPrompt(editModeState.codeToEdit); - streamResponse( - editorState, - modifiers, - ideMessenger, - undefined, - promptPreamble, + dispatch( + streamResponseThunk({ + editorState, + modifiers, + promptPreamble, + }), ); } } @@ -184,11 +183,12 @@ export default function Edit() { {item.message.role === "user" ? ( { - streamResponse( - editorState, - modifiers, - ideMessenger, - index, + dispatch( + streamResponseThunk({ + editorState, + modifiers, + index, + }), ); }} isLastUserInput={isLastUserInput(index)} diff --git a/gui/src/pages/gui/Chat.tsx b/gui/src/pages/gui/Chat.tsx index adea3ec80f..f297e71da3 100644 --- a/gui/src/pages/gui/Chat.tsx +++ b/gui/src/pages/gui/Chat.tsx @@ -5,10 +5,10 @@ import { ExclamationTriangleIcon, } from "@heroicons/react/24/outline"; import { JSONContent } from "@tiptap/react"; -import { InputModifiers } from "core"; +import { InputModifiers, ToolCallState } from "core"; +import { streamResponse } from "core/llm/stream"; import { usePostHog } from "posthog-js/react"; import { - Fragment, useCallback, useContext, useEffect, @@ -17,7 +17,7 @@ import { useState, } from "react"; import { ErrorBoundary } from "react-error-boundary"; -import { useDispatch, useSelector } from "react-redux"; +import { useSelector } from "react-redux"; import styled from "styled-components"; import { Button, @@ -26,8 +26,9 @@ import { vscBackground, } from "../../components"; import { ChatScrollAnchor } from "../../components/ChatScrollAnchor"; -import StepContainer from "../../components/StepContainer"; +import { useFindWidget } from "../../components/find/FindWidget"; import TimelineItem from "../../components/gui/TimelineItem"; +import ChatIndexingPeeks from "../../components/indexing/ChatIndexingPeeks"; import ContinueInputBox from "../../components/mainInput/ContinueInputBox"; import { NewSessionButton } from "../../components/mainInput/NewSessionButton"; import { TutorialCard } from "../../components/mainInput/TutorialCard"; @@ -35,12 +36,13 @@ import { OnboardingCard, useOnboardingCard, } from "../../components/OnboardingCard"; +import StepContainer from "../../components/StepContainer"; import { IdeMessengerContext } from "../../context/IdeMessenger"; -import useChatHandler from "../../hooks/useChatHandler"; import useHistory from "../../hooks/useHistory"; import { useTutorialCard } from "../../hooks/useTutorialCard"; import { useWebviewListener } from "../../hooks/useWebviewListener"; import { defaultModelSelector } from "../../redux/selectors/modelSelectors"; +import { selectCurrentToolCall } from "../../redux/selectors/selectCurrentToolCall"; import { clearLastEmptyResponse, newSession, @@ -51,7 +53,8 @@ import { setDialogMessage, setShowDialog, } from "../../redux/slices/uiStateSlice"; -import { RootState } from "../../redux/store"; +import { RootState, useAppDispatch } from "../../redux/store"; +import { streamResponseThunk } from "../../redux/thunks/streamResponse"; import { getFontSize, getMetaKeyLabel, @@ -60,8 +63,9 @@ import { import { FREE_TRIAL_LIMIT_REQUESTS } from "../../util/freeTrial"; import { getLocalStorage, setLocalStorage } from "../../util/localStorage"; import ConfigErrorIndicator from "./ConfigError"; -import ChatIndexingPeeks from "../../components/indexing/ChatIndexingPeeks"; -import { useFindWidget } from "../../components/find/FindWidget"; +import { ToolCallDiv } from "./ToolCallDiv"; +import { ToolCallButtons } from "./ToolCallDiv/ToolCallButtonsDiv"; +import ToolOutput from "./ToolCallDiv/ToolOutput"; const StopButton = styled.div` background-color: ${vscBackground}; @@ -95,7 +99,7 @@ const StepsDiv = styled.div` } .thread-message { - margin: 8px 4px 0 4px; + margin: 0px 4px 0 4px; } `; @@ -110,6 +114,7 @@ function fallbackRender({ error, resetErrorBoundary }: any) { >

Something went wrong:

{error.message}
+
{error.stack}
@@ -120,9 +125,8 @@ function fallbackRender({ error, resetErrorBoundary }: any) { export function Chat() { const posthog = usePostHog(); - const dispatch = useDispatch(); + const dispatch = useAppDispatch(); const ideMessenger = useContext(IdeMessengerContext); - const { streamResponse } = useChatHandler(dispatch, ideMessenger); const onboardingCard = useOnboardingCard(); const { showTutorialCard, closeTutorialCard } = useTutorialCard(); const defaultModel = useSelector(defaultModelSelector); @@ -139,6 +143,10 @@ export function Chat() { const { saveSession, getLastSessionId, loadLastSession } = useHistory(dispatch); + const toolCallState = useSelector( + selectCurrentToolCall, + ); + const snapToBottom = useCallback(() => { if (!stepsDivRef.current) return; const elem = stepsDivRef.current; @@ -176,7 +184,7 @@ export function Chat() { isMetaEquivalentKeyPressed(e) && !e.shiftKey ) { - dispatch(setInactive()); + // dispatch(cancelGeneration()); TODO!!! } }; window.addEventListener("keydown", listener); @@ -221,7 +229,7 @@ export function Chat() { } } - streamResponse(editorState, modifiers, ideMessenger); + dispatch(streamResponseThunk({ editorState, modifiers })); // Increment localstorage counter for popup const currentCount = getLocalStorage("mainTextEntryCounter"); @@ -328,7 +336,12 @@ export function Chat() { > {highlights} {history.map((item, index: number) => ( - +
{ @@ -338,13 +351,36 @@ export function Chat() { {item.message.role === "user" ? ( { - streamResponse(editorState, modifiers, ideMessenger, index); + dispatch( + streamResponseThunk({ editorState, modifiers, index }), + ); }} isLastUserInput={isLastUserInput(index)} isMainInput={false} editorState={item.editorState} contextItems={item.contextItems} /> + ) : item.message.role === "tool" ? ( + + ) : item.message.role === "assistant" && + item.message.toolCalls && + item.toolCallState ? ( +
+ {item.message.toolCalls?.map((toolCall, i) => { + return ( +
+ +
+ ); + })} +
) : (
)} - +
))} )}
+ + {toolCallState?.status === "generated" && } + ) : null; +} diff --git a/gui/src/pages/gui/ToolCallDiv/FunctionSpecificToolCallDiv.tsx b/gui/src/pages/gui/ToolCallDiv/FunctionSpecificToolCallDiv.tsx new file mode 100644 index 0000000000..d55055b031 --- /dev/null +++ b/gui/src/pages/gui/ToolCallDiv/FunctionSpecificToolCallDiv.tsx @@ -0,0 +1,31 @@ +import { ToolCall, ToolCallState } from "core"; +import { CreateFile } from "./CreateFile"; +import { RunTerminalCommand } from "./RunTerminalCommand"; + +function FunctionSpecificToolCallDiv({ + toolCall, + toolCallState, +}: { + toolCall: ToolCall; + toolCallState: ToolCallState; +}) { + const args = toolCallState.parsedArgs; + + switch (toolCall.function.name) { + case "create_new_file": + return ( + + ); + case "run_terminal_command": + return ( + + ); + default: + return null; + } +} + +export default FunctionSpecificToolCallDiv; diff --git a/gui/src/pages/gui/ToolCallDiv/RunTerminalCommand.tsx b/gui/src/pages/gui/ToolCallDiv/RunTerminalCommand.tsx new file mode 100644 index 0000000000..35fc0e54f5 --- /dev/null +++ b/gui/src/pages/gui/ToolCallDiv/RunTerminalCommand.tsx @@ -0,0 +1,16 @@ +import { ToolCallState } from "core"; +import StyledMarkdownPreview from "../../../components/markdown/StyledMarkdownPreview"; + +interface RunTerminalCommandToolCallProps { + command: string; + toolCallState: ToolCallState; +} + +export function RunTerminalCommand(props: RunTerminalCommandToolCallProps) { + return ( + + ); +} diff --git a/gui/src/pages/gui/ToolCallDiv/ThreadDiv.tsx b/gui/src/pages/gui/ToolCallDiv/ThreadDiv.tsx new file mode 100644 index 0000000000..e02eaca8f2 --- /dev/null +++ b/gui/src/pages/gui/ToolCallDiv/ThreadDiv.tsx @@ -0,0 +1,78 @@ +import { Tool, ToolCall, ToolCallState } from "core"; +import Mustache from "mustache"; +import { useSelector } from "react-redux"; +import styled from "styled-components"; +import { RootState } from "../../../redux/store"; + +interface ThreadDivProps { + children: React.ReactNode; + icon: React.ReactNode; + toolCall: ToolCall; + toolCallState: ToolCallState; + reactKey: string; +} + +const Container = styled.div` + display: flex; + flex-direction: column; + justify-content: center; + position: relative; + padding: 8px; + padding-bottom: 0; +`; + +const ChildrenDiv = styled.div``; + +const W = 16; + +const HeaderDiv = styled.div` + display: flex; + align-items: center; + gap: 6px; + margin-bottom: 8px; +`; + +export function ThreadDiv(props: ThreadDivProps) { + const availableTools = useSelector( + (state: RootState) => state.state.config.tools, + ); + + function renderWouldLikeToMessage( + tool: Tool | undefined, + toolCallState: ToolCallState, + ): string { + if (!tool) return ""; + + const rendered = Mustache.render( + tool.wouldLikeTo, + toolCallState.parsedArgs, + ); + return rendered; + } + + return ( + + +
+ {props.icon} +
+ Continue wants to{" "} + {renderWouldLikeToMessage( + availableTools.find( + (tool) => props.toolCall.function.name === tool.function.name, + ), + props.toolCallState, + )} +
+ {props.children} +
+ ); +} diff --git a/gui/src/pages/gui/ToolCallDiv/ToolCallButtonsDiv.tsx b/gui/src/pages/gui/ToolCallDiv/ToolCallButtonsDiv.tsx new file mode 100644 index 0000000000..d088091ca8 --- /dev/null +++ b/gui/src/pages/gui/ToolCallDiv/ToolCallButtonsDiv.tsx @@ -0,0 +1,91 @@ +import { useSelector } from "react-redux"; +import styled from "styled-components"; +import { + defaultBorderRadius, + lightGray, + vscButtonBackground, + vscButtonForeground, +} from "../../../components"; +import Spinner from "../../../components/markdown/StepContainerPreToolbar/Spinner"; +import { selectCurrentToolCall } from "../../../redux/selectors/selectCurrentToolCall"; +import { useAppDispatch } from "../../../redux/store"; +import { callTool } from "../../../redux/thunks/callTool"; +import { cancelTool } from "../../../redux/thunks/cancelTool"; + +const ButtonContainer = styled.div` + display: flex; + gap: 8px; + margin-top: 12px; + margin: 8px; +`; + +const Button = styled.button` + padding: 5px; + border-radius: ${defaultBorderRadius}; + flex: 1; + + &:hover { + cursor: pointer; + opacity: 0.8; + } +`; + +const AcceptButton = styled(Button)` + color: ${vscButtonForeground}; + border: none; + background-color: ${vscButtonBackground}; + color: ${vscButtonForeground}; + + &:hover { + cursor: pointer; + } +`; + +const RejectButton = styled(Button)` + color: ${lightGray}; + border: 1px solid ${lightGray}; + background-color: transparent; +`; + +interface ToolCallButtonsProps {} + +export function ToolCallButtons(props: ToolCallButtonsProps) { + const dispatch = useAppDispatch(); + const toolCallState = useSelector(selectCurrentToolCall); + + if (!toolCallState) { + return null; + } + + return ( + <> + + {toolCallState.status === "generating" ? ( +
+ Thinking... +
+ ) : toolCallState.status === "generated" ? ( + <> + dispatch(cancelTool())}> + Cancel + + dispatch(callTool())}> + Continue + + + ) : toolCallState.status === "calling" ? ( +
+ Loading... + +
+ ) : null} +
+ + ); +} diff --git a/gui/src/pages/gui/ToolCallDiv/ToolOutput.tsx b/gui/src/pages/gui/ToolCallDiv/ToolOutput.tsx new file mode 100644 index 0000000000..46057e90c5 --- /dev/null +++ b/gui/src/pages/gui/ToolCallDiv/ToolOutput.tsx @@ -0,0 +1,24 @@ +import { ContextItemWithId } from "core"; +import ContextItemsPeek from "../../../components/mainInput/ContextItemsPeek"; + +interface ToolOutputProps { + contextItems: ContextItemWithId[]; + toolCallId: string; +} + +function ToolOutput(props: ToolOutputProps) { + if (props.contextItems.length === 0) { + return null; + } + + return ( +
+ +
+ ); +} + +export default ToolOutput; diff --git a/gui/src/pages/gui/ToolCallDiv/index.tsx b/gui/src/pages/gui/ToolCallDiv/index.tsx new file mode 100644 index 0000000000..e1745c4596 --- /dev/null +++ b/gui/src/pages/gui/ToolCallDiv/index.tsx @@ -0,0 +1,46 @@ +import { + ArrowRightIcon, + CheckIcon, + XMarkIcon, +} from "@heroicons/react/24/outline"; +import { ToolCall, ToolCallState, ToolStatus } from "core"; +import { vscButtonBackground } from "../../../components"; +import Spinner from "../../../components/markdown/StepContainerPreToolbar/Spinner"; +import FunctionSpecificToolCallDiv from "./FunctionSpecificToolCallDiv"; +import { ThreadDiv } from "./ThreadDiv"; + +interface ToolCallDivProps { + toolCall: ToolCall; + toolCallState: ToolCallState; + reactKey: string; +} + +export function ToolCallDiv(props: ToolCallDivProps) { + function getIcon(state: ToolStatus) { + switch (state) { + case "generating": + case "calling": + return ; + case "generated": + return ; + case "done": + return ; + case "canceled": + return ; + } + } + + return ( + + + + ); +} diff --git a/gui/src/pages/settings.tsx b/gui/src/pages/settings.tsx index 613d5f2b16..6c871962a7 100644 --- a/gui/src/pages/settings.tsx +++ b/gui/src/pages/settings.tsx @@ -172,6 +172,7 @@ function Settings() {

System Message Temperature store.state.history, + (history) => { + return findCurrentToolCall(history); + }, +); diff --git a/gui/src/redux/selectors/selectLastToolCall.ts b/gui/src/redux/selectors/selectLastToolCall.ts new file mode 100644 index 0000000000..5509ab4766 --- /dev/null +++ b/gui/src/redux/selectors/selectLastToolCall.ts @@ -0,0 +1,19 @@ +import { createSelector } from "@reduxjs/toolkit"; +import { ToolCallState } from "core"; +import { RootState } from "../store"; + +export const selectLastToolCall = createSelector( + [(store: RootState) => store.state.history], + (history): ToolCallState | null => { + let lastToolCallHistoryItem = null; + for (let i = history.length - 1; i >= 0; i--) { + const item = history[i]; + if (item.message.role === "assistant" && item.message.toolCalls?.length) { + lastToolCallHistoryItem = item; + break; + } + } + if (!lastToolCallHistoryItem) return null; + return lastToolCallHistoryItem.toolCallState; + }, +); diff --git a/gui/src/redux/slices/stateSlice.ts b/gui/src/redux/slices/stateSlice.ts index 7b32247ff4..6f542a1b26 100644 --- a/gui/src/redux/slices/stateSlice.ts +++ b/gui/src/redux/slices/stateSlice.ts @@ -1,21 +1,31 @@ -import { PayloadAction, createSlice } from "@reduxjs/toolkit"; +import { + ActionReducerMapBuilder, + AsyncThunk, + PayloadAction, + createSlice, +} from "@reduxjs/toolkit"; import { JSONContent } from "@tiptap/react"; import { ApplyState, ChatHistoryItem, ChatMessage, Checkpoint, + ContextItem, ContextItemWithId, - PackageDocsResult, FileSymbolMap, IndexingStatus, - Session, + PackageDocsResult, PromptLog, + Session, + ToolCall, } from "core"; import { BrowserSerializedContinueConfig } from "core/config/load"; import { ConfigValidationError } from "core/config/validation"; -import { stripImages } from "core/llm/images"; +import { incrementalParseJson } from "core/util/incrementalParseJson"; +import { renderChatMessage } from "core/util/messageContent"; import { v4 as uuidv4, v4 } from "uuid"; +import { streamResponseThunk } from "../thunks/streamResponse"; +import { findCurrentToolCall } from "../util"; // We need this to handle reorderings (e.g. a mid-array deletion) of the messages array. // The proper fix is adding a UUID to all chat messages, but this is the temp workaround. @@ -74,6 +84,7 @@ const initialState: State = { ], contextProviders: [], models: [], + tools: [], }, title: "New Session", sessionId: v4(), @@ -252,7 +263,7 @@ export const stateSlice = createSlice({ message: { ...payload.message, id: uuidv4() }, editorState: { type: "doc", - content: stripImages(payload.message.content) + content: renderChatMessage(payload.message) .split("\n") .map((line) => ({ type: "paragraph", @@ -295,10 +306,72 @@ export const stateSlice = createSlice({ state.streamAborter.abort(); state.streamAborter = new AbortController(); }, - streamUpdate: (state, action: PayloadAction) => { + streamUpdate: (state, action: PayloadAction) => { if (state.history.length) { - state.history[state.history.length - 1].message.content += - action.payload; + const lastMessage = state.history[state.history.length - 1]; + + if ( + action.payload.role && + (lastMessage.message.role !== action.payload.role || + // This is when a tool call comes after assistant text + (lastMessage.message.content !== "" && + action.payload.role === "assistant" && + action.payload.toolCalls?.length)) + ) { + // Create a new message + const historyItem: ChatHistoryItemWithMessageId = { + contextItems: [], + message: { id: uuidv4(), ...action.payload }, + }; + + if (action.payload.role === "assistant" && action.payload.toolCalls) { + const [_, parsedArgs] = incrementalParseJson( + action.payload.toolCalls[0].function.arguments, + ); + historyItem.toolCallState = { + status: "generating", + toolCall: action.payload.toolCalls[0] as ToolCall, + toolCallId: action.payload.toolCalls[0].id, + parsedArgs, + }; + } + + state.history.push(historyItem); + } else { + // Add to the existing message + const msg = state.history[state.history.length - 1].message; + if (action.payload.content) { + msg.content += renderChatMessage(action.payload); + } else if ( + action.payload.role === "assistant" && + action.payload.toolCalls && + msg.role === "assistant" + ) { + if (!msg.toolCalls) { + msg.toolCalls = []; + } + action.payload.toolCalls.forEach((toolCall, i) => { + if (msg.toolCalls.length <= i) { + msg.toolCalls.push(toolCall); + } else { + msg.toolCalls[i].function.arguments += + toolCall.function.arguments; + + const [_, parsedArgs] = incrementalParseJson( + msg.toolCalls[i].function.arguments, + ); + + state.history[ + state.history.length - 1 + ].toolCallState.parsedArgs = parsedArgs; + state.history[ + state.history.length - 1 + ].toolCallState.toolCall.function.arguments += + toolCall.function.arguments; + } + }); + } + } } }, newSession: (state, { payload }: PayloadAction) => { @@ -308,6 +381,7 @@ export const stateSlice = createSlice({ state.active = false; state.context.isGathering = false; state.symbols = {}; + if (payload) { state.history = payload.history as any; state.title = payload.title; @@ -405,6 +479,38 @@ export const stateSlice = createSlice({ resetNextCodeBlockToApplyIndex: (state) => { state.nextCodeBlockToApplyIndex = 0; }, + + // Related to currentToolCallState + setToolGenerated: (state) => { + const toolCallState = findCurrentToolCall(state.history); + if (!toolCallState) return; + + toolCallState.status = "generated"; + }, + setToolCallOutput: (state, action: PayloadAction) => { + const toolCallState = findCurrentToolCall(state.history); + if (!toolCallState) return; + + toolCallState.output = action.payload; + }, + cancelToolCall: (state) => { + const toolCallState = findCurrentToolCall(state.history); + if (!toolCallState) return; + + toolCallState.status = "canceled"; + }, + acceptToolCall: (state) => { + const toolCallState = findCurrentToolCall(state.history); + if (!toolCallState) return; + + toolCallState.status = "done"; + }, + setCalling: (state) => { + const toolCallState = findCurrentToolCall(state.history); + if (!toolCallState) return; + + toolCallState.status = "calling"; + }, updateIndexingStatus: ( state, { payload }: PayloadAction, @@ -448,8 +554,23 @@ export const stateSlice = createSlice({ state.docsSuggestions = payload; }, }, + + extraReducers: (builder) => { + addPassthroughCases(builder, [streamResponseThunk]); + }, }); +function addPassthroughCases( + builder: ActionReducerMapBuilder, + thunks: AsyncThunk[], +) { + thunks.forEach((thunk) => { + builder + .addCase(thunk.fulfilled, (state, action) => {}) + .addCase(thunk.rejected, (state, action) => {}) + .addCase(thunk.pending, (state, action) => {}); + }); +} export const { updateFileSymbols, setContextItemsAtIndex, @@ -477,10 +598,15 @@ export const { setCurCheckpointIndex, resetNextCodeBlockToApplyIndex, updateApplyState, + abortStream, updateIndexingStatus, setIndexingChatPeekHidden, - abortStream, + setCalling, + cancelToolCall, + acceptToolCall, + setToolGenerated, updateDocsSuggestions, + setToolCallOutput, } = stateSlice.actions; export default stateSlice.reducer; diff --git a/gui/src/redux/slices/uiStateSlice.ts b/gui/src/redux/slices/uiStateSlice.ts index f952b5fc94..3472eb0184 100644 --- a/gui/src/redux/slices/uiStateSlice.ts +++ b/gui/src/redux/slices/uiStateSlice.ts @@ -1,17 +1,28 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit"; +import { Tool } from "core"; +import { BuiltInToolNames } from "core/tools/builtIn"; import { defaultOnboardingCardState, OnboardingCardState, } from "../../components/OnboardingCard"; +type ToolSetting = + | "allowedWithPermission" + | "allowedWithoutPermission" + | "disabled"; + type UiState = { showDialog: boolean; dialogMessage: string | JSX.Element | undefined; dialogEntryOn: boolean; onboardingCard: OnboardingCardState; shouldAddFileForEditing: boolean; + useTools: boolean; + toolSettings: { [toolName: string]: ToolSetting }; }; +export const DEFAULT_TOOL_SETTING: ToolSetting = "allowedWithPermission"; + export const uiStateSlice = createSlice({ name: "uiState", initialState: { @@ -20,6 +31,18 @@ export const uiStateSlice = createSlice({ dialogEntryOn: false, onboardingCard: defaultOnboardingCardState, shouldAddFileForEditing: false, + + useTools: false, + toolSettings: { + [BuiltInToolNames.ReadFile]: "allowedWithoutPermission", + [BuiltInToolNames.CreateNewFile]: "allowedWithPermission", + [BuiltInToolNames.RunTerminalCommand]: "allowedWithPermission", + [BuiltInToolNames.ViewSubdirectory]: "allowedWithoutPermission", + [BuiltInToolNames.ViewRepoMap]: "allowedWithoutPermission", + [BuiltInToolNames.ExactSearch]: "allowedWithoutPermission", + [BuiltInToolNames.SearchWeb]: "allowedWithoutPermission", + [BuiltInToolNames.ViewDiff]: "allowedWithoutPermission", + }, } as UiState, reducers: { setOnboardingCard: ( @@ -43,6 +66,33 @@ export const uiStateSlice = createSlice({ setShowDialog: (state, action: PayloadAction) => { state.showDialog = action.payload; }, + // Tools + toggleUseTools: (state) => { + state.useTools = !state.useTools; + }, + addTool: (state, action: PayloadAction) => { + state.toolSettings[action.payload.function.name] = action.payload.readonly + ? "allowedWithoutPermission" + : "allowedWithPermission"; + }, + toggleToolSetting: (state, action: PayloadAction) => { + const setting = state.toolSettings[action.payload]; + + switch (setting) { + case "allowedWithPermission": + state.toolSettings[action.payload] = "allowedWithoutPermission"; + break; + case "allowedWithoutPermission": + state.toolSettings[action.payload] = "disabled"; + break; + case "disabled": + state.toolSettings[action.payload] = "allowedWithPermission"; + break; + default: + state.toolSettings[action.payload] = DEFAULT_TOOL_SETTING; + break; + } + }, }, }); @@ -51,6 +101,9 @@ export const { setDialogMessage, setDialogEntryOn, setShowDialog, + toggleUseTools, + toggleToolSetting, + addTool, } = uiStateSlice.actions; export default uiStateSlice.reducer; diff --git a/gui/src/redux/store.ts b/gui/src/redux/store.ts index 87edf3c663..a594b30d5a 100644 --- a/gui/src/redux/store.ts +++ b/gui/src/redux/store.ts @@ -5,10 +5,12 @@ import serverStateReducer from "./slices/serverStateReducer"; import stateReducer from "./slices/stateSlice"; import uiStateReducer from "./slices/uiStateSlice"; +import { useDispatch } from "react-redux"; import { createTransform, persistReducer, persistStore } from "redux-persist"; import { createFilter } from "redux-persist-transform-filter"; import autoMergeLevel2 from "redux-persist/lib/stateReconciler/autoMergeLevel2"; import storage from "redux-persist/lib/storage"; +import { IdeMessenger, IIdeMessenger } from "../context/IdeMessenger"; import editModeStateReducer from "./slices/editModeState"; export interface ChatMessage { role: "system" | "user" | "assistant"; @@ -24,8 +26,6 @@ const rootReducer = combineReducers({ editModeState: editModeStateReducer, }); -export type RootState = ReturnType; - const windowIDTransform = (windowID: string) => createTransform( // transform state on its way to being serialized and persisted. @@ -40,6 +40,7 @@ const windowIDTransform = (windowID: string) => const saveSubsetFilters = [ createFilter("state", ["history", "sessionId", "defaultModelTitle"]), + // Don't persist any of the edit state for now createFilter("editModeState", []), ]; @@ -57,15 +58,32 @@ const persistedReducer = persistReducer(persistConfig, rootReducer); export function setupStore() { return configureStore({ - reducer: persistedReducer, + // persistedReducer causes type errors with async thunks + reducer: persistedReducer as unknown as typeof rootReducer, // reducer: rootReducer, middleware: (getDefaultMiddleware) => getDefaultMiddleware({ serializableCheck: false, + thunk: { + extraArgument: { + ideMessenger: new IdeMessenger(), + }, + }, }), }); } +export type ThunkApiType = { + state: RootState; + extra: { ideMessenger: IIdeMessenger }; +}; + export const store = setupStore(); +export type RootState = ReturnType; + +export type AppDispatch = typeof store.dispatch; + +export const useAppDispatch = () => useDispatch(); + export const persistor = persistStore(store); diff --git a/gui/src/redux/thunks/callTool.ts b/gui/src/redux/thunks/callTool.ts new file mode 100644 index 0000000000..b022b34973 --- /dev/null +++ b/gui/src/redux/thunks/callTool.ts @@ -0,0 +1,46 @@ +import { createAsyncThunk } from "@reduxjs/toolkit"; +import { selectCurrentToolCall } from "../selectors/selectCurrentToolCall"; +import { + acceptToolCall, + setCalling, + setToolCallOutput, +} from "../slices/stateSlice"; +import { ThunkApiType } from "../store"; +import { streamResponseAfterToolCall } from "./streamResponseAfterToolCall"; + +export const callTool = createAsyncThunk( + "chat/callTool", + async (_, { dispatch, extra, getState }) => { + const state = getState(); + const toolCallState = selectCurrentToolCall(state); + + console.log("calling tool", toolCallState.toolCall); + if (!toolCallState) { + return; + } + + if (toolCallState.status !== "generated") { + return; + } + + dispatch(setCalling()); + + const result = await extra.ideMessenger.request("tools/call", { + toolCall: toolCallState.toolCall, + }); + + if (result.status === "success") { + const contextItems = result.content.contextItems; + dispatch(setToolCallOutput(contextItems)); + dispatch(acceptToolCall()); + + // Send to the LLM to continue the conversation + dispatch( + streamResponseAfterToolCall({ + toolCallId: toolCallState.toolCall.id, + toolOutput: contextItems, + }), + ); + } + }, +); diff --git a/gui/src/redux/thunks/cancelTool.ts b/gui/src/redux/thunks/cancelTool.ts new file mode 100644 index 0000000000..c93c4585c2 --- /dev/null +++ b/gui/src/redux/thunks/cancelTool.ts @@ -0,0 +1,78 @@ +import { createAsyncThunk } from "@reduxjs/toolkit"; +import { selectCurrentToolCall } from "../selectors/selectCurrentToolCall"; +import { + cancelToolCall, + setCalling, + setToolCallOutput, +} from "../slices/stateSlice"; +import { ThunkApiType } from "../store"; +import { streamResponseAfterToolCall } from "./streamResponseAfterToolCall"; + +export const cancelTool = createAsyncThunk( + "chat/callTool", + async (_, { dispatch, extra, getState }) => { + const state = getState(); + const toolCallState = selectCurrentToolCall(state); + + if (!toolCallState) { + return; + } + + if (toolCallState.status !== "generated") { + return; + } + + dispatch(cancelToolCall()); + + dispatch( + streamResponseAfterToolCall({ + toolCallId: toolCallState.toolCallId, + toolOutput: [ + { + name: "Tool Cancelled", + description: "Tool Cancelled", + content: + "The tool call was cancelled by the user. Please try something else or request further instructions.", + hidden: true, + }, + ], + }), + ); + }, +); + +export const callTool = createAsyncThunk( + "chat/callTool", + async (_, { dispatch, extra, getState }) => { + const state = getState(); + const toolCallState = selectCurrentToolCall(state); + + console.log("calling tool", toolCallState.toolCall); + if (!toolCallState) { + return; + } + + if (toolCallState.status !== "generated") { + return; + } + + dispatch(setCalling()); + + const result = await extra.ideMessenger.request("tools/call", { + toolCall: toolCallState.toolCall, + }); + + if (result.status === "success") { + const contextItems = result.content.contextItems; + dispatch(setToolCallOutput(contextItems)); + + // Send to the LLM to continue the conversation + dispatch( + streamResponseAfterToolCall({ + toolCallId: toolCallState.toolCall.id, + toolOutput: contextItems, + }), + ); + } + }, +); diff --git a/gui/src/redux/thunks/gatherContext.ts b/gui/src/redux/thunks/gatherContext.ts new file mode 100644 index 0000000000..409004e08c --- /dev/null +++ b/gui/src/redux/thunks/gatherContext.ts @@ -0,0 +1,119 @@ +import { createAsyncThunk } from "@reduxjs/toolkit"; +import { JSONContent } from "@tiptap/core"; +import { + ContextItemWithId, + InputModifiers, + MessageContent, + RangeInFile, +} from "core"; +import { getBasename, getRelativePath } from "core/util"; +import resolveEditorContent, { + hasSlashCommandOrContextProvider, +} from "../../components/mainInput/resolveInput"; +import { updateFileSymbolsFromContextItems } from "../../util/symbols"; +import { defaultModelSelector } from "../selectors/modelSelectors"; +import { setIsGatheringContext } from "../slices/stateSlice"; +import { ThunkApiType } from "../store"; + +export const gatherContext = createAsyncThunk< + { + selectedContextItems: ContextItemWithId[]; + selectedCode: RangeInFile[]; + content: MessageContent; + }, + { + editorState: JSONContent; + modifiers: InputModifiers; + promptPreamble?: string; + }, + ThunkApiType +>( + "chat/gatherContext", + async ( + { modifiers, editorState, promptPreamble }, + { dispatch, extra, getState }, + ) => { + const state = getState(); + const defaultModel = defaultModelSelector(state); + const defaultContextProviders = + state.state.config.experimental?.defaultContext ?? []; + + // Resolve context providers and construct new history + const shouldGatherContext = + modifiers.useCodebase || hasSlashCommandOrContextProvider(editorState); + + if (shouldGatherContext) { + dispatch( + setIsGatheringContext({ + isGathering: true, + gatheringMessage: "Gathering Context", + }), + ); + } + + let [selectedContextItems, selectedCode, content] = + await resolveEditorContent( + editorState, + modifiers, + extra.ideMessenger, + defaultContextProviders, + dispatch, + ); + + // Automatically use currently open file + if (!modifiers.noContext) { + const usingFreeTrial = defaultModel?.provider === "free-trial"; + + const currentFile = await extra.ideMessenger.ide.getCurrentFile(); + if (currentFile) { + let currentFileContents = currentFile.contents; + if (usingFreeTrial) { + currentFileContents = currentFile.contents + .split("\n") + .slice(0, 1000) + .join("\n"); + } + if ( + !selectedContextItems.find( + (item) => item.uri?.value === currentFile.path, + ) + ) { + // don't add the file if it's already in the context items + selectedContextItems.unshift({ + content: `The following file is currently open. Don't reference it if it's not relevant to the user's message.\n\n\`\`\`${getRelativePath( + currentFile.path, + await extra.ideMessenger.ide.getWorkspaceDirs(), + )}\n${currentFileContents}\n\`\`\``, + name: `Active file: ${getBasename(currentFile.path)}`, + description: currentFile.path, + id: { + itemId: currentFile.path, + providerTitle: "file", + }, + uri: { + type: "file", + value: currentFile.path, + }, + }); + } + } + } + + await updateFileSymbolsFromContextItems( + selectedContextItems, + extra.ideMessenger, + dispatch, + ); + + if (promptPreamble) { + if (typeof content === "string") { + content = promptPreamble + content; + } else { + content[0].text = promptPreamble + content[0].text; + } + } + + // dispatch(addContextItems(contextItems)); + return { selectedContextItems, selectedCode, content }; + }, +); diff --git a/gui/src/redux/thunks/handleErrors.ts b/gui/src/redux/thunks/handleErrors.ts new file mode 100644 index 0000000000..6d7afb6327 --- /dev/null +++ b/gui/src/redux/thunks/handleErrors.ts @@ -0,0 +1,18 @@ +import { createAsyncThunk } from "@reduxjs/toolkit"; +import { setInactive } from "../slices/stateSlice"; +import { ThunkApiType } from "../store"; + +export const handleErrors = createAsyncThunk< + void, + () => Promise, + ThunkApiType +>("chat/handleErrors", async (runStream, { dispatch, extra }) => { + try { + await runStream(); + } catch (e: any) { + console.debug("Error streaming response: ", e); + } finally { + dispatch(setInactive()); + // triggerSave(!save); TODO + } +}); diff --git a/gui/src/redux/thunks/resetStateForNewMessage.ts b/gui/src/redux/thunks/resetStateForNewMessage.ts new file mode 100644 index 0000000000..e7f534e42d --- /dev/null +++ b/gui/src/redux/thunks/resetStateForNewMessage.ts @@ -0,0 +1,11 @@ +import { createAsyncThunk } from "@reduxjs/toolkit"; +import { resetNextCodeBlockToApplyIndex } from "../slices/stateSlice"; +import { ThunkApiType } from "../store"; + +export const resetStateForNewMessage = createAsyncThunk< + void, + undefined, + ThunkApiType +>("chat/resetStateForNewMessage", async (_, { dispatch }) => { + dispatch(resetNextCodeBlockToApplyIndex()); +}); diff --git a/gui/src/redux/thunks/streamNormalInput.ts b/gui/src/redux/thunks/streamNormalInput.ts new file mode 100644 index 0000000000..d17f2d8705 --- /dev/null +++ b/gui/src/redux/thunks/streamNormalInput.ts @@ -0,0 +1,91 @@ +import { createAsyncThunk } from "@reduxjs/toolkit"; +import { ChatMessage, PromptLog } from "core"; +import { defaultModelSelector } from "../selectors/modelSelectors"; +import { selectCurrentToolCall } from "../selectors/selectCurrentToolCall"; +import { + abortStream, + addPromptCompletionPair, + clearLastEmptyResponse, + setToolGenerated, + streamUpdate, +} from "../slices/stateSlice"; +import { ThunkApiType } from "../store"; +import { callTool } from "./callTool"; + +export const streamNormalInput = createAsyncThunk< + void, + ChatMessage[], + ThunkApiType +>("chat/streamNormalInput", async (messages, { dispatch, extra, getState }) => { + try { + // Gather state + const state = getState(); + const defaultModel = defaultModelSelector(state); + const toolSettings = state.uiState.toolSettings; + const streamAborter = state.state.streamAborter; + const useTools = state.uiState.useTools; + + if (!defaultModel) { + throw new Error("Default model not defined"); + } + + // Send request + const gen = extra.ideMessenger.llmStreamChat( + defaultModel.title, + streamAborter.signal, + messages, + { + tools: useTools + ? Object.keys(toolSettings) + .filter((tool) => toolSettings[tool] !== "disabled") + .map((toolName) => + state.state.config.tools.find( + (tool) => tool.function.name === toolName, + ), + ) + .filter(Boolean) + : undefined, + }, + ); + + // Stream response + let next = await gen.next(); + while (!next.done) { + if (!getState().state.active) { + dispatch(abortStream()); + break; + } + + const update = next.value as ChatMessage; + dispatch(streamUpdate(update)); + next = await gen.next(); + + // There has been lag when streaming tool calls. This is a temporary solution + if (update.role === "assistant" && update.toolCalls) { + await new Promise((resolve) => setTimeout(resolve, 10)); + } + } + + // Attach prompt log + let returnVal = next.value as PromptLog; + if (returnVal) { + dispatch(addPromptCompletionPair([returnVal])); + } + + // If it's a tool call that is automatically accepted, we should call it + const toolCallState = selectCurrentToolCall(getState()); + if (toolCallState) { + dispatch(setToolGenerated()); + + if ( + toolSettings[toolCallState.toolCall.function.name] === + "allowedWithoutPermission" + ) { + await dispatch(callTool()); + } + } + } catch (e) { + // If there's an error, we should clear the response so there aren't two input boxes + dispatch(clearLastEmptyResponse()); + } +}); diff --git a/gui/src/redux/thunks/streamResponse.ts b/gui/src/redux/thunks/streamResponse.ts new file mode 100644 index 0000000000..c1f6ba4d83 --- /dev/null +++ b/gui/src/redux/thunks/streamResponse.ts @@ -0,0 +1,157 @@ +import { createAsyncThunk, unwrapResult } from "@reduxjs/toolkit"; +import { JSONContent } from "@tiptap/core"; +import { + ChatHistoryItem, + ChatMessage, + InputModifiers, + MessageContent, + SlashCommandDescription, +} from "core"; +import { constructMessages } from "core/llm/constructMessages"; +import { renderChatMessage } from "core/util/messageContent"; +import posthog from "posthog-js"; +import { defaultModelSelector } from "../selectors/modelSelectors"; +import { + initNewActiveMessage, + resubmitAtIndex, + setCurCheckpointIndex, + setMessageAtIndex, +} from "../slices/stateSlice"; +import { ThunkApiType } from "../store"; +import { gatherContext } from "./gatherContext"; +import { handleErrors } from "./handleErrors"; +import { resetStateForNewMessage } from "./resetStateForNewMessage"; +import { streamNormalInput } from "./streamNormalInput"; +import { streamSlashCommand } from "./streamSlashCommand"; + +const getSlashCommandForInput = ( + input: MessageContent, + slashCommands: SlashCommandDescription[], +): [SlashCommandDescription, string] | undefined => { + let slashCommand: SlashCommandDescription | undefined; + let slashCommandName: string | undefined; + + let lastText = + typeof input === "string" + ? input + : input.filter((part) => part.type === "text").slice(-1)[0]?.text || ""; + + if (lastText.startsWith("/")) { + slashCommandName = lastText.split(" ")[0].substring(1); + slashCommand = slashCommands.find( + (command) => command.name === slashCommandName, + ); + } + if (!slashCommand || !slashCommandName) { + return undefined; + } + + // Convert to actual slash command object with runnable function + return [slashCommand, renderChatMessage({ role: "user", content: input })]; +}; + +export const streamResponseThunk = createAsyncThunk< + void, + { + editorState: JSONContent; + modifiers: InputModifiers; + index?: number; + promptPreamble?: string; + }, + ThunkApiType +>( + "chat/streamResponse", + async ( + { editorState, modifiers, index, promptPreamble }, + { dispatch, extra, getState }, + ) => { + await dispatch( + handleErrors(async () => { + const state = getState(); + const defaultModel = defaultModelSelector(state); + const slashCommands = state.state.config.slashCommands || []; + const inputIndex = index ?? state.state.history.length; + + if (typeof index === "number") { + dispatch(resubmitAtIndex({ index, editorState })); + } else { + dispatch(initNewActiveMessage({ editorState })); + } + + resetStateForNewMessage(); + + if (index) { + dispatch(setCurCheckpointIndex(Math.floor(index / 2))); + } + + const result = await dispatch( + gatherContext({ + editorState, + modifiers, + promptPreamble, + }), + ); + const unwrapped = unwrapResult(result); + const { selectedContextItems, selectedCode, content } = unwrapped; + + // Add the message to the history + const message: ChatMessage = { + role: "user", + content, + }; + const historyItem: ChatHistoryItem = { + message, + contextItems: selectedContextItems, + editorState, + }; + + dispatch( + setMessageAtIndex({ + message, + index: inputIndex, + contextItems: selectedContextItems, + }), + ); + + // Construct messages from updated history + const updatedHistory = getState().state.history; + const messages = constructMessages(updatedHistory, defaultModel.model); + + posthog.capture("step run", { + step_name: "User Input", + params: {}, + }); + posthog.capture("userInput", {}); + + // Determine if the input is a slash command + let commandAndInput = getSlashCommandForInput(content, slashCommands); + + if (!commandAndInput) { + await dispatch(streamNormalInput(messages)); + } else { + const [slashCommand, commandInput] = commandAndInput; + let updatedContextItems = []; + posthog.capture("step run", { + step_name: slashCommand.name, + params: {}, + }); + + // if (slashCommand.name === "multifile-edit") { + // dispatch(setIsInMultifileEdit(true)); + // } + + await dispatch( + streamSlashCommand({ + messages, + slashCommand, + input: commandInput, + historyIndex: inputIndex, + selectedCode, + contextItems: updatedContextItems, + }), + ); + } + }), + ); + }, +); diff --git a/gui/src/redux/thunks/streamResponseAfterToolCall.ts b/gui/src/redux/thunks/streamResponseAfterToolCall.ts new file mode 100644 index 0000000000..6652f3a0ed --- /dev/null +++ b/gui/src/redux/thunks/streamResponseAfterToolCall.ts @@ -0,0 +1,67 @@ +import { createAsyncThunk } from "@reduxjs/toolkit"; +import { ChatMessage, ContextItem } from "core"; +import { constructMessages } from "core/llm/constructMessages"; +import { renderContextItems } from "core/util/messageContent"; +import { defaultModelSelector } from "../selectors/modelSelectors"; +import { + addContextItemsAtIndex, + setActive, + streamUpdate, +} from "../slices/stateSlice"; +import { ThunkApiType } from "../store"; +import { handleErrors } from "./handleErrors"; +import { resetStateForNewMessage } from "./resetStateForNewMessage"; +import { streamNormalInput } from "./streamNormalInput"; + +export const streamResponseAfterToolCall = createAsyncThunk< + void, + { + toolCallId: string; + toolOutput: ContextItem[]; + }, + ThunkApiType +>( + "chat/streamAfterToolCall", + async ({ toolCallId, toolOutput }, { dispatch, getState, extra }) => { + await dispatch( + handleErrors(async () => { + const state = getState(); + const initialHistory = state.state.history; + const defaultModel = defaultModelSelector(state); + + resetStateForNewMessage(); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + const newMessage: ChatMessage = { + role: "tool", + content: renderContextItems(toolOutput), + toolCallId, + }; + + dispatch(streamUpdate(newMessage)); + dispatch( + addContextItemsAtIndex({ + index: initialHistory.length, + contextItems: toolOutput.map((contextItem) => ({ + ...contextItem, + id: { + providerTitle: "toolCall", + itemId: toolCallId, + }, + })), + }), + ); + + dispatch(setActive()); + + const updatedHistory = getState().state.history; + const messages = constructMessages( + [...updatedHistory], + defaultModel.model, + ); + await dispatch(streamNormalInput(messages)); + }), + ); + }, +); diff --git a/gui/src/redux/thunks/streamSlashCommand.ts b/gui/src/redux/thunks/streamSlashCommand.ts new file mode 100644 index 0000000000..4b6c717910 --- /dev/null +++ b/gui/src/redux/thunks/streamSlashCommand.ts @@ -0,0 +1,71 @@ +import { createAsyncThunk } from "@reduxjs/toolkit"; +import { + ChatMessage, + ContextItemWithId, + RangeInFile, + SlashCommandDescription, +} from "core"; +import { defaultModelSelector } from "../selectors/modelSelectors"; +import { abortStream, streamUpdate } from "../slices/stateSlice"; +import { ThunkApiType } from "../store"; + +export const streamSlashCommand = createAsyncThunk< + void, + { + messages: ChatMessage[]; + slashCommand: SlashCommandDescription; + input: string; + historyIndex: number; + selectedCode: RangeInFile[]; + contextItems: ContextItemWithId[]; + }, + ThunkApiType +>( + "chat/streamSlashCommand", + async ( + { messages, selectedCode, slashCommand, input, historyIndex, contextItems }, + { dispatch, getState, extra }, + ) => { + const state = getState(); + const defaultModel = defaultModelSelector(state); + const active = state.state.active; + const streamAborter = state.state.streamAborter; + + if (!defaultModel) { + throw new Error("Default model not defined"); + } + + const modelTitle = defaultModel.title; + + const checkActiveInterval = setInterval(() => { + if (!active) { + dispatch(abortStream()); + clearInterval(checkActiveInterval); + } + }, 100); + + for await (const update of extra.ideMessenger.streamRequest( + "command/run", + { + input, + history: messages, + modelTitle, + slashCommandName: slashCommand.name, + contextItems, + params: slashCommand.params, + historyIndex, + selectedCode, + }, + streamAborter.signal, + )) { + if (!getState().state.active) { + dispatch(abortStream()); + break; + } + if (typeof update === "string") { + dispatch(streamUpdate(update)); + } + } + clearInterval(checkActiveInterval); + }, +); diff --git a/gui/src/redux/util/index.ts b/gui/src/redux/util/index.ts new file mode 100644 index 0000000000..06ceebb4ab --- /dev/null +++ b/gui/src/redux/util/index.ts @@ -0,0 +1,8 @@ +import { ToolCallState } from "core"; +import { RootState } from "../store"; + +export function findCurrentToolCall( + state: RootState["state"]["history"], +): ToolCallState | undefined { + return state[state.length - 1]?.toolCallState; +} diff --git a/manual-testing-sandbox/test.ts b/manual-testing-sandbox/test.ts new file mode 100644 index 0000000000..9cf524421e --- /dev/null +++ b/manual-testing-sandbox/test.ts @@ -0,0 +1,41 @@ +class Calculator { + private result: number; + + constructor() { + this.result = 0; + } + + add(number: number): Calculator { + this.result += number; + return this; + } + + subtract(number: number): Calculator { + this.result -= number; + return this; + } + + multiply(number: number): Calculator { + this.result *= number; + return this; + } + + divide(number: number): Calculator { + if (number === 0) { + throw new Error("Cannot divide by zero"); + } + this.result /= number; + return this; + } + + getResult(): number { + return this.result; + } + + reset(): Calculator { + this.result = 0; + return this; + } +} + +export default Calculator; \ No newline at end of file diff --git a/packages/llm-info/package.json b/packages/llm-info/package.json index 4ba97f6b8a..af83039381 100644 --- a/packages/llm-info/package.json +++ b/packages/llm-info/package.json @@ -1,6 +1,6 @@ { "name": "@continuedev/llm-info", - "version": "1.0.6", + "version": "1.0.7", "description": "", "main": "dist/index.js", "types": "dist/index.d.ts", diff --git a/packages/llm-info/src/index.ts b/packages/llm-info/src/index.ts index 84d88e031e..fe71f419dd 100644 --- a/packages/llm-info/src/index.ts +++ b/packages/llm-info/src/index.ts @@ -22,7 +22,7 @@ export const allModelProviders: ModelProvider[] = [ Vllm, Bedrock, Cohere, - xAI + xAI, ]; export const allLlms: LlmInfo[] = allModelProviders.flatMap((provider) =>