diff --git a/apps/web/src/app/api/trpc/[trpc]/route.ts b/apps/web/src/app/api/trpc/[trpc]/route.ts index 677ef2814fc..5c69ef06a4d 100644 --- a/apps/web/src/app/api/trpc/[trpc]/route.ts +++ b/apps/web/src/app/api/trpc/[trpc]/route.ts @@ -28,11 +28,16 @@ const handler = (req: NextRequest) => { } : undefined; + const ip = + process.env.NODE_ENV === "development" ? "127.0.0.1" : ipAddress(req); + + const identifier = + session?.user?.id ?? req.headers.get("x-vercel-ja4-digest") ?? ip; + return { user, locale, - ip: - process.env.NODE_ENV === "development" ? "127.0.0.1" : ipAddress(req), + identifier, } satisfies TRPCContext; }, onError({ error }) { diff --git a/apps/web/src/trpc/context.ts b/apps/web/src/trpc/context.ts index ffba2156f07..74bcca84179 100644 --- a/apps/web/src/trpc/context.ts +++ b/apps/web/src/trpc/context.ts @@ -11,5 +11,5 @@ type User = { export type TRPCContext = { user?: User; locale?: string; - ip?: string; + identifier?: string; }; diff --git a/apps/web/src/trpc/routers/auth.ts b/apps/web/src/trpc/routers/auth.ts index a87c17778e1..e66472153c0 100644 --- a/apps/web/src/trpc/routers/auth.ts +++ b/apps/web/src/trpc/routers/auth.ts @@ -29,7 +29,7 @@ export const auth = router({ return { isRegistered: count > 0 }; }), requestRegistration: publicProcedure - .use(createRateLimitMiddleware(5, "1 m")) + .use(createRateLimitMiddleware("request_registration", 5, "1 m")) .input( z.object({ name: z.string().min(1).max(100), diff --git a/apps/web/src/trpc/routers/polls.ts b/apps/web/src/trpc/routers/polls.ts index f28b46f8e94..9d0555bd3bd 100644 --- a/apps/web/src/trpc/routers/polls.ts +++ b/apps/web/src/trpc/routers/polls.ts @@ -130,7 +130,7 @@ export const polls = router({ // START LEGACY ROUTES create: possiblyPublicProcedure - .use(createRateLimitMiddleware(20, "1 h")) + .use(createRateLimitMiddleware("create_poll", 10, "1 h")) .use(requireUserMiddleware) .input( z.object({ @@ -233,7 +233,6 @@ export const polls = router({ return { id: poll.id }; }), update: possiblyPublicProcedure - .use(createRateLimitMiddleware(60, "1 h")) .input( z.object({ urlId: z.string(), @@ -306,7 +305,6 @@ export const polls = router({ }); }), delete: possiblyPublicProcedure - .use(createRateLimitMiddleware(30, "1 h")) .input( z.object({ urlId: z.string(), diff --git a/apps/web/src/trpc/routers/polls/comments.ts b/apps/web/src/trpc/routers/polls/comments.ts index b4de8091438..d541bfcbefe 100644 --- a/apps/web/src/trpc/routers/polls/comments.ts +++ b/apps/web/src/trpc/routers/polls/comments.ts @@ -72,7 +72,7 @@ export const comments = router({ }); }), add: publicProcedure - .use(createRateLimitMiddleware(5, "1 m")) + .use(createRateLimitMiddleware("add_comment", 5, "1 m")) .use(requireUserMiddleware) .input( z.object({ diff --git a/apps/web/src/trpc/routers/polls/participants.ts b/apps/web/src/trpc/routers/polls/participants.ts index 49d360c3b52..d1c571ede06 100644 --- a/apps/web/src/trpc/routers/polls/participants.ts +++ b/apps/web/src/trpc/routers/polls/participants.ts @@ -105,7 +105,6 @@ export const participants = router({ return participants; }), delete: publicProcedure - .use(createRateLimitMiddleware(20, "1 m")) .input( z.object({ participantId: z.string(), @@ -123,7 +122,7 @@ export const participants = router({ }); }), add: publicProcedure - .use(createRateLimitMiddleware(20, "1 m")) + .use(createRateLimitMiddleware("add_participant", 5, "1 m")) .use(requireUserMiddleware) .input( z.object({ @@ -218,7 +217,6 @@ export const participants = router({ return participant; }), rename: publicProcedure - .use(createRateLimitMiddleware(20, "1 m")) .input(z.object({ participantId: z.string(), newName: z.string() })) .mutation(async ({ input: { participantId, newName } }) => { await prisma.participant.update({ @@ -232,7 +230,6 @@ export const participants = router({ }); }), update: publicProcedure - .use(createRateLimitMiddleware(20, "1 m")) .input( z.object({ pollId: z.string(), diff --git a/apps/web/src/trpc/routers/user.ts b/apps/web/src/trpc/routers/user.ts index f14f4fc1158..fd60206667c 100644 --- a/apps/web/src/trpc/routers/user.ts +++ b/apps/web/src/trpc/routers/user.ts @@ -38,22 +38,20 @@ export const user = router({ }, }); }), - delete: privateProcedure - .use(createRateLimitMiddleware(5, "1 h")) - .mutation(async ({ ctx }) => { - if (ctx.user.isGuest) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Guest users cannot be deleted", - }); - } - - await prisma.user.delete({ - where: { - id: ctx.user.id, - }, + delete: privateProcedure.mutation(async ({ ctx }) => { + if (ctx.user.isGuest) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Guest users cannot be deleted", }); - }), + } + + await prisma.user.delete({ + where: { + id: ctx.user.id, + }, + }); + }), subscription: publicProcedure.query( async ({ ctx }): Promise<{ legacy?: boolean; active: boolean }> => { if (!ctx.user || ctx.user.isGuest) { @@ -67,7 +65,6 @@ export const user = router({ }, ), changeName: privateProcedure - .use(createRateLimitMiddleware(20, "1 h")) .input( z.object({ name: z.string().min(1).max(100), @@ -84,7 +81,6 @@ export const user = router({ }); }), updatePreferences: privateProcedure - .use(createRateLimitMiddleware(30, "1 h")) .input( z.object({ locale: z.string().optional(), @@ -111,7 +107,7 @@ export const user = router({ return { success: true }; }), requestEmailChange: privateProcedure - .use(createRateLimitMiddleware(10, "1 h")) + .use(createRateLimitMiddleware("request_email_change", 10, "1 h")) .input(z.object({ email: z.string().email() })) .mutation(async ({ input, ctx }) => { const currentUser = await prisma.user.findUnique({ @@ -163,7 +159,7 @@ export const user = router({ return { success: true as const }; }), getAvatarUploadUrl: privateProcedure - .use(createRateLimitMiddleware(20, "1 h")) + .use(createRateLimitMiddleware("get_avatar_upload_url", 10, "1 h")) .input( z.object({ fileType: z.enum(["image/jpeg", "image/png"]), @@ -209,7 +205,6 @@ export const user = router({ }), updateAvatar: privateProcedure .input(z.object({ imageKey: z.string().max(255) })) - .use(createRateLimitMiddleware(10, "1 h")) .mutation(async ({ ctx, input }) => { const userId = ctx.user.id; const oldImageKey = ctx.user.image; diff --git a/apps/web/src/trpc/trpc.ts b/apps/web/src/trpc/trpc.ts index c09f66bc1d1..7c50efef7f7 100644 --- a/apps/web/src/trpc/trpc.ts +++ b/apps/web/src/trpc/trpc.ts @@ -90,6 +90,7 @@ export const proProcedure = privateProcedure.use(async ({ ctx, next }) => { }); export const createRateLimitMiddleware = ( + name: string, requests: number, duration: "1 m" | "1 h", ) => { @@ -98,20 +99,27 @@ export const createRateLimitMiddleware = ( return next(); } - if (!ctx.ip) { + if (!ctx.identifier) { throw new TRPCError({ code: "INTERNAL_SERVER_ERROR", - message: "Failed to get client IP", + message: "Failed to get identifier", }); } + const ratelimit = new Ratelimit({ redis: kv, limiter: Ratelimit.slidingWindow(requests, duration), }); - const res = await ratelimit.limit(ctx.ip); + const res = await ratelimit.limit(`${name}:${ctx.identifier}`); if (!res.success) { + console.warn("Rate limit exceeded", { + identifier: ctx.identifier, + endpoint: name, + limit: requests, + duration, + }); throw new TRPCError({ code: "TOO_MANY_REQUESTS", message: "Too many requests",