rust_data_processing/processing/
multi.rs1use std::cmp::Ordering;
7use std::collections::HashMap;
8
9use crate::types::{DataSet, DataType, Value};
10
11use super::reduce::{VarianceKind, Welford};
12
13#[derive(Debug, Clone, PartialEq)]
15pub struct FeatureMeanStd {
16 pub mean: Value,
17 pub std_dev: Value,
18}
19
20pub fn feature_wise_mean_std(
26 dataset: &DataSet,
27 columns: &[&str],
28 std_kind: VarianceKind,
29) -> Option<Vec<(String, FeatureMeanStd)>> {
30 let mut meta: Vec<(String, usize, DataType)> = Vec::with_capacity(columns.len());
31 for &name in columns {
32 let idx = dataset.schema.index_of(name)?;
33 let dt = dataset.schema.fields.get(idx)?.data_type.clone();
34 if !matches!(dt, DataType::Int64 | DataType::Float64) {
35 return None;
36 }
37 meta.push((name.to_string(), idx, dt));
38 }
39
40 let mut w: Vec<Welford> = (0..meta.len()).map(|_| Welford::default()).collect();
41 for row in &dataset.rows {
42 for (i, (_, idx, dt)) in meta.iter().enumerate() {
43 let x = match (row.get(*idx), dt) {
44 (Some(Value::Int64(v)), DataType::Int64) => Some(*v as f64),
45 (Some(Value::Float64(v)), DataType::Float64) => Some(*v),
46 _ => None,
47 };
48 if let Some(x) = x {
49 w[i].observe(x);
50 }
51 }
52 }
53
54 let mut out = Vec::with_capacity(meta.len());
55 for ((name, _, _), wf) in meta.into_iter().zip(w) {
56 let mean = wf.mean().map(Value::Float64).unwrap_or(Value::Null);
57 let std_dev = wf
58 .variance(std_kind)
59 .map(|v| Value::Float64(v.sqrt()))
60 .unwrap_or(Value::Null);
61 let (mean, std_dev) = if wf.observation_count() == 0 {
62 (Value::Null, Value::Null)
63 } else {
64 (mean, std_dev)
65 };
66 out.push((name, FeatureMeanStd { mean, std_dev }));
67 }
68 Some(out)
69}
70
71fn cmp_non_null_values(a: &Value, b: &Value) -> Option<Ordering> {
72 match (a, b) {
73 (Value::Int64(x), Value::Int64(y)) => Some(x.cmp(y)),
74 (Value::Float64(x), Value::Float64(y)) => Some(x.total_cmp(y)),
75 (Value::Utf8(x), Value::Utf8(y)) => Some(x.cmp(y)),
76 (Value::Bool(x), Value::Bool(y)) => Some(x.cmp(y)),
77 _ => None,
78 }
79}
80
81pub fn arg_max_row(dataset: &DataSet, column: &str) -> Option<Option<(usize, Value)>> {
85 let idx = dataset.schema.index_of(column)?;
86 let mut best: Option<(usize, Value)> = None;
87 for (r, row) in dataset.rows.iter().enumerate() {
88 let Some(cell) = row.get(idx) else {
89 continue;
90 };
91 if matches!(cell, Value::Null) {
92 continue;
93 }
94 match &best {
95 None => best = Some((r, cell.clone())),
96 Some((_, bv)) => {
97 if cmp_non_null_values(cell, bv) == Some(Ordering::Greater) {
98 best = Some((r, cell.clone()));
99 }
100 }
101 }
102 }
103 Some(best)
104}
105
106pub fn arg_min_row(dataset: &DataSet, column: &str) -> Option<Option<(usize, Value)>> {
108 let idx = dataset.schema.index_of(column)?;
109 let mut best: Option<(usize, Value)> = None;
110 for (r, row) in dataset.rows.iter().enumerate() {
111 let Some(cell) = row.get(idx) else {
112 continue;
113 };
114 if matches!(cell, Value::Null) {
115 continue;
116 }
117 match &best {
118 None => best = Some((r, cell.clone())),
119 Some((_, bv)) => {
120 if cmp_non_null_values(cell, bv) == Some(Ordering::Less) {
121 best = Some((r, cell.clone()));
122 }
123 }
124 }
125 }
126 Some(best)
127}
128
129fn freq_bucket_key(v: &Value) -> Option<String> {
130 match v {
131 Value::Null => None,
132 Value::Int64(x) => Some(format!("i:{x}")),
133 Value::Float64(x) => Some(format!("f:{}", x.to_bits())),
134 Value::Bool(b) => Some(format!("b:{b}")),
135 Value::Utf8(s) => Some(format!("s:{s}")),
136 }
137}
138
139fn value_sort_key(v: &Value) -> String {
140 match v {
141 Value::Null => String::new(),
142 Value::Int64(x) => format!("i:{x:020}"),
143 Value::Float64(x) => format!("f:{:020}", x.to_bits()),
144 Value::Bool(b) => format!("b:{b}"),
145 Value::Utf8(s) => format!("s:{s}"),
146 }
147}
148
149pub fn top_k_by_frequency(dataset: &DataSet, column: &str, k: usize) -> Option<Vec<(Value, i64)>> {
154 let idx = dataset.schema.index_of(column)?;
155 let mut buckets: HashMap<String, (Value, i64)> = HashMap::new();
156 for row in &dataset.rows {
157 let Some(cell) = row.get(idx) else {
158 continue;
159 };
160 let Some(key) = freq_bucket_key(cell) else {
161 continue;
162 };
163 buckets
164 .entry(key)
165 .and_modify(|(_, c)| *c += 1)
166 .or_insert_with(|| (cell.clone(), 1));
167 }
168 let mut v: Vec<(Value, i64)> = buckets.into_values().collect();
169 v.sort_by(|a, b| {
170 b.1.cmp(&a.1)
171 .then_with(|| value_sort_key(&a.0).cmp(&value_sort_key(&b.0)))
172 });
173 v.truncate(k);
174 Some(v)
175}
176
177#[cfg(test)]
178mod tests {
179 use super::{arg_max_row, arg_min_row, feature_wise_mean_std, top_k_by_frequency};
180 use crate::processing::VarianceKind;
181 use crate::types::{DataSet, DataType, Field, Schema, Value};
182
183 #[test]
184 fn feature_wise_mean_std_two_columns_one_pass() {
185 let schema = Schema::new(vec![
186 Field::new("a", DataType::Int64),
187 Field::new("b", DataType::Float64),
188 ]);
189 let ds = DataSet::new(
190 schema,
191 vec![
192 vec![Value::Int64(10), Value::Float64(1.0)],
193 vec![Value::Int64(20), Value::Null],
194 vec![Value::Null, Value::Float64(3.0)],
195 ],
196 );
197 let got = feature_wise_mean_std(&ds, &["a", "b"], VarianceKind::Sample).unwrap();
198 assert_eq!(got[0].0, "a");
199 assert_eq!(got[0].1.mean, Value::Float64(15.0));
200 let std_a = match &got[0].1.std_dev {
201 Value::Float64(x) => *x,
202 o => panic!("{o:?}"),
203 };
204 assert!((std_a - 50.0_f64.sqrt()).abs() < 1e-9);
205 assert_eq!(got[1].0, "b");
206 assert_eq!(got[1].1.mean, Value::Float64(2.0));
207 let std_b = match &got[1].1.std_dev {
208 Value::Float64(x) => *x,
209 o => panic!("{o:?}"),
210 };
211 assert!((std_b - 2.0_f64.sqrt()).abs() < 1e-9);
212 }
213
214 #[test]
215 fn feature_wise_returns_none_for_unknown_or_non_numeric_column() {
216 let schema = Schema::new(vec![
217 Field::new("a", DataType::Int64),
218 Field::new("t", DataType::Utf8),
219 ]);
220 let ds = DataSet::new(
221 schema,
222 vec![vec![Value::Int64(1), Value::Utf8("x".to_string())]],
223 );
224 assert!(feature_wise_mean_std(&ds, &["missing"], VarianceKind::Sample).is_none());
225 assert!(feature_wise_mean_std(&ds, &["a", "t"], VarianceKind::Sample).is_none());
226 }
227
228 #[test]
229 fn arg_max_min_first_on_ties() {
230 let schema = Schema::new(vec![Field::new("x", DataType::Int64)]);
231 let ds = DataSet::new(
232 schema,
233 vec![
234 vec![Value::Int64(1)],
235 vec![Value::Int64(3)],
236 vec![Value::Int64(3)],
237 vec![Value::Null],
238 ],
239 );
240 assert_eq!(arg_max_row(&ds, "x"), Some(Some((1, Value::Int64(3)))));
241 assert_eq!(arg_min_row(&ds, "x"), Some(Some((0, Value::Int64(1)))));
242 }
243
244 #[test]
245 fn top_k_frequency_ordering() {
246 let schema = Schema::new(vec![Field::new("label", DataType::Utf8)]);
247 let ds = DataSet::new(
248 schema,
249 vec![
250 vec![Value::Utf8("a".to_string())],
251 vec![Value::Utf8("b".to_string())],
252 vec![Value::Utf8("a".to_string())],
253 vec![Value::Utf8("c".to_string())],
254 vec![Value::Null],
255 ],
256 );
257 let top = top_k_by_frequency(&ds, "label", 2).unwrap();
258 assert_eq!(top.len(), 2);
259 assert_eq!(top[0], (Value::Utf8("a".to_string()), 2));
260 assert_eq!(top[1].1, 1);
261 }
262}