rust_data_processing/processing/
multi.rs

1//! Multi-column and row-index reductions over a [`DataSet`](crate::types::DataSet).
2//!
3//! Aggregate semantics (nulls, all-null groups, casting) are documented in
4//! `docs/REDUCE_AGG_SEMANTICS.md` at the repository root.
5
6use std::cmp::Ordering;
7use std::collections::HashMap;
8
9use crate::types::{DataSet, DataType, Value};
10
11use super::reduce::{VarianceKind, Welford};
12
13/// Per-column mean and standard deviation (square root of variance under `std_kind`).
14#[derive(Debug, Clone, PartialEq)]
15pub struct FeatureMeanStd {
16    pub mean: Value,
17    pub std_dev: Value,
18}
19
20/// One pass over all rows: compute mean and std dev for each listed **numeric** column (`Int64` /
21/// `Float64`). Nulls are ignored. If a column has no non-null values, both fields are
22/// [`Value::Null`]. Sample std dev is undefined for fewer than two values → [`Value::Null`].
23///
24/// Returns [`None`] if any name is missing from the schema or is not numeric.
25pub fn feature_wise_mean_std(
26    dataset: &DataSet,
27    columns: &[&str],
28    std_kind: VarianceKind,
29) -> Option<Vec<(String, FeatureMeanStd)>> {
30    let mut meta: Vec<(String, usize, DataType)> = Vec::with_capacity(columns.len());
31    for &name in columns {
32        let idx = dataset.schema.index_of(name)?;
33        let dt = dataset.schema.fields.get(idx)?.data_type.clone();
34        if !matches!(dt, DataType::Int64 | DataType::Float64) {
35            return None;
36        }
37        meta.push((name.to_string(), idx, dt));
38    }
39
40    let mut w: Vec<Welford> = (0..meta.len()).map(|_| Welford::default()).collect();
41    for row in &dataset.rows {
42        for (i, (_, idx, dt)) in meta.iter().enumerate() {
43            let x = match (row.get(*idx), dt) {
44                (Some(Value::Int64(v)), DataType::Int64) => Some(*v as f64),
45                (Some(Value::Float64(v)), DataType::Float64) => Some(*v),
46                _ => None,
47            };
48            if let Some(x) = x {
49                w[i].observe(x);
50            }
51        }
52    }
53
54    let mut out = Vec::with_capacity(meta.len());
55    for ((name, _, _), wf) in meta.into_iter().zip(w) {
56        let mean = wf.mean().map(Value::Float64).unwrap_or(Value::Null);
57        let std_dev = wf
58            .variance(std_kind)
59            .map(|v| Value::Float64(v.sqrt()))
60            .unwrap_or(Value::Null);
61        let (mean, std_dev) = if wf.observation_count() == 0 {
62            (Value::Null, Value::Null)
63        } else {
64            (mean, std_dev)
65        };
66        out.push((name, FeatureMeanStd { mean, std_dev }));
67    }
68    Some(out)
69}
70
71fn cmp_non_null_values(a: &Value, b: &Value) -> Option<Ordering> {
72    match (a, b) {
73        (Value::Int64(x), Value::Int64(y)) => Some(x.cmp(y)),
74        (Value::Float64(x), Value::Float64(y)) => Some(x.total_cmp(y)),
75        (Value::Utf8(x), Value::Utf8(y)) => Some(x.cmp(y)),
76        (Value::Bool(x), Value::Bool(y)) => Some(x.cmp(y)),
77        _ => None,
78    }
79}
80
81/// Returns [`None`] if `column` is not in the schema. Otherwise [`Some(None)`] if there is no
82/// non-null comparable value, or [`Some(Some((row_index, value)))`] for the **first** row
83/// attaining the maximum (stable tie-break).
84pub fn arg_max_row(dataset: &DataSet, column: &str) -> Option<Option<(usize, Value)>> {
85    let idx = dataset.schema.index_of(column)?;
86    let mut best: Option<(usize, Value)> = None;
87    for (r, row) in dataset.rows.iter().enumerate() {
88        let Some(cell) = row.get(idx) else {
89            continue;
90        };
91        if matches!(cell, Value::Null) {
92            continue;
93        }
94        match &best {
95            None => best = Some((r, cell.clone())),
96            Some((_, bv)) => {
97                if cmp_non_null_values(cell, bv) == Some(Ordering::Greater) {
98                    best = Some((r, cell.clone()));
99                }
100            }
101        }
102    }
103    Some(best)
104}
105
106/// Same as [`arg_max_row`] for the minimum.
107pub fn arg_min_row(dataset: &DataSet, column: &str) -> Option<Option<(usize, Value)>> {
108    let idx = dataset.schema.index_of(column)?;
109    let mut best: Option<(usize, Value)> = None;
110    for (r, row) in dataset.rows.iter().enumerate() {
111        let Some(cell) = row.get(idx) else {
112            continue;
113        };
114        if matches!(cell, Value::Null) {
115            continue;
116        }
117        match &best {
118            None => best = Some((r, cell.clone())),
119            Some((_, bv)) => {
120                if cmp_non_null_values(cell, bv) == Some(Ordering::Less) {
121                    best = Some((r, cell.clone()));
122                }
123            }
124        }
125    }
126    Some(best)
127}
128
129fn freq_bucket_key(v: &Value) -> Option<String> {
130    match v {
131        Value::Null => None,
132        Value::Int64(x) => Some(format!("i:{x}")),
133        Value::Float64(x) => Some(format!("f:{}", x.to_bits())),
134        Value::Bool(b) => Some(format!("b:{b}")),
135        Value::Utf8(s) => Some(format!("s:{s}")),
136    }
137}
138
139fn value_sort_key(v: &Value) -> String {
140    match v {
141        Value::Null => String::new(),
142        Value::Int64(x) => format!("i:{x:020}"),
143        Value::Float64(x) => format!("f:{:020}", x.to_bits()),
144        Value::Bool(b) => format!("b:{b}"),
145        Value::Utf8(s) => format!("s:{s}"),
146    }
147}
148
149/// Non-null value frequencies; returns the top `k` pairs by count (desc), breaking ties by
150/// [`value_sort_key`] ascending. `k == 0` yields an empty vector.
151///
152/// Returns [`None`] if the column is not in the schema.
153pub fn top_k_by_frequency(dataset: &DataSet, column: &str, k: usize) -> Option<Vec<(Value, i64)>> {
154    let idx = dataset.schema.index_of(column)?;
155    let mut buckets: HashMap<String, (Value, i64)> = HashMap::new();
156    for row in &dataset.rows {
157        let Some(cell) = row.get(idx) else {
158            continue;
159        };
160        let Some(key) = freq_bucket_key(cell) else {
161            continue;
162        };
163        buckets
164            .entry(key)
165            .and_modify(|(_, c)| *c += 1)
166            .or_insert_with(|| (cell.clone(), 1));
167    }
168    let mut v: Vec<(Value, i64)> = buckets.into_values().collect();
169    v.sort_by(|a, b| {
170        b.1.cmp(&a.1)
171            .then_with(|| value_sort_key(&a.0).cmp(&value_sort_key(&b.0)))
172    });
173    v.truncate(k);
174    Some(v)
175}
176
177#[cfg(test)]
178mod tests {
179    use super::{arg_max_row, arg_min_row, feature_wise_mean_std, top_k_by_frequency};
180    use crate::processing::VarianceKind;
181    use crate::types::{DataSet, DataType, Field, Schema, Value};
182
183    #[test]
184    fn feature_wise_mean_std_two_columns_one_pass() {
185        let schema = Schema::new(vec![
186            Field::new("a", DataType::Int64),
187            Field::new("b", DataType::Float64),
188        ]);
189        let ds = DataSet::new(
190            schema,
191            vec![
192                vec![Value::Int64(10), Value::Float64(1.0)],
193                vec![Value::Int64(20), Value::Null],
194                vec![Value::Null, Value::Float64(3.0)],
195            ],
196        );
197        let got = feature_wise_mean_std(&ds, &["a", "b"], VarianceKind::Sample).unwrap();
198        assert_eq!(got[0].0, "a");
199        assert_eq!(got[0].1.mean, Value::Float64(15.0));
200        let std_a = match &got[0].1.std_dev {
201            Value::Float64(x) => *x,
202            o => panic!("{o:?}"),
203        };
204        assert!((std_a - 50.0_f64.sqrt()).abs() < 1e-9);
205        assert_eq!(got[1].0, "b");
206        assert_eq!(got[1].1.mean, Value::Float64(2.0));
207        let std_b = match &got[1].1.std_dev {
208            Value::Float64(x) => *x,
209            o => panic!("{o:?}"),
210        };
211        assert!((std_b - 2.0_f64.sqrt()).abs() < 1e-9);
212    }
213
214    #[test]
215    fn feature_wise_returns_none_for_unknown_or_non_numeric_column() {
216        let schema = Schema::new(vec![
217            Field::new("a", DataType::Int64),
218            Field::new("t", DataType::Utf8),
219        ]);
220        let ds = DataSet::new(
221            schema,
222            vec![vec![Value::Int64(1), Value::Utf8("x".to_string())]],
223        );
224        assert!(feature_wise_mean_std(&ds, &["missing"], VarianceKind::Sample).is_none());
225        assert!(feature_wise_mean_std(&ds, &["a", "t"], VarianceKind::Sample).is_none());
226    }
227
228    #[test]
229    fn arg_max_min_first_on_ties() {
230        let schema = Schema::new(vec![Field::new("x", DataType::Int64)]);
231        let ds = DataSet::new(
232            schema,
233            vec![
234                vec![Value::Int64(1)],
235                vec![Value::Int64(3)],
236                vec![Value::Int64(3)],
237                vec![Value::Null],
238            ],
239        );
240        assert_eq!(arg_max_row(&ds, "x"), Some(Some((1, Value::Int64(3)))));
241        assert_eq!(arg_min_row(&ds, "x"), Some(Some((0, Value::Int64(1)))));
242    }
243
244    #[test]
245    fn top_k_frequency_ordering() {
246        let schema = Schema::new(vec![Field::new("label", DataType::Utf8)]);
247        let ds = DataSet::new(
248            schema,
249            vec![
250                vec![Value::Utf8("a".to_string())],
251                vec![Value::Utf8("b".to_string())],
252                vec![Value::Utf8("a".to_string())],
253                vec![Value::Utf8("c".to_string())],
254                vec![Value::Null],
255            ],
256        );
257        let top = top_k_by_frequency(&ds, "label", 2).unwrap();
258        assert_eq!(top.len(), 2);
259        assert_eq!(top[0], (Value::Utf8("a".to_string()), 2));
260        assert_eq!(top[1].1, 1);
261    }
262}