Skip to content

Commit

Permalink
Improve KdTree memory allocation using ArraySegments
Browse files Browse the repository at this point in the history
  • Loading branch information
BobLd committed Jul 23, 2023
1 parent 76fc980 commit 8a82500
Show file tree
Hide file tree
Showing 6 changed files with 340 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
namespace UglyToad.PdfPig.DocumentLayoutAnalysis
{
using System;
using System.Collections.Generic;

/// <summary>
/// Useful <see cref="ArraySegment{T}"/> extensions.
/// </summary>
public static class ArraySegmentExtensions
{
/// <summary>
/// Returns a specified number of contiguous elements from the start of a sequence.
/// </summary>
/// <typeparam name="T">The type of the elements of <see name="source"/>.</typeparam>
/// <param name="source">An <see cref="ArraySegment{T}"/> to return elements from.</param>
/// <param name="count">The number of elements to return.</param>
/// <returns>An <see cref="ArraySegment{T}"/> that contains the specified number of elements from the start of the input sequence.</returns>
public static ArraySegment<T> Take<T>(this ArraySegment<T> source, int count)
{
return new ArraySegment<T>(source.Array, source.Offset, count);
}

/// <summary>
/// Bypasses a specified number of elements in a sequence and then returns the remaining elements.
/// </summary>
/// <typeparam name="T">The type of the elements of <see name="source"/>.</typeparam>
/// <param name="source">An <see cref="ArraySegment{T}"/> to return elements from.</param>
/// <param name="count">The number of elements to skip before returning the remaining elements.</param>
/// <returns>An <see cref="ArraySegment{T}"/> that contains the elements that occur after the specified index in the input sequence.</returns>
public static ArraySegment<T> Skip<T>(this ArraySegment<T> source, int count)
{
return new ArraySegment<T>(source.Array, source.Offset + count, source.Count - count);
}

/// <summary>
/// Sorts the elements in a <see cref="ArraySegment{T}"/> using the specified <see cref="IComparer{T}"/>.
/// </summary>
/// <typeparam name="T">The type of the elements of <see name="source"/>.</typeparam>
/// <param name="source">The <see cref="ArraySegment{T}"/> to sort.</param>
/// <param name="comparer">The implementation to use when comparing elements.</param>
public static void Sort<T>(this ArraySegment<T> source, IComparer<T> comparer)
{
Array.Sort(source.Array, source.Offset, source.Count, comparer);
}

/// <summary>
/// Returns the element at a specified index in a sequence.
/// </summary>
/// <typeparam name="T">The type of the elements of <see name="source"/>.</typeparam>
/// <param name="source">The <see cref="ArraySegment{T}"/> to get the element from.</param>
/// <param name="index">The index of the element to retrieve.</param>
/// <returns>The element at the specified position in the <see name="source"/> sequence.</returns>
public static T GetAt<T>(this ArraySegment<T> source, int index)
{
return source.Array[source.Offset + index];
}
}
}
88 changes: 68 additions & 20 deletions src/UglyToad.PdfPig.DocumentLayoutAnalysis/KdTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ public PdfPoint FindNearestNeighbour(PdfPoint pivot, Func<PdfPoint, PdfPoint, do
/// <typeparam name="T"></typeparam>
public class KdTree<T>
{
private readonly KdTreeComparerY kdTreeComparerY = new KdTreeComparerY();
private readonly KdTreeComparerX kdTreeComparerX = new KdTreeComparerX();

/// <summary>
/// The root of the tree.
/// </summary>
Expand All @@ -77,40 +80,49 @@ public KdTree(IReadOnlyList<T> elements, Func<T, PdfPoint> elementsPointFunc)
}

Count = elements.Count;
Root = BuildTree(Enumerable.Range(0, elements.Count).Zip(elements, (e, p) => (e, elementsPointFunc(p), p)).ToArray(), 0);

KdTreeElement<T>[] array = new KdTreeElement<T>[Count];

for (int i = 0; i < Count; i++)
{
var el = elements[i];
array[i] = new KdTreeElement<T>(i, elementsPointFunc(el), el);
}

Root = BuildTree(new ArraySegment<KdTreeElement<T>>(array));
}

