From 8f5f0b5bf5733d5489a7ba918db4335c9b7403c1 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Wed, 19 Nov 2025 15:59:32 +0100 Subject: [PATCH 1/9] Added more SPMD/distributed examples. Moved some tests from `pjrt` to `pjrt_test`. --- pjrt/benchmarks_test.go | 14 ++--- pjrt/buffers_test.go | 20 +++--- pjrt/clients_test.go | 4 +- pjrt/devices_test.go | 2 +- pjrt/dynamiclib_test.go | 8 +-- pjrt/minimal_test.go | 2 +- pjrt/pjrt_test.go | 11 ++-- pjrt/spmd_test.go | 132 ++++++++++++++++++++++++++++++++++++++-- pjrt/zero_dim_test.go | 4 +- 9 files changed, 159 insertions(+), 38 deletions(-) diff --git a/pjrt/benchmarks_test.go b/pjrt/benchmarks_test.go index ca6a8b8..5ad2b5e 100644 --- a/pjrt/benchmarks_test.go +++ b/pjrt/benchmarks_test.go @@ -35,7 +35,7 @@ func TestBenchCGO(t *testing.T) { if testing.Short() { t.SkipNow() } - plugin := must1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*FlagPluginName)) const repeats = 1000 repeatedCGO := func() { for _ = range repeats { @@ -52,7 +52,7 @@ func TestBenchArena(t *testing.T) { if testing.Short() { t.SkipNow() } - plugin := must1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*FlagPluginName)) client := must1(plugin.NewClient(nil)) defer runtime.KeepAlive(client) @@ -129,7 +129,7 @@ func TestBenchBufferFromHost(t *testing.T) { if testing.Short() { t.SkipNow() } - plugin := must1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*FlagPluginName)) client := must1(plugin.NewClient(nil)) defer runtime.KeepAlive(client) @@ -171,7 +171,7 @@ func TestBenchBufferToHost(t *testing.T) { if testing.Short() { t.SkipNow() } - plugin := must1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*FlagPluginName)) client := must1(plugin.NewClient(nil)) defer runtime.KeepAlive(client) @@ -212,7 +212,7 @@ func TestBenchAdd1Execution(t *testing.T) { if testing.Short() { t.SkipNow() } - plugin := must1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*FlagPluginName)) client := must1(plugin.NewClient(nil)) defer runtime.KeepAlive(client) @@ -278,7 +278,7 @@ func TestBenchAdd1Div2Execution(t *testing.T) { if testing.Short() { t.SkipNow() } - plugin := must1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*FlagPluginName)) client := must1(plugin.NewClient(nil)) defer runtime.KeepAlive(client) @@ -347,7 +347,7 @@ func TestBenchMeanNormalizedExecution(t *testing.T) { if testing.Short() { t.SkipNow() } - plugin := must1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*FlagPluginName)) client := must1(plugin.NewClient(nil)) defer runtime.KeepAlive(client) diff --git a/pjrt/buffers_test.go b/pjrt/buffers_test.go index fd33f01..ed1b806 100644 --- a/pjrt/buffers_test.go +++ b/pjrt/buffers_test.go @@ -76,7 +76,7 @@ func testTransfersImpl[T interface { } func TestTransfers(t *testing.T) { - plugin, err := GetPlugin(*flagPluginName) + plugin, err := GetPlugin(*FlagPluginName) require.NoError(t, err) fmt.Printf("Loaded %s\n", plugin) @@ -97,7 +97,7 @@ func TestTransfers(t *testing.T) { } func TestBufferProperties(t *testing.T) { - plugin, err := GetPlugin(*flagPluginName) + plugin, err := GetPlugin(*FlagPluginName) require.NoError(t, err) fmt.Printf("Loaded %s\n", plugin) @@ -132,7 +132,7 @@ func TestBufferProperties(t *testing.T) { } func TestBufferCopyToDevice(t *testing.T) { - plugin, err := GetPlugin(*flagPluginName) + plugin, err := GetPlugin(*FlagPluginName) require.NoError(t, err) client, err := plugin.NewClient(nil) require.NoErrorf(t, err, "Failed to create a client on %s", plugin) @@ -185,13 +185,13 @@ var flagForceSharedBuffer = flag.Bool( "force_shared_buffer", false, "Force executing TestCreateViewOfDeviceBuffer and TestBufferUnsafePointer even if plugin is not \"cpu\".") func TestCreateViewOfDeviceBuffer(t *testing.T) { - if *flagPluginName != "cpu" && !*flagForceSharedBuffer { + if *FlagPluginName != "cpu" && !*flagForceSharedBuffer { t.Skip("Skipping TestCreateViewOfDeviceBuffer because -plugin != \"cpu\". " + "Set --force_create_view to force executing the test anyway") } // Create plugin. - plugin := must1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*FlagPluginName)) client := must1(plugin.NewClient(nil)) defer runtime.KeepAlive(client) @@ -248,13 +248,13 @@ func TestCreateViewOfDeviceBuffer(t *testing.T) { } func TestNewSharedBuffer(t *testing.T) { - if *flagPluginName != "cpu" && !*flagForceSharedBuffer { + if *FlagPluginName != "cpu" && !*flagForceSharedBuffer { t.Skip("Skipping TestNewSharedBuffer because -plugin != \"cpu\". " + "Set --force_create_view to force executing the test anyway") } // Create plugin. - plugin := must1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*FlagPluginName)) client := must1(plugin.NewClient(nil)) defer runtime.KeepAlive(client) @@ -310,13 +310,13 @@ func TestNewSharedBuffer(t *testing.T) { } func TestBufferData(t *testing.T) { - if *flagPluginName != "cpu" && !*flagForceSharedBuffer { + if *FlagPluginName != "cpu" && !*flagForceSharedBuffer { t.Skip("Skipping TestNewSharedBuffer because -plugin != \"cpu\". " + "Set --force_create_view to force executing the test anyway") } // Create plugin. - plugin := must1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*FlagPluginName)) client := must1(plugin.NewClient(nil)) defer runtime.KeepAlive(client) @@ -360,7 +360,7 @@ func TestBufferData(t *testing.T) { func TestBufferDestroyAfterClient(t *testing.T) { // Create the plugin and the client. - plugin := must1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*FlagPluginName)) client := must1(plugin.NewClient(nil)) defer runtime.KeepAlive(client) diff --git a/pjrt/clients_test.go b/pjrt/clients_test.go index 330a8ad..a9dda57 100644 --- a/pjrt/clients_test.go +++ b/pjrt/clients_test.go @@ -35,7 +35,7 @@ var ( ) func TestPlugin_NewClient(t *testing.T) { - plugin, err := GetPlugin(*flagPluginName) + plugin, err := GetPlugin(*FlagPluginName) require.NoError(t, err) fmt.Printf("Loaded %s\n", plugin) @@ -51,7 +51,7 @@ func TestPlugin_NewClient(t *testing.T) { } func TestCompileAndExecute(t *testing.T) { - plugin, err := GetPlugin(*flagPluginName) + plugin, err := GetPlugin(*FlagPluginName) require.NoError(t, err) fmt.Printf("Loaded %s\n", plugin) diff --git a/pjrt/devices_test.go b/pjrt/devices_test.go index ca7c35f..b8da78f 100644 --- a/pjrt/devices_test.go +++ b/pjrt/devices_test.go @@ -7,7 +7,7 @@ import ( ) func TestClient_Devices(t *testing.T) { - plugin, err := GetPlugin(*flagPluginName) + plugin, err := GetPlugin(*FlagPluginName) require.NoError(t, err) fmt.Printf("Loaded %s\n", plugin) client, err := plugin.NewClient(nil) diff --git a/pjrt/dynamiclib_test.go b/pjrt/dynamiclib_test.go index c8dc45f..653aec0 100644 --- a/pjrt/dynamiclib_test.go +++ b/pjrt/dynamiclib_test.go @@ -27,15 +27,15 @@ import ( "testing" ) -// TestLoadNamedPlugin loads the *flagPluginName plugin, which defaults to "cpu", that should be made available. +// TestLoadNamedPlugin loads the *FlagPluginName plugin, which defaults to "cpu", that should be made available. func TestLoadNamedPlugin(t *testing.T) { - plugin, err := loadNamedPlugin(*flagPluginName) + plugin, err := loadNamedPlugin(*FlagPluginName) require.NoError(t, err) fmt.Printf("Loaded %s\n", plugin) fmt.Printf("\tattributes: %v\n", plugin.attributes) // Checks cache works. - plugin2, err := loadNamedPlugin(*flagPluginName) + plugin2, err := loadNamedPlugin(*FlagPluginName) require.NoError(t, err) require.Equal(t, plugin, plugin2) plugin3, err := loadNamedPlugin(plugin2.Path()) // Try by using the absolute path, should return the same plugin. @@ -52,7 +52,7 @@ func TestLoadNamedPlugin(t *testing.T) { func TestAvailablePlugins(t *testing.T) { plugins := AvailablePlugins() fmt.Printf("Available plugins: %v\n", plugins) - require.NotEqualf(t, "", plugins[*flagPluginName], "Can not find %q plugin", *flagPluginName) + require.NotEqualf(t, "", plugins[*FlagPluginName], "Can not find %q plugin", *FlagPluginName) } // TestSuppressAbseilLoggingHack never fails, since errors are simply logged. diff --git a/pjrt/minimal_test.go b/pjrt/minimal_test.go index 56c6d3f..53e9ca5 100644 --- a/pjrt/minimal_test.go +++ b/pjrt/minimal_test.go @@ -61,7 +61,7 @@ func TestMinimal(t *testing.T) { fmt.Printf("HLO Program:\n%s\n\n", hloModule.String()) // `dlopen` PJRT plugin. - plugin := must1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*FlagPluginName)) defer runtime.KeepAlive(plugin) fmt.Printf("PJRT: %s\n", plugin.String()) diff --git a/pjrt/pjrt_test.go b/pjrt/pjrt_test.go index 72d679d..8d452a6 100644 --- a/pjrt/pjrt_test.go +++ b/pjrt/pjrt_test.go @@ -5,15 +5,16 @@ package pjrt import ( "flag" "fmt" - + + "testing" + "github.com/gomlx/gopjrt/dtypes" "github.com/pkg/errors" "github.com/stretchr/testify/require" "k8s.io/klog/v2" - "testing" ) -var flagPluginName = flag.String("plugin", "cpu", "plugin name") +var FlagPluginName = flag.String("plugin", "cpu", "plugin name") func init() { klog.InitFlags(nil) @@ -49,8 +50,8 @@ func must1[T any](t T, err error) T { // It exits the test if anything goes wrong. func getPJRTClient(t *testing.T) *Client { // PJRT plugin and create a client. - plugin, err := GetPlugin(*flagPluginName) - require.NoError(t, err, "Failed to get plugin %q", *flagPluginName) + plugin, err := GetPlugin(*FlagPluginName) + require.NoError(t, err, "Failed to get plugin %q", *FlagPluginName) attributes := plugin.Attributes() fmt.Printf("Loaded PJRT plugin %s with %d atributes:\n", plugin, len(attributes)) for key, value := range attributes { diff --git a/pjrt/spmd_test.go b/pjrt/spmd_test.go index ef1d168..ddf3c8a 100644 --- a/pjrt/spmd_test.go +++ b/pjrt/spmd_test.go @@ -1,21 +1,115 @@ -package pjrt +package pjrt_test import ( "fmt" "testing" "github.com/gomlx/gopjrt/dtypes" + "github.com/gomlx/gopjrt/pjrt" "github.com/gomlx/stablehlo" "github.com/gomlx/stablehlo/types/shapes" + "github.com/pkg/errors" "github.com/stretchr/testify/require" ) +func panicf(format string, args ...any) { + panic(errors.Errorf(format, args...)) +} + +func must(err error) { + if err != nil { + panicf("Failed: %+v", errors.WithStack(err)) + } +} + +func must1[T any](t T, err error) T { + must(err) + return t +} + +var ( + allReduceProgram = []byte( + ` +module @TestDistributedAllReduce_multiple_values__different_dtype attributes {stablehlo.num_replicas = 2 } { + func.func @main(%x: tensor, %y: tensor<2xf32>, %z: tensor<3xf64>) -> (tensor, tensor<2xf32>, tensor<3xf64>) { + %1 = "stablehlo.all_reduce"(%z) ({ + ^computation(%lhs: tensor, %rhs: tensor) : + %0 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<3xf64>) -> tensor<3xf64> + %3, %4 = "stablehlo.all_reduce"(%x, %y) ({ + ^computation(%lhs: tensor, %rhs: tensor) : + %2 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor, tensor<2xf32>) -> (tensor, tensor<2xf32>) + "stablehlo.return"(%3, %4, %1) : (tensor, tensor<2xf32>, tensor<3xf64>) -> () + } +} +`) + + allReduceProgram2 = []byte( + ` +module @TestDistributedAllReduce_multiple_values__different_dtype attributes {stablehlo.num_replicas = 2 } { + func.func @main(%x: tensor, %y: tensor<2xf32>, %z: tensor<3xf64>) -> (tensor, tensor<2xf32>, tensor<3xf64>) { + %1, %2 = "stablehlo.all_reduce"(%x, %y) ({ + ^computation(%lhs: tensor, %rhs: tensor) : + %0 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + channel_handle = #stablehlo.channel_handle, + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor, tensor<2xf32>) -> (tensor, tensor<2xf32>) + %4 = "stablehlo.all_reduce"(%z) ({ + ^computation(%lhs: tensor, %rhs: tensor) : + %3 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + "stablehlo.return"(%3) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<3xf64>) -> tensor<3xf64> + "stablehlo.return"(%1, %2, %4) : (tensor, tensor<2xf32>, tensor<3xf64>) -> () + } +}`) + + allReduceProgram3 = []byte( + ` +module @TestDistributedAllReduce_multiple_values__different_dtype attributes {stablehlo.num_replicas = 2 } { + func.func @main(%x: tensor, %y: tensor<2xf32>, %z: tensor<3xf64>) -> (tensor, tensor<2xf32>, tensor<3xf64>) { + %1 = "stablehlo.all_reduce"(%z) ({ + ^computation(%lhs: tensor, %rhs: tensor) : + %0 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor<3xf64>) -> tensor<3xf64> + %3, %4 = "stablehlo.all_reduce"(%x, %y) ({ + ^computation(%lhs: tensor, %rhs: tensor) : + %0 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor, tensor<2xf32>) -> (tensor, tensor<2xf32>) + "stablehlo.return"(%3, %4, %1) : (tensor, tensor<2xf32>, tensor<3xf64>) -> () + } +} +`) + _ = allReduceProgram + _ = allReduceProgram2 + _ = allReduceProgram3 +) + // TestSPMD builds, compiles, and executes a minimal distributed (SPMD = Single Program Multiple Data) computation, // and uses PJRT to compile and execute it. func TestSPMD(t *testing.T) { // PJRT plugin and create a client. - plugin, err := GetPlugin(*flagPluginName) - require.NoError(t, err, "Failed to get plugin %q", *flagPluginName) + plugin, err := pjrt.GetPlugin(*pjrt.FlagPluginName) + require.NoError(t, err, "Failed to get plugin %q", *pjrt.FlagPluginName) fmt.Printf("Loaded %s\n", plugin) fmt.Printf("\t- Attributes=%+v\n", plugin.Attributes()) client, err := plugin.NewClient(nil) @@ -67,7 +161,7 @@ func TestSPMD(t *testing.T) { fmt.Printf("\nStableHLO:\n%s\n", string(compBytes)) // Compile program. - var loadedExec *LoadedExecutable + var loadedExec *pjrt.LoadedExecutable loadedExec, err = client.Compile(). WithStableHLO(compBytes). WithSPMD(numReplicas). @@ -78,7 +172,7 @@ func TestSPMD(t *testing.T) { // Test values: fmt.Printf("f(x_r) = Reduce_sum(CollectiveAllReduce_sum(x_r)):\n") - inputBuffers := make([]*Buffer, numReplicas) + inputBuffers := make([]*pjrt.Buffer, numReplicas) for ii := range numReplicas { input := []float32{1.0 * float32(ii+1), 0.1 * float32(ii+1)} // Transfer input to an on-device buffer. @@ -96,7 +190,7 @@ func TestSPMD(t *testing.T) { require.Lenf(t, outputBuffers, numReplicas, "Expected %d outputs, got %d", numReplicas, len(outputBuffers)) // Transfer output on-device buffer to a "host" value (in Go). - output, err := BufferToScalar[float32](outputBuffers[0]) + output, err := pjrt.BufferToScalar[float32](outputBuffers[0]) require.NoErrorf(t, err, "Failed to transfer results of execution") // Print and check value is what we wanted. @@ -118,3 +212,29 @@ func TestSPMD(t *testing.T) { err = client.Destroy() require.NoErrorf(t, err, "Failed to destroy client on %s", plugin) } + +func TestCollectiveAllReduce(t *testing.T) { + // PJRT plugin and create a client. + plugin, err := pjrt.GetPlugin(*pjrt.FlagPluginName) + require.NoError(t, err, "Failed to get plugin %q", *pjrt.FlagPluginName) + fmt.Printf("Loaded %s\n", plugin) + fmt.Printf("\t- Attributes=%+v\n", plugin.Attributes()) + client, err := plugin.NewClient(nil) + require.NoErrorf(t, err, "Failed to create a client on %s", plugin) + fmt.Printf(" client: %s\n", client) + + // Verify that we have enough devices. + devices := client.AddressableDevices() + if len(devices) < 2 { + t.Skipf("TestCollectiveAllReduce requires at least 2 devices, only %d available", len(devices)) + } + + // Compile program: the default compilation is "portable", meaning it can be executed by any device. + var loadedExec *pjrt.LoadedExecutable + loadedExec, err = client.Compile(). + WithStableHLO(allReduceProgram3). + WithSPMD(2). + Done() + require.NoErrorf(t, err, "Failed to compile program") + fmt.Printf("Compiled program: name=%s, #outputs=%d\n", loadedExec.Name, loadedExec.NumOutputs) +} diff --git a/pjrt/zero_dim_test.go b/pjrt/zero_dim_test.go index e0c7847..1b5e127 100644 --- a/pjrt/zero_dim_test.go +++ b/pjrt/zero_dim_test.go @@ -13,7 +13,7 @@ import ( ) func TestZeroDim(t *testing.T) { - plugin, err := GetPlugin(*flagPluginName) + plugin, err := GetPlugin(*FlagPluginName) require.NoError(t, err) fmt.Printf("Loaded %s\n", plugin) @@ -74,7 +74,7 @@ func TestZeroDim(t *testing.T) { }) // Test 2: Create zero-dimension buffer using NewSharedBuffer (CPU only) - if *flagPluginName == "cpu" || *flagForceSharedBuffer { + if *FlagPluginName == "cpu" || *flagForceSharedBuffer { t.Run("NewSharedBuffer", func(t *testing.T) { fmt.Println("testing NewSharedBuffer") testZeroDimNewSharedBuffer(t, client, tc.dtype, tc.dimensions, tc.expectSize) From 3fe751a14c317cfa86b80ff091660dbd9d510996 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 20 Nov 2025 05:40:03 +0100 Subject: [PATCH 2/9] Fixed to new NamedInput API. --- docs/CHANGELOG.md | 4 +++- pjrt/benchmarks_test.go | 10 +++++----- pjrt/buffers_test.go | 6 +++--- pjrt/error_test.go | 4 ++-- pjrt/loadedexecutables_test.go | 6 +++--- pjrt/shardy_test.go | 1 + pjrt/spmd_test.go | 10 +++++----- pjrt/zero_dim_test.go | 2 +- 8 files changed, 23 insertions(+), 20 deletions(-) create mode 100644 pjrt/shardy_test.go diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 2eb237d..1028be3 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,6 +1,6 @@ # Gopjrt Changelog -# Next +# 0.10.0 Added Shardy support. - Package `cmd/gopjrt_installer`: - Link `libcublasLt.so.13` and `libcublas.so.13` to the `lib` subdirectory of the installation directory given. @@ -10,6 +10,8 @@ - Fixed wrong struct size set for `PJRT_Event_Destroy_Args`, in #65 (@timkaye11) - Fixed `buffer.Destroy` to release the `client` pointer in the wrapper. - Added a required `runtime.KeepAlive(program)` on a CGO call to compile `program`. + - Added new Shardy support for distributed (across multiple devices) execution. + - Added old SPMD support (see `pjrt.Compile().WithSPMD`) # v0.9.1 2025/11/07: More multi-device support; updated CPU PJRT; dropped static CPU PJRT linking. diff --git a/pjrt/benchmarks_test.go b/pjrt/benchmarks_test.go index 5ad2b5e..93906ee 100644 --- a/pjrt/benchmarks_test.go +++ b/pjrt/benchmarks_test.go @@ -233,7 +233,7 @@ func TestBenchAdd1Execution(t *testing.T) { builder := stablehlo.New(fmt.Sprintf("Add1/%s", s)) mainFn := builder.Main() // f(x) = x + 1 - x := mainFn.NamedInput("x", s) + x := must1(mainFn.NamedInput("x", s)) one := must1(mainFn.ConstantFromScalar(float32(1))) broadcastedOne := must1(stablehlo.BroadcastInDim(one, x.Shape(), nil)) add1 := must1(stablehlo.Add(x, broadcastedOne)) @@ -299,7 +299,7 @@ func TestBenchAdd1Div2Execution(t *testing.T) { builder := stablehlo.New(fmt.Sprintf("Add1/%s", s)) mainFn := builder.Main() // f(x) = (x + 1) * 0.5 - x := mainFn.NamedInput("x", s) + x := must1(mainFn.NamedInput("x", s)) one := must1(mainFn.ConstantFromScalar(float32(1))) broadcastedOne := must1(stablehlo.BroadcastInDim(one, x.Shape(), nil)) add1 := must1(stablehlo.Add(x, broadcastedOne)) @@ -368,7 +368,7 @@ func TestBenchMeanNormalizedExecution(t *testing.T) { builder := stablehlo.New(fmt.Sprintf("MeanNormalized/%s", s)) mainFn := builder.Main() // f(x) = (x + 1) * 0.5 - mean((x + 1) * 0.5) - x := mainFn.NamedInput("x", s) + x := must1(mainFn.NamedInput("x", s)) one := must1(mainFn.ConstantFromScalar(float32(1))) broadcastedOne := must1(stablehlo.BroadcastInDim(one, x.Shape(), nil)) add1 := must1(stablehlo.Add(x, broadcastedOne)) @@ -377,8 +377,8 @@ func TestBenchMeanNormalizedExecution(t *testing.T) { div2 := must1(stablehlo.Multiply(add1, broadcastedHalf)) reductionFn := mainFn.Closure() - lhs := reductionFn.NamedInput("lhs", shapes.Make(dtypes.F32)) - rhs := reductionFn.NamedInput("rhs", shapes.Make(dtypes.F32)) + lhs := must1(reductionFn.NamedInput("lhs", shapes.Make(dtypes.F32))) + rhs := must1(reductionFn.NamedInput("rhs", shapes.Make(dtypes.F32))) must(reductionFn.Return(must1(stablehlo.Add(lhs, rhs)))) initialValue := must1(mainFn.ConstantFromScalar(float32(0))) diff --git a/pjrt/buffers_test.go b/pjrt/buffers_test.go index ed1b806..55570b7 100644 --- a/pjrt/buffers_test.go +++ b/pjrt/buffers_test.go @@ -200,7 +200,7 @@ func TestCreateViewOfDeviceBuffer(t *testing.T) { shape := shapes.Make(dtype, 2, 3) builder := stablehlo.New("Add1") mainFn := builder.Main() - x := mainFn.NamedInput("x", shape) + x := must1(mainFn.NamedInput("x", shape)) one := must1(mainFn.ConstantFromScalar(float32(1))) broadcastedOne := must1(stablehlo.BroadcastInDim(one, x.Shape(), nil)) add1 := must1(stablehlo.Add(x, broadcastedOne)) @@ -263,7 +263,7 @@ func TestNewSharedBuffer(t *testing.T) { shape := shapes.Make(dtype, 2, 3) builder := stablehlo.New("Add1") mainFn := builder.Main() - x := mainFn.NamedInput("x", shape) + x := must1(mainFn.NamedInput("x", shape)) one := must1(mainFn.ConstantFromScalar(float32(1))) broadcastedOne := must1(stablehlo.BroadcastInDim(one, x.Shape(), nil)) add1 := must1(stablehlo.Add(x, broadcastedOne)) @@ -325,7 +325,7 @@ func TestBufferData(t *testing.T) { shape := shapes.Make(dtype, 2, 3) builder := stablehlo.New("Add1") mainFn := builder.Main() - x := mainFn.NamedInput("x", shape) + x := must1(mainFn.NamedInput("x", shape)) one := must1(mainFn.ConstantFromScalar(float32(1))) broadcastedOne := must1(stablehlo.BroadcastInDim(one, x.Shape(), nil)) add1 := must1(stablehlo.Add(x, broadcastedOne)) diff --git a/pjrt/error_test.go b/pjrt/error_test.go index 74f2c0f..3cb227c 100644 --- a/pjrt/error_test.go +++ b/pjrt/error_test.go @@ -17,8 +17,8 @@ func TestError(t *testing.T) { // f(x, y) = x+y scalarF32 := shapes.Make(dtypes.F32) - x := mainFn.NamedInput("x", scalarF32) // Scalar float32. - y := mainFn.NamedInput("y", scalarF32) // Scalar float32. + x := must1(mainFn.NamedInput("x", scalarF32)) // Scalar float32. + y := must1(mainFn.NamedInput("y", scalarF32)) // Scalar float32. fXY := capture(stablehlo.Add(x, y)).Test(t) // Take program and compile. diff --git a/pjrt/loadedexecutables_test.go b/pjrt/loadedexecutables_test.go index dc6438a..75ce352 100644 --- a/pjrt/loadedexecutables_test.go +++ b/pjrt/loadedexecutables_test.go @@ -17,9 +17,9 @@ func TestDonatableConfig(t *testing.T) { // f(x, y, z) = x*y + z scalarF32 := shapes.Make(dtypes.F32) - x := mainFn.NamedInput("x", scalarF32) // Scalar float32. - y := mainFn.NamedInput("y", scalarF32) // Scalar float32. - z := mainFn.NamedInput("z", scalarF32) // Scalar float32. + x := must1(mainFn.NamedInput("x", scalarF32)) // Scalar float32. + y := must1(mainFn.NamedInput("y", scalarF32)) // Scalar float32. + z := must1(mainFn.NamedInput("z", scalarF32)) // Scalar float32. fX := capture(stablehlo.Multiply(x, y)).Test(t) fX = capture(stablehlo.Add(fX, z)).Test(t) diff --git a/pjrt/shardy_test.go b/pjrt/shardy_test.go new file mode 100644 index 0000000..fc851bd --- /dev/null +++ b/pjrt/shardy_test.go @@ -0,0 +1 @@ +package pjrt diff --git a/pjrt/spmd_test.go b/pjrt/spmd_test.go index ddf3c8a..a09b920 100644 --- a/pjrt/spmd_test.go +++ b/pjrt/spmd_test.go @@ -142,15 +142,15 @@ func TestSPMD(t *testing.T) { builder := stablehlo.New("sum_x0") mainFn := builder.Main() argShape := shapes.Make(dtypes.F32, 2) - x := mainFn.NamedInput("x", argShape) + x := must1(mainFn.NamedInput("x", argShape)) reductionFn := mainFn.Closure() - lhs := reductionFn.NamedInput("lhs", shapes.Make(dtypes.F32)) - rhs := reductionFn.NamedInput("rhs", shapes.Make(dtypes.F32)) + lhs := must1(reductionFn.NamedInput("lhs", shapes.Make(dtypes.F32))) + rhs := must1(reductionFn.NamedInput("rhs", shapes.Make(dtypes.F32))) must(reductionFn.Return(must1(stablehlo.Add(lhs, rhs)))) - reducedReplicas, err := stablehlo.AllReduce(x, replicaGroups, reductionFn) + reducedReplicas, err := stablehlo.AllReduce([]*stablehlo.Value{x}, replicaGroups, reductionFn) require.NoError(t, err, "Failed operation CollectiveAllReduce") zero := must1(mainFn.ConstantFromScalar(float32(0))) - sum, err := stablehlo.Reduce(reducedReplicas, zero, reductionFn, 0) + sum, err := stablehlo.Reduce(reducedReplicas[0], zero, reductionFn, 0) require.NoError(t, err, "Failed operation Reduce") err = mainFn.Return(sum) require.NoError(t, err, "Failed operation Return") diff --git a/pjrt/zero_dim_test.go b/pjrt/zero_dim_test.go index 1b5e127..90a6ad2 100644 --- a/pjrt/zero_dim_test.go +++ b/pjrt/zero_dim_test.go @@ -217,7 +217,7 @@ func testZeroDimAsInput(t *testing.T, client *Client, dtype dtypes.DType, dimens builder := stablehlo.New("ZeroDimIdentity") shape := shapes.Make(dtype, dimensions...) mainFn := builder.Main() - param := mainFn.NamedInput("input", shape) + param := must1(mainFn.NamedInput("input", shape)) err := mainFn.Return(param) require.NoError(t, err, "Failed to set return value") From c55364370a303899022cfc3589f6e9a72039d0b9 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 20 Nov 2025 05:40:57 +0100 Subject: [PATCH 3/9] Added `Compile().WithShardy()` and minimal example. --- pjrt/compile.go | 21 +++++++++ pjrt/shardy_test.go | 111 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/pjrt/compile.go b/pjrt/compile.go index 67111be..4a40200 100644 --- a/pjrt/compile.go +++ b/pjrt/compile.go @@ -297,6 +297,8 @@ func (cc *CompileConfig) WithComputation(computation XlaComputation) *CompileCon // WithSPMD configures the program to be compiled for "Single Program Multiple Data" (SPMD) mode. // +// This is the old way, consider using WithShardy instead. +// // This means the same program will be executed on multiple devices, with different inputs per device. // // Later the inputs to the LoadedExecutable.Execute method will be divided in numReplicas slices, each fed @@ -327,6 +329,25 @@ func (cc *CompileConfig) WithSPMD(numReplicas int) *CompileConfig { return cc } +// WithShardy uses XLA Shardy [1] to automatically distribute the execution across a mesh of devices. +// +// After setting this, consider doing also the device assignment (WithDeviceAssignment) that should be provided +// by your shardy.DeviceMesh. +// +// Shardy uses as input the sharding specification of the inputs and outputs (and optionally hints inside the program +// to automatically shard the computation). +// +// [1] https://openxla.org/shardy/ +func (cc *CompileConfig) WithShardy(numDevices int) *CompileConfig { + cc.options.ExecutableBuildOptions.UseShardyPartitioner = true + cc.options.ExecutableBuildOptions.UseSpmdPartitioning = true + cc.options.ExecutableBuildOptions.NumReplicas = 1 + cc.options.ExecutableBuildOptions.NumPartitions = int64(numDevices) + cc.options.CompilePortableExecutable = false + cc.setDefaultDeviceAssignment() + return cc +} + // WithDeviceAssignment configures the device assignment for the program. // The device assignment is used to determine the device on which each computation will be executed. // diff --git a/pjrt/shardy_test.go b/pjrt/shardy_test.go index fc851bd..c0abb72 100644 --- a/pjrt/shardy_test.go +++ b/pjrt/shardy_test.go @@ -1 +1,110 @@ -package pjrt +package pjrt_test + +import ( + "fmt" + "testing" + + "github.com/gomlx/gopjrt/dtypes" + "github.com/gomlx/gopjrt/pjrt" + "github.com/gomlx/stablehlo/types/shapes" + "github.com/gomlx/stablehlo/types/shardy" + "github.com/stretchr/testify/require" +) + +func TestShardy(t *testing.T) { + plugin, err := pjrt.GetPlugin(*pjrt.FlagPluginName) + require.NoError(t, err, "Failed to get plugin %q", *pjrt.FlagPluginName) + fmt.Printf("Loaded %s\n", plugin) + fmt.Printf("\t- Attributes=%+v\n", plugin.Attributes()) + client, err := plugin.NewClient(nil) + require.NoErrorf(t, err, "Failed to create a client on %s", plugin) + fmt.Printf(" client: %s\n", client) + + // We will test it with 2 devices. + const numReplicas = 2 + numDevices := client.NumDevices() + if numDevices < numReplicas { + t.Skipf("Skipping test: not enough devices: %d < %d", numDevices, numReplicas) + return + } + + t.Run("input-data-sharding", func(t *testing.T) { + mesh := must1(shardy.NewDeviceMesh("data_mesh", []int{2}, []string{"data"})) + program := []byte(`module @TestShardy_input_data_sharding attributes {mhlo.num_replicas = 2:i32, mhlo.num_partitions = 1:i32} { + sdy.mesh @data_mesh = <["data"=2]> + func.func @main(%arg0: tensor<2x3xf32> { sdy.sharding = #sdy.sharding<@data_mesh, [{"data"}, {}]> }) -> tensor { + %1 = "stablehlo.constant"() { value = dense<0.0> : tensor } : () -> tensor + %2 = "stablehlo.reduce"(%arg0, %1) ({ + ^reductionFn(%lhs: tensor, %rhs: tensor) : + %0 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { dimensions = array } : (tensor<2x3xf32>, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + } + }`) + x0 := must1(client.BufferFromHost().ToDeviceNum(0).FromFlatDataWithDimensions( + []float32{0, 1, 2}, []int{1, 3}).Done()) + x1 := must1(client.BufferFromHost().ToDeviceNum(1).FromFlatDataWithDimensions( + []float32{0, 0.1, 0.2}, []int{1, 3}).Done()) + outputs := shardyCompileAndExecute(t, client, program, mesh, x0, x1) + requireBuffersEqual(t, []FlatAndDims{ + {[]float32{3.3}, nil}, + {[]float32{3.3}, nil}, + }, outputs) + }) +} + +// compileAndExecute program with PJRT. All inputs are donated. +func shardyCompileAndExecute(t *testing.T, client *pjrt.Client, program []byte, + mesh *shardy.DeviceMesh, inputs ...*pjrt.Buffer) []*pjrt.Buffer { + loadedExec, err := client.Compile(). + WithStableHLO(program). + WithShardy(mesh.NumDevices()). + WithDeviceAssignment(mesh.DeviceAssignment()). + Done() + require.NoErrorf(t, err, "failed to compile program: \n%s", program) + defer func() { + err := loadedExec.Destroy() + if err != nil { + t.Errorf("failed to destroy loaded exec: %+v", err) + } + }() + outputBuffers, err := loadedExec.Execute(inputs...).DonateAll().Done() + require.NoErrorf(t, err, "failed to execute program: \n%s", program) + return outputBuffers +} + +type FlatAndDims struct { + Flat any + Dims []int +} + +// requireBuffersEqual checks that the actual buffers contents match the expected flat values. +// It destroys the buffers. +func requireBuffersEqual(t *testing.T, expected []FlatAndDims, got []*pjrt.Buffer) { + defer func() { + for _, b := range got { + err := b.Destroy() + if err != nil { + t.Errorf("failed to destroy buffer: %+v", err) + } + } + }() + require.Len(t, got, len(expected)) + for i, b := range got { + gotFlat, gotDims, err := b.ToFlatDataAndDimensions() + expectedShape, err := shapes.FromAnyValue(expected[i].Flat) + require.NoErrorf(t, err, "failed to get shape for output #%d: %v", i, expected[i].Flat) + dtype := expectedShape.DType + fmt.Printf("\t - output #%d:\n\t - Got: dims=%v, flat_values=%v\n", i, gotDims, gotFlat) + fmt.Printf("\t - Want(%s): dims=%v, flat_values=%v\n", dtype, expected[i].Dims, expected[i].Flat) + require.NoErrorf(t, err, "failed to get buffer contents for output #%d, expected flat value %v", i, expected[i].Flat) + require.Equalf(t, expected[i].Dims, gotDims, "output #%d dims don't match", i) + switch dtype { + case dtypes.Float64, dtypes.Float32: + require.InDeltaSlicef(t, expected[i].Flat, gotFlat, 1e-4, "output #%d flat values don't match", i) + default: + require.Equalf(t, expected[i].Flat, gotFlat, "output #%d flat values don't match", i) + } + } +} From fc9a0a70efdca773ceec73c8a86c1c026aac76a5 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 20 Nov 2025 09:00:41 +0100 Subject: [PATCH 4/9] Adapted documentation and example to support multiple meshes. --- pjrt/compile.go | 11 +++++++-- pjrt/shardy_test.go | 59 +++++++++++++++++++++++---------------------- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/pjrt/compile.go b/pjrt/compile.go index 4a40200..e3e781d 100644 --- a/pjrt/compile.go +++ b/pjrt/compile.go @@ -351,6 +351,10 @@ func (cc *CompileConfig) WithShardy(numDevices int) *CompileConfig { // WithDeviceAssignment configures the device assignment for the program. // The device assignment is used to determine the device on which each computation will be executed. // +// Very important: this defines the device order of the sharded inputs: if you use `WithDeviceAssignment([]int{3, 2})` +// for a computation f(x) with x sharded, you need to feed the shard on device 3 first, followed by the x shard +// on device 2. +// // The device assignment is a list of device IDs, one for each computation in the program, // so there should be numReplicas*numPartitions entries -- organized in "numReplicas" rows and // "numPartitions" columns (row-major). @@ -358,11 +362,14 @@ func (cc *CompileConfig) WithShardy(numDevices int) *CompileConfig { // For example, if numReplicas=2 and numPartitions=3, a device assignment would specify the (replica, partition) pairs // [(replica=0,partition=0), (0,1), (1,0), (1,1), (2,0), (2,1)]. // +// For XLA Shardy program (WithShardy), the "replica/partition" split is ignored, all devices are assigned as +// "partition" devices, and the assignment is a simple list, with the order of devices to use. +// // If not set, the device assignment is done automatically (see Client.DefaultDeviceAssignment), and it can be queried // using LoadedExecutable.GetDeviceAssignment method after compilation. // -// If not set directly or indirectly with WithSPMD the program defaults to being portable, and the device assignment -// is ignored. +// If it is not set directly or indirectly with WithShardy or WithSPMD, +// the program defaults to being "portable", and the device assignment is ignored. // // It returns itself (CompileConfig) to allow cascading configuration calls. func (cc *CompileConfig) WithDeviceAssignment(assignment []int) *CompileConfig { diff --git a/pjrt/shardy_test.go b/pjrt/shardy_test.go index c0abb72..4a474b5 100644 --- a/pjrt/shardy_test.go +++ b/pjrt/shardy_test.go @@ -7,7 +7,6 @@ import ( "github.com/gomlx/gopjrt/dtypes" "github.com/gomlx/gopjrt/pjrt" "github.com/gomlx/stablehlo/types/shapes" - "github.com/gomlx/stablehlo/types/shardy" "github.com/stretchr/testify/require" ) @@ -29,9 +28,9 @@ func TestShardy(t *testing.T) { } t.Run("input-data-sharding", func(t *testing.T) { - mesh := must1(shardy.NewDeviceMesh("data_mesh", []int{2}, []string{"data"})) - program := []byte(`module @TestShardy_input_data_sharding attributes {mhlo.num_replicas = 2:i32, mhlo.num_partitions = 1:i32} { - sdy.mesh @data_mesh = <["data"=2]> + program := []byte(`module @TestShardy_input_data_sharding attributes {} { + sdy.mesh @mesh_unused = <["x"=1, "y"=2], device_ids=[0, 1]> + sdy.mesh @data_mesh = <["data"=2], device_ids=[1, 0]> func.func @main(%arg0: tensor<2x3xf32> { sdy.sharding = #sdy.sharding<@data_mesh, [{"data"}, {}]> }) -> tensor { %1 = "stablehlo.constant"() { value = dense<0.0> : tensor } : () -> tensor %2 = "stablehlo.reduce"(%arg0, %1) ({ @@ -42,11 +41,33 @@ func TestShardy(t *testing.T) { "stablehlo.return"(%2) : (tensor) -> () } }`) - x0 := must1(client.BufferFromHost().ToDeviceNum(0).FromFlatDataWithDimensions( - []float32{0, 1, 2}, []int{1, 3}).Done()) - x1 := must1(client.BufferFromHost().ToDeviceNum(1).FromFlatDataWithDimensions( - []float32{0, 0.1, 0.2}, []int{1, 3}).Done()) - outputs := shardyCompileAndExecute(t, client, program, mesh, x0, x1) + deviceAssignment := []int{3, 2} // More than we actually use, but it doesn't matter. + loadedExec, err := client.Compile(). + WithStableHLO(program). + WithShardy(2). + WithDeviceAssignment(deviceAssignment). + Done() + require.NoErrorf(t, err, "failed to compile program: \n%s", program) + defer func() { + err := loadedExec.Destroy() + if err != nil { + t.Errorf("failed to destroy loaded exec: %+v", err) + } + }() + + // Notice we provided device nums + x0 := must1(client.BufferFromHost(). + ToDeviceNum(deviceAssignment[0]). + FromFlatDataWithDimensions([]float32{0, 1, 2}, []int{1, 3}). + Done()) + x1 := must1(client.BufferFromHost(). + ToDeviceNum(deviceAssignment[1]). + FromFlatDataWithDimensions([]float32{0, 0.1, 0.2}, []int{1, 3}). + Done()) + outputs, err := loadedExec.Execute(x0, x1).DonateAll().Done() + require.NoErrorf(t, err, "failed to execute program: \n%s", program) + + // Check results. requireBuffersEqual(t, []FlatAndDims{ {[]float32{3.3}, nil}, {[]float32{3.3}, nil}, @@ -54,26 +75,6 @@ func TestShardy(t *testing.T) { }) } -// compileAndExecute program with PJRT. All inputs are donated. -func shardyCompileAndExecute(t *testing.T, client *pjrt.Client, program []byte, - mesh *shardy.DeviceMesh, inputs ...*pjrt.Buffer) []*pjrt.Buffer { - loadedExec, err := client.Compile(). - WithStableHLO(program). - WithShardy(mesh.NumDevices()). - WithDeviceAssignment(mesh.DeviceAssignment()). - Done() - require.NoErrorf(t, err, "failed to compile program: \n%s", program) - defer func() { - err := loadedExec.Destroy() - if err != nil { - t.Errorf("failed to destroy loaded exec: %+v", err) - } - }() - outputBuffers, err := loadedExec.Execute(inputs...).DonateAll().Done() - require.NoErrorf(t, err, "failed to execute program: \n%s", program) - return outputBuffers -} - type FlatAndDims struct { Flat any Dims []int From 6ef2aa91c2603b41d05a136aacfe256ae96a09fe Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Fri, 21 Nov 2025 07:13:44 +0100 Subject: [PATCH 5/9] Updated shardy test. --- pjrt/shardy_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pjrt/shardy_test.go b/pjrt/shardy_test.go index 4a474b5..e079d44 100644 --- a/pjrt/shardy_test.go +++ b/pjrt/shardy_test.go @@ -28,8 +28,8 @@ func TestShardy(t *testing.T) { } t.Run("input-data-sharding", func(t *testing.T) { - program := []byte(`module @TestShardy_input_data_sharding attributes {} { - sdy.mesh @mesh_unused = <["x"=1, "y"=2], device_ids=[0, 1]> + program := []byte(`module @TestShardy_input_data_sharding attributes {stablehlo.num_replicas = 1, stablehlo.num_partitions = 2} { + sdy.mesh @unused_mesh = <["x"=1, "y"=2], device_ids=[0, 1]> sdy.mesh @data_mesh = <["data"=2], device_ids=[1, 0]> func.func @main(%arg0: tensor<2x3xf32> { sdy.sharding = #sdy.sharding<@data_mesh, [{"data"}, {}]> }) -> tensor { %1 = "stablehlo.constant"() { value = dense<0.0> : tensor } : () -> tensor @@ -41,7 +41,7 @@ func TestShardy(t *testing.T) { "stablehlo.return"(%2) : (tensor) -> () } }`) - deviceAssignment := []int{3, 2} // More than we actually use, but it doesn't matter. + deviceAssignment := []int{3, 2} loadedExec, err := client.Compile(). WithStableHLO(program). WithShardy(2). From 7722ce7d785e1bd3e1f14a3c604bb3dd1c69a1bd Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Sun, 23 Nov 2025 07:32:15 +0100 Subject: [PATCH 6/9] Updated test for AllReduce. --- pjrt/spmd_test.go | 97 +++++++++++++++++++---------------------------- 1 file changed, 38 insertions(+), 59 deletions(-) diff --git a/pjrt/spmd_test.go b/pjrt/spmd_test.go index a09b920..1947676 100644 --- a/pjrt/spmd_test.go +++ b/pjrt/spmd_test.go @@ -28,7 +28,7 @@ func must1[T any](t T, err error) T { } var ( - allReduceProgram = []byte( + allReduceProgramFail = []byte( ` module @TestDistributedAllReduce_multiple_values__different_dtype attributes {stablehlo.num_replicas = 2 } { func.func @main(%x: tensor, %y: tensor<2xf32>, %z: tensor<3xf64>) -> (tensor, tensor<2xf32>, tensor<3xf64>) { @@ -53,31 +53,7 @@ module @TestDistributedAllReduce_multiple_values__different_dtype attributes {st } `) - allReduceProgram2 = []byte( - ` -module @TestDistributedAllReduce_multiple_values__different_dtype attributes {stablehlo.num_replicas = 2 } { - func.func @main(%x: tensor, %y: tensor<2xf32>, %z: tensor<3xf64>) -> (tensor, tensor<2xf32>, tensor<3xf64>) { - %1, %2 = "stablehlo.all_reduce"(%x, %y) ({ - ^computation(%lhs: tensor, %rhs: tensor) : - %0 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor - "stablehlo.return"(%0) : (tensor) -> () - }) { - channel_handle = #stablehlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - } : (tensor, tensor<2xf32>) -> (tensor, tensor<2xf32>) - %4 = "stablehlo.all_reduce"(%z) ({ - ^computation(%lhs: tensor, %rhs: tensor) : - %3 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor - "stablehlo.return"(%3) : (tensor) -> () - }) { - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - channel_handle = #stablehlo.channel_handle - } : (tensor<3xf64>) -> tensor<3xf64> - "stablehlo.return"(%1, %2, %4) : (tensor, tensor<2xf32>, tensor<3xf64>) -> () - } -}`) - - allReduceProgram3 = []byte( + allReduceProgram = []byte( ` module @TestDistributedAllReduce_multiple_values__different_dtype attributes {stablehlo.num_replicas = 2 } { func.func @main(%x: tensor, %y: tensor<2xf32>, %z: tensor<3xf64>) -> (tensor, tensor<2xf32>, tensor<3xf64>) { @@ -86,24 +62,52 @@ module @TestDistributedAllReduce_multiple_values__different_dtype attributes {st %0 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor "stablehlo.return"(%0) : (tensor) -> () }) { - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_id = 1 } : (tensor<3xf64>) -> tensor<3xf64> %3, %4 = "stablehlo.all_reduce"(%x, %y) ({ ^computation(%lhs: tensor, %rhs: tensor) : - %0 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor - "stablehlo.return"(%0) : (tensor) -> () + %2 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () }) { - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_id = 2 } : (tensor, tensor<2xf32>) -> (tensor, tensor<2xf32>) "stablehlo.return"(%3, %4, %1) : (tensor, tensor<2xf32>, tensor<3xf64>) -> () } } `) + _ = allReduceProgram - _ = allReduceProgram2 - _ = allReduceProgram3 + _ = allReduceProgramFail ) +func TestCollectiveAllReduce(t *testing.T) { + // PJRT plugin and create a client. + plugin, err := pjrt.GetPlugin(*pjrt.FlagPluginName) + require.NoError(t, err, "Failed to get plugin %q", *pjrt.FlagPluginName) + fmt.Printf("Loaded %s\n", plugin) + fmt.Printf("\t- Attributes=%+v\n", plugin.Attributes()) + client, err := plugin.NewClient(nil) + require.NoErrorf(t, err, "Failed to create a client on %s", plugin) + fmt.Printf(" client: %s\n", client) + + // Verify that we have enough devices. + devices := client.AddressableDevices() + if len(devices) < 2 { + t.Skipf("TestCollectiveAllReduce requires at least 2 devices, only %d available", len(devices)) + } + + // Compile program: the default compilation is "portable", meaning it can be executed by any device. + var loadedExec *pjrt.LoadedExecutable + loadedExec, err = client.Compile(). + WithStableHLO(allReduceProgram). + WithSPMD(2). + Done() + require.NoErrorf(t, err, "Failed to compile program") + fmt.Printf("Compiled program: name=%s, #outputs=%d\n", loadedExec.Name, loadedExec.NumOutputs) +} + // TestSPMD builds, compiles, and executes a minimal distributed (SPMD = Single Program Multiple Data) computation, // and uses PJRT to compile and execute it. func TestSPMD(t *testing.T) { @@ -126,7 +130,8 @@ func TestSPMD(t *testing.T) { require.NoError(t, err) desc, err := device.GetDescription() require.NoError(t, err) - fmt.Printf("\tDevice #%d: hardwareId=%d, addressable=%v, description=%s\n", deviceNum, hardwareId, addressable, desc.DebugString()) + fmt.Printf("\tDevice #%d: hardwareId=%d, addressable=%v, description=%s\n", + deviceNum, hardwareId, addressable, desc.DebugString()) } // Create replicaGroups [numPartitions=1][numReplicas=numDevices] according to the device assignment. @@ -212,29 +217,3 @@ func TestSPMD(t *testing.T) { err = client.Destroy() require.NoErrorf(t, err, "Failed to destroy client on %s", plugin) } - -func TestCollectiveAllReduce(t *testing.T) { - // PJRT plugin and create a client. - plugin, err := pjrt.GetPlugin(*pjrt.FlagPluginName) - require.NoError(t, err, "Failed to get plugin %q", *pjrt.FlagPluginName) - fmt.Printf("Loaded %s\n", plugin) - fmt.Printf("\t- Attributes=%+v\n", plugin.Attributes()) - client, err := plugin.NewClient(nil) - require.NoErrorf(t, err, "Failed to create a client on %s", plugin) - fmt.Printf(" client: %s\n", client) - - // Verify that we have enough devices. - devices := client.AddressableDevices() - if len(devices) < 2 { - t.Skipf("TestCollectiveAllReduce requires at least 2 devices, only %d available", len(devices)) - } - - // Compile program: the default compilation is "portable", meaning it can be executed by any device. - var loadedExec *pjrt.LoadedExecutable - loadedExec, err = client.Compile(). - WithStableHLO(allReduceProgram3). - WithSPMD(2). - Done() - require.NoErrorf(t, err, "Failed to compile program") - fmt.Printf("Compiled program: name=%s, #outputs=%d\n", loadedExec.Name, loadedExec.NumOutputs) -} From 8ce356fab16bca80286ae514013decca381d4ec4 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Sun, 23 Nov 2025 07:33:35 +0100 Subject: [PATCH 7/9] Start channel_id at 0 for test. --- pjrt/spmd_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pjrt/spmd_test.go b/pjrt/spmd_test.go index 1947676..5bcb25d 100644 --- a/pjrt/spmd_test.go +++ b/pjrt/spmd_test.go @@ -63,7 +63,7 @@ module @TestDistributedAllReduce_multiple_values__different_dtype attributes {st "stablehlo.return"(%0) : (tensor) -> () }) { replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - channel_id = 1 + channel_id = 0 } : (tensor<3xf64>) -> tensor<3xf64> %3, %4 = "stablehlo.all_reduce"(%x, %y) ({ ^computation(%lhs: tensor, %rhs: tensor) : @@ -71,7 +71,7 @@ module @TestDistributedAllReduce_multiple_values__different_dtype attributes {st "stablehlo.return"(%2) : (tensor) -> () }) { replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - channel_id = 2 + channel_id = 1 } : (tensor, tensor<2xf32>) -> (tensor, tensor<2xf32>) "stablehlo.return"(%3, %4, %1) : (tensor, tensor<2xf32>, tensor<3xf64>) -> () } From 07d2ab51ab1fc3c284b878809256ca168e4a40ba Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Sun, 23 Nov 2025 07:51:52 +0100 Subject: [PATCH 8/9] Fixed AllReduce test. --- pjrt/spmd_test.go | 34 ++-------------------------------- 1 file changed, 2 insertions(+), 32 deletions(-) diff --git a/pjrt/spmd_test.go b/pjrt/spmd_test.go index 5bcb25d..fde3067 100644 --- a/pjrt/spmd_test.go +++ b/pjrt/spmd_test.go @@ -28,31 +28,6 @@ func must1[T any](t T, err error) T { } var ( - allReduceProgramFail = []byte( - ` -module @TestDistributedAllReduce_multiple_values__different_dtype attributes {stablehlo.num_replicas = 2 } { - func.func @main(%x: tensor, %y: tensor<2xf32>, %z: tensor<3xf64>) -> (tensor, tensor<2xf32>, tensor<3xf64>) { - %1 = "stablehlo.all_reduce"(%z) ({ - ^computation(%lhs: tensor, %rhs: tensor) : - %0 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor - "stablehlo.return"(%0) : (tensor) -> () - }) { - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - channel_handle = #stablehlo.channel_handle - } : (tensor<3xf64>) -> tensor<3xf64> - %3, %4 = "stablehlo.all_reduce"(%x, %y) ({ - ^computation(%lhs: tensor, %rhs: tensor) : - %2 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor - "stablehlo.return"(%2) : (tensor) -> () - }) { - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - channel_handle = #stablehlo.channel_handle - } : (tensor, tensor<2xf32>) -> (tensor, tensor<2xf32>) - "stablehlo.return"(%3, %4, %1) : (tensor, tensor<2xf32>, tensor<3xf64>) -> () - } -} -`) - allReduceProgram = []byte( ` module @TestDistributedAllReduce_multiple_values__different_dtype attributes {stablehlo.num_replicas = 2 } { @@ -62,24 +37,19 @@ module @TestDistributedAllReduce_multiple_values__different_dtype attributes {st %0 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor "stablehlo.return"(%0) : (tensor) -> () }) { - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - channel_id = 0 + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> } : (tensor<3xf64>) -> tensor<3xf64> %3, %4 = "stablehlo.all_reduce"(%x, %y) ({ ^computation(%lhs: tensor, %rhs: tensor) : %2 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor "stablehlo.return"(%2) : (tensor) -> () }) { - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - channel_id = 1 + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> } : (tensor, tensor<2xf32>) -> (tensor, tensor<2xf32>) "stablehlo.return"(%3, %4, %1) : (tensor, tensor<2xf32>, tensor<3xf64>) -> () } } `) - - _ = allReduceProgram - _ = allReduceProgramFail ) func TestCollectiveAllReduce(t *testing.T) { From 26fd9b31087284b21c7412851c4b7a862ab7ea38 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Tue, 25 Nov 2025 17:31:54 +0100 Subject: [PATCH 9/9] Added "Release Candidate" tag to CHANGELOG on version that is being worked. --- docs/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 1028be3..2741b33 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,6 +1,6 @@ # Gopjrt Changelog -# 0.10.0 Added Shardy support. +# 0.10.0 - (Release Candidate) Added Shardy support. - Package `cmd/gopjrt_installer`: - Link `libcublasLt.so.13` and `libcublas.so.13` to the `lib` subdirectory of the installation directory given.