1use crate::error::IngestionResult;
64use crate::pipeline::{CastMode, DataFrame};
65use crate::types::{DataSet, DataType, Schema, Value};
66use serde::{Deserialize, Serialize};
67
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
70pub enum TransformStep {
71 Select { columns: Vec<String> },
73 Drop { columns: Vec<String> },
75 Rename { pairs: Vec<(String, String)> },
77 Cast {
79 column: String,
80 to: DataType,
81 #[serde(default)]
82 mode: CastMode,
83 },
84 FillNull { column: String, value: Value },
86 WithLiteral { name: String, value: Value },
88 DeriveMulF64 {
90 name: String,
91 source: String,
92 factor: f64,
93 },
94 DeriveAddF64 {
96 name: String,
97 source: String,
98 delta: f64,
99 },
100}
101
102#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
108pub struct TransformSpec {
109 pub output_schema: Schema,
110 pub steps: Vec<TransformStep>,
111}
112
113impl TransformSpec {
114 pub fn new(output_schema: Schema) -> Self {
115 Self {
116 output_schema,
117 steps: Vec::new(),
118 }
119 }
120
121 pub fn with_step(mut self, step: TransformStep) -> Self {
122 self.steps.push(step);
123 self
124 }
125
126 pub fn apply(&self, input: &DataSet) -> IngestionResult<DataSet> {
128 let mut df = DataFrame::from_dataset(input)?;
129
130 for step in &self.steps {
131 df = match step {
132 TransformStep::Select { columns } => {
133 let cols: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
134 df.select(&cols)?
135 }
136 TransformStep::Drop { columns } => {
137 let cols: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
138 df.drop(&cols)?
139 }
140 TransformStep::Rename { pairs } => {
141 let pairs_ref: Vec<(&str, &str)> = pairs
142 .iter()
143 .map(|(a, b)| (a.as_str(), b.as_str()))
144 .collect();
145 df.rename(&pairs_ref)?
146 }
147 TransformStep::Cast { column, to, mode } => {
148 df.cast_with_mode(column, to.clone(), *mode)?
149 }
150 TransformStep::FillNull { column, value } => df.fill_null(column, value.clone())?,
151 TransformStep::WithLiteral { name, value } => {
152 df.with_literal(name, value.clone())?
153 }
154 TransformStep::DeriveMulF64 {
155 name,
156 source,
157 factor,
158 } => df.with_mul_f64(name, source, *factor)?,
159 TransformStep::DeriveAddF64 {
160 name,
161 source,
162 delta,
163 } => df.with_add_f64(name, source, *delta)?,
164 };
165 }
166
167 df.collect_with_schema(&self.output_schema)
168 }
169}
170
171#[cfg(feature = "arrow")]
173pub mod arrow {
174 use std::sync::Arc;
175
176 use arrow::array::{Array, ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray};
177 use arrow::datatypes::{DataType as ArrowDataType, Field, Schema as ArrowSchema};
178 use arrow::record_batch::RecordBatch;
179
180 use crate::error::{IngestionError, IngestionResult};
181 use crate::types::{DataSet, DataType, Field as DsField, Schema, Value};
182
183 pub fn schema_from_record_batch(batch: &RecordBatch) -> IngestionResult<Schema> {
184 let mut fields = Vec::with_capacity(batch.schema().fields().len());
185 for f in batch.schema().fields() {
186 let dt = match f.data_type() {
187 ArrowDataType::Int64 => DataType::Int64,
188 ArrowDataType::Float64 => DataType::Float64,
189 ArrowDataType::Boolean => DataType::Bool,
190 ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => DataType::Utf8,
191 other => {
192 return Err(IngestionError::SchemaMismatch {
193 message: format!("unsupported Arrow dtype in schema: {other:?}"),
194 });
195 }
196 };
197 fields.push(DsField::new(f.name().to_string(), dt));
198 }
199 Ok(Schema::new(fields))
200 }
201
202 pub fn dataset_to_record_batch(ds: &DataSet) -> IngestionResult<RecordBatch> {
203 let mut arrow_fields = Vec::with_capacity(ds.schema.fields.len());
204 let mut cols: Vec<ArrayRef> = Vec::with_capacity(ds.schema.fields.len());
205
206 for (col_idx, field) in ds.schema.fields.iter().enumerate() {
207 match field.data_type {
208 DataType::Int64 => {
209 let mut v = Vec::with_capacity(ds.row_count());
210 for row in &ds.rows {
211 match row.get(col_idx) {
212 Some(Value::Null) | None => v.push(None),
213 Some(Value::Int64(x)) => v.push(Some(*x)),
214 Some(other) => {
215 return Err(IngestionError::ParseError {
216 row: 1,
217 column: field.name.clone(),
218 raw: format!("{other:?}"),
219 message: "value does not match schema type Int64".to_string(),
220 });
221 }
222 }
223 }
224 cols.push(Arc::new(Int64Array::from(v)) as ArrayRef);
225 arrow_fields.push(Field::new(&field.name, ArrowDataType::Int64, true));
226 }
227 DataType::Float64 => {
228 let mut v = Vec::with_capacity(ds.row_count());
229 for row in &ds.rows {
230 match row.get(col_idx) {
231 Some(Value::Null) | None => v.push(None),
232 Some(Value::Float64(x)) => v.push(Some(*x)),
233 Some(other) => {
234 return Err(IngestionError::ParseError {
235 row: 1,
236 column: field.name.clone(),
237 raw: format!("{other:?}"),
238 message: "value does not match schema type Float64".to_string(),
239 });
240 }
241 }
242 }
243 cols.push(Arc::new(Float64Array::from(v)) as ArrayRef);
244 arrow_fields.push(Field::new(&field.name, ArrowDataType::Float64, true));
245 }
246 DataType::Bool => {
247 let mut v = Vec::with_capacity(ds.row_count());
248 for row in &ds.rows {
249 match row.get(col_idx) {
250 Some(Value::Null) | None => v.push(None),
251 Some(Value::Bool(x)) => v.push(Some(*x)),
252 Some(other) => {
253 return Err(IngestionError::ParseError {
254 row: 1,
255 column: field.name.clone(),
256 raw: format!("{other:?}"),
257 message: "value does not match schema type Bool".to_string(),
258 });
259 }
260 }
261 }
262 cols.push(Arc::new(BooleanArray::from(v)) as ArrayRef);
263 arrow_fields.push(Field::new(&field.name, ArrowDataType::Boolean, true));
264 }
265 DataType::Utf8 => {
266 let mut v = Vec::with_capacity(ds.row_count());
267 for row in &ds.rows {
268 match row.get(col_idx) {
269 Some(Value::Null) | None => v.push(None),
270 Some(Value::Utf8(x)) => v.push(Some(x.as_str())),
271 Some(other) => {
272 return Err(IngestionError::ParseError {
273 row: 1,
274 column: field.name.clone(),
275 raw: format!("{other:?}"),
276 message: "value does not match schema type Utf8".to_string(),
277 });
278 }
279 }
280 }
281 cols.push(Arc::new(StringArray::from(v)) as ArrayRef);
282 arrow_fields.push(Field::new(&field.name, ArrowDataType::Utf8, true));
283 }
284 }
285 }
286
287 let schema = Arc::new(ArrowSchema::new(arrow_fields));
288 RecordBatch::try_new(schema, cols).map_err(|e| IngestionError::Engine {
289 message: "failed to build Arrow RecordBatch".to_string(),
290 source: Box::new(e),
291 })
292 }
293
294 pub fn record_batch_to_dataset(
295 batch: &RecordBatch,
296 schema: &Schema,
297 ) -> IngestionResult<DataSet> {
298 let mut col_idx = Vec::with_capacity(schema.fields.len());
300 for f in &schema.fields {
301 let idx =
302 batch
303 .schema()
304 .index_of(&f.name)
305 .map_err(|_| IngestionError::SchemaMismatch {
306 message: format!("missing required column '{}'", f.name),
307 })?;
308 col_idx.push(idx);
309 }
310
311 let nrows = batch.num_rows();
312 let mut out_rows = Vec::with_capacity(nrows);
313 for row_i in 0..nrows {
314 let mut row = Vec::with_capacity(schema.fields.len());
315 for (field, idx) in schema.fields.iter().zip(col_idx.iter().copied()) {
316 let arr = batch.column(idx);
317 let v = match field.data_type {
318 DataType::Int64 => {
319 let a = arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
320 IngestionError::SchemaMismatch {
321 message: format!("arrow column '{}' is not Int64", field.name),
322 }
323 })?;
324 if a.is_null(row_i) {
325 Value::Null
326 } else {
327 Value::Int64(a.value(row_i))
328 }
329 }
330 DataType::Float64 => {
331 let a = arr.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
332 IngestionError::SchemaMismatch {
333 message: format!("arrow column '{}' is not Float64", field.name),
334 }
335 })?;
336 if a.is_null(row_i) {
337 Value::Null
338 } else {
339 Value::Float64(a.value(row_i))
340 }
341 }
342 DataType::Bool => {
343 let a = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
344 IngestionError::SchemaMismatch {
345 message: format!("arrow column '{}' is not Boolean", field.name),
346 }
347 })?;
348 if a.is_null(row_i) {
349 Value::Null
350 } else {
351 Value::Bool(a.value(row_i))
352 }
353 }
354 DataType::Utf8 => {
355 if let Some(a) = arr.as_any().downcast_ref::<StringArray>() {
357 if a.is_null(row_i) {
358 Value::Null
359 } else {
360 Value::Utf8(a.value(row_i).to_string())
361 }
362 } else {
363 return Err(IngestionError::SchemaMismatch {
364 message: format!("arrow column '{}' is not Utf8", field.name),
365 });
366 }
367 }
368 };
369 row.push(v);
370 }
371 out_rows.push(row);
372 }
373 Ok(DataSet::new(schema.clone(), out_rows))
374 }
375}
376
377#[cfg(feature = "serde_arrow")]
381pub mod serde_interop {
382 use arrow::datatypes::FieldRef;
383 use arrow::record_batch::RecordBatch;
384 use serde_arrow::schema::{SchemaLike, TracingOptions};
385
386 use crate::error::{IngestionError, IngestionResult};
387
388 pub fn to_record_batch<T>(records: &Vec<T>) -> IngestionResult<RecordBatch>
390 where
391 T: serde::Serialize + for<'de> serde::Deserialize<'de>,
392 {
393 let fields = Vec::<FieldRef>::from_type::<T>(TracingOptions::default()).map_err(|e| {
394 IngestionError::Engine {
395 message: "failed to trace Arrow schema from type".to_string(),
396 source: Box::new(e),
397 }
398 })?;
399
400 serde_arrow::to_record_batch(&fields, records).map_err(|e| IngestionError::Engine {
401 message: "failed to convert records to Arrow RecordBatch".to_string(),
402 source: Box::new(e),
403 })
404 }
405
406 pub fn from_record_batch<T>(batch: &RecordBatch) -> IngestionResult<Vec<T>>
408 where
409 T: serde::de::DeserializeOwned,
410 {
411 serde_arrow::from_record_batch(batch).map_err(|e| IngestionError::Engine {
412 message: "failed to deserialize records from Arrow RecordBatch".to_string(),
413 source: Box::new(e),
414 })
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::{TransformSpec, TransformStep};
421 use crate::pipeline::CastMode;
422 use crate::types::{DataSet, DataType, Field, Schema, Value};
423
424 fn sample_dataset() -> DataSet {
425 let schema = Schema::new(vec![
426 Field::new("id", DataType::Int64),
427 Field::new("score", DataType::Int64),
428 ]);
429 let rows = vec![
430 vec![Value::Int64(1), Value::Int64(10)],
431 vec![Value::Int64(2), Value::Null],
432 ];
433 DataSet::new(schema, rows)
434 }
435
436 #[test]
437 fn transform_spec_can_rename_cast_fill_and_derive() {
438 let ds = sample_dataset();
439
440 let out_schema = Schema::new(vec![
441 Field::new("id", DataType::Int64),
442 Field::new("score_x2", DataType::Float64),
443 Field::new("score_f", DataType::Float64),
444 Field::new("tag", DataType::Utf8),
445 ]);
446
447 let spec = TransformSpec::new(out_schema.clone())
448 .with_step(TransformStep::Rename {
449 pairs: vec![("score".to_string(), "score_f".to_string())],
450 })
451 .with_step(TransformStep::Cast {
452 column: "score_f".to_string(),
453 to: DataType::Float64,
454 mode: CastMode::Strict,
455 })
456 .with_step(TransformStep::FillNull {
457 column: "score_f".to_string(),
458 value: Value::Float64(0.0),
459 })
460 .with_step(TransformStep::DeriveMulF64 {
461 name: "score_x2".to_string(),
462 source: "score_f".to_string(),
463 factor: 2.0,
464 })
465 .with_step(TransformStep::WithLiteral {
466 name: "tag".to_string(),
467 value: Value::Utf8("A".to_string()),
468 })
469 .with_step(TransformStep::Select {
470 columns: vec![
471 "id".to_string(),
472 "score_x2".to_string(),
473 "score_f".to_string(),
474 "tag".to_string(),
475 ],
476 });
477
478 let out = spec.apply(&ds).unwrap();
479 assert_eq!(out.schema, out_schema);
480 assert_eq!(out.row_count(), 2);
481 assert_eq!(out.rows[0][0], Value::Int64(1));
482 assert_eq!(out.rows[0][1], Value::Float64(20.0));
483 assert_eq!(out.rows[0][2], Value::Float64(10.0));
484 assert_eq!(out.rows[0][3], Value::Utf8("A".to_string()));
485
486 assert_eq!(out.rows[1][0], Value::Int64(2));
487 assert_eq!(out.rows[1][1], Value::Float64(0.0));
488 assert_eq!(out.rows[1][2], Value::Float64(0.0));
489 assert_eq!(out.rows[1][3], Value::Utf8("A".to_string()));
490 }
491}