|
24 | 24 |
|
25 | 25 | #include "arrow/array/builder_primitive.h" |
26 | 26 | #include "arrow/array/concatenate.h" |
| 27 | +#include "arrow/buffer.h" |
27 | 28 | #include "arrow/buffer_builder.h" |
28 | 29 | #include "arrow/chunked_array.h" |
29 | 30 | #include "arrow/compute/api_vector.h" |
@@ -488,6 +489,136 @@ Status FixedWidthTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* |
488 | 489 |
|
489 | 490 | namespace { |
490 | 491 |
|
| 492 | +template <typename IndexCType> |
| 493 | +Status VarBinaryViewTakeTyped(const ArraySpan& values, const ArraySpan& indices, |
| 494 | + BinaryViewType::c_type* out_views, uint8_t* out_validity, |
| 495 | + int64_t* valid_count) { |
| 496 | + const auto* source_views = values.GetValues<BinaryViewType::c_type>(1); |
| 497 | + const auto* index_values = indices.GetValues<IndexCType>(1); |
| 498 | + |
| 499 | + const bool values_may_have_nulls = values.MayHaveNulls(); |
| 500 | + const bool indices_may_have_nulls = indices.MayHaveNulls(); |
| 501 | + |
| 502 | + if (!values_may_have_nulls && !indices_may_have_nulls) { |
| 503 | + for (int64_t out_i = 0; out_i < indices.length; ++out_i) { |
| 504 | + out_views[out_i] = source_views[static_cast<int64_t>(index_values[out_i])]; |
| 505 | + } |
| 506 | + *valid_count = indices.length; |
| 507 | + return Status::OK(); |
| 508 | + } |
| 509 | + |
| 510 | + for (int64_t out_i = 0; out_i < indices.length; ++out_i) { |
| 511 | + if (indices_may_have_nulls && |
| 512 | + !bit_util::GetBit(indices.buffers[0].data, indices.offset + out_i)) { |
| 513 | + continue; |
| 514 | + } |
| 515 | + |
| 516 | + const int64_t source_i = static_cast<int64_t>(index_values[out_i]); |
| 517 | + const bool source_valid = |
| 518 | + !values_may_have_nulls || |
| 519 | + bit_util::GetBit(values.buffers[0].data, values.offset + source_i); |
| 520 | + if (!source_valid) { |
| 521 | + continue; |
| 522 | + } |
| 523 | + |
| 524 | + out_views[out_i] = source_views[source_i]; |
| 525 | + if (out_validity != nullptr) { |
| 526 | + bit_util::SetBit(out_validity, out_i); |
| 527 | + } |
| 528 | + ++(*valid_count); |
| 529 | + } |
| 530 | + |
| 531 | + return Status::OK(); |
| 532 | +} |
| 533 | + |
| 534 | +Status VarBinaryViewTakeDispatch(const ArraySpan& values, const ArraySpan& indices, |
| 535 | + BinaryViewType::c_type* out_views, uint8_t* out_validity, |
| 536 | + int64_t* valid_count) { |
| 537 | + switch (indices.type->id()) { |
| 538 | + case Type::INT8: |
| 539 | + return VarBinaryViewTakeTyped<int8_t>(values, indices, out_views, out_validity, |
| 540 | + valid_count); |
| 541 | + case Type::INT16: |
| 542 | + return VarBinaryViewTakeTyped<int16_t>(values, indices, out_views, out_validity, |
| 543 | + valid_count); |
| 544 | + case Type::INT32: |
| 545 | + return VarBinaryViewTakeTyped<int32_t>(values, indices, out_views, out_validity, |
| 546 | + valid_count); |
| 547 | + case Type::INT64: |
| 548 | + return VarBinaryViewTakeTyped<int64_t>(values, indices, out_views, out_validity, |
| 549 | + valid_count); |
| 550 | + case Type::UINT8: |
| 551 | + return VarBinaryViewTakeTyped<uint8_t>(values, indices, out_views, out_validity, |
| 552 | + valid_count); |
| 553 | + case Type::UINT16: |
| 554 | + return VarBinaryViewTakeTyped<uint16_t>(values, indices, out_views, out_validity, |
| 555 | + valid_count); |
| 556 | + case Type::UINT32: |
| 557 | + return VarBinaryViewTakeTyped<uint32_t>(values, indices, out_views, out_validity, |
| 558 | + valid_count); |
| 559 | + case Type::UINT64: |
| 560 | + return VarBinaryViewTakeTyped<uint64_t>(values, indices, out_views, out_validity, |
| 561 | + valid_count); |
| 562 | + default: |
| 563 | + return Status::NotImplemented("Unsupported index type for take: ", *indices.type); |
| 564 | + } |
| 565 | +} |
| 566 | + |
| 567 | +} // namespace |
| 568 | + |
| 569 | +Status VarBinaryViewTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { |
| 570 | + const ArraySpan& values = batch[0].array; |
| 571 | + const ArraySpan& indices = batch[1].array; |
| 572 | + |
| 573 | + if (TakeState::Get(ctx).boundscheck) { |
| 574 | + RETURN_NOT_OK(CheckIndexBounds(indices, values.length)); |
| 575 | + } |
| 576 | + |
| 577 | + const int64_t out_length = indices.length; |
| 578 | + const bool may_have_nulls = values.MayHaveNulls() || indices.MayHaveNulls(); |
| 579 | + const auto data_buffers = values.GetVariadicBuffers(); |
| 580 | + |
| 581 | + ARROW_ASSIGN_OR_RAISE( |
| 582 | + auto views_buf, |
| 583 | + AllocateBuffer(out_length * static_cast<int64_t>(sizeof(BinaryViewType::c_type)), |
| 584 | + ctx->memory_pool())); |
| 585 | + auto* out_views = reinterpret_cast<BinaryViewType::c_type*>(views_buf->mutable_data()); |
| 586 | + if (may_have_nulls && views_buf->size() > 0) { |
| 587 | + std::memset(out_views, 0, views_buf->size()); |
| 588 | + } |
| 589 | + |
| 590 | + std::shared_ptr<Buffer> validity_buf; |
| 591 | + uint8_t* out_validity = nullptr; |
| 592 | + if (may_have_nulls) { |
| 593 | + ARROW_ASSIGN_OR_RAISE(validity_buf, |
| 594 | + AllocateEmptyBitmap(out_length, ctx->memory_pool())); |
| 595 | + if (validity_buf->size() > 0) { |
| 596 | + std::memset(validity_buf->mutable_data(), 0, validity_buf->size()); |
| 597 | + } |
| 598 | + out_validity = validity_buf->mutable_data(); |
| 599 | + } |
| 600 | + |
| 601 | + int64_t valid_count = 0; |
| 602 | + RETURN_NOT_OK( |
| 603 | + VarBinaryViewTakeDispatch(values, indices, out_views, out_validity, &valid_count)); |
| 604 | + |
| 605 | + const int64_t null_count = out_length - valid_count; |
| 606 | + BufferVector buffers; |
| 607 | + buffers.reserve(2 + data_buffers.size()); |
| 608 | + buffers.push_back(null_count == 0 ? nullptr : std::move(validity_buf)); |
| 609 | + buffers.push_back(std::move(views_buf)); |
| 610 | + |
| 611 | + for (const auto& data_buffer : data_buffers) { |
| 612 | + buffers.push_back(data_buffer); |
| 613 | + } |
| 614 | + |
| 615 | + out->value = ArrayData::Make(values.type->GetSharedPtr(), out_length, |
| 616 | + std::move(buffers), null_count, /*offset=*/0); |
| 617 | + return Status::OK(); |
| 618 | +} |
| 619 | + |
| 620 | +namespace { |
| 621 | + |
491 | 622 | // ---------------------------------------------------------------------- |
492 | 623 | // Null take |
493 | 624 |
|
@@ -740,6 +871,8 @@ void PopulateTakeKernels(std::vector<SelectionKernelData>* out) { |
740 | 871 | {InputType(match::Primitive()), take_indices, FixedWidthTakeExec}, |
741 | 872 | {InputType(match::BinaryLike()), take_indices, VarBinaryTakeExec}, |
742 | 873 | {InputType(match::LargeBinaryLike()), take_indices, LargeVarBinaryTakeExec}, |
| 874 | + {InputType(Type::BINARY_VIEW), take_indices, VarBinaryViewTakeExec}, |
| 875 | + {InputType(Type::STRING_VIEW), take_indices, VarBinaryViewTakeExec}, |
743 | 876 | {InputType(match::FixedSizeBinaryLike()), take_indices, FixedWidthTakeExec}, |
744 | 877 | {InputType(null()), take_indices, NullTakeExec}, |
745 | 878 | {InputType(Type::DICTIONARY), take_indices, DictionaryTake}, |
|
0 commit comments