Skip to content

Commit fce680d

Browse files
committed
[C++][Compute] Support view arrays in selection kernels
1 parent 16fe342 commit fce680d

7 files changed

Lines changed: 363 additions & 6 deletions

cpp/src/arrow/compute/kernels/scalar_cast_internal.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,8 @@ void AddZeroCopyCast(Type::type in_type_id, InputType in_type, OutputType out_ty
280280
}
281281

282282
static bool CanCastFromDictionary(Type::type type_id) {
283-
/// TODO(GH-43010): add is_binary_view_like() here once array_take
284-
/// can handle string-views
285283
return (is_primitive(type_id) || is_base_binary_like(type_id) ||
286-
is_fixed_size_binary(type_id));
284+
is_binary_view_like(type_id) || is_fixed_size_binary(type_id));
287285
}
288286

289287
void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* func) {

cpp/src/arrow/compute/kernels/scalar_cast_test.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4368,6 +4368,37 @@ TEST(Cast, FromDictionary) {
43684368
}
43694369
}
43704370

4371+
TEST(Cast, DictionaryDecodeFromViewDictionary) {
4372+
for (const auto& value_type : {binary_view(), utf8_view()}) {
4373+
ARROW_SCOPED_TRACE(value_type->ToString());
4374+
auto dict_values = ArrayFromJSON(
4375+
value_type, R"(["alpha", "long-value-over-inline-limit", "omega"])");
4376+
auto indices = ArrayFromJSON(int8(), "[0, 1, null, 2, 1]");
4377+
ASSERT_OK_AND_ASSIGN(auto dict_arr,
4378+
DictionaryArray::FromArrays(dictionary(int8(), value_type),
4379+
indices, dict_values));
4380+
auto expected = ArrayFromJSON(
4381+
value_type,
4382+
R"(["alpha", "long-value-over-inline-limit", null, "omega", "long-value-over-inline-limit"])");
4383+
4384+
ASSERT_OK_AND_ASSIGN(Datum decoded, CallFunction("dictionary_decode", {dict_arr}));
4385+
ValidateOutput(decoded);
4386+
AssertArraysEqual(*expected, *decoded.make_array(), /*verbose=*/true);
4387+
CheckCast(dict_arr, expected);
4388+
4389+
auto chunked_dict = std::make_shared<ChunkedArray>(
4390+
ArrayVector{dict_arr->Slice(0, 2), dict_arr->Slice(2, 3)});
4391+
ASSERT_OK_AND_ASSIGN(Datum decoded_chunked,
4392+
CallFunction("dictionary_decode", {chunked_dict}));
4393+
ValidateOutput(decoded_chunked);
4394+
AssertChunkedEqual(
4395+
*ChunkedArrayFromJSON(value_type,
4396+
{R"(["alpha", "long-value-over-inline-limit"])",
4397+
R"([null, "omega", "long-value-over-inline-limit"])"}),
4398+
*decoded_chunked.chunked_array());
4399+
}
4400+
}
4401+
43714402
std::shared_ptr<Array> SmallintArrayFromJSON(const std::string& json_data) {
43724403
auto arr = ArrayFromJSON(int16(), json_data);
43734404
auto ext_data = arr->data()->Copy();

cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,11 @@ Status SparseUnionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResu
919919
return FilterWithTakeExec(SparseUnionTakeExec, ctx, batch, out);
920920
}
921921

922+
Status VarBinaryViewFilterExec(KernelContext* ctx, const ExecSpan& batch,
923+
ExecResult* out) {
924+
return FilterWithTakeExec(VarBinaryViewTakeExec, ctx, batch, out);
925+
}
926+
922927
// ----------------------------------------------------------------------
923928
// Implement Filter metafunction
924929

