dt-cli-tools

CLI tools for viewing, filtering, and comparing tabular data files
Log | Files | Refs | README | LICENSE

commit ffcccf83cde6fa0e636583d99574c03047392193
parent 37a609215ba119cf6ecc566db0fca113512cfc95
Author: Erik Loualiche <eloualic@umn.edu>
Date:   Mon, 30 Mar 2026 23:41:40 -0500

feat: port filter and diff modules from xl-cli-tools

filter.rs: ported with letter-based column resolution removed (name-only)
diff.rs: ported verbatim, no logic changes

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Diffstat:
Msrc/diff.rs | 830+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Msrc/filter.rs | 671+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 1501 insertions(+), 0 deletions(-)

diff --git a/src/diff.rs b/src/diff.rs @@ -0,0 +1,830 @@ +// Diff engine for comparing two Excel sheets. + +use anyhow::{Result, bail}; +use polars::prelude::*; +use std::collections::HashMap; + +use crate::formatter; + +/// Source file and sheet metadata for display. +#[derive(Debug, Clone)] +pub struct SheetSource { + pub file_name: String, + pub sheet_name: String, +} + +/// A single row from an added or removed set. +#[derive(Debug, Clone)] +pub struct DiffRow { + pub values: Vec<String>, +} + +/// A change in a single cell. +#[derive(Debug, Clone)] +pub struct CellChange { + pub column: String, + pub old_value: String, + pub new_value: String, +} + +/// A row present in both files with cell-level differences. +#[derive(Debug, Clone)] +pub struct ModifiedRow { + pub key: Vec<String>, + pub changes: Vec<CellChange>, +} + +/// Result of comparing two sheets. +#[derive(Debug, Clone)] +pub struct DiffResult { + pub headers: Vec<String>, + pub key_columns: Vec<String>, + pub added: Vec<DiffRow>, + pub removed: Vec<DiffRow>, + pub modified: Vec<ModifiedRow>, + pub source_a: SheetSource, + pub source_b: SheetSource, +} + +impl DiffResult { + pub fn has_differences(&self) -> bool { + !self.added.is_empty() || !self.removed.is_empty() || !self.modified.is_empty() + } +} + +/// Options controlling how the diff is performed. +#[derive(Debug, Clone, Default)] +pub struct DiffOptions { + pub key_columns: Vec<String>, + pub tolerance: Option<f64>, +} + +// --------------------------------------------------------------------------- +// Helper functions +// --------------------------------------------------------------------------- + +/// Format a cell value for display. Returns empty string for null. +fn cell_to_string(col: &Column, idx: usize) -> String { + match col.get(idx) { + Ok(AnyValue::Null) | Err(_) => String::new(), + Ok(v) => formatter::format_any_value(&v), + } +} + +/// Format a cell value for hashing. Uses a sentinel for null so that null +/// and empty string produce different keys. +fn cell_to_key_part(col: &Column, idx: usize) -> String { + match col.get(idx) { + Ok(AnyValue::Null) | Err(_) => "\x01NULL\x01".to_string(), + Ok(v) => formatter::format_any_value(&v), + } +} + +/// Build a string key for an entire row by joining all column values. +fn row_to_key(df: &DataFrame, row_idx: usize) -> String { + df.get_columns() + .iter() + .map(|col| cell_to_key_part(col, row_idx)) + .collect::<Vec<_>>() + .join("\0") +} + +/// Collect display values for every column in a row. +fn row_to_strings(df: &DataFrame, row_idx: usize) -> Vec<String> { + df.get_columns() + .iter() + .map(|col| cell_to_string(col, row_idx)) + .collect() +} + +// --------------------------------------------------------------------------- +// Positional diff +// --------------------------------------------------------------------------- + +/// Compare two DataFrames positionally (no key columns). +/// +/// Uses multiset comparison: each unique row is tracked by frequency. +/// Rows present in A but not (or fewer times) in B are "removed"; +/// rows present in B but not (or fewer times) in A are "added". +pub fn diff_positional( + df_a: &DataFrame, + df_b: &DataFrame, + _opts: &DiffOptions, + source_a: SheetSource, + source_b: SheetSource, +) -> Result<DiffResult> { + // Determine headers — use the longer header set. + let headers_a: Vec<String> = df_a.get_column_names().iter().map(|s| s.to_string()).collect(); + let headers_b: Vec<String> = df_b.get_column_names().iter().map(|s| s.to_string()).collect(); + + let headers = if headers_b.len() > headers_a.len() { + if headers_a.len() != headers_b.len() { + eprintln!( + "Warning: column count differs ({} vs {}), using wider header set", + headers_a.len(), + headers_b.len() + ); + } + headers_b.clone() + } else { + if headers_a.len() != headers_b.len() { + eprintln!( + "Warning: column count differs ({} vs {}), using wider header set", + headers_a.len(), + headers_b.len() + ); + } + headers_a.clone() + }; + + let num_headers = headers.len(); + + // Build frequency maps: key → list of row indices (so we can consume them). + let mut freq_a: HashMap<String, Vec<usize>> = HashMap::new(); + for i in 0..df_a.height() { + let key = row_to_key(df_a, i); + freq_a.entry(key).or_default().push(i); + } + + let mut freq_b: HashMap<String, Vec<usize>> = HashMap::new(); + for i in 0..df_b.height() { + let key = row_to_key(df_b, i); + freq_b.entry(key).or_default().push(i); + } + + let mut removed = Vec::new(); + let mut added = Vec::new(); + + // Walk A: for each row, try to consume a matching row from B. + for i in 0..df_a.height() { + let key = row_to_key(df_a, i); + let consumed = freq_b + .get_mut(&key) + .and_then(|indices| indices.pop()) + .is_some(); + if !consumed { + let mut vals = row_to_strings(df_a, i); + vals.resize(num_headers, String::new()); + removed.push(DiffRow { values: vals }); + } + } + + // Walk B: for each row, try to consume a matching row from A. + for i in 0..df_b.height() { + let key = row_to_key(df_b, i); + let consumed = freq_a + .get_mut(&key) + .and_then(|indices| indices.pop()) + .is_some(); + if !consumed { + let mut vals = row_to_strings(df_b, i); + vals.resize(num_headers, String::new()); + added.push(DiffRow { values: vals }); + } + } + + Ok(DiffResult { + headers, + key_columns: vec![], + added, + removed, + modified: vec![], + source_a, + source_b, + }) +} + +// --------------------------------------------------------------------------- +// Key-based diff +// --------------------------------------------------------------------------- + +/// A row indexed by its key columns. +struct KeyedRow { + values: Vec<String>, + key_values: Vec<String>, +} + +/// Build a map from composite key string to KeyedRow for every row in the DataFrame. +fn build_key_map( + df: &DataFrame, + key_indices: &[usize], + columns: &[Column], +) -> HashMap<String, KeyedRow> { + let mut map = HashMap::new(); + for i in 0..df.height() { + let key_values: Vec<String> = key_indices + .iter() + .map(|&ki| cell_to_string(&columns[ki], i)) + .collect(); + let composite_key = key_values.join("\0"); + let values: Vec<String> = columns.iter().map(|col| cell_to_string(col, i)).collect(); + map.insert( + composite_key, + KeyedRow { + values, + key_values, + }, + ); + } + map +} + +/// Warn on stderr when duplicate keys are found. +fn check_duplicate_keys( + df: &DataFrame, + key_indices: &[usize], + columns: &[Column], + source: &SheetSource, +) { + let mut seen: HashMap<String, usize> = HashMap::new(); + for i in 0..df.height() { + let key: String = key_indices + .iter() + .map(|&ki| cell_to_string(&columns[ki], i)) + .collect::<Vec<_>>() + .join("\0"); + let count = seen.entry(key.clone()).or_insert(0); + *count += 1; + if *count == 2 { + let display_key = key.replace('\0', ", "); + eprintln!( + "Warning: duplicate key [{}] in {}:{}", + display_key, source.file_name, source.sheet_name + ); + } + } +} + +/// Check whether a polars DataType is a float type. +fn is_float_dtype(dt: &DataType) -> bool { + matches!(dt, DataType::Float32 | DataType::Float64) +} + +/// Check whether a polars DataType is an integer type. +fn is_int_dtype(dt: &DataType) -> bool { + matches!( + dt, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + ) +} + +/// Compare two string-rendered values with optional numeric tolerance. +/// +/// Rules: +/// - NaN == NaN is true. +/// - NaN vs non-NaN is false. +/// - Pure int+int columns use exact comparison (no tolerance applied). +/// - At least one float column applies tolerance. +/// - Otherwise exact string comparison. +fn values_equal_with_tolerance( + val_a: &str, + val_b: &str, + tolerance: f64, + df_a: &DataFrame, + df_b: &DataFrame, + col_name: &str, +) -> bool { + let parsed_a = val_a.parse::<f64>(); + let parsed_b = val_b.parse::<f64>(); + + match (parsed_a, parsed_b) { + (Ok(a), Ok(b)) => { + if a.is_nan() && b.is_nan() { + return true; + } + if a.is_nan() || b.is_nan() { + return false; + } + + let dt_a = df_a + .column(col_name) + .map(|c| c.dtype().clone()) + .unwrap_or(DataType::String); + let dt_b = df_b + .column(col_name) + .map(|c| c.dtype().clone()) + .unwrap_or(DataType::String); + + if is_int_dtype(&dt_a) && is_int_dtype(&dt_b) { + val_a == val_b + } else if is_float_dtype(&dt_a) || is_float_dtype(&dt_b) { + (a - b).abs() <= tolerance + } else { + val_a == val_b + } + } + _ => val_a == val_b, + } +} + +/// Compare non-key columns of two keyed rows and return cell-level changes. +#[allow(clippy::too_many_arguments)] +fn compare_rows( + df_a: &DataFrame, + df_b: &DataFrame, + headers_a: &[String], + headers_b: &[String], + row_a: &KeyedRow, + row_b: &KeyedRow, + common_columns: &[String], + opts: &DiffOptions, +) -> Vec<CellChange> { + let mut changes = Vec::new(); + for col_name in common_columns { + let idx_a = headers_a.iter().position(|h| h == col_name); + let idx_b = headers_b.iter().position(|h| h == col_name); + let val_a = idx_a + .map(|i| row_a.values.get(i).cloned().unwrap_or_default()) + .unwrap_or_default(); + let val_b = idx_b + .map(|i| row_b.values.get(i).cloned().unwrap_or_default()) + .unwrap_or_default(); + + let equal = if let Some(tol) = opts.tolerance { + values_equal_with_tolerance(&val_a, &val_b, tol, df_a, df_b, col_name) + } else { + val_a == val_b + }; + + if !equal { + changes.push(CellChange { + column: col_name.clone(), + old_value: val_a, + new_value: val_b, + }); + } + } + changes +} + +/// Compare two DataFrames using key columns. +pub fn diff_keyed( + df_a: &DataFrame, + df_b: &DataFrame, + opts: &DiffOptions, + source_a: SheetSource, + source_b: SheetSource, +) -> Result<DiffResult> { + let columns_a = df_a.get_columns(); + let columns_b = df_b.get_columns(); + let headers_a: Vec<String> = df_a.get_column_names().iter().map(|s| s.to_string()).collect(); + let headers_b: Vec<String> = df_b.get_column_names().iter().map(|s| s.to_string()).collect(); + + // Resolve key column indices in both frames. + let mut key_indices_a = Vec::new(); + let mut key_indices_b = Vec::new(); + for key_col in &opts.key_columns { + match headers_a.iter().position(|h| h == key_col) { + Some(idx) => key_indices_a.push(idx), + None => bail!("Key column '{}' not found in {}", key_col, source_a.file_name), + } + match headers_b.iter().position(|h| h == key_col) { + Some(idx) => key_indices_b.push(idx), + None => bail!("Key column '{}' not found in {}", key_col, source_b.file_name), + } + } + + // Find non-key columns. + let non_key_a: Vec<String> = headers_a + .iter() + .filter(|h| !opts.key_columns.contains(h)) + .cloned() + .collect(); + let non_key_b: Vec<String> = headers_b + .iter() + .filter(|h| !opts.key_columns.contains(h)) + .cloned() + .collect(); + + // Common non-key columns (for modification detection). + let common_columns: Vec<String> = non_key_a + .iter() + .filter(|h| non_key_b.contains(h)) + .cloned() + .collect(); + + // Warn about columns only in one file. + for col in &non_key_a { + if !non_key_b.contains(col) { + eprintln!("Warning: column '{}' only in {}", col, source_a.file_name); + } + } + for col in &non_key_b { + if !non_key_a.contains(col) { + eprintln!("Warning: column '{}' only in {}", col, source_b.file_name); + } + } + + // Build output headers: key columns + all from A non-key + B-only non-key. + let mut headers = opts.key_columns.clone(); + headers.extend(non_key_a.iter().cloned()); + for col in &non_key_b { + if !non_key_a.contains(col) { + headers.push(col.clone()); + } + } + + // Check for duplicate keys. + check_duplicate_keys(df_a, &key_indices_a, columns_a, &source_a); + check_duplicate_keys(df_b, &key_indices_b, columns_b, &source_b); + + // Build key maps. + let map_a = build_key_map(df_a, &key_indices_a, columns_a); + let map_b = build_key_map(df_b, &key_indices_b, columns_b); + + let mut removed = Vec::new(); + let mut added = Vec::new(); + let mut modified = Vec::new(); + + // Keys in A but not in B → removed. + for (composite_key, row_a) in &map_a { + if !map_b.contains_key(composite_key) { + let mut vals = Vec::new(); + for h in &headers { + if let Some(idx) = headers_a.iter().position(|ha| ha == h) { + vals.push(row_a.values.get(idx).cloned().unwrap_or_default()); + } else { + vals.push(String::new()); + } + } + removed.push(DiffRow { values: vals }); + } + } + + // Keys in B but not in A → added. + for (composite_key, row_b) in &map_b { + if !map_a.contains_key(composite_key) { + let mut vals = Vec::new(); + for h in &headers { + if let Some(idx) = headers_b.iter().position(|hb| hb == h) { + vals.push(row_b.values.get(idx).cloned().unwrap_or_default()); + } else { + vals.push(String::new()); + } + } + added.push(DiffRow { values: vals }); + } + } + + // Keys in both → compare for modifications. + for (composite_key, row_a) in &map_a { + if let Some(row_b) = map_b.get(composite_key) { + let changes = compare_rows( + df_a, + df_b, + &headers_a, + &headers_b, + row_a, + row_b, + &common_columns, + opts, + ); + if !changes.is_empty() { + modified.push(ModifiedRow { + key: row_a.key_values.clone(), + changes, + }); + } + } + } + + Ok(DiffResult { + headers, + key_columns: opts.key_columns.clone(), + added, + removed, + modified, + source_a, + source_b, + }) +} + +// --------------------------------------------------------------------------- +// Entry point +// --------------------------------------------------------------------------- + +/// Compare two DataFrames, dispatching to positional or key-based diff +/// depending on whether key columns are specified. +pub fn diff_sheets( + df_a: &DataFrame, + df_b: &DataFrame, + opts: &DiffOptions, + source_a: SheetSource, + source_b: SheetSource, +) -> Result<DiffResult> { + if opts.key_columns.is_empty() { + diff_positional(df_a, df_b, opts, source_a, source_b) + } else { + diff_keyed(df_a, df_b, opts, source_a, source_b) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_source_a() -> SheetSource { + SheetSource { + file_name: "a.xlsx".into(), + sheet_name: "Sheet1".into(), + } + } + + fn test_source_b() -> SheetSource { + SheetSource { + file_name: "b.xlsx".into(), + sheet_name: "Sheet1".into(), + } + } + + // ---- Positional diff tests ---- + + #[test] + fn test_positional_no_diff() { + let df_a = df! { + "name" => &["Alice", "Bob"], + "score" => &[100, 200], + } + .unwrap(); + let df_b = df_a.clone(); + let opts = DiffOptions::default(); + + let result = diff_positional(&df_a, &df_b, &opts, test_source_a(), test_source_b()).unwrap(); + + assert!(!result.has_differences()); + assert!(result.added.is_empty()); + assert!(result.removed.is_empty()); + assert!(result.modified.is_empty()); + } + + #[test] + fn test_positional_added_removed() { + let df_a = df! { + "name" => &["Alice", "Bob"], + } + .unwrap(); + let df_b = df! { + "name" => &["Alice", "Charlie"], + } + .unwrap(); + let opts = DiffOptions::default(); + + let result = diff_positional(&df_a, &df_b, &opts, test_source_a(), test_source_b()).unwrap(); + + assert!(result.has_differences()); + assert_eq!(result.removed.len(), 1); + assert_eq!(result.removed[0].values, vec!["Bob"]); + assert_eq!(result.added.len(), 1); + assert_eq!(result.added[0].values, vec!["Charlie"]); + } + + #[test] + fn test_positional_duplicate_rows() { + let df_a = df! { + "val" => &["A", "A", "A"], + } + .unwrap(); + let df_b = df! { + "val" => &["A", "A"], + } + .unwrap(); + let opts = DiffOptions::default(); + + let result = diff_positional(&df_a, &df_b, &opts, test_source_a(), test_source_b()).unwrap(); + + assert_eq!(result.removed.len(), 1); + assert_eq!(result.removed[0].values, vec!["A"]); + assert!(result.added.is_empty()); + } + + // ---- Key-based diff tests ---- + + #[test] + fn test_keyed_no_diff() { + let df_a = df! { + "id" => &[1, 2], + "name" => &["Alice", "Bob"], + } + .unwrap(); + let df_b = df_a.clone(); + let opts = DiffOptions { + key_columns: vec!["id".into()], + tolerance: None, + }; + + let result = diff_keyed(&df_a, &df_b, &opts, test_source_a(), test_source_b()).unwrap(); + + assert!(!result.has_differences()); + } + + #[test] + fn test_keyed_added_removed() { + let df_a = df! { + "id" => &[1, 2], + "name" => &["Alice", "Bob"], + } + .unwrap(); + let df_b = df! { + "id" => &[2, 3], + "name" => &["Bob", "Charlie"], + } + .unwrap(); + let opts = DiffOptions { + key_columns: vec!["id".into()], + tolerance: None, + }; + + let result = diff_keyed(&df_a, &df_b, &opts, test_source_a(), test_source_b()).unwrap(); + + assert_eq!(result.removed.len(), 1); + assert!(result.removed[0].values.contains(&"1".to_string())); + assert!(result.removed[0].values.contains(&"Alice".to_string())); + + assert_eq!(result.added.len(), 1); + assert!(result.added[0].values.contains(&"3".to_string())); + assert!(result.added[0].values.contains(&"Charlie".to_string())); + } + + #[test] + fn test_keyed_modified() { + let df_a = df! { + "id" => &[1, 2], + "score" => &[100, 200], + } + .unwrap(); + let df_b = df! { + "id" => &[1, 2], + "score" => &[100, 250], + } + .unwrap(); + let opts = DiffOptions { + key_columns: vec!["id".into()], + tolerance: None, + }; + + let result = diff_keyed(&df_a, &df_b, &opts, test_source_a(), test_source_b()).unwrap(); + + assert!(result.added.is_empty()); + assert!(result.removed.is_empty()); + assert_eq!(result.modified.len(), 1); + + let m = &result.modified[0]; + assert_eq!(m.key, vec!["2"]); + assert_eq!(m.changes.len(), 1); + assert_eq!(m.changes[0].column, "score"); + assert_eq!(m.changes[0].old_value, "200"); + assert_eq!(m.changes[0].new_value, "250"); + } + + #[test] + fn test_keyed_composite_key() { + let df_a = df! { + "date" => &["2024-01-01", "2024-01-01", "2024-01-02"], + "ticker" => &["AAPL", "GOOG", "AAPL"], + "price" => &[150.0, 140.0, 151.0], + } + .unwrap(); + let df_b = df! { + "date" => &["2024-01-01", "2024-01-01", "2024-01-02"], + "ticker" => &["AAPL", "GOOG", "AAPL"], + "price" => &[150.0, 142.0, 151.0], + } + .unwrap(); + let opts = DiffOptions { + key_columns: vec!["date".into(), "ticker".into()], + tolerance: None, + }; + + let result = diff_keyed(&df_a, &df_b, &opts, test_source_a(), test_source_b()).unwrap(); + + assert!(result.added.is_empty()); + assert!(result.removed.is_empty()); + assert_eq!(result.modified.len(), 1); + + let m = &result.modified[0]; + assert_eq!(m.key, vec!["2024-01-01", "GOOG"]); + assert_eq!(m.changes[0].column, "price"); + assert_eq!(m.changes[0].old_value, "140"); + assert_eq!(m.changes[0].new_value, "142"); + } + + // ---- Tolerance tests ---- + + #[test] + fn test_keyed_tolerance_within() { + let df_a = df! { + "id" => &[1], + "price" => &[100.001_f64], + } + .unwrap(); + let df_b = df! { + "id" => &[1], + "price" => &[100.002_f64], + } + .unwrap(); + let opts = DiffOptions { + key_columns: vec!["id".into()], + tolerance: Some(0.01), + }; + + let result = diff_keyed(&df_a, &df_b, &opts, test_source_a(), test_source_b()).unwrap(); + + assert!(!result.has_differences()); + } + + #[test] + fn test_keyed_tolerance_exceeded() { + let df_a = df! { + "id" => &[1], + "price" => &[100.0_f64], + } + .unwrap(); + let df_b = df! { + "id" => &[1], + "price" => &[100.05_f64], + } + .unwrap(); + let opts = DiffOptions { + key_columns: vec!["id".into()], + tolerance: Some(0.01), + }; + + let result = diff_keyed(&df_a, &df_b, &opts, test_source_a(), test_source_b()).unwrap(); + + assert_eq!(result.modified.len(), 1); + assert_eq!(result.modified[0].changes[0].column, "price"); + } + + #[test] + fn test_keyed_nan_handling() { + let df_a = df! { + "id" => &[1], + "value" => &[f64::NAN], + } + .unwrap(); + let df_b = df! { + "id" => &[1], + "value" => &[f64::NAN], + } + .unwrap(); + let opts = DiffOptions { + key_columns: vec!["id".into()], + tolerance: Some(0.01), + }; + + let result = diff_keyed(&df_a, &df_b, &opts, test_source_a(), test_source_b()).unwrap(); + + assert!(!result.has_differences(), "NaN vs NaN should be treated as equal"); + } + + // ---- diff_sheets entry point tests ---- + + #[test] + fn test_diff_sheets_positional() { + let df_a = df! { + "name" => &["Alice", "Bob"], + } + .unwrap(); + let df_b = df! { + "name" => &["Alice", "Charlie"], + } + .unwrap(); + let opts = DiffOptions::default(); // No key columns → positional. + + let result = diff_sheets(&df_a, &df_b, &opts, test_source_a(), test_source_b()).unwrap(); + + assert!(result.key_columns.is_empty()); + assert_eq!(result.removed.len(), 1); + assert_eq!(result.added.len(), 1); + } + + #[test] + fn test_diff_sheets_keyed() { + let df_a = df! { + "id" => &[1, 2], + "score" => &[100, 200], + } + .unwrap(); + let df_b = df! { + "id" => &[1, 2], + "score" => &[100, 250], + } + .unwrap(); + let opts = DiffOptions { + key_columns: vec!["id".into()], + tolerance: None, + }; + + let result = diff_sheets(&df_a, &df_b, &opts, test_source_a(), test_source_b()).unwrap(); + + assert_eq!(result.key_columns, vec!["id"]); + assert_eq!(result.modified.len(), 1); + } +} diff --git a/src/filter.rs b/src/filter.rs @@ -0,0 +1,671 @@ +use anyhow::Result; +use polars::prelude::*; + +#[derive(Debug, Clone, PartialEq)] +pub enum FilterOp { + Eq, + NotEq, + Gt, + Lt, + Gte, + Lte, + Contains, + NotContains, +} + +#[derive(Debug, Clone)] +pub struct FilterExpr { + pub column: String, + pub op: FilterOp, + pub value: String, +} + +#[derive(Debug, Clone)] +pub struct SortSpec { + pub column: String, + pub descending: bool, +} + +/// Parse a filter expression like "State=CA", "Amount>1000", "Name~john". +/// Scans left-to-right for the first operator character (= ! > < ~), +/// then determines the full operator. +pub fn parse_filter_expr(s: &str) -> Result<FilterExpr, String> { + let op_chars = ['=', '!', '>', '<', '~']; + let pos = s + .find(|c: char| op_chars.contains(&c)) + .ok_or_else(|| { + format!( + "no operator found in '{}'. Use =, !=, >, <, >=, <=, ~ or !~", + s + ) + })?; + if pos == 0 { + return Err(format!("missing column name in '{}'", s)); + } + let column = s[..pos].to_string(); + let rest = &s[pos..]; + let (op, op_len) = if rest.starts_with(">=") { + (FilterOp::Gte, 2) + } else if rest.starts_with("<=") { + (FilterOp::Lte, 2) + } else if rest.starts_with("!=") { + (FilterOp::NotEq, 2) + } else if rest.starts_with("!~") { + (FilterOp::NotContains, 2) + } else if rest.starts_with('>') { + (FilterOp::Gt, 1) + } else if rest.starts_with('<') { + (FilterOp::Lt, 1) + } else if rest.starts_with('=') { + (FilterOp::Eq, 1) + } else if rest.starts_with('~') { + (FilterOp::Contains, 1) + } else { + return Err(format!("invalid operator in '{}'", s)); + }; + let value = rest[op_len..].to_string(); + Ok(FilterExpr { column, op, value }) +} + +/// Parse a sort spec like "Amount:desc" or "Name" (default asc). +/// Splits on the last colon so column names containing colons are supported. +pub fn parse_sort_spec(s: &str) -> Result<SortSpec, String> { + if let Some(colon_pos) = s.rfind(':') { + let col = &s[..colon_pos]; + let dir = &s[colon_pos + 1..]; + match dir.to_lowercase().as_str() { + "asc" => Ok(SortSpec { + column: col.to_string(), + descending: false, + }), + "desc" => Ok(SortSpec { + column: col.to_string(), + descending: true, + }), + _ => Err(format!( + "invalid sort direction '{}'. Use 'asc' or 'desc'", + dir + )), + } + } else { + Ok(SortSpec { + column: s.to_string(), + descending: false, + }) + } +} + +/// Resolve a column specifier to a DataFrame column name. +/// Accepts either: +/// - A header name (exact match first, then case-insensitive) +pub fn resolve_column(spec: &str, df_columns: &[String]) -> Result<String, String> { + if df_columns.contains(&spec.to_string()) { + return Ok(spec.to_string()); + } + let spec_lower = spec.to_lowercase(); + for col in df_columns { + if col.to_lowercase() == spec_lower { + return Ok(col.clone()); + } + } + let available = df_columns.join(", "); + Err(format!("column '{}' not found. Available columns: {}", spec, available)) +} + +/// Resolve a list of column specifiers to DataFrame column names. +pub fn resolve_columns(specs: &[String], df_columns: &[String]) -> Result<Vec<String>, String> { + specs.iter().map(|s| resolve_column(s, df_columns)).collect() +} + +/// Check if a polars DataType is numeric. +fn is_numeric_dtype(dtype: &DataType) -> bool { + matches!( + dtype, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + ) +} + +/// Build a boolean mask for a single filter expression against a DataFrame. +fn build_filter_mask(df: &DataFrame, expr: &FilterExpr) -> Result<BooleanChunked> { + let col = df.column(&expr.column).map_err(|e| anyhow::anyhow!("{}", e))?; + let series = col.as_materialized_series(); + let dtype = series.dtype(); + + match &expr.op { + FilterOp::Eq => { + if is_numeric_dtype(dtype) { + if let Ok(n) = expr.value.parse::<f64>() { + let s = series.cast(&DataType::Float64)?; + return Ok(s.f64()?.equal(n)); + } + } + let s = series.cast(&DataType::String)?; + Ok(s.str()?.equal(expr.value.as_str())) + } + FilterOp::NotEq => { + if is_numeric_dtype(dtype) { + if let Ok(n) = expr.value.parse::<f64>() { + let s = series.cast(&DataType::Float64)?; + return Ok(s.f64()?.not_equal(n)); + } + } + let s = series.cast(&DataType::String)?; + Ok(s.str()?.not_equal(expr.value.as_str())) + } + FilterOp::Gt => { + let n = parse_numeric_value(&expr.value, ">")?; + let s = series.cast(&DataType::Float64)?; + Ok(s.f64()?.gt(n)) + } + FilterOp::Lt => { + let n = parse_numeric_value(&expr.value, "<")?; + let s = series.cast(&DataType::Float64)?; + Ok(s.f64()?.lt(n)) + } + FilterOp::Gte => { + let n = parse_numeric_value(&expr.value, ">=")?; + let s = series.cast(&DataType::Float64)?; + Ok(s.f64()?.gt_eq(n)) + } + FilterOp::Lte => { + let n = parse_numeric_value(&expr.value, "<=")?; + let s = series.cast(&DataType::Float64)?; + Ok(s.f64()?.lt_eq(n)) + } + FilterOp::Contains => { + let s = series.cast(&DataType::String)?; + let ca = s.str()?; + let pat = expr.value.to_lowercase(); + let mask: BooleanChunked = ca.into_iter() + .map(|opt_s| opt_s.map(|s| s.to_lowercase().contains(&pat)).unwrap_or(false)) + .collect(); + Ok(mask) + } + FilterOp::NotContains => { + let s = series.cast(&DataType::String)?; + let ca = s.str()?; + let pat = expr.value.to_lowercase(); + let mask: BooleanChunked = ca.into_iter() + .map(|opt_s| opt_s.map(|s| !s.to_lowercase().contains(&pat)).unwrap_or(true)) + .collect(); + Ok(mask) + } + } +} + +fn parse_numeric_value(value: &str, op: &str) -> Result<f64> { + value + .parse::<f64>() + .map_err(|_| anyhow::anyhow!("'{}' requires numeric value, got '{}'", op, value)) +} + +/// Apply a list of filter expressions to a DataFrame (AND logic). +/// An empty list returns the DataFrame unchanged. +pub fn apply_filters(df: &DataFrame, exprs: &[FilterExpr]) -> Result<DataFrame> { + let mut result = df.clone(); + for expr in exprs { + let mask = build_filter_mask(&result, expr)?; + result = result.filter(&mask)?; + } + Ok(result) +} + +/// Options for the filter pipeline. +pub struct FilterOptions { + pub filters: Vec<FilterExpr>, + pub cols: Option<Vec<String>>, + pub sort: Option<SortSpec>, + pub limit: Option<usize>, + pub head: Option<usize>, + pub tail: Option<usize>, +} + +/// Apply a sort specification to a DataFrame. +pub fn apply_sort(df: &DataFrame, spec: &SortSpec) -> Result<DataFrame> { + let opts = SortMultipleOptions::default() + .with_order_descending(spec.descending); + Ok(df.sort([&spec.column], opts)?) +} + +/// Run the full filter pipeline: head/tail → resolve & filter → sort → limit → select columns. +pub fn filter_pipeline(df: DataFrame, opts: &FilterOptions) -> Result<DataFrame> { + let df_columns: Vec<String> = df + .get_column_names() + .iter() + .map(|s| s.to_string()) + .collect(); + + // 1. Pre-filter window: head or tail + let df = if let Some(n) = opts.head { + df.head(Some(n)) + } else if let Some(n) = opts.tail { + df.tail(Some(n)) + } else { + df + }; + + // 2. Resolve column names in filter expressions and apply filters + let resolved_filters: Vec<FilterExpr> = opts + .filters + .iter() + .map(|f| { + let resolved_col = resolve_column(&f.column, &df_columns)?; + Ok(FilterExpr { + column: resolved_col, + op: f.op.clone(), + value: f.value.clone(), + }) + }) + .collect::<Result<Vec<_>, String>>() + .map_err(|e| anyhow::anyhow!("{}", e))?; + + let df = apply_filters(&df, &resolved_filters)?; + + // 3. Sort + let df = if let Some(ref spec) = opts.sort { + let resolved_col = resolve_column(&spec.column, &df_columns) + .map_err(|e| anyhow::anyhow!("{}", e))?; + let resolved_spec = SortSpec { + column: resolved_col, + descending: spec.descending, + }; + apply_sort(&df, &resolved_spec)? + } else { + df + }; + + // 4. Limit (after filtering and sorting) + let df = if let Some(n) = opts.limit { + df.head(Some(n)) + } else { + df + }; + + // 5. Select columns + let df = if let Some(ref col_specs) = opts.cols { + let resolved_cols = resolve_columns(col_specs, &df_columns) + .map_err(|e| anyhow::anyhow!("{}", e))?; + let col_refs: Vec<&str> = resolved_cols.iter().map(|s| s.as_str()).collect(); + df.select(col_refs)? + } else { + df + }; + + Ok(df) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_test_df() -> DataFrame { + DataFrame::new(vec![ + Column::new("State".into(), &["CA", "NY", "CA", "TX", "NY"]), + Column::new("City".into(), &["LA", "NYC", "SF", "Houston", "Albany"]), + Column::new("Amount".into(), &[1500i64, 2000, 800, 1200, 500]), + Column::new("Year".into(), &[2023i64, 2023, 2024, 2024, 2023]), + Column::new("Status".into(), &["Active", "Active", "Draft", "Active", "Draft"]), + ]) + .unwrap() + } + + #[test] + fn filter_eq_string() { + let df = make_test_df(); + let expr = parse_filter_expr("State=CA").unwrap(); + let result = apply_filters(&df, &[expr]).unwrap(); + assert_eq!(result.height(), 2); + } + + #[test] + fn filter_eq_numeric() { + let df = make_test_df(); + let expr = parse_filter_expr("Amount=1500").unwrap(); + let result = apply_filters(&df, &[expr]).unwrap(); + assert_eq!(result.height(), 1); + } + + #[test] + fn filter_not_eq() { + let df = make_test_df(); + let expr = parse_filter_expr("Status!=Draft").unwrap(); + let result = apply_filters(&df, &[expr]).unwrap(); + assert_eq!(result.height(), 3); + } + + #[test] + fn filter_gt() { + let df = make_test_df(); + let expr = parse_filter_expr("Amount>1000").unwrap(); + let result = apply_filters(&df, &[expr]).unwrap(); + assert_eq!(result.height(), 3); + } + + #[test] + fn filter_lt() { + let df = make_test_df(); + let expr = parse_filter_expr("Amount<1000").unwrap(); + let result = apply_filters(&df, &[expr]).unwrap(); + assert_eq!(result.height(), 2); + } + + #[test] + fn filter_gte() { + let df = make_test_df(); + let expr = parse_filter_expr("Amount>=1500").unwrap(); + let result = apply_filters(&df, &[expr]).unwrap(); + assert_eq!(result.height(), 2); + } + + #[test] + fn filter_lte() { + let df = make_test_df(); + let expr = parse_filter_expr("Amount<=800").unwrap(); + let result = apply_filters(&df, &[expr]).unwrap(); + assert_eq!(result.height(), 2); + } + + #[test] + fn filter_contains() { + let df = make_test_df(); + let expr = parse_filter_expr("City~ou").unwrap(); + let result = apply_filters(&df, &[expr]).unwrap(); + assert_eq!(result.height(), 1); + } + + #[test] + fn filter_contains_case_insensitive() { + let df = make_test_df(); + let expr = parse_filter_expr("City~HOUSTON").unwrap(); + let result = apply_filters(&df, &[expr]).unwrap(); + assert_eq!(result.height(), 1); + } + + #[test] + fn filter_not_contains() { + let df = make_test_df(); + let expr = parse_filter_expr("Status!~raft").unwrap(); + let result = apply_filters(&df, &[expr]).unwrap(); + assert_eq!(result.height(), 3); + } + + #[test] + fn filter_multiple_and() { + let df = make_test_df(); + let e1 = parse_filter_expr("State=CA").unwrap(); + let e2 = parse_filter_expr("Amount>1000").unwrap(); + let result = apply_filters(&df, &[e1, e2]).unwrap(); + assert_eq!(result.height(), 1); + } + + #[test] + fn filter_no_matches_returns_empty() { + let df = make_test_df(); + let expr = parse_filter_expr("State=ZZ").unwrap(); + let result = apply_filters(&df, &[expr]).unwrap(); + assert_eq!(result.height(), 0); + } + + #[test] + fn filter_empty_exprs_returns_all() { + let df = make_test_df(); + let result = apply_filters(&df, &[]).unwrap(); + assert_eq!(result.height(), 5); + } + + #[test] + fn parse_eq() { + let expr = parse_filter_expr("State=CA").unwrap(); + assert_eq!(expr.column, "State"); + assert_eq!(expr.op, FilterOp::Eq); + assert_eq!(expr.value, "CA"); + } + + #[test] + fn parse_not_eq() { + let expr = parse_filter_expr("Status!=Draft").unwrap(); + assert_eq!(expr.column, "Status"); + assert_eq!(expr.op, FilterOp::NotEq); + assert_eq!(expr.value, "Draft"); + } + + #[test] + fn parse_gt() { + let expr = parse_filter_expr("Amount>1000").unwrap(); + assert_eq!(expr.column, "Amount"); + assert_eq!(expr.op, FilterOp::Gt); + assert_eq!(expr.value, "1000"); + } + + #[test] + fn parse_lt() { + let expr = parse_filter_expr("Year<2024").unwrap(); + assert_eq!(expr.column, "Year"); + assert_eq!(expr.op, FilterOp::Lt); + assert_eq!(expr.value, "2024"); + } + + #[test] + fn parse_gte() { + let expr = parse_filter_expr("Score>=90").unwrap(); + assert_eq!(expr.column, "Score"); + assert_eq!(expr.op, FilterOp::Gte); + assert_eq!(expr.value, "90"); + } + + #[test] + fn parse_lte() { + let expr = parse_filter_expr("Price<=50.5").unwrap(); + assert_eq!(expr.column, "Price"); + assert_eq!(expr.op, FilterOp::Lte); + assert_eq!(expr.value, "50.5"); + } + + #[test] + fn parse_contains() { + let expr = parse_filter_expr("Name~john").unwrap(); + assert_eq!(expr.column, "Name"); + assert_eq!(expr.op, FilterOp::Contains); + assert_eq!(expr.value, "john"); + } + + #[test] + fn parse_not_contains() { + let expr = parse_filter_expr("Name!~draft").unwrap(); + assert_eq!(expr.column, "Name"); + assert_eq!(expr.op, FilterOp::NotContains); + assert_eq!(expr.value, "draft"); + } + + #[test] + fn parse_value_with_equals() { + let expr = parse_filter_expr("Formula=A+B=C").unwrap(); + assert_eq!(expr.column, "Formula"); + assert_eq!(expr.op, FilterOp::Eq); + assert_eq!(expr.value, "A+B=C"); + } + + #[test] + fn parse_empty_value() { + let expr = parse_filter_expr("Status=").unwrap(); + assert_eq!(expr.column, "Status"); + assert_eq!(expr.op, FilterOp::Eq); + assert_eq!(expr.value, ""); + } + + #[test] + fn parse_no_operator_is_err() { + assert!(parse_filter_expr("JustAWord").is_err()); + } + + #[test] + fn parse_no_column_is_err() { + assert!(parse_filter_expr("=value").is_err()); + } + + #[test] + fn parse_sort_desc() { + let spec = parse_sort_spec("Amount:desc").unwrap(); + assert_eq!(spec.column, "Amount"); + assert!(spec.descending); + } + + #[test] + fn parse_sort_asc() { + let spec = parse_sort_spec("Name:asc").unwrap(); + assert_eq!(spec.column, "Name"); + assert!(!spec.descending); + } + + #[test] + fn parse_sort_default_asc() { + let spec = parse_sort_spec("Name").unwrap(); + assert_eq!(spec.column, "Name"); + assert!(!spec.descending); + } + + #[test] + fn parse_sort_bad_dir_is_err() { + assert!(parse_sort_spec("Name:up").is_err()); + } + + #[test] + fn resolve_by_header_name() { + let cols = vec!["State".to_string(), "Amount".to_string(), "Year".to_string()]; + assert_eq!(resolve_column("Amount", &cols).unwrap(), "Amount"); + } + + #[test] + fn resolve_case_insensitive_header() { + let cols = vec!["State".to_string(), "Amount".to_string()]; + assert_eq!(resolve_column("state", &cols).unwrap(), "State"); + } + + #[test] + fn resolve_unknown_column_is_err() { + let cols = vec!["State".to_string(), "Amount".to_string()]; + let err = resolve_column("Foo", &cols).unwrap_err(); + assert!(err.contains("not found"), "error was: {}", err); + } + + #[test] + fn resolve_multiple_columns() { + let cols = vec!["State".to_string(), "Amount".to_string(), "Year".to_string()]; + let resolved = resolve_columns(&["State".to_string(), "Year".to_string()], &cols).unwrap(); + assert_eq!(resolved, vec!["State", "Year"]); + } + + #[test] + fn sort_ascending() { + let df = make_test_df(); + let spec = parse_sort_spec("Amount:asc").unwrap(); + let result = apply_sort(&df, &spec).unwrap(); + let col = result.column("Amount").unwrap().as_materialized_series(); + let amounts = col.i64().unwrap(); + assert_eq!(amounts.get(0), Some(500)); + assert_eq!(amounts.get(4), Some(2000)); + } + + #[test] + fn sort_descending() { + let df = make_test_df(); + let spec = parse_sort_spec("Amount:desc").unwrap(); + let result = apply_sort(&df, &spec).unwrap(); + let col = result.column("Amount").unwrap().as_materialized_series(); + let amounts = col.i64().unwrap(); + assert_eq!(amounts.get(0), Some(2000)); + assert_eq!(amounts.get(4), Some(500)); + } + + #[test] + fn pipeline_full() { + let df = make_test_df(); + let opts = FilterOptions { + filters: vec![parse_filter_expr("Amount>500").unwrap()], + cols: Some(vec!["State".to_string(), "Amount".to_string()]), + sort: Some(parse_sort_spec("Amount:desc").unwrap()), + limit: Some(2), + head: None, + tail: None, + }; + let result = filter_pipeline(df, &opts).unwrap(); + assert_eq!(result.height(), 2); + assert_eq!(result.width(), 2); + let col = result.column("Amount").unwrap().as_materialized_series(); + let amounts = col.i64().unwrap(); + assert_eq!(amounts.get(0), Some(2000)); + assert_eq!(amounts.get(1), Some(1500)); + } + + #[test] + fn pipeline_head_before_filter() { + let df = make_test_df(); // 5 rows: CA/LA, NY/NYC, CA/SF, TX/Houston, NY/Albany + let opts = FilterOptions { + filters: vec![parse_filter_expr("State=NY").unwrap()], + cols: None, + sort: None, + limit: None, + head: Some(3), // Take first 3 rows before filtering + tail: None, + }; + let result = filter_pipeline(df, &opts).unwrap(); + // First 3 rows: CA/LA, NY/NYC, CA/SF → only NY/NYC matches + assert_eq!(result.height(), 1); + } + + #[test] + fn pipeline_tail_before_filter() { + let df = make_test_df(); // 5 rows + let opts = FilterOptions { + filters: vec![parse_filter_expr("State=CA").unwrap()], + cols: None, + sort: None, + limit: None, + head: None, + tail: Some(3), // Last 3 rows before filtering + }; + let result = filter_pipeline(df, &opts).unwrap(); + // Last 3 rows: CA/SF, TX/Houston, NY/Albany → only CA/SF matches + assert_eq!(result.height(), 1); + } + + #[test] + fn pipeline_no_options_returns_all() { + let df = make_test_df(); + let opts = FilterOptions { + filters: vec![], + cols: None, + sort: None, + limit: None, + head: None, + tail: None, + }; + let result = filter_pipeline(df, &opts).unwrap(); + assert_eq!(result.height(), 5); + assert_eq!(result.width(), 5); + } + + #[test] + fn pipeline_limit_after_filter() { + let df = make_test_df(); + let opts = FilterOptions { + filters: vec![parse_filter_expr("Status=Active").unwrap()], + cols: None, + sort: None, + limit: Some(2), + head: None, + tail: None, + }; + let result = filter_pipeline(df, &opts).unwrap(); + assert_eq!(result.height(), 2); // 3 Active rows, limited to 2 + } +}