Skip to content

Commit

Permalink
Add Mutexes to mock dgxa100 types to avoid concurrent maps reads/writes
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Klues <[email protected]>
  • Loading branch information
klueska committed Apr 23, 2024
1 parent 5e1cdb1 commit 6895ece
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions pkg/nvml/mock/dgxa100/dgxa100.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package dgxa100

import (
"fmt"
"sync"

"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/NVIDIA/go-nvml/pkg/nvml/mock"
Expand All @@ -34,6 +35,7 @@ type Server struct {
}
type Device struct {
mock.Device
sync.RWMutex
UUID string
Name string
Brand nvml.BrandType
Expand All @@ -50,6 +52,7 @@ type Device struct {

type GpuInstance struct {
mock.GpuInstance
sync.RWMutex
Info nvml.GpuInstanceInfo
ComputeInstances map[*ComputeInstance]struct{}
ComputeInstanceCounter uint32
Expand Down Expand Up @@ -254,6 +257,8 @@ func (d *Device) setMockFuncs() {
}

d.CreateGpuInstanceFunc = func(info *nvml.GpuInstanceProfileInfo) (nvml.GpuInstance, nvml.Return) {
d.Lock()
defer d.Unlock()
giInfo := nvml.GpuInstanceInfo{
Device: d,
Id: d.GpuInstanceCounter,
Expand All @@ -266,6 +271,8 @@ func (d *Device) setMockFuncs() {
}

d.CreateGpuInstanceWithPlacementFunc = func(info *nvml.GpuInstanceProfileInfo, placement *nvml.GpuInstancePlacement) (nvml.GpuInstance, nvml.Return) {
d.Lock()
defer d.Unlock()
giInfo := nvml.GpuInstanceInfo{
Device: d,
Id: d.GpuInstanceCounter,
Expand All @@ -279,6 +286,8 @@ func (d *Device) setMockFuncs() {
}

d.GetGpuInstancesFunc = func(info *nvml.GpuInstanceProfileInfo) ([]nvml.GpuInstance, nvml.Return) {
d.RLock()
defer d.RUnlock()
var gis []nvml.GpuInstance
for gi := range d.GpuInstances {
if gi.Info.ProfileId == info.Id {
Expand Down Expand Up @@ -321,6 +330,8 @@ func (gi *GpuInstance) setMockFuncs() {
}

gi.CreateComputeInstanceFunc = func(info *nvml.ComputeInstanceProfileInfo) (nvml.ComputeInstance, nvml.Return) {
gi.Lock()
defer gi.Unlock()
ciInfo := nvml.ComputeInstanceInfo{
Device: gi.Info.Device,
GpuInstance: gi,
Expand All @@ -334,6 +345,8 @@ func (gi *GpuInstance) setMockFuncs() {
}

gi.GetComputeInstancesFunc = func(info *nvml.ComputeInstanceProfileInfo) ([]nvml.ComputeInstance, nvml.Return) {
gi.RLock()
defer gi.RUnlock()
var cis []nvml.ComputeInstance
for ci := range gi.ComputeInstances {
if ci.Info.ProfileId == info.Id {
Expand All @@ -344,7 +357,10 @@ func (gi *GpuInstance) setMockFuncs() {
}

gi.DestroyFunc = func() nvml.Return {
delete(gi.Info.Device.(*Device).GpuInstances, gi)
d := gi.Info.Device.(*Device)
d.Lock()
defer d.Unlock()
delete(d.GpuInstances, gi)
return nvml.SUCCESS
}
}
Expand All @@ -355,7 +371,10 @@ func (ci *ComputeInstance) setMockFuncs() {
}

ci.DestroyFunc = func() nvml.Return {
delete(ci.Info.GpuInstance.(*GpuInstance).ComputeInstances, ci)
gi := ci.Info.GpuInstance.(*GpuInstance)
gi.Lock()
defer gi.Unlock()
delete(gi.ComputeInstances, ci)
return nvml.SUCCESS
}
}

0 comments on commit 6895ece

Please sign in to comment.