diff --git a/src/core.ts b/src/core.ts index 3591728..c837dc6 100644 --- a/src/core.ts +++ b/src/core.ts @@ -349,7 +349,11 @@ export abstract class APIClient { delete reqHeaders['content-type']; } - reqHeaders['x-stainless-retry-count'] = String(retryCount); + // Don't set the retry count header if it was already set or removed by the caller. We check `headers`, + // which can contain nulls, instead of `reqHeaders` to account for the removal case. + if (getHeader(headers, 'x-stainless-retry-count') === undefined) { + reqHeaders['x-stainless-retry-count'] = String(retryCount); + } this.validateHeaders(reqHeaders, headers); @@ -1128,7 +1132,15 @@ export const isHeadersProtocol = (headers: any): headers is HeadersProtocol => { return typeof headers?.get === 'function'; }; -export const getRequiredHeader = (headers: HeadersLike, header: string): string => { +export const getRequiredHeader = (headers: HeadersLike | Headers, header: string): string => { + const foundHeader = getHeader(headers, header); + if (foundHeader === undefined) { + throw new Error(`Could not find ${header} header`); + } + return foundHeader; +}; + +export const getHeader = (headers: HeadersLike | Headers, header: string): string | undefined => { const lowerCasedHeader = header.toLowerCase(); if (isHeadersProtocol(headers)) { // to deal with the case where the header looks like Stainless-Event-Id @@ -1154,7 +1166,7 @@ export const getRequiredHeader = (headers: HeadersLike, header: string): string } } - throw new Error(`Could not find ${header} header`); + return undefined; }; /** diff --git a/tests/index.test.ts b/tests/index.test.ts index 28469d4..fb1e7c9 100644 --- a/tests/index.test.ts +++ b/tests/index.test.ts @@ -266,6 +266,64 @@ describe('retries', () => { expect(count).toEqual(3); }); + test('omit retry count header', async () => { + let count = 0; + let capturedRequest: RequestInit | undefined; + const testFetch = async (url: RequestInfo, init: RequestInit = {}): Promise => { + count++; + if (count <= 2) { + return new Response(undefined, { + status: 429, + headers: { + 'Retry-After': '0.1', + }, + }); + } + capturedRequest = init; + return new Response(JSON.stringify({ a: 1 }), { headers: { 'Content-Type': 'application/json' } }); + }; + const client = new RunwayML({ apiKey: 'My API Key', fetch: testFetch, maxRetries: 4 }); + + expect( + await client.request({ + path: '/foo', + method: 'get', + headers: { 'X-Stainless-Retry-Count': null }, + }), + ).toEqual({ a: 1 }); + + expect(capturedRequest!.headers as Headers).not.toHaveProperty('x-stainless-retry-count'); + }); + + test('overwrite retry count header', async () => { + let count = 0; + let capturedRequest: RequestInit | undefined; + const testFetch = async (url: RequestInfo, init: RequestInit = {}): Promise => { + count++; + if (count <= 2) { + return new Response(undefined, { + status: 429, + headers: { + 'Retry-After': '0.1', + }, + }); + } + capturedRequest = init; + return new Response(JSON.stringify({ a: 1 }), { headers: { 'Content-Type': 'application/json' } }); + }; + const client = new RunwayML({ apiKey: 'My API Key', fetch: testFetch, maxRetries: 4 }); + + expect( + await client.request({ + path: '/foo', + method: 'get', + headers: { 'X-Stainless-Retry-Count': '42' }, + }), + ).toEqual({ a: 1 }); + + expect((capturedRequest!.headers as Headers)['x-stainless-retry-count']).toBe('42'); + }); + test('retry on 429 with retry-after', async () => { let count = 0; const testFetch = async (url: RequestInfo, { signal }: RequestInit = {}): Promise => {