-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathpivotp.rs
566 lines (511 loc) · 20.5 KB
/
pivotp.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
static USAGE: &str = r#"
Pivots CSV data using the Polars engine.
The pivot operation consists of:
- One or more index columns (these will be the new rows)
- A column that will be pivoted (this will create the new columns)
- A values column that will be aggregated
- An aggregation function to apply. Features "smart" aggregation auto-selection.
For examples, see https://github.com/dathere/qsv/blob/master/tests/test_pivotp.rs.
Usage:
qsv pivotp [options] <on-cols> <input>
qsv pivotp --help
pivotp arguments:
<on-cols> The column(s) to pivot on (creates new columns).
<input> is the input CSV file. The file must have headers.
Stdin is not supported.
pivotp options:
-i, --index <cols> The column(s) to use as the index (row labels).
Specify multiple columns by separating them with a comma.
The output will have one row for each unique combination of the index’s values.
If None, all remaining columns not specified on --on and --values will be used.
At least one of --index and --values must be specified.
-v, --values <cols> The column(s) containing values to aggregate.
If an aggregation is specified, these are the values on which the aggregation
will be computed. If None, all remaining columns not specified on --on and --index
will be used. At least one of --index and --values must be specified.
-a, --agg <func> The aggregation function to use:
first - First value encountered
last - Last value encountered
sum - Sum of values
min - Minimum value
max - Maximum value
mean - Average value
median - Median value
count - Count of values
none - No aggregation is done. Raises error if multiple values are in group.
smart - use value column data type & statistics to pick an aggregation.
Will only work if there is one value column, otherwise
it falls back to `first`
[default: smart]
--sort-columns Sort the transposed columns by name. Default is by order of discovery.
--col-separator <arg> The separator in generated column names in case of multiple --values columns.
[default: _]
--validate Validate a pivot by checking the pivot column(s)' cardinality.
--try-parsedates When set, will attempt to parse columns as dates.
--infer-len <arg> Number of rows to scan when inferring schema.
Set to 0 to scan entire file. [default: 10000]
--decimal-comma Use comma as decimal separator when READING the input.
Note that you will need to specify an alternate --delimiter.
--ignore-errors Skip rows that can't be parsed.
Common options:
-h, --help Display this message
-o, --output <file> Write output to <file> instead of stdout.
-d, --delimiter <arg> The field delimiter for reading/writing CSV data.
Must be a single character. (default: ,)
-Q, --quiet Do not return smart aggregation chosen nor pivot result shape to stderr.
"#;
use std::{fs::File, io, io::Write, path::Path, sync::OnceLock};
use csv::ByteRecord;
use indicatif::HumanCount;
use polars::prelude::*;
use polars_ops::pivot::{pivot_stable, PivotAgg};
use serde::Deserialize;
use crate::{
cmd::stats::StatsData,
config::{Config, Delimiter},
util,
util::{get_stats_records, StatsMode},
CliResult,
};
static STATS_RECORDS: OnceLock<(ByteRecord, Vec<StatsData>)> = OnceLock::new();
#[derive(Deserialize)]
struct Args {
arg_on_cols: String,
arg_input: String,
flag_index: Option<String>,
flag_values: Option<String>,
flag_agg: Option<String>,
flag_sort_columns: bool,
flag_col_separator: String,
flag_validate: bool,
flag_try_parsedates: bool,
flag_infer_len: usize,
flag_decimal_comma: bool,
flag_ignore_errors: bool,
flag_output: Option<String>,
flag_delimiter: Option<Delimiter>,
flag_quiet: bool,
}
/// Structure to hold pivot operation metadata
struct PivotMetadata {
estimated_columns: u64,
on_col_cardinalities: Vec<(String, u64)>,
}
/// Calculate pivot operation metadata using stats information
fn calculate_pivot_metadata(
args: &Args,
on_cols: &[String],
value_cols: Option<&Vec<String>>,
) -> CliResult<Option<PivotMetadata>> {
// Get stats records
let schema_args = util::SchemaArgs {
flag_enum_threshold: 0,
flag_ignore_case: false,
flag_strict_dates: false,
flag_pattern_columns: crate::select::SelectColumns::parse("").unwrap(),
flag_dates_whitelist: String::new(),
flag_prefer_dmy: false,
flag_force: false,
flag_stdout: false,
flag_jobs: None,
flag_no_headers: false,
flag_delimiter: args.flag_delimiter,
arg_input: Some(args.arg_input.clone()),
flag_memcheck: false,
};
let (csv_fields, csv_stats) = STATS_RECORDS.get_or_init(|| {
get_stats_records(&schema_args, StatsMode::FrequencyForceStats)
.unwrap_or_else(|_| (ByteRecord::new(), Vec::new()))
});
if csv_stats.is_empty() {
return Ok(None);
}
// Get cardinalities for pivot columns
let mut on_col_cardinalities = Vec::with_capacity(on_cols.len());
let mut total_new_columns: u64 = 1;
for on_col in on_cols {
if let Some(pos) = csv_fields
.iter()
.position(|f| std::str::from_utf8(f).unwrap_or("") == on_col)
{
let cardinality = csv_stats[pos].cardinality;
total_new_columns = total_new_columns.saturating_mul(cardinality);
on_col_cardinalities.push((on_col.clone(), cardinality));
}
}
// Calculate total columns in result
let value_cols_count = match value_cols {
Some(cols) => cols.len() as u64,
None => 1,
};
let estimated_columns = total_new_columns.saturating_mul(value_cols_count);
Ok(Some(PivotMetadata {
estimated_columns,
on_col_cardinalities,
}))
}
/// Validate pivot operation using metadata
fn validate_pivot_operation(metadata: &PivotMetadata) -> CliResult<()> {
const COLUMN_WARNING_THRESHOLD: u64 = 1000;
// Print cardinality information
if metadata.on_col_cardinalities.len() > 1 {
eprintln!("Pivot <on-cols> cardinalities:");
} else {
eprintln!("Pivot on-column cardinality:");
}
for (col, card) in &metadata.on_col_cardinalities {
eprintln!(" {col}: {}", HumanCount(*card));
}
// Warn about large number of columns
if metadata.estimated_columns > COLUMN_WARNING_THRESHOLD {
eprintln!(
"Warning: Pivot will create {} columns. This might impact performance.",
HumanCount(metadata.estimated_columns)
);
}
// Error if operation would create an unreasonable number of columns
if metadata.estimated_columns > 100_000 {
return fail_clierror!(
"Pivot would create too many columns ({}). Consider reducing the number of pivot \
columns or using a different approach.",
HumanCount(metadata.estimated_columns)
);
}
Ok(())
}
/// Suggest an appropriate aggregation function based on column statistics
#[allow(clippy::cast_precision_loss)]
fn suggest_agg_function(
args: &Args,
on_cols: &[String],
index_cols: Option<&[String]>,
value_cols: &[String],
) -> CliResult<Option<PivotAgg>> {
// If multiple value columns, default to First
if value_cols.len() > 1 {
return Ok(Some(PivotAgg::First));
}
let quiet = args.flag_quiet;
// Get stats for all columns with enhanced statistics
let schema_args = util::SchemaArgs {
flag_enum_threshold: 0,
flag_ignore_case: false,
flag_strict_dates: false,
flag_pattern_columns: crate::select::SelectColumns::parse("").unwrap(),
flag_dates_whitelist: String::new(),
flag_prefer_dmy: false,
flag_force: false,
flag_stdout: false,
flag_jobs: None,
flag_no_headers: false,
flag_delimiter: args.flag_delimiter,
arg_input: Some(args.arg_input.clone()),
flag_memcheck: false,
};
let (csv_fields, csv_stats) = STATS_RECORDS.get_or_init(|| {
get_stats_records(&schema_args, StatsMode::FrequencyForceStats)
.unwrap_or_else(|_| (ByteRecord::new(), Vec::new()))
});
let rconfig = Config::new(Some(&args.arg_input));
let row_count = util::count_rows(&rconfig)? as u64;
// eprintln!("row_count: {}\nstats: {:#?}", row_count, csv_stats);
// Analyze pivot column characteristics
let mut high_cardinality_pivot = false;
let mut ordered_pivot = false; // Track if pivot columns are ordered
for on_col in on_cols {
if let Some(pos) = csv_fields
.iter()
.position(|f| std::str::from_utf8(f).unwrap_or("") == on_col)
{
let stats = &csv_stats[pos];
// Check cardinality ratio
if stats.cardinality as f64 / row_count as f64 > 0.5 {
high_cardinality_pivot = true;
if !quiet {
eprintln!("Info: Pivot column \"{on_col}\" has high cardinality");
}
}
// Check if column is unordered based on sort_order
if let Some(sort_order) = &stats.sort_order {
ordered_pivot = sort_order != "Unsorted";
}
}
}
// Analyze index column characteristics
let mut high_cardinality_index = false;
let mut ordered_index = false;
if let Some(idx_cols) = index_cols {
for idx_col in idx_cols {
if let Some(pos) = csv_fields
.iter()
.position(|f| std::str::from_utf8(f).unwrap_or("") == idx_col)
{
let stats = &csv_stats[pos];
// Check cardinality ratio
if stats.cardinality as f64 / row_count as f64 > 0.5 {
high_cardinality_index = true;
if !quiet {
eprintln!("Info: Index column \"{idx_col}\" has high cardinality");
}
}
// Check if column is unordered
if let Some(sort_order) = &stats.sort_order {
ordered_index = sort_order != "Unsorted";
}
}
}
}
// Get stats for the value column
let value_col = &value_cols[0];
let field_pos = csv_fields
.iter()
.position(|f| std::str::from_utf8(f).unwrap_or("") == value_col);
if let Some(pos) = field_pos {
let stats = &csv_stats[pos];
// Suggest aggregation based on field type and statistics
let suggested_agg = match stats.r#type.as_str() {
"NULL" => {
if !quiet {
eprintln!("Info: \"{value_col}\" contains only NULL values");
}
PivotAgg::Count
},
"Integer" | "Float" => {
if stats.nullcount as f64 / row_count as f64 > 0.5 {
if !quiet {
eprintln!("Info: \"{value_col}\" contains >50% NULL values, using Count");
}
PivotAgg::Count
} else if stats.cv > Some(1.0) {
// High coefficient of variation suggests using median for better central
// tendency
if !quiet {
eprintln!(
"Info: High variability in values (CV > 1), using Median for more \
robust central tendency"
);
}
PivotAgg::Median
} else if high_cardinality_pivot && high_cardinality_index {
if ordered_pivot && ordered_index {
// With ordered high cardinality columns, mean might be more meaningful
if !quiet {
eprintln!(
"Info: Ordered high cardinality columns detected, using Mean"
);
}
PivotAgg::Mean
} else {
// With unordered high cardinality, sum might be more appropriate
if !quiet {
eprintln!(
"Info: High cardinality in pivot and index columns, using Sum"
);
}
PivotAgg::Sum
}
} else if let Some(skewness) = stats.skewness {
if skewness.abs() > 2.0 {
// Highly skewed data might benefit from median
if !quiet {
eprintln!("Info: Highly skewed numeric data detected, using Median");
}
PivotAgg::Median
} else {
PivotAgg::Sum
}
} else {
PivotAgg::Sum
}
},
"Date" | "DateTime" => {
if high_cardinality_pivot || high_cardinality_index {
if ordered_pivot && ordered_index {
if !quiet {
eprintln!(
"Info: Ordered temporal data with high cardinality, using Last"
);
}
PivotAgg::Last
} else {
if !quiet {
eprintln!(
"Info: High cardinality detected, using First for {} column",
stats.r#type
);
}
PivotAgg::First
}
} else {
if !quiet {
eprintln!("Info: Using Count for {} column", stats.r#type);
}
PivotAgg::Count
}
},
_ => {
if stats.cardinality == row_count {
if !quiet {
eprintln!("Info: \"{value_col}\" contains all unique values, using First");
}
PivotAgg::First
} else if stats.sparsity > Some(0.5) {
if !quiet {
eprintln!("Info: Sparse data detected, using Count");
}
PivotAgg::Count
} else if high_cardinality_pivot || high_cardinality_index {
if !quiet {
eprintln!("Info: High cardinality detected, using Count");
}
PivotAgg::Count
} else {
if !quiet {
eprintln!("Info: Using Count for String column");
}
PivotAgg::Count
}
},
};
Ok(Some(suggested_agg))
} else {
Ok(None)
}
}
pub fn run(argv: &[&str]) -> CliResult<()> {
let args: Args = util::get_args(USAGE, argv)?;
// Parse on column(s)
let on_cols: Vec<String> = args
.arg_on_cols
.as_str()
.split(',')
.map(std::string::ToString::to_string)
.collect();
// Parse index column(s)
let index_cols = if let Some(ref flag_index) = args.flag_index {
let idx_cols: Vec<String> = flag_index
.as_str()
.split(',')
.map(std::string::ToString::to_string)
.collect();
Some(idx_cols)
} else {
None
};
// Parse values column(s)
let value_cols = if let Some(ref flag_values) = args.flag_values {
let val_cols: Vec<String> = flag_values
.as_str()
.split(',')
.map(std::string::ToString::to_string)
.collect();
Some(val_cols)
} else {
None
};
if index_cols.is_none() && value_cols.is_none() {
return fail_incorrectusage_clierror!(
"Either --index <cols> or --values <cols> must be specified."
);
}
// Get aggregation function
let agg_fn = if let Some(ref agg) = args.flag_agg {
let lower_agg = agg.to_lowercase();
if lower_agg == "none" {
None
} else {
Some(match lower_agg.as_str() {
"first" => PivotAgg::First,
"sum" => PivotAgg::Sum,
"min" => PivotAgg::Min,
"max" => PivotAgg::Max,
"mean" => PivotAgg::Mean,
"median" => PivotAgg::Median,
"count" => PivotAgg::Count,
"last" => PivotAgg::Last,
"smart" => {
if let Some(value_cols) = &value_cols {
// Try to suggest an appropriate aggregation function
if let Some(suggested_agg) = suggest_agg_function(
&args,
&on_cols,
index_cols.as_deref(),
value_cols,
)? {
suggested_agg
} else {
// fallback to first, which always works
PivotAgg::First
}
} else {
// Default to Count if no value columns specified
PivotAgg::Count
}
},
_ => {
return fail_incorrectusage_clierror!(
"Invalid pivot aggregation function: {agg}"
)
},
})
}
} else {
None
};
// Set delimiter if specified
let delim = if let Some(delimiter) = args.flag_delimiter {
delimiter.as_byte()
} else {
b','
};
if args.flag_decimal_comma && delim == b',' {
return fail_incorrectusage_clierror!(
"You need to specify an alternate --delimiter when using --decimal-comma."
);
}
// Create CSV reader config
let csv_reader = LazyCsvReader::new(&args.arg_input)
.with_has_header(true)
.with_try_parse_dates(args.flag_try_parsedates)
.with_decimal_comma(args.flag_decimal_comma)
.with_separator(delim)
.with_ignore_errors(args.flag_ignore_errors)
.with_infer_schema_length(Some(args.flag_infer_len));
// Read the CSV into a DataFrame
let df = csv_reader.finish()?.collect()?;
if args.flag_validate {
// Validate the operation
if let Some(metadata) = calculate_pivot_metadata(&args, &on_cols, value_cols.as_ref())? {
validate_pivot_operation(&metadata)?;
}
}
// Perform pivot operation
let mut pivot_result = pivot_stable(
&df,
on_cols,
index_cols,
value_cols,
args.flag_sort_columns,
agg_fn,
Some(&args.flag_col_separator),
)?;
// Write output
let mut writer = match args.flag_output {
Some(ref output_file) => {
// no need to use buffered writer here, as CsvWriter already does that
let path = Path::new(&output_file);
Box::new(File::create(path).unwrap()) as Box<dyn Write>
},
None => Box::new(io::stdout()) as Box<dyn Write>,
};
CsvWriter::new(&mut writer)
.include_header(true)
.with_datetime_format(Some("%Y-%m-%d %H:%M:%S".to_string()))
.with_separator(delim)
.finish(&mut pivot_result)?;
// Print shape to stderr
if !args.flag_quiet {
eprintln!("{:?}", pivot_result.shape());
}
Ok(())
}