@@ -2338,20 +2338,20 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
23382338
23392339 // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
23402340 bool yarn_ramp_tensor_updated = false ;
2341- ggml_cann_pool_alloc yarn_ramp_allocator (ctx.pool ());
23422341 acl_tensor_ptr acl_yarn_ramp_tensor;
23432342 if (ext_factor != 0 &&
23442343 // TODO: check more parameter.
23452344 (ctx.rope_cache .theta_scale_length != theta_scale_length || ctx.rope_cache .freq_scale != freq_scale)) {
23462345 yarn_ramp_tensor_updated = true ;
2347-
2346+ if (ctx.rope_cache .yarn_ramp_cache != nullptr ) {
2347+ ACL_CHECK (aclrtFree (ctx.rope_cache .yarn_ramp_cache ));
2348+ }
2349+ ACL_CHECK (aclrtMalloc (&ctx.rope_cache .yarn_ramp_cache , theta_scale_length * sizeof (float ), ACL_MEM_MALLOC_HUGE_FIRST));
23482350 // -rope_yarn_ramp
23492351 // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
23502352 // return MIN(1, MAX(0, y)) - 1;
2351- yarn_ramp_allocator.alloc (theta_scale_length * sizeof (float ));
2352- void * yarn_ramp_buffer = yarn_ramp_allocator.get ();
23532353 acl_yarn_ramp_tensor =
2354- ggml_cann_create_tensor (yarn_ramp_buffer , ACL_FLOAT, sizeof (float ), theta_scale_ne, theta_scale_nb, 1 );
2354+ ggml_cann_create_tensor (ctx. rope_cache . yarn_ramp_cache , ACL_FLOAT, sizeof (float ), theta_scale_ne, theta_scale_nb, 1 );
23552355 float zero_value = 0 , one_value = 1 ;
23562356 float denom_safe_value = MAX (0 .001f , corr_dims[1 ] - corr_dims[0 ]);
23572357 acl_scalar_ptr low = ggml_cann_create_scalar (&corr_dims[0 ], aclDataType::ACL_FLOAT);
@@ -2381,8 +2381,10 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
23812381 acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar (&freq_scale_1, aclDataType::ACL_FLOAT);
23822382 GGML_CANN_CALL_ACLNN_OP (ctx, InplaceMuls, acl_yarn_ramp_tensor.get (), freq_scale_1_sc.get ());
23832383 GGML_CANN_CALL_ACLNN_OP (ctx, InplaceAdds, acl_yarn_ramp_tensor.get (), freq_scale_sc.get (), one.get ());
2384+ } else {
2385+ acl_yarn_ramp_tensor =
2386+ ggml_cann_create_tensor (ctx.rope_cache .yarn_ramp_cache , ACL_FLOAT, sizeof (float ), theta_scale_ne, theta_scale_nb, 1 );
23842387 }
2385-
23862388 // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
23872389 if (ext_factor != 0 ) {
23882390 if (theta_scale_updated || yarn_ramp_tensor_updated) {
0 commit comments