Skip to content

Commit

Permalink
Even more nonnegative_int updating
Browse files Browse the repository at this point in the history
  • Loading branch information
lockshaw committed Jan 31, 2025
1 parent 3728251 commit f27d31b
Show file tree
Hide file tree
Showing 102 changed files with 981 additions and 975 deletions.
6 changes: 3 additions & 3 deletions lib/compiler/src/compiler/allowed_machine_views.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "utils/containers/unordered_multiset_of.h"
#include "utils/containers/unordered_set_of.h"
#include "utils/containers/zip.h"
#include "utils/nonnegative_int/ceildiv.h"
#include "utils/nonnegative_int/nonnegative_range.h"
#include "utils/nonnegative_int/num_elements.h"
#include "utils/overload.h"
Expand Down Expand Up @@ -52,9 +53,8 @@ static std::unordered_set<MachineView>
auto get_max_stride_upper_bound = [](std::vector<nonnegative_int> const &tensor_dims,
nonnegative_int total_devices) -> nonnegative_int {
nonnegative_int min_num_devices_with_full_stride_volume = product(transform(
tensor_dims, [](nonnegative_int num_devices) { return nonnegative_int{num_devices.value() - 1}; }));
return nonnegative_int{TODO colin
static_cast<int>(std::ceil(static_cast<float>(total_devices.value()) / min_num_devices_with_full_stride_volume.value()))};
tensor_dims, [](nonnegative_int num_devices) { return nonnegative_int{num_devices.unwrap_nonnegative() - 1}; }));
return ceildiv(total_devices, min_num_devices_with_full_stride_volume);
};

auto candidate_strides = [&](std::vector<nonnegative_int> const &tensor_dims,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ std::unordered_set<std::pair<MachineSpecification, MachineSpecification>>
for (int i = 1; i < resource.num_nodes; i *= 2) {
MachineSpecification sub_resource1 = resource;
MachineSpecification sub_resource2 = resource;
sub_resource1.num_nodes = i;
sub_resource2.num_nodes = resource.num_nodes - i;
sub_resource1.num_nodes = nonnegative_int{i};
sub_resource2.num_nodes = nonnegative_int{resource.num_nodes.unwrap_nonnegative() - i};
result.insert(std::make_pair(sub_resource1, sub_resource2));
result.insert(std::make_pair(sub_resource2, sub_resource1));
}

for (int i = 1; i < resource.num_gpus_per_node; i *= 2) {
MachineSpecification sub_resource1 = resource;
MachineSpecification sub_resource2 = resource;
sub_resource1.num_gpus_per_node = i;
sub_resource2.num_gpus_per_node = resource.num_gpus_per_node - i;
sub_resource1.num_gpus_per_node = nonnegative_int{i};
sub_resource2.num_gpus_per_node = nonnegative_int{resource.num_gpus_per_node.unwrap_nonnegative() - i};
result.insert(std::make_pair(sub_resource1, sub_resource2));
result.insert(std::make_pair(sub_resource2, sub_resource1));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ TEST_SUITE(FF_TEST_SUITE) {
ParallelTensorShape input_shape = ParallelTensorShape{
ParallelTensorDims{
FFOrdered<ShardParallelDim>{
ShardParallelDim{10, 2},
ShardParallelDim{12, 1},
ShardParallelDim{10_n, 2_n},
ShardParallelDim{12_n, 1_n},
},
ReplicaParallelDimSet{
SumDegree{1},
DiscardCopyDegree{1},
SumDegree{1_n},
DiscardCopyDegree{1_n},
},
},
DataType::FLOAT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ using namespace ::FlexFlow;

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("get_machine_resource_splits") {
auto make_machine_spec = [](int num_nodes, int num_gpus_per_node) {
auto make_machine_spec = [](nonnegative_int num_nodes, nonnegative_int num_gpus_per_node) {
return MachineSpecification{
/*num_nodes=*/num_nodes,
/*num_cpus_per_node=*/1,
/*num_cpus_per_node=*/1_n,
/*num_gpus_per_node=*/num_gpus_per_node,
/*inter_node_bandwidth=*/1.0,
/*intra_node_bandwidth=*/1.0,
};
};

SUBCASE("returns no splits if no splits are possible") {
MachineSpecification input = make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/1);
MachineSpecification input = make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/1_n);

std::unordered_set<std::pair<MachineSpecification, MachineSpecification>>
result = get_machine_resource_splits(input);
Expand All @@ -32,25 +32,25 @@ TEST_SUITE(FF_TEST_SUITE) {

SUBCASE(
"returns splits in gpu and node dimensions, but not at the same time") {
MachineSpecification input = make_machine_spec(/*num_nodes=*/2,
/*num_gpus_per_node=*/2);
MachineSpecification input = make_machine_spec(/*num_nodes=*/2_n,
/*num_gpus_per_node=*/2_n);

std::unordered_set<std::pair<MachineSpecification, MachineSpecification>>
result = get_machine_resource_splits(input);

std::unordered_set<std::pair<MachineSpecification, MachineSpecification>>
correct = {
{
make_machine_spec(/*num_nodes=*/2,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/2,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/2_n,
/*num_gpus_per_node=*/1_n),
make_machine_spec(/*num_nodes=*/2_n,
/*num_gpus_per_node=*/1_n),
},
{
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/2),
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/2),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/2_n),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/2_n),
},

};
Expand All @@ -60,8 +60,8 @@ TEST_SUITE(FF_TEST_SUITE) {

SUBCASE("returns splits in node dimension in powers of two") {
SUBCASE("num_nodes is a power of 2") {
MachineSpecification input = make_machine_spec(/*num_nodes=*/8,
/*num_gpus_per_node=*/1);
MachineSpecification input = make_machine_spec(/*num_nodes=*/8_n,
/*num_gpus_per_node=*/1_n);

std::unordered_set<
std::pair<MachineSpecification, MachineSpecification>>
Expand All @@ -71,43 +71,43 @@ TEST_SUITE(FF_TEST_SUITE) {
std::pair<MachineSpecification, MachineSpecification>>
correct = {
{
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/7,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/1_n),
make_machine_spec(/*num_nodes=*/7_n,
/*num_gpus_per_node=*/1_n),
},
{
make_machine_spec(/*num_nodes=*/2,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/6,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/2_n,
/*num_gpus_per_node=*/1_n),
make_machine_spec(/*num_nodes=*/6_n,
/*num_gpus_per_node=*/1_n),
},
{
make_machine_spec(/*num_nodes=*/4,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/4,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/4_n,
/*num_gpus_per_node=*/1_n),
make_machine_spec(/*num_nodes=*/4_n,
/*num_gpus_per_node=*/1_n),
},
{
make_machine_spec(/*num_nodes=*/6,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/2,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/6_n,
/*num_gpus_per_node=*/1_n),
make_machine_spec(/*num_nodes=*/2_n,
/*num_gpus_per_node=*/1_n),
},
{
make_machine_spec(/*num_nodes=*/7,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/7_n,
/*num_gpus_per_node=*/1_n),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/1_n),
},
};

CHECK(result == correct);
}

SUBCASE("num_nodes is not a power of 2") {
MachineSpecification input = make_machine_spec(/*num_nodes=*/6,
/*num_gpus_per_node=*/1);
MachineSpecification input = make_machine_spec(/*num_nodes=*/6_n,
/*num_gpus_per_node=*/1_n);

std::unordered_set<
std::pair<MachineSpecification, MachineSpecification>>
Expand All @@ -117,28 +117,28 @@ TEST_SUITE(FF_TEST_SUITE) {
std::pair<MachineSpecification, MachineSpecification>>
correct = {
{
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/5,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/1_n),
make_machine_spec(/*num_nodes=*/5_n,
/*num_gpus_per_node=*/1_n),
},
{
make_machine_spec(/*num_nodes=*/2,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/4,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/2_n,
/*num_gpus_per_node=*/1_n),
make_machine_spec(/*num_nodes=*/4_n,
/*num_gpus_per_node=*/1_n),
},
{
make_machine_spec(/*num_nodes=*/4,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/2,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/4_n,
/*num_gpus_per_node=*/1_n),
make_machine_spec(/*num_nodes=*/2_n,
/*num_gpus_per_node=*/1_n),
},
{
make_machine_spec(/*num_nodes=*/5,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/5_n,
/*num_gpus_per_node=*/1_n),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/1_n),
},
};

Expand All @@ -148,8 +148,8 @@ TEST_SUITE(FF_TEST_SUITE) {

SUBCASE("returns splits in gpu dimension in powers of two") {
SUBCASE("num_gpus_per_node is a power of 2") {
MachineSpecification input = make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/8);
MachineSpecification input = make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/8_n);

std::unordered_set<
std::pair<MachineSpecification, MachineSpecification>>
Expand All @@ -159,43 +159,43 @@ TEST_SUITE(FF_TEST_SUITE) {
std::pair<MachineSpecification, MachineSpecification>>
correct = {
{
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/7),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/1_n),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/7_n),
},
{
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/2),
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/6),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/2_n),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/6_n),
},
{
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/4),
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/4),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/4_n),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/4_n),
},
{
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/6),
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/2),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/6_n),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/2_n),
},
{
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/7),
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/7_n),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/1_n),
},
};

CHECK(result == correct);
}

SUBCASE("num_gpus_per_node is not a power of 2") {
MachineSpecification input = make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/6);
MachineSpecification input = make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/6_n);

std::unordered_set<
std::pair<MachineSpecification, MachineSpecification>>
Expand All @@ -205,28 +205,28 @@ TEST_SUITE(FF_TEST_SUITE) {
std::pair<MachineSpecification, MachineSpecification>>
correct = {
{
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/5),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/1_n),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/5_n),
},
{
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/2),
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/4),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/2_n),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/4_n),
},
{
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/4),
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/2),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/4_n),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/2_n),
},
{
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/5),
make_machine_spec(/*num_nodes=*/1,
/*num_gpus_per_node=*/1),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/5_n),
make_machine_spec(/*num_nodes=*/1_n,
/*num_gpus_per_node=*/1_n),
},
};
}
Expand Down
Loading

0 comments on commit f27d31b

Please sign in to comment.