rust_data_processing/
transform.rs

1//! Transformation specifications and helpers.
2//!
3//! This module defines **engine-agnostic** transformation specs in crate-owned types that can be
4//! applied to an in-memory [`crate::types::DataSet`].
5//!
6//! Phase 1 intent:
7//! - Keep public API free of Polars types
8//! - Implement by compiling to the Polars-backed [`crate::pipeline::DataFrame`] where possible
9//! - Reserve room for additional backends later
10//!
11//! ## Example
12//!
13//! ```rust
14//! use rust_data_processing::pipeline::CastMode;
15//! use rust_data_processing::transform::{TransformSpec, TransformStep};
16//! use rust_data_processing::types::{DataSet, DataType, Field, Schema, Value};
17//!
18//! # fn main() -> Result<(), rust_data_processing::IngestionError> {
19//! let ds = DataSet::new(
20//!     Schema::new(vec![
21//!         Field::new("id", DataType::Int64),
22//!         Field::new("score", DataType::Int64),
23//!         Field::new("weather", DataType::Utf8),
24//!     ]),
25//!     vec![
26//!         vec![Value::Int64(1), Value::Int64(10), Value::Utf8("drizzle".to_string())],
27//!         vec![Value::Int64(2), Value::Null, Value::Utf8("rain".to_string())],
28//!     ],
29//! );
30//!
31//! let out_schema = Schema::new(vec![
32//!     Field::new("id", DataType::Int64),
33//!     Field::new("score_f", DataType::Float64),
34//!     Field::new("wx", DataType::Utf8),
35//! ]);
36//!
37//! let spec = TransformSpec::new(out_schema.clone())
38//!     .with_step(TransformStep::Rename {
39//!         pairs: vec![("weather".to_string(), "wx".to_string())],
40//!     })
41//!     .with_step(TransformStep::Rename {
42//!         pairs: vec![("score".to_string(), "score_f".to_string())],
43//!     })
44//!     .with_step(TransformStep::Cast {
45//!         column: "score_f".to_string(),
46//!         to: DataType::Float64,
47//!         mode: CastMode::Lossy,
48//!     })
49//!     .with_step(TransformStep::FillNull {
50//!         column: "score_f".to_string(),
51//!         value: Value::Float64(0.0),
52//!     })
53//!     .with_step(TransformStep::Select {
54//!         columns: vec!["id".to_string(), "score_f".to_string(), "wx".to_string()],
55//!     });
56//!
57//! let out = spec.apply(&ds)?;
58//! assert_eq!(out.schema, out_schema);
59//! # Ok(())
60//! # }
61//! ```
62
63use crate::error::IngestionResult;
64use crate::pipeline::{CastMode, DataFrame};
65use crate::types::{DataSet, DataType, Schema, Value};
66use serde::{Deserialize, Serialize};
67
68/// A transformation step in a [`TransformSpec`].
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
70pub enum TransformStep {
71    /// Select/reorder columns (in the provided order).
72    Select { columns: Vec<String> },
73    /// Drop columns.
74    Drop { columns: Vec<String> },
75    /// Rename columns (strict: source columns must exist).
76    Rename { pairs: Vec<(String, String)> },
77    /// Cast a column to a target type.
78    Cast {
79        column: String,
80        to: DataType,
81        #[serde(default)]
82        mode: CastMode,
83    },
84    /// Fill nulls in a column with a literal.
85    FillNull { column: String, value: Value },
86    /// Add a derived column with a literal value.
87    WithLiteral { name: String, value: Value },
88    /// Add a derived Float64 column: `name = source * factor` (nulls propagate).
89    DeriveMulF64 {
90        name: String,
91        source: String,
92        factor: f64,
93    },
94    /// Add a derived Float64 column: `name = source + delta` (nulls propagate).
95    DeriveAddF64 {
96        name: String,
97        source: String,
98        delta: f64,
99    },
100}
101
102/// A user-provided transformation specification with an explicit output schema.
103///
104/// The output schema is used to:
105/// - enforce required output columns exist
106/// - enforce output types (via casting) when collecting back into a [`DataSet`]
107#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
108pub struct TransformSpec {
109    pub output_schema: Schema,
110    pub steps: Vec<TransformStep>,
111}
112
113impl TransformSpec {
114    pub fn new(output_schema: Schema) -> Self {
115        Self {
116            output_schema,
117            steps: Vec::new(),
118        }
119    }
120
121    pub fn with_step(mut self, step: TransformStep) -> Self {
122        self.steps.push(step);
123        self
124    }
125
126    /// Apply this spec to an input dataset.
127    pub fn apply(&self, input: &DataSet) -> IngestionResult<DataSet> {
128        let mut df = DataFrame::from_dataset(input)?;
129
130        for step in &self.steps {
131            df = match step {
132                TransformStep::Select { columns } => {
133                    let cols: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
134                    df.select(&cols)?
135                }
136                TransformStep::Drop { columns } => {
137                    let cols: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
138                    df.drop(&cols)?
139                }
140                TransformStep::Rename { pairs } => {
141                    let pairs_ref: Vec<(&str, &str)> = pairs
142                        .iter()
143                        .map(|(a, b)| (a.as_str(), b.as_str()))
144                        .collect();
145                    df.rename(&pairs_ref)?
146                }
147                TransformStep::Cast { column, to, mode } => {
148                    df.cast_with_mode(column, to.clone(), *mode)?
149                }
150                TransformStep::FillNull { column, value } => df.fill_null(column, value.clone())?,
151                TransformStep::WithLiteral { name, value } => {
152                    df.with_literal(name, value.clone())?
153                }
154                TransformStep::DeriveMulF64 {
155                    name,
156                    source,
157                    factor,
158                } => df.with_mul_f64(name, source, *factor)?,
159                TransformStep::DeriveAddF64 {
160                    name,
161                    source,
162                    delta,
163                } => df.with_add_f64(name, source, *delta)?,
164            };
165        }
166
167        df.collect_with_schema(&self.output_schema)
168    }
169}
170
171/// Arrow interop helpers (feature-gated).
172#[cfg(feature = "arrow")]
173pub mod arrow {
174    use std::sync::Arc;
175
176    use arrow::array::{Array, ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray};
177    use arrow::datatypes::{DataType as ArrowDataType, Field, Schema as ArrowSchema};
178    use arrow::record_batch::RecordBatch;
179
180    use crate::error::{IngestionError, IngestionResult};
181    use crate::types::{DataSet, DataType, Field as DsField, Schema, Value};
182
183    pub fn schema_from_record_batch(batch: &RecordBatch) -> IngestionResult<Schema> {
184        let mut fields = Vec::with_capacity(batch.schema().fields().len());
185        for f in batch.schema().fields() {
186            let dt = match f.data_type() {
187                ArrowDataType::Int64 => DataType::Int64,
188                ArrowDataType::Float64 => DataType::Float64,
189                ArrowDataType::Boolean => DataType::Bool,
190                ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => DataType::Utf8,
191                other => {
192                    return Err(IngestionError::SchemaMismatch {
193                        message: format!("unsupported Arrow dtype in schema: {other:?}"),
194                    });
195                }
196            };
197            fields.push(DsField::new(f.name().to_string(), dt));
198        }
199        Ok(Schema::new(fields))
200    }
201
202    pub fn dataset_to_record_batch(ds: &DataSet) -> IngestionResult<RecordBatch> {
203        let mut arrow_fields = Vec::with_capacity(ds.schema.fields.len());
204        let mut cols: Vec<ArrayRef> = Vec::with_capacity(ds.schema.fields.len());
205
206        for (col_idx, field) in ds.schema.fields.iter().enumerate() {
207            match field.data_type {
208                DataType::Int64 => {
209                    let mut v = Vec::with_capacity(ds.row_count());
210                    for row in &ds.rows {
211                        match row.get(col_idx) {
212                            Some(Value::Null) | None => v.push(None),
213                            Some(Value::Int64(x)) => v.push(Some(*x)),
214                            Some(other) => {
215                                return Err(IngestionError::ParseError {
216                                    row: 1,
217                                    column: field.name.clone(),
218                                    raw: format!("{other:?}"),
219                                    message: "value does not match schema type Int64".to_string(),
220                                });
221                            }
222                        }
223                    }
224                    cols.push(Arc::new(Int64Array::from(v)) as ArrayRef);
225                    arrow_fields.push(Field::new(&field.name, ArrowDataType::Int64, true));
226                }
227                DataType::Float64 => {
228                    let mut v = Vec::with_capacity(ds.row_count());
229                    for row in &ds.rows {
230                        match row.get(col_idx) {
231                            Some(Value::Null) | None => v.push(None),
232                            Some(Value::Float64(x)) => v.push(Some(*x)),
233                            Some(other) => {
234                                return Err(IngestionError::ParseError {
235                                    row: 1,
236                                    column: field.name.clone(),
237                                    raw: format!("{other:?}"),
238                                    message: "value does not match schema type Float64".to_string(),
239                                });
240                            }
241                        }
242                    }
243                    cols.push(Arc::new(Float64Array::from(v)) as ArrayRef);
244                    arrow_fields.push(Field::new(&field.name, ArrowDataType::Float64, true));
245                }
246                DataType::Bool => {
247                    let mut v = Vec::with_capacity(ds.row_count());
248                    for row in &ds.rows {
249                        match row.get(col_idx) {
250                            Some(Value::Null) | None => v.push(None),
251                            Some(Value::Bool(x)) => v.push(Some(*x)),
252                            Some(other) => {
253                                return Err(IngestionError::ParseError {
254                                    row: 1,
255                                    column: field.name.clone(),
256                                    raw: format!("{other:?}"),
257                                    message: "value does not match schema type Bool".to_string(),
258                                });
259                            }
260                        }
261                    }
262                    cols.push(Arc::new(BooleanArray::from(v)) as ArrayRef);
263                    arrow_fields.push(Field::new(&field.name, ArrowDataType::Boolean, true));
264                }
265                DataType::Utf8 => {
266                    let mut v = Vec::with_capacity(ds.row_count());
267                    for row in &ds.rows {
268                        match row.get(col_idx) {
269                            Some(Value::Null) | None => v.push(None),
270                            Some(Value::Utf8(x)) => v.push(Some(x.as_str())),
271                            Some(other) => {
272                                return Err(IngestionError::ParseError {
273                                    row: 1,
274                                    column: field.name.clone(),
275                                    raw: format!("{other:?}"),
276                                    message: "value does not match schema type Utf8".to_string(),
277                                });
278                            }
279                        }
280                    }
281                    cols.push(Arc::new(StringArray::from(v)) as ArrayRef);
282                    arrow_fields.push(Field::new(&field.name, ArrowDataType::Utf8, true));
283                }
284            }
285        }
286
287        let schema = Arc::new(ArrowSchema::new(arrow_fields));
288        RecordBatch::try_new(schema, cols).map_err(|e| IngestionError::Engine {
289            message: "failed to build Arrow RecordBatch".to_string(),
290            source: Box::new(e),
291        })
292    }
293
294    pub fn record_batch_to_dataset(
295        batch: &RecordBatch,
296        schema: &Schema,
297    ) -> IngestionResult<DataSet> {
298        // Map schema fields to column indices by name.
299        let mut col_idx = Vec::with_capacity(schema.fields.len());
300        for f in &schema.fields {
301            let idx =
302                batch
303                    .schema()
304                    .index_of(&f.name)
305                    .map_err(|_| IngestionError::SchemaMismatch {
306                        message: format!("missing required column '{}'", f.name),
307                    })?;
308            col_idx.push(idx);
309        }
310
311        let nrows = batch.num_rows();
312        let mut out_rows = Vec::with_capacity(nrows);
313        for row_i in 0..nrows {
314            let mut row = Vec::with_capacity(schema.fields.len());
315            for (field, idx) in schema.fields.iter().zip(col_idx.iter().copied()) {
316                let arr = batch.column(idx);
317                let v = match field.data_type {
318                    DataType::Int64 => {
319                        let a = arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
320                            IngestionError::SchemaMismatch {
321                                message: format!("arrow column '{}' is not Int64", field.name),
322                            }
323                        })?;
324                        if a.is_null(row_i) {
325                            Value::Null
326                        } else {
327                            Value::Int64(a.value(row_i))
328                        }
329                    }
330                    DataType::Float64 => {
331                        let a = arr.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
332                            IngestionError::SchemaMismatch {
333                                message: format!("arrow column '{}' is not Float64", field.name),
334                            }
335                        })?;
336                        if a.is_null(row_i) {
337                            Value::Null
338                        } else {
339                            Value::Float64(a.value(row_i))
340                        }
341                    }
342                    DataType::Bool => {
343                        let a = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
344                            IngestionError::SchemaMismatch {
345                                message: format!("arrow column '{}' is not Boolean", field.name),
346                            }
347                        })?;
348                        if a.is_null(row_i) {
349                            Value::Null
350                        } else {
351                            Value::Bool(a.value(row_i))
352                        }
353                    }
354                    DataType::Utf8 => {
355                        // Accept both Utf8 and LargeUtf8 arrays.
356                        if let Some(a) = arr.as_any().downcast_ref::<StringArray>() {
357                            if a.is_null(row_i) {
358                                Value::Null
359                            } else {
360                                Value::Utf8(a.value(row_i).to_string())
361                            }
362                        } else {
363                            return Err(IngestionError::SchemaMismatch {
364                                message: format!("arrow column '{}' is not Utf8", field.name),
365                            });
366                        }
367                    }
368                };
369                row.push(v);
370            }
371            out_rows.push(row);
372        }
373        Ok(DataSet::new(schema.clone(), out_rows))
374    }
375}
376
377/// Serde-based interop helpers (feature-gated).
378///
379/// This uses `serde_arrow` to reduce boilerplate when turning a Rust record type into columnar data.
380#[cfg(feature = "serde_arrow")]
381pub mod serde_interop {
382    use arrow::datatypes::FieldRef;
383    use arrow::record_batch::RecordBatch;
384    use serde_arrow::schema::{SchemaLike, TracingOptions};
385
386    use crate::error::{IngestionError, IngestionResult};
387
388    /// Build a `RecordBatch` from Rust records using schema tracing.
389    pub fn to_record_batch<T>(records: &Vec<T>) -> IngestionResult<RecordBatch>
390    where
391        T: serde::Serialize + for<'de> serde::Deserialize<'de>,
392    {
393        let fields = Vec::<FieldRef>::from_type::<T>(TracingOptions::default()).map_err(|e| {
394            IngestionError::Engine {
395                message: "failed to trace Arrow schema from type".to_string(),
396                source: Box::new(e),
397            }
398        })?;
399
400        serde_arrow::to_record_batch(&fields, records).map_err(|e| IngestionError::Engine {
401            message: "failed to convert records to Arrow RecordBatch".to_string(),
402            source: Box::new(e),
403        })
404    }
405
406    /// Deserialize Rust records from a `RecordBatch`.
407    pub fn from_record_batch<T>(batch: &RecordBatch) -> IngestionResult<Vec<T>>
408    where
409        T: serde::de::DeserializeOwned,
410    {
411        serde_arrow::from_record_batch(batch).map_err(|e| IngestionError::Engine {
412            message: "failed to deserialize records from Arrow RecordBatch".to_string(),
413            source: Box::new(e),
414        })
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::{TransformSpec, TransformStep};
421    use crate::pipeline::CastMode;
422    use crate::types::{DataSet, DataType, Field, Schema, Value};
423
424    fn sample_dataset() -> DataSet {
425        let schema = Schema::new(vec![
426            Field::new("id", DataType::Int64),
427            Field::new("score", DataType::Int64),
428        ]);
429        let rows = vec![
430            vec![Value::Int64(1), Value::Int64(10)],
431            vec![Value::Int64(2), Value::Null],
432        ];
433        DataSet::new(schema, rows)
434    }
435
436    #[test]
437    fn transform_spec_can_rename_cast_fill_and_derive() {
438        let ds = sample_dataset();
439
440        let out_schema = Schema::new(vec![
441            Field::new("id", DataType::Int64),
442            Field::new("score_x2", DataType::Float64),
443            Field::new("score_f", DataType::Float64),
444            Field::new("tag", DataType::Utf8),
445        ]);
446
447        let spec = TransformSpec::new(out_schema.clone())
448            .with_step(TransformStep::Rename {
449                pairs: vec![("score".to_string(), "score_f".to_string())],
450            })
451            .with_step(TransformStep::Cast {
452                column: "score_f".to_string(),
453                to: DataType::Float64,
454                mode: CastMode::Strict,
455            })
456            .with_step(TransformStep::FillNull {
457                column: "score_f".to_string(),
458                value: Value::Float64(0.0),
459            })
460            .with_step(TransformStep::DeriveMulF64 {
461                name: "score_x2".to_string(),
462                source: "score_f".to_string(),
463                factor: 2.0,
464            })
465            .with_step(TransformStep::WithLiteral {
466                name: "tag".to_string(),
467                value: Value::Utf8("A".to_string()),
468            })
469            .with_step(TransformStep::Select {
470                columns: vec![
471                    "id".to_string(),
472                    "score_x2".to_string(),
473                    "score_f".to_string(),
474                    "tag".to_string(),
475                ],
476            });
477
478        let out = spec.apply(&ds).unwrap();
479        assert_eq!(out.schema, out_schema);
480        assert_eq!(out.row_count(), 2);
481        assert_eq!(out.rows[0][0], Value::Int64(1));
482        assert_eq!(out.rows[0][1], Value::Float64(20.0));
483        assert_eq!(out.rows[0][2], Value::Float64(10.0));
484        assert_eq!(out.rows[0][3], Value::Utf8("A".to_string()));
485
486        assert_eq!(out.rows[1][0], Value::Int64(2));
487        assert_eq!(out.rows[1][1], Value::Float64(0.0));
488        assert_eq!(out.rows[1][2], Value::Float64(0.0));
489        assert_eq!(out.rows[1][3], Value::Utf8("A".to_string()));
490    }
491}