1use crate::error::{IngestionError, IngestionResult};
61use crate::ingestion::polars_bridge::{
62 dataframe_to_dataset, dataset_to_dataframe, infer_schema_from_dataframe,
63 polars_error_to_ingestion,
64};
65use crate::processing::{FeatureMeanStd, ReduceOp, VarianceKind};
66use crate::types::{DataSet, DataType, Schema, Value};
67
68use polars::chunked_array::cast::CastOptions;
69use polars::prelude::*;
70use serde::{Deserialize, Serialize};
71
72const REDUCE_SCALAR_COL: &str = "__rust_dp_reduce_scalar";
73
74#[derive(Debug, Clone, PartialEq)]
76pub enum Predicate {
77 Eq { column: String, value: Value },
79 NotNull { column: String },
81 ModEqInt64 {
83 column: String,
84 modulus: i64,
85 equals: i64,
86 },
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum JoinKind {
92 Inner,
93 Left,
94 Right,
95 Full,
96}
97
98#[derive(Debug, Clone, PartialEq)]
100pub enum Agg {
101 CountRows {
103 alias: String,
104 },
105 CountNotNull {
107 column: String,
108 alias: String,
109 },
110 Sum {
111 column: String,
112 alias: String,
113 },
114 Min {
115 column: String,
116 alias: String,
117 },
118 Max {
119 column: String,
120 alias: String,
121 },
122 Mean {
124 column: String,
125 alias: String,
126 },
127 Variance {
128 column: String,
129 alias: String,
130 kind: VarianceKind,
131 },
132 StdDev {
133 column: String,
134 alias: String,
135 kind: VarianceKind,
136 },
137 SumSquares {
138 column: String,
139 alias: String,
140 },
141 L2Norm {
142 column: String,
143 alias: String,
144 },
145 CountDistinctNonNull {
147 column: String,
148 alias: String,
149 },
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
154#[serde(rename_all = "snake_case")]
155pub enum CastMode {
156 Strict,
158 Lossy,
160}
161
162impl Default for CastMode {
163 fn default() -> Self {
164 Self::Strict
165 }
166}
167
168#[derive(Clone)]
173pub struct DataFrame {
174 lf: LazyFrame,
175}
176
177impl DataFrame {
178 pub fn from_dataset(ds: &DataSet) -> IngestionResult<Self> {
183 let df = dataset_to_dataframe(ds)?;
184 Ok(Self { lf: df.lazy() })
185 }
186
187 pub fn filter(mut self, predicate: Predicate) -> IngestionResult<Self> {
189 let expr = match predicate {
190 Predicate::Eq { column, value } => match value {
191 Value::Null => col(&column).is_null(),
192 Value::Int64(x) => col(&column).eq(lit(x)),
193 Value::Float64(x) => col(&column).eq(lit(x)),
194 Value::Bool(x) => col(&column).eq(lit(x)),
195 Value::Utf8(s) => col(&column).eq(lit(s)),
196 },
197 Predicate::NotNull { column } => col(&column).is_not_null(),
198 Predicate::ModEqInt64 {
199 column,
200 modulus,
201 equals,
202 } => (col(&column) % lit(modulus)).eq(lit(equals)),
203 };
204 self.lf = self.lf.filter(expr);
206 Ok(self)
207 }
208
209 pub fn multiply_f64(mut self, column: &str, factor: f64) -> IngestionResult<Self> {
211 self.lf = self
213 .lf
214 .with_columns([(col(column) * lit(factor)).alias(column)]);
215 Ok(self)
216 }
217
218 pub fn add_f64(mut self, column: &str, delta: f64) -> IngestionResult<Self> {
220 self.lf = self
221 .lf
222 .with_columns([(col(column) + lit(delta)).alias(column)]);
223 Ok(self)
224 }
225
226 pub fn with_mul_f64(mut self, name: &str, source: &str, factor: f64) -> IngestionResult<Self> {
228 self.lf = self
229 .lf
230 .with_columns([(col(source) * lit(factor)).alias(name)]);
231 Ok(self)
232 }
233
234 pub fn with_add_f64(mut self, name: &str, source: &str, delta: f64) -> IngestionResult<Self> {
236 self.lf = self
237 .lf
238 .with_columns([(col(source) + lit(delta)).alias(name)]);
239 Ok(self)
240 }
241
242 pub fn select(mut self, columns: &[&str]) -> IngestionResult<Self> {
244 let exprs: Vec<Expr> = columns.iter().map(|c| col(*c)).collect();
245 self.lf = self.lf.select(exprs);
247 Ok(self)
248 }
249
250 pub fn rename(mut self, pairs: &[(&str, &str)]) -> IngestionResult<Self> {
254 let (existing, new): (Vec<&str>, Vec<&str>) = pairs.iter().copied().unzip();
255 self.lf = self.lf.rename(existing, new, true);
256 Ok(self)
257 }
258
259 pub fn cast(self, column: &str, to: DataType) -> IngestionResult<Self> {
263 self.cast_with_mode(column, to, CastMode::Strict)
264 }
265
266 pub fn cast_with_mode(
268 mut self,
269 column: &str,
270 to: DataType,
271 mode: CastMode,
272 ) -> IngestionResult<Self> {
273 let dt = to_polars_dtype(&to);
274 let expr = match mode {
275 CastMode::Strict => col(column).strict_cast(dt),
276 CastMode::Lossy => col(column).cast_with_options(dt, CastOptions::NonStrict),
277 }
278 .alias(column);
279 self.lf = self.lf.with_columns([expr]);
280 Ok(self)
281 }
282
283 pub fn drop(mut self, columns: &[&str]) -> IngestionResult<Self> {
285 let names: Vec<PlSmallStr> = columns.iter().map(|c| (*c).into()).collect();
286 let sel = Selector::ByName {
287 names: names.into(),
288 strict: true,
289 };
290 self.lf = self.lf.drop(sel);
291 Ok(self)
292 }
293
294 pub fn fill_null(mut self, column: &str, value: Value) -> IngestionResult<Self> {
296 let lit_expr = value_to_lit_expr(value)?;
297 self.lf = self
298 .lf
299 .with_columns([col(column).fill_null(lit_expr).alias(column)]);
300 Ok(self)
301 }
302
303 pub fn with_literal(mut self, name: &str, value: Value) -> IngestionResult<Self> {
305 let lit_expr = value_to_lit_expr(value)?;
306 self.lf = self.lf.with_columns([lit_expr.alias(name)]);
307 Ok(self)
308 }
309
310 pub fn group_by(mut self, keys: &[&str], aggs: &[Agg]) -> IngestionResult<Self> {
312 if keys.is_empty() {
313 return Err(IngestionError::SchemaMismatch {
314 message: "group_by requires at least one key column".to_string(),
315 });
316 }
317 if aggs.is_empty() {
318 return Err(IngestionError::SchemaMismatch {
319 message: "group_by requires at least one aggregation".to_string(),
320 });
321 }
322
323 let key_exprs: Vec<Expr> = keys.iter().map(|k| col(*k)).collect();
324 let agg_exprs: Vec<Expr> = aggs.iter().map(agg_to_expr).collect();
325 self.lf = self.lf.group_by(key_exprs).agg(agg_exprs);
326 Ok(self)
327 }
328
329 pub fn join(
333 mut self,
334 other: DataFrame,
335 left_on: &[&str],
336 right_on: &[&str],
337 how: JoinKind,
338 ) -> IngestionResult<Self> {
339 if left_on.is_empty() || right_on.is_empty() {
340 return Err(IngestionError::SchemaMismatch {
341 message: "join requires at least one join key on each side".to_string(),
342 });
343 }
344 if left_on.len() != right_on.len() {
345 return Err(IngestionError::SchemaMismatch {
346 message: format!(
347 "join requires left_on and right_on to have same length (left_on={}, right_on={})",
348 left_on.len(),
349 right_on.len()
350 ),
351 });
352 }
353
354 let left_exprs: Vec<Expr> = left_on.iter().map(|c| col(*c)).collect();
355 let right_exprs: Vec<Expr> = right_on.iter().map(|c| col(*c)).collect();
356
357 let how = match how {
358 JoinKind::Inner => JoinType::Inner,
359 JoinKind::Left => JoinType::Left,
360 JoinKind::Right => JoinType::Right,
361 JoinKind::Full => JoinType::Full,
362 };
363
364 self.lf = self
365 .lf
366 .join(other.lf, left_exprs, right_exprs, JoinArgs::new(how));
367 Ok(self)
368 }
369
370 pub fn collect(self) -> IngestionResult<DataSet> {
372 let df = self
373 .lf
374 .collect()
375 .map_err(|e| polars_error_to_ingestion("failed to collect polars lazy plan", e))?;
376 let out_schema = infer_schema_from_dataframe(&df)?;
377 dataframe_to_dataset(&df, &out_schema, "column", 1)
378 }
379
380 pub fn collect_with_schema(self, schema: &Schema) -> IngestionResult<DataSet> {
382 let df = self
383 .lf
384 .collect()
385 .map_err(|e| polars_error_to_ingestion("failed to collect polars lazy plan", e))?;
386 dataframe_to_dataset(&df, schema, "column", 1)
387 }
388
389 pub fn reduce(mut self, column: &str, op: ReduceOp) -> IngestionResult<Option<Value>> {
393 let df_schema = self
394 .lf
395 .collect_schema()
396 .map_err(|e| polars_error_to_ingestion("failed to collect polars schema", e))?;
397 if df_schema.get(column).is_none() {
398 return Ok(None);
399 }
400
401 let expr = polars_reduce_expr(column, op);
402 let df = self
403 .lf
404 .select([expr.alias(REDUCE_SCALAR_COL)])
405 .collect()
406 .map_err(|e| polars_error_to_ingestion("failed to collect polars reduce", e))?;
407
408 let s = df
409 .column(REDUCE_SCALAR_COL)
410 .map_err(|_| IngestionError::SchemaMismatch {
411 message: format!("missing reduce output column '{REDUCE_SCALAR_COL}'"),
412 })?
413 .as_materialized_series();
414 if s.len() == 0 {
415 return Ok(Some(Value::Null));
416 }
417 let av = s.get(0).map_err(|e| IngestionError::SchemaMismatch {
418 message: format!("polars reduce output error: {e}"),
419 })?;
420 Ok(Some(anyvalue_to_value(av)))
421 }
422
423 pub fn sum(self, column: &str) -> IngestionResult<Option<Value>> {
427 self.reduce(column, ReduceOp::Sum)
428 }
429
430 pub fn feature_wise_mean_std(
435 mut self,
436 columns: &[&str],
437 std_kind: VarianceKind,
438 ) -> IngestionResult<Vec<(String, FeatureMeanStd)>> {
439 let df_schema = self
440 .lf
441 .collect_schema()
442 .map_err(|e| polars_error_to_ingestion("failed to collect polars schema", e))?;
443 for c in columns {
444 if df_schema.get(*c).is_none() {
445 return Err(IngestionError::SchemaMismatch {
446 message: format!("feature_wise_mean_std: unknown column '{c}'"),
447 });
448 }
449 }
450 let ddof = match std_kind {
451 VarianceKind::Population => 0u8,
452 VarianceKind::Sample => 1u8,
453 };
454 use polars::datatypes::DataType as P;
455 let mut exprs: Vec<Expr> = Vec::with_capacity(columns.len() * 2);
456 for (i, c) in columns.iter().enumerate() {
457 let cf = col(*c).strict_cast(P::Float64);
458 exprs.push(cf.clone().mean().alias(format!("__fwm_{i}_mean").as_str()));
459 exprs.push(cf.std(ddof).alias(format!("__fwm_{i}_std").as_str()));
460 }
461 let df =
462 self.lf.select(exprs).collect().map_err(|e| {
463 polars_error_to_ingestion("failed to collect feature_wise_mean_std", e)
464 })?;
465
466 if df.height() == 0 {
467 return Ok(columns
468 .iter()
469 .map(|c| {
470 (
471 (*c).to_string(),
472 FeatureMeanStd {
473 mean: Value::Null,
474 std_dev: Value::Null,
475 },
476 )
477 })
478 .collect());
479 }
480
481 let mut out = Vec::with_capacity(columns.len());
482 for i in 0..columns.len() {
483 let mean_s = df
484 .column(&format!("__fwm_{i}_mean"))
485 .map_err(|_| IngestionError::SchemaMismatch {
486 message: format!("missing __fwm_{i}_mean"),
487 })?
488 .as_materialized_series();
489 let std_s = df
490 .column(&format!("__fwm_{i}_std"))
491 .map_err(|_| IngestionError::SchemaMismatch {
492 message: format!("missing __fwm_{i}_std"),
493 })?
494 .as_materialized_series();
495 let mean_av = mean_s.get(0).map_err(|e| IngestionError::SchemaMismatch {
496 message: format!("feature_wise mean get: {e}"),
497 })?;
498 let std_av = std_s.get(0).map_err(|e| IngestionError::SchemaMismatch {
499 message: format!("feature_wise std get: {e}"),
500 })?;
501 out.push((
502 columns[i].to_string(),
503 FeatureMeanStd {
504 mean: anyvalue_to_value(mean_av),
505 std_dev: anyvalue_to_value(std_av),
506 },
507 ));
508 }
509 Ok(out)
510 }
511
512 pub(crate) fn lazy_clone(&self) -> LazyFrame {
513 self.lf.clone()
514 }
515
516 pub(crate) fn from_lazyframe(lf: LazyFrame) -> Self {
517 Self { lf }
518 }
519}
520
521fn polars_reduce_expr(column: &str, op: ReduceOp) -> Expr {
522 use polars::datatypes::DataType as P;
523 let c = col(column);
524 match op {
525 ReduceOp::Count => len(),
526 ReduceOp::Sum => c.sum(),
527 ReduceOp::Min => c.min(),
528 ReduceOp::Max => c.max(),
529 ReduceOp::Mean => c.clone().strict_cast(P::Float64).mean(),
530 ReduceOp::Variance(kind) => {
531 let ddof = match kind {
532 VarianceKind::Population => 0u8,
533 VarianceKind::Sample => 1u8,
534 };
535 c.clone().strict_cast(P::Float64).var(ddof)
536 }
537 ReduceOp::StdDev(kind) => {
538 let ddof = match kind {
539 VarianceKind::Population => 0u8,
540 VarianceKind::Sample => 1u8,
541 };
542 c.clone().strict_cast(P::Float64).std(ddof)
543 }
544 ReduceOp::SumSquares => c.clone().strict_cast(P::Float64).pow(lit(2.0)).sum(),
545 ReduceOp::L2Norm => c.clone().strict_cast(P::Float64).pow(lit(2.0)).sum().sqrt(),
546 ReduceOp::CountDistinctNonNull => c.drop_nulls().n_unique(),
547 }
548}
549
550fn agg_to_expr(agg: &Agg) -> Expr {
551 use polars::datatypes::DataType as P;
552 match agg {
553 Agg::CountRows { alias } => len().alias(alias.as_str()),
554 Agg::CountNotNull { column, alias } => col(column.as_str()).count().alias(alias.as_str()),
555 Agg::Sum { column, alias } => col(column.as_str()).sum().alias(alias.as_str()),
556 Agg::Min { column, alias } => col(column.as_str()).min().alias(alias.as_str()),
557 Agg::Max { column, alias } => col(column.as_str()).max().alias(alias.as_str()),
558 Agg::Mean { column, alias } => col(column.as_str())
559 .strict_cast(P::Float64)
560 .mean()
561 .alias(alias.as_str()),
562 Agg::Variance {
563 column,
564 alias,
565 kind,
566 } => {
567 let ddof = match kind {
568 VarianceKind::Population => 0u8,
569 VarianceKind::Sample => 1u8,
570 };
571 col(column.as_str())
572 .strict_cast(P::Float64)
573 .var(ddof)
574 .alias(alias.as_str())
575 }
576 Agg::StdDev {
577 column,
578 alias,
579 kind,
580 } => {
581 let ddof = match kind {
582 VarianceKind::Population => 0u8,
583 VarianceKind::Sample => 1u8,
584 };
585 col(column.as_str())
586 .strict_cast(P::Float64)
587 .std(ddof)
588 .alias(alias.as_str())
589 }
590 Agg::SumSquares { column, alias } => col(column.as_str())
591 .strict_cast(P::Float64)
592 .pow(lit(2.0))
593 .sum()
594 .alias(alias.as_str()),
595 Agg::L2Norm { column, alias } => col(column.as_str())
596 .strict_cast(P::Float64)
597 .pow(lit(2.0))
598 .sum()
599 .sqrt()
600 .alias(alias.as_str()),
601 Agg::CountDistinctNonNull { column, alias } => col(column.as_str())
602 .drop_nulls()
603 .n_unique()
604 .alias(alias.as_str()),
605 }
606}
607
608fn to_polars_dtype(dt: &DataType) -> polars::datatypes::DataType {
609 match dt {
610 DataType::Int64 => polars::datatypes::DataType::Int64,
611 DataType::Float64 => polars::datatypes::DataType::Float64,
612 DataType::Bool => polars::datatypes::DataType::Boolean,
613 DataType::Utf8 => polars::datatypes::DataType::String,
614 }
615}
616
617fn value_to_lit_expr(value: Value) -> IngestionResult<Expr> {
618 match value {
619 Value::Null => Err(IngestionError::SchemaMismatch {
620 message: "Value::Null is not supported as a literal expression; use fill_null or cast/collect to materialize".to_string(),
621 }),
622 Value::Int64(v) => Ok(lit(v)),
623 Value::Float64(v) => Ok(lit(v)),
624 Value::Bool(v) => Ok(lit(v)),
625 Value::Utf8(v) => Ok(lit(v)),
626 }
627}
628
629fn anyvalue_to_value(av: AnyValue) -> Value {
630 match av {
631 AnyValue::Null => Value::Null,
632 AnyValue::Int8(v) => Value::Int64(v as i64),
633 AnyValue::Int16(v) => Value::Int64(v as i64),
634 AnyValue::Int32(v) => Value::Int64(v as i64),
635 AnyValue::Int64(v) => Value::Int64(v),
636 AnyValue::UInt8(v) => Value::Int64(v as i64),
637 AnyValue::UInt16(v) => Value::Int64(v as i64),
638 AnyValue::UInt32(v) => Value::Int64(v as i64),
639 AnyValue::UInt64(v) => Value::Int64(v as i64),
640 AnyValue::Float64(v) => Value::Float64(v),
641 AnyValue::Boolean(v) => Value::Bool(v),
642 AnyValue::String(v) => Value::Utf8(v.to_string()),
643 AnyValue::StringOwned(v) => Value::Utf8(v.to_string()),
644 other => Value::Utf8(other.to_string()),
645 }
646}
647
648pub type PolarsPipeline = DataFrame;
650
651#[cfg(test)]
652mod tests {
653 use super::{Agg, DataFrame, JoinKind, PolarsPipeline, Predicate};
654 use crate::processing::{ReduceOp, VarianceKind, feature_wise_mean_std, filter, map, reduce};
655 use crate::types::{DataSet, DataType, Field, Schema, Value};
656
657 fn sample_dataset() -> DataSet {
658 let schema = Schema::new(vec![
659 Field::new("id", DataType::Int64),
660 Field::new("active", DataType::Bool),
661 Field::new("score", DataType::Float64),
662 ]);
663 let rows = vec![
664 vec![Value::Int64(1), Value::Bool(true), Value::Float64(10.0)],
665 vec![Value::Int64(2), Value::Bool(true), Value::Float64(20.0)],
666 vec![Value::Int64(3), Value::Bool(false), Value::Float64(30.0)],
667 vec![Value::Int64(4), Value::Bool(true), Value::Null],
668 ];
669 DataSet::new(schema, rows)
670 }
671
672 #[test]
673 fn polars_pipeline_filter_map_reduce_parity_with_in_memory() {
674 let ds = sample_dataset();
675
676 let active_idx = ds.schema.index_of("active").unwrap();
678 let id_idx = ds.schema.index_of("id").unwrap();
679 let filtered = filter(&ds, |row| {
680 let is_active = matches!(row.get(active_idx), Some(Value::Bool(true)));
681 let even_id = matches!(row.get(id_idx), Some(Value::Int64(v)) if *v % 2 == 0);
682 is_active && even_id
683 });
684 let mapped = map(&filtered, |row| {
685 let mut out = row.to_vec();
686 if let Some(Value::Float64(v)) = out.get(2) {
687 out[2] = Value::Float64(v * 2.0);
688 }
689 out
690 });
691 let expected = reduce(&mapped, "score", ReduceOp::Sum).unwrap();
692
693 let got = DataFrame::from_dataset(&ds)
695 .unwrap()
696 .filter(Predicate::Eq {
697 column: "active".to_string(),
698 value: Value::Bool(true),
699 })
700 .unwrap()
701 .filter(Predicate::ModEqInt64 {
702 column: "id".to_string(),
703 modulus: 2,
704 equals: 0,
705 })
706 .unwrap()
707 .multiply_f64("score", 2.0)
708 .unwrap()
709 .sum("score")
710 .unwrap()
711 .unwrap();
712
713 assert_eq!(got, expected);
714 }
715
716 #[test]
717 fn polars_pipeline_reduce_parity_mean_variance_l2_distinct() {
718 let schema = Schema::new(vec![
719 Field::new("x", DataType::Float64),
720 Field::new("tag", DataType::Utf8),
721 ]);
722 let ds = DataSet::new(
723 schema,
724 vec![
725 vec![Value::Float64(1.0), Value::Utf8("a".to_string())],
726 vec![Value::Float64(2.0), Value::Utf8("b".to_string())],
727 vec![Value::Null, Value::Utf8("a".to_string())],
728 ],
729 );
730
731 let mean = reduce(&ds, "x", ReduceOp::Mean).unwrap();
732 let var_pop = reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Population)).unwrap();
733 let l2 = reduce(&ds, "x", ReduceOp::L2Norm).unwrap();
734 let dcnt = reduce(&ds, "tag", ReduceOp::CountDistinctNonNull).unwrap();
735
736 assert_eq!(
737 DataFrame::from_dataset(&ds)
738 .unwrap()
739 .reduce("x", ReduceOp::Mean)
740 .unwrap()
741 .unwrap(),
742 mean
743 );
744 assert_eq!(
745 DataFrame::from_dataset(&ds)
746 .unwrap()
747 .reduce("x", ReduceOp::Variance(VarianceKind::Population))
748 .unwrap()
749 .unwrap(),
750 var_pop
751 );
752 assert_eq!(
753 DataFrame::from_dataset(&ds)
754 .unwrap()
755 .reduce("x", ReduceOp::L2Norm)
756 .unwrap()
757 .unwrap(),
758 l2
759 );
760 assert_eq!(
761 DataFrame::from_dataset(&ds)
762 .unwrap()
763 .reduce("tag", ReduceOp::CountDistinctNonNull)
764 .unwrap()
765 .unwrap(),
766 dcnt
767 );
768 }
769
770 #[test]
771 fn polars_pipeline_collect_select_works() {
772 let ds = sample_dataset();
773 let out = DataFrame::from_dataset(&ds)
774 .unwrap()
775 .select(&["score", "id"])
776 .unwrap()
777 .collect()
778 .unwrap();
779
780 assert_eq!(
781 out.schema.field_names().collect::<Vec<_>>(),
782 vec!["score", "id"]
783 );
784 assert_eq!(out.row_count(), ds.row_count());
785 assert_eq!(out.rows[0][0], Value::Float64(10.0));
786 assert_eq!(out.rows[0][1], Value::Int64(1));
787 }
788
789 #[test]
790 fn polars_pipeline_sum_returns_none_for_missing_column() {
791 let ds = sample_dataset();
792 let out = DataFrame::from_dataset(&ds)
793 .unwrap()
794 .sum("missing")
795 .unwrap();
796 assert_eq!(out, None);
797 }
798
799 #[test]
800 fn polars_errors_are_preserved_as_engine_error_sources() {
801 let schema = Schema::new(vec![Field::new("name", DataType::Utf8)]);
803 let ds = DataSet::new(schema, vec![vec![Value::Utf8("x".to_string())]]);
804
805 let err = DataFrame::from_dataset(&ds)
806 .unwrap()
807 .multiply_f64("name", 2.0)
808 .unwrap()
809 .collect()
810 .unwrap_err();
811
812 match err {
814 crate::error::IngestionError::Engine { source, .. } => {
815 assert!(!source.to_string().is_empty());
816 }
817 other => panic!("expected Engine error, got: {other:?}"),
818 }
819 }
820
821 #[test]
822 fn backwards_compatible_polars_pipeline_alias_exists() {
823 let ds = sample_dataset();
824 let _ = PolarsPipeline::from_dataset(&ds)
825 .unwrap()
826 .select(&["id"])
827 .unwrap();
828 }
829
830 #[test]
831 fn rename_cast_fill_null_group_by_and_join_work() {
832 let schema = Schema::new(vec![
834 Field::new("id", DataType::Int64),
835 Field::new("score", DataType::Int64),
836 ]);
837 let ds = DataSet::new(
838 schema,
839 vec![
840 vec![Value::Int64(1), Value::Int64(10)],
841 vec![Value::Int64(2), Value::Null],
842 ],
843 );
844
845 let out = DataFrame::from_dataset(&ds)
846 .unwrap()
847 .rename(&[("score", "score_i")])
848 .unwrap()
849 .cast("score_i", DataType::Float64)
850 .unwrap()
851 .fill_null("score_i", Value::Float64(0.0))
852 .unwrap()
853 .collect()
854 .unwrap();
855
856 assert_eq!(
857 out.schema.field_names().collect::<Vec<_>>(),
858 vec!["id", "score_i"]
859 );
860 assert_eq!(out.rows[0][1], Value::Float64(10.0));
861 assert_eq!(out.rows[1][1], Value::Float64(0.0));
862
863 let schema = Schema::new(vec![
865 Field::new("grp", DataType::Utf8),
866 Field::new("score", DataType::Float64),
867 ]);
868 let ds = DataSet::new(
869 schema,
870 vec![
871 vec![Value::Utf8("A".to_string()), Value::Float64(1.0)],
872 vec![Value::Utf8("A".to_string()), Value::Float64(2.0)],
873 vec![Value::Utf8("B".to_string()), Value::Null],
874 ],
875 );
876
877 let out = DataFrame::from_dataset(&ds)
878 .unwrap()
879 .group_by(
880 &["grp"],
881 &[
882 Agg::Sum {
883 column: "score".to_string(),
884 alias: "sum_score".to_string(),
885 },
886 Agg::CountRows {
887 alias: "cnt".to_string(),
888 },
889 ],
890 )
891 .unwrap()
892 .collect()
893 .unwrap();
894
895 let mut sums: std::collections::HashMap<String, (Value, Value)> =
897 std::collections::HashMap::new();
898 for row in &out.rows {
899 if let Value::Utf8(g) = &row[0] {
900 sums.insert(g.clone(), (row[1].clone(), row[2].clone()));
901 }
902 }
903 assert_eq!(sums.get("A"), Some(&(Value::Float64(3.0), Value::Int64(2))));
904 assert_eq!(
905 sums.get("B"),
906 Some(&(Value::Float64(0.0), Value::Int64(1)))
908 );
909
910 let left = DataSet::new(
912 Schema::new(vec![
913 Field::new("id", DataType::Int64),
914 Field::new("name", DataType::Utf8),
915 ]),
916 vec![
917 vec![Value::Int64(1), Value::Utf8("Ada".to_string())],
918 vec![Value::Int64(2), Value::Utf8("Grace".to_string())],
919 ],
920 );
921 let right = DataSet::new(
922 Schema::new(vec![
923 Field::new("id", DataType::Int64),
924 Field::new("score", DataType::Float64),
925 ]),
926 vec![
927 vec![Value::Int64(1), Value::Float64(9.0)],
928 vec![Value::Int64(3), Value::Float64(7.0)],
929 ],
930 );
931
932 let out = DataFrame::from_dataset(&left)
933 .unwrap()
934 .join(
935 DataFrame::from_dataset(&right).unwrap(),
936 &["id"],
937 &["id"],
938 JoinKind::Inner,
939 )
940 .unwrap()
941 .collect()
942 .unwrap();
943 assert_eq!(out.row_count(), 1);
944 assert_eq!(out.rows[0][0], Value::Int64(1));
946 }
947
948 #[test]
949 fn polars_feature_wise_mean_std_matches_in_memory() {
950 let schema = Schema::new(vec![
951 Field::new("a", DataType::Int64),
952 Field::new("b", DataType::Float64),
953 ]);
954 let ds = DataSet::new(
955 schema,
956 vec![
957 vec![Value::Int64(1), Value::Float64(10.0)],
958 vec![Value::Int64(3), Value::Float64(20.0)],
959 ],
960 );
961 let mem = feature_wise_mean_std(&ds, &["a", "b"], VarianceKind::Sample).unwrap();
962 let pol = DataFrame::from_dataset(&ds)
963 .unwrap()
964 .feature_wise_mean_std(&["a", "b"], VarianceKind::Sample)
965 .unwrap();
966 assert_eq!(mem.len(), pol.len());
967 for i in 0..mem.len() {
968 assert_eq!(mem[i].0, pol[i].0);
969 assert_eq!(mem[i].1.mean, pol[i].1.mean);
970 match (&mem[i].1.std_dev, &pol[i].1.std_dev) {
971 (Value::Float64(m), Value::Float64(p)) => assert!((m - p).abs() < 1e-9),
972 (a, b) => assert_eq!(a, b),
973 }
974 }
975 }
976
977 #[test]
978 fn group_by_mean_std_count_distinct_all_null_numeric_is_null() {
979 let schema = Schema::new(vec![
980 Field::new("g", DataType::Utf8),
981 Field::new("x", DataType::Float64),
982 Field::new("tag", DataType::Utf8),
983 ]);
984 let ds = DataSet::new(
985 schema,
986 vec![
987 vec![
988 Value::Utf8("A".to_string()),
989 Value::Null,
990 Value::Utf8("p".to_string()),
991 ],
992 vec![
993 Value::Utf8("A".to_string()),
994 Value::Null,
995 Value::Utf8("q".to_string()),
996 ],
997 ],
998 );
999 let out = DataFrame::from_dataset(&ds)
1000 .unwrap()
1001 .group_by(
1002 &["g"],
1003 &[
1004 Agg::Mean {
1005 column: "x".to_string(),
1006 alias: "mx".to_string(),
1007 },
1008 Agg::StdDev {
1009 column: "x".to_string(),
1010 alias: "sx".to_string(),
1011 kind: VarianceKind::Sample,
1012 },
1013 Agg::CountDistinctNonNull {
1014 column: "tag".to_string(),
1015 alias: "dt".to_string(),
1016 },
1017 ],
1018 )
1019 .unwrap()
1020 .collect()
1021 .unwrap();
1022 assert_eq!(out.row_count(), 1);
1023 assert_eq!(out.rows[0][0], Value::Utf8("A".to_string()));
1024 assert_eq!(out.rows[0][1], Value::Null);
1025 assert_eq!(out.rows[0][2], Value::Null);
1026 assert_eq!(out.rows[0][3], Value::Int64(2));
1027 }
1028}