From 592c5969c260fa2741ec8235ba07a177e44e0479 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 18 Feb 2025 23:05:59 -0600 Subject: [PATCH 1/3] feat: add rating system for generated media - Add thumbs up/down UI for generated media items - Implement rating persistence in database - Add rating filters in left panel with "Not Disliked" default view --- src/components/left-panel.tsx | 238 +++++++++++++++++++++------------ src/components/media-panel.tsx | 41 +++++- src/data/db.ts | 8 ++ src/data/schema.ts | 2 + 4 files changed, 203 insertions(+), 86 deletions(-) diff --git a/src/components/left-panel.tsx b/src/components/left-panel.tsx index 013526b..a3e96a8 100644 --- a/src/components/left-panel.tsx +++ b/src/components/left-panel.tsx @@ -21,6 +21,10 @@ import { LoaderCircleIcon, CloudUploadIcon, SparklesIcon, + ThumbsUpIcon, + ThumbsDownIcon, + MinusCircleIcon, + ListFilterIcon, } from "lucide-react"; import { MediaItemPanel } from "./media-panel"; import { Button } from "./ui/button"; @@ -51,6 +55,9 @@ export default function LeftPanel() { const { data: project = PROJECT_PLACEHOLDER } = useProject(projectId); const projectUpdate = useProjectUpdater(projectId); const [mediaType, setMediaType] = useState("all"); + const [ratingFilter, setRatingFilter] = useState< + "all" | "positive" | "negative" | "unrated" | "not_disliked" + >("not_disliked"); const queryClient = useQueryClient(); const { data: mediaItems = [], isLoading } = useProjectMediaItems(projectId); @@ -179,92 +186,139 @@ export default function LeftPanel() {
-
-

- Gallery -

-
- - - - - - setMediaType("all")} - > - - All - - setMediaType("image")} - > - - Image - - setMediaType("music")} - > - - Music - - setMediaType("voiceover")} - > - - Voiceover - - setMediaType("video")} +
+
+

+ Gallery +

+
+ + + + + + setRatingFilter("not_disliked")} + > + + Not Disliked + + setRatingFilter("all")} + > + + All + + setRatingFilter("positive")} + > + + Liked + + setRatingFilter("negative")} + > + + Disliked + + setRatingFilter("unrated")} + > + + Unrated + + + + + + + + + setMediaType("all")} + > + + All + + setMediaType("image")} + > + + Image + + setMediaType("music")} + > + + Music + + setMediaType("voiceover")} + > + + Voiceover + + setMediaType("video")} + > + + Video + + + + + {mediaItems.length > 0 && ( + + + Generate... + + )} +
- {mediaItems.length > 0 && ( - - )}
{!isLoading && mediaItems.length === 0 && (
@@ -285,7 +339,21 @@ export default function LeftPanel() { {mediaItems.length > 0 && ( { + // First filter by media type + if (mediaType !== "all" && media.mediaType !== mediaType) + return false; + + // Then filter by rating + if (ratingFilter === "positive") + return media.rating === "positive"; + if (ratingFilter === "negative") + return media.rating === "negative"; + if (ratingFilter === "unrated") return media.rating === undefined; + if (ratingFilter === "not_disliked") + return media.rating !== "negative"; + return true; + })} mediaType={mediaType} className="overflow-y-auto" /> diff --git a/src/components/media-panel.tsx b/src/components/media-panel.tsx index 8699334..e54b879 100644 --- a/src/components/media-panel.tsx +++ b/src/components/media-panel.tsx @@ -13,6 +13,8 @@ import { ImageIcon, MicIcon, MusicIcon, + ThumbsDownIcon, + ThumbsUpIcon, VideoIcon, } from "lucide-react"; import { @@ -125,6 +127,15 @@ export function MediaItemRow({ ? data.metadata?.start_frame_url || data?.metadata?.end_frame_url : resolveMediaUrl(data); + const handleRating = async (rating: "positive" | "negative" | undefined) => { + // If clicking the same rating, clear it. Otherwise set the new rating + const newRating = data.rating === rating ? undefined : rating; + await db.media.updateRating(data.id, newRating); + await queryClient.invalidateQueries({ + queryKey: queryKeys.projectMediaItems(projectId), + }); + }; + return (
-
+
{formatDistanceToNow(data.createdAt, { addSuffix: true })} + {data.status === "completed" && data.kind === "generated" && ( +
+ + +
+ )}
diff --git a/src/data/db.ts b/src/data/db.ts index d52cd03..d6ca286 100644 --- a/src/data/db.ts +++ b/src/data/db.ts @@ -154,6 +154,14 @@ export const db = { await tx.done; return result; }, + async updateRating( + id: string, + rating: "positive" | "negative" | undefined, + ) { + const existing = await this.find(id); + if (!existing || existing.kind !== "generated") return; + return this.update(id, { rating }); + }, async delete(id: string) { const db = await open(); const media: MediaItem | null = await db.get("media_items", id); diff --git a/src/data/schema.ts b/src/data/schema.ts index ffb4fb8..dc2f56e 100644 --- a/src/data/schema.ts +++ b/src/data/schema.ts @@ -78,6 +78,7 @@ export type MediaItem = { input?: Record; output?: Record; url?: string; + rating?: "positive" | "negative"; metadata?: Record; // TODO: Define the metadata schema } & ( | { @@ -86,6 +87,7 @@ export type MediaItem = { requestId: string; input: Record; output?: Record; + rating?: "positive" | "negative"; } | { kind: "uploaded"; From 3508e5cd29079f7a5fc624c486f17f142476098a Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 20 Feb 2025 15:30:10 -0600 Subject: [PATCH 2/3] feat: analysis view that shows statistics and prompt insights based on positively vs. negatively rated media generations for each model --- src/components/analytics/analytics-dialog.tsx | 134 ++++++++++ .../analytics/model-performance.tsx | 172 +++++++++++++ src/components/analytics/prompt-analysis.tsx | 164 ++++++++++++ src/components/header.tsx | 15 +- src/components/main.tsx | 7 + src/components/ui/scroll-area.tsx | 45 ++++ src/data/store.ts | 15 ++ src/lib/analytics.ts | 242 ++++++++++++++++++ 8 files changed, 793 insertions(+), 1 deletion(-) create mode 100644 src/components/analytics/analytics-dialog.tsx create mode 100644 src/components/analytics/model-performance.tsx create mode 100644 src/components/analytics/prompt-analysis.tsx create mode 100644 src/components/ui/scroll-area.tsx create mode 100644 src/lib/analytics.ts diff --git a/src/components/analytics/analytics-dialog.tsx b/src/components/analytics/analytics-dialog.tsx new file mode 100644 index 0000000..4eba41b --- /dev/null +++ b/src/components/analytics/analytics-dialog.tsx @@ -0,0 +1,134 @@ +import { useVideoProjectStore } from "@/data/store"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { BarChart2Icon } from "lucide-react"; +import { useProjectId } from "@/data/store"; +import { useProjectMediaItems } from "@/data/queries"; +import { calculateModelStats, analyzePrompts } from "@/lib/analytics"; +import { ModelPerformance } from "./model-performance"; +import { PromptAnalysis } from "./prompt-analysis"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import { useCallback } from "react"; +import type { MediaItem } from "@/data/schema"; +import type { PromptAnalysis as PromptAnalysisType } from "@/lib/analytics"; + +interface AnalyticsDialogProps { + onOpenChange?: (open: boolean) => void; +} + +function preparePromptData(mediaItems: MediaItem[]) { + // Group rated prompts by model + const byModel = new Map< + string, + { positivePrompts: string[]; negativePrompts: string[] } + >(); + + for (const item of mediaItems) { + if (item.kind !== "generated" || !item.input?.prompt || !item.rating) + continue; + const modelId = item.endpointId; + if (!byModel.has(modelId)) { + byModel.set(modelId, { positivePrompts: [], negativePrompts: [] }); + } + const modelData = byModel.get(modelId)!; + if (item.rating === "positive") { + modelData.positivePrompts.push(item.input.prompt); + } else { + modelData.negativePrompts.push(item.input.prompt); + } + } + + // Convert to array format + return Array.from(byModel.entries()) + .filter( + ([_, data]) => + data.positivePrompts.length > 0 || data.negativePrompts.length > 0, + ) + .map(([modelId, data]) => ({ + modelId, + ...data, + })); +} + +export function AnalyticsDialog({ onOpenChange }: AnalyticsDialogProps) { + const projectId = useProjectId(); + const analyticsDialogOpen = useVideoProjectStore( + (s) => s.analyticsDialogOpen, + ); + const setAnalyticsDialogOpen = useVideoProjectStore( + (s) => s.setAnalyticsDialogOpen, + ); + const { data: mediaItems = [] } = useProjectMediaItems(projectId); + + const modelStats = calculateModelStats(mediaItems); + const promptData = preparePromptData(mediaItems); + + const handleAnalyzePrompts = useCallback( + async (modelId: string): Promise => { + const modelItems = mediaItems.filter( + (item) => + item.kind === "generated" && + item.endpointId === modelId && + item.input?.prompt && + item.rating, + ); + const [analysis] = await analyzePrompts(modelItems); + return analysis; + }, + [mediaItems], + ); + + const handleOnOpenChange = (isOpen: boolean) => { + onOpenChange?.(isOpen); + setAnalyticsDialogOpen(isOpen); + }; + + return ( + + + + + + Analytics + + + + + + Model Performance + Prompt Analysis + + + +

+ Compare the performance of different AI models +

+ + {modelStats.length > 0 ? ( + + ) : ( +
+ No model data available yet +
+ )} +
+ + +

+ Analyze patterns in successful and unsuccessful prompts +

+ + +
+
+
+
+ ); +} diff --git a/src/components/analytics/model-performance.tsx b/src/components/analytics/model-performance.tsx new file mode 100644 index 0000000..664f9f3 --- /dev/null +++ b/src/components/analytics/model-performance.tsx @@ -0,0 +1,172 @@ +import { type ModelStats, getModelName } from "@/lib/analytics"; +import { + BarChart, + Bar, + XAxis, + YAxis, + CartesianGrid, + Tooltip, + Legend, + ResponsiveContainer, +} from "recharts"; +import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group"; +import { + FilmIcon, + ImageIcon, + MicIcon, + MusicIcon, + LayersIcon, +} from "lucide-react"; +import { useState } from "react"; +import type { MediaType } from "@/data/store"; +import { AVAILABLE_ENDPOINTS } from "@/lib/fal"; + +interface ModelPerformanceProps { + data: ModelStats[]; +} + +type FilterType = MediaType | "all"; + +export function ModelPerformance({ data }: ModelPerformanceProps) { + const [selectedType, setSelectedType] = useState("video"); + + const handleFilterChange = (value: string | undefined) => { + if (!value) return; + setSelectedType(value as FilterType); + }; + + const getModelType = (modelId: string): MediaType | null => { + // First try exact match + const endpoint = AVAILABLE_ENDPOINTS.find((e) => e.endpointId === modelId); + if (endpoint) return endpoint.category; + + // Try matching base model ID (for variants like image-to-video) + const baseModelId = modelId.split("/").slice(0, -1).join("/"); + const baseEndpoint = AVAILABLE_ENDPOINTS.find( + (e) => e.endpointId === baseModelId, + ); + return baseEndpoint?.category || null; + }; + + const filteredData = data.filter((stat) => { + if (selectedType === "all") return true; + const modelType = getModelType(stat.modelId); + return modelType === selectedType; + }); + + const chartData = filteredData.map((stat) => ({ + name: getModelName(stat.modelId), + Positive: stat.positive, + Negative: stat.negative, + Unrated: stat.unrated, + total: stat.totalGenerations, + })); + + // Find the maximum value to set the domain + const maxValue = Math.max(...chartData.map((d) => d.total)); + // Calculate a nice round number for the max Y value + const yAxisMax = Math.ceil(maxValue / 5) * 5; + + return ( +
+ + + + Video + + + + Image + + + + Music + + + + Voice + + + + All + + + +
+ + + + + Math.round(value).toString()} + domain={[0, yAxisMax]} + ticks={Array.from({ length: yAxisMax + 1 }, (_, i) => i)} + allowDecimals={false} + /> + value} + labelFormatter={(label) => `Model: ${label}`} + /> + + {" "} + {/* Green */} + {" "} + {/* Red */} + {" "} + {/* Orange */} + + +
+ +
+ {filteredData.map((stat) => ( +
+

{getModelName(stat.modelId)}

+

+ {stat.totalGenerations} generations +

+
+

+ {" "} + Positive: {stat.positive} ({stat.positiveRate.toFixed(1)}%) +

+

+ {" "} + Negative: {stat.negative} ({stat.negativeRate.toFixed(1)}%) +

+

+ Unrated:{" "} + {stat.unrated} ({stat.unratedRate.toFixed(1)}%) +

+
+
+ ))} +
+
+ ); +} diff --git a/src/components/analytics/prompt-analysis.tsx b/src/components/analytics/prompt-analysis.tsx new file mode 100644 index 0000000..6e7c560 --- /dev/null +++ b/src/components/analytics/prompt-analysis.tsx @@ -0,0 +1,164 @@ +import { + type PromptAnalysis as PromptAnalysisType, + getModelName, +} from "@/lib/analytics"; +import { Badge } from "@/components/ui/badge"; +import { + LightbulbIcon, + CheckCircleIcon, + XCircleIcon, + SparklesIcon, +} from "lucide-react"; +import { ScrollArea } from "@/components/ui/scroll-area"; +import { Button } from "@/components/ui/button"; +import { useState } from "react"; + +interface ModelPrompts { + modelId: string; + positivePrompts: string[]; + negativePrompts: string[]; +} + +interface PromptAnalysisProps { + data: ModelPrompts[]; + onAnalyze: (modelId: string) => Promise; +} + +export function PromptAnalysis({ data, onAnalyze }: PromptAnalysisProps) { + const [analysisResults, setAnalysisResults] = useState< + Record + >({}); + const [analyzing, setAnalyzing] = useState>({}); + + if (data.length === 0) { + return ( +
+ No rated prompts available for analysis. Try rating some prompts as + positive or negative first. +
+ ); + } + + const handleAnalyze = async (modelId: string) => { + setAnalyzing((prev) => ({ ...prev, [modelId]: true })); + try { + const result = await onAnalyze(modelId); + setAnalysisResults((prev) => ({ ...prev, [modelId]: result })); + } finally { + setAnalyzing((prev) => ({ ...prev, [modelId]: false })); + } + }; + + return ( +
+ {data.map((model) => ( +
+
+

+ {getModelName(model.modelId)} +

+

+ {model.positivePrompts.length + model.negativePrompts.length}{" "} + rated prompts available ({model.positivePrompts.length} positive,{" "} + {model.negativePrompts.length} negative) +

+
+
+ {analysisResults[model.modelId] ? ( +
+ {/* Analysis Section */} +
+
+

+ + Successful Patterns +

+
    + {analysisResults[ + model.modelId + ].analysis.positivePatterns.map((pattern, i) => ( +
  • {pattern}
  • + ))} +
+
+ +
+

+ + Unsuccessful Patterns +

+
    + {analysisResults[ + model.modelId + ].analysis.negativePatterns.map((pattern, i) => ( +
  • {pattern}
  • + ))} +
+
+ +
+

+ + Recommendations +

+
    + {analysisResults[ + model.modelId + ].analysis.recommendations.map((rec, i) => ( +
  • {rec}
  • + ))} +
+
+
+ + {/* Example Prompts */} +
+
+

+ Example Successful Prompts +

+ + {model.positivePrompts.slice(0, 5).map((prompt, i) => ( +

+ {prompt} +

+ ))} +
+
+
+

+ Example Unsuccessful Prompts +

+ + {model.negativePrompts.slice(0, 5).map((prompt, i) => ( +

+ {prompt} +

+ ))} +
+
+
+
+ ) : ( +
+ +
+ )} +
+
+ ))} +
+ ); +} diff --git a/src/components/header.tsx b/src/components/header.tsx index df2a3f5..d0236b9 100644 --- a/src/components/header.tsx +++ b/src/components/header.tsx @@ -1,18 +1,31 @@ import { Button } from "@/components/ui/button"; import { Logo } from "./logo"; -import { SettingsIcon } from "lucide-react"; +import { BarChart2Icon, SettingsIcon } from "lucide-react"; +import { useVideoProjectStore } from "@/data/store"; export default function Header({ openKeyDialog, }: { openKeyDialog?: () => void; }) { + const setAnalyticsDialogOpen = useVideoProjectStore( + (s) => s.setAnalyticsDialogOpen, + ); + return (

+ )} +
+ + {/* Category Analysis Results */} + {categoryAnalysis[selectedType] && ( +
+

+ + {selectedType === "all" ? "Cross-Model Analysis" : `${selectedType} Models Analysis`} +

+ +
+
+

+ + Key Success Patterns +

+
    + {categoryAnalysis[selectedType].analysis.positivePatterns.map((pattern, i) => ( +
  • {pattern}
  • + ))} +
+
+ +
+

+ + Common Challenges +

+
    + {categoryAnalysis[selectedType].analysis.negativePatterns.map((pattern, i) => ( +
  • {pattern}
  • + ))} +
+
+ +
+

+ + Strategic Recommendations +

+
    + {categoryAnalysis[selectedType].analysis.recommendations.map((rec, i) => ( +
  • {rec}
  • + ))} +
+
+
+
+ )} + + {/* Individual Model Analysis */} + {filteredData.map((model) => (

- {model.positivePrompts.length + model.negativePrompts.length}{" "} - rated prompts available ({model.positivePrompts.length} positive,{" "} - {model.negativePrompts.length} negative) + {model.positivePrompts.length + model.negativePrompts.length} rated prompts available ({model.positivePrompts.length} positive, {model.negativePrompts.length} negative)

@@ -77,9 +221,7 @@ export function PromptAnalysis({ data, onAnalyze }: PromptAnalysisProps) { Successful Patterns
    - {analysisResults[ - model.modelId - ].analysis.positivePatterns.map((pattern, i) => ( + {analysisResults[model.modelId].analysis.positivePatterns.map((pattern, i) => (
  • {pattern}
  • ))}
@@ -91,9 +233,7 @@ export function PromptAnalysis({ data, onAnalyze }: PromptAnalysisProps) { Unsuccessful Patterns
    - {analysisResults[ - model.modelId - ].analysis.negativePatterns.map((pattern, i) => ( + {analysisResults[model.modelId].analysis.negativePatterns.map((pattern, i) => (
  • {pattern}
  • ))}
@@ -105,9 +245,7 @@ export function PromptAnalysis({ data, onAnalyze }: PromptAnalysisProps) { Recommendations
    - {analysisResults[ - model.modelId - ].analysis.recommendations.map((rec, i) => ( + {analysisResults[model.modelId].analysis.recommendations.map((rec, i) => (
  • {rec}
  • ))}
@@ -150,9 +288,7 @@ export function PromptAnalysis({ data, onAnalyze }: PromptAnalysisProps) { disabled={analyzing[model.modelId]} > - {analyzing[model.modelId] - ? "Analyzing Prompts..." - : "Get Prompting Insights"} + {analyzing[model.modelId] ? "Analyzing Prompts..." : "Get Prompting Insights"}
)} diff --git a/src/lib/analytics.ts b/src/lib/analytics.ts index 8be4e08..7a02817 100644 --- a/src/lib/analytics.ts +++ b/src/lib/analytics.ts @@ -128,13 +128,13 @@ export function getModelName(modelId: string): string { export async function analyzePrompts( mediaItems: MediaItem[], + customSystemPrompt?: string, ): Promise { // Group by model const byModel = new Map(); for (const item of mediaItems) { - if (item.kind !== "generated" || !item.input?.prompt || !item.rating) - continue; + if (item.kind !== "generated" || !item.input?.prompt || !item.rating) continue; const modelId = item.endpointId; if (!byModel.has(modelId)) { byModel.set(modelId, []); @@ -162,6 +162,7 @@ export async function analyzePrompts( positivePrompts, negativePrompts, modelId, + customSystemPrompt, ); return { @@ -180,6 +181,7 @@ async function analyzeWithGemini( positivePrompts: string[], negativePrompts: string[], modelId: string, + customSystemPrompt?: string, ): Promise<{ positivePatterns: string[]; negativePatterns: string[]; @@ -212,8 +214,7 @@ Example format: try { const { data } = await fal.subscribe("fal-ai/any-llm", { input: { - system_prompt: - "You are an AI prompt analysis assistant. Analyze patterns in successful and unsuccessful prompts to provide actionable insights. Always respond in valid JSON format.", + system_prompt: customSystemPrompt || "You are an AI prompt analysis assistant. Analyze patterns in successful and unsuccessful prompts to provide actionable insights. Always respond in valid JSON format.", prompt, model: "meta-llama/llama-3.2-1b-instruct", }, @@ -225,9 +226,7 @@ Example format: console.error("Failed to parse LLM response as JSON:", error); return { positivePatterns: ["Could not analyze patterns in successful prompts"], - negativePatterns: [ - "Could not analyze patterns in unsuccessful prompts", - ], + negativePatterns: ["Could not analyze patterns in unsuccessful prompts"], recommendations: ["Try rating more prompts to get better analysis"], }; }