-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#17477: Introduce ND coordinate system for TT-distributed #17745
base: main
Are you sure you want to change the base?
Conversation
bool eq_spans(const ArrayType& a, const ArrayType& b) { | ||
return std::equal(a.begin(), a.end(), b.begin(), b.end()); | ||
} | ||
bool eq_spans(const auto a, const auto b) { return std::equal(a.begin(), a.end(), b.begin(), b.end()); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to do this because a
is now tt::stl::Span
while b
is std::span
. Annoying, but this keeps Metal at cpp17.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great overall. One thing to keep in mind is that our existing 2D coordinate system in Metal (exposed through CoreCoord
, CoreRange
and CoreRangeSet
) provides a bunch of utility functions, allowing users to compute set/range intersections, adjacency, etc.
It would be very useful for us to expose similar APIs for this new ND coordinate system as well. Especially as we start introducing more heterogeneity in our workloads.
mesh_device_->num_cols()); | ||
return buffers_[device_coord.row][device_coord.col]; | ||
return get_device_buffer(MeshCoordinate(device_coord.row, device_coord.col)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a comment saying that this overload is kept around to be compatible with existing infra and will be removed once everything is migrated to use MeshCoordinate
.
@@ -218,7 +218,11 @@ std::vector<IDevice*> MeshDevice::get_devices() const { return view_->get_device | |||
|
|||
// TODO: Remove this function once we have a proper view interface | |||
IDevice* MeshDevice::get_device(size_t row_idx, size_t col_idx) const { | |||
return this->get_device_index(row_idx * num_cols() + col_idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, I don't see why we need this overload to expose physical devices using [row, col]
in the long term. This doesn't make sense for an ND mesh anyway
} | ||
|
||
MeshCoordinate::MeshCoordinate(uint32_t coord) : value_({coord}) {} | ||
MeshCoordinate::MeshCoordinate(uint32_t row, uint32_t col) : value_({row, col}) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of [row, col]
should be removed eventually. I expect that users will be extremely confused by two different notations being exposed by the same data structure.
Fundamentally, removing this requires MeshDevice
to start using a Cartesian scheme.
} | ||
|
||
MeshCoordinateRange::MeshCoordinateRange(const SimpleMeshShape& shape) : | ||
MeshCoordinateRange(zero_coordinate(shape.dims()), shape_back(shape)) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving on pad file changes
Ticket
#17477
Problem description
Existing mesh infra assumes 2D. This assumption won't hold in the future.
What's changed
Introduce a new
SimpleMeshShape
that will gradually replace the existingMeshShape
, after which it will be renamed toMeshShape
.Introduce
MeshCoordinate
,MeshCoordinateRange
, andMeshContainer
- primitives designed to work with the new ND coordinate system.MeshContainer
allows efficient flat representation of various metadata that matches the mesh shape. Iterators are available to make it easy to use.MeshCoordinate
along with strides that are precomputed onSimpleMeshShape
allows for an easy point access. The integration withMeshBuffer
demonstrates the use case.Next steps:
MeshShape
,MeshOffset
, and the related aliases with the newSimpleMeshShape
, andMeshCoordinate
.CoreCoord
, for now. Cores are fundamentally in 2D, so a more specialized system can be used for efficiency. Also it is not desired to makeCoreCoord
to interop withMeshCoordinate
- the 2 sets of coordinates mean entirely different concepts.Checklist