Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Gopjrt Changelog

# Next
# 0.10.0 - (Release Candidate) Added Shardy support.

- Package `cmd/gopjrt_installer`:
- Link `libcublasLt.so.13` and `libcublas.so.13` to the `lib` subdirectory of the installation directory given.
Expand All @@ -10,6 +10,8 @@
- Fixed wrong struct size set for `PJRT_Event_Destroy_Args`, in #65 (@timkaye11)
- Fixed `buffer.Destroy` to release the `client` pointer in the wrapper.
- Added a required `runtime.KeepAlive(program)` on a CGO call to compile `program`.
- Added new Shardy support for distributed (across multiple devices) execution.
- Added old SPMD support (see `pjrt.Compile().WithSPMD`)

# v0.9.1 2025/11/07: More multi-device support; updated CPU PJRT; dropped static CPU PJRT linking.

Expand Down
24 changes: 12 additions & 12 deletions pjrt/benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*FlagPluginName))
const repeats = 1000
repeatedCGO := func() {
for _ = range repeats {
Expand All @@ -52,7 +52,7 @@
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*FlagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

Expand Down Expand Up @@ -129,7 +129,7 @@
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*FlagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

Expand Down Expand Up @@ -171,7 +171,7 @@
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*FlagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

Expand Down Expand Up @@ -212,7 +212,7 @@
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*FlagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

Expand All @@ -233,7 +233,7 @@
builder := stablehlo.New(fmt.Sprintf("Add1/%s", s))
mainFn := builder.Main()
// f(x) = x + 1
x := mainFn.NamedInput("x", s)
x := must1(mainFn.NamedInput("x", s))

Check failure on line 236 in pjrt/benchmarks_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1

Check failure on line 236 in pjrt/benchmarks_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1
one := must1(mainFn.ConstantFromScalar(float32(1)))
broadcastedOne := must1(stablehlo.BroadcastInDim(one, x.Shape(), nil))
add1 := must1(stablehlo.Add(x, broadcastedOne))
Expand Down Expand Up @@ -278,7 +278,7 @@
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*FlagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

Expand All @@ -299,7 +299,7 @@
builder := stablehlo.New(fmt.Sprintf("Add1/%s", s))
mainFn := builder.Main()
// f(x) = (x + 1) * 0.5
x := mainFn.NamedInput("x", s)
x := must1(mainFn.NamedInput("x", s))

Check failure on line 302 in pjrt/benchmarks_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1

Check failure on line 302 in pjrt/benchmarks_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1
one := must1(mainFn.ConstantFromScalar(float32(1)))
broadcastedOne := must1(stablehlo.BroadcastInDim(one, x.Shape(), nil))
add1 := must1(stablehlo.Add(x, broadcastedOne))
Expand Down Expand Up @@ -347,7 +347,7 @@
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*FlagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

Expand All @@ -368,7 +368,7 @@
builder := stablehlo.New(fmt.Sprintf("MeanNormalized/%s", s))
mainFn := builder.Main()
// f(x) = (x + 1) * 0.5 - mean((x + 1) * 0.5)
x := mainFn.NamedInput("x", s)
x := must1(mainFn.NamedInput("x", s))

Check failure on line 371 in pjrt/benchmarks_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1

Check failure on line 371 in pjrt/benchmarks_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1
one := must1(mainFn.ConstantFromScalar(float32(1)))
broadcastedOne := must1(stablehlo.BroadcastInDim(one, x.Shape(), nil))
add1 := must1(stablehlo.Add(x, broadcastedOne))
Expand All @@ -377,8 +377,8 @@
div2 := must1(stablehlo.Multiply(add1, broadcastedHalf))

reductionFn := mainFn.Closure()
lhs := reductionFn.NamedInput("lhs", shapes.Make(dtypes.F32))
rhs := reductionFn.NamedInput("rhs", shapes.Make(dtypes.F32))
lhs := must1(reductionFn.NamedInput("lhs", shapes.Make(dtypes.F32)))

Check failure on line 380 in pjrt/benchmarks_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1

Check failure on line 380 in pjrt/benchmarks_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1
rhs := must1(reductionFn.NamedInput("rhs", shapes.Make(dtypes.F32)))

Check failure on line 381 in pjrt/benchmarks_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1

Check failure on line 381 in pjrt/benchmarks_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1
must(reductionFn.Return(must1(stablehlo.Add(lhs, rhs))))
initialValue := must1(mainFn.ConstantFromScalar(float32(0)))

Expand Down
26 changes: 13 additions & 13 deletions pjrt/buffers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
}

func TestTransfers(t *testing.T) {
plugin, err := GetPlugin(*flagPluginName)
plugin, err := GetPlugin(*FlagPluginName)
require.NoError(t, err)
fmt.Printf("Loaded %s\n", plugin)

Expand All @@ -97,7 +97,7 @@
}

func TestBufferProperties(t *testing.T) {
plugin, err := GetPlugin(*flagPluginName)
plugin, err := GetPlugin(*FlagPluginName)
require.NoError(t, err)
fmt.Printf("Loaded %s\n", plugin)

Expand Down Expand Up @@ -132,7 +132,7 @@
}

func TestBufferCopyToDevice(t *testing.T) {
plugin, err := GetPlugin(*flagPluginName)
plugin, err := GetPlugin(*FlagPluginName)
require.NoError(t, err)
client, err := plugin.NewClient(nil)
require.NoErrorf(t, err, "Failed to create a client on %s", plugin)
Expand Down Expand Up @@ -185,13 +185,13 @@
"force_shared_buffer", false, "Force executing TestCreateViewOfDeviceBuffer and TestBufferUnsafePointer even if plugin is not \"cpu\".")

func TestCreateViewOfDeviceBuffer(t *testing.T) {
if *flagPluginName != "cpu" && !*flagForceSharedBuffer {
if *FlagPluginName != "cpu" && !*flagForceSharedBuffer {
t.Skip("Skipping TestCreateViewOfDeviceBuffer because -plugin != \"cpu\". " +
"Set --force_create_view to force executing the test anyway")
}

// Create plugin.
plugin := must1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*FlagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

Expand All @@ -200,7 +200,7 @@
shape := shapes.Make(dtype, 2, 3)
builder := stablehlo.New("Add1")
mainFn := builder.Main()
x := mainFn.NamedInput("x", shape)
x := must1(mainFn.NamedInput("x", shape))

Check failure on line 203 in pjrt/buffers_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1

Check failure on line 203 in pjrt/buffers_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1
one := must1(mainFn.ConstantFromScalar(float32(1)))
broadcastedOne := must1(stablehlo.BroadcastInDim(one, x.Shape(), nil))
add1 := must1(stablehlo.Add(x, broadcastedOne))
Expand Down Expand Up @@ -248,13 +248,13 @@
}

func TestNewSharedBuffer(t *testing.T) {
if *flagPluginName != "cpu" && !*flagForceSharedBuffer {
if *FlagPluginName != "cpu" && !*flagForceSharedBuffer {
t.Skip("Skipping TestNewSharedBuffer because -plugin != \"cpu\". " +
"Set --force_create_view to force executing the test anyway")
}

// Create plugin.
plugin := must1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*FlagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

Expand All @@ -263,7 +263,7 @@
shape := shapes.Make(dtype, 2, 3)
builder := stablehlo.New("Add1")
mainFn := builder.Main()
x := mainFn.NamedInput("x", shape)
x := must1(mainFn.NamedInput("x", shape))

Check failure on line 266 in pjrt/buffers_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1

Check failure on line 266 in pjrt/buffers_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1
one := must1(mainFn.ConstantFromScalar(float32(1)))
broadcastedOne := must1(stablehlo.BroadcastInDim(one, x.Shape(), nil))
add1 := must1(stablehlo.Add(x, broadcastedOne))
Expand Down Expand Up @@ -310,13 +310,13 @@
}

func TestBufferData(t *testing.T) {
if *flagPluginName != "cpu" && !*flagForceSharedBuffer {
if *FlagPluginName != "cpu" && !*flagForceSharedBuffer {
t.Skip("Skipping TestNewSharedBuffer because -plugin != \"cpu\". " +
"Set --force_create_view to force executing the test anyway")
}

// Create plugin.
plugin := must1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*FlagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

Expand All @@ -325,7 +325,7 @@
shape := shapes.Make(dtype, 2, 3)
builder := stablehlo.New("Add1")
mainFn := builder.Main()
x := mainFn.NamedInput("x", shape)
x := must1(mainFn.NamedInput("x", shape))

Check failure on line 328 in pjrt/buffers_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1

Check failure on line 328 in pjrt/buffers_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1
one := must1(mainFn.ConstantFromScalar(float32(1)))
broadcastedOne := must1(stablehlo.BroadcastInDim(one, x.Shape(), nil))
add1 := must1(stablehlo.Add(x, broadcastedOne))
Expand Down Expand Up @@ -360,7 +360,7 @@

func TestBufferDestroyAfterClient(t *testing.T) {
// Create the plugin and the client.
plugin := must1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*FlagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

Expand Down
4 changes: 2 additions & 2 deletions pjrt/clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ var (
)

func TestPlugin_NewClient(t *testing.T) {
plugin, err := GetPlugin(*flagPluginName)
plugin, err := GetPlugin(*FlagPluginName)
require.NoError(t, err)
fmt.Printf("Loaded %s\n", plugin)

Expand All @@ -51,7 +51,7 @@ func TestPlugin_NewClient(t *testing.T) {
}

func TestCompileAndExecute(t *testing.T) {
plugin, err := GetPlugin(*flagPluginName)
plugin, err := GetPlugin(*FlagPluginName)
require.NoError(t, err)
fmt.Printf("Loaded %s\n", plugin)

Expand Down
32 changes: 30 additions & 2 deletions pjrt/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ func (cc *CompileConfig) WithComputation(computation XlaComputation) *CompileCon

// WithSPMD configures the program to be compiled for "Single Program Multiple Data" (SPMD) mode.
//
// This is the old way, consider using WithShardy instead.
//
// This means the same program will be executed on multiple devices, with different inputs per device.
//
// Later the inputs to the LoadedExecutable.Execute method will be divided in numReplicas slices, each fed
Expand Down Expand Up @@ -327,21 +329,47 @@ func (cc *CompileConfig) WithSPMD(numReplicas int) *CompileConfig {
return cc
}

// WithShardy uses XLA Shardy [1] to automatically distribute the execution across a mesh of devices.
//
// After setting this, consider doing also the device assignment (WithDeviceAssignment) that should be provided
// by your shardy.DeviceMesh.
//
// Shardy uses as input the sharding specification of the inputs and outputs (and optionally hints inside the program
// to automatically shard the computation).
//
// [1] https://openxla.org/shardy/
func (cc *CompileConfig) WithShardy(numDevices int) *CompileConfig {
cc.options.ExecutableBuildOptions.UseShardyPartitioner = true
cc.options.ExecutableBuildOptions.UseSpmdPartitioning = true
cc.options.ExecutableBuildOptions.NumReplicas = 1
cc.options.ExecutableBuildOptions.NumPartitions = int64(numDevices)
cc.options.CompilePortableExecutable = false
cc.setDefaultDeviceAssignment()
return cc
}

// WithDeviceAssignment configures the device assignment for the program.
// The device assignment is used to determine the device on which each computation will be executed.
//
// Very important: this defines the device order of the sharded inputs: if you use `WithDeviceAssignment([]int{3, 2})`
// for a computation f(x) with x sharded, you need to feed the shard on device 3 first, followed by the x shard
// on device 2.
//
// The device assignment is a list of device IDs, one for each computation in the program,
// so there should be numReplicas*numPartitions entries -- organized in "numReplicas" rows and
// "numPartitions" columns (row-major).
//
// For example, if numReplicas=2 and numPartitions=3, a device assignment would specify the (replica, partition) pairs
// [(replica=0,partition=0), (0,1), (1,0), (1,1), (2,0), (2,1)].
//
// For XLA Shardy program (WithShardy), the "replica/partition" split is ignored, all devices are assigned as
// "partition" devices, and the assignment is a simple list, with the order of devices to use.
//
// If not set, the device assignment is done automatically (see Client.DefaultDeviceAssignment), and it can be queried
// using LoadedExecutable.GetDeviceAssignment method after compilation.
//
// If not set directly or indirectly with WithSPMD the program defaults to being portable, and the device assignment
// is ignored.
// If it is not set directly or indirectly with WithShardy or WithSPMD,
// the program defaults to being "portable", and the device assignment is ignored.
//
// It returns itself (CompileConfig) to allow cascading configuration calls.
func (cc *CompileConfig) WithDeviceAssignment(assignment []int) *CompileConfig {
Expand Down
2 changes: 1 addition & 1 deletion pjrt/devices_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

func TestClient_Devices(t *testing.T) {
plugin, err := GetPlugin(*flagPluginName)
plugin, err := GetPlugin(*FlagPluginName)
require.NoError(t, err)
fmt.Printf("Loaded %s\n", plugin)
client, err := plugin.NewClient(nil)
Expand Down
8 changes: 4 additions & 4 deletions pjrt/dynamiclib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ import (
"testing"
)

// TestLoadNamedPlugin loads the *flagPluginName plugin, which defaults to "cpu", that should be made available.
// TestLoadNamedPlugin loads the *FlagPluginName plugin, which defaults to "cpu", that should be made available.
func TestLoadNamedPlugin(t *testing.T) {
plugin, err := loadNamedPlugin(*flagPluginName)
plugin, err := loadNamedPlugin(*FlagPluginName)
require.NoError(t, err)
fmt.Printf("Loaded %s\n", plugin)
fmt.Printf("\tattributes: %v\n", plugin.attributes)

// Checks cache works.
plugin2, err := loadNamedPlugin(*flagPluginName)
plugin2, err := loadNamedPlugin(*FlagPluginName)
require.NoError(t, err)
require.Equal(t, plugin, plugin2)
plugin3, err := loadNamedPlugin(plugin2.Path()) // Try by using the absolute path, should return the same plugin.
Expand All @@ -52,7 +52,7 @@ func TestLoadNamedPlugin(t *testing.T) {
func TestAvailablePlugins(t *testing.T) {
plugins := AvailablePlugins()
fmt.Printf("Available plugins: %v\n", plugins)
require.NotEqualf(t, "", plugins[*flagPluginName], "Can not find %q plugin", *flagPluginName)
require.NotEqualf(t, "", plugins[*FlagPluginName], "Can not find %q plugin", *FlagPluginName)
}

// TestSuppressAbseilLoggingHack never fails, since errors are simply logged.
Expand Down
4 changes: 2 additions & 2 deletions pjrt/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

// f(x, y) = x+y
scalarF32 := shapes.Make(dtypes.F32)
x := mainFn.NamedInput("x", scalarF32) // Scalar float32.
y := mainFn.NamedInput("y", scalarF32) // Scalar float32.
x := must1(mainFn.NamedInput("x", scalarF32)) // Scalar float32.

Check failure on line 20 in pjrt/error_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1

Check failure on line 20 in pjrt/error_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1
y := must1(mainFn.NamedInput("y", scalarF32)) // Scalar float32.

Check failure on line 21 in pjrt/error_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1

Check failure on line 21 in pjrt/error_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to must1
fXY := capture(stablehlo.Add(x, y)).Test(t)

// Take program and compile.
Expand Down
6 changes: 3 additions & 3 deletions pjrt/loadedexecutables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ func TestDonatableConfig(t *testing.T) {

// f(x, y, z) = x*y + z
scalarF32 := shapes.Make(dtypes.F32)
x := mainFn.NamedInput("x", scalarF32) // Scalar float32.
y := mainFn.NamedInput("y", scalarF32) // Scalar float32.
z := mainFn.NamedInput("z", scalarF32) // Scalar float32.
x := must1(mainFn.NamedInput("x", scalarF32)) // Scalar float32.
y := must1(mainFn.NamedInput("y", scalarF32)) // Scalar float32.
z := must1(mainFn.NamedInput("z", scalarF32)) // Scalar float32.
fX := capture(stablehlo.Multiply(x, y)).Test(t)
fX = capture(stablehlo.Add(fX, z)).Test(t)

Expand Down
2 changes: 1 addition & 1 deletion pjrt/minimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestMinimal(t *testing.T) {
fmt.Printf("HLO Program:\n%s\n\n", hloModule.String())

// `dlopen` PJRT plugin.
plugin := must1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*FlagPluginName))
defer runtime.KeepAlive(plugin)
fmt.Printf("PJRT: %s\n", plugin.String())

Expand Down
Loading