Skip to content

Commit 83aa1ba

Browse files
authored
support dictionary-encoded arrays (#52)
* wip: dictionary encoding * Pass through args * update check on readme * support dictionary encoded arrays * update readme
1 parent 6b83961 commit 83aa1ba

File tree

6 files changed

+208
-26
lines changed

6 files changed

+208
-26
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,10 @@ Most of the unsupported types should be pretty straightforward to implement; the
152152

153153
### Decimal
154154

155-
- [ ] Decimal128 (failing a test)
156-
- [ ] Decimal256 (failing a test)
155+
- [ ] Decimal128 (failing a test, this may be [#37920])
156+
- [ ] Decimal256 (failing a test, this may be [#37920])
157+
158+
[#37920]: https://github.com/apache/arrow/issues/37920
157159

158160
### Temporal Types
159161

@@ -174,7 +176,7 @@ Most of the unsupported types should be pretty straightforward to implement; the
174176
- [ ] Map
175177
- [x] Dense Union
176178
- [x] Sparse Union
177-
- [ ] Dictionary-encoded arrays
179+
- [x] Dictionary-encoded arrays
178180

179181
### Extension Types
180182

src/field.ts

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ export function parseField(buffer: ArrayBuffer, ptr: number): arrow.Field {
6363
const nChildren = dataView.getBigInt64(ptr + 24, true);
6464

6565
const ptrToChildrenPtrs = dataView.getUint32(ptr + 32, true);
66+
const dictionaryPtr = dataView.getUint32(ptr + 36, true);
67+
6668
const childrenFields: arrow.Field[] = new Array(Number(nChildren));
6769
for (let i = 0; i < nChildren; i++) {
6870
childrenFields[i] = parseField(
@@ -71,6 +73,46 @@ export function parseField(buffer: ArrayBuffer, ptr: number): arrow.Field {
7173
);
7274
}
7375

76+
const field = parseFieldContent({
77+
formatString,
78+
flags,
79+
name,
80+
childrenFields,
81+
metadata,
82+
});
83+
84+
if (dictionaryPtr !== 0) {
85+
const dictionaryValuesField = parseField(buffer, dictionaryPtr);
86+
const dictionaryType = new arrow.Dictionary(
87+
dictionaryValuesField,
88+
field.type,
89+
null,
90+
flags.dictionaryOrdered,
91+
);
92+
return new arrow.Field(
93+
field.name,
94+
dictionaryType,
95+
flags.nullable,
96+
metadata,
97+
);
98+
}
99+
100+
return field;
101+
}
102+
103+
function parseFieldContent({
104+
formatString,
105+
flags,
106+
name,
107+
childrenFields,
108+
metadata,
109+
}: {
110+
formatString: string;
111+
flags: Flags;
112+
name: string;
113+
childrenFields: arrow.Field[];
114+
metadata: Map<string, string> | null;
115+
}): arrow.Field {
74116
const primitiveType = formatMapping[formatString];
75117
if (primitiveType) {
76118
return new arrow.Field(name, primitiveType, flags.nullable, metadata);

src/vector.ts

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ export function parseData<T extends DataType>(
6969
}
7070

7171
const ptrToChildrenPtrs = dataView.getUint32(ptr + 44, true);
72+
const dictionaryPtr = dataView.getUint32(ptr + 48, true);
73+
7274
const children: arrow.Data[] = new Array(Number(nChildren));
7375
for (let i = 0; i < nChildren; i++) {
7476
children[i] = parseData(
@@ -79,6 +81,77 @@ export function parseData<T extends DataType>(
7981
);
8082
}
8183

84+
// Special case for handling dictionary-encoded arrays
85+
if (dictionaryPtr !== 0) {
86+
const dictionaryType = dataType as unknown as arrow.Dictionary;
87+
88+
// the parent structure points to the index data, the ArrowArray.dictionary
89+
// points to the dictionary values array.
90+
const indicesType = dictionaryType.indices;
91+
const dictionaryIndices = parseDataContent({
92+
dataType: indicesType,
93+
dataView,
94+
copy,
95+
length,
96+
nullCount,
97+
offset,
98+
nChildren,
99+
children,
100+
bufferPtrs,
101+
});
102+
103+
const valueType = dictionaryType.dictionary.type;
104+
const dictionaryValues = parseData(buffer, dictionaryPtr, valueType, copy);
105+
106+
// @ts-expect-error we're casting to dictionary type
107+
return arrow.makeData({
108+
type: dictionaryType,
109+
// TODO: double check that this offset should be set on both the values
110+
// and indices of the dictionary array
111+
offset,
112+
length,
113+
nullCount,
114+
nullBitmap: dictionaryIndices.nullBitmap,
115+
// Note: Here we need to pass in the _raw TypedArray_ not a Data instance
116+
data: dictionaryIndices.values,
117+
dictionary: arrow.makeVector(dictionaryValues),
118+
});
119+
} else {
120+
return parseDataContent({
121+
dataType,
122+
dataView,
123+
copy,
124+
length,
125+
nullCount,
126+
offset,
127+
nChildren,
128+
children,
129+
bufferPtrs,
130+
});
131+
}
132+
}
133+
134+
function parseDataContent<T extends DataType>({
135+
dataType,
136+
dataView,
137+
copy,
138+
length,
139+
nullCount,
140+
offset,
141+
nChildren,
142+
children,
143+
bufferPtrs,
144+
}: {
145+
dataType: T;
146+
dataView: DataView;
147+
copy: boolean;
148+
length: number;
149+
nullCount: number;
150+
offset: number;
151+
nChildren: number;
152+
children: arrow.Data[];
153+
bufferPtrs: Uint32Array;
154+
}): arrow.Data<T> {
82155
if (DataType.isNull(dataType)) {
83156
return arrow.makeData({
84157
type: dataType,
@@ -653,7 +726,6 @@ export function parseData<T extends DataType>(
653726
});
654727
}
655728

656-
// TODO: map arrays, dictionary encoding
657729
throw new Error(`Unsupported type ${dataType}`);
658730
}
659731

tests/ffi.test.ts

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,7 @@ describe("binary", (t) => {
196196
);
197197

198198
const originalField = TEST_TABLE.schema.fields[columnIndex];
199-
// declare it's not null
200-
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
199+
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
201200
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
202201
const field = parseField(WASM_MEMORY.buffer, fieldPtr);
203202

@@ -277,8 +276,7 @@ describe("string", (t) => {
277276
);
278277

279278
const originalField = TEST_TABLE.schema.fields[columnIndex];
280-
// declare it's not null
281-
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
279+
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
282280
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
283281
const field = parseField(WASM_MEMORY.buffer, fieldPtr);
284282

@@ -346,8 +344,7 @@ describe("boolean", (t) => {
346344
);
347345

348346
const originalField = TEST_TABLE.schema.fields[columnIndex];
349-
// declare it's not null
350-
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
347+
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
351348
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
352349
const field = parseField(WASM_MEMORY.buffer, fieldPtr);
353350

@@ -379,8 +376,7 @@ describe("null array", (t) => {
379376
);
380377

381378
const originalField = TEST_TABLE.schema.fields[columnIndex];
382-
// declare it's not null
383-
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
379+
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
384380
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
385381
const field = parseField(WASM_MEMORY.buffer, fieldPtr);
386382

@@ -412,8 +408,7 @@ describe("list array", (t) => {
412408
);
413409

414410
const originalField = TEST_TABLE.schema.fields[columnIndex];
415-
// declare it's not null
416-
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
411+
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
417412
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
418413
const field = parseField(WASM_MEMORY.buffer, fieldPtr);
419414

@@ -499,8 +494,7 @@ describe("extension array", (t) => {
499494
);
500495

501496
const originalField = TEST_TABLE.schema.fields[columnIndex];
502-
// declare it's not null
503-
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
497+
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
504498
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
505499
const field = parseField(WASM_MEMORY.buffer, fieldPtr);
506500

@@ -544,8 +538,7 @@ describe("extension array", (t) => {
544538
// );
545539

546540
// const originalField = TEST_TABLE.schema.fields[columnIndex];
547-
// // declare it's not null
548-
// const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
541+
// const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
549542
// const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
550543
// const field = parseField(WASM_MEMORY.buffer, fieldPtr);
551544

@@ -572,8 +565,7 @@ describe("date32", (t) => {
572565
);
573566

574567
const originalField = TEST_TABLE.schema.fields[columnIndex];
575-
// declare it's not null
576-
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
568+
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
577569
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
578570
const field = parseField(WASM_MEMORY.buffer, fieldPtr);
579571

@@ -606,8 +598,7 @@ describe("date32", (t) => {
606598
// );
607599

608600
// const originalField = TEST_TABLE.schema.fields[columnIndex];
609-
// // declare it's not null
610-
// const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
601+
// const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
611602
// const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
612603
// const field = parseField(WASM_MEMORY.buffer, fieldPtr);
613604

@@ -634,8 +625,7 @@ describe("duration", (t) => {
634625
);
635626

636627
const originalField = TEST_TABLE.schema.fields[columnIndex];
637-
// declare it's not null
638-
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
628+
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
639629
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
640630
const field = parseField(WASM_MEMORY.buffer, fieldPtr);
641631

@@ -667,8 +657,7 @@ describe("nullable int", (t) => {
667657
);
668658

669659
const originalField = TEST_TABLE.schema.fields[columnIndex];
670-
// declare it's not null
671-
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
660+
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
672661
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
673662
const field = parseField(WASM_MEMORY.buffer, fieldPtr);
674663

@@ -693,3 +682,67 @@ describe("nullable int", (t) => {
693682
it("copy=false", () => test(false));
694683
it("copy=true", () => test(true));
695684
});
685+
686+
describe("dictionary encoded string", (t) => {
687+
function test(copy: boolean) {
688+
let columnIndex = TEST_TABLE.schema.fields.findIndex(
689+
(field) => field.name == "dictionary_encoded_string"
690+
);
691+
692+
const originalField = TEST_TABLE.schema.fields[columnIndex];
693+
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
694+
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
695+
const field = parseField(WASM_MEMORY.buffer, fieldPtr);
696+
697+
expect(field.name).toStrictEqual(originalField.name);
698+
expect(field.typeId).toStrictEqual(originalField.typeId);
699+
expect(field.nullable).toStrictEqual(originalField.nullable);
700+
701+
const arrayPtr = FFI_TABLE.arrayAddr(0, columnIndex);
702+
const wasmVector = parseVector(
703+
WASM_MEMORY.buffer,
704+
arrayPtr,
705+
field.type,
706+
copy
707+
);
708+
709+
for (let i = 0; i < 3; i++) {
710+
expect(originalVector.get(i)).toStrictEqual(wasmVector.get(i));
711+
}
712+
}
713+
714+
it("copy=false", () => test(false));
715+
it("copy=true", () => test(true));
716+
});
717+
718+
describe("dictionary encoded string (with nulls)", (t) => {
719+
function test(copy: boolean) {
720+
let columnIndex = TEST_TABLE.schema.fields.findIndex(
721+
(field) => field.name == "dictionary_encoded_string_null"
722+
);
723+
724+
const originalField = TEST_TABLE.schema.fields[columnIndex];
725+
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
726+
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
727+
const field = parseField(WASM_MEMORY.buffer, fieldPtr);
728+
729+
expect(field.name).toStrictEqual(originalField.name);
730+
expect(field.typeId).toStrictEqual(originalField.typeId);
731+
expect(field.nullable).toStrictEqual(originalField.nullable);
732+
733+
const arrayPtr = FFI_TABLE.arrayAddr(0, columnIndex);
734+
const wasmVector = parseVector(
735+
WASM_MEMORY.buffer,
736+
arrayPtr,
737+
field.type,
738+
copy
739+
);
740+
741+
for (let i = 0; i < 3; i++) {
742+
expect(originalVector.get(i)).toStrictEqual(wasmVector.get(i));
743+
}
744+
}
745+
746+
it("copy=false", () => test(false));
747+
it("copy=true", () => test(true));
748+
});

tests/pyarrow_generate_data.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pandas as pd
66
import pyarrow as pa
7+
import pyarrow.compute as pc
78
import pyarrow.feather as feather
89

910

@@ -194,6 +195,16 @@ def dense_union_array() -> pa.Array:
194195
return union_arr
195196

196197

198+
def dictionary_encoded_string_array() -> pa.DictionaryArray:
199+
arr = pa.StringArray.from_pandas(["a", "a", "b"])
200+
return pc.dictionary_encode(arr)
201+
202+
203+
def dictionary_encoded_string_array_null() -> pa.DictionaryArray:
204+
arr = pa.StringArray.from_pandas(["a", "a", None])
205+
return pc.dictionary_encode(arr)
206+
207+
197208
class MyExtensionType(pa.ExtensionType):
198209
"""
199210
Refer to https://arrow.apache.org/docs/python/extending_types.html for
@@ -243,6 +254,8 @@ def table() -> pa.Table:
243254
"sparse_union": sparse_union_array(),
244255
"dense_union": dense_union_array(),
245256
"duration": duration_array(),
257+
"dictionary_encoded_string": dictionary_encoded_string_array(),
258+
"dictionary_encoded_string_null": dictionary_encoded_string_array_null(),
246259
}
247260
)
248261

tests/table.arrow

1016 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)