Skip to content

Commit e3f0032

Browse files
committed
CANN: In the ROPE operator, yarn_ramp uses cache
1 parent dea9ba2 commit e3f0032

2 files changed

Lines changed: 14 additions & 8 deletions

File tree

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,20 +2338,20 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
23382338

23392339
// Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
23402340
bool yarn_ramp_tensor_updated = false;
2341-
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
23422341
acl_tensor_ptr acl_yarn_ramp_tensor;
23432342
if (ext_factor != 0 &&
23442343
// TODO: check more parameter.
23452344
(ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) {
23462345
yarn_ramp_tensor_updated = true;
2347-
2346+
if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
2347+
ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
2348+
}
2349+
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
23482350
// -rope_yarn_ramp
23492351
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
23502352
// return MIN(1, MAX(0, y)) - 1;
2351-
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
2352-
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
23532353
acl_yarn_ramp_tensor =
2354-
ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
2354+
ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
23552355
float zero_value = 0, one_value = 1;
23562356
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
23572357
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
@@ -2381,8 +2381,10 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
23812381
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
23822382
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
23832383
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
2384+
} else {
2385+
acl_yarn_ramp_tensor =
2386+
ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
23842387
}
2385-
23862388
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
23872389
if (ext_factor != 0) {
23882390
if (theta_scale_updated || yarn_ramp_tensor_updated) {

ggml/src/ggml-cann/common.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,12 @@ struct ggml_cann_rope_cache {
315315
if (theta_scale_exp_host) {
316316
free(theta_scale_exp_host);
317317
}
318-
if(position_select_index_host) {
318+
if (position_select_index_host) {
319319
free(position_select_index_host);
320320
}
321+
if (yarn_ramp_cache) {
322+
ACL_CHECK(aclrtFree(yarn_ramp_cache));
323+
}
321324
}
322325

323326
bool equal(int64_t theta_scale_length,
@@ -340,7 +343,7 @@ struct ggml_cann_rope_cache {
340343

341344
void set(int64_t theta_scale_length,
342345
int64_t position_length,
343-
float ext_factor,
346+
float ext_factor,
344347
float theta_scale,
345348
float freq_scale,
346349
float attn_factor,
@@ -370,6 +373,7 @@ struct ggml_cann_rope_cache {
370373
float * theta_scale_exp_host = nullptr;
371374
int * position_select_index_host = nullptr;
372375
void * position_select_index = nullptr;
376+
void * yarn_ramp_cache = nullptr;
373377
// sin/cos cache, used only to accelerate first layer on each device
374378
void * sin_cache = nullptr;
375379
void * cos_cache = nullptr;

0 commit comments

Comments
 (0)