@@ -1094,6 +1099,8 @@ void PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
10941099
{InputType(match::Primitive()), plain_filter, PrimitiveFilterExec},
10951100
{InputType(match::BinaryLike()), plain_filter, BinaryFilterExec},
10961101
{InputType(match::LargeBinaryLike()), plain_filter, BinaryFilterExec},
1102+
{InputType(Type::BINARY_VIEW), plain_filter, VarBinaryViewFilterExec},
1103+
{InputType(Type::STRING_VIEW), plain_filter, VarBinaryViewFilterExec},
10971104
{InputType(null()), plain_filter, NullFilterExec},
10981105
{InputType(Type::FIXED_SIZE_BINARY), plain_filter, PrimitiveFilterExec},
10991106
{InputType(Type::DECIMAL32), plain_filter, PrimitiveFilterExec},
@@ -1116,6 +1123,8 @@ void PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
11161123
{InputType(match::Primitive()), ree_filter, PrimitiveFilterExec},
11171124
{InputType(match::BinaryLike()), ree_filter, BinaryFilterExec},
11181125
{InputType(match::LargeBinaryLike()), ree_filter, BinaryFilterExec},
1126+
{InputType(Type::BINARY_VIEW), ree_filter, VarBinaryViewFilterExec},
1127+
{InputType(Type::STRING_VIEW), ree_filter, VarBinaryViewFilterExec},
11191128
{InputType(null()), ree_filter, NullFilterExec},
11201129
{InputType(Type::FIXED_SIZE_BINARY), ree_filter, PrimitiveFilterExec},
11211130
{InputType(Type::DECIMAL32), ree_filter, PrimitiveFilterExec},

cpp/src/arrow/compute/kernels/vector_selection_internal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ Status MapFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
7575

7676
Status VarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
7777
Status LargeVarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
78+
Status VarBinaryViewTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
7879
Status FixedWidthTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
7980
Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
8081
Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);

cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
#include "arrow/array/builder_primitive.h"
2626
#include "arrow/array/concatenate.h"
27+
#include "arrow/buffer.h"
2728
#include "arrow/buffer_builder.h"
2829
#include "arrow/chunked_array.h"
2930
#include "arrow/compute/api_vector.h"
@@ -488,6 +489,136 @@ Status FixedWidthTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult*
488489

