|
99 | 99 | if field.is_nullable() && array.is_null(seq) { |
100 | 100 | return Ok(Value::Null); |
101 | 101 | } |
102 | | - match array.as_any().downcast_ref::<LargeBinaryArray>() { |
103 | | - Some(array) => { |
104 | | - if settings.arrow_result_version.unwrap_or_default() > 1 { |
105 | | - Ok(Value::Variant( |
106 | | - String::from_utf8_lossy(array.value(seq)).into_owned(), |
107 | | - )) |
108 | | - } else { |
109 | | - Ok(Value::Variant(RawJsonb::new(array.value(seq)).to_string())) |
| 102 | + match array.data_type() { |
| 103 | + ArrowDataType::Utf8 => match array.as_any().downcast_ref::<StringArray>() { |
| 104 | + Some(array) => Ok(Value::Variant(array.value(seq).to_string())), |
| 105 | + None => Err(ConvertError::new("variant", format!("{array:?}")).into()), |
| 106 | + }, |
| 107 | + ArrowDataType::LargeUtf8 => { |
| 108 | + match array.as_any().downcast_ref::<LargeStringArray>() { |
| 109 | + Some(array) => Ok(Value::Variant(array.value(seq).to_string())), |
| 110 | + None => { |
| 111 | + Err(ConvertError::new("variant", format!("{array:?}")).into()) |
| 112 | + } |
| 113 | + } |
| 114 | + } |
| 115 | + ArrowDataType::LargeBinary => { |
| 116 | + match array.as_any().downcast_ref::<LargeBinaryArray>() { |
| 117 | + Some(array) => { |
| 118 | + if settings.arrow_result_version.unwrap_or_default() > 1 { |
| 119 | + Ok(Value::Variant( |
| 120 | + String::from_utf8_lossy(array.value(seq)).into_owned(), |
| 121 | + )) |
| 122 | + } else { |
| 123 | + Ok(Value::Variant( |
| 124 | + RawJsonb::new(array.value(seq)).to_string(), |
| 125 | + )) |
| 126 | + } |
| 127 | + } |
| 128 | + None => { |
| 129 | + Err(ConvertError::new("variant", format!("{array:?}")).into()) |
| 130 | + } |
110 | 131 | } |
111 | 132 | } |
112 | | - None => Err(ConvertError::new("variant", format!("{array:?}")).into()), |
| 133 | + _ => Err(ConvertError::new("variant", format!("{array:?}")).into()), |
113 | 134 | } |
114 | 135 | } |
115 | 136 | ARROW_EXT_TYPE_TIMESTAMP_TIMEZONE => { |
@@ -456,3 +477,51 @@ impl |
456 | 477 | } |
457 | 478 | } |
458 | 479 | } |
| 480 | + |
| 481 | +#[cfg(test)] |
| 482 | +mod tests { |
| 483 | + use super::*; |
| 484 | + use arrow_array::ArrayRef; |
| 485 | + use std::collections::HashMap; |
| 486 | + |
| 487 | + fn variant_field(data_type: ArrowDataType) -> ArrowField { |
| 488 | + ArrowField::new("v", data_type, false).with_metadata(HashMap::from([( |
| 489 | + EXTENSION_KEY.to_string(), |
| 490 | + ARROW_EXT_TYPE_VARIANT.to_string(), |
| 491 | + )])) |
| 492 | + } |
| 493 | + |
| 494 | + #[test] |
| 495 | + fn decode_variant_from_utf8_array() { |
| 496 | + let field = variant_field(ArrowDataType::Utf8); |
| 497 | + let array: ArrayRef = Arc::new(StringArray::from(vec!["{\"a\":1}"])); |
| 498 | + |
| 499 | + let value = Value::try_from((&field, &array, 0, &ResultFormatSettings::default())).unwrap(); |
| 500 | + |
| 501 | + assert_eq!(value, Value::Variant("{\"a\":1}".to_string())); |
| 502 | + } |
| 503 | + |
| 504 | + #[test] |
| 505 | + fn decode_variant_from_large_utf8_array() { |
| 506 | + let field = variant_field(ArrowDataType::LargeUtf8); |
| 507 | + let array: ArrayRef = Arc::new(LargeStringArray::from(vec!["{\"b\":2}"])); |
| 508 | + |
| 509 | + let value = Value::try_from((&field, &array, 0, &ResultFormatSettings::default())).unwrap(); |
| 510 | + |
| 511 | + assert_eq!(value, Value::Variant("{\"b\":2}".to_string())); |
| 512 | + } |
| 513 | + |
| 514 | + #[test] |
| 515 | + fn decode_variant_from_large_binary_array_for_v2() { |
| 516 | + let field = variant_field(ArrowDataType::LargeBinary); |
| 517 | + let array: ArrayRef = Arc::new(LargeBinaryArray::from(vec![b"{\"c\":3}".as_slice()])); |
| 518 | + let settings = ResultFormatSettings { |
| 519 | + arrow_result_version: Some(2), |
| 520 | + ..ResultFormatSettings::default() |
| 521 | + }; |
| 522 | + |
| 523 | + let value = Value::try_from((&field, &array, 0, &settings)).unwrap(); |
| 524 | + |
| 525 | + assert_eq!(value, Value::Variant("{\"c\":3}".to_string())); |
| 526 | + } |
| 527 | +} |
0 commit comments