diff --git a/.gitignore b/.gitignore index d118c3344..6afb875c8 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,5 @@ m4/ltoptions.m4 m4/ltsugar.m4 m4/ltversion.m4 m4/lt~obsolete.m4 + +.idea/ diff --git a/include/nccl_ofi.h b/include/nccl_ofi.h index cc4577e3e..f42c367e0 100644 --- a/include/nccl_ofi.h +++ b/include/nccl_ofi.h @@ -529,7 +529,7 @@ int nccl_net_ofi_dealloc_mr_buffer(void *ptr, size_t size); * Set required behavior flags (and print debugging information) for * local_mr, virt_addr_mr, and endpoint_mr. */ -int nccl_net_ofi_query_provider_capabilities(struct fi_info *selected_provider, +int nccl_net_ofi_query_provider_capabilities(const struct fi_info *selected_provider, unsigned int num_providers); /* Declare a platform-specific initialization hook that can be diff --git a/src/nccl_ofi_api.c b/src/nccl_ofi_api.c index e6abd756b..41c1f8c82 100644 --- a/src/nccl_ofi_api.c +++ b/src/nccl_ofi_api.c @@ -269,7 +269,7 @@ ncclResult_t nccl_net_ofi_connect(int dev_id, void *handle, void **sComm) nccl_net_ofi_ep_t *base_ep = NULL; if (ofi_handle->state.stage == COMM_CREATE_START) { /* Retrieve and validate device */ - nccl_net_ofi_device_t *base_dev = base_dev = plugin->devs[dev_id]; + nccl_net_ofi_device_t *base_dev = plugin->devs[dev_id]; if (OFI_UNLIKELY(base_dev == NULL)) { NCCL_OFI_WARN("Error accessing device. Device #%i has not been initialized.", dev_id); return ncclInternalError; @@ -491,7 +491,7 @@ ncclResult_t nccl_net_ofi_accept(void *lComm, void **rComm) ncclResult_t nccl_net_ofi_accept_v4(void* listenComm, void** recvComm) { - ncclResult_t ret; + ncclResult_t ret = ncclInvalidArgument; while (*recvComm == NULL) { ret = nccl_net_ofi_accept(listenComm, recvComm); diff --git a/src/nccl_ofi_net.c b/src/nccl_ofi_net.c index e714a7019..f52bdf289 100644 --- a/src/nccl_ofi_net.c +++ b/src/nccl_ofi_net.c @@ -353,12 +353,14 @@ int nccl_net_ofi_info_properties(struct fi_info *nic_prov, int dev_id, int num_d nccl_ofi_properties_t *props) { int ret = 0; - nccl_ofi_properties_t dev_props = {0}; struct fid_nic *nic_info = NULL; - ret = set_nic_props_default(dev_id, nic_prov, &dev_props); - if (ret != 0) + memset(props, 0, sizeof(*props)); + + ret = set_nic_props_default(dev_id, nic_prov, props); + if (ret != 0) { goto error; + } /* Change default values as set by NIC attributes */ nic_info = (struct fid_nic *)nic_prov->nic; @@ -373,7 +375,11 @@ int nccl_net_ofi_info_properties(struct fi_info *nic_prov, int dev_id, int num_d /* name is NULL if device is a part of multirail config */ /* overriding default name only if value is available from provider */ if (nic_info->device_attr->name) { - dev_props.name = strdup(nic_info->device_attr->name); + if (props->name) { + free(props->name); + } + props->name = strdup(nic_info->device_attr->name); + assert(props->name != NULL); } /* @@ -381,17 +387,17 @@ int nccl_net_ofi_info_properties(struct fi_info *nic_prov, int dev_id, int num_d * registration support to NCCL */ if (nic_prov->domain_attr->mr_mode & FI_MR_ENDPOINT) { - dev_props.mr_scope = NCCL_OFI_MR_SCOPE_ENDPOINT; + props->mr_scope = NCCL_OFI_MR_SCOPE_ENDPOINT; NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Libfabric provider associates MRs with endpoints"); } else { - dev_props.mr_scope = NCCL_OFI_MR_SCOPE_DOMAIN; + props->mr_scope = NCCL_OFI_MR_SCOPE_DOMAIN; NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Libfabric provider associates MRs with domains"); } /* Speed reported in Mbps */ - dev_props.port_speed = nic_info->link_attr->speed / (1e6); + props->port_speed = nic_info->link_attr->speed / (1e6); - ret = get_device_pci_path(nic_info, &(dev_props.pci_path)); + ret = get_device_pci_path(nic_info, &props->pci_path); if (ret != 0) { ret = 0; props->pci_path = NULL; @@ -436,18 +442,25 @@ int nccl_net_ofi_info_properties(struct fi_info *nic_prov, int dev_id, int num_d * this is probably ok; any affinity is lost by * bouncing through host buffers anyway. */ - if (active_cuda_device / gpus_per_conn != dev_id) { - for (c=strlen(dev_props.pci_path); c && dev_props.pci_path[c] != '/'; c--) { - dev_props.pci_path[c] = '\0'; + if ((active_cuda_device / gpus_per_conn != dev_id) && props->pci_path) { + for (c = strlen(props->pci_path); props->pci_path[c] != '/'; c--) { + props->pci_path[c] = '\0'; } - dev_props.pci_path[c] = '\0'; } - NCCL_OFI_TRACE(NCCL_INIT, "Returning synthetic PCI path for device %d of %s", - dev_id, dev_props.pci_path); - - snprintf(dev_props.name, FI_NAME_MAX + 2, "%s-%x", nic_info->device_attr->name, dev_id); - NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Adjusted dev %d device name to %s", - dev_id, dev_props.name); + NCCL_OFI_TRACE(NCCL_INIT, + "Returning synthetic PCI path for device %d of %s", + dev_id, + props->pci_path); + + snprintf(props->name, + FI_NAME_MAX + 2, + "%s-%x", + nic_info->device_attr->name, + dev_id); + NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, + "Adjusted dev %d device name to %s", + dev_id, + props->name); #else NCCL_OFI_WARN("NIC_DUP_CONNS enabled on platform that does not support NIC_DUP_CONNS. This should not happen."); ret = -ENOTSUP; @@ -456,11 +469,15 @@ int nccl_net_ofi_info_properties(struct fi_info *nic_prov, int dev_id, int num_d } goto exit; +error: + if (props->pci_path) { + free(props->pci_path); + } + if (props->name) { + free(props->name); + } - error: - props = NULL; - exit: - *props = dev_props; +exit: return ret; } @@ -482,7 +499,7 @@ int nccl_net_ofi_reg_mr_dma_buf_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm, } -int nccl_net_ofi_query_provider_capabilities(struct fi_info *selected_provider, +int nccl_net_ofi_query_provider_capabilities(const struct fi_info *selected_provider, unsigned int num_providers) { NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Selected Provider is %s (found %d nics)", diff --git a/src/nccl_ofi_ofiutils.c b/src/nccl_ofi_ofiutils.c index 1ce552c7e..e29f2eb67 100644 --- a/src/nccl_ofi_ofiutils.c +++ b/src/nccl_ofi_ofiutils.c @@ -166,6 +166,7 @@ int nccl_ofi_ofiutils_get_providers(const char *prov_include, int rc = 0; struct fi_info *providers = NULL, *prov = NULL, *last_prov; char *selected_prov_name = NULL; + *num_prov_infos = 0; rc = fi_getinfo(required_version, NULL, NULL, 0ULL, hints, &providers); if (rc != 0) @@ -173,6 +174,9 @@ int nccl_ofi_ofiutils_get_providers(const char *prov_include, if (!providers) goto error; + if (!num_prov_infos) { + goto error; + } /* Pick a provider name to use. If there is a prov_include * provided, use the first provider which matches the list, @@ -201,7 +205,6 @@ int nccl_ofi_ofiutils_get_providers(const char *prov_include, prov = providers; providers = NULL; last_prov = NULL; - *num_prov_infos = 0; while (prov) { struct fi_info *prov_next = prov->next; prov->next = NULL; diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index bbf403de6..33f84bc98 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -5873,7 +5873,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, goto error; } - base_devs = calloc(num_devs, sizeof(nccl_net_ofi_rdma_device_t *)); + base_devs = calloc(num_devs, sizeof(nccl_net_ofi_device_t *)); if (!base_devs) { NCCL_OFI_WARN("Unable to allocate " "nccl_net_ofi_rdma_device_t pointer array"); diff --git a/src/nccl_ofi_sendrecv.c b/src/nccl_ofi_sendrecv.c index dcbd811d0..a6d7e4251 100644 --- a/src/nccl_ofi_sendrecv.c +++ b/src/nccl_ofi_sendrecv.c @@ -1,11 +1,13 @@ /* * Copyright (c) 2023-2024 Amazon.com, Inc. or its affiliates. All rights reserved. */ + #include "config.h" #include #include #include +#include #include #include #include @@ -2400,7 +2402,7 @@ int nccl_net_ofi_sendrecv_init(const char *provider_filter, goto exit; } - base_devs = malloc(num_providers * sizeof(nccl_net_ofi_sendrecv_device_t *)); + base_devs = malloc(num_providers * sizeof(nccl_net_ofi_device_t *)); if (!base_devs) { NCCL_OFI_WARN("Unable to allocate " "nccl_net_ofi_sendrecv_device_t pointer array"); @@ -2476,7 +2478,17 @@ int nccl_net_ofi_sendrecv_init(const char *provider_filter, /* The provider may return support for a larger key size. Use * the size requested by the user to allow them to limit the * size of the mr_keys table. */ - ret = nccl_ofi_idpool_init(&device->key_pool, (size_t)(1 << (ofi_nccl_mr_key_size() * 8))); + const size_t shift = (ofi_nccl_mr_key_size() * 8); + const size_t size_t_bits = (sizeof(size_t) * CHAR_BIT); + if (shift > (size_t_bits - 1)) { + NCCL_OFI_WARN( + "Provided mr keypool size of %lu must be less than %zu", + ofi_nccl_mr_key_size(), + size_t_bits); + ret = -EINVAL; + goto error; + } + ret = nccl_ofi_idpool_init(&device->key_pool, 1 << shift); } else { /* Mark key pool as not in use */ ret = nccl_ofi_idpool_init(&device->key_pool, 0); diff --git a/src/nccl_ofi_topo.c b/src/nccl_ofi_topo.c index 6e196138c..84bd4a144 100644 --- a/src/nccl_ofi_topo.c +++ b/src/nccl_ofi_topo.c @@ -599,11 +599,15 @@ nccl_ofi_topo_t *nccl_ofi_topo_create(struct fi_info *info_list) */ static int mark_topo_nodes_with_ofi_info_subtree(nccl_ofi_topo_t *topo) { + int status; nccl_ofi_topo_data_t *data = NULL; /* Iterate over user data that stores libfabric NIC info structs */ nccl_ofi_topo_data_iterator_t data_iter; - nccl_ofi_topo_set_to_begin(topo, &data_iter); + if ((status = nccl_ofi_topo_set_to_begin(topo, &data_iter)) < 0) { + return status; + } + while ((data = nccl_ofi_get_user_data(&data_iter))) { nccl_ofi_inc_user_data_iter(&data_iter); if (!data->info_list) { @@ -1296,9 +1300,13 @@ static int write_pci_tag(FILE *file, int indent, "\n", - indent, "", - pcidev->domain, pcidev->bus, pcidev->dev, pcidev->func, + "link_width=\"%zu\"/>\n", + indent, + "", + pcidev->domain, + pcidev->bus, + pcidev->dev, + pcidev->func, pcie_gen[speed_idx], width); diff --git a/tests/unit/scheduler.c b/tests/unit/scheduler.c index 3e946f0af..5b9bfa0ec 100644 --- a/tests/unit/scheduler.c +++ b/tests/unit/scheduler.c @@ -91,12 +91,14 @@ int test_multiplexing_schedule() ret = create_multiplexed(size, num_rails, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); + free(ref_schedule); return ret; } ref_schedule->num_xfer_infos = 0; ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } free(schedule); @@ -112,12 +114,14 @@ int test_multiplexing_schedule() ret = create_multiplexed(size, num_rails, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); + free(ref_schedule); return ret; } ref_schedule->num_xfer_infos = 0; ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } free(schedule); @@ -129,6 +133,7 @@ int test_multiplexing_schedule() ret = create_multiplexed(size, num_rails, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); + free(ref_schedule); return ret; } ref_schedule->num_xfer_infos = 1; @@ -138,6 +143,7 @@ int test_multiplexing_schedule() ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } free(schedule); @@ -149,6 +155,7 @@ int test_multiplexing_schedule() ret = create_multiplexed(size, num_rails, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); + free(ref_schedule); return ret; } ref_schedule->num_xfer_infos = 1; @@ -158,6 +165,7 @@ int test_multiplexing_schedule() ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } free(schedule); @@ -169,6 +177,7 @@ int test_multiplexing_schedule() ret = create_multiplexed(size, num_rails, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); + free(ref_schedule); return ret; } ref_schedule->num_xfer_infos = 1; @@ -178,6 +187,7 @@ int test_multiplexing_schedule() ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } free(schedule); @@ -193,12 +203,14 @@ int test_multiplexing_schedule() ret = create_multiplexed(size, num_rails, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); + free(ref_schedule); return ret; } ref_schedule->num_xfer_infos = 0; ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } free(schedule); @@ -210,6 +222,7 @@ int test_multiplexing_schedule() ret = create_multiplexed(size, num_rails, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); + free(ref_schedule); return ret; } ref_schedule->num_xfer_infos = 2; @@ -222,6 +235,7 @@ int test_multiplexing_schedule() ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } free(schedule); @@ -233,6 +247,7 @@ int test_multiplexing_schedule() ret = create_multiplexed(size, num_rails, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); + free(ref_schedule); return ret; } ref_schedule->num_xfer_infos = 2; @@ -245,6 +260,7 @@ int test_multiplexing_schedule() ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } free(schedule); @@ -256,6 +272,7 @@ int test_multiplexing_schedule() ret = create_multiplexed(size, num_rails, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); + free(ref_schedule); return ret; } ref_schedule->num_xfer_infos = 3; @@ -271,6 +288,7 @@ int test_multiplexing_schedule() ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } free(schedule); @@ -291,6 +309,7 @@ int test_threshold_scheduler() nccl_net_ofi_scheduler_t *scheduler; if (nccl_net_ofi_threshold_scheduler_init(num_rails, rr_threshold, &scheduler)) { NCCL_OFI_WARN("Failed to initialize threshold scheduler"); + free(ref_schedule); return -1; } @@ -298,6 +317,7 @@ int test_threshold_scheduler() schedule = scheduler->get_schedule(scheduler, rr_threshold + 1, num_rails); if (!schedule) { NCCL_OFI_WARN("Failed to get schedule"); + free(ref_schedule); return -1; } ref_schedule->num_xfer_infos = 2; @@ -310,6 +330,7 @@ int test_threshold_scheduler() ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } nccl_net_ofi_release_schedule(scheduler, schedule); @@ -318,6 +339,7 @@ int test_threshold_scheduler() schedule = scheduler->get_schedule(scheduler, rr_threshold, num_rails); if (!schedule) { NCCL_OFI_WARN("Failed to get schedule"); + free(ref_schedule); return -1; } ref_schedule->num_xfer_infos = 1; @@ -327,6 +349,7 @@ int test_threshold_scheduler() ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } nccl_net_ofi_release_schedule(scheduler, schedule); @@ -334,6 +357,7 @@ int test_threshold_scheduler() schedule = scheduler->get_schedule(scheduler, rr_threshold, num_rails); if (!schedule) { NCCL_OFI_WARN("Failed to get schedule"); + free(ref_schedule); return -1; } ref_schedule->num_xfer_infos = 1; @@ -343,6 +367,7 @@ int test_threshold_scheduler() ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } nccl_net_ofi_release_schedule(scheduler, schedule); @@ -350,6 +375,7 @@ int test_threshold_scheduler() schedule = scheduler->get_schedule(scheduler, rr_threshold, num_rails); if (!schedule) { NCCL_OFI_WARN("Failed to get schedule"); + free(ref_schedule); return -1; } ref_schedule->num_xfer_infos = 1; @@ -359,6 +385,7 @@ int test_threshold_scheduler() ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); return ret; } nccl_net_ofi_release_schedule(scheduler, schedule);