Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add rating system for generated media #53

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions src/components/analytics/analytics-dialog.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import { useVideoProjectStore } from "@/data/store";
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { BarChart2Icon } from "lucide-react";
import { useProjectId } from "@/data/store";
import { useProjectMediaItems } from "@/data/queries";
import { calculateModelStats, analyzePrompts } from "@/lib/analytics";
import { ModelPerformance } from "./model-performance";
import { PromptAnalysis } from "./prompt-analysis";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { useCallback } from "react";
import type { MediaItem } from "@/data/schema";
import type { MediaType } from "@/data/store";
import type { PromptAnalysis as PromptAnalysisType } from "@/lib/analytics";

interface AnalyticsDialogProps {
onOpenChange?: (open: boolean) => void;
}

function preparePromptData(mediaItems: MediaItem[]) {
// Group rated prompts by model
const byModel = new Map<
string,
{ positivePrompts: string[]; negativePrompts: string[] }
>();

for (const item of mediaItems) {
if (item.kind !== "generated" || !item.input?.prompt || !item.rating)
continue;
const modelId = item.endpointId;
if (!byModel.has(modelId)) {
byModel.set(modelId, { positivePrompts: [], negativePrompts: [] });
}
const modelData = byModel.get(modelId)!;
if (item.rating === "positive") {
modelData.positivePrompts.push(item.input.prompt);
} else {
modelData.negativePrompts.push(item.input.prompt);
}
}

// Convert to array format
return Array.from(byModel.entries())
.filter(
([_, data]) =>
data.positivePrompts.length > 0 || data.negativePrompts.length > 0,
)
.map(([modelId, data]) => ({
modelId,
...data,
}));
}

export function AnalyticsDialog({ onOpenChange }: AnalyticsDialogProps) {
const projectId = useProjectId();
const analyticsDialogOpen = useVideoProjectStore(
(s) => s.analyticsDialogOpen,
);
const setAnalyticsDialogOpen = useVideoProjectStore(
(s) => s.setAnalyticsDialogOpen,
);
const { data: mediaItems = [] } = useProjectMediaItems(projectId);

const modelStats = calculateModelStats(mediaItems);
const promptData = preparePromptData(mediaItems);

const handleAnalyzePrompts = useCallback(
async (modelId: string): Promise<PromptAnalysisType> => {
const modelItems = mediaItems.filter(
(item) =>
item.kind === "generated" &&
item.endpointId === modelId &&
item.input?.prompt &&
item.rating,
);
const [analysis] = await analyzePrompts(modelItems);
return analysis;
},
[mediaItems],
);

const handleAnalyzeCategory = useCallback(
async (category: MediaType, modelIds: string[]): Promise<PromptAnalysisType> => {
// Get all rated prompts for the selected models
const modelItems = mediaItems.filter(
(item) =>
item.kind === "generated" &&
modelIds.includes(item.endpointId) &&
item.input?.prompt &&
item.rating
);

// Create a special system prompt for category analysis
const systemPrompt = `You are an AI prompt analysis assistant specializing in ${category} generation.
Analyze patterns across multiple models to identify what works best for ${category} generation.
Compare and contrast different approaches, and provide strategic recommendations for using these models effectively.
Always respond in valid JSON format.`;

// Call analyzePrompts with the category-specific system prompt
const [analysis] = await analyzePrompts(modelItems, systemPrompt);
return analysis;
},
[mediaItems],
);

const handleOnOpenChange = (isOpen: boolean) => {
onOpenChange?.(isOpen);
setAnalyticsDialogOpen(isOpen);
};

return (
<Dialog open={analyticsDialogOpen} onOpenChange={handleOnOpenChange}>
<DialogContent className="max-w-4xl">
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<BarChart2Icon className="w-5 h-5" />
<span>Analytics</span>
</DialogTitle>
</DialogHeader>

<Tabs defaultValue="performance" className="space-y-4">
<TabsList>
<TabsTrigger value="performance">Model Performance</TabsTrigger>
<TabsTrigger value="prompts">Prompt Analysis</TabsTrigger>
</TabsList>

<TabsContent value="performance" className="space-y-4">
<p className="text-sm text-muted-foreground">
Compare the performance of different AI models
</p>

{modelStats.length > 0 ? (
<ModelPerformance data={modelStats} />
) : (
<div className="min-h-[300px] flex items-center justify-center text-muted-foreground">
No model data available yet
</div>
)}
</TabsContent>

<TabsContent value="prompts" className="space-y-4">
<p className="text-sm text-muted-foreground">
Analyze patterns in successful and unsuccessful prompts
</p>

<PromptAnalysis
data={promptData}
onAnalyze={handleAnalyzePrompts}
onAnalyzeCategory={handleAnalyzeCategory}
/>
</TabsContent>
</Tabs>
</DialogContent>
</Dialog>
);
}
172 changes: 172 additions & 0 deletions src/components/analytics/model-performance.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import { type ModelStats, getModelName } from "@/lib/analytics";
import {
BarChart,
Bar,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
Legend,
ResponsiveContainer,
} from "recharts";
import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group";
import {
FilmIcon,
ImageIcon,
MicIcon,
MusicIcon,
LayersIcon,
} from "lucide-react";
import { useState } from "react";
import type { MediaType } from "@/data/store";
import { AVAILABLE_ENDPOINTS } from "@/lib/fal";

interface ModelPerformanceProps {
data: ModelStats[];
}

type FilterType = MediaType | "all";

export function ModelPerformance({ data }: ModelPerformanceProps) {
const [selectedType, setSelectedType] = useState<FilterType>("video");

const handleFilterChange = (value: string | undefined) => {
if (!value) return;
setSelectedType(value as FilterType);
};

const getModelType = (modelId: string): MediaType | null => {
// First try exact match
const endpoint = AVAILABLE_ENDPOINTS.find((e) => e.endpointId === modelId);
if (endpoint) return endpoint.category;

// Try matching base model ID (for variants like image-to-video)
const baseModelId = modelId.split("/").slice(0, -1).join("/");
const baseEndpoint = AVAILABLE_ENDPOINTS.find(
(e) => e.endpointId === baseModelId,
);
return baseEndpoint?.category || null;
};

const filteredData = data.filter((stat) => {
if (selectedType === "all") return true;
const modelType = getModelType(stat.modelId);
return modelType === selectedType;
});

const chartData = filteredData.map((stat) => ({
name: getModelName(stat.modelId),
Positive: stat.positive,
Negative: stat.negative,
Unrated: stat.unrated,
total: stat.totalGenerations,
}));

// Find the maximum value to set the domain
const maxValue = Math.max(...chartData.map((d) => d.total));
// Calculate a nice round number for the max Y value
const yAxisMax = Math.ceil(maxValue / 5) * 5;

return (
<div className="w-full space-y-6">
<ToggleGroup
type="single"
value={selectedType}
onValueChange={handleFilterChange}
className="justify-start"
>
<ToggleGroupItem value="video" aria-label="Video models">
<FilmIcon className="w-4 h-4 mr-2" />
Video
</ToggleGroupItem>
<ToggleGroupItem value="image" aria-label="Image models">
<ImageIcon className="w-4 h-4 mr-2" />
Image
</ToggleGroupItem>
<ToggleGroupItem value="music" aria-label="Music models">
<MusicIcon className="w-4 h-4 mr-2" />
Music
</ToggleGroupItem>
<ToggleGroupItem value="voiceover" aria-label="Voiceover models">
<MicIcon className="w-4 h-4 mr-2" />
Voice
</ToggleGroupItem>
<ToggleGroupItem value="all" aria-label="All models">
<LayersIcon className="w-4 h-4 mr-2" />
All
</ToggleGroupItem>
</ToggleGroup>

<div className="w-full h-[400px]">
<ResponsiveContainer width="100%" height="100%">
<BarChart
data={chartData}
margin={{
top: 20,
right: 30,
left: 20,
bottom: 100,
}}
>
<CartesianGrid strokeDasharray="3 3" />
<XAxis
dataKey="name"
angle={-45}
textAnchor="end"
height={100}
interval={0}
/>
<YAxis
label={{
value: "Number of Generations",
angle: -90,
position: "insideLeft",
offset: 10,
}}
tickFormatter={(value) => Math.round(value).toString()}
domain={[0, yAxisMax]}
ticks={Array.from({ length: yAxisMax + 1 }, (_, i) => i)}
allowDecimals={false}
/>
<Tooltip
formatter={(value: number) => value}
labelFormatter={(label) => `Model: ${label}`}
/>
<Legend />
<Bar dataKey="Positive" fill="hsl(142.1 76.2% 36.3%)" stackId="a" />{" "}
{/* Green */}
<Bar dataKey="Negative" fill="hsl(346.8 77.2% 49.8%)" stackId="a" />{" "}
{/* Red */}
<Bar dataKey="Unrated" fill="hsl(24.6 95% 53.1%)" stackId="a" />{" "}
{/* Orange */}
</BarChart>
</ResponsiveContainer>
</div>

<div className="grid grid-cols-2 md:grid-cols-4 gap-4">
{filteredData.map((stat) => (
<div key={stat.modelId} className="p-4 rounded-lg bg-accent">
<h3 className="font-medium mb-1">{getModelName(stat.modelId)}</h3>
<p className="text-sm text-muted-foreground">
{stat.totalGenerations} generations
</p>
<div className="mt-2 space-y-1 text-sm">
<p>
<span className="text-[hsl(142.1,76.2%,36.3%)]">●</span>{" "}
Positive: {stat.positive} ({stat.positiveRate.toFixed(1)}%)
</p>
<p>
<span className="text-[hsl(346.8,77.2%,49.8%)]">●</span>{" "}
Negative: {stat.negative} ({stat.negativeRate.toFixed(1)}%)
</p>
<p>
<span className="text-[hsl(24.6,95%,53.1%)]">●</span> Unrated:{" "}
{stat.unrated} ({stat.unratedRate.toFixed(1)}%)
</p>
</div>
</div>
))}
</div>
</div>
);
}
Loading