Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/nvlib/device/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ type Interface interface {
GetMigDevices() ([]MigDevice, error)
GetMigProfiles() ([]MigProfile, error)
NewDevice(d nvml.Device) (Device, error)
NewDeviceByIdentifier(Identifier) (Device, error)
NewDeviceByUUID(uuid string) (Device, error)
NewMigDevice(d nvml.Device) (MigDevice, error)
NewMigDeviceByIdentifier(Identifier) (MigDevice, error)
NewMigDeviceByUUID(uuid string) (MigDevice, error)
NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error)
ParseMigProfile(profile string) (MigProfile, error)
Expand Down
26 changes: 26 additions & 0 deletions pkg/nvlib/device/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package device

import (
"fmt"
"strconv"

"github.com/NVIDIA/go-nvml/pkg/nvml"
)
Expand Down Expand Up @@ -49,6 +50,31 @@ func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {
return d.newDevice(dev)
}

// NewDeviceByIdentifier builds a new device from a device identifier.
func (d *devicelib) NewDeviceByIdentifier(id Identifier) (Device, error) {
switch {
case id.IsGpuUUID():
return d.NewDeviceByUUID(string(id))
case id.IsGpuIndex():
idx, err := strconv.Atoi(string(id))
if err != nil {
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
}
return d.NewDeviceByIndex(idx)
default:
return nil, fmt.Errorf("invalid device identifier: %v", id)
}
}

// NewDeviceByIndex builds a new Device for the specified index.
func (d *devicelib) NewDeviceByIndex(index int) (Device, error) {
dev, ret := d.nvmllib.DeviceGetHandleByIndex(index)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle for index '%v': %v", index, ret)
}
return d.newDevice(dev)
}

// NewDeviceByUUID builds a new Device from a UUID.
func (d *devicelib) NewDeviceByUUID(uuid string) (Device, error) {
dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid)
Expand Down
38 changes: 38 additions & 0 deletions pkg/nvlib/device/mig_device.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package device

import (
"fmt"
"strconv"
"strings"

"github.com/NVIDIA/go-nvml/pkg/nvml"
)
Expand Down Expand Up @@ -45,6 +47,42 @@ func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) {
if !isMig {
return nil, fmt.Errorf("not a MIG device")
}
return d.newMigDevice(handle)
}

// NewMigDeviceByIdentifier builds a new MigDevice for the specified identifier.
// If the identifier is not a valid MIG identifier, an error is raised.
func (d *devicelib) NewMigDeviceByIdentifier(id Identifier) (MigDevice, error) {
switch {
case id.IsMigUUID():
return d.NewMigDeviceByUUID(string(id))
case id.IsMigIndex():
split := strings.SplitN(string(id), ":", 2)
gpuIdx, err := strconv.Atoi(split[0])
if err != nil {
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
}
migIdx, err := strconv.Atoi(split[1])
if err != nil {
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
}
parent, err := d.NewDeviceByIndex(gpuIdx)
if err != nil {
return nil, fmt.Errorf("failed to get parent device handle: %w", err)
}
migDevice, ret := parent.GetMigDeviceHandleByIndex(migIdx)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("failed to get mig device by index: %w", ret)
}
return d.newMigDevice(migDevice)
default:
return nil, fmt.Errorf("invalid MIG device identifier: %v", id)
}
}

// newMigDevice constructs a new MigDevice for the supplied handle.
// The handle is not checked for validity.
func (d *devicelib) newMigDevice(handle nvml.Device) (MigDevice, error) {
return &migdevice{handle, d, nil}, nil
}

Expand Down