From 8896dc73f2ea603c247c305b60b1469ef71a1736 Mon Sep 17 00:00:00 2001 From: BritishWerewolf Date: Tue, 26 Nov 2024 14:15:58 +0000 Subject: [PATCH] Add `RawImage.split()` function to split images into channels; Improved documentation and tests (#978) * Add tests for original slice method. * Add vslice and tests to retrieve the entire length of a column. * Add a test for slicing every other column. * Add method to return each channel as a separate array. * Add documentation. Fix TypeScript error for unsure type. * Remove vslice as it doesn't work as it should. Update documentation. Update tests. * Optimize `RawImage.split()` function * Use dummy test image * Update tensor unit tests * Wrap `.split()` result in `RawImage` * Update JSDoc * Update JSDoc * Update comments --------- Co-authored-by: Joshua Lochner --- src/utils/image.js | 30 ++++++++++++++++++++++++++++++ src/utils/tensor.js | 34 +++++++++++++++++++++++++++++++++- tests/utils/tensor.test.js | 27 +++++++++++++++++++++++++++ tests/utils/utils.test.js | 23 +++++++++++++++++++++++ 4 files changed, 113 insertions(+), 1 deletion(-) diff --git a/src/utils/image.js b/src/utils/image.js index 3941b812b..04562592a 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -658,6 +658,36 @@ export class RawImage { return clonedCanvas; } + /** + * Split this image into individual bands. This method returns an array of individual image bands from an image. + * For example, splitting an "RGB" image creates three new images each containing a copy of one of the original bands (red, green, blue). + * + * Inspired by PIL's `Image.split()` [function](https://pillow.readthedocs.io/en/latest/reference/Image.html#PIL.Image.Image.split). + * @returns {RawImage[]} An array containing bands. + */ + split() { + const { data, width, height, channels } = this; + + /** @type {typeof Uint8Array | typeof Uint8ClampedArray} */ + const data_type = /** @type {any} */(data.constructor); + const per_channel_length = data.length / channels; + + // Pre-allocate buffers for each channel + const split_data = Array.from( + { length: channels }, + () => new data_type(per_channel_length), + ); + + // Write pixel data + for (let i = 0; i < per_channel_length; ++i) { + const data_offset = channels * i; + for (let j = 0; j < channels; ++j) { + split_data[j][i] = data[data_offset + j]; + } + } + return split_data.map((data) => new RawImage(data, width, height, 1)); + } + /** * Helper method to update the image data. * @param {Uint8ClampedArray} data The new image data. diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 536a8c249..dec65d1d7 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -340,10 +340,43 @@ export class Tensor { return this; } + /** + * Creates a deep copy of the current Tensor. + * @returns {Tensor} A new Tensor with the same type, data, and dimensions as the original. + */ clone() { return new Tensor(this.type, this.data.slice(), this.dims.slice()); } + /** + * Performs a slice operation on the Tensor along specified dimensions. + * + * Consider a Tensor that has a dimension of [4, 7]: + * ``` + * [ 1, 2, 3, 4, 5, 6, 7] + * [ 8, 9, 10, 11, 12, 13, 14] + * [15, 16, 17, 18, 19, 20, 21] + * [22, 23, 24, 25, 26, 27, 28] + * ``` + * We can slice against the two dims of row and column, for instance in this + * case we can start at the second element, and return to the second last, + * like this: + * ``` + * tensor.slice([1, -1], [1, -1]); + * ``` + * which would return: + * ``` + * [ 9, 10, 11, 12, 13 ] + * [ 16, 17, 18, 19, 20 ] + * ``` + * + * @param {...(number|number[]|null)} slices The slice specifications for each dimension. + * - If a number is given, then a single element is selected. + * - If an array of two numbers is given, then a range of elements [start, end (exclusive)] is selected. + * - If null is given, then the entire dimension is selected. + * @returns {Tensor} A new Tensor containing the selected elements. + * @throws {Error} If the slice input is invalid. + */ slice(...slices) { // This allows for slicing with ranges and numbers const newTensorDims = []; @@ -413,7 +446,6 @@ export class Tensor { data[i] = this_data[originalIndex]; } return new Tensor(this.type, data, newTensorDims); - } /** diff --git a/tests/utils/tensor.test.js b/tests/utils/tensor.test.js index 0d36954e3..1bede1984 100644 --- a/tests/utils/tensor.test.js +++ b/tests/utils/tensor.test.js @@ -51,6 +51,33 @@ describe("Tensor operations", () => { // TODO add tests for errors }); + describe("slice", () => { + it("should return a given row dim", async () => { + const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]); + const t2 = t1.slice(1); + const target = new Tensor("float32", [3, 4], [2]); + + compare(t2, target); + }); + + it("should return a range of rows", async () => { + const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]); + const t2 = t1.slice([1, 3]); + const target = new Tensor("float32", [3, 4, 5, 6], [2, 2]); + + compare(t2, target); + }); + + it("should return a crop", async () => { + const t1 = new Tensor("float32", Array.from({ length: 28 }, (_, i) => i + 1), [4, 7]); + const t2 = t1.slice([1, -1], [1, -1]); + + const target = new Tensor("float32", [9, 10, 11, 12, 13, 16, 17, 18, 19, 20], [2, 5]); + + compare(t2, target); + }); + }); + describe("stack", () => { const t1 = new Tensor("float32", [0, 1, 2, 3, 4, 5], [1, 3, 2]); diff --git a/tests/utils/utils.test.js b/tests/utils/utils.test.js index 7d50cbdc0..79c6dcc7f 100644 --- a/tests/utils/utils.test.js +++ b/tests/utils/utils.test.js @@ -62,11 +62,34 @@ describe("Utilities", () => { }); describe("Image utilities", () => { + const [width, height, channels] = [2, 2, 3]; + const data = Uint8Array.from({ length: width * height * channels }, (_, i) => i % 5); + const tiny_image = new RawImage(data, width, height, channels); + let image; beforeAll(async () => { image = await RawImage.fromURL("https://picsum.photos/300/200"); }); + it("Can split image into separate channels", async () => { + const image_data = tiny_image.split().map(x => x.data); + + const target = [ + new Uint8Array([0, 3, 1, 4]), // Reds + new Uint8Array([1, 4, 2, 0]), // Greens + new Uint8Array([2, 0, 3, 1]), // Blues + ]; + + compare(image_data, target); + }); + + it("Can splits channels for grayscale", async () => { + const image_data = tiny_image.grayscale().split().map(x => x.data); + const target = [new Uint8Array([1, 3, 2, 1])]; + + compare(image_data, target); + }); + it("Read image from URL", async () => { expect(image.width).toBe(300); expect(image.height).toBe(200);