489490
namespace {
490491

492+
template <typename IndexCType>
493+
Status VarBinaryViewTakeTyped(const ArraySpan& values, const ArraySpan& indices,
494+
BinaryViewType::c_type* out_views, uint8_t* out_validity,
495+
int64_t* valid_count) {
496+
const auto* source_views = values.GetValues<BinaryViewType::c_type>(1);
497+
const auto* index_values = indices.GetValues<IndexCType>(1);
498+
499+
const bool values_may_have_nulls = values.MayHaveNulls();
500+
const bool indices_may_have_nulls = indices.MayHaveNulls();
501+
502+
if (!values_may_have_nulls && !indices_may_have_nulls) {
503+
for (int64_t out_i = 0; out_i < indices.length; ++out_i) {
504+
out_views[out_i] = source_views[static_cast<int64_t>(index_values[out_i])];
505+
}
506+
*valid_count = indices.length;
507+
return Status::OK();
508+
}
509+
510+
for (int64_t out_i = 0; out_i < indices.length; ++out_i) {
511+
if (indices_may_have_nulls &&
512+
!bit_util::GetBit(indices.buffers[0].data, indices.offset + out_i)) {
513+
continue;
514+
}
515+
516+
const int64_t source_i = static_cast<int64_t>(index_values[out_i]);
517+
const bool source_valid =
518+
!values_may_have_nulls ||
519+
bit_util::GetBit(values.buffers[0].data, values.offset + source_i);
520+
if (!source_valid) {
521+
continue;
522+
}
523+
524+
out_views[out_i] = source_views[source_i];
525+
if (out_validity != nullptr) {
526+
bit_util::SetBit(out_validity, out_i);
527+
}
528+
++(*valid_count);
529+
}
530+
531+
return Status::OK();
532+
}
533+
534+
Status VarBinaryViewTakeDispatch(const ArraySpan& values, const ArraySpan& indices,
535+
BinaryViewType::c_type* out_views, uint8_t* out_validity,
536+
int64_t* valid_count) {
537+
switch (indices.type->id()) {
538+
case Type::INT8:
539+
return VarBinaryViewTakeTyped<int8_t>(values, indices, out_views, out_validity,
540+
valid_count);
541+
case Type::INT16:
542+
return VarBinaryViewTakeTyped<int16_t>(values, indices, out_views, out_validity,
543+
valid_count);
544+
case Type::INT32:
545+
return VarBinaryViewTakeTyped<int32_t>(values, indices, out_views, out_validity,
546+
valid_count);
547+
case Type::INT64:
548+
return VarBinaryViewTakeTyped<int64_t>(values, indices, out_views, out_validity,
549+
valid_count);
550+
case Type::UINT8:
551+
return VarBinaryViewTakeTyped<uint8_t>(values, indices, out_views, out_validity,
552+
valid_count);
553+
case Type::UINT16:
554+
return VarBinaryViewTakeTyped<uint16_t>(values, indices, out_views, out_validity,
555+
valid_count);
556+
case Type::UINT32:
557+
return VarBinaryViewTakeTyped<uint32_t>(values, indices, out_views, out_validity,
558+
valid_count);
559+
case Type::UINT64:
560+
return VarBinaryViewTakeTyped<uint64_t>(values, indices, out_views, out_validity,
561+
valid_count);
562+
default:
563+
return Status::NotImplemented("Unsupported index type for take: ", *indices.type);
564+
}
565+
}
566+
567+
} // namespace
568+
569+
Status VarBinaryViewTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
570+
const ArraySpan& values = batch[0].array;
571+
const ArraySpan& indices = batch[1].array;
572+
573+
if (TakeState::Get(ctx).boundscheck) {
574+
RETURN_NOT_OK(CheckIndexBounds(indices, values.length));
575+
}
576+
577+
const int64_t out_length = indices.length;
578+
const bool may_have_nulls = values.MayHaveNulls() || indices.MayHaveNulls();
579+
const auto data_buffers = values.GetVariadicBuffers();
580+
581+
ARROW_ASSIGN_OR_RAISE(
582+
auto views_buf,
583+
AllocateBuffer(out_length * static_cast<int64_t>(sizeof(BinaryViewType::c_type)),
584+
ctx->memory_pool()));
585+
auto* out_views = reinterpret_cast<BinaryViewType::c_type*>(views_buf->mutable_data());
586+
if (may_have_nulls && views_buf->size() > 0) {
587+
std::memset(out_views, 0, views_buf->size());
588+
}
589+
590+
std::shared_ptr<Buffer> validity_buf;
591+
uint8_t* out_validity = nullptr;
592+
if (may_have_nulls) {
593+
ARROW_ASSIGN_OR_RAISE(validity_buf,
594+
AllocateEmptyBitmap(out_length, ctx->memory_pool()));
595+
if (validity_buf->size() > 0) {
596+
std::memset(validity_buf->mutable_data(), 0, validity_buf->size());
597+
}
598+
out_validity = validity_buf->mutable_data();
599+
}
600+
601+
int64_t valid_count = 0;
602+
RETURN_NOT_OK(
603+
VarBinaryViewTakeDispatch(values, indices, out_views, out_validity, &valid_count));
604+
605+
const int64_t null_count = out_length - valid_count;
606+
BufferVector buffers;
607+
buffers.reserve(2 + data_buffers.size());
608+
buffers.push_back(null_count == 0 ? nullptr : std::move(validity_buf));
609+
buffers.push_back(std::move(views_buf));
610+
611+
for (const auto& data_buffer : data_buffers) {
612+
buffers.push_back(data_buffer);
613+
}
614+
615+
out->value = ArrayData::Make(values.type->GetSharedPtr(), out_length,
616+
std::move(buffers), null_count, /*offset=*/0);
617+
return Status::OK();
618+
}
619+
620+
namespace {
621+
491622
// ----------------------------------------------------------------------
492623
// Null take
493624

@@ -740,6 +871,8 @@ void PopulateTakeKernels(std::vector<SelectionKernelData>* out) {
740871
{InputType(match::Primitive()), take_indices, FixedWidthTakeExec},
741872
{InputType(match::BinaryLike()), take_indices, VarBinaryTakeExec},
742873
{InputType(match::LargeBinaryLike()), take_indices, LargeVarBinaryTakeExec},
874+
{InputType(Type::BINARY_VIEW), take_indices, VarBinaryViewTakeExec},
875+
{InputType(Type::STRING_VIEW), take_indices, VarBinaryViewTakeExec},
743876
{InputType(match::FixedSizeBinaryLike()), take_indices, FixedWidthTakeExec},
744877
{InputType(null()), take_indices, NullTakeExec},
745878
{InputType(Type::DICTIONARY), take_indices, DictionaryTake},

0 commit comments

Comments
 (0)