rust_data_processing/pipeline/
mod.rs

1//! DataFrame-centric pipeline/transforms backed by a Polars lazy plan.
2//!
3//! This module provides a small, engine-delegated pipeline API that compiles to a Polars
4//! [`polars::prelude::LazyFrame`] and then collects results back into our in-memory [`crate::types::DataSet`].
5//!
6//! Design goals for Phase 1:
7//! - Keep the public API in our own types (no Polars types in signatures)
8//! - Support a minimal set of transformation primitives needed for parity/benchmarks
9//! - Provide deterministic, testable behavior (null handling, missing column errors)
10//!
11//! # Examples
12//!
13//! ```no_run
14//! use rust_data_processing::pipeline::{Agg, DataFrame, JoinKind, Predicate};
15//! use rust_data_processing::types::{DataSet, DataType, Field, Schema, Value};
16//!
17//! # fn main() -> Result<(), rust_data_processing::IngestionError> {
18//! let ds = DataSet::new(
19//!     Schema::new(vec![
20//!         Field::new("id", DataType::Int64),
21//!         Field::new("active", DataType::Bool),
22//!         Field::new("score", DataType::Int64),
23//!         Field::new("grp", DataType::Utf8),
24//!     ]),
25//!     vec![
26//!         vec![Value::Int64(1), Value::Bool(true), Value::Int64(10), Value::Utf8("A".to_string())],
27//!         vec![Value::Int64(2), Value::Bool(true), Value::Null, Value::Utf8("A".to_string())],
28//!     ],
29//! );
30//!
31//! // Rename + cast + fill nulls.
32//! let cleaned = DataFrame::from_dataset(&ds)?
33//!     .rename(&[("score", "score_i")])?
34//!     .cast("score_i", DataType::Float64)?
35//!     .fill_null("score_i", Value::Float64(0.0))?;
36//!
37//! // Filter + group_by.
38//! let _out = cleaned
39//!     .filter(Predicate::Eq {
40//!         column: "active".to_string(),
41//!         value: Value::Bool(true),
42//!     })?
43//!     .group_by(
44//!         &["grp"],
45//!         &[Agg::Sum {
46//!             column: "score_i".to_string(),
47//!             alias: "sum_score".to_string(),
48//!         }],
49//!     )?
50//!     .collect()?;
51//!
52//! // Join two DataFrames.
53//! let left = DataFrame::from_dataset(&ds)?;
54//! let right = DataFrame::from_dataset(&ds)?;
55//! let _joined = left.join(right, &["id"], &["id"], JoinKind::Inner)?;
56//! # Ok(())
57//! # }
58//! ```
59
60use crate::error::{IngestionError, IngestionResult};
61use crate::ingestion::polars_bridge::{
62    dataframe_to_dataset, dataset_to_dataframe, infer_schema_from_dataframe,
63    polars_error_to_ingestion,
64};
65use crate::processing::{FeatureMeanStd, ReduceOp, VarianceKind};
66use crate::types::{DataSet, DataType, Schema, Value};
67
68use polars::chunked_array::cast::CastOptions;
69use polars::prelude::*;
70use serde::{Deserialize, Serialize};
71
72const REDUCE_SCALAR_COL: &str = "__rust_dp_reduce_scalar";
73
74/// A predicate used by [`DataFrame::filter`].
75#[derive(Debug, Clone, PartialEq)]
76pub enum Predicate {
77    /// Keep rows where `column == value`.
78    Eq { column: String, value: Value },
79    /// Keep rows where `column` is not null.
80    NotNull { column: String },
81    /// Keep rows where `column % modulus == equals` (Int64 only).
82    ModEqInt64 {
83        column: String,
84        modulus: i64,
85        equals: i64,
86    },
87}
88
89/// Join behavior for [`DataFrame::join`].
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum JoinKind {
92    Inner,
93    Left,
94    Right,
95    Full,
96}
97
98/// Aggregations for [`DataFrame::group_by`].
99#[derive(Debug, Clone, PartialEq)]
100pub enum Agg {
101    /// Count rows in each group (includes nulls).
102    CountRows {
103        alias: String,
104    },
105    /// Count non-null values of a column in each group.
106    CountNotNull {
107        column: String,
108        alias: String,
109    },
110    Sum {
111        column: String,
112        alias: String,
113    },
114    Min {
115        column: String,
116        alias: String,
117    },
118    Max {
119        column: String,
120        alias: String,
121    },
122    /// Mean of numeric values (cast to `Float64` first), nulls ignored.
123    Mean {
124        column: String,
125        alias: String,
126    },
127    Variance {
128        column: String,
129        alias: String,
130        kind: VarianceKind,
131    },
132    StdDev {
133        column: String,
134        alias: String,
135        kind: VarianceKind,
136    },
137    SumSquares {
138        column: String,
139        alias: String,
140    },
141    L2Norm {
142        column: String,
143        alias: String,
144    },
145    /// Distinct count of non-null values in each group.
146    CountDistinctNonNull {
147        column: String,
148        alias: String,
149    },
150}
151
152/// Casting behavior for [`DataFrame::cast_with_mode`].
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
154#[serde(rename_all = "snake_case")]
155pub enum CastMode {
156    /// Casting errors fail the pipeline at `collect()` time.
157    Strict,
158    /// Casting errors yield nulls instead of failing.
159    Lossy,
160}
161
162impl Default for CastMode {
163    fn default() -> Self {
164        Self::Strict
165    }
166}
167
168/// A DataFrame-centric pipeline compiled into a lazy plan.
169///
170/// The public API stays in this crate's own types. The current engine implementation is Polars,
171/// but callers do not need to depend on Polars types.
172#[derive(Clone)]
173pub struct DataFrame {
174    lf: LazyFrame,
175}
176
177impl DataFrame {
178    /// Build a pipeline starting from an in-memory [`DataSet`].
179    ///
180    /// Note: this converts the dataset into a Polars `DataFrame` first. The transformations after
181    /// that are planned lazily.
182    pub fn from_dataset(ds: &DataSet) -> IngestionResult<Self> {
183        let df = dataset_to_dataframe(ds)?;
184        Ok(Self { lf: df.lazy() })
185    }
186
187    /// Add a filter predicate.
188    pub fn filter(mut self, predicate: Predicate) -> IngestionResult<Self> {
189        let expr = match predicate {
190            Predicate::Eq { column, value } => match value {
191                Value::Null => col(&column).is_null(),
192                Value::Int64(x) => col(&column).eq(lit(x)),
193                Value::Float64(x) => col(&column).eq(lit(x)),
194                Value::Bool(x) => col(&column).eq(lit(x)),
195                Value::Utf8(s) => col(&column).eq(lit(s)),
196            },
197            Predicate::NotNull { column } => col(&column).is_not_null(),
198            Predicate::ModEqInt64 {
199                column,
200                modulus,
201                equals,
202            } => (col(&column) % lit(modulus)).eq(lit(equals)),
203        };
204        // Planning ops are infallible; errors surface at `collect` time.
205        self.lf = self.lf.filter(expr);
206        Ok(self)
207    }
208
209    /// Multiply a Float64 column by a constant factor (nulls remain null).
210    pub fn multiply_f64(mut self, column: &str, factor: f64) -> IngestionResult<Self> {
211        // Planning ops are infallible; errors surface at `collect` time.
212        self.lf = self
213            .lf
214            .with_columns([(col(column) * lit(factor)).alias(column)]);
215        Ok(self)
216    }
217
218    /// Add a constant Float64 value to a column (nulls remain null).
219    pub fn add_f64(mut self, column: &str, delta: f64) -> IngestionResult<Self> {
220        self.lf = self
221            .lf
222            .with_columns([(col(column) + lit(delta)).alias(column)]);
223        Ok(self)
224    }
225
226    /// Add a derived Float64 column: `name = source * factor` (nulls remain null).
227    pub fn with_mul_f64(mut self, name: &str, source: &str, factor: f64) -> IngestionResult<Self> {
228        self.lf = self
229            .lf
230            .with_columns([(col(source) * lit(factor)).alias(name)]);
231        Ok(self)
232    }
233
234    /// Add a derived Float64 column: `name = source + delta` (nulls remain null).
235    pub fn with_add_f64(mut self, name: &str, source: &str, delta: f64) -> IngestionResult<Self> {
236        self.lf = self
237            .lf
238            .with_columns([(col(source) + lit(delta)).alias(name)]);
239        Ok(self)
240    }
241
242    /// Select a subset of columns (in the provided order).
243    pub fn select(mut self, columns: &[&str]) -> IngestionResult<Self> {
244        let exprs: Vec<Expr> = columns.iter().map(|c| col(*c)).collect();
245        // Planning ops are infallible; errors surface at `collect` time.
246        self.lf = self.lf.select(exprs);
247        Ok(self)
248    }
249
250    /// Rename columns.
251    ///
252    /// This uses Polars' `rename(..., strict=true)` behavior: all `from` columns must exist.
253    pub fn rename(mut self, pairs: &[(&str, &str)]) -> IngestionResult<Self> {
254        let (existing, new): (Vec<&str>, Vec<&str>) = pairs.iter().copied().unzip();
255        self.lf = self.lf.rename(existing, new, true);
256        Ok(self)
257    }
258
259    /// Cast a column to a target type.
260    ///
261    /// Note: cast errors (e.g. invalid parses) surface at `collect()` time.
262    pub fn cast(self, column: &str, to: DataType) -> IngestionResult<Self> {
263        self.cast_with_mode(column, to, CastMode::Strict)
264    }
265
266    /// Cast a column with an explicit mode (strict vs lossy).
267    pub fn cast_with_mode(
268        mut self,
269        column: &str,
270        to: DataType,
271        mode: CastMode,
272    ) -> IngestionResult<Self> {
273        let dt = to_polars_dtype(&to);
274        let expr = match mode {
275            CastMode::Strict => col(column).strict_cast(dt),
276            CastMode::Lossy => col(column).cast_with_options(dt, CastOptions::NonStrict),
277        }
278        .alias(column);
279        self.lf = self.lf.with_columns([expr]);
280        Ok(self)
281    }
282
283    /// Drop columns by name.
284    pub fn drop(mut self, columns: &[&str]) -> IngestionResult<Self> {
285        let names: Vec<PlSmallStr> = columns.iter().map(|c| (*c).into()).collect();
286        let sel = Selector::ByName {
287            names: names.into(),
288            strict: true,
289        };
290        self.lf = self.lf.drop(sel);
291        Ok(self)
292    }
293
294    /// Fill nulls in a column with a literal.
295    pub fn fill_null(mut self, column: &str, value: Value) -> IngestionResult<Self> {
296        let lit_expr = value_to_lit_expr(value)?;
297        self.lf = self
298            .lf
299            .with_columns([col(column).fill_null(lit_expr).alias(column)]);
300        Ok(self)
301    }
302
303    /// Add a derived column with a literal value.
304    pub fn with_literal(mut self, name: &str, value: Value) -> IngestionResult<Self> {
305        let lit_expr = value_to_lit_expr(value)?;
306        self.lf = self.lf.with_columns([lit_expr.alias(name)]);
307        Ok(self)
308    }
309
310    /// Group rows by `keys` and compute aggregations.
311    pub fn group_by(mut self, keys: &[&str], aggs: &[Agg]) -> IngestionResult<Self> {
312        if keys.is_empty() {
313            return Err(IngestionError::SchemaMismatch {
314                message: "group_by requires at least one key column".to_string(),
315            });
316        }
317        if aggs.is_empty() {
318            return Err(IngestionError::SchemaMismatch {
319                message: "group_by requires at least one aggregation".to_string(),
320            });
321        }
322
323        let key_exprs: Vec<Expr> = keys.iter().map(|k| col(*k)).collect();
324        let agg_exprs: Vec<Expr> = aggs.iter().map(agg_to_expr).collect();
325        self.lf = self.lf.group_by(key_exprs).agg(agg_exprs);
326        Ok(self)
327    }
328
329    /// Join this pipeline with another [`DataFrame`] on key columns.
330    ///
331    /// Note: join planning is infallible; missing-column errors surface at `collect()` time.
332    pub fn join(
333        mut self,
334        other: DataFrame,
335        left_on: &[&str],
336        right_on: &[&str],
337        how: JoinKind,
338    ) -> IngestionResult<Self> {
339        if left_on.is_empty() || right_on.is_empty() {
340            return Err(IngestionError::SchemaMismatch {
341                message: "join requires at least one join key on each side".to_string(),
342            });
343        }
344        if left_on.len() != right_on.len() {
345            return Err(IngestionError::SchemaMismatch {
346                message: format!(
347                    "join requires left_on and right_on to have same length (left_on={}, right_on={})",
348                    left_on.len(),
349                    right_on.len()
350                ),
351            });
352        }
353
354        let left_exprs: Vec<Expr> = left_on.iter().map(|c| col(*c)).collect();
355        let right_exprs: Vec<Expr> = right_on.iter().map(|c| col(*c)).collect();
356
357        let how = match how {
358            JoinKind::Inner => JoinType::Inner,
359            JoinKind::Left => JoinType::Left,
360            JoinKind::Right => JoinType::Right,
361            JoinKind::Full => JoinType::Full,
362        };
363
364        self.lf = self
365            .lf
366            .join(other.lf, left_exprs, right_exprs, JoinArgs::new(how));
367        Ok(self)
368    }
369
370    /// Collect the pipeline into an in-memory [`DataSet`].
371    pub fn collect(self) -> IngestionResult<DataSet> {
372        let df = self
373            .lf
374            .collect()
375            .map_err(|e| polars_error_to_ingestion("failed to collect polars lazy plan", e))?;
376        let out_schema = infer_schema_from_dataframe(&df)?;
377        dataframe_to_dataset(&df, &out_schema, "column", 1)
378    }
379
380    /// Collect the pipeline into an in-memory [`DataSet`], enforcing an explicit output schema.
381    pub fn collect_with_schema(self, schema: &Schema) -> IngestionResult<DataSet> {
382        let df = self
383            .lf
384            .collect()
385            .map_err(|e| polars_error_to_ingestion("failed to collect polars lazy plan", e))?;
386        dataframe_to_dataset(&df, schema, "column", 1)
387    }
388
389    /// Reduce a column using a built-in [`ReduceOp`] (Polars-backed).
390    ///
391    /// Returns `None` if `column` does not exist (aligned with [`crate::processing::reduce`]).
392    pub fn reduce(mut self, column: &str, op: ReduceOp) -> IngestionResult<Option<Value>> {
393        let df_schema = self
394            .lf
395            .collect_schema()
396            .map_err(|e| polars_error_to_ingestion("failed to collect polars schema", e))?;
397        if df_schema.get(column).is_none() {
398            return Ok(None);
399        }
400
401        let expr = polars_reduce_expr(column, op);
402        let df = self
403            .lf
404            .select([expr.alias(REDUCE_SCALAR_COL)])
405            .collect()
406            .map_err(|e| polars_error_to_ingestion("failed to collect polars reduce", e))?;
407
408        let s = df
409            .column(REDUCE_SCALAR_COL)
410            .map_err(|_| IngestionError::SchemaMismatch {
411                message: format!("missing reduce output column '{REDUCE_SCALAR_COL}'"),
412            })?
413            .as_materialized_series();
414        if s.len() == 0 {
415            return Ok(Some(Value::Null));
416        }
417        let av = s.get(0).map_err(|e| IngestionError::SchemaMismatch {
418            message: format!("polars reduce output error: {e}"),
419        })?;
420        Ok(Some(anyvalue_to_value(av)))
421    }
422
423    /// Reduce a numeric column by summing values (nulls ignored; all-null -> null).
424    ///
425    /// Returns `None` if `column` does not exist (aligned with `processing::reduce`).
426    pub fn sum(self, column: &str) -> IngestionResult<Option<Value>> {
427        self.reduce(column, ReduceOp::Sum)
428    }
429
430    /// Single Polars collect: for each column, mean and standard deviation (`std_kind` maps to
431    /// Polars `ddof`). Columns are cast to `Float64` first (aligned with scalar reduces).
432    ///
433    /// Returns an error if any column name is missing from the lazy schema.
434    pub fn feature_wise_mean_std(
435        mut self,
436        columns: &[&str],
437        std_kind: VarianceKind,
438    ) -> IngestionResult<Vec<(String, FeatureMeanStd)>> {
439        let df_schema = self
440            .lf
441            .collect_schema()
442            .map_err(|e| polars_error_to_ingestion("failed to collect polars schema", e))?;
443        for c in columns {
444            if df_schema.get(*c).is_none() {
445                return Err(IngestionError::SchemaMismatch {
446                    message: format!("feature_wise_mean_std: unknown column '{c}'"),
447                });
448            }
449        }
450        let ddof = match std_kind {
451            VarianceKind::Population => 0u8,
452            VarianceKind::Sample => 1u8,
453        };
454        use polars::datatypes::DataType as P;
455        let mut exprs: Vec<Expr> = Vec::with_capacity(columns.len() * 2);
456        for (i, c) in columns.iter().enumerate() {
457            let cf = col(*c).strict_cast(P::Float64);
458            exprs.push(cf.clone().mean().alias(format!("__fwm_{i}_mean").as_str()));
459            exprs.push(cf.std(ddof).alias(format!("__fwm_{i}_std").as_str()));
460        }
461        let df =
462            self.lf.select(exprs).collect().map_err(|e| {
463                polars_error_to_ingestion("failed to collect feature_wise_mean_std", e)
464            })?;
465
466        if df.height() == 0 {
467            return Ok(columns
468                .iter()
469                .map(|c| {
470                    (
471                        (*c).to_string(),
472                        FeatureMeanStd {
473                            mean: Value::Null,
474                            std_dev: Value::Null,
475                        },
476                    )
477                })
478                .collect());
479        }
480
481        let mut out = Vec::with_capacity(columns.len());
482        for i in 0..columns.len() {
483            let mean_s = df
484                .column(&format!("__fwm_{i}_mean"))
485                .map_err(|_| IngestionError::SchemaMismatch {
486                    message: format!("missing __fwm_{i}_mean"),
487                })?
488                .as_materialized_series();
489            let std_s = df
490                .column(&format!("__fwm_{i}_std"))
491                .map_err(|_| IngestionError::SchemaMismatch {
492                    message: format!("missing __fwm_{i}_std"),
493                })?
494                .as_materialized_series();
495            let mean_av = mean_s.get(0).map_err(|e| IngestionError::SchemaMismatch {
496                message: format!("feature_wise mean get: {e}"),
497            })?;
498            let std_av = std_s.get(0).map_err(|e| IngestionError::SchemaMismatch {
499                message: format!("feature_wise std get: {e}"),
500            })?;
501            out.push((
502                columns[i].to_string(),
503                FeatureMeanStd {
504                    mean: anyvalue_to_value(mean_av),
505                    std_dev: anyvalue_to_value(std_av),
506                },
507            ));
508        }
509        Ok(out)
510    }
511
512    pub(crate) fn lazy_clone(&self) -> LazyFrame {
513        self.lf.clone()
514    }
515
516    pub(crate) fn from_lazyframe(lf: LazyFrame) -> Self {
517        Self { lf }
518    }
519}
520
521fn polars_reduce_expr(column: &str, op: ReduceOp) -> Expr {
522    use polars::datatypes::DataType as P;
523    let c = col(column);
524    match op {
525        ReduceOp::Count => len(),
526        ReduceOp::Sum => c.sum(),
527        ReduceOp::Min => c.min(),
528        ReduceOp::Max => c.max(),
529        ReduceOp::Mean => c.clone().strict_cast(P::Float64).mean(),
530        ReduceOp::Variance(kind) => {
531            let ddof = match kind {
532                VarianceKind::Population => 0u8,
533                VarianceKind::Sample => 1u8,
534            };
535            c.clone().strict_cast(P::Float64).var(ddof)
536        }
537        ReduceOp::StdDev(kind) => {
538            let ddof = match kind {
539                VarianceKind::Population => 0u8,
540                VarianceKind::Sample => 1u8,
541            };
542            c.clone().strict_cast(P::Float64).std(ddof)
543        }
544        ReduceOp::SumSquares => c.clone().strict_cast(P::Float64).pow(lit(2.0)).sum(),
545        ReduceOp::L2Norm => c.clone().strict_cast(P::Float64).pow(lit(2.0)).sum().sqrt(),
546        ReduceOp::CountDistinctNonNull => c.drop_nulls().n_unique(),
547    }
548}
549
550fn agg_to_expr(agg: &Agg) -> Expr {
551    use polars::datatypes::DataType as P;
552    match agg {
553        Agg::CountRows { alias } => len().alias(alias.as_str()),
554        Agg::CountNotNull { column, alias } => col(column.as_str()).count().alias(alias.as_str()),
555        Agg::Sum { column, alias } => col(column.as_str()).sum().alias(alias.as_str()),
556        Agg::Min { column, alias } => col(column.as_str()).min().alias(alias.as_str()),
557        Agg::Max { column, alias } => col(column.as_str()).max().alias(alias.as_str()),
558        Agg::Mean { column, alias } => col(column.as_str())
559            .strict_cast(P::Float64)
560            .mean()
561            .alias(alias.as_str()),
562        Agg::Variance {
563            column,
564            alias,
565            kind,
566        } => {
567            let ddof = match kind {
568                VarianceKind::Population => 0u8,
569                VarianceKind::Sample => 1u8,
570            };
571            col(column.as_str())
572                .strict_cast(P::Float64)
573                .var(ddof)
574                .alias(alias.as_str())
575        }
576        Agg::StdDev {
577            column,
578            alias,
579            kind,
580        } => {
581            let ddof = match kind {
582                VarianceKind::Population => 0u8,
583                VarianceKind::Sample => 1u8,
584            };
585            col(column.as_str())
586                .strict_cast(P::Float64)
587                .std(ddof)
588                .alias(alias.as_str())
589        }
590        Agg::SumSquares { column, alias } => col(column.as_str())
591            .strict_cast(P::Float64)
592            .pow(lit(2.0))
593            .sum()
594            .alias(alias.as_str()),
595        Agg::L2Norm { column, alias } => col(column.as_str())
596            .strict_cast(P::Float64)
597            .pow(lit(2.0))
598            .sum()
599            .sqrt()
600            .alias(alias.as_str()),
601        Agg::CountDistinctNonNull { column, alias } => col(column.as_str())
602            .drop_nulls()
603            .n_unique()
604            .alias(alias.as_str()),
605    }
606}
607
608fn to_polars_dtype(dt: &DataType) -> polars::datatypes::DataType {
609    match dt {
610        DataType::Int64 => polars::datatypes::DataType::Int64,
611        DataType::Float64 => polars::datatypes::DataType::Float64,
612        DataType::Bool => polars::datatypes::DataType::Boolean,
613        DataType::Utf8 => polars::datatypes::DataType::String,
614    }
615}
616
617fn value_to_lit_expr(value: Value) -> IngestionResult<Expr> {
618    match value {
619        Value::Null => Err(IngestionError::SchemaMismatch {
620            message: "Value::Null is not supported as a literal expression; use fill_null or cast/collect to materialize".to_string(),
621        }),
622        Value::Int64(v) => Ok(lit(v)),
623        Value::Float64(v) => Ok(lit(v)),
624        Value::Bool(v) => Ok(lit(v)),
625        Value::Utf8(v) => Ok(lit(v)),
626    }
627}
628
629fn anyvalue_to_value(av: AnyValue) -> Value {
630    match av {
631        AnyValue::Null => Value::Null,
632        AnyValue::Int8(v) => Value::Int64(v as i64),
633        AnyValue::Int16(v) => Value::Int64(v as i64),
634        AnyValue::Int32(v) => Value::Int64(v as i64),
635        AnyValue::Int64(v) => Value::Int64(v),
636        AnyValue::UInt8(v) => Value::Int64(v as i64),
637        AnyValue::UInt16(v) => Value::Int64(v as i64),
638        AnyValue::UInt32(v) => Value::Int64(v as i64),
639        AnyValue::UInt64(v) => Value::Int64(v as i64),
640        AnyValue::Float64(v) => Value::Float64(v),
641        AnyValue::Boolean(v) => Value::Bool(v),
642        AnyValue::String(v) => Value::Utf8(v.to_string()),
643        AnyValue::StringOwned(v) => Value::Utf8(v.to_string()),
644        other => Value::Utf8(other.to_string()),
645    }
646}
647
648/// Backwards-compatible alias for earlier naming.
649pub type PolarsPipeline = DataFrame;
650
651#[cfg(test)]
652mod tests {
653    use super::{Agg, DataFrame, JoinKind, PolarsPipeline, Predicate};
654    use crate::processing::{ReduceOp, VarianceKind, feature_wise_mean_std, filter, map, reduce};
655    use crate::types::{DataSet, DataType, Field, Schema, Value};
656
657    fn sample_dataset() -> DataSet {
658        let schema = Schema::new(vec![
659            Field::new("id", DataType::Int64),
660            Field::new("active", DataType::Bool),
661            Field::new("score", DataType::Float64),
662        ]);
663        let rows = vec![
664            vec![Value::Int64(1), Value::Bool(true), Value::Float64(10.0)],
665            vec![Value::Int64(2), Value::Bool(true), Value::Float64(20.0)],
666            vec![Value::Int64(3), Value::Bool(false), Value::Float64(30.0)],
667            vec![Value::Int64(4), Value::Bool(true), Value::Null],
668        ];
669        DataSet::new(schema, rows)
670    }
671
672    #[test]
673    fn polars_pipeline_filter_map_reduce_parity_with_in_memory() {
674        let ds = sample_dataset();
675
676        // In-memory baseline: active && even id, score *= 2.0, then sum(score)
677        let active_idx = ds.schema.index_of("active").unwrap();
678        let id_idx = ds.schema.index_of("id").unwrap();
679        let filtered = filter(&ds, |row| {
680            let is_active = matches!(row.get(active_idx), Some(Value::Bool(true)));
681            let even_id = matches!(row.get(id_idx), Some(Value::Int64(v)) if *v % 2 == 0);
682            is_active && even_id
683        });
684        let mapped = map(&filtered, |row| {
685            let mut out = row.to_vec();
686            if let Some(Value::Float64(v)) = out.get(2) {
687                out[2] = Value::Float64(v * 2.0);
688            }
689            out
690        });
691        let expected = reduce(&mapped, "score", ReduceOp::Sum).unwrap();
692
693        // Polars-delegated pipeline.
694        let got = DataFrame::from_dataset(&ds)
695            .unwrap()
696            .filter(Predicate::Eq {
697                column: "active".to_string(),
698                value: Value::Bool(true),
699            })
700            .unwrap()
701            .filter(Predicate::ModEqInt64 {
702                column: "id".to_string(),
703                modulus: 2,
704                equals: 0,
705            })
706            .unwrap()
707            .multiply_f64("score", 2.0)
708            .unwrap()
709            .sum("score")
710            .unwrap()
711            .unwrap();
712
713        assert_eq!(got, expected);
714    }
715
716    #[test]
717    fn polars_pipeline_reduce_parity_mean_variance_l2_distinct() {
718        let schema = Schema::new(vec![
719            Field::new("x", DataType::Float64),
720            Field::new("tag", DataType::Utf8),
721        ]);
722        let ds = DataSet::new(
723            schema,
724            vec![
725                vec![Value::Float64(1.0), Value::Utf8("a".to_string())],
726                vec![Value::Float64(2.0), Value::Utf8("b".to_string())],
727                vec![Value::Null, Value::Utf8("a".to_string())],
728            ],
729        );
730
731        let mean = reduce(&ds, "x", ReduceOp::Mean).unwrap();
732        let var_pop = reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Population)).unwrap();
733        let l2 = reduce(&ds, "x", ReduceOp::L2Norm).unwrap();
734        let dcnt = reduce(&ds, "tag", ReduceOp::CountDistinctNonNull).unwrap();
735
736        assert_eq!(
737            DataFrame::from_dataset(&ds)
738                .unwrap()
739                .reduce("x", ReduceOp::Mean)
740                .unwrap()
741                .unwrap(),
742            mean
743        );
744        assert_eq!(
745            DataFrame::from_dataset(&ds)
746                .unwrap()
747                .reduce("x", ReduceOp::Variance(VarianceKind::Population))
748                .unwrap()
749                .unwrap(),
750            var_pop
751        );
752        assert_eq!(
753            DataFrame::from_dataset(&ds)
754                .unwrap()
755                .reduce("x", ReduceOp::L2Norm)
756                .unwrap()
757                .unwrap(),
758            l2
759        );
760        assert_eq!(
761            DataFrame::from_dataset(&ds)
762                .unwrap()
763                .reduce("tag", ReduceOp::CountDistinctNonNull)
764                .unwrap()
765                .unwrap(),
766            dcnt
767        );
768    }
769
770    #[test]
771    fn polars_pipeline_collect_select_works() {
772        let ds = sample_dataset();
773        let out = DataFrame::from_dataset(&ds)
774            .unwrap()
775            .select(&["score", "id"])
776            .unwrap()
777            .collect()
778            .unwrap();
779
780        assert_eq!(
781            out.schema.field_names().collect::<Vec<_>>(),
782            vec!["score", "id"]
783        );
784        assert_eq!(out.row_count(), ds.row_count());
785        assert_eq!(out.rows[0][0], Value::Float64(10.0));
786        assert_eq!(out.rows[0][1], Value::Int64(1));
787    }
788
789    #[test]
790    fn polars_pipeline_sum_returns_none_for_missing_column() {
791        let ds = sample_dataset();
792        let out = DataFrame::from_dataset(&ds)
793            .unwrap()
794            .sum("missing")
795            .unwrap();
796        assert_eq!(out, None);
797    }
798
799    #[test]
800    fn polars_errors_are_preserved_as_engine_error_sources() {
801        // Trigger a Polars execution error by applying a numeric multiply to a Utf8 column.
802        let schema = Schema::new(vec![Field::new("name", DataType::Utf8)]);
803        let ds = DataSet::new(schema, vec![vec![Value::Utf8("x".to_string())]]);
804
805        let err = DataFrame::from_dataset(&ds)
806            .unwrap()
807            .multiply_f64("name", 2.0)
808            .unwrap()
809            .collect()
810            .unwrap_err();
811
812        // This should not be stringified into SchemaMismatch; it should preserve a source() chain.
813        match err {
814            crate::error::IngestionError::Engine { source, .. } => {
815                assert!(!source.to_string().is_empty());
816            }
817            other => panic!("expected Engine error, got: {other:?}"),
818        }
819    }
820
821    #[test]
822    fn backwards_compatible_polars_pipeline_alias_exists() {
823        let ds = sample_dataset();
824        let _ = PolarsPipeline::from_dataset(&ds)
825            .unwrap()
826            .select(&["id"])
827            .unwrap();
828    }
829
830    #[test]
831    fn rename_cast_fill_null_group_by_and_join_work() {
832        // rename + cast + fill_null
833        let schema = Schema::new(vec![
834            Field::new("id", DataType::Int64),
835            Field::new("score", DataType::Int64),
836        ]);
837        let ds = DataSet::new(
838            schema,
839            vec![
840                vec![Value::Int64(1), Value::Int64(10)],
841                vec![Value::Int64(2), Value::Null],
842            ],
843        );
844
845        let out = DataFrame::from_dataset(&ds)
846            .unwrap()
847            .rename(&[("score", "score_i")])
848            .unwrap()
849            .cast("score_i", DataType::Float64)
850            .unwrap()
851            .fill_null("score_i", Value::Float64(0.0))
852            .unwrap()
853            .collect()
854            .unwrap();
855
856        assert_eq!(
857            out.schema.field_names().collect::<Vec<_>>(),
858            vec!["id", "score_i"]
859        );
860        assert_eq!(out.rows[0][1], Value::Float64(10.0));
861        assert_eq!(out.rows[1][1], Value::Float64(0.0));
862
863        // group_by
864        let schema = Schema::new(vec![
865            Field::new("grp", DataType::Utf8),
866            Field::new("score", DataType::Float64),
867        ]);
868        let ds = DataSet::new(
869            schema,
870            vec![
871                vec![Value::Utf8("A".to_string()), Value::Float64(1.0)],
872                vec![Value::Utf8("A".to_string()), Value::Float64(2.0)],
873                vec![Value::Utf8("B".to_string()), Value::Null],
874            ],
875        );
876
877        let out = DataFrame::from_dataset(&ds)
878            .unwrap()
879            .group_by(
880                &["grp"],
881                &[
882                    Agg::Sum {
883                        column: "score".to_string(),
884                        alias: "sum_score".to_string(),
885                    },
886                    Agg::CountRows {
887                        alias: "cnt".to_string(),
888                    },
889                ],
890            )
891            .unwrap()
892            .collect()
893            .unwrap();
894
895        // Order is not guaranteed; validate via a lookup.
896        let mut sums: std::collections::HashMap<String, (Value, Value)> =
897            std::collections::HashMap::new();
898        for row in &out.rows {
899            if let Value::Utf8(g) = &row[0] {
900                sums.insert(g.clone(), (row[1].clone(), row[2].clone()));
901            }
902        }
903        assert_eq!(sums.get("A"), Some(&(Value::Float64(3.0), Value::Int64(2))));
904        assert_eq!(
905            sums.get("B"),
906            // Polars `sum` ignores nulls and returns 0.0 for all-null groups.
907            Some(&(Value::Float64(0.0), Value::Int64(1)))
908        );
909
910        // join
911        let left = DataSet::new(
912            Schema::new(vec![
913                Field::new("id", DataType::Int64),
914                Field::new("name", DataType::Utf8),
915            ]),
916            vec![
917                vec![Value::Int64(1), Value::Utf8("Ada".to_string())],
918                vec![Value::Int64(2), Value::Utf8("Grace".to_string())],
919            ],
920        );
921        let right = DataSet::new(
922            Schema::new(vec![
923                Field::new("id", DataType::Int64),
924                Field::new("score", DataType::Float64),
925            ]),
926            vec![
927                vec![Value::Int64(1), Value::Float64(9.0)],
928                vec![Value::Int64(3), Value::Float64(7.0)],
929            ],
930        );
931
932        let out = DataFrame::from_dataset(&left)
933            .unwrap()
934            .join(
935                DataFrame::from_dataset(&right).unwrap(),
936                &["id"],
937                &["id"],
938                JoinKind::Inner,
939            )
940            .unwrap()
941            .collect()
942            .unwrap();
943        assert_eq!(out.row_count(), 1);
944        // One matched row with id=1.
945        assert_eq!(out.rows[0][0], Value::Int64(1));
946    }
947
948    #[test]
949    fn polars_feature_wise_mean_std_matches_in_memory() {
950        let schema = Schema::new(vec![
951            Field::new("a", DataType::Int64),
952            Field::new("b", DataType::Float64),
953        ]);
954        let ds = DataSet::new(
955            schema,
956            vec![
957                vec![Value::Int64(1), Value::Float64(10.0)],
958                vec![Value::Int64(3), Value::Float64(20.0)],
959            ],
960        );
961        let mem = feature_wise_mean_std(&ds, &["a", "b"], VarianceKind::Sample).unwrap();
962        let pol = DataFrame::from_dataset(&ds)
963            .unwrap()
964            .feature_wise_mean_std(&["a", "b"], VarianceKind::Sample)
965            .unwrap();
966        assert_eq!(mem.len(), pol.len());
967        for i in 0..mem.len() {
968            assert_eq!(mem[i].0, pol[i].0);
969            assert_eq!(mem[i].1.mean, pol[i].1.mean);
970            match (&mem[i].1.std_dev, &pol[i].1.std_dev) {
971                (Value::Float64(m), Value::Float64(p)) => assert!((m - p).abs() < 1e-9),
972                (a, b) => assert_eq!(a, b),
973            }
974        }
975    }
976
977    #[test]
978    fn group_by_mean_std_count_distinct_all_null_numeric_is_null() {
979        let schema = Schema::new(vec![
980            Field::new("g", DataType::Utf8),
981            Field::new("x", DataType::Float64),
982            Field::new("tag", DataType::Utf8),
983        ]);
984        let ds = DataSet::new(
985            schema,
986            vec![
987                vec![
988                    Value::Utf8("A".to_string()),
989                    Value::Null,
990                    Value::Utf8("p".to_string()),
991                ],
992                vec![
993                    Value::Utf8("A".to_string()),
994                    Value::Null,
995                    Value::Utf8("q".to_string()),
996                ],
997            ],
998        );
999        let out = DataFrame::from_dataset(&ds)
1000            .unwrap()
1001            .group_by(
1002                &["g"],
1003                &[
1004                    Agg::Mean {
1005                        column: "x".to_string(),
1006                        alias: "mx".to_string(),
1007                    },
1008                    Agg::StdDev {
1009                        column: "x".to_string(),
1010                        alias: "sx".to_string(),
1011                        kind: VarianceKind::Sample,
1012                    },
1013                    Agg::CountDistinctNonNull {
1014                        column: "tag".to_string(),
1015                        alias: "dt".to_string(),
1016                    },
1017                ],
1018            )
1019            .unwrap()
1020            .collect()
1021            .unwrap();
1022        assert_eq!(out.row_count(), 1);
1023        assert_eq!(out.rows[0][0], Value::Utf8("A".to_string()));
1024        assert_eq!(out.rows[0][1], Value::Null);
1025        assert_eq!(out.rows[0][2], Value::Null);
1026        assert_eq!(out.rows[0][3], Value::Int64(2));
1027    }
1028}