Skip to content

Commit bae64d8

Browse files
authored
feat(sql): support Utf8/LargeUtf8 for Arrow variant decoding (#771)
1 parent f69909e commit bae64d8

1 file changed

Lines changed: 78 additions & 9 deletions

File tree

sql/src/value/arrow_decoder.rs

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,38 @@ impl
9999
if field.is_nullable() && array.is_null(seq) {
100100
return Ok(Value::Null);
101101
}
102-
match array.as_any().downcast_ref::<LargeBinaryArray>() {
103-
Some(array) => {
104-
if settings.arrow_result_version.unwrap_or_default() > 1 {
105-
Ok(Value::Variant(
106-
String::from_utf8_lossy(array.value(seq)).into_owned(),
107-
))
108-
} else {
109-
Ok(Value::Variant(RawJsonb::new(array.value(seq)).to_string()))
102+
match array.data_type() {
103+
ArrowDataType::Utf8 => match array.as_any().downcast_ref::<StringArray>() {
104+
Some(array) => Ok(Value::Variant(array.value(seq).to_string())),
105+
None => Err(ConvertError::new("variant", format!("{array:?}")).into()),
106+
},
107+
ArrowDataType::LargeUtf8 => {
108+
match array.as_any().downcast_ref::<LargeStringArray>() {
109+
Some(array) => Ok(Value::Variant(array.value(seq).to_string())),
110+
None => {
111+
Err(ConvertError::new("variant", format!("{array:?}")).into())
112+
}
113+
}
114+
}
115+
ArrowDataType::LargeBinary => {
116+
match array.as_any().downcast_ref::<LargeBinaryArray>() {
117+
Some(array) => {
118+
if settings.arrow_result_version.unwrap_or_default() > 1 {
119+
Ok(Value::Variant(
120+
String::from_utf8_lossy(array.value(seq)).into_owned(),
121+
))
122+
} else {
123+
Ok(Value::Variant(
124+
RawJsonb::new(array.value(seq)).to_string(),
125+
))
126+
}
127+
}
128+
None => {
129+
Err(ConvertError::new("variant", format!("{array:?}")).into())
130+
}
110131
}
111132
}
112-
None => Err(ConvertError::new("variant", format!("{array:?}")).into()),
133+
_ => Err(ConvertError::new("variant", format!("{array:?}")).into()),
113134
}
114135
}
115136
ARROW_EXT_TYPE_TIMESTAMP_TIMEZONE => {
@@ -456,3 +477,51 @@ impl
456477
}
457478
}
458479
}
480+
481+
#[cfg(test)]
482+
mod tests {
483+
use super::*;
484+
use arrow_array::ArrayRef;
485+
use std::collections::HashMap;
486+
487+
fn variant_field(data_type: ArrowDataType) -> ArrowField {
488+
ArrowField::new("v", data_type, false).with_metadata(HashMap::from([(
489+
EXTENSION_KEY.to_string(),
490+
ARROW_EXT_TYPE_VARIANT.to_string(),
491+
)]))
492+
}
493+
494+
#[test]
495+
fn decode_variant_from_utf8_array() {
496+
let field = variant_field(ArrowDataType::Utf8);
497+
let array: ArrayRef = Arc::new(StringArray::from(vec!["{\"a\":1}"]));
498+
499+
let value = Value::try_from((&field, &array, 0, &ResultFormatSettings::default())).unwrap();
500+
501+
assert_eq!(value, Value::Variant("{\"a\":1}".to_string()));
502+
}
503+
504+
#[test]
505+
fn decode_variant_from_large_utf8_array() {
506+
let field = variant_field(ArrowDataType::LargeUtf8);
507+
let array: ArrayRef = Arc::new(LargeStringArray::from(vec!["{\"b\":2}"]));
508+
509+
let value = Value::try_from((&field, &array, 0, &ResultFormatSettings::default())).unwrap();
510+
511+
assert_eq!(value, Value::Variant("{\"b\":2}".to_string()));
512+
}
513+
514+
#[test]
515+
fn decode_variant_from_large_binary_array_for_v2() {
516+
let field = variant_field(ArrowDataType::LargeBinary);
517+
let array: ArrayRef = Arc::new(LargeBinaryArray::from(vec![b"{\"c\":3}".as_slice()]));
518+
let settings = ResultFormatSettings {
519+
arrow_result_version: Some(2),
520+
..ResultFormatSettings::default()
521+
};
522+
523+
let value = Value::try_from((&field, &array, 0, &settings)).unwrap();
524+
525+
assert_eq!(value, Value::Variant("{\"c\":3}".to_string()));
526+
}
527+
}

0 commit comments

Comments
 (0)