diff --git a/stdlib/src/collections/inline_list.mojo b/stdlib/src/collections/inline_list.mojo index bf17dfa8ba..f8933ba952 100644 --- a/stdlib/src/collections/inline_list.mojo +++ b/stdlib/src/collections/inline_list.mojo @@ -19,6 +19,7 @@ from collections import InlineList ``` """ +from collections._index_normalization import normalize_index from sys.intrinsics import _type_is_eq from memory.maybe_uninitialized import UnsafeMaybeUninitialized @@ -145,15 +146,11 @@ struct InlineList[ElementType: CollectionElementNew, capacity: Int = 16](Sized): Returns: A reference to the item at the given index. """ - var index = Int(idx) - debug_assert( - -self._size <= index < self._size, "Index must be within bounds." + # Using UInt to avoid extra signed normalization in self._array + var normalized_index = normalize_index["InlineList"]( + idx, UInt(self._size) ) - - if index < 0: - index += len(self) - - return self._array[index].assume_initialized() + return self._array[normalized_index].assume_initialized() # ===-------------------------------------------------------------------===# # Trait implementations diff --git a/stdlib/src/collections/linked_list.mojo b/stdlib/src/collections/linked_list.mojo index 83bf614c47..ad0befdb9e 100644 --- a/stdlib/src/collections/linked_list.mojo +++ b/stdlib/src/collections/linked_list.mojo @@ -284,7 +284,7 @@ struct LinkedList[ elem.free() return value^ - fn pop[I: Indexer](mut self, owned i: I) raises -> ElementType: + fn pop[I: Indexer](mut self, idx: I) raises -> ElementType: """ Remove the ith element of the list, counting from the tail if given a negative index. @@ -295,12 +295,12 @@ struct LinkedList[ I: The type of index to use. Args: - i: The index of the element to get. + idx: The index of the element to get. Returns: Ownership of the indicated element. """ - var current = self._get_node_ptr(Int(i)) + var current = self._get_node_ptr(idx) if current: var node = current[] @@ -323,7 +323,7 @@ struct LinkedList[ self._size -= 1 return data^ - raise String("Invalid index for pop: {}").format(Int(i)) + raise String("Invalid index for pop: ", Int(idx)) fn maybe_pop(mut self) -> Optional[ElementType]: """ @@ -347,7 +347,7 @@ struct LinkedList[ elem.free() return value^ - fn maybe_pop[I: Indexer](mut self, owned i: I) -> Optional[ElementType]: + fn maybe_pop[I: Indexer](mut self, idx: I) -> Optional[ElementType]: """ Remove the ith element of the list, counting from the tail if given a negative index. @@ -358,12 +358,12 @@ struct LinkedList[ I: The type of index to use. Args: - i: The index of the element to get. + idx: The index of the element to get. Returns: The element, if it was found. """ - var current = self._get_node_ptr(Int(i)) + var current = self._get_node_ptr(idx) if not current: return Optional[ElementType]() diff --git a/stdlib/src/collections/list.mojo b/stdlib/src/collections/list.mojo index ba6b8d9760..677ce18005 100644 --- a/stdlib/src/collections/list.mojo +++ b/stdlib/src/collections/list.mojo @@ -20,6 +20,7 @@ from collections import List """ +from collections._index_normalization import normalize_index from os import abort from sys import sizeof from sys.intrinsics import _type_is_eq @@ -884,23 +885,8 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( Returns: A reference to the element at the given index. """ - - @parameter - if _type_is_eq[I, UInt](): - return (self.data + idx)[] - else: - var normalized_idx = Int(idx) - debug_assert( - -self.size <= normalized_idx < self.size, - "index: ", - normalized_idx, - " is out of bounds for `List` of size: ", - self.size, - ) - if normalized_idx < 0: - normalized_idx += len(self) - - return (self.data + normalized_idx)[] + var normalized_index = normalize_index["List"](idx, self.size) + return (self.data + normalized_index)[] @always_inline fn unsafe_get(ref self, idx: Int) -> ref [self] Self.T: diff --git a/stdlib/src/collections/string/string.mojo b/stdlib/src/collections/string/string.mojo index 395dcdce10..217ab86427 100644 --- a/stdlib/src/collections/string/string.mojo +++ b/stdlib/src/collections/string/string.mojo @@ -861,7 +861,8 @@ struct String( A new string containing the character at the specified position. """ # TODO(#933): implement this for unicode when we support llvm intrinsic evaluation at compile time - var normalized_idx = normalize_index["String"](idx, len(self)) + # Using UInt to avoid extra signed normalization in self._buffer + var normalized_idx = normalize_index["String"](idx, UInt(len(self))) var buf = Self._buffer_type(capacity=1) buf.append(self._buffer[normalized_idx]) buf.append(0) diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index 756ece77b8..ba784bd617 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -20,6 +20,7 @@ from memory import Span ``` """ +from collections._index_normalization import normalize_index from collections import InlineArray from sys.info import simdwidthof @@ -178,15 +179,8 @@ struct Span[ Returns: An element reference. """ - # TODO: Simplify this with a UInt type. - debug_assert( - -self._len <= Int(idx) < self._len, "index must be within bounds" - ) - # TODO(MSTDL-1086): optimize away SIMD/UInt normalization check - var offset = Int(idx) - if offset < 0: - offset += len(self) - return self._data[offset] + var normalized_index = normalize_index["Span"](idx, self._len) + return self._data[normalized_index] @always_inline fn __getitem__(self, slc: Slice) -> Self: diff --git a/stdlib/test/collections/test_list_getitem_invalid_index.mojo b/stdlib/test/collections/test_list_getitem_invalid_index.mojo index 34ef22d26c..d7be0efbc4 100644 --- a/stdlib/test/collections/test_list_getitem_invalid_index.mojo +++ b/stdlib/test/collections/test_list_getitem_invalid_index.mojo @@ -17,7 +17,7 @@ # CHECK-FAIL-LABEL: test_fail_list_index fn main(): print("== test_fail_list_index") - # CHECK-FAIL: index: 4 is out of bounds for `List` of size: 3 + # CHECK-FAIL: index out of bounds nums = List[Int](1, 2, 3) print(nums[4])