rust_data_processing/processing/
filter.rs

1//! Row filtering for [`crate::types::DataSet`].
2
3use crate::types::{DataSet, Value};
4
5/// Returns a new [`DataSet`] containing only rows for which `predicate` returns `true`.
6///
7/// This is a convenience wrapper around [`DataSet::filter_rows`].
8pub fn filter<F>(dataset: &DataSet, predicate: F) -> DataSet
9where
10    F: FnMut(&[Value]) -> bool,
11{
12    dataset.filter_rows(predicate)
13}
14
15#[cfg(test)]
16mod tests {
17    use super::filter;
18    use crate::types::{DataSet, DataType, Field, Schema, Value};
19
20    fn sample_dataset() -> DataSet {
21        let schema = Schema::new(vec![
22            Field::new("id", DataType::Int64),
23            Field::new("active", DataType::Bool),
24            Field::new("name", DataType::Utf8),
25        ]);
26
27        let rows = vec![
28            vec![
29                Value::Int64(1),
30                Value::Bool(true),
31                Value::Utf8("a".to_string()),
32            ],
33            vec![
34                Value::Int64(2),
35                Value::Bool(false),
36                Value::Utf8("b".to_string()),
37            ],
38            vec![
39                Value::Int64(3),
40                Value::Bool(true),
41                Value::Utf8("c".to_string()),
42            ],
43        ];
44
45        DataSet::new(schema, rows)
46    }
47
48    #[test]
49    fn schema_index_of_works() {
50        let ds = sample_dataset();
51        assert_eq!(ds.schema.index_of("id"), Some(0));
52        assert_eq!(ds.schema.index_of("active"), Some(1));
53        assert_eq!(ds.schema.index_of("name"), Some(2));
54        assert_eq!(ds.schema.index_of("missing"), None);
55    }
56
57    #[test]
58    fn filter_rows_by_numeric_predicate() {
59        let ds = sample_dataset();
60        let id_idx = ds.schema.index_of("id").unwrap();
61
62        let out = ds.filter_rows(|row| matches!(row.get(id_idx), Some(Value::Int64(v)) if *v > 1));
63
64        assert_eq!(out.schema, ds.schema);
65        assert_eq!(out.row_count(), 2);
66        assert_eq!(
67            out.rows,
68            vec![
69                vec![
70                    Value::Int64(2),
71                    Value::Bool(false),
72                    Value::Utf8("b".to_string())
73                ],
74                vec![
75                    Value::Int64(3),
76                    Value::Bool(true),
77                    Value::Utf8("c".to_string())
78                ],
79            ]
80        );
81        // Original unchanged
82        assert_eq!(ds.row_count(), 3);
83    }
84
85    #[test]
86    fn filter_rows_by_bool_predicate() {
87        let ds = sample_dataset();
88        let active_idx = ds.schema.index_of("active").unwrap();
89
90        let out = filter(&ds, |row| {
91            matches!(row.get(active_idx), Some(Value::Bool(true)))
92        });
93
94        assert_eq!(out.row_count(), 2);
95        assert_eq!(
96            out.rows,
97            vec![
98                vec![
99                    Value::Int64(1),
100                    Value::Bool(true),
101                    Value::Utf8("a".to_string())
102                ],
103                vec![
104                    Value::Int64(3),
105                    Value::Bool(true),
106                    Value::Utf8("c".to_string())
107                ],
108            ]
109        );
110    }
111
112    #[test]
113    fn filter_rows_can_return_empty_dataset() {
114        let ds = sample_dataset();
115        let out = ds.filter_rows(|_| false);
116        assert_eq!(out.schema, ds.schema);
117        assert!(out.rows.is_empty());
118    }
119}