Skip to content

Commit

Permalink
fix(set_agg): use data equality instead of only hash equality for ded…
Browse files Browse the repository at this point in the history
…uplication
  • Loading branch information
f4t4nt committed Feb 3, 2025
1 parent 01cddd5 commit aa31d66
Showing 1 changed file with 143 additions and 6 deletions.
149 changes: 143 additions & 6 deletions src/daft-core/src/array/ops/set_agg.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,174 @@
use std::collections::HashMap;

use arrow2::array::{ListArray as Arrow2ListArray, PrimitiveArray};
use common_error::{DaftError, DaftResult};

use super::{DaftNotNull, DaftSetAggable, GroupIndices};
use crate::{
array::{
growable::{Growable, GrowableArray},
ops::{arrow2::comparison::build_is_equal, as_arrow::AsArrow},
DataArray, FixedSizeListArray, ListArray, StructArray,
},
datatypes::{DaftArrowBackedType, UInt64Array},
datatypes::{DaftArrowBackedType, DataType, UInt64Array},
series::{IntoSeries, Series},
};

fn deduplicate_series(series: &Series) -> DaftResult<(Series, Vec<u64>)> {
// Special handling for Null type
if series.data_type() == &DataType::Null {
let mut unique_indices = Vec::new();
if !series.is_empty() {
unique_indices.push(0); // Just take the first null value
}
let indices_array = UInt64Array::from(("", unique_indices.clone())).into_series();
let result = series.take(&indices_array)?;
return Ok((result, unique_indices));
}

// Special handling for List type
if let DataType::List(_) = series.data_type() {
let mut seen_lists = HashMap::new();
let mut unique_indices = Vec::new();
let mut has_null = false;

let list_array = series.to_arrow();
let list_array = list_array
.as_any()
.downcast_ref::<Arrow2ListArray<i64>>()
.ok_or_else(|| DaftError::ValueError("Failed to downcast to ListArray".to_string()))?;
let values = list_array
.values()
.as_any()
.downcast_ref::<PrimitiveArray<i64>>()
.ok_or_else(|| {
DaftError::ValueError(
"Failed to downcast list values to PrimitiveArray".to_string(),
)

Check warning on line 47 in src/daft-core/src/array/ops/set_agg.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/set_agg.rs#L45-L47

Added lines #L45 - L47 were not covered by tests
})?;

let hash_series = series.hash(None).map_err(|_| {
DaftError::ValueError(
"Cannot perform set aggregation on elements that are not hashable".to_string(),
)

Check warning on line 53 in src/daft-core/src/array/ops/set_agg.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/set_agg.rs#L51-L53

Added lines #L51 - L53 were not covered by tests
})?;
let hash_array = hash_series.as_arrow();
let hash_values = hash_array.values();

for idx in 0..series.len() {
if !series.is_valid(idx) {
if !has_null {
has_null = true;
unique_indices.push(idx as u64);
}
continue;
}

let start = list_array.offsets()[idx] as usize;
let end = list_array.offsets()[idx + 1] as usize;
let current_list = &values.values()[start..end];
let hash = hash_values.get(idx).unwrap();

let mut is_duplicate = false;
if let Some(existing_indices) = seen_lists.get(hash) {
for &existing_idx in existing_indices {
let start = list_array.offsets()[existing_idx] as usize;
let end = list_array.offsets()[existing_idx + 1] as usize;
let other_list = &values.values()[start..end];

if current_list == other_list {
is_duplicate = true;
break;
}

Check warning on line 82 in src/daft-core/src/array/ops/set_agg.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/set_agg.rs#L82

Added line #L82 was not covered by tests
}
}

if !is_duplicate {
seen_lists.entry(*hash).or_insert_with(Vec::new).push(idx);
unique_indices.push(idx as u64);
}
}

let indices_array = UInt64Array::from(("", unique_indices.clone())).into_series();
let result = series.take(&indices_array)?;
return Ok((result, unique_indices));
}

// Special handling for Struct type
if let DataType::Struct(_) = series.data_type() {
let mut seen_structs: HashMap<u64, Vec<usize>> = HashMap::new();
let mut unique_indices = Vec::new();
let mut has_null = false;

let hash_series = series.hash(None).map_err(|_| {
DaftError::ValueError(
"Cannot perform set aggregation on elements that are not hashable".to_string(),
)

Check warning on line 106 in src/daft-core/src/array/ops/set_agg.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/set_agg.rs#L104-L106

Added lines #L104 - L106 were not covered by tests
})?;
let hash_array = hash_series.as_arrow();
let hash_values = hash_array.values();

for idx in 0..series.len() {
if !series.is_valid(idx) {
if !has_null {
has_null = true;
unique_indices.push(idx as u64);
}
continue;
}

let hash = hash_values.get(idx).unwrap();
let mut is_duplicate = false;

if let Some(existing_indices) = seen_structs.get(hash) {
// For structs, we can rely on the hash equality since the hash function
// takes into account all fields and their values
is_duplicate = !existing_indices.is_empty();
}

if !is_duplicate {
seen_structs.entry(*hash).or_default().push(idx);
unique_indices.push(idx as u64);
}
}

let indices_array = UInt64Array::from(("", unique_indices.clone())).into_series();
let result = series.take(&indices_array)?;
return Ok((result, unique_indices));
}

let hashes = series.hash(None).map_err(|_| {
DaftError::ValueError(
"Cannot perform set aggregation on elements that are not hashable".to_string(),
)

Check warning on line 143 in src/daft-core/src/array/ops/set_agg.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/set_agg.rs#L141-L143

Added lines #L141 - L143 were not covered by tests
})?;

let mut seen_hashes = HashMap::new();
let array = series.to_arrow();
let comparator = build_is_equal(&*array, &*array, true, false)?;
let mut seen_hashes = HashMap::<u64, Vec<usize>>::new();
let mut unique_indices = Vec::new();
let mut has_null = false;

for (idx, hash) in hashes.into_iter().enumerate() {
let hash_array = hashes.as_arrow();
for (idx, hash) in hash_array.values_iter().enumerate() {
if !series.is_valid(idx) {
if !has_null {
has_null = true;
unique_indices.push(idx as u64);
}
} else if let Some(hash) = hash {
if let std::collections::hash_map::Entry::Vacant(e) = seen_hashes.entry(hash) {
e.insert(idx);
} else {
let mut is_duplicate = false;
if let Some(existing_indices) = seen_hashes.get(hash) {
for &existing_idx in existing_indices {
if comparator(idx, existing_idx) {
is_duplicate = true;
break;
}

Check warning on line 166 in src/daft-core/src/array/ops/set_agg.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/set_agg.rs#L166

Added line #L166 was not covered by tests
}
}

if !is_duplicate {
seen_hashes.entry(*hash).or_default().push(idx);
unique_indices.push(idx as u64);
}
}
Expand Down

0 comments on commit aa31d66

Please sign in to comment.