1use std::collections::HashSet;
4
5use crate::types::{DataSet, DataType, Value};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum VarianceKind {
10 Population,
12 Sample,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ReduceOp {
19 Count,
21 Sum,
23 Min,
25 Max,
27 Mean,
29 Variance(VarianceKind),
31 StdDev(VarianceKind),
33 SumSquares,
35 L2Norm,
37 CountDistinctNonNull,
39}
40
41pub fn reduce(dataset: &DataSet, column: &str, op: ReduceOp) -> Option<Value> {
50 let idx = dataset.schema.index_of(column)?;
51
52 match op {
53 ReduceOp::Count => Some(Value::Int64(dataset.row_count() as i64)),
54 ReduceOp::CountDistinctNonNull => {
55 let field = dataset.schema.fields.get(idx)?;
56 reduce_count_distinct_non_null(dataset, idx, &field.data_type)
57 }
58 ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => match dataset.schema.fields.get(idx) {
59 Some(field) => reduce_numeric_typed(dataset, idx, field.data_type.clone(), op),
60 None => None,
61 },
62 ReduceOp::Mean
63 | ReduceOp::Variance(_)
64 | ReduceOp::StdDev(_)
65 | ReduceOp::SumSquares
66 | ReduceOp::L2Norm => match dataset.schema.fields.get(idx) {
67 Some(field) => reduce_numeric_float_stats(dataset, idx, field.data_type.clone(), op),
68 None => None,
69 },
70 }
71}
72
73#[derive(Default)]
74pub(crate) struct Welford {
75 n: u64,
76 mean: f64,
77 m2: f64,
78}
79
80impl Welford {
81 pub(crate) fn observe(&mut self, x: f64) {
82 self.n += 1;
83 let delta = x - self.mean;
84 self.mean += delta / self.n as f64;
85 let delta2 = x - self.mean;
86 self.m2 += delta * delta2;
87 }
88
89 pub(crate) fn mean(&self) -> Option<f64> {
90 (self.n > 0).then_some(self.mean)
91 }
92
93 pub(crate) fn variance(&self, kind: VarianceKind) -> Option<f64> {
94 if self.n == 0 {
95 return None;
96 }
97 match kind {
98 VarianceKind::Population => Some(self.m2 / self.n as f64),
99 VarianceKind::Sample => {
100 if self.n < 2 {
101 None
102 } else {
103 Some(self.m2 / (self.n - 1) as f64)
104 }
105 }
106 }
107 }
108
109 pub(crate) fn observation_count(&self) -> u64 {
110 self.n
111 }
112}
113
114fn reduce_numeric_float_stats(
115 dataset: &DataSet,
116 idx: usize,
117 data_type: DataType,
118 op: ReduceOp,
119) -> Option<Value> {
120 match data_type {
121 dt @ (DataType::Int64 | DataType::Float64) => {
122 let is_int = matches!(dt, DataType::Int64);
123 let mut w = Welford::default();
124 let mut sum_squares = 0.0_f64;
125 let mut any = false;
126
127 for row in &dataset.rows {
128 let x = match row.get(idx) {
129 Some(Value::Null) | None => None,
130 Some(Value::Int64(v)) if is_int => Some(*v as f64),
131 Some(Value::Float64(v)) if !is_int => Some(*v),
132 Some(_) => None,
133 };
134 if let Some(x) = x {
135 any = true;
136 w.observe(x);
137 sum_squares += x * x;
138 }
139 }
140
141 if !any {
142 return Some(Value::Null);
143 }
144
145 let out = match op {
146 ReduceOp::Mean => Value::Float64(w.mean().expect("n > 0")),
147 ReduceOp::Variance(kind) => match w.variance(kind) {
148 Some(v) => Value::Float64(v),
149 None => Value::Null,
150 },
151 ReduceOp::StdDev(kind) => match w.variance(kind) {
152 Some(v) => Value::Float64(v.sqrt()),
153 None => Value::Null,
154 },
155 ReduceOp::SumSquares => Value::Float64(sum_squares),
156 ReduceOp::L2Norm => Value::Float64(sum_squares.sqrt()),
157 _ => unreachable!("caller only dispatches float stats ops"),
158 };
159 Some(out)
160 }
161 _ => Some(Value::Null),
162 }
163}
164
165fn reduce_count_distinct_non_null(
166 dataset: &DataSet,
167 idx: usize,
168 data_type: &DataType,
169) -> Option<Value> {
170 let n = match data_type {
171 DataType::Int64 => {
172 let mut set = HashSet::new();
173 for row in &dataset.rows {
174 if let Some(Value::Int64(v)) = row.get(idx) {
175 set.insert(*v);
176 }
177 }
178 set.len() as i64
179 }
180 DataType::Float64 => {
181 let mut set = HashSet::new();
182 for row in &dataset.rows {
183 if let Some(Value::Float64(v)) = row.get(idx) {
184 set.insert(v.to_bits());
185 }
186 }
187 set.len() as i64
188 }
189 DataType::Bool => {
190 let mut set = HashSet::new();
191 for row in &dataset.rows {
192 if let Some(Value::Bool(v)) = row.get(idx) {
193 set.insert(*v);
194 }
195 }
196 set.len() as i64
197 }
198 DataType::Utf8 => {
199 let mut set = HashSet::new();
200 for row in &dataset.rows {
201 if let Some(Value::Utf8(s)) = row.get(idx) {
202 set.insert(s.clone());
203 }
204 }
205 set.len() as i64
206 }
207 };
208 Some(Value::Int64(n))
209}
210
211fn reduce_numeric_typed(
212 dataset: &DataSet,
213 idx: usize,
214 data_type: DataType,
215 op: ReduceOp,
216) -> Option<Value> {
217 match data_type {
218 DataType::Int64 => {
219 let mut acc: Option<i64> = None;
220 for row in &dataset.rows {
221 match row.get(idx) {
222 Some(Value::Null) | None => {}
223 Some(Value::Int64(v)) => {
224 acc = Some(match (op, acc) {
225 (ReduceOp::Sum, Some(a)) => a + v,
226 (ReduceOp::Sum, None) => *v,
227 (ReduceOp::Min, Some(a)) => a.min(*v),
228 (ReduceOp::Min, None) => *v,
229 (ReduceOp::Max, Some(a)) => a.max(*v),
230 (ReduceOp::Max, None) => *v,
231 _ => unreachable!("non-numeric op handled earlier"),
232 });
233 }
234 Some(_) => {}
235 }
236 }
237 Some(acc.map(Value::Int64).unwrap_or(Value::Null))
238 }
239 DataType::Float64 => {
240 let mut acc: Option<f64> = None;
241 for row in &dataset.rows {
242 match row.get(idx) {
243 Some(Value::Null) | None => {}
244 Some(Value::Float64(v)) => {
245 acc = Some(match (op, acc) {
246 (ReduceOp::Sum, Some(a)) => a + v,
247 (ReduceOp::Sum, None) => *v,
248 (ReduceOp::Min, Some(a)) => a.min(*v),
249 (ReduceOp::Min, None) => *v,
250 (ReduceOp::Max, Some(a)) => a.max(*v),
251 (ReduceOp::Max, None) => *v,
252 _ => unreachable!("non-numeric op handled earlier"),
253 });
254 }
255 Some(_) => {}
256 }
257 }
258 Some(acc.map(Value::Float64).unwrap_or(Value::Null))
259 }
260 _ => Some(Value::Null),
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::{ReduceOp, VarianceKind, reduce};
267 use crate::types::{DataSet, DataType, Field, Schema, Value};
268
269 fn numeric_dataset_with_nulls() -> DataSet {
270 let schema = Schema::new(vec![
271 Field::new("id", DataType::Int64),
272 Field::new("score", DataType::Float64),
273 ]);
274
275 let rows = vec![
276 vec![Value::Int64(1), Value::Float64(10.0)],
277 vec![Value::Int64(2), Value::Null],
278 vec![Value::Int64(3), Value::Float64(5.5)],
279 ];
280
281 DataSet::new(schema, rows)
282 }
283
284 #[test]
285 fn reduce_count_counts_rows() {
286 let ds = numeric_dataset_with_nulls();
287 assert_eq!(reduce(&ds, "score", ReduceOp::Count), Some(Value::Int64(3)));
288 assert_eq!(reduce(&ds, "id", ReduceOp::Count), Some(Value::Int64(3)));
289 }
290
291 #[test]
292 fn reduce_sum_ignores_nulls_and_preserves_type() {
293 let ds = numeric_dataset_with_nulls();
294 assert_eq!(
295 reduce(&ds, "score", ReduceOp::Sum),
296 Some(Value::Float64(15.5))
297 );
298 assert_eq!(reduce(&ds, "id", ReduceOp::Sum), Some(Value::Int64(6)));
299 }
300
301 #[test]
302 fn reduce_min_max_ignore_nulls() {
303 let ds = numeric_dataset_with_nulls();
304 assert_eq!(
305 reduce(&ds, "score", ReduceOp::Min),
306 Some(Value::Float64(5.5))
307 );
308 assert_eq!(
309 reduce(&ds, "score", ReduceOp::Max),
310 Some(Value::Float64(10.0))
311 );
312 assert_eq!(reduce(&ds, "id", ReduceOp::Min), Some(Value::Int64(1)));
313 assert_eq!(reduce(&ds, "id", ReduceOp::Max), Some(Value::Int64(3)));
314 }
315
316 #[test]
317 fn reduce_returns_none_for_missing_column() {
318 let ds = numeric_dataset_with_nulls();
319 assert_eq!(reduce(&ds, "missing", ReduceOp::Count), None);
320 assert_eq!(reduce(&ds, "missing", ReduceOp::Sum), None);
321 }
322
323 #[test]
324 fn reduce_numeric_returns_null_if_all_values_null() {
325 let schema = Schema::new(vec![Field::new("score", DataType::Float64)]);
326 let ds = DataSet::new(schema, vec![vec![Value::Null], vec![Value::Null]]);
327 assert_eq!(reduce(&ds, "score", ReduceOp::Sum), Some(Value::Null));
328 assert_eq!(reduce(&ds, "score", ReduceOp::Min), Some(Value::Null));
329 assert_eq!(reduce(&ds, "score", ReduceOp::Max), Some(Value::Null));
330 assert_eq!(reduce(&ds, "score", ReduceOp::Mean), Some(Value::Null));
331 assert_eq!(
332 reduce(&ds, "score", ReduceOp::Variance(VarianceKind::Population)),
333 Some(Value::Null)
334 );
335 assert_eq!(
336 reduce(&ds, "score", ReduceOp::StdDev(VarianceKind::Sample)),
337 Some(Value::Null)
338 );
339 }
340
341 #[test]
342 fn reduce_mean_float_and_int() {
343 let ds = numeric_dataset_with_nulls();
344 assert_eq!(
345 reduce(&ds, "score", ReduceOp::Mean),
346 Some(Value::Float64(7.75))
347 );
348 assert_eq!(reduce(&ds, "id", ReduceOp::Mean), Some(Value::Float64(2.0)));
349 }
350
351 #[test]
352 fn reduce_variance_std_known_values() {
353 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
354 let ds = DataSet::new(
355 schema,
356 vec![
357 vec![Value::Float64(1.0)],
358 vec![Value::Float64(2.0)],
359 vec![Value::Float64(3.0)],
360 ],
361 );
362 let pop = 2.0 / 3.0;
363 assert_eq!(
364 reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Population)),
365 Some(Value::Float64(pop))
366 );
367 assert_eq!(
368 reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Sample)),
369 Some(Value::Float64(1.0))
370 );
371 let std_pop = reduce(&ds, "x", ReduceOp::StdDev(VarianceKind::Population)).unwrap();
372 match std_pop {
373 Value::Float64(v) => assert!((v - pop.sqrt()).abs() < 1e-12),
374 other => panic!("expected Float64, got {other:?}"),
375 }
376 }
377
378 #[test]
379 fn reduce_sample_variance_single_value_is_null() {
380 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
381 let ds = DataSet::new(schema, vec![vec![Value::Float64(42.0)]]);
382 assert_eq!(
383 reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Sample)),
384 Some(Value::Null)
385 );
386 }
387
388 #[test]
389 fn reduce_population_variance_single_value_is_zero() {
390 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
391 let ds = DataSet::new(schema, vec![vec![Value::Float64(42.0)]]);
392 assert_eq!(
393 reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Population)),
394 Some(Value::Float64(0.0))
395 );
396 let std0 = reduce(&ds, "x", ReduceOp::StdDev(VarianceKind::Population)).unwrap();
397 match std0 {
398 Value::Float64(v) => assert_eq!(v, 0.0),
399 other => panic!("expected Float64, got {other:?}"),
400 }
401 }
402
403 #[test]
404 fn reduce_int64_mean_sum_squares_and_distinct() {
405 let schema = Schema::new(vec![Field::new("k", DataType::Int64)]);
406 let ds = DataSet::new(
407 schema,
408 vec![
409 vec![Value::Int64(2)],
410 vec![Value::Int64(3)],
411 vec![Value::Null],
412 ],
413 );
414 assert_eq!(reduce(&ds, "k", ReduceOp::Mean), Some(Value::Float64(2.5)));
415 assert_eq!(
416 reduce(&ds, "k", ReduceOp::SumSquares),
417 Some(Value::Float64(13.0))
418 );
419 assert_eq!(
420 reduce(&ds, "k", ReduceOp::L2Norm),
421 Some(Value::Float64(13.0_f64.sqrt()))
422 );
423 assert_eq!(
424 reduce(&ds, "k", ReduceOp::CountDistinctNonNull),
425 Some(Value::Int64(2))
426 );
427 }
428
429 #[test]
430 fn reduce_sum_squares_and_l2() {
431 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
432 let ds = DataSet::new(
433 schema,
434 vec![
435 vec![Value::Float64(3.0)],
436 vec![Value::Float64(4.0)],
437 vec![Value::Null],
438 ],
439 );
440 assert_eq!(
441 reduce(&ds, "x", ReduceOp::SumSquares),
442 Some(Value::Float64(25.0))
443 );
444 assert_eq!(
445 reduce(&ds, "x", ReduceOp::L2Norm),
446 Some(Value::Float64(5.0))
447 );
448 }
449
450 #[test]
451 fn reduce_count_distinct_non_null() {
452 let schema = Schema::new(vec![
453 Field::new("f", DataType::Float64),
454 Field::new("s", DataType::Utf8),
455 ]);
456 let ds = DataSet::new(
457 schema,
458 vec![
459 vec![Value::Float64(1.0), Value::Utf8("a".to_string())],
460 vec![Value::Float64(1.0), Value::Utf8("b".to_string())],
461 vec![Value::Null, Value::Null],
462 ],
463 );
464 assert_eq!(
465 reduce(&ds, "f", ReduceOp::CountDistinctNonNull),
466 Some(Value::Int64(1))
467 );
468 assert_eq!(
469 reduce(&ds, "s", ReduceOp::CountDistinctNonNull),
470 Some(Value::Int64(2))
471 );
472 }
473
474 #[test]
475 fn reduce_new_ops_return_none_for_missing_column() {
476 let ds = numeric_dataset_with_nulls();
477 assert_eq!(reduce(&ds, "nope", ReduceOp::Mean), None);
478 assert_eq!(
479 reduce(&ds, "nope", ReduceOp::Variance(VarianceKind::Sample)),
480 None
481 );
482 assert_eq!(reduce(&ds, "nope", ReduceOp::CountDistinctNonNull), None);
483 }
484
485 #[test]
486 fn reduce_sum_squares_and_l2_all_null() {
487 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
488 let ds = DataSet::new(schema, vec![vec![Value::Null]]);
489 assert_eq!(reduce(&ds, "x", ReduceOp::SumSquares), Some(Value::Null));
490 assert_eq!(reduce(&ds, "x", ReduceOp::L2Norm), Some(Value::Null));
491 }
492
493 #[test]
494 fn reduce_count_distinct_bool_and_empty_rows() {
495 let schema = Schema::new(vec![Field::new("b", DataType::Bool)]);
496 let ds = DataSet::new(schema.clone(), vec![]);
497 assert_eq!(
498 reduce(&ds, "b", ReduceOp::CountDistinctNonNull),
499 Some(Value::Int64(0))
500 );
501
502 let ds = DataSet::new(
503 schema,
504 vec![
505 vec![Value::Bool(true)],
506 vec![Value::Bool(false)],
507 vec![Value::Bool(true)],
508 vec![Value::Null],
509 ],
510 );
511 assert_eq!(
512 reduce(&ds, "b", ReduceOp::CountDistinctNonNull),
513 Some(Value::Int64(2))
514 );
515 }
516
517 #[test]
518 fn reduce_mean_variance_null_for_non_numeric_column() {
519 let schema = Schema::new(vec![Field::new("label", DataType::Utf8)]);
520 let ds = DataSet::new(
521 schema,
522 vec![
523 vec![Value::Utf8("a".to_string())],
524 vec![Value::Utf8("b".to_string())],
525 ],
526 );
527 assert_eq!(reduce(&ds, "label", ReduceOp::Mean), Some(Value::Null));
528 assert_eq!(
529 reduce(&ds, "label", ReduceOp::Variance(VarianceKind::Population)),
530 Some(Value::Null)
531 );
532 assert_eq!(
533 reduce(&ds, "label", ReduceOp::SumSquares),
534 Some(Value::Null)
535 );
536 }
537
538 #[test]
539 fn reduce_std_dev_sample_matches_sqrt_of_sample_variance() {
540 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
541 let ds = DataSet::new(
542 schema,
543 vec![
544 vec![Value::Float64(0.0)],
545 vec![Value::Float64(4.0)],
546 vec![Value::Float64(8.0)],
547 ],
548 );
549 let var_s = match reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Sample)).unwrap() {
550 Value::Float64(v) => v,
551 other => panic!("expected Float64, got {other:?}"),
552 };
553 let std_s = match reduce(&ds, "x", ReduceOp::StdDev(VarianceKind::Sample)).unwrap() {
554 Value::Float64(v) => v,
555 other => panic!("expected Float64, got {other:?}"),
556 };
557 assert!((std_s - var_s.sqrt()).abs() < 1e-12);
558 }
559
560 #[test]
561 fn reduce_l2_squared_matches_sum_squares_for_non_nulls() {
562 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
563 let ds = DataSet::new(
564 schema,
565 vec![vec![Value::Float64(2.0)], vec![Value::Float64(3.0)]],
566 );
567 let ss = match reduce(&ds, "x", ReduceOp::SumSquares).unwrap() {
568 Value::Float64(v) => v,
569 other => panic!("expected Float64, got {other:?}"),
570 };
571 let l2 = match reduce(&ds, "x", ReduceOp::L2Norm).unwrap() {
572 Value::Float64(v) => v,
573 other => panic!("expected Float64, got {other:?}"),
574 };
575 assert!((l2 * l2 - ss).abs() < 1e-12);
576 }
577}