Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support file uploads using a stream #252

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/strange-chairs-search.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@google/generative-ai": minor
---

Support file uploads using a stream
6 changes: 5 additions & 1 deletion common/api-review/generative-ai-server.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

```ts

/// <reference types="node" />

import { Readable } from 'node:stream';

// @public
export interface CachedContent extends CachedContentBase {
createTime?: string;
Expand Down Expand Up @@ -347,7 +351,7 @@ export class GoogleAIFileManager {
deleteFile(fileId: string): Promise<void>;
getFile(fileId: string, requestOptions?: SingleRequestOptions): Promise<FileMetadataResponse>;
listFiles(listParams?: ListParams, requestOptions?: SingleRequestOptions): Promise<ListFilesResponse>;
uploadFile(filePath: string, fileMetadata: FileMetadata): Promise<UploadFileResponse>;
uploadFile(filePathOrStream: string | Readable, fileMetadata: FileMetadata): Promise<UploadFileResponse>;
}

// @public
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/server/generative-ai.googleaifilemanager.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ export declare class GoogleAIFileManager
| [deleteFile(fileId)](./generative-ai.googleaifilemanager.deletefile.md) | | Delete file with given ID. |
| [getFile(fileId, requestOptions)](./generative-ai.googleaifilemanager.getfile.md) | | <p>Get metadata for file with given ID.</p><p>Any fields set in the optional [SingleRequestOptions](./generative-ai.singlerequestoptions.md) parameter will take precedence over the [RequestOptions](./generative-ai.requestoptions.md) values provided at the time of the [GoogleAIFileManager](./generative-ai.googleaifilemanager.md) initialization.</p> |
| [listFiles(listParams, requestOptions)](./generative-ai.googleaifilemanager.listfiles.md) | | <p>List all uploaded files.</p><p>Any fields set in the optional [SingleRequestOptions](./generative-ai.singlerequestoptions.md) parameter will take precedence over the [RequestOptions](./generative-ai.requestoptions.md) values provided at the time of the [GoogleAIFileManager](./generative-ai.googleaifilemanager.md) initialization.</p> |
| [uploadFile(filePath, fileMetadata)](./generative-ai.googleaifilemanager.uploadfile.md) | | Upload a file. |
| [uploadFile(filePathOrStream, fileMetadata)](./generative-ai.googleaifilemanager.uploadfile.md) | | Upload a file. |

Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ Upload a file.
**Signature:**

```typescript
uploadFile(filePath: string, fileMetadata: FileMetadata): Promise<UploadFileResponse>;
uploadFile(filePathOrStream: string | Readable, fileMetadata: FileMetadata): Promise<UploadFileResponse>;
```

## Parameters

| Parameter | Type | Description |
| --- | --- | --- |
| filePath | string | |
| filePathOrStream | string \| Readable | |
| fileMetadata | [FileMetadata](./generative-ai.filemetadata.md) | |

**Returns:**
Expand Down
24 changes: 16 additions & 8 deletions src/server/file-manager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import * as request from "./request";
import { RpcTask } from "./constants";
import { DEFAULT_API_VERSION } from "../requests/request";
import { FileMetadata } from "../../types/server";

import { blob } from "node:stream/consumers";
use(sinonChai);
use(chaiAsPromised);

Expand Down Expand Up @@ -56,8 +56,10 @@ describe("GoogleAIFileManager", () => {
expect(makeRequestStub.args[0][1].get("X-Goog-Upload-Protocol")).to.equal(
"multipart",
);
expect(makeRequestStub.args[0][2]).to.be.instanceOf(Blob);
const bodyBlob = makeRequestStub.args[0][2];
expect(makeRequestStub.args[0][2]).to.have.property("next");
const bodyBlob = await blob(
makeRequestStub.args[0][2] as any as NodeJS.ReadableStream,
);
const blobText = await (bodyBlob as Blob).text();
expect(blobText).to.include("Content-Type: image/png");
});
Expand All @@ -73,8 +75,10 @@ describe("GoogleAIFileManager", () => {
displayName: "mydisplayname",
});
expect(result.file.uri).to.equal(FAKE_URI);
expect(makeRequestStub.args[0][2]).to.be.instanceOf(Blob);
const bodyBlob = makeRequestStub.args[0][2];
expect(makeRequestStub.args[0][2]).to.have.property("next");
const bodyBlob = await blob(
makeRequestStub.args[0][2] as any as NodeJS.ReadableStream,
);
const blobText = await (bodyBlob as Blob).text();
expect(blobText).to.include("Content-Type: image/png");
expect(blobText).to.include("files/customname");
Expand All @@ -91,7 +95,9 @@ describe("GoogleAIFileManager", () => {
name: "customname",
displayName: "mydisplayname",
});
const bodyBlob = makeRequestStub.args[0][2];
const bodyBlob = await blob(
makeRequestStub.args[0][2] as any as NodeJS.ReadableStream,
);
const blobText = await (bodyBlob as Blob).text();
expect(blobText).to.include("files/customname");
});
Expand All @@ -114,8 +120,10 @@ describe("GoogleAIFileManager", () => {
expect(makeRequestStub.args[0][1].get("X-Goog-Upload-Protocol")).to.equal(
"multipart",
);
expect(makeRequestStub.args[0][2]).to.be.instanceOf(Blob);
const bodyBlob = makeRequestStub.args[0][2];
expect(makeRequestStub.args[0][2]).to.have.property("next");
const bodyBlob = await blob(
makeRequestStub.args[0][2] as any as NodeJS.ReadableStream,
);
const blobText = await (bodyBlob as Blob).text();
expect(blobText).to.include("Content-Type: image/png");
expect(makeRequestStub.args[0][0].toString()).to.include("v3000/files");
Expand Down
47 changes: 29 additions & 18 deletions src/server/file-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

import { RequestOptions, SingleRequestOptions } from "../../types";
import { readFileSync } from "fs";
import { createReadStream } from "node:fs";
import { FilesRequestUrl, getHeaders, makeServerRequest } from "./request";
import {
FileMetadata,
Expand All @@ -30,6 +30,7 @@ import {
GoogleGenerativeAIError,
GoogleGenerativeAIRequestInputError,
} from "../errors";
import { Readable } from "node:stream";

// Internal type, metadata sent in the upload
export interface UploadMetadata {
Expand All @@ -51,10 +52,14 @@ export class GoogleAIFileManager {
* Upload a file.
*/
async uploadFile(
filePath: string,
filePathOrStream: string | Readable,
fileMetadata: FileMetadata,
): Promise<UploadFileResponse> {
const file = readFileSync(filePath);
const file =
typeof filePathOrStream === "string"
? createReadStream(filePathOrStream)
: filePathOrStream;

const url = new FilesRequestUrl(
RpcTask.UPLOAD,
this.apiKey,
Expand All @@ -73,22 +78,28 @@ export class GoogleAIFileManager {

// Multipart formatting code taken from @firebase/storage
const metadataString = JSON.stringify({ file: uploadMetadata });
const preBlobPart =
const preBlobPart = new TextEncoder().encode(
"--" +
boundary +
"\r\n" +
"Content-Type: application/json; charset=utf-8\r\n\r\n" +
metadataString +
"\r\n--" +
boundary +
"\r\n" +
"Content-Type: " +
fileMetadata.mimeType +
"\r\n\r\n";
const postBlobPart = "\r\n--" + boundary + "--";
const blob = new Blob([preBlobPart, file, postBlobPart]);

const response = await makeServerRequest(url, uploadHeaders, blob);
boundary +
"\r\n" +
"Content-Type: application/json; charset=utf-8\r\n\r\n" +
metadataString +
"\r\n--" +
boundary +
"\r\n" +
"Content-Type: " +
fileMetadata.mimeType +
"\r\n\r\n",
);
const postBlobPart = new TextEncoder().encode("\r\n--" + boundary + "--");

const stream = (async function* () {
yield preBlobPart;
yield* file;
yield postBlobPart;
})();

const response = await makeServerRequest(url, uploadHeaders, stream);
return response.json();
}

Expand Down
5 changes: 3 additions & 2 deletions src/server/request.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,14 @@ describe("Files API - request methods", () => {
const response = await makeServerRequest(
url,
headers,
new Blob(),
(async function* () {})(),
fetchStub as typeof fetch,
);
expect(fetchStub).to.be.calledWith(match.string, {
method: "POST",
headers: match.instanceOf(Headers),
body: match.instanceOf(Blob),
body: match.instanceOf(ReadableStream),
duplex: "half",
});
expect(response.ok).to.be.true;
});
Expand Down
26 changes: 23 additions & 3 deletions src/server/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,36 @@ export function getHeaders(url: ServerRequestUrl): Headers {
export async function makeServerRequest(
url: FilesRequestUrl,
headers: Headers,
body?: Blob | string,
body?: Blob | string | AsyncIterable<Uint8Array>,
fetchFn: typeof fetch = fetch,
): Promise<Response> {
const requestInit: RequestInit = {
// Add the duplex option, which is required when streaming in newer versions of node.
// See: https://github.com/nodejs/node/issues/46221
const requestInit: RequestInit & { duplex?: "half" } = {
method: taskToMethod[url.task],
headers,
duplex: "half",
};

if (body) {
if (typeof body === "string" || body instanceof Blob) {
requestInit.body = body;
} else if (body?.[Symbol.asyncIterator]) {
// Note that in later versions, the signature `fetch` is updated to accept any AsyncIterator,
// and ReadableStream implements AsyncIterator. In this case, `body` can be passed exactly
// as supplied, and the following can be removed:
const iterator = body[Symbol.asyncIterator]();
requestInit.body = new ReadableStream({
type: "bytes",
async pull(controller) {
const { value, done } = await iterator.next();
if (done) {
controller.close();
return;
}

controller.enqueue(value);
},
});
}

const signal = getSignal(url.requestOptions);
Expand Down