@@ -204,6 +207,7 @@ function ModelCard(props: ModelCardProps) {
`
`}
`;
-const IconBase = styled.div<{ hovered: boolean }>`
+const IconBase = styled.div<{ $hovered: boolean }>`
width: 1.2em;
height: 1.2em;
cursor: pointer;
padding: 4px;
border-radius: ${defaultBorderRadius};
- opacity: ${(props) => (props.hovered ? 0.75 : 0)};
- visibility: ${(props) => (props.hovered ? "visible" : "hidden")};
+ opacity: ${(props) => (props.$hovered ? 0.75 : 0)};
+ visibility: ${(props) => (props.$hovered ? "visible" : "hidden")};
&:hover {
opacity: 1;
@@ -204,9 +204,9 @@ function ModelOption({
-
+
{showDelete && (
-
+
)}
diff --git a/gui/src/context/IdeMessenger.ts b/gui/src/context/IdeMessenger.ts
index c5c4f9b3e4..bddd3fbbee 100644
--- a/gui/src/context/IdeMessenger.ts
+++ b/gui/src/context/IdeMessenger.ts
@@ -156,7 +156,7 @@ export class IdeMessenger implements IIdeMessenger {
this.post(messageType, data, messageId);
- let buffer = "";
+ const buffer: any[] = [];
let index = 0;
let done = false;
let returnVal = undefined;
@@ -169,7 +169,7 @@ export class IdeMessenger implements IIdeMessenger {
done = true;
returnVal = responseData;
} else {
- buffer += responseData.content;
+ buffer.push(responseData.content);
}
}
};
@@ -181,16 +181,16 @@ export class IdeMessenger implements IIdeMessenger {
while (!done) {
if (buffer.length > index) {
- const chunk = buffer.slice(index);
- index = buffer.length;
+ const chunk = buffer[index];
+ index++;
yield chunk;
}
await new Promise((resolve) => setTimeout(resolve, 50));
}
- if (buffer.length > index) {
- const chunk = buffer.slice(index);
- index = buffer.length;
+ while (buffer.length > index) {
+ const chunk = buffer[index];
+ index++;
yield chunk;
}
@@ -215,7 +215,7 @@ export class IdeMessenger implements IIdeMessenger {
let next = await gen.next();
while (!next.done) {
- yield { role: "user", content: next.value };
+ yield next.value;
next = await gen.next();
}
diff --git a/gui/src/hooks/useChatHandler.ts b/gui/src/hooks/useChatHandler.ts
deleted file mode 100644
index 06e60102e7..0000000000
--- a/gui/src/hooks/useChatHandler.ts
+++ /dev/null
@@ -1,341 +0,0 @@
-import { Dispatch } from "@reduxjs/toolkit";
-import { JSONContent } from "@tiptap/react";
-import {
- ChatHistoryItem,
- ChatMessage,
- ContextItemWithId,
- InputModifiers,
- MessageContent,
- PromptLog,
- RangeInFile,
- SlashCommandDescription,
-} from "core";
-import { constructMessages } from "core/llm/constructMessages";
-import { stripImages } from "core/llm/images";
-import { getBasename, getRelativePath } from "core/util";
-import { usePostHog } from "posthog-js/react";
-import { useEffect, useMemo, useRef, useState } from "react";
-import { useSelector } from "react-redux";
-import resolveEditorContent, {
- hasSlashCommandOrContextProvider,
-} from "../components/mainInput/resolveInput";
-import { IIdeMessenger } from "../context/IdeMessenger";
-import { defaultModelSelector } from "../redux/selectors/modelSelectors";
-import {
- abortStream,
- addPromptCompletionPair,
- clearLastEmptyResponse,
- initNewActiveMessage,
- resubmitAtIndex,
- setCurCheckpointIndex,
- setInactive,
- setIsGatheringContext,
- setMessageAtIndex,
- streamUpdate,
-} from "../redux/slices/stateSlice";
-import { resetNextCodeBlockToApplyIndex } from "../redux/slices/stateSlice";
-import { RootState } from "../redux/store";
-import useHistory from "./useHistory";
-import { updateFileSymbolsFromContextItems } from "../util/symbols";
-import {
- selectDefaultContextProviders,
- selectSlashCommands,
-} from "../redux/selectors";
-
-function useChatHandler(dispatch: Dispatch, ideMessenger: IIdeMessenger) {
- const posthog = usePostHog();
-
- const defaultModel = useSelector(defaultModelSelector);
- const defaultContextProviders = useSelector(selectDefaultContextProviders);
-
- const slashCommands = useSelector(selectSlashCommands);
-
- const history = useSelector((store: RootState) => store.state.history);
- const active = useSelector((store: RootState) => store.state.active);
- const streamAborter = useSelector(
- (store: RootState) => store.state.streamAborter,
- );
- const activeRef = useRef(active);
-
- const { saveSession } = useHistory(dispatch);
- const [save, triggerSave] = useState(false);
-
- useEffect(() => {
- saveSession(false);
- }, [save]);
-
- useEffect(() => {
- activeRef.current = active;
- }, [active]);
-
- async function _streamNormalInput(messages: ChatMessage[]) {
- if (!defaultModel) {
- throw new Error("Default model not defined");
- }
- const gen = ideMessenger.llmStreamChat(
- defaultModel.title,
- streamAborter.signal,
- messages,
- );
- let next = await gen.next();
- while (!next.done) {
- if (!activeRef.current) {
- dispatch(abortStream());
- break;
- }
- dispatch(streamUpdate(stripImages((next.value as ChatMessage).content)));
- next = await gen.next();
- }
-
- let returnVal = next.value as PromptLog;
- if (returnVal) {
- dispatch(addPromptCompletionPair([returnVal]));
- }
- }
-
- const getSlashCommandForInput = (
- input: MessageContent,
- ): [SlashCommandDescription, string] | undefined => {
- let slashCommand: SlashCommandDescription | undefined;
- let slashCommandName: string | undefined;
-
- let lastText =
- typeof input === "string"
- ? input
- : input.filter((part) => part.type === "text").slice(-1)[0]?.text || "";
-
- if (lastText.startsWith("/")) {
- slashCommandName = lastText.split(" ")[0].substring(1);
- slashCommand = slashCommands.find(
- (command) => command.name === slashCommandName,
- );
- }
- if (!slashCommand || !slashCommandName) {
- return undefined;
- }
-
- // Convert to actual slash command object with runnable function
- return [slashCommand, stripImages(input)];
- };
-
- async function _streamSlashCommand(
- messages: ChatMessage[],
- slashCommand: SlashCommandDescription,
- input: string,
- historyIndex: number,
- selectedCode: RangeInFile[],
- contextItems: ContextItemWithId[],
- ) {
- if (!defaultModel) {
- throw new Error("Default model not defined");
- }
-
- const modelTitle = defaultModel.title;
-
- const checkActiveInterval = setInterval(() => {
- if (!activeRef.current) {
- dispatch(abortStream());
- clearInterval(checkActiveInterval);
- }
- }, 100);
-
- try {
- for await (const update of ideMessenger.streamRequest(
- "command/run",
- {
- input,
- history: messages,
- modelTitle,
- slashCommandName: slashCommand.name,
- contextItems,
- params: slashCommand.params,
- historyIndex,
- selectedCode,
- },
- streamAborter.signal,
- )) {
- if (!activeRef.current) {
- dispatch(abortStream());
- clearInterval(checkActiveInterval);
- break;
- }
- if (typeof update === "string") {
- dispatch(streamUpdate(update));
- }
- }
- } catch (e) {
- throw e;
- } finally {
- clearInterval(checkActiveInterval);
- }
- }
-
- async function streamResponse(
- editorState: JSONContent,
- modifiers: InputModifiers,
- ideMessenger: IIdeMessenger,
- index?: number,
- promptPreamble?: string,
- ) {
- try {
- if (typeof index === "number") {
- dispatch(resubmitAtIndex({ index, editorState }));
- } else {
- dispatch(initNewActiveMessage({ editorState }));
- }
-
- // Reset current code block index
- dispatch(resetNextCodeBlockToApplyIndex());
-
- if (index) {
- dispatch(setCurCheckpointIndex(Math.floor(index / 2)));
- }
-
- const shouldGatherContext =
- modifiers.useCodebase || hasSlashCommandOrContextProvider(editorState);
-
- if (shouldGatherContext) {
- dispatch(
- setIsGatheringContext({
- isGathering: true,
- gatheringMessage: "Gathering Context",
- }),
- );
- }
-
- // Resolve context providers and construct new history
- let [selectedContextItems, selectedCode, content] =
- await resolveEditorContent(
- editorState,
- modifiers,
- ideMessenger,
- defaultContextProviders,
- dispatch,
- );
-
- // Automatically use currently open file
- if (!modifiers.noContext) {
- const usingFreeTrial = defaultModel?.provider === "free-trial";
-
- const currentFile = await ideMessenger.ide.getCurrentFile();
- if (currentFile) {
- let currentFileContents = currentFile.contents;
- if (usingFreeTrial) {
- currentFileContents = currentFile.contents
- .split("\n")
- .slice(0, 1000)
- .join("\n");
- }
- if (
- !selectedContextItems.find(
- (item) => item.uri?.value === currentFile.path,
- )
- ) {
- // don't add the file if it's already in the context items
- selectedContextItems.unshift({
- content: `The following file is currently open. Don't reference it if it's not relevant to the user's message.\n\n\`\`\`${getRelativePath(
- currentFile.path,
- await ideMessenger.ide.getWorkspaceDirs(),
- )}\n${currentFileContents}\n\`\`\``,
- name: `Active file: ${getBasename(currentFile.path)}`,
- description: currentFile.path,
- id: {
- itemId: currentFile.path,
- providerTitle: "file",
- },
- uri: {
- type: "file",
- value: currentFile.path,
- },
- });
- }
- }
- }
-
- await updateFileSymbolsFromContextItems(
- selectedContextItems,
- ideMessenger,
- dispatch,
- );
- if (promptPreamble) {
- if (typeof content === "string") {
- content = promptPreamble + content;
- } else {
- content[0].text = promptPreamble + content[0].text;
- }
- }
-
- const message: ChatMessage = {
- role: "user",
- content,
- };
-
- const historyItem: ChatHistoryItem = {
- message,
- contextItems: selectedContextItems,
- editorState,
- };
-
- let newHistory: ChatHistoryItem[] = [
- ...history.slice(0, index),
- historyItem,
- ];
- const historyIndex = index || newHistory.length - 1;
- dispatch(
- setMessageAtIndex({
- message,
- index: historyIndex,
- contextItems: selectedContextItems,
- }),
- );
-
- // TODO: hacky way to allow rerender
- await new Promise((resolve) => setTimeout(resolve, 0));
-
- posthog.capture("step run", {
- step_name: "User Input",
- params: {},
- });
- posthog.capture("userInput", {});
-
- const messages = constructMessages(newHistory, defaultModel.model);
-
- // Determine if the input is a slash command
- let commandAndInput = getSlashCommandForInput(content);
-
- if (!commandAndInput) {
- await _streamNormalInput(messages);
- } else {
- const [slashCommand, commandInput] = commandAndInput;
-
- let updatedContextItems = [];
-
- posthog.capture("step run", {
- step_name: slashCommand.name,
- params: {},
- });
-
- await _streamSlashCommand(
- messages,
- slashCommand,
- commandInput,
- historyIndex,
- selectedCode,
- updatedContextItems,
- );
- }
- } catch (e: any) {
- dispatch(clearLastEmptyResponse());
- console.debug("Error streaming response: ", e);
- } finally {
- dispatch(setInactive());
- triggerSave(!save);
- }
- }
-
- return {
- streamResponse,
- };
-}
-
-export default useChatHandler;
diff --git a/gui/src/hooks/useHistory.tsx b/gui/src/hooks/useHistory.tsx
index e477dfba9d..32bee26e04 100644
--- a/gui/src/hooks/useHistory.tsx
+++ b/gui/src/hooks/useHistory.tsx
@@ -1,7 +1,7 @@
import { Dispatch } from "@reduxjs/toolkit";
import { Session, SessionMetadata } from "core";
-import { stripImages } from "core/llm/images";
+import { renderChatMessage } from "core/util/messageContent";
import { useCallback, useContext } from "react";
import { useSelector } from "react-redux";
import { IdeMessengerContext } from "../context/IdeMessenger";
@@ -77,7 +77,7 @@ function useHistory(dispatch: Dispatch) {
let newTitle =
currentTitle === "New Session"
? truncateText(
- stripImages(history[0].message.content)
+ renderChatMessage(history[0].message)
.split("\n")
.filter((l) => l.trim() !== "")
.slice(-1)[0] || "",
diff --git a/gui/src/hooks/useSetup.ts b/gui/src/hooks/useSetup.ts
index e1c1e103e6..c62f847ea5 100644
--- a/gui/src/hooks/useSetup.ts
+++ b/gui/src/hooks/useSetup.ts
@@ -1,4 +1,3 @@
-import { Dispatch } from "@reduxjs/toolkit";
import { useCallback, useContext, useEffect, useRef } from "react";
import { useSelector } from "react-redux";
import { VSC_THEME_COLOR_VARS } from "../components";
@@ -14,15 +13,15 @@ import {
updateDocsSuggestions,
updateIndexingStatus,
} from "../redux/slices/stateSlice";
-import { RootState } from "../redux/store";
+import { AppDispatch, RootState } from "../redux/store";
+import { streamResponseThunk } from "../redux/thunks/streamResponse";
import { isJetBrains } from "../util";
import { setLocalStorage } from "../util/localStorage";
-import useChatHandler from "./useChatHandler";
-import { useWebviewListener } from "./useWebviewListener";
import { updateFileSymbolsFromContextItems } from "../util/symbols";
+import { useWebviewListener } from "./useWebviewListener";
-function useSetup(dispatch: Dispatch) {
+function useSetup(dispatch: AppDispatch) {
const ideMessenger = useContext(IdeMessengerContext);
const history = useSelector((store: RootState) => store.state.history);
@@ -131,8 +130,6 @@ function useSetup(dispatch: Dispatch) {
dispatch(updateDocsSuggestions(data));
});
- const { streamResponse } = useChatHandler(dispatch, ideMessenger);
-
// IDE event listeners
useWebviewListener(
"getWebviewHistoryLength",
@@ -163,10 +160,11 @@ function useSetup(dispatch: Dispatch) {
// TODO - remove?
useWebviewListener("submitMessage", async (data) => {
- streamResponse(
- data.message,
- { useCodebase: false, noContext: true },
- ideMessenger,
+ dispatch(
+ streamResponseThunk({
+ editorState: data.message,
+ modifiers: { useCodebase: false, noContext: true },
+ }),
);
});
diff --git a/gui/src/pages/edit/Edit.tsx b/gui/src/pages/edit/Edit.tsx
index 201a4bc4cf..d1c8f343fa 100644
--- a/gui/src/pages/edit/Edit.tsx
+++ b/gui/src/pages/edit/Edit.tsx
@@ -1,27 +1,27 @@
import { ArrowLeftIcon } from "@heroicons/react/24/outline";
import { Editor, JSONContent } from "@tiptap/core";
import { InputModifiers, RangeInFileWithContents } from "core";
-import { stripImages } from "core/llm/images";
+import { stripImages } from "core/util/messageContent";
import { useCallback, useContext, useEffect, useMemo } from "react";
-import { useDispatch, useSelector } from "react-redux";
+import { useSelector } from "react-redux";
import { useNavigate } from "react-router-dom";
+import ContinueInputBox from "../../components/mainInput/ContinueInputBox";
import { NewSessionButton } from "../../components/mainInput/NewSessionButton";
import resolveEditorContent from "../../components/mainInput/resolveInput";
import TipTapEditor from "../../components/mainInput/TipTapEditor";
+import StepContainer from "../../components/StepContainer";
+import AcceptRejectAllButtons from "../../components/StepContainer/AcceptRejectAllButtons";
import { IdeMessengerContext } from "../../context/IdeMessenger";
+import { selectApplyState } from "../../redux/selectors";
import {
clearCodeToEdit,
setEditDone,
submitEdit,
} from "../../redux/slices/editModeState";
-import { RootState } from "../../redux/store";
+import { RootState, useAppDispatch } from "../../redux/store";
+import { streamResponseThunk } from "../../redux/thunks/streamResponse";
import CodeToEdit from "./CodeToEdit";
-import useChatHandler from "../../hooks/useChatHandler";
-import AcceptRejectAllButtons from "../../components/StepContainer/AcceptRejectAllButtons";
-import ContinueInputBox from "../../components/mainInput/ContinueInputBox";
-import StepContainer from "../../components/StepContainer";
import getMultifileEditPrompt from "./getMultifileEditPrompt";
-import { selectApplyState } from "../../redux/selectors";
const EDIT_DISALLOWED_CONTEXT_PROVIDERS = [
"codebase",
@@ -36,10 +36,9 @@ const EDIT_DISALLOWED_CONTEXT_PROVIDERS = [
];
export default function Edit() {
- const dispatch = useDispatch();
+ const dispatch = useAppDispatch();
const navigate = useNavigate();
const ideMessenger = useContext(IdeMessengerContext);
- const { streamResponse } = useChatHandler(dispatch, ideMessenger);
const editModeState = useSelector((state: RootState) => state.editModeState);
const availableContextProviders = useSelector(
(store: RootState) => store.state.config.contextProviders,
@@ -138,12 +137,12 @@ export default function Edit() {
} else {
const promptPreamble = getMultifileEditPrompt(editModeState.codeToEdit);
- streamResponse(
- editorState,
- modifiers,
- ideMessenger,
- undefined,
- promptPreamble,
+ dispatch(
+ streamResponseThunk({
+ editorState,
+ modifiers,
+ promptPreamble,
+ }),
);
}
}
@@ -184,11 +183,12 @@ export default function Edit() {
{item.message.role === "user" ? (
{
- streamResponse(
- editorState,
- modifiers,
- ideMessenger,
- index,
+ dispatch(
+ streamResponseThunk({
+ editorState,
+ modifiers,
+ index,
+ }),
);
}}
isLastUserInput={isLastUserInput(index)}
diff --git a/gui/src/pages/gui/Chat.tsx b/gui/src/pages/gui/Chat.tsx
index adea3ec80f..f297e71da3 100644
--- a/gui/src/pages/gui/Chat.tsx
+++ b/gui/src/pages/gui/Chat.tsx
@@ -5,10 +5,10 @@ import {
ExclamationTriangleIcon,
} from "@heroicons/react/24/outline";
import { JSONContent } from "@tiptap/react";
-import { InputModifiers } from "core";
+import { InputModifiers, ToolCallState } from "core";
+import { streamResponse } from "core/llm/stream";
import { usePostHog } from "posthog-js/react";
import {
- Fragment,
useCallback,
useContext,
useEffect,
@@ -17,7 +17,7 @@ import {
useState,
} from "react";
import { ErrorBoundary } from "react-error-boundary";
-import { useDispatch, useSelector } from "react-redux";
+import { useSelector } from "react-redux";
import styled from "styled-components";
import {
Button,
@@ -26,8 +26,9 @@ import {
vscBackground,
} from "../../components";
import { ChatScrollAnchor } from "../../components/ChatScrollAnchor";
-import StepContainer from "../../components/StepContainer";
+import { useFindWidget } from "../../components/find/FindWidget";
import TimelineItem from "../../components/gui/TimelineItem";
+import ChatIndexingPeeks from "../../components/indexing/ChatIndexingPeeks";
import ContinueInputBox from "../../components/mainInput/ContinueInputBox";
import { NewSessionButton } from "../../components/mainInput/NewSessionButton";
import { TutorialCard } from "../../components/mainInput/TutorialCard";
@@ -35,12 +36,13 @@ import {
OnboardingCard,
useOnboardingCard,
} from "../../components/OnboardingCard";
+import StepContainer from "../../components/StepContainer";
import { IdeMessengerContext } from "../../context/IdeMessenger";
-import useChatHandler from "../../hooks/useChatHandler";
import useHistory from "../../hooks/useHistory";
import { useTutorialCard } from "../../hooks/useTutorialCard";
import { useWebviewListener } from "../../hooks/useWebviewListener";
import { defaultModelSelector } from "../../redux/selectors/modelSelectors";
+import { selectCurrentToolCall } from "../../redux/selectors/selectCurrentToolCall";
import {
clearLastEmptyResponse,
newSession,
@@ -51,7 +53,8 @@ import {
setDialogMessage,
setShowDialog,
} from "../../redux/slices/uiStateSlice";
-import { RootState } from "../../redux/store";
+import { RootState, useAppDispatch } from "../../redux/store";
+import { streamResponseThunk } from "../../redux/thunks/streamResponse";
import {
getFontSize,
getMetaKeyLabel,
@@ -60,8 +63,9 @@ import {
import { FREE_TRIAL_LIMIT_REQUESTS } from "../../util/freeTrial";
import { getLocalStorage, setLocalStorage } from "../../util/localStorage";
import ConfigErrorIndicator from "./ConfigError";
-import ChatIndexingPeeks from "../../components/indexing/ChatIndexingPeeks";
-import { useFindWidget } from "../../components/find/FindWidget";
+import { ToolCallDiv } from "./ToolCallDiv";
+import { ToolCallButtons } from "./ToolCallDiv/ToolCallButtonsDiv";
+import ToolOutput from "./ToolCallDiv/ToolOutput";
const StopButton = styled.div`
background-color: ${vscBackground};
@@ -95,7 +99,7 @@ const StepsDiv = styled.div`
}
.thread-message {
- margin: 8px 4px 0 4px;
+ margin: 0px 4px 0 4px;
}
`;
@@ -110,6 +114,7 @@ function fallbackRender({ error, resetErrorBoundary }: any) {
>
Something went wrong:
{error.message}
+ {error.stack}
@@ -120,9 +125,8 @@ function fallbackRender({ error, resetErrorBoundary }: any) {
export function Chat() {
const posthog = usePostHog();
- const dispatch = useDispatch();
+ const dispatch = useAppDispatch();
const ideMessenger = useContext(IdeMessengerContext);
- const { streamResponse } = useChatHandler(dispatch, ideMessenger);
const onboardingCard = useOnboardingCard();
const { showTutorialCard, closeTutorialCard } = useTutorialCard();
const defaultModel = useSelector(defaultModelSelector);
@@ -139,6 +143,10 @@ export function Chat() {
const { saveSession, getLastSessionId, loadLastSession } =
useHistory(dispatch);
+ const toolCallState = useSelector
(
+ selectCurrentToolCall,
+ );
+
const snapToBottom = useCallback(() => {
if (!stepsDivRef.current) return;
const elem = stepsDivRef.current;
@@ -176,7 +184,7 @@ export function Chat() {
isMetaEquivalentKeyPressed(e) &&
!e.shiftKey
) {
- dispatch(setInactive());
+ // dispatch(cancelGeneration()); TODO!!!
}
};
window.addEventListener("keydown", listener);
@@ -221,7 +229,7 @@ export function Chat() {
}
}
- streamResponse(editorState, modifiers, ideMessenger);
+ dispatch(streamResponseThunk({ editorState, modifiers }));
// Increment localstorage counter for popup
const currentCount = getLocalStorage("mainTextEntryCounter");
@@ -328,7 +336,12 @@ export function Chat() {
>
{highlights}
{history.map((item, index: number) => (
-
+
{
@@ -338,13 +351,36 @@ export function Chat() {
{item.message.role === "user" ? (
{
- streamResponse(editorState, modifiers, ideMessenger, index);
+ dispatch(
+ streamResponseThunk({ editorState, modifiers, index }),
+ );
}}
isLastUserInput={isLastUserInput(index)}
isMainInput={false}
editorState={item.editorState}
contextItems={item.contextItems}
/>
+ ) : item.message.role === "tool" ? (
+
+ ) : item.message.role === "assistant" &&
+ item.message.toolCalls &&
+ item.toolCallState ? (
+
+ {item.message.toolCalls?.map((toolCall, i) => {
+ return (
+
+
+
+ );
+ })}
+
) : (
)}
-
+
))}
)}
+
+ {toolCallState?.status === "generated" && }
+ ) : null;
+}
diff --git a/gui/src/pages/gui/ToolCallDiv/FunctionSpecificToolCallDiv.tsx b/gui/src/pages/gui/ToolCallDiv/FunctionSpecificToolCallDiv.tsx
new file mode 100644
index 0000000000..d55055b031
--- /dev/null
+++ b/gui/src/pages/gui/ToolCallDiv/FunctionSpecificToolCallDiv.tsx
@@ -0,0 +1,31 @@
+import { ToolCall, ToolCallState } from "core";
+import { CreateFile } from "./CreateFile";
+import { RunTerminalCommand } from "./RunTerminalCommand";
+
+function FunctionSpecificToolCallDiv({
+ toolCall,
+ toolCallState,
+}: {
+ toolCall: ToolCall;
+ toolCallState: ToolCallState;
+}) {
+ const args = toolCallState.parsedArgs;
+
+ switch (toolCall.function.name) {
+ case "create_new_file":
+ return (
+
+ );
+ case "run_terminal_command":
+ return (
+
+ );
+ default:
+ return null;
+ }
+}
+
+export default FunctionSpecificToolCallDiv;
diff --git a/gui/src/pages/gui/ToolCallDiv/RunTerminalCommand.tsx b/gui/src/pages/gui/ToolCallDiv/RunTerminalCommand.tsx
new file mode 100644
index 0000000000..35fc0e54f5
--- /dev/null
+++ b/gui/src/pages/gui/ToolCallDiv/RunTerminalCommand.tsx
@@ -0,0 +1,16 @@
+import { ToolCallState } from "core";
+import StyledMarkdownPreview from "../../../components/markdown/StyledMarkdownPreview";
+
+interface RunTerminalCommandToolCallProps {
+ command: string;
+ toolCallState: ToolCallState;
+}
+
+export function RunTerminalCommand(props: RunTerminalCommandToolCallProps) {
+ return (
+
+ );
+}
diff --git a/gui/src/pages/gui/ToolCallDiv/ThreadDiv.tsx b/gui/src/pages/gui/ToolCallDiv/ThreadDiv.tsx
new file mode 100644
index 0000000000..e02eaca8f2
--- /dev/null
+++ b/gui/src/pages/gui/ToolCallDiv/ThreadDiv.tsx
@@ -0,0 +1,78 @@
+import { Tool, ToolCall, ToolCallState } from "core";
+import Mustache from "mustache";
+import { useSelector } from "react-redux";
+import styled from "styled-components";
+import { RootState } from "../../../redux/store";
+
+interface ThreadDivProps {
+ children: React.ReactNode;
+ icon: React.ReactNode;
+ toolCall: ToolCall;
+ toolCallState: ToolCallState;
+ reactKey: string;
+}
+
+const Container = styled.div`
+ display: flex;
+ flex-direction: column;
+ justify-content: center;
+ position: relative;
+ padding: 8px;
+ padding-bottom: 0;
+`;
+
+const ChildrenDiv = styled.div``;
+
+const W = 16;
+
+const HeaderDiv = styled.div`
+ display: flex;
+ align-items: center;
+ gap: 6px;
+ margin-bottom: 8px;
+`;
+
+export function ThreadDiv(props: ThreadDivProps) {
+ const availableTools = useSelector(
+ (state: RootState) => state.state.config.tools,
+ );
+
+ function renderWouldLikeToMessage(
+ tool: Tool | undefined,
+ toolCallState: ToolCallState,
+ ): string {
+ if (!tool) return "";
+
+ const rendered = Mustache.render(
+ tool.wouldLikeTo,
+ toolCallState.parsedArgs,
+ );
+ return rendered;
+ }
+
+ return (
+
+
+
+ {props.icon}
+
+ Continue wants to{" "}
+ {renderWouldLikeToMessage(
+ availableTools.find(
+ (tool) => props.toolCall.function.name === tool.function.name,
+ ),
+ props.toolCallState,
+ )}
+
+ {props.children}
+
+ );
+}
diff --git a/gui/src/pages/gui/ToolCallDiv/ToolCallButtonsDiv.tsx b/gui/src/pages/gui/ToolCallDiv/ToolCallButtonsDiv.tsx
new file mode 100644
index 0000000000..d088091ca8
--- /dev/null
+++ b/gui/src/pages/gui/ToolCallDiv/ToolCallButtonsDiv.tsx
@@ -0,0 +1,91 @@
+import { useSelector } from "react-redux";
+import styled from "styled-components";
+import {
+ defaultBorderRadius,
+ lightGray,
+ vscButtonBackground,
+ vscButtonForeground,
+} from "../../../components";
+import Spinner from "../../../components/markdown/StepContainerPreToolbar/Spinner";
+import { selectCurrentToolCall } from "../../../redux/selectors/selectCurrentToolCall";
+import { useAppDispatch } from "../../../redux/store";
+import { callTool } from "../../../redux/thunks/callTool";
+import { cancelTool } from "../../../redux/thunks/cancelTool";
+
+const ButtonContainer = styled.div`
+ display: flex;
+ gap: 8px;
+ margin-top: 12px;
+ margin: 8px;
+`;
+
+const Button = styled.button`
+ padding: 5px;
+ border-radius: ${defaultBorderRadius};
+ flex: 1;
+
+ &:hover {
+ cursor: pointer;
+ opacity: 0.8;
+ }
+`;
+
+const AcceptButton = styled(Button)`
+ color: ${vscButtonForeground};
+ border: none;
+ background-color: ${vscButtonBackground};
+ color: ${vscButtonForeground};
+
+ &:hover {
+ cursor: pointer;
+ }
+`;
+
+const RejectButton = styled(Button)`
+ color: ${lightGray};
+ border: 1px solid ${lightGray};
+ background-color: transparent;
+`;
+
+interface ToolCallButtonsProps {}
+
+export function ToolCallButtons(props: ToolCallButtonsProps) {
+ const dispatch = useAppDispatch();
+ const toolCallState = useSelector(selectCurrentToolCall);
+
+ if (!toolCallState) {
+ return null;
+ }
+
+ return (
+ <>
+
+ {toolCallState.status === "generating" ? (
+
+ Thinking...
+
+ ) : toolCallState.status === "generated" ? (
+ <>
+ dispatch(cancelTool())}>
+ Cancel
+
+ dispatch(callTool())}>
+ Continue
+
+ >
+ ) : toolCallState.status === "calling" ? (
+
+ Loading...
+
+
+ ) : null}
+
+ >
+ );
+}
diff --git a/gui/src/pages/gui/ToolCallDiv/ToolOutput.tsx b/gui/src/pages/gui/ToolCallDiv/ToolOutput.tsx
new file mode 100644
index 0000000000..46057e90c5
--- /dev/null
+++ b/gui/src/pages/gui/ToolCallDiv/ToolOutput.tsx
@@ -0,0 +1,24 @@
+import { ContextItemWithId } from "core";
+import ContextItemsPeek from "../../../components/mainInput/ContextItemsPeek";
+
+interface ToolOutputProps {
+ contextItems: ContextItemWithId[];
+ toolCallId: string;
+}
+
+function ToolOutput(props: ToolOutputProps) {
+ if (props.contextItems.length === 0) {
+ return null;
+ }
+
+ return (
+
+
+
+ );
+}
+
+export default ToolOutput;
diff --git a/gui/src/pages/gui/ToolCallDiv/index.tsx b/gui/src/pages/gui/ToolCallDiv/index.tsx
new file mode 100644
index 0000000000..e1745c4596
--- /dev/null
+++ b/gui/src/pages/gui/ToolCallDiv/index.tsx
@@ -0,0 +1,46 @@
+import {
+ ArrowRightIcon,
+ CheckIcon,
+ XMarkIcon,
+} from "@heroicons/react/24/outline";
+import { ToolCall, ToolCallState, ToolStatus } from "core";
+import { vscButtonBackground } from "../../../components";
+import Spinner from "../../../components/markdown/StepContainerPreToolbar/Spinner";
+import FunctionSpecificToolCallDiv from "./FunctionSpecificToolCallDiv";
+import { ThreadDiv } from "./ThreadDiv";
+
+interface ToolCallDivProps {
+ toolCall: ToolCall;
+ toolCallState: ToolCallState;
+ reactKey: string;
+}
+
+export function ToolCallDiv(props: ToolCallDivProps) {
+ function getIcon(state: ToolStatus) {
+ switch (state) {
+ case "generating":
+ case "calling":
+ return ;
+ case "generated":
+ return ;
+ case "done":
+ return ;
+ case "canceled":
+ return ;
+ }
+ }
+
+ return (
+
+
+
+ );
+}
diff --git a/gui/src/pages/settings.tsx b/gui/src/pages/settings.tsx
index 613d5f2b16..6c871962a7 100644
--- a/gui/src/pages/settings.tsx
+++ b/gui/src/pages/settings.tsx
@@ -172,6 +172,7 @@ function Settings() {
System Message
Temperature
store.state.history,
+ (history) => {
+ return findCurrentToolCall(history);
+ },
+);
diff --git a/gui/src/redux/selectors/selectLastToolCall.ts b/gui/src/redux/selectors/selectLastToolCall.ts
new file mode 100644
index 0000000000..5509ab4766
--- /dev/null
+++ b/gui/src/redux/selectors/selectLastToolCall.ts
@@ -0,0 +1,19 @@
+import { createSelector } from "@reduxjs/toolkit";
+import { ToolCallState } from "core";
+import { RootState } from "../store";
+
+export const selectLastToolCall = createSelector(
+ [(store: RootState) => store.state.history],
+ (history): ToolCallState | null => {
+ let lastToolCallHistoryItem = null;
+ for (let i = history.length - 1; i >= 0; i--) {
+ const item = history[i];
+ if (item.message.role === "assistant" && item.message.toolCalls?.length) {
+ lastToolCallHistoryItem = item;
+ break;
+ }
+ }
+ if (!lastToolCallHistoryItem) return null;
+ return lastToolCallHistoryItem.toolCallState;
+ },
+);
diff --git a/gui/src/redux/slices/stateSlice.ts b/gui/src/redux/slices/stateSlice.ts
index 7b32247ff4..6f542a1b26 100644
--- a/gui/src/redux/slices/stateSlice.ts
+++ b/gui/src/redux/slices/stateSlice.ts
@@ -1,21 +1,31 @@
-import { PayloadAction, createSlice } from "@reduxjs/toolkit";
+import {
+ ActionReducerMapBuilder,
+ AsyncThunk,
+ PayloadAction,
+ createSlice,
+} from "@reduxjs/toolkit";
import { JSONContent } from "@tiptap/react";
import {
ApplyState,
ChatHistoryItem,
ChatMessage,
Checkpoint,
+ ContextItem,
ContextItemWithId,
- PackageDocsResult,
FileSymbolMap,
IndexingStatus,
- Session,
+ PackageDocsResult,
PromptLog,
+ Session,
+ ToolCall,
} from "core";
import { BrowserSerializedContinueConfig } from "core/config/load";
import { ConfigValidationError } from "core/config/validation";
-import { stripImages } from "core/llm/images";
+import { incrementalParseJson } from "core/util/incrementalParseJson";
+import { renderChatMessage } from "core/util/messageContent";
import { v4 as uuidv4, v4 } from "uuid";
+import { streamResponseThunk } from "../thunks/streamResponse";
+import { findCurrentToolCall } from "../util";
// We need this to handle reorderings (e.g. a mid-array deletion) of the messages array.
// The proper fix is adding a UUID to all chat messages, but this is the temp workaround.
@@ -74,6 +84,7 @@ const initialState: State = {
],
contextProviders: [],
models: [],
+ tools: [],
},
title: "New Session",
sessionId: v4(),
@@ -252,7 +263,7 @@ export const stateSlice = createSlice({
message: { ...payload.message, id: uuidv4() },
editorState: {
type: "doc",
- content: stripImages(payload.message.content)
+ content: renderChatMessage(payload.message)
.split("\n")
.map((line) => ({
type: "paragraph",
@@ -295,10 +306,72 @@ export const stateSlice = createSlice({
state.streamAborter.abort();
state.streamAborter = new AbortController();
},
- streamUpdate: (state, action: PayloadAction) => {
+ streamUpdate: (state, action: PayloadAction) => {
if (state.history.length) {
- state.history[state.history.length - 1].message.content +=
- action.payload;
+ const lastMessage = state.history[state.history.length - 1];
+
+ if (
+ action.payload.role &&
+ (lastMessage.message.role !== action.payload.role ||
+ // This is when a tool call comes after assistant text
+ (lastMessage.message.content !== "" &&
+ action.payload.role === "assistant" &&
+ action.payload.toolCalls?.length))
+ ) {
+ // Create a new message
+ const historyItem: ChatHistoryItemWithMessageId = {
+ contextItems: [],
+ message: { id: uuidv4(), ...action.payload },
+ };
+
+ if (action.payload.role === "assistant" && action.payload.toolCalls) {
+ const [_, parsedArgs] = incrementalParseJson(
+ action.payload.toolCalls[0].function.arguments,
+ );
+ historyItem.toolCallState = {
+ status: "generating",
+ toolCall: action.payload.toolCalls[0] as ToolCall,
+ toolCallId: action.payload.toolCalls[0].id,
+ parsedArgs,
+ };
+ }
+
+ state.history.push(historyItem);
+ } else {
+ // Add to the existing message
+ const msg = state.history[state.history.length - 1].message;
+ if (action.payload.content) {
+ msg.content += renderChatMessage(action.payload);
+ } else if (
+ action.payload.role === "assistant" &&
+ action.payload.toolCalls &&
+ msg.role === "assistant"
+ ) {
+ if (!msg.toolCalls) {
+ msg.toolCalls = [];
+ }
+ action.payload.toolCalls.forEach((toolCall, i) => {
+ if (msg.toolCalls.length <= i) {
+ msg.toolCalls.push(toolCall);
+ } else {
+ msg.toolCalls[i].function.arguments +=
+ toolCall.function.arguments;
+
+ const [_, parsedArgs] = incrementalParseJson(
+ msg.toolCalls[i].function.arguments,
+ );
+
+ state.history[
+ state.history.length - 1
+ ].toolCallState.parsedArgs = parsedArgs;
+ state.history[
+ state.history.length - 1
+ ].toolCallState.toolCall.function.arguments +=
+ toolCall.function.arguments;
+ }
+ });
+ }
+ }
}
},
newSession: (state, { payload }: PayloadAction) => {
@@ -308,6 +381,7 @@ export const stateSlice = createSlice({
state.active = false;
state.context.isGathering = false;
state.symbols = {};
+
if (payload) {
state.history = payload.history as any;
state.title = payload.title;
@@ -405,6 +479,38 @@ export const stateSlice = createSlice({
resetNextCodeBlockToApplyIndex: (state) => {
state.nextCodeBlockToApplyIndex = 0;
},
+
+ // Related to currentToolCallState
+ setToolGenerated: (state) => {
+ const toolCallState = findCurrentToolCall(state.history);
+ if (!toolCallState) return;
+
+ toolCallState.status = "generated";
+ },
+ setToolCallOutput: (state, action: PayloadAction) => {
+ const toolCallState = findCurrentToolCall(state.history);
+ if (!toolCallState) return;
+
+ toolCallState.output = action.payload;
+ },
+ cancelToolCall: (state) => {
+ const toolCallState = findCurrentToolCall(state.history);
+ if (!toolCallState) return;
+
+ toolCallState.status = "canceled";
+ },
+ acceptToolCall: (state) => {
+ const toolCallState = findCurrentToolCall(state.history);
+ if (!toolCallState) return;
+
+ toolCallState.status = "done";
+ },
+ setCalling: (state) => {
+ const toolCallState = findCurrentToolCall(state.history);
+ if (!toolCallState) return;
+
+ toolCallState.status = "calling";
+ },
updateIndexingStatus: (
state,
{ payload }: PayloadAction,
@@ -448,8 +554,23 @@ export const stateSlice = createSlice({
state.docsSuggestions = payload;
},
},
+
+ extraReducers: (builder) => {
+ addPassthroughCases(builder, [streamResponseThunk]);
+ },
});
+function addPassthroughCases(
+ builder: ActionReducerMapBuilder,
+ thunks: AsyncThunk[],
+) {
+ thunks.forEach((thunk) => {
+ builder
+ .addCase(thunk.fulfilled, (state, action) => {})
+ .addCase(thunk.rejected, (state, action) => {})
+ .addCase(thunk.pending, (state, action) => {});
+ });
+}
export const {
updateFileSymbols,
setContextItemsAtIndex,
@@ -477,10 +598,15 @@ export const {
setCurCheckpointIndex,
resetNextCodeBlockToApplyIndex,
updateApplyState,
+ abortStream,
updateIndexingStatus,
setIndexingChatPeekHidden,
- abortStream,
+ setCalling,
+ cancelToolCall,
+ acceptToolCall,
+ setToolGenerated,
updateDocsSuggestions,
+ setToolCallOutput,
} = stateSlice.actions;
export default stateSlice.reducer;
diff --git a/gui/src/redux/slices/uiStateSlice.ts b/gui/src/redux/slices/uiStateSlice.ts
index f952b5fc94..3472eb0184 100644
--- a/gui/src/redux/slices/uiStateSlice.ts
+++ b/gui/src/redux/slices/uiStateSlice.ts
@@ -1,17 +1,28 @@
import { createSlice, PayloadAction } from "@reduxjs/toolkit";
+import { Tool } from "core";
+import { BuiltInToolNames } from "core/tools/builtIn";
import {
defaultOnboardingCardState,
OnboardingCardState,
} from "../../components/OnboardingCard";
+type ToolSetting =
+ | "allowedWithPermission"
+ | "allowedWithoutPermission"
+ | "disabled";
+
type UiState = {
showDialog: boolean;
dialogMessage: string | JSX.Element | undefined;
dialogEntryOn: boolean;
onboardingCard: OnboardingCardState;
shouldAddFileForEditing: boolean;
+ useTools: boolean;
+ toolSettings: { [toolName: string]: ToolSetting };
};
+export const DEFAULT_TOOL_SETTING: ToolSetting = "allowedWithPermission";
+
export const uiStateSlice = createSlice({
name: "uiState",
initialState: {
@@ -20,6 +31,18 @@ export const uiStateSlice = createSlice({
dialogEntryOn: false,
onboardingCard: defaultOnboardingCardState,
shouldAddFileForEditing: false,
+
+ useTools: false,
+ toolSettings: {
+ [BuiltInToolNames.ReadFile]: "allowedWithoutPermission",
+ [BuiltInToolNames.CreateNewFile]: "allowedWithPermission",
+ [BuiltInToolNames.RunTerminalCommand]: "allowedWithPermission",
+ [BuiltInToolNames.ViewSubdirectory]: "allowedWithoutPermission",
+ [BuiltInToolNames.ViewRepoMap]: "allowedWithoutPermission",
+ [BuiltInToolNames.ExactSearch]: "allowedWithoutPermission",
+ [BuiltInToolNames.SearchWeb]: "allowedWithoutPermission",
+ [BuiltInToolNames.ViewDiff]: "allowedWithoutPermission",
+ },
} as UiState,
reducers: {
setOnboardingCard: (
@@ -43,6 +66,33 @@ export const uiStateSlice = createSlice({
setShowDialog: (state, action: PayloadAction) => {
state.showDialog = action.payload;
},
+ // Tools
+ toggleUseTools: (state) => {
+ state.useTools = !state.useTools;
+ },
+ addTool: (state, action: PayloadAction) => {
+ state.toolSettings[action.payload.function.name] = action.payload.readonly
+ ? "allowedWithoutPermission"
+ : "allowedWithPermission";
+ },
+ toggleToolSetting: (state, action: PayloadAction) => {
+ const setting = state.toolSettings[action.payload];
+
+ switch (setting) {
+ case "allowedWithPermission":
+ state.toolSettings[action.payload] = "allowedWithoutPermission";
+ break;
+ case "allowedWithoutPermission":
+ state.toolSettings[action.payload] = "disabled";
+ break;
+ case "disabled":
+ state.toolSettings[action.payload] = "allowedWithPermission";
+ break;
+ default:
+ state.toolSettings[action.payload] = DEFAULT_TOOL_SETTING;
+ break;
+ }
+ },
},
});
@@ -51,6 +101,9 @@ export const {
setDialogMessage,
setDialogEntryOn,
setShowDialog,
+ toggleUseTools,
+ toggleToolSetting,
+ addTool,
} = uiStateSlice.actions;
export default uiStateSlice.reducer;
diff --git a/gui/src/redux/store.ts b/gui/src/redux/store.ts
index 87edf3c663..a594b30d5a 100644
--- a/gui/src/redux/store.ts
+++ b/gui/src/redux/store.ts
@@ -5,10 +5,12 @@ import serverStateReducer from "./slices/serverStateReducer";
import stateReducer from "./slices/stateSlice";
import uiStateReducer from "./slices/uiStateSlice";
+import { useDispatch } from "react-redux";
import { createTransform, persistReducer, persistStore } from "redux-persist";
import { createFilter } from "redux-persist-transform-filter";
import autoMergeLevel2 from "redux-persist/lib/stateReconciler/autoMergeLevel2";
import storage from "redux-persist/lib/storage";
+import { IdeMessenger, IIdeMessenger } from "../context/IdeMessenger";
import editModeStateReducer from "./slices/editModeState";
export interface ChatMessage {
role: "system" | "user" | "assistant";
@@ -24,8 +26,6 @@ const rootReducer = combineReducers({
editModeState: editModeStateReducer,
});
-export type RootState = ReturnType;
-
const windowIDTransform = (windowID: string) =>
createTransform(
// transform state on its way to being serialized and persisted.
@@ -40,6 +40,7 @@ const windowIDTransform = (windowID: string) =>
const saveSubsetFilters = [
createFilter("state", ["history", "sessionId", "defaultModelTitle"]),
+ // Don't persist any of the edit state for now
createFilter("editModeState", []),
];
@@ -57,15 +58,32 @@ const persistedReducer = persistReducer(persistConfig, rootReducer);
export function setupStore() {
return configureStore({
- reducer: persistedReducer,
+ // persistedReducer causes type errors with async thunks
+ reducer: persistedReducer as unknown as typeof rootReducer,
// reducer: rootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
serializableCheck: false,
+ thunk: {
+ extraArgument: {
+ ideMessenger: new IdeMessenger(),
+ },
+ },
}),
});
}
+export type ThunkApiType = {
+ state: RootState;
+ extra: { ideMessenger: IIdeMessenger };
+};
+
export const store = setupStore();
+export type RootState = ReturnType;
+
+export type AppDispatch = typeof store.dispatch;
+
+export const useAppDispatch = () => useDispatch();
+
export const persistor = persistStore(store);
diff --git a/gui/src/redux/thunks/callTool.ts b/gui/src/redux/thunks/callTool.ts
new file mode 100644
index 0000000000..b022b34973
--- /dev/null
+++ b/gui/src/redux/thunks/callTool.ts
@@ -0,0 +1,46 @@
+import { createAsyncThunk } from "@reduxjs/toolkit";
+import { selectCurrentToolCall } from "../selectors/selectCurrentToolCall";
+import {
+ acceptToolCall,
+ setCalling,
+ setToolCallOutput,
+} from "../slices/stateSlice";
+import { ThunkApiType } from "../store";
+import { streamResponseAfterToolCall } from "./streamResponseAfterToolCall";
+
+export const callTool = createAsyncThunk(
+ "chat/callTool",
+ async (_, { dispatch, extra, getState }) => {
+ const state = getState();
+ const toolCallState = selectCurrentToolCall(state);
+
+ console.log("calling tool", toolCallState.toolCall);
+ if (!toolCallState) {
+ return;
+ }
+
+ if (toolCallState.status !== "generated") {
+ return;
+ }
+
+ dispatch(setCalling());
+
+ const result = await extra.ideMessenger.request("tools/call", {
+ toolCall: toolCallState.toolCall,
+ });
+
+ if (result.status === "success") {
+ const contextItems = result.content.contextItems;
+ dispatch(setToolCallOutput(contextItems));
+ dispatch(acceptToolCall());
+
+ // Send to the LLM to continue the conversation
+ dispatch(
+ streamResponseAfterToolCall({
+ toolCallId: toolCallState.toolCall.id,
+ toolOutput: contextItems,
+ }),
+ );
+ }
+ },
+);
diff --git a/gui/src/redux/thunks/cancelTool.ts b/gui/src/redux/thunks/cancelTool.ts
new file mode 100644
index 0000000000..c93c4585c2
--- /dev/null
+++ b/gui/src/redux/thunks/cancelTool.ts
@@ -0,0 +1,78 @@
+import { createAsyncThunk } from "@reduxjs/toolkit";
+import { selectCurrentToolCall } from "../selectors/selectCurrentToolCall";
+import {
+ cancelToolCall,
+ setCalling,
+ setToolCallOutput,
+} from "../slices/stateSlice";
+import { ThunkApiType } from "../store";
+import { streamResponseAfterToolCall } from "./streamResponseAfterToolCall";
+
+export const cancelTool = createAsyncThunk(
+ "chat/callTool",
+ async (_, { dispatch, extra, getState }) => {
+ const state = getState();
+ const toolCallState = selectCurrentToolCall(state);
+
+ if (!toolCallState) {
+ return;
+ }
+
+ if (toolCallState.status !== "generated") {
+ return;
+ }
+
+ dispatch(cancelToolCall());
+
+ dispatch(
+ streamResponseAfterToolCall({
+ toolCallId: toolCallState.toolCallId,
+ toolOutput: [
+ {
+ name: "Tool Cancelled",
+ description: "Tool Cancelled",
+ content:
+ "The tool call was cancelled by the user. Please try something else or request further instructions.",
+ hidden: true,
+ },
+ ],
+ }),
+ );
+ },
+);
+
+export const callTool = createAsyncThunk(
+ "chat/callTool",
+ async (_, { dispatch, extra, getState }) => {
+ const state = getState();
+ const toolCallState = selectCurrentToolCall(state);
+
+ console.log("calling tool", toolCallState.toolCall);
+ if (!toolCallState) {
+ return;
+ }
+
+ if (toolCallState.status !== "generated") {
+ return;
+ }
+
+ dispatch(setCalling());
+
+ const result = await extra.ideMessenger.request("tools/call", {
+ toolCall: toolCallState.toolCall,
+ });
+
+ if (result.status === "success") {
+ const contextItems = result.content.contextItems;
+ dispatch(setToolCallOutput(contextItems));
+
+ // Send to the LLM to continue the conversation
+ dispatch(
+ streamResponseAfterToolCall({
+ toolCallId: toolCallState.toolCall.id,
+ toolOutput: contextItems,
+ }),
+ );
+ }
+ },
+);
diff --git a/gui/src/redux/thunks/gatherContext.ts b/gui/src/redux/thunks/gatherContext.ts
new file mode 100644
index 0000000000..409004e08c
--- /dev/null
+++ b/gui/src/redux/thunks/gatherContext.ts
@@ -0,0 +1,119 @@
+import { createAsyncThunk } from "@reduxjs/toolkit";
+import { JSONContent } from "@tiptap/core";
+import {
+ ContextItemWithId,
+ InputModifiers,
+ MessageContent,
+ RangeInFile,
+} from "core";
+import { getBasename, getRelativePath } from "core/util";
+import resolveEditorContent, {
+ hasSlashCommandOrContextProvider,
+} from "../../components/mainInput/resolveInput";
+import { updateFileSymbolsFromContextItems } from "../../util/symbols";
+import { defaultModelSelector } from "../selectors/modelSelectors";
+import { setIsGatheringContext } from "../slices/stateSlice";
+import { ThunkApiType } from "../store";
+
+export const gatherContext = createAsyncThunk<
+ {
+ selectedContextItems: ContextItemWithId[];
+ selectedCode: RangeInFile[];
+ content: MessageContent;
+ },
+ {
+ editorState: JSONContent;
+ modifiers: InputModifiers;
+ promptPreamble?: string;
+ },
+ ThunkApiType
+>(
+ "chat/gatherContext",
+ async (
+ { modifiers, editorState, promptPreamble },
+ { dispatch, extra, getState },
+ ) => {
+ const state = getState();
+ const defaultModel = defaultModelSelector(state);
+ const defaultContextProviders =
+ state.state.config.experimental?.defaultContext ?? [];
+
+ // Resolve context providers and construct new history
+ const shouldGatherContext =
+ modifiers.useCodebase || hasSlashCommandOrContextProvider(editorState);
+
+ if (shouldGatherContext) {
+ dispatch(
+ setIsGatheringContext({
+ isGathering: true,
+ gatheringMessage: "Gathering Context",
+ }),
+ );
+ }
+
+ let [selectedContextItems, selectedCode, content] =
+ await resolveEditorContent(
+ editorState,
+ modifiers,
+ extra.ideMessenger,
+ defaultContextProviders,
+ dispatch,
+ );
+
+ // Automatically use currently open file
+ if (!modifiers.noContext) {
+ const usingFreeTrial = defaultModel?.provider === "free-trial";
+
+ const currentFile = await extra.ideMessenger.ide.getCurrentFile();
+ if (currentFile) {
+ let currentFileContents = currentFile.contents;
+ if (usingFreeTrial) {
+ currentFileContents = currentFile.contents
+ .split("\n")
+ .slice(0, 1000)
+ .join("\n");
+ }
+ if (
+ !selectedContextItems.find(
+ (item) => item.uri?.value === currentFile.path,
+ )
+ ) {
+ // don't add the file if it's already in the context items
+ selectedContextItems.unshift({
+ content: `The following file is currently open. Don't reference it if it's not relevant to the user's message.\n\n\`\`\`${getRelativePath(
+ currentFile.path,
+ await extra.ideMessenger.ide.getWorkspaceDirs(),
+ )}\n${currentFileContents}\n\`\`\``,
+ name: `Active file: ${getBasename(currentFile.path)}`,
+ description: currentFile.path,
+ id: {
+ itemId: currentFile.path,
+ providerTitle: "file",
+ },
+ uri: {
+ type: "file",
+ value: currentFile.path,
+ },
+ });
+ }
+ }
+ }
+
+ await updateFileSymbolsFromContextItems(
+ selectedContextItems,
+ extra.ideMessenger,
+ dispatch,
+ );
+
+ if (promptPreamble) {
+ if (typeof content === "string") {
+ content = promptPreamble + content;
+ } else {
+ content[0].text = promptPreamble + content[0].text;
+ }
+ }
+
+ // dispatch(addContextItems(contextItems));
+ return { selectedContextItems, selectedCode, content };
+ },
+);
diff --git a/gui/src/redux/thunks/handleErrors.ts b/gui/src/redux/thunks/handleErrors.ts
new file mode 100644
index 0000000000..6d7afb6327
--- /dev/null
+++ b/gui/src/redux/thunks/handleErrors.ts
@@ -0,0 +1,18 @@
+import { createAsyncThunk } from "@reduxjs/toolkit";
+import { setInactive } from "../slices/stateSlice";
+import { ThunkApiType } from "../store";
+
+export const handleErrors = createAsyncThunk<
+ void,
+ () => Promise,
+ ThunkApiType
+>("chat/handleErrors", async (runStream, { dispatch, extra }) => {
+ try {
+ await runStream();
+ } catch (e: any) {
+ console.debug("Error streaming response: ", e);
+ } finally {
+ dispatch(setInactive());
+ // triggerSave(!save); TODO
+ }
+});
diff --git a/gui/src/redux/thunks/resetStateForNewMessage.ts b/gui/src/redux/thunks/resetStateForNewMessage.ts
new file mode 100644
index 0000000000..e7f534e42d
--- /dev/null
+++ b/gui/src/redux/thunks/resetStateForNewMessage.ts
@@ -0,0 +1,11 @@
+import { createAsyncThunk } from "@reduxjs/toolkit";
+import { resetNextCodeBlockToApplyIndex } from "../slices/stateSlice";
+import { ThunkApiType } from "../store";
+
+export const resetStateForNewMessage = createAsyncThunk<
+ void,
+ undefined,
+ ThunkApiType
+>("chat/resetStateForNewMessage", async (_, { dispatch }) => {
+ dispatch(resetNextCodeBlockToApplyIndex());
+});
diff --git a/gui/src/redux/thunks/streamNormalInput.ts b/gui/src/redux/thunks/streamNormalInput.ts
new file mode 100644
index 0000000000..d17f2d8705
--- /dev/null
+++ b/gui/src/redux/thunks/streamNormalInput.ts
@@ -0,0 +1,91 @@
+import { createAsyncThunk } from "@reduxjs/toolkit";
+import { ChatMessage, PromptLog } from "core";
+import { defaultModelSelector } from "../selectors/modelSelectors";
+import { selectCurrentToolCall } from "../selectors/selectCurrentToolCall";
+import {
+ abortStream,
+ addPromptCompletionPair,
+ clearLastEmptyResponse,
+ setToolGenerated,
+ streamUpdate,
+} from "../slices/stateSlice";
+import { ThunkApiType } from "../store";
+import { callTool } from "./callTool";
+
+export const streamNormalInput = createAsyncThunk<
+ void,
+ ChatMessage[],
+ ThunkApiType
+>("chat/streamNormalInput", async (messages, { dispatch, extra, getState }) => {
+ try {
+ // Gather state
+ const state = getState();
+ const defaultModel = defaultModelSelector(state);
+ const toolSettings = state.uiState.toolSettings;
+ const streamAborter = state.state.streamAborter;
+ const useTools = state.uiState.useTools;
+
+ if (!defaultModel) {
+ throw new Error("Default model not defined");
+ }
+
+ // Send request
+ const gen = extra.ideMessenger.llmStreamChat(
+ defaultModel.title,
+ streamAborter.signal,
+ messages,
+ {
+ tools: useTools
+ ? Object.keys(toolSettings)
+ .filter((tool) => toolSettings[tool] !== "disabled")
+ .map((toolName) =>
+ state.state.config.tools.find(
+ (tool) => tool.function.name === toolName,
+ ),
+ )
+ .filter(Boolean)
+ : undefined,
+ },
+ );
+
+ // Stream response
+ let next = await gen.next();
+ while (!next.done) {
+ if (!getState().state.active) {
+ dispatch(abortStream());
+ break;
+ }
+
+ const update = next.value as ChatMessage;
+ dispatch(streamUpdate(update));
+ next = await gen.next();
+
+ // There has been lag when streaming tool calls. This is a temporary solution
+ if (update.role === "assistant" && update.toolCalls) {
+ await new Promise((resolve) => setTimeout(resolve, 10));
+ }
+ }
+
+ // Attach prompt log
+ let returnVal = next.value as PromptLog;
+ if (returnVal) {
+ dispatch(addPromptCompletionPair([returnVal]));
+ }
+
+ // If it's a tool call that is automatically accepted, we should call it
+ const toolCallState = selectCurrentToolCall(getState());
+ if (toolCallState) {
+ dispatch(setToolGenerated());
+
+ if (
+ toolSettings[toolCallState.toolCall.function.name] ===
+ "allowedWithoutPermission"
+ ) {
+ await dispatch(callTool());
+ }
+ }
+ } catch (e) {
+ // If there's an error, we should clear the response so there aren't two input boxes
+ dispatch(clearLastEmptyResponse());
+ }
+});
diff --git a/gui/src/redux/thunks/streamResponse.ts b/gui/src/redux/thunks/streamResponse.ts
new file mode 100644
index 0000000000..c1f6ba4d83
--- /dev/null
+++ b/gui/src/redux/thunks/streamResponse.ts
@@ -0,0 +1,157 @@
+import { createAsyncThunk, unwrapResult } from "@reduxjs/toolkit";
+import { JSONContent } from "@tiptap/core";
+import {
+ ChatHistoryItem,
+ ChatMessage,
+ InputModifiers,
+ MessageContent,
+ SlashCommandDescription,
+} from "core";
+import { constructMessages } from "core/llm/constructMessages";
+import { renderChatMessage } from "core/util/messageContent";
+import posthog from "posthog-js";
+import { defaultModelSelector } from "../selectors/modelSelectors";
+import {
+ initNewActiveMessage,
+ resubmitAtIndex,
+ setCurCheckpointIndex,
+ setMessageAtIndex,
+} from "../slices/stateSlice";
+import { ThunkApiType } from "../store";
+import { gatherContext } from "./gatherContext";
+import { handleErrors } from "./handleErrors";
+import { resetStateForNewMessage } from "./resetStateForNewMessage";
+import { streamNormalInput } from "./streamNormalInput";
+import { streamSlashCommand } from "./streamSlashCommand";
+
+const getSlashCommandForInput = (
+ input: MessageContent,
+ slashCommands: SlashCommandDescription[],
+): [SlashCommandDescription, string] | undefined => {
+ let slashCommand: SlashCommandDescription | undefined;
+ let slashCommandName: string | undefined;
+
+ let lastText =
+ typeof input === "string"
+ ? input
+ : input.filter((part) => part.type === "text").slice(-1)[0]?.text || "";
+
+ if (lastText.startsWith("/")) {
+ slashCommandName = lastText.split(" ")[0].substring(1);
+ slashCommand = slashCommands.find(
+ (command) => command.name === slashCommandName,
+ );
+ }
+ if (!slashCommand || !slashCommandName) {
+ return undefined;
+ }
+
+ // Convert to actual slash command object with runnable function
+ return [slashCommand, renderChatMessage({ role: "user", content: input })];
+};
+
+export const streamResponseThunk = createAsyncThunk<
+ void,
+ {
+ editorState: JSONContent;
+ modifiers: InputModifiers;
+ index?: number;
+ promptPreamble?: string;
+ },
+ ThunkApiType
+>(
+ "chat/streamResponse",
+ async (
+ { editorState, modifiers, index, promptPreamble },
+ { dispatch, extra, getState },
+ ) => {
+ await dispatch(
+ handleErrors(async () => {
+ const state = getState();
+ const defaultModel = defaultModelSelector(state);
+ const slashCommands = state.state.config.slashCommands || [];
+ const inputIndex = index ?? state.state.history.length;
+
+ if (typeof index === "number") {
+ dispatch(resubmitAtIndex({ index, editorState }));
+ } else {
+ dispatch(initNewActiveMessage({ editorState }));
+ }
+
+ resetStateForNewMessage();
+
+ if (index) {
+ dispatch(setCurCheckpointIndex(Math.floor(index / 2)));
+ }
+
+ const result = await dispatch(
+ gatherContext({
+ editorState,
+ modifiers,
+ promptPreamble,
+ }),
+ );
+ const unwrapped = unwrapResult(result);
+ const { selectedContextItems, selectedCode, content } = unwrapped;
+
+ // Add the message to the history
+ const message: ChatMessage = {
+ role: "user",
+ content,
+ };
+ const historyItem: ChatHistoryItem = {
+ message,
+ contextItems: selectedContextItems,
+ editorState,
+ };
+
+ dispatch(
+ setMessageAtIndex({
+ message,
+ index: inputIndex,
+ contextItems: selectedContextItems,
+ }),
+ );
+
+ // Construct messages from updated history
+ const updatedHistory = getState().state.history;
+ const messages = constructMessages(updatedHistory, defaultModel.model);
+
+ posthog.capture("step run", {
+ step_name: "User Input",
+ params: {},
+ });
+ posthog.capture("userInput", {});
+
+ // Determine if the input is a slash command
+ let commandAndInput = getSlashCommandForInput(content, slashCommands);
+
+ if (!commandAndInput) {
+ await dispatch(streamNormalInput(messages));
+ } else {
+ const [slashCommand, commandInput] = commandAndInput;
+ let updatedContextItems = [];
+ posthog.capture("step run", {
+ step_name: slashCommand.name,
+ params: {},
+ });
+
+ // if (slashCommand.name === "multifile-edit") {
+ // dispatch(setIsInMultifileEdit(true));
+ // }
+
+ await dispatch(
+ streamSlashCommand({
+ messages,
+ slashCommand,
+ input: commandInput,
+ historyIndex: inputIndex,
+ selectedCode,
+ contextItems: updatedContextItems,
+ }),
+ );
+ }
+ }),
+ );
+ },
+);
diff --git a/gui/src/redux/thunks/streamResponseAfterToolCall.ts b/gui/src/redux/thunks/streamResponseAfterToolCall.ts
new file mode 100644
index 0000000000..6652f3a0ed
--- /dev/null
+++ b/gui/src/redux/thunks/streamResponseAfterToolCall.ts
@@ -0,0 +1,67 @@
+import { createAsyncThunk } from "@reduxjs/toolkit";
+import { ChatMessage, ContextItem } from "core";
+import { constructMessages } from "core/llm/constructMessages";
+import { renderContextItems } from "core/util/messageContent";
+import { defaultModelSelector } from "../selectors/modelSelectors";
+import {
+ addContextItemsAtIndex,
+ setActive,
+ streamUpdate,
+} from "../slices/stateSlice";
+import { ThunkApiType } from "../store";
+import { handleErrors } from "./handleErrors";
+import { resetStateForNewMessage } from "./resetStateForNewMessage";
+import { streamNormalInput } from "./streamNormalInput";
+
+export const streamResponseAfterToolCall = createAsyncThunk<
+ void,
+ {
+ toolCallId: string;
+ toolOutput: ContextItem[];
+ },
+ ThunkApiType
+>(
+ "chat/streamAfterToolCall",
+ async ({ toolCallId, toolOutput }, { dispatch, getState, extra }) => {
+ await dispatch(
+ handleErrors(async () => {
+ const state = getState();
+ const initialHistory = state.state.history;
+ const defaultModel = defaultModelSelector(state);
+
+ resetStateForNewMessage();
+
+ await new Promise((resolve) => setTimeout(resolve, 0));
+
+ const newMessage: ChatMessage = {
+ role: "tool",
+ content: renderContextItems(toolOutput),
+ toolCallId,
+ };
+
+ dispatch(streamUpdate(newMessage));
+ dispatch(
+ addContextItemsAtIndex({
+ index: initialHistory.length,
+ contextItems: toolOutput.map((contextItem) => ({
+ ...contextItem,
+ id: {
+ providerTitle: "toolCall",
+ itemId: toolCallId,
+ },
+ })),
+ }),
+ );
+
+ dispatch(setActive());
+
+ const updatedHistory = getState().state.history;
+ const messages = constructMessages(
+ [...updatedHistory],
+ defaultModel.model,
+ );
+ await dispatch(streamNormalInput(messages));
+ }),
+ );
+ },
+);
diff --git a/gui/src/redux/thunks/streamSlashCommand.ts b/gui/src/redux/thunks/streamSlashCommand.ts
new file mode 100644
index 0000000000..4b6c717910
--- /dev/null
+++ b/gui/src/redux/thunks/streamSlashCommand.ts
@@ -0,0 +1,71 @@
+import { createAsyncThunk } from "@reduxjs/toolkit";
+import {
+ ChatMessage,
+ ContextItemWithId,
+ RangeInFile,
+ SlashCommandDescription,
+} from "core";
+import { defaultModelSelector } from "../selectors/modelSelectors";
+import { abortStream, streamUpdate } from "../slices/stateSlice";
+import { ThunkApiType } from "../store";
+
+export const streamSlashCommand = createAsyncThunk<
+ void,
+ {
+ messages: ChatMessage[];
+ slashCommand: SlashCommandDescription;
+ input: string;
+ historyIndex: number;
+ selectedCode: RangeInFile[];
+ contextItems: ContextItemWithId[];
+ },
+ ThunkApiType
+>(
+ "chat/streamSlashCommand",
+ async (
+ { messages, selectedCode, slashCommand, input, historyIndex, contextItems },
+ { dispatch, getState, extra },
+ ) => {
+ const state = getState();
+ const defaultModel = defaultModelSelector(state);
+ const active = state.state.active;
+ const streamAborter = state.state.streamAborter;
+
+ if (!defaultModel) {
+ throw new Error("Default model not defined");
+ }
+
+ const modelTitle = defaultModel.title;
+
+ const checkActiveInterval = setInterval(() => {
+ if (!active) {
+ dispatch(abortStream());
+ clearInterval(checkActiveInterval);
+ }
+ }, 100);
+
+ for await (const update of extra.ideMessenger.streamRequest(
+ "command/run",
+ {
+ input,
+ history: messages,
+ modelTitle,
+ slashCommandName: slashCommand.name,
+ contextItems,
+ params: slashCommand.params,
+ historyIndex,
+ selectedCode,
+ },
+ streamAborter.signal,
+ )) {
+ if (!getState().state.active) {
+ dispatch(abortStream());
+ break;
+ }
+ if (typeof update === "string") {
+ dispatch(streamUpdate(update));
+ }
+ }
+ clearInterval(checkActiveInterval);
+ },
+);
diff --git a/gui/src/redux/util/index.ts b/gui/src/redux/util/index.ts
new file mode 100644
index 0000000000..06ceebb4ab
--- /dev/null
+++ b/gui/src/redux/util/index.ts
@@ -0,0 +1,8 @@
+import { ToolCallState } from "core";
+import { RootState } from "../store";
+
+export function findCurrentToolCall(
+ state: RootState["state"]["history"],
+): ToolCallState | undefined {
+ return state[state.length - 1]?.toolCallState;
+}
diff --git a/manual-testing-sandbox/test.ts b/manual-testing-sandbox/test.ts
new file mode 100644
index 0000000000..9cf524421e
--- /dev/null
+++ b/manual-testing-sandbox/test.ts
@@ -0,0 +1,41 @@
+class Calculator {
+ private result: number;
+
+ constructor() {
+ this.result = 0;
+ }
+
+ add(number: number): Calculator {
+ this.result += number;
+ return this;
+ }
+
+ subtract(number: number): Calculator {
+ this.result -= number;
+ return this;
+ }
+
+ multiply(number: number): Calculator {
+ this.result *= number;
+ return this;
+ }
+
+ divide(number: number): Calculator {
+ if (number === 0) {
+ throw new Error("Cannot divide by zero");
+ }
+ this.result /= number;
+ return this;
+ }
+
+ getResult(): number {
+ return this.result;
+ }
+
+ reset(): Calculator {
+ this.result = 0;
+ return this;
+ }
+}
+
+export default Calculator;
\ No newline at end of file
diff --git a/packages/llm-info/package.json b/packages/llm-info/package.json
index 4ba97f6b8a..af83039381 100644
--- a/packages/llm-info/package.json
+++ b/packages/llm-info/package.json
@@ -1,6 +1,6 @@
{
"name": "@continuedev/llm-info",
- "version": "1.0.6",
+ "version": "1.0.7",
"description": "",
"main": "dist/index.js",
"types": "dist/index.d.ts",
diff --git a/packages/llm-info/src/index.ts b/packages/llm-info/src/index.ts
index 84d88e031e..fe71f419dd 100644
--- a/packages/llm-info/src/index.ts
+++ b/packages/llm-info/src/index.ts
@@ -22,7 +22,7 @@ export const allModelProviders: ModelProvider[] = [
Vllm,
Bedrock,
Cohere,
- xAI
+ xAI,
];
export const allLlms: LlmInfo[] = allModelProviders.flatMap((provider) =>