private KdTreeNode<T> BuildTree((int, PdfPoint, T)[] P, int depth)
private KdTreeNode<T> BuildTree(ArraySegment<KdTreeElement<T>> P, int depth = 0)
{
if (P.Length == 0)
if (P.Count == 0)
{
return null;
}
else if (P.Length == 1)
else if (P.Count == 1)
{
return new KdTreeLeaf<T>(P[0], depth);
return new KdTreeLeaf<T>(P.GetAt(0), depth);
}

if (depth % 2 == 0)
{
Array.Sort(P, (p0, p1) => p0.Item2.X.CompareTo(p1.Item2.X));
P.Sort(kdTreeComparerX);
}
else
{
Array.Sort(P, (p0, p1) => p0.Item2.Y.CompareTo(p1.Item2.Y));
P.Sort(kdTreeComparerY);
}

if (P.Length == 2)
if (P.Count == 2)
{
return new KdTreeNode<T>(new KdTreeLeaf<T>(P[0], depth + 1), null, P[1], depth);
return new KdTreeNode<T>(new KdTreeLeaf<T>(P.GetAt(0), depth + 1), null, P.GetAt(1), depth);
}

int median = P.Length / 2;
int median = P.Count / 2;

KdTreeNode<T> vLeft = BuildTree(P.Take(median).ToArray(), depth + 1);
KdTreeNode<T> vRight = BuildTree(P.Skip(median + 1).ToArray(), depth + 1);
KdTreeNode<T> vLeft = BuildTree(P.Take(median), depth + 1);
KdTreeNode<T> vRight = BuildTree(P.Skip(median + 1), depth + 1);

return new KdTreeNode<T>(vLeft, vRight, P[median], depth);
return new KdTreeNode<T>(vLeft, vRight, P.GetAt(median), depth);
}

#region NN
Expand Down Expand Up @@ -216,7 +228,7 @@ private static (KdTreeNode<T>, double?) FindNearestNeighbour(KdTreeNode<T> node,
{
var kdTreeNodes = new KNearestNeighboursQueue(k);
FindNearestNeighbours(Root, pivot, k, pivotPointFunc, distanceMeasure, kdTreeNodes);
return kdTreeNodes.SelectMany(n => n.Value.Select(e => (e.Element, e.Index, n.Key))).ToList();
return kdTreeNodes.SelectMany(n => n.Value.Select(e => (e.Element, e.Index, n.Key))).ToArray();
}

private static (KdTreeNode<T>, double) FindNearestNeighbours(KdTreeNode<T> node, T pivot, int k,
Expand Down Expand Up @@ -350,6 +362,38 @@ public void Add(double key, KdTreeNode<T> value)
}
#endregion

internal readonly struct KdTreeElement<R>
{
internal KdTreeElement(int index, PdfPoint point, R value)
{
Index = index;
Value = point;
Element = value;
}

public int Index { get; }

public PdfPoint Value { get; }

public R Element { get; }
}

private sealed class KdTreeComparerY : IComparer<KdTreeElement<T>>
{
public int Compare(KdTreeElement<T> p0, KdTreeElement<T> p1)
{
return p0.Value.Y.CompareTo(p1.Value.Y);
}
}

private sealed class KdTreeComparerX : IComparer<KdTreeElement<T>>
{
public int Compare(KdTreeElement<T> p0, KdTreeElement<T> p1)
{
return p0.Value.X.CompareTo(p1.Value.X);
}
}

/// <summary>
/// K-D tree leaf.
/// </summary>
Expand All @@ -361,7 +405,7 @@ public class KdTreeLeaf<Q> : KdTreeNode<Q>
/// </summary>
public override bool IsLeaf => true;

internal KdTreeLeaf((int, PdfPoint, Q) point, int depth)
internal KdTreeLeaf(KdTreeElement<Q> point, int depth)
: base(null, null, point, depth)
{ }

Expand Down Expand Up @@ -423,15 +467,15 @@ public class KdTreeNode<Q>
/// </summary>
public int Index { get; }

internal KdTreeNode(KdTreeNode<Q> leftChild, KdTreeNode<Q> rightChild, (int, PdfPoint, Q) point, int depth)
internal KdTreeNode(KdTreeNode<Q> leftChild, KdTreeNode<Q> rightChild, KdTreeElement<Q> point, int depth)
{
LeftChild = leftChild;
RightChild = rightChild;
Value = point.Item2;
Element = point.Item3;
Value = point.Value;
Element = point.Element;
Depth = depth;
IsAxisCutX = depth % 2 == 0;
Index = point.Item1;
Index = point.Index;
}

/// <summary>
Expand All @@ -447,7 +491,11 @@ public IEnumerable<KdTreeLeaf<Q>> GetLeaves()

private void RecursiveGetLeaves(KdTreeNode<Q> leaf, ref List<KdTreeLeaf<Q>> leaves)
{
if (leaf == null) return;
if (leaf == null)
{
return;
}

if (leaf is KdTreeLeaf<Q> lLeaf)
{
leaves.Add(lLeaf);
Expand Down
Loading

0 comments on commit 8a82500

Please sign in to comment.