Skip to content

Commit

Permalink
feat: add citations for knowledge sources (#1763)
Browse files Browse the repository at this point in the history
* feat: add citations for website knowledge

* enhance: prevent rerenders in messages

* feat: enable citations for google search

* chore: remove console logs

Signed-off-by: Ryan Hopper-Lowe <[email protected]>

---------

Signed-off-by: Ryan Hopper-Lowe <[email protected]>
  • Loading branch information
ryanhopperlowe authored Feb 18, 2025
1 parent 43500e4 commit f04e079
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 52 deletions.
90 changes: 74 additions & 16 deletions ui/admin/app/components/chat/Message.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ export const Message = React.memo(
// prevent animation for messages that never run
// only calculate on mount because we don't want to stop animation when the message finishes streaming
const [shouldAnimate] = useState(isRunning);
const animatedText = useDeferredValue(
useAnimatedText(message.text, !shouldAnimate || isUser || !!toolCall)
const [_animatedText, isAnimating] = useAnimatedText(
message.text,
!shouldAnimate || isUser || !!toolCall
);

const animatedText = useDeferredValue(_animatedText);

const parsedMessage = useMemo(() => {
if (OpenMarkdownLinkRegex.test(animatedText)) {
return animatedText.replace(
Expand Down Expand Up @@ -90,17 +93,13 @@ export const Message = React.memo(
>
<div
className={cn({
"rounded-xl border border-error bg-error-foreground":
"rounded-xl border border-error bg-error-foreground p-4":
message.error,
"max-w-[80%] rounded-2xl bg-accent p-4": isUser,
"w-full max-w-full": !isUser,
})}
>
<div
className={cn(
"flex max-w-full items-center gap-2 overflow-hidden"
)}
>
<div className="flex max-w-full items-center gap-2 overflow-hidden">
{message.aborted && (
<AlertCircleIcon className="h-5 w-5 text-muted-foreground" />
)}
Expand All @@ -117,14 +116,21 @@ export const Message = React.memo(
{message.prompt ? (
<PromptMessage prompt={message.prompt} isRunning={isRunning} />
) : (
<Markdown
className={cn({
"prose-invert text-accent-foreground": isUser,
"text-muted-foreground": message.aborted,
})}
>
{parsedMessage || "Waiting for more information..."}
</Markdown>
<div className="flex w-full flex-col">
<Markdown
className={cn({
"prose-invert text-accent-foreground": isUser,
"text-muted-foreground": message.aborted,
})}
>
{parsedMessage || "Waiting for more information..."}
</Markdown>

<SourceCitations
message={message}
show={!isAnimating || !shouldAnimate}
/>
</div>
)}

{toolCall && (
Expand Down Expand Up @@ -352,3 +358,55 @@ export function PromptAuthForm({
</Form>
);
}

export function SourceCitations({
message,
show,
}: {
message: MessageType;
show: boolean;
}) {
if (!message.knowledgeSources || !show) return null;

const formatUrl = (url: string) => {
return url.replace(/(https?:\/\/)?(www\.)?/, "");
};

const citations = new Map(
message.knowledgeSources
.filter((s) => !!s.url)
.map((s) => {
return [
formatUrl(s.url as string),
{ ...s, url: new URL(s.url as string) },
];
})
);

return (
<div className="flex flex-col gap-2 pt-4">
<h4>Sources</h4>
<div className="flex flex-wrap gap-2 pb-3">
{Array.from(citations.entries()).map(([key, { url }]) => (
<Link
as="button"
variant="secondary"
size="sm"
to={url.toString()}
target="_blank"
rel="noreferrer"
key={key}
>
<img
src={`${url.origin}/favicon.ico`}
alt="Favicon"
onError={(e) => (e.currentTarget.src = "/favicon.ico")}
className="size-4"
/>
{key.length > 25 ? key.slice(0, 25) + "..." : key}
</Link>
))}
</div>
</div>
);
}
2 changes: 1 addition & 1 deletion ui/admin/app/components/chat/MessagePane.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ export function MessagePane({
<Message
key={i}
message={message}
isRunning={isRunning}
isRunning={isRunning && i === messages.length - 1}
icons={icons}
isDarkMode={isDarkMode}
isMostRecent={i === messages.length - 1}
Expand Down
11 changes: 8 additions & 3 deletions ui/admin/app/components/ui/scroll-area.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ const ScrollArea = React.forwardRef<
startScrollAt?: "bottom";
enableScrollStick?: "bottom";
enableScrollTo?: "bottom";
orientation?: "vertical" | "horizontal";
classNames?: {
root?: string;
viewport?: string;
content?: string;
};
}
>((props, ref) => {
Expand All @@ -29,6 +31,7 @@ const ScrollArea = React.forwardRef<
startScrollAt,
enableScrollTo,
enableScrollStick,
orientation = "vertical",
classNames = {},
...rootProps
} = props;
Expand Down Expand Up @@ -82,7 +85,7 @@ const ScrollArea = React.forwardRef<
// setting `display: table` in the `ScrollAreaPrimitive.Viewport` component.
// This is a known issue with Radix UI ScrollArea.
// https://github.com/radix-ui/primitives/issues/2722
"h-full max-h-[inherit] w-full scroll-smooth rounded-[inherit] [&>div]:!block",
"h-full max-h-[inherit] w-full max-w-[inherit] scroll-smooth rounded-[inherit] [&>div]:!block",
classNames.viewport
)}
ref={initRef}
Expand All @@ -94,7 +97,9 @@ const ScrollArea = React.forwardRef<
}
}}
>
<div ref={contentRef}>{children}</div>
<div ref={contentRef} className={classNames.content}>
{children}
</div>
{enableScrollTo === "bottom" && (
<ScrollToBottom
behavior="smooth"
Expand All @@ -104,7 +109,7 @@ const ScrollArea = React.forwardRef<
/>
)}
</ScrollAreaPrimitive.Viewport>
<ScrollBar />
<ScrollBar orientation={orientation} />
<ScrollAreaPrimitive.Corner />
</ScrollAreaPrimitive.Root>
);
Expand Down
8 changes: 7 additions & 1 deletion ui/admin/app/hooks/messages/useAnimatedText.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export function useAnimatedText(text: string, disabled?: boolean) {
const [cursor, setCursor] = useState(0);
const [prev, setPrev] = useState(text);
const [isSameText, setIsSameText] = useState(true);
const [isAnimating, setIsAnimating] = useState(false);

if (prev !== text) {
setPrev(text);
Expand All @@ -22,16 +23,21 @@ export function useAnimatedText(text: string, disabled?: boolean) {
animatedCursor.jump(0);
}

setIsAnimating(true);

const controls = animate(animatedCursor, text.length, {
duration: 0.5,
ease: "linear",
onUpdate(latest) {
setCursor(Math.floor(latest));
},
onComplete() {
setIsAnimating(false);
},
});

return () => controls.stop();
}, [animatedCursor, isSameText, text.length]);

return disabled ? text : text.slice(0, cursor);
return [disabled ? text : text.slice(0, cursor), isAnimating] as const;
}
8 changes: 8 additions & 0 deletions ui/admin/app/lib/model/chatEvents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ export type ToolInput = {
content: string;
};

export type KnowledgeToolOutput = { url?: string; content: string }[];

export type GoogleSearchOutput = {
duration: { search: number; refine: number; response: number };
query: string;
results: { url: string; content: string }[];
};

export type ToolCall = {
name: string;
description: string;
Expand Down
8 changes: 7 additions & 1 deletion ui/admin/app/lib/model/messages.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import { AuthPrompt, ChatEvent, ToolCall } from "~/lib/model/chatEvents";
import {
AuthPrompt,
ChatEvent,
KnowledgeToolOutput,
ToolCall,
} from "~/lib/model/chatEvents";
import { Run } from "~/lib/model/runs";

export interface Message {
Expand All @@ -14,6 +19,7 @@ export interface Message {
error?: boolean;
contentID?: string;
time?: Date | string;
knowledgeSources?: KnowledgeToolOutput;
}

export const runsToMessages = (runs: Run[]) => {
Expand Down
70 changes: 54 additions & 16 deletions ui/admin/app/lib/store/chat/message-store.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import { createStore } from "zustand";

import { ChatEvent } from "~/lib/model/chatEvents";
import {
ChatEvent,
GoogleSearchOutput,
KnowledgeToolOutput,
ToolCall,
} from "~/lib/model/chatEvents";
import { Message, promptMessage, toolCallMessage } from "~/lib/model/messages";
import { ThreadsService } from "~/lib/service/api/threadsService";
import { handleTry } from "~/lib/utils/handleTry";

type EventInitConfig = {
onEvent: (event: ChatEvent) => void;
Expand All @@ -20,6 +26,9 @@ export type MessageStore = {

export const createMessageStore = () => {
return createStore<MessageStore>()((set, get) => {
// pseudo private instance variable
let parsedSources: { url?: string; content: string }[] = [];

return {
messages: [],
cleanupFns: [],
Expand Down Expand Up @@ -169,31 +178,60 @@ export const createMessageStore = () => {
runId: runID,
contentID,
time,
knowledgeSources: parsedSources.length ? parsedSources : undefined,
});

// reset the knowledge output for the next message
parsedSources = [];
return { messages: copy };
}

return { messages: copy };
});
}
});
};

const handleToolCallEvent = (messages: Message[], event: ChatEvent) => {
if (!event.toolCall) return messages;
function handleToolCallEvent(messages: Message[], event: ChatEvent) {
if (!event.toolCall) return messages;

const { toolCall } = event;

const sources = pullSources(toolCall);

if (sources) parsedSources.push(...sources);

// if the toolCall is an output event
if (toolCall.output) {
const index = messages.findLastIndex(
(m) => m.tools && !m.tools[0].output
);
if (index !== -1) {
// update the previous pending toolcall message (without output)
messages[index].tools = [toolCall];
return messages;
}
}

const { toolCall } = event;
if (toolCall.output) {
// const index = findIndexLastPendingToolCall(messages);
const index = messages.findLastIndex((m) => m.tools && !m.tools[0].output);
if (index !== -1) {
// update the found pending toolcall message (without output)
messages[index].tools = [toolCall];
// otherwise add a new toolcall message
messages.push(toolCallMessage(toolCall));
return messages;
}
});
};

function pullSources(toolCall: ToolCall) {
if (!toolCall.output) return;

const [err, output] = handleTry(() => JSON.parse(toolCall.output));

if (err) return [];

if (toolCall.name === "Knowledge") {
const o = output as KnowledgeToolOutput;
return o;
}

// otherwise add a new toolcall message
messages.push(toolCallMessage(toolCall));
return messages;
};
if (toolCall.name === "Search") {
const o = output as GoogleSearchOutput;
return o.results;
}
}
10 changes: 10 additions & 0 deletions ui/admin/app/lib/utils/forceError.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
export const forceError = (
error: unknown,
fallbackMessage = "Something went wrong"
) => {
if (error instanceof Error) return error;

if (typeof error === "string") return new Error(error);

return new Error(fallbackMessage, { cause: error });
};
11 changes: 1 addition & 10 deletions ui/admin/app/lib/utils/handlePromise.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,4 @@
const forceError = (
error: unknown,
fallbackMessage = "Something went wrong"
) => {
if (error instanceof Error) return error;

if (typeof error === "string") return new Error(error);

return new Error(fallbackMessage, { cause: error });
};
import { forceError } from "~/lib/utils/forceError";

export type PromiseResult<TData> =
| readonly [null, TData]
Expand Down
9 changes: 9 additions & 0 deletions ui/admin/app/lib/utils/handleTry.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { forceError } from "~/lib/utils/forceError";

export function handleTry<TResponse>(fn: () => TResponse) {
try {
return [null, fn()] as const;
} catch (e) {
return [forceError(e), null] as const;
}
}
Loading

0 comments on commit f04e079

Please sign in to comment.