@@ -82,6 +82,50 @@ static void localPermute(const cudecompHandle_t handle, const std::array<int64_t
8282 if (extent_out[i] == 0 ) return ;
8383 }
8484
85+ // Workaround for an out-of-bounds host write bug in cuTENSOR triggered when the
86+ // total number of tensor elements exceeds INT32_MAX/2. We split the tensor so each
87+ // cuTENSOR call stays below that limit.
88+ static constexpr int64_t CUTENSOR_EXTENT_LIMIT = (int64_t )std::numeric_limits<int32_t >::max () / 2 ;
89+ int64_t total_elems = extent_in[0 ] * extent_in[1 ] * extent_in[2 ];
90+ if (handle->cutensor_needs_permute_chunking && total_elems > CUTENSOR_EXTENT_LIMIT) {
91+
92+ // Always pass explicit strides when splitting
93+ std::array<int64_t , 3 > actual_strides_in = strides_in;
94+ if (!anyNonzeros (strides_in)) { actual_strides_in = {extent_in[1 ] * extent_in[2 ], extent_in[2 ], 1 }; }
95+ std::array<int64_t , 3 > actual_strides_out = strides_out;
96+ if (!anyNonzeros (strides_out)) { actual_strides_out = {extent_out[1 ] * extent_out[2 ], extent_out[2 ], 1 }; }
97+ // Try to split on input dims, starting with outermost dim.
98+ std::array<int , 3 > inv_order_out;
99+ for (int i = 0 ; i < 3 ; ++i)
100+ inv_order_out[order_out[i]] = i;
101+ int split_dim_in = -1 ;
102+ int64_t elems_per_slice = 0 ;
103+ for (int j = 2 ; j >= 0 ; --j) {
104+ elems_per_slice = total_elems / extent_in[j];
105+ if (elems_per_slice <= CUTENSOR_EXTENT_LIMIT) {
106+ split_dim_in = j;
107+ break ;
108+ }
109+ }
110+
111+ if (split_dim_in >= 0 ) {
112+ // Run localPermute multiple times, once per chunk.
113+ int64_t chunk = std::max ((int64_t )1 , CUTENSOR_EXTENT_LIMIT / elems_per_slice);
114+ for (int64_t offset = 0 ; offset < extent_in[split_dim_in]; offset += chunk) {
115+ int64_t this_chunk = std::min (chunk, extent_in[split_dim_in] - offset);
116+ std::array<int64_t , 3 > chunk_extent_in = extent_in;
117+ chunk_extent_in[split_dim_in] = this_chunk;
118+ localPermute (handle, chunk_extent_in, order_out, actual_strides_in, actual_strides_out,
119+ input + offset * actual_strides_in[split_dim_in],
120+ output + offset * actual_strides_out[inv_order_out[split_dim_in]], stream);
121+ }
122+ return ;
123+ }
124+ // All pairwise products exceed the limit so splitting isn't possible (requires each dimension > sqrt(INT32_MAX/2)
125+ // ~= 32768). This is not a realistic scenario, but throw an error here for completeness.
126+ THROW_INTERNAL_ERROR (" Input too large to work around CUTENSOR large-tensor bug" );
127+ }
128+
85129 auto strides_in_ptr = anyNonzeros (strides_in) ? strides_in.data () : nullptr ;
86130 auto strides_out_ptr = anyNonzeros (strides_out) ? strides_out.data () : nullptr ;
87131
0 commit comments