rust_data_processing/processing/
reduce.rs

1//! Reduction operations for [`crate::types::DataSet`].
2
3use std::collections::HashSet;
4
5use crate::types::{DataSet, DataType, Value};
6
7/// Population vs sample variance / standard deviation (`ddof` 0 vs 1).
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum VarianceKind {
10    /// Divide by `n` (when `n > 0`).
11    Population,
12    /// Divide by `n - 1` (when `n >= 2`); otherwise [`None`] / null.
13    Sample,
14}
15
16/// Built-in reduction operations over a single column.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ReduceOp {
19    /// Count all rows (including nulls).
20    Count,
21    /// Sum numeric values, ignoring nulls.
22    Sum,
23    /// Minimum numeric value, ignoring nulls.
24    Min,
25    /// Maximum numeric value, ignoring nulls.
26    Max,
27    /// Arithmetic mean of numeric values as [`Value::Float64`], ignoring nulls.
28    Mean,
29    /// Variance (Welford); null if no values, or sample with fewer than two values.
30    Variance(VarianceKind),
31    /// Standard deviation from variance; same null rules as [`ReduceOp::Variance`].
32    StdDev(VarianceKind),
33    /// \(\sum x^2\) over non-null numeric values as [`Value::Float64`].
34    SumSquares,
35    /// \(\sqrt{\sum x^2}\) over non-null numeric values as [`Value::Float64`].
36    L2Norm,
37    /// Count of distinct non-null values (returns [`Value::Int64`]).
38    CountDistinctNonNull,
39}
40
41/// Reduce a column using a built-in [`ReduceOp`].
42///
43/// - Returns `None` if `column` does not exist in the schema.
44/// - For `Count`, always returns `Some(Value::Int64(row_count))`.
45/// - For numeric aggregates other than `Count` / `CountDistinctNonNull`, returns
46///   `Some(Value::Null)` if there are no non-null numeric values, or if the column type is not
47///   numeric (for those ops). `CountDistinctNonNull` supports [`DataType::Bool`] and
48///   [`DataType::Utf8`] as well as numeric types.
49pub fn reduce(dataset: &DataSet, column: &str, op: ReduceOp) -> Option<Value> {
50    let idx = dataset.schema.index_of(column)?;
51
52    match op {
53        ReduceOp::Count => Some(Value::Int64(dataset.row_count() as i64)),
54        ReduceOp::CountDistinctNonNull => {
55            let field = dataset.schema.fields.get(idx)?;
56            reduce_count_distinct_non_null(dataset, idx, &field.data_type)
57        }
58        ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => match dataset.schema.fields.get(idx) {
59            Some(field) => reduce_numeric_typed(dataset, idx, field.data_type.clone(), op),
60            None => None,
61        },
62        ReduceOp::Mean
63        | ReduceOp::Variance(_)
64        | ReduceOp::StdDev(_)
65        | ReduceOp::SumSquares
66        | ReduceOp::L2Norm => match dataset.schema.fields.get(idx) {
67            Some(field) => reduce_numeric_float_stats(dataset, idx, field.data_type.clone(), op),
68            None => None,
69        },
70    }
71}
72
73#[derive(Default)]
74pub(crate) struct Welford {
75    n: u64,
76    mean: f64,
77    m2: f64,
78}
79
80impl Welford {
81    pub(crate) fn observe(&mut self, x: f64) {
82        self.n += 1;
83        let delta = x - self.mean;
84        self.mean += delta / self.n as f64;
85        let delta2 = x - self.mean;
86        self.m2 += delta * delta2;
87    }
88
89    pub(crate) fn mean(&self) -> Option<f64> {
90        (self.n > 0).then_some(self.mean)
91    }
92
93    pub(crate) fn variance(&self, kind: VarianceKind) -> Option<f64> {
94        if self.n == 0 {
95            return None;
96        }
97        match kind {
98            VarianceKind::Population => Some(self.m2 / self.n as f64),
99            VarianceKind::Sample => {
100                if self.n < 2 {
101                    None
102                } else {
103                    Some(self.m2 / (self.n - 1) as f64)
104                }
105            }
106        }
107    }
108
109    pub(crate) fn observation_count(&self) -> u64 {
110        self.n
111    }
112}
113
114fn reduce_numeric_float_stats(
115    dataset: &DataSet,
116    idx: usize,
117    data_type: DataType,
118    op: ReduceOp,
119) -> Option<Value> {
120    match data_type {
121        dt @ (DataType::Int64 | DataType::Float64) => {
122            let is_int = matches!(dt, DataType::Int64);
123            let mut w = Welford::default();
124            let mut sum_squares = 0.0_f64;
125            let mut any = false;
126
127            for row in &dataset.rows {
128                let x = match row.get(idx) {
129                    Some(Value::Null) | None => None,
130                    Some(Value::Int64(v)) if is_int => Some(*v as f64),
131                    Some(Value::Float64(v)) if !is_int => Some(*v),
132                    Some(_) => None,
133                };
134                if let Some(x) = x {
135                    any = true;
136                    w.observe(x);
137                    sum_squares += x * x;
138                }
139            }
140
141            if !any {
142                return Some(Value::Null);
143            }
144
145            let out = match op {
146                ReduceOp::Mean => Value::Float64(w.mean().expect("n > 0")),
147                ReduceOp::Variance(kind) => match w.variance(kind) {
148                    Some(v) => Value::Float64(v),
149                    None => Value::Null,
150                },
151                ReduceOp::StdDev(kind) => match w.variance(kind) {
152                    Some(v) => Value::Float64(v.sqrt()),
153                    None => Value::Null,
154                },
155                ReduceOp::SumSquares => Value::Float64(sum_squares),
156                ReduceOp::L2Norm => Value::Float64(sum_squares.sqrt()),
157                _ => unreachable!("caller only dispatches float stats ops"),
158            };
159            Some(out)
160        }
161        _ => Some(Value::Null),
162    }
163}
164
165fn reduce_count_distinct_non_null(
166    dataset: &DataSet,
167    idx: usize,
168    data_type: &DataType,
169) -> Option<Value> {
170    let n = match data_type {
171        DataType::Int64 => {
172            let mut set = HashSet::new();
173            for row in &dataset.rows {
174                if let Some(Value::Int64(v)) = row.get(idx) {
175                    set.insert(*v);
176                }
177            }
178            set.len() as i64
179        }
180        DataType::Float64 => {
181            let mut set = HashSet::new();
182            for row in &dataset.rows {
183                if let Some(Value::Float64(v)) = row.get(idx) {
184                    set.insert(v.to_bits());
185                }
186            }
187            set.len() as i64
188        }
189        DataType::Bool => {
190            let mut set = HashSet::new();
191            for row in &dataset.rows {
192                if let Some(Value::Bool(v)) = row.get(idx) {
193                    set.insert(*v);
194                }
195            }
196            set.len() as i64
197        }
198        DataType::Utf8 => {
199            let mut set = HashSet::new();
200            for row in &dataset.rows {
201                if let Some(Value::Utf8(s)) = row.get(idx) {
202                    set.insert(s.clone());
203                }
204            }
205            set.len() as i64
206        }
207    };
208    Some(Value::Int64(n))
209}
210
211fn reduce_numeric_typed(
212    dataset: &DataSet,
213    idx: usize,
214    data_type: DataType,
215    op: ReduceOp,
216) -> Option<Value> {
217    match data_type {
218        DataType::Int64 => {
219            let mut acc: Option<i64> = None;
220            for row in &dataset.rows {
221                match row.get(idx) {
222                    Some(Value::Null) | None => {}
223                    Some(Value::Int64(v)) => {
224                        acc = Some(match (op, acc) {
225                            (ReduceOp::Sum, Some(a)) => a + v,
226                            (ReduceOp::Sum, None) => *v,
227                            (ReduceOp::Min, Some(a)) => a.min(*v),
228                            (ReduceOp::Min, None) => *v,
229                            (ReduceOp::Max, Some(a)) => a.max(*v),
230                            (ReduceOp::Max, None) => *v,
231                            _ => unreachable!("non-numeric op handled earlier"),
232                        });
233                    }
234                    Some(_) => {}
235                }
236            }
237            Some(acc.map(Value::Int64).unwrap_or(Value::Null))
238        }
239        DataType::Float64 => {
240            let mut acc: Option<f64> = None;
241            for row in &dataset.rows {
242                match row.get(idx) {
243                    Some(Value::Null) | None => {}
244                    Some(Value::Float64(v)) => {
245                        acc = Some(match (op, acc) {
246                            (ReduceOp::Sum, Some(a)) => a + v,
247                            (ReduceOp::Sum, None) => *v,
248                            (ReduceOp::Min, Some(a)) => a.min(*v),
249                            (ReduceOp::Min, None) => *v,
250                            (ReduceOp::Max, Some(a)) => a.max(*v),
251                            (ReduceOp::Max, None) => *v,
252                            _ => unreachable!("non-numeric op handled earlier"),
253                        });
254                    }
255                    Some(_) => {}
256                }
257            }
258            Some(acc.map(Value::Float64).unwrap_or(Value::Null))
259        }
260        _ => Some(Value::Null),
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::{ReduceOp, VarianceKind, reduce};
267    use crate::types::{DataSet, DataType, Field, Schema, Value};
268
269    fn numeric_dataset_with_nulls() -> DataSet {
270        let schema = Schema::new(vec![
271            Field::new("id", DataType::Int64),
272            Field::new("score", DataType::Float64),
273        ]);
274
275        let rows = vec![
276            vec![Value::Int64(1), Value::Float64(10.0)],
277            vec![Value::Int64(2), Value::Null],
278            vec![Value::Int64(3), Value::Float64(5.5)],
279        ];
280
281        DataSet::new(schema, rows)
282    }
283
284    #[test]
285    fn reduce_count_counts_rows() {
286        let ds = numeric_dataset_with_nulls();
287        assert_eq!(reduce(&ds, "score", ReduceOp::Count), Some(Value::Int64(3)));
288        assert_eq!(reduce(&ds, "id", ReduceOp::Count), Some(Value::Int64(3)));
289    }
290
291    #[test]
292    fn reduce_sum_ignores_nulls_and_preserves_type() {
293        let ds = numeric_dataset_with_nulls();
294        assert_eq!(
295            reduce(&ds, "score", ReduceOp::Sum),
296            Some(Value::Float64(15.5))
297        );
298        assert_eq!(reduce(&ds, "id", ReduceOp::Sum), Some(Value::Int64(6)));
299    }
300
301    #[test]
302    fn reduce_min_max_ignore_nulls() {
303        let ds = numeric_dataset_with_nulls();
304        assert_eq!(
305            reduce(&ds, "score", ReduceOp::Min),
306            Some(Value::Float64(5.5))
307        );
308        assert_eq!(
309            reduce(&ds, "score", ReduceOp::Max),
310            Some(Value::Float64(10.0))
311        );
312        assert_eq!(reduce(&ds, "id", ReduceOp::Min), Some(Value::Int64(1)));
313        assert_eq!(reduce(&ds, "id", ReduceOp::Max), Some(Value::Int64(3)));
314    }
315
316    #[test]
317    fn reduce_returns_none_for_missing_column() {
318        let ds = numeric_dataset_with_nulls();
319        assert_eq!(reduce(&ds, "missing", ReduceOp::Count), None);
320        assert_eq!(reduce(&ds, "missing", ReduceOp::Sum), None);
321    }
322
323    #[test]
324    fn reduce_numeric_returns_null_if_all_values_null() {
325        let schema = Schema::new(vec![Field::new("score", DataType::Float64)]);
326        let ds = DataSet::new(schema, vec![vec![Value::Null], vec![Value::Null]]);
327        assert_eq!(reduce(&ds, "score", ReduceOp::Sum), Some(Value::Null));
328        assert_eq!(reduce(&ds, "score", ReduceOp::Min), Some(Value::Null));
329        assert_eq!(reduce(&ds, "score", ReduceOp::Max), Some(Value::Null));
330        assert_eq!(reduce(&ds, "score", ReduceOp::Mean), Some(Value::Null));
331        assert_eq!(
332            reduce(&ds, "score", ReduceOp::Variance(VarianceKind::Population)),
333            Some(Value::Null)
334        );
335        assert_eq!(
336            reduce(&ds, "score", ReduceOp::StdDev(VarianceKind::Sample)),
337            Some(Value::Null)
338        );
339    }
340
341    #[test]
342    fn reduce_mean_float_and_int() {
343        let ds = numeric_dataset_with_nulls();
344        assert_eq!(
345            reduce(&ds, "score", ReduceOp::Mean),
346            Some(Value::Float64(7.75))
347        );
348        assert_eq!(reduce(&ds, "id", ReduceOp::Mean), Some(Value::Float64(2.0)));
349    }
350
351    #[test]
352    fn reduce_variance_std_known_values() {
353        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
354        let ds = DataSet::new(
355            schema,
356            vec![
357                vec![Value::Float64(1.0)],
358                vec![Value::Float64(2.0)],
359                vec![Value::Float64(3.0)],
360            ],
361        );
362        let pop = 2.0 / 3.0;
363        assert_eq!(
364            reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Population)),
365            Some(Value::Float64(pop))
366        );
367        assert_eq!(
368            reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Sample)),
369            Some(Value::Float64(1.0))
370        );
371        let std_pop = reduce(&ds, "x", ReduceOp::StdDev(VarianceKind::Population)).unwrap();
372        match std_pop {
373            Value::Float64(v) => assert!((v - pop.sqrt()).abs() < 1e-12),
374            other => panic!("expected Float64, got {other:?}"),
375        }
376    }
377
378    #[test]
379    fn reduce_sample_variance_single_value_is_null() {
380        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
381        let ds = DataSet::new(schema, vec![vec![Value::Float64(42.0)]]);
382        assert_eq!(
383            reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Sample)),
384            Some(Value::Null)
385        );
386    }
387
388    #[test]
389    fn reduce_population_variance_single_value_is_zero() {
390        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
391        let ds = DataSet::new(schema, vec![vec![Value::Float64(42.0)]]);
392        assert_eq!(
393            reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Population)),
394            Some(Value::Float64(0.0))
395        );
396        let std0 = reduce(&ds, "x", ReduceOp::StdDev(VarianceKind::Population)).unwrap();
397        match std0 {
398            Value::Float64(v) => assert_eq!(v, 0.0),
399            other => panic!("expected Float64, got {other:?}"),
400        }
401    }
402
403    #[test]
404    fn reduce_int64_mean_sum_squares_and_distinct() {
405        let schema = Schema::new(vec![Field::new("k", DataType::Int64)]);
406        let ds = DataSet::new(
407            schema,
408            vec![
409                vec![Value::Int64(2)],
410                vec![Value::Int64(3)],
411                vec![Value::Null],
412            ],
413        );
414        assert_eq!(reduce(&ds, "k", ReduceOp::Mean), Some(Value::Float64(2.5)));
415        assert_eq!(
416            reduce(&ds, "k", ReduceOp::SumSquares),
417            Some(Value::Float64(13.0))
418        );
419        assert_eq!(
420            reduce(&ds, "k", ReduceOp::L2Norm),
421            Some(Value::Float64(13.0_f64.sqrt()))
422        );
423        assert_eq!(
424            reduce(&ds, "k", ReduceOp::CountDistinctNonNull),
425            Some(Value::Int64(2))
426        );
427    }
428
429    #[test]
430    fn reduce_sum_squares_and_l2() {
431        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
432        let ds = DataSet::new(
433            schema,
434            vec![
435                vec![Value::Float64(3.0)],
436                vec![Value::Float64(4.0)],
437                vec![Value::Null],
438            ],
439        );
440        assert_eq!(
441            reduce(&ds, "x", ReduceOp::SumSquares),
442            Some(Value::Float64(25.0))
443        );
444        assert_eq!(
445            reduce(&ds, "x", ReduceOp::L2Norm),
446            Some(Value::Float64(5.0))
447        );
448    }
449
450    #[test]
451    fn reduce_count_distinct_non_null() {
452        let schema = Schema::new(vec![
453            Field::new("f", DataType::Float64),
454            Field::new("s", DataType::Utf8),
455        ]);
456        let ds = DataSet::new(
457            schema,
458            vec![
459                vec![Value::Float64(1.0), Value::Utf8("a".to_string())],
460                vec![Value::Float64(1.0), Value::Utf8("b".to_string())],
461                vec![Value::Null, Value::Null],
462            ],
463        );
464        assert_eq!(
465            reduce(&ds, "f", ReduceOp::CountDistinctNonNull),
466            Some(Value::Int64(1))
467        );
468        assert_eq!(
469            reduce(&ds, "s", ReduceOp::CountDistinctNonNull),
470            Some(Value::Int64(2))
471        );
472    }
473
474    #[test]
475    fn reduce_new_ops_return_none_for_missing_column() {
476        let ds = numeric_dataset_with_nulls();
477        assert_eq!(reduce(&ds, "nope", ReduceOp::Mean), None);
478        assert_eq!(
479            reduce(&ds, "nope", ReduceOp::Variance(VarianceKind::Sample)),
480            None
481        );
482        assert_eq!(reduce(&ds, "nope", ReduceOp::CountDistinctNonNull), None);
483    }
484
485    #[test]
486    fn reduce_sum_squares_and_l2_all_null() {
487        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
488        let ds = DataSet::new(schema, vec![vec![Value::Null]]);
489        assert_eq!(reduce(&ds, "x", ReduceOp::SumSquares), Some(Value::Null));
490        assert_eq!(reduce(&ds, "x", ReduceOp::L2Norm), Some(Value::Null));
491    }
492
493    #[test]
494    fn reduce_count_distinct_bool_and_empty_rows() {
495        let schema = Schema::new(vec![Field::new("b", DataType::Bool)]);
496        let ds = DataSet::new(schema.clone(), vec![]);
497        assert_eq!(
498            reduce(&ds, "b", ReduceOp::CountDistinctNonNull),
499            Some(Value::Int64(0))
500        );
501
502        let ds = DataSet::new(
503            schema,
504            vec![
505                vec![Value::Bool(true)],
506                vec![Value::Bool(false)],
507                vec![Value::Bool(true)],
508                vec![Value::Null],
509            ],
510        );
511        assert_eq!(
512            reduce(&ds, "b", ReduceOp::CountDistinctNonNull),
513            Some(Value::Int64(2))
514        );
515    }
516
517    #[test]
518    fn reduce_mean_variance_null_for_non_numeric_column() {
519        let schema = Schema::new(vec![Field::new("label", DataType::Utf8)]);
520        let ds = DataSet::new(
521            schema,
522            vec![
523                vec![Value::Utf8("a".to_string())],
524                vec![Value::Utf8("b".to_string())],
525            ],
526        );
527        assert_eq!(reduce(&ds, "label", ReduceOp::Mean), Some(Value::Null));
528        assert_eq!(
529            reduce(&ds, "label", ReduceOp::Variance(VarianceKind::Population)),
530            Some(Value::Null)
531        );
532        assert_eq!(
533            reduce(&ds, "label", ReduceOp::SumSquares),
534            Some(Value::Null)
535        );
536    }
537
538    #[test]
539    fn reduce_std_dev_sample_matches_sqrt_of_sample_variance() {
540        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
541        let ds = DataSet::new(
542            schema,
543            vec![
544                vec![Value::Float64(0.0)],
545                vec![Value::Float64(4.0)],
546                vec![Value::Float64(8.0)],
547            ],
548        );
549        let var_s = match reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Sample)).unwrap() {
550            Value::Float64(v) => v,
551            other => panic!("expected Float64, got {other:?}"),
552        };
553        let std_s = match reduce(&ds, "x", ReduceOp::StdDev(VarianceKind::Sample)).unwrap() {
554            Value::Float64(v) => v,
555            other => panic!("expected Float64, got {other:?}"),
556        };
557        assert!((std_s - var_s.sqrt()).abs() < 1e-12);
558    }
559
560    #[test]
561    fn reduce_l2_squared_matches_sum_squares_for_non_nulls() {
562        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
563        let ds = DataSet::new(
564            schema,
565            vec![vec![Value::Float64(2.0)], vec![Value::Float64(3.0)]],
566        );
567        let ss = match reduce(&ds, "x", ReduceOp::SumSquares).unwrap() {
568            Value::Float64(v) => v,
569            other => panic!("expected Float64, got {other:?}"),
570        };
571        let l2 = match reduce(&ds, "x", ReduceOp::L2Norm).unwrap() {
572            Value::Float64(v) => v,
573            other => panic!("expected Float64, got {other:?}"),
574        };
575        assert!((l2 * l2 - ss).abs() < 1e-12);
576    }
577}