Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support building graphs from MLTensor containing constants #760

Open
bbernhar opened this issue Sep 12, 2024 · 7 comments
Open

Support building graphs from MLTensor containing constants #760

bbernhar opened this issue Sep 12, 2024 · 7 comments

Comments

@bbernhar
Copy link

bbernhar commented Sep 12, 2024

Demonstrate how MLTensor can be used to help web developers manage constant data (e.g., trained weights) on-device.

Dependent PRs

Motivation

  • Allow constant data to be uploaded directly to the device, which is a capability that Execution Providers (EPs) leverage to prevent out-of-memory (OOM) errors (ORT example).
  • Re-use constant buffers in system memory between graphs, particularly for encoder-decoder models like Whisper.

Design

MLTensor containing constant data will be associated upon creating the MLOperand. At build(), the constant data will be forwarded into the device. The original constant data (ie. ArrayBuffer input or uploaded device data held by MLTensor) can be discarded immediately after createConstant() succeeds.

Example JS

// Upload constant data directly to device
constantTensor = ctx.createConstant(...); // immutable

builder1 = new MLGraphBuilder(ctx);
constantOp1 = builder1.constant(constantTensor);
constantOp2 = builder2.constant(constantTensor);
// ...
graph1 = await builder1.build(...);
graph2 = await builder2.build(...);

// Optional: free-up system memory
constantTensor.destroy();

Proposed IDL

interface MLConstantTensor : MLTensor {};

partial interface MLContext {
    Promise<MLConstantTensor> createConstant(MLOperandDataType dataType, ArrayBufferView sourceData);
};

partial interface MLGraphBuilder {
    MLOperand constant(MLConstantTensor tensor);
};

Edits:

  • 9/16: Added MLOperandDescriptor as required by MLOperand
  • 9/18: Added constant-initializer to createTensor()
  • 9/19: Reuse input(..) via constant usage flag
  • 1/29: Have new tensor type passed to constant()
@bbernhar
Copy link
Author

@a-sully @RafaelCintron @fdwr @huningxin appreciate any feedback

@fdwr
Copy link
Collaborator

fdwr commented Sep 12, 2024

constant_input -> constantInput (🚫🐍).

I'd need more time to think for meaningful feedback, but it may be rather confusing having this list of methods o_o:

graphBuilder.input()
graphBuilder.constant()
graphBuilder.constantInput()

@bbernhar
Copy link
Author

I'd need more time to think for meaningful feedback, but it may be rather confusing having this list of methods o_o:

Thanks for the quick feedback. I simplified the proposal even further via initializer + new usage bit.

@mmccool
Copy link

mmccool commented Sep 23, 2024

Definitely interested in this from the point of view of caching models as well (especially weights that might be used by both WebGPU and WebNN implementations).

@bbernhar
Copy link
Author

Revised the API design based on PoC feedback.

@a-sully
Copy link
Contributor

a-sully commented Jan 30, 2025

Overall approach mostly LGTM. Feedback here is mostly cosmetic except for #2 (and some of which I've already mentioned in comments on your prototype CL)

1. destroy() != Freeing Memory

At build(), the (un-optimized) constant data will be copied into the device.

We should not prescribe anything regarding how weights are stored or managed. Once the source buffer data is copied, everything after that is implementation-defined, just as the weights created from MLGraphBuilder.constant() are today. For example:

  • Rather than holding all weights in memory, the user agent could choose to write weights to disk to reduce memory pressure until build() is called (and then e.g. mmapped later as needed). In that case, MLConstantTensor.destroy() may delete the file, but with no effects on system memory usage
  • Some backends support weights which are decoupled from the model architecture. In this case, several models may point to the same file-backed memory and MLConstantTensor.destroy() may have no immediate effects (other than disallowing more models from also using this data)
  • Some backends support caching "packed" weights. In this case, MLConstantTensor.destroy() may allow the user agent to delete the original (unpacked) copy
  • Some backends (e.g. DML) require their own copy of each weight (see discussion starting with Allow no-op graphs? #614 (comment) and concluding with Allow no-op graphs? #614 (comment)). In this case MLConstantTensor.destroy() may allow the user agent to delete the original (unpacked) copy

Basically, MLConstantTensor.destroy() may or may not actually free system resources - it's just a signal that these weights will not be used in subsequent graphs.

-// Optional: free-up system memory
+// Optional: maybe free-up system resources
constantTensor.destroy();

2. Does this need to extend MLTensor?

interface MLConstantTensor : MLTensor {};

What's the rationale for this? Are you expecting that an MLConstantTensor may be used in dispatch() and readTensor()?

I think we should err on the side of not extending MLTensor unless there are very strong use cases to support MLTensor functionality. The proposed use cases have no overlap with the functionality of MLTensor. I'd prefer to not allow an MLConstantTensor to be passed to dispatch() or readTensor() at all.

We could consider naming this something like MLWeight to distinguish this new type from MLTensor?... Naming is hard ¯\(ツ)

3. Let's use AllowSharedBufferSource?

See #788

-    Promise<MLConstantTensor> createConstant(MLOperandDataType dataType, ArrayBufferView sourceData);
+    Promise<MLConstantTensor> createConstant(MLOperandDataType dataType, AllowSharedBufferSource buffer);

4. Comparison to Multi-Build: #567 (comment)

It's worth noting that sharing weights between graphs could also be supported by allowing build() to build multiple graphs. Here's my understanding (let me know if I'm missing anything) of the tradeoffs of these approaches:

MLConstantTensor Multi-graph build()
❌ Weight lifecycle management is the exposed to developers ✅ All graphs using these weights are known upfront, so no need for explicit create/destroy API surfaces
✅ More naturally matches how ML frameworks generally load weights ❌ May be harder for web ML frameworks to utilize
✅ Graph compilation costs may be spread out as the developer sees fit ❌ All graphs sharing weights are compiled at once, which may exacerbate issues of slow compilation
✅ If graph compilation fails, weights do not need to be re-uploaded ❌ Not obvious how the API should behave if some fraction of graphs fail to compile
✅ Logical decoupling of weights from the model architecture ❌ API suggests weights are tightly coupled to a model (which they might be under the hood, but that's an implementation detail)

As listed here, the left column does seem superior :) Just wanted to note this to be explicit about the tradeoffs we're making

@bbernhar
Copy link
Author

Thanks for the summary @a-sully.

We should not prescribe anything regarding how weights are stored or managed.
Let's use AllowSharedBufferSource?

SGTM.

What's the rationale for this? Are you expecting that an MLConstantTensor may be used in dispatch() and readTensor()?

We need a way to communicate to web developers that 'constant' tensors are read-only and non-dispatchable tensors. The WebGPU approach would be to introduce a new MLTensorUsage. Others have suggested extending MLTensor with MLConstantTensor, which would implicitly have this usage. I slightly prefer the WebGPU approach as naming it something other than tensor could cause some confusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants