From aa31d6676809039bc9801c1931e7754391ccd33d Mon Sep 17 00:00:00 2001 From: Nishant Bhakar Date: Sun, 2 Feb 2025 16:55:45 -0800 Subject: [PATCH] fix(set_agg): use data equality instead of only hash equality for deduplication --- src/daft-core/src/array/ops/set_agg.rs | 149 ++++++++++++++++++++++++- 1 file changed, 143 insertions(+), 6 deletions(-) diff --git a/src/daft-core/src/array/ops/set_agg.rs b/src/daft-core/src/array/ops/set_agg.rs index 0472c807b3..91eea7e8d5 100644 --- a/src/daft-core/src/array/ops/set_agg.rs +++ b/src/daft-core/src/array/ops/set_agg.rs @@ -1,37 +1,174 @@ use std::collections::HashMap; +use arrow2::array::{ListArray as Arrow2ListArray, PrimitiveArray}; use common_error::{DaftError, DaftResult}; use super::{DaftNotNull, DaftSetAggable, GroupIndices}; use crate::{ array::{ growable::{Growable, GrowableArray}, + ops::{arrow2::comparison::build_is_equal, as_arrow::AsArrow}, DataArray, FixedSizeListArray, ListArray, StructArray, }, - datatypes::{DaftArrowBackedType, UInt64Array}, + datatypes::{DaftArrowBackedType, DataType, UInt64Array}, series::{IntoSeries, Series}, }; fn deduplicate_series(series: &Series) -> DaftResult<(Series, Vec)> { + // Special handling for Null type + if series.data_type() == &DataType::Null { + let mut unique_indices = Vec::new(); + if !series.is_empty() { + unique_indices.push(0); // Just take the first null value + } + let indices_array = UInt64Array::from(("", unique_indices.clone())).into_series(); + let result = series.take(&indices_array)?; + return Ok((result, unique_indices)); + } + + // Special handling for List type + if let DataType::List(_) = series.data_type() { + let mut seen_lists = HashMap::new(); + let mut unique_indices = Vec::new(); + let mut has_null = false; + + let list_array = series.to_arrow(); + let list_array = list_array + .as_any() + .downcast_ref::>() + .ok_or_else(|| DaftError::ValueError("Failed to downcast to ListArray".to_string()))?; + let values = list_array + .values() + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DaftError::ValueError( + "Failed to downcast list values to PrimitiveArray".to_string(), + ) + })?; + + let hash_series = series.hash(None).map_err(|_| { + DaftError::ValueError( + "Cannot perform set aggregation on elements that are not hashable".to_string(), + ) + })?; + let hash_array = hash_series.as_arrow(); + let hash_values = hash_array.values(); + + for idx in 0..series.len() { + if !series.is_valid(idx) { + if !has_null { + has_null = true; + unique_indices.push(idx as u64); + } + continue; + } + + let start = list_array.offsets()[idx] as usize; + let end = list_array.offsets()[idx + 1] as usize; + let current_list = &values.values()[start..end]; + let hash = hash_values.get(idx).unwrap(); + + let mut is_duplicate = false; + if let Some(existing_indices) = seen_lists.get(hash) { + for &existing_idx in existing_indices { + let start = list_array.offsets()[existing_idx] as usize; + let end = list_array.offsets()[existing_idx + 1] as usize; + let other_list = &values.values()[start..end]; + + if current_list == other_list { + is_duplicate = true; + break; + } + } + } + + if !is_duplicate { + seen_lists.entry(*hash).or_insert_with(Vec::new).push(idx); + unique_indices.push(idx as u64); + } + } + + let indices_array = UInt64Array::from(("", unique_indices.clone())).into_series(); + let result = series.take(&indices_array)?; + return Ok((result, unique_indices)); + } + + // Special handling for Struct type + if let DataType::Struct(_) = series.data_type() { + let mut seen_structs: HashMap> = HashMap::new(); + let mut unique_indices = Vec::new(); + let mut has_null = false; + + let hash_series = series.hash(None).map_err(|_| { + DaftError::ValueError( + "Cannot perform set aggregation on elements that are not hashable".to_string(), + ) + })?; + let hash_array = hash_series.as_arrow(); + let hash_values = hash_array.values(); + + for idx in 0..series.len() { + if !series.is_valid(idx) { + if !has_null { + has_null = true; + unique_indices.push(idx as u64); + } + continue; + } + + let hash = hash_values.get(idx).unwrap(); + let mut is_duplicate = false; + + if let Some(existing_indices) = seen_structs.get(hash) { + // For structs, we can rely on the hash equality since the hash function + // takes into account all fields and their values + is_duplicate = !existing_indices.is_empty(); + } + + if !is_duplicate { + seen_structs.entry(*hash).or_default().push(idx); + unique_indices.push(idx as u64); + } + } + + let indices_array = UInt64Array::from(("", unique_indices.clone())).into_series(); + let result = series.take(&indices_array)?; + return Ok((result, unique_indices)); + } + let hashes = series.hash(None).map_err(|_| { DaftError::ValueError( "Cannot perform set aggregation on elements that are not hashable".to_string(), ) })?; - let mut seen_hashes = HashMap::new(); + let array = series.to_arrow(); + let comparator = build_is_equal(&*array, &*array, true, false)?; + let mut seen_hashes = HashMap::>::new(); let mut unique_indices = Vec::new(); let mut has_null = false; - for (idx, hash) in hashes.into_iter().enumerate() { + let hash_array = hashes.as_arrow(); + for (idx, hash) in hash_array.values_iter().enumerate() { if !series.is_valid(idx) { if !has_null { has_null = true; unique_indices.push(idx as u64); } - } else if let Some(hash) = hash { - if let std::collections::hash_map::Entry::Vacant(e) = seen_hashes.entry(hash) { - e.insert(idx); + } else { + let mut is_duplicate = false; + if let Some(existing_indices) = seen_hashes.get(hash) { + for &existing_idx in existing_indices { + if comparator(idx, existing_idx) { + is_duplicate = true; + break; + } + } + } + + if !is_duplicate { + seen_hashes.entry(*hash).or_default().push(idx); unique_indices.push(idx as u64); } }