Skip to content

Commit

Permalink
Fix host-tasks being enqueued before they should be updated.
Browse files Browse the repository at this point in the history
  • Loading branch information
Bensuo committed Feb 6, 2025
1 parent fd83dfe commit de6a71e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 10 deletions.
25 changes: 22 additions & 3 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ void exec_graph_impl::makePartitions() {
}
}

MContainsHostTask = HostTaskList.size() > 0;
// Annotate nodes
// The first step in graph partitioning is to annotate all nodes of the graph
// with a temporary partition or group number. This step allows us to group
Expand Down Expand Up @@ -1078,6 +1079,16 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
} else if ((CurrentPartition->MSchedule.size() > 0) &&
(CurrentPartition->MSchedule.front()->MCGType ==
sycl::detail::CGType::CodeplayHostTask)) {
// If we have pending updates then we need to make sure that they are
// completed before the host-task is enqueued, to ensure it has received
// those updates prior to calling node->getCGCopy()
if (MUpdateEvents.size() > 0) {
for (auto &Event : MUpdateEvents) {
Event->wait_and_throw(Event);
}
MUpdateEvents.clear();
}

auto NodeImpl = CurrentPartition->MSchedule.front();
// Schedule host task
NodeImpl->MCommandGroup->getEvents().insert(
Expand Down Expand Up @@ -1436,9 +1447,17 @@ void exec_graph_impl::update(
sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()),
sycl::detail::getSyclObjImpl(MGraphImpl->getContext()),
sycl::async_handler{}, sycl::property_list{});
// Don't need to care about the return event here because it is synchronous
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
this, Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents);

auto UpdateEvent =
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
this, Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents);

// If the graph contains host-task(s) we need to track update events so we
// can explicitly wait on them before enqueue further host-tasks to ensure
// updates have taken effect.
if (MContainsHostTask) {
MUpdateEvents.push_back(UpdateEvent);
}
} else {
for (auto &Node : Nodes) {
updateImpl(Node);
Expand Down
7 changes: 7 additions & 0 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
HostTaskCG->getAccStorage() = OtherHostTaskCG->getAccStorage();
HostTaskCG->getRequirements() = OtherHostTaskCG->getRequirements();
HostTaskCG->MHostTask = OtherHostTaskCG->MHostTask;
HostTaskCG->getEvents() = OtherHostTaskCG->getEvents();
break;
}
default:
Expand Down Expand Up @@ -1441,6 +1442,12 @@ class exec_graph_impl {
// Stores a cache of node ids from modifiable graph nodes to the companion
// node(s) in this graph. Used for quick access when updating this graph.
std::multimap<node_impl::id_type, std::shared_ptr<node_impl>> MIDCache;
// True if this graph contains any host-tasks, controls whether we store
// events in MUpdateEvents.
bool MContainsHostTask = false;
// Contains events for updates submitted through the scheduler as we need to
// wait on them when enqueuing host-tasks.
std::vector<sycl::detail::EventImplPtr> MUpdateEvents;
};

class dynamic_parameter_impl {
Expand Down
11 changes: 4 additions & 7 deletions sycl/test-e2e/Graph/Inputs/whole_update_host_task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,20 @@ int main() {
// Fill graphB with nodes, with a different set of pointers
add_nodes_to_graph(GraphB, Queue, PtrA2, PtrB2, PtrC2, ModValue);

// Execute several Iterations of the graph for 1st set of buffers
// Execute several Iterations of the graph, updating in between each
// execution.
event Event;
for (unsigned n = 0; n < Iterations; n++) {
Event = Queue.submit([&](handler &CGH) {
CGH.depends_on(Event);
CGH.ext_oneapi_graph(GraphExec);
});
}

GraphExec.update(GraphB);

// Execute several Iterations of the graph for 2nd set of buffers
for (unsigned n = 0; n < Iterations; n++) {
GraphExec.update(GraphB);
Event = Queue.submit([&](handler &CGH) {
CGH.depends_on(Event);
CGH.ext_oneapi_graph(GraphExec);
});
GraphExec.update(GraphA);
}

Queue.wait_and_throw();
Expand Down

0 comments on commit de6a71e

Please sign in to comment.