filter.rs (21036B)
1 use anyhow::Result; 2 use polars::prelude::*; 3 4 #[derive(Debug, Clone, PartialEq)] 5 pub enum FilterOp { 6 Eq, 7 NotEq, 8 Gt, 9 Lt, 10 Gte, 11 Lte, 12 Contains, 13 NotContains, 14 } 15 16 #[derive(Debug, Clone)] 17 pub struct FilterExpr { 18 pub column: String, 19 pub op: FilterOp, 20 pub value: String, 21 } 22 23 #[derive(Debug, Clone)] 24 pub struct SortSpec { 25 pub column: String, 26 pub descending: bool, 27 } 28 29 /// Parse a filter expression like "State=CA", "Amount>1000", "Name~john". 30 /// Scans left-to-right for the first operator character (= ! > < ~), 31 /// then determines the full operator. 32 pub fn parse_filter_expr(s: &str) -> Result<FilterExpr, String> { 33 let op_chars = ['=', '!', '>', '<', '~']; 34 let pos = s 35 .find(|c: char| op_chars.contains(&c)) 36 .ok_or_else(|| { 37 format!( 38 "no operator found in '{}'. Use =, !=, >, <, >=, <=, ~ or !~", 39 s 40 ) 41 })?; 42 if pos == 0 { 43 return Err(format!("missing column name in '{}'", s)); 44 } 45 let column = s[..pos].to_string(); 46 let rest = &s[pos..]; 47 let (op, op_len) = if rest.starts_with(">=") { 48 (FilterOp::Gte, 2) 49 } else if rest.starts_with("<=") { 50 (FilterOp::Lte, 2) 51 } else if rest.starts_with("!=") { 52 (FilterOp::NotEq, 2) 53 } else if rest.starts_with("!~") { 54 (FilterOp::NotContains, 2) 55 } else if rest.starts_with('>') { 56 (FilterOp::Gt, 1) 57 } else if rest.starts_with('<') { 58 (FilterOp::Lt, 1) 59 } else if rest.starts_with('=') { 60 (FilterOp::Eq, 1) 61 } else if rest.starts_with('~') { 62 (FilterOp::Contains, 1) 63 } else { 64 return Err(format!("invalid operator in '{}'", s)); 65 }; 66 let value = rest[op_len..].to_string(); 67 Ok(FilterExpr { column, op, value }) 68 } 69 70 /// Parse a sort spec like "Amount:desc" or "Name" (default asc). 71 /// Splits on the last colon so column names containing colons are supported. 72 pub fn parse_sort_spec(s: &str) -> Result<SortSpec, String> { 73 if let Some(colon_pos) = s.rfind(':') { 74 let col = &s[..colon_pos]; 75 let dir = &s[colon_pos + 1..]; 76 match dir.to_lowercase().as_str() { 77 "asc" => Ok(SortSpec { 78 column: col.to_string(), 79 descending: false, 80 }), 81 "desc" => Ok(SortSpec { 82 column: col.to_string(), 83 descending: true, 84 }), 85 _ => Err(format!( 86 "invalid sort direction '{}'. Use 'asc' or 'desc'", 87 dir 88 )), 89 } 90 } else { 91 Ok(SortSpec { 92 column: s.to_string(), 93 descending: false, 94 }) 95 } 96 } 97 98 /// Resolve a column specifier to a DataFrame column name. 99 /// Accepts either: 100 /// - A header name (exact match first, then case-insensitive) 101 pub fn resolve_column(spec: &str, df_columns: &[String]) -> Result<String, String> { 102 if df_columns.contains(&spec.to_string()) { 103 return Ok(spec.to_string()); 104 } 105 let spec_lower = spec.to_lowercase(); 106 for col in df_columns { 107 if col.to_lowercase() == spec_lower { 108 return Ok(col.clone()); 109 } 110 } 111 let available = df_columns.join(", "); 112 Err(format!("column '{}' not found. Available columns: {}", spec, available)) 113 } 114 115 /// Resolve a list of column specifiers to DataFrame column names. 116 pub fn resolve_columns(specs: &[String], df_columns: &[String]) -> Result<Vec<String>, String> { 117 specs.iter().map(|s| resolve_column(s, df_columns)).collect() 118 } 119 120 /// Check if a polars DataType is numeric. 121 fn is_numeric_dtype(dtype: &DataType) -> bool { 122 matches!( 123 dtype, 124 DataType::Int8 125 | DataType::Int16 126 | DataType::Int32 127 | DataType::Int64 128 | DataType::UInt8 129 | DataType::UInt16 130 | DataType::UInt32 131 | DataType::UInt64 132 | DataType::Float32 133 | DataType::Float64 134 ) 135 } 136 137 /// Build a boolean mask for a single filter expression against a DataFrame. 138 fn build_filter_mask(df: &DataFrame, expr: &FilterExpr) -> Result<BooleanChunked> { 139 let col = df.column(&expr.column).map_err(|e| anyhow::anyhow!("{}", e))?; 140 let series = col.as_materialized_series(); 141 let dtype = series.dtype(); 142 143 match &expr.op { 144 FilterOp::Eq => { 145 if is_numeric_dtype(dtype) 146 && let Ok(n) = expr.value.parse::<f64>() { 147 let s = series.cast(&DataType::Float64)?; 148 return Ok(s.f64()?.equal(n)); 149 } 150 let s = series.cast(&DataType::String)?; 151 Ok(s.str()?.equal(expr.value.as_str())) 152 } 153 FilterOp::NotEq => { 154 if is_numeric_dtype(dtype) 155 && let Ok(n) = expr.value.parse::<f64>() { 156 let s = series.cast(&DataType::Float64)?; 157 return Ok(s.f64()?.not_equal(n)); 158 } 159 let s = series.cast(&DataType::String)?; 160 Ok(s.str()?.not_equal(expr.value.as_str())) 161 } 162 FilterOp::Gt => { 163 let n = parse_numeric_value(&expr.value, ">")?; 164 let s = series.cast(&DataType::Float64)?; 165 Ok(s.f64()?.gt(n)) 166 } 167 FilterOp::Lt => { 168 let n = parse_numeric_value(&expr.value, "<")?; 169 let s = series.cast(&DataType::Float64)?; 170 Ok(s.f64()?.lt(n)) 171 } 172 FilterOp::Gte => { 173 let n = parse_numeric_value(&expr.value, ">=")?; 174 let s = series.cast(&DataType::Float64)?; 175 Ok(s.f64()?.gt_eq(n)) 176 } 177 FilterOp::Lte => { 178 let n = parse_numeric_value(&expr.value, "<=")?; 179 let s = series.cast(&DataType::Float64)?; 180 Ok(s.f64()?.lt_eq(n)) 181 } 182 FilterOp::Contains => { 183 let s = series.cast(&DataType::String)?; 184 let ca = s.str()?; 185 let pat = expr.value.to_lowercase(); 186 let mask: BooleanChunked = ca.into_iter() 187 .map(|opt_s| opt_s.map(|s| s.to_lowercase().contains(&pat)).unwrap_or(false)) 188 .collect(); 189 Ok(mask) 190 } 191 FilterOp::NotContains => { 192 let s = series.cast(&DataType::String)?; 193 let ca = s.str()?; 194 let pat = expr.value.to_lowercase(); 195 let mask: BooleanChunked = ca.into_iter() 196 .map(|opt_s| opt_s.map(|s| !s.to_lowercase().contains(&pat)).unwrap_or(true)) 197 .collect(); 198 Ok(mask) 199 } 200 } 201 } 202 203 fn parse_numeric_value(value: &str, op: &str) -> Result<f64> { 204 value 205 .parse::<f64>() 206 .map_err(|_| anyhow::anyhow!("'{}' requires numeric value, got '{}'", op, value)) 207 } 208 209 /// Apply a list of filter expressions to a DataFrame (AND logic). 210 /// An empty list returns the DataFrame unchanged. 211 pub fn apply_filters(df: &DataFrame, exprs: &[FilterExpr]) -> Result<DataFrame> { 212 let mut result = df.clone(); 213 for expr in exprs { 214 let mask = build_filter_mask(&result, expr)?; 215 result = result.filter(&mask)?; 216 } 217 Ok(result) 218 } 219 220 /// Options for the filter pipeline. 221 pub struct FilterOptions { 222 pub filters: Vec<FilterExpr>, 223 pub cols: Option<Vec<String>>, 224 pub sort: Option<SortSpec>, 225 pub limit: Option<usize>, 226 pub head: Option<usize>, 227 pub tail: Option<usize>, 228 } 229 230 /// Apply a sort specification to a DataFrame. 231 pub fn apply_sort(df: &DataFrame, spec: &SortSpec) -> Result<DataFrame> { 232 let opts = SortMultipleOptions::default() 233 .with_order_descending(spec.descending); 234 Ok(df.sort([&spec.column], opts)?) 235 } 236 237 /// Run the full filter pipeline: head/tail → resolve & filter → sort → limit → select columns. 238 pub fn filter_pipeline(df: DataFrame, opts: &FilterOptions) -> Result<DataFrame> { 239 let df_columns: Vec<String> = df 240 .get_column_names() 241 .iter() 242 .map(|s| s.to_string()) 243 .collect(); 244 245 // 1. Pre-filter window: head or tail 246 let df = if let Some(n) = opts.head { 247 df.head(Some(n)) 248 } else if let Some(n) = opts.tail { 249 df.tail(Some(n)) 250 } else { 251 df 252 }; 253 254 // 2. Resolve column names in filter expressions and apply filters 255 let resolved_filters: Vec<FilterExpr> = opts 256 .filters 257 .iter() 258 .map(|f| { 259 let resolved_col = resolve_column(&f.column, &df_columns)?; 260 Ok(FilterExpr { 261 column: resolved_col, 262 op: f.op.clone(), 263 value: f.value.clone(), 264 }) 265 }) 266 .collect::<Result<Vec<_>, String>>() 267 .map_err(|e| anyhow::anyhow!("{}", e))?; 268 269 let df = apply_filters(&df, &resolved_filters)?; 270 271 // 3. Sort 272 let df = if let Some(ref spec) = opts.sort { 273 let resolved_col = resolve_column(&spec.column, &df_columns) 274 .map_err(|e| anyhow::anyhow!("{}", e))?; 275 let resolved_spec = SortSpec { 276 column: resolved_col, 277 descending: spec.descending, 278 }; 279 apply_sort(&df, &resolved_spec)? 280 } else { 281 df 282 }; 283 284 // 4. Limit (after filtering and sorting) 285 let df = if let Some(n) = opts.limit { 286 df.head(Some(n)) 287 } else { 288 df 289 }; 290 291 // 5. Select columns 292 let df = if let Some(ref col_specs) = opts.cols { 293 let resolved_cols = resolve_columns(col_specs, &df_columns) 294 .map_err(|e| anyhow::anyhow!("{}", e))?; 295 let col_refs: Vec<&str> = resolved_cols.iter().map(|s| s.as_str()).collect(); 296 df.select(col_refs)? 297 } else { 298 df 299 }; 300 301 Ok(df) 302 } 303 304 #[cfg(test)] 305 mod tests { 306 use super::*; 307 308 fn make_test_df() -> DataFrame { 309 DataFrame::new(vec![ 310 Column::new("State".into(), &["CA", "NY", "CA", "TX", "NY"]), 311 Column::new("City".into(), &["LA", "NYC", "SF", "Houston", "Albany"]), 312 Column::new("Amount".into(), &[1500i64, 2000, 800, 1200, 500]), 313 Column::new("Year".into(), &[2023i64, 2023, 2024, 2024, 2023]), 314 Column::new("Status".into(), &["Active", "Active", "Draft", "Active", "Draft"]), 315 ]) 316 .unwrap() 317 } 318 319 #[test] 320 fn filter_eq_string() { 321 let df = make_test_df(); 322 let expr = parse_filter_expr("State=CA").unwrap(); 323 let result = apply_filters(&df, &[expr]).unwrap(); 324 assert_eq!(result.height(), 2); 325 } 326 327 #[test] 328 fn filter_eq_numeric() { 329 let df = make_test_df(); 330 let expr = parse_filter_expr("Amount=1500").unwrap(); 331 let result = apply_filters(&df, &[expr]).unwrap(); 332 assert_eq!(result.height(), 1); 333 } 334 335 #[test] 336 fn filter_not_eq() { 337 let df = make_test_df(); 338 let expr = parse_filter_expr("Status!=Draft").unwrap(); 339 let result = apply_filters(&df, &[expr]).unwrap(); 340 assert_eq!(result.height(), 3); 341 } 342 343 #[test] 344 fn filter_gt() { 345 let df = make_test_df(); 346 let expr = parse_filter_expr("Amount>1000").unwrap(); 347 let result = apply_filters(&df, &[expr]).unwrap(); 348 assert_eq!(result.height(), 3); 349 } 350 351 #[test] 352 fn filter_lt() { 353 let df = make_test_df(); 354 let expr = parse_filter_expr("Amount<1000").unwrap(); 355 let result = apply_filters(&df, &[expr]).unwrap(); 356 assert_eq!(result.height(), 2); 357 } 358 359 #[test] 360 fn filter_gte() { 361 let df = make_test_df(); 362 let expr = parse_filter_expr("Amount>=1500").unwrap(); 363 let result = apply_filters(&df, &[expr]).unwrap(); 364 assert_eq!(result.height(), 2); 365 } 366 367 #[test] 368 fn filter_lte() { 369 let df = make_test_df(); 370 let expr = parse_filter_expr("Amount<=800").unwrap(); 371 let result = apply_filters(&df, &[expr]).unwrap(); 372 assert_eq!(result.height(), 2); 373 } 374 375 #[test] 376 fn filter_contains() { 377 let df = make_test_df(); 378 let expr = parse_filter_expr("City~ou").unwrap(); 379 let result = apply_filters(&df, &[expr]).unwrap(); 380 assert_eq!(result.height(), 1); 381 } 382 383 #[test] 384 fn filter_contains_case_insensitive() { 385 let df = make_test_df(); 386 let expr = parse_filter_expr("City~HOUSTON").unwrap(); 387 let result = apply_filters(&df, &[expr]).unwrap(); 388 assert_eq!(result.height(), 1); 389 } 390 391 #[test] 392 fn filter_not_contains() { 393 let df = make_test_df(); 394 let expr = parse_filter_expr("Status!~raft").unwrap(); 395 let result = apply_filters(&df, &[expr]).unwrap(); 396 assert_eq!(result.height(), 3); 397 } 398 399 #[test] 400 fn filter_multiple_and() { 401 let df = make_test_df(); 402 let e1 = parse_filter_expr("State=CA").unwrap(); 403 let e2 = parse_filter_expr("Amount>1000").unwrap(); 404 let result = apply_filters(&df, &[e1, e2]).unwrap(); 405 assert_eq!(result.height(), 1); 406 } 407 408 #[test] 409 fn filter_no_matches_returns_empty() { 410 let df = make_test_df(); 411 let expr = parse_filter_expr("State=ZZ").unwrap(); 412 let result = apply_filters(&df, &[expr]).unwrap(); 413 assert_eq!(result.height(), 0); 414 } 415 416 #[test] 417 fn filter_empty_exprs_returns_all() { 418 let df = make_test_df(); 419 let result = apply_filters(&df, &[]).unwrap(); 420 assert_eq!(result.height(), 5); 421 } 422 423 #[test] 424 fn parse_eq() { 425 let expr = parse_filter_expr("State=CA").unwrap(); 426 assert_eq!(expr.column, "State"); 427 assert_eq!(expr.op, FilterOp::Eq); 428 assert_eq!(expr.value, "CA"); 429 } 430 431 #[test] 432 fn parse_not_eq() { 433 let expr = parse_filter_expr("Status!=Draft").unwrap(); 434 assert_eq!(expr.column, "Status"); 435 assert_eq!(expr.op, FilterOp::NotEq); 436 assert_eq!(expr.value, "Draft"); 437 } 438 439 #[test] 440 fn parse_gt() { 441 let expr = parse_filter_expr("Amount>1000").unwrap(); 442 assert_eq!(expr.column, "Amount"); 443 assert_eq!(expr.op, FilterOp::Gt); 444 assert_eq!(expr.value, "1000"); 445 } 446 447 #[test] 448 fn parse_lt() { 449 let expr = parse_filter_expr("Year<2024").unwrap(); 450 assert_eq!(expr.column, "Year"); 451 assert_eq!(expr.op, FilterOp::Lt); 452 assert_eq!(expr.value, "2024"); 453 } 454 455 #[test] 456 fn parse_gte() { 457 let expr = parse_filter_expr("Score>=90").unwrap(); 458 assert_eq!(expr.column, "Score"); 459 assert_eq!(expr.op, FilterOp::Gte); 460 assert_eq!(expr.value, "90"); 461 } 462 463 #[test] 464 fn parse_lte() { 465 let expr = parse_filter_expr("Price<=50.5").unwrap(); 466 assert_eq!(expr.column, "Price"); 467 assert_eq!(expr.op, FilterOp::Lte); 468 assert_eq!(expr.value, "50.5"); 469 } 470 471 #[test] 472 fn parse_contains() { 473 let expr = parse_filter_expr("Name~john").unwrap(); 474 assert_eq!(expr.column, "Name"); 475 assert_eq!(expr.op, FilterOp::Contains); 476 assert_eq!(expr.value, "john"); 477 } 478 479 #[test] 480 fn parse_not_contains() { 481 let expr = parse_filter_expr("Name!~draft").unwrap(); 482 assert_eq!(expr.column, "Name"); 483 assert_eq!(expr.op, FilterOp::NotContains); 484 assert_eq!(expr.value, "draft"); 485 } 486 487 #[test] 488 fn parse_value_with_equals() { 489 let expr = parse_filter_expr("Formula=A+B=C").unwrap(); 490 assert_eq!(expr.column, "Formula"); 491 assert_eq!(expr.op, FilterOp::Eq); 492 assert_eq!(expr.value, "A+B=C"); 493 } 494 495 #[test] 496 fn parse_empty_value() { 497 let expr = parse_filter_expr("Status=").unwrap(); 498 assert_eq!(expr.column, "Status"); 499 assert_eq!(expr.op, FilterOp::Eq); 500 assert_eq!(expr.value, ""); 501 } 502 503 #[test] 504 fn parse_no_operator_is_err() { 505 assert!(parse_filter_expr("JustAWord").is_err()); 506 } 507 508 #[test] 509 fn parse_no_column_is_err() { 510 assert!(parse_filter_expr("=value").is_err()); 511 } 512 513 #[test] 514 fn parse_sort_desc() { 515 let spec = parse_sort_spec("Amount:desc").unwrap(); 516 assert_eq!(spec.column, "Amount"); 517 assert!(spec.descending); 518 } 519 520 #[test] 521 fn parse_sort_asc() { 522 let spec = parse_sort_spec("Name:asc").unwrap(); 523 assert_eq!(spec.column, "Name"); 524 assert!(!spec.descending); 525 } 526 527 #[test] 528 fn parse_sort_default_asc() { 529 let spec = parse_sort_spec("Name").unwrap(); 530 assert_eq!(spec.column, "Name"); 531 assert!(!spec.descending); 532 } 533 534 #[test] 535 fn parse_sort_bad_dir_is_err() { 536 assert!(parse_sort_spec("Name:up").is_err()); 537 } 538 539 #[test] 540 fn resolve_by_header_name() { 541 let cols = vec!["State".to_string(), "Amount".to_string(), "Year".to_string()]; 542 assert_eq!(resolve_column("Amount", &cols).unwrap(), "Amount"); 543 } 544 545 #[test] 546 fn resolve_case_insensitive_header() { 547 let cols = vec!["State".to_string(), "Amount".to_string()]; 548 assert_eq!(resolve_column("state", &cols).unwrap(), "State"); 549 } 550 551 #[test] 552 fn resolve_unknown_column_is_err() { 553 let cols = vec!["State".to_string(), "Amount".to_string()]; 554 let err = resolve_column("Foo", &cols).unwrap_err(); 555 assert!(err.contains("not found"), "error was: {}", err); 556 } 557 558 #[test] 559 fn resolve_multiple_columns() { 560 let cols = vec!["State".to_string(), "Amount".to_string(), "Year".to_string()]; 561 let resolved = resolve_columns(&["State".to_string(), "Year".to_string()], &cols).unwrap(); 562 assert_eq!(resolved, vec!["State", "Year"]); 563 } 564 565 #[test] 566 fn sort_ascending() { 567 let df = make_test_df(); 568 let spec = parse_sort_spec("Amount:asc").unwrap(); 569 let result = apply_sort(&df, &spec).unwrap(); 570 let col = result.column("Amount").unwrap().as_materialized_series(); 571 let amounts = col.i64().unwrap(); 572 assert_eq!(amounts.get(0), Some(500)); 573 assert_eq!(amounts.get(4), Some(2000)); 574 } 575 576 #[test] 577 fn sort_descending() { 578 let df = make_test_df(); 579 let spec = parse_sort_spec("Amount:desc").unwrap(); 580 let result = apply_sort(&df, &spec).unwrap(); 581 let col = result.column("Amount").unwrap().as_materialized_series(); 582 let amounts = col.i64().unwrap(); 583 assert_eq!(amounts.get(0), Some(2000)); 584 assert_eq!(amounts.get(4), Some(500)); 585 } 586 587 #[test] 588 fn pipeline_full() { 589 let df = make_test_df(); 590 let opts = FilterOptions { 591 filters: vec![parse_filter_expr("Amount>500").unwrap()], 592 cols: Some(vec!["State".to_string(), "Amount".to_string()]), 593 sort: Some(parse_sort_spec("Amount:desc").unwrap()), 594 limit: Some(2), 595 head: None, 596 tail: None, 597 }; 598 let result = filter_pipeline(df, &opts).unwrap(); 599 assert_eq!(result.height(), 2); 600 assert_eq!(result.width(), 2); 601 let col = result.column("Amount").unwrap().as_materialized_series(); 602 let amounts = col.i64().unwrap(); 603 assert_eq!(amounts.get(0), Some(2000)); 604 assert_eq!(amounts.get(1), Some(1500)); 605 } 606 607 #[test] 608 fn pipeline_head_before_filter() { 609 let df = make_test_df(); // 5 rows: CA/LA, NY/NYC, CA/SF, TX/Houston, NY/Albany 610 let opts = FilterOptions { 611 filters: vec![parse_filter_expr("State=NY").unwrap()], 612 cols: None, 613 sort: None, 614 limit: None, 615 head: Some(3), // Take first 3 rows before filtering 616 tail: None, 617 }; 618 let result = filter_pipeline(df, &opts).unwrap(); 619 // First 3 rows: CA/LA, NY/NYC, CA/SF → only NY/NYC matches 620 assert_eq!(result.height(), 1); 621 } 622 623 #[test] 624 fn pipeline_tail_before_filter() { 625 let df = make_test_df(); // 5 rows 626 let opts = FilterOptions { 627 filters: vec![parse_filter_expr("State=CA").unwrap()], 628 cols: None, 629 sort: None, 630 limit: None, 631 head: None, 632 tail: Some(3), // Last 3 rows before filtering 633 }; 634 let result = filter_pipeline(df, &opts).unwrap(); 635 // Last 3 rows: CA/SF, TX/Houston, NY/Albany → only CA/SF matches 636 assert_eq!(result.height(), 1); 637 } 638 639 #[test] 640 fn pipeline_no_options_returns_all() { 641 let df = make_test_df(); 642 let opts = FilterOptions { 643 filters: vec![], 644 cols: None, 645 sort: None, 646 limit: None, 647 head: None, 648 tail: None, 649 }; 650 let result = filter_pipeline(df, &opts).unwrap(); 651 assert_eq!(result.height(), 5); 652 assert_eq!(result.width(), 5); 653 } 654 655 #[test] 656 fn pipeline_limit_after_filter() { 657 let df = make_test_df(); 658 let opts = FilterOptions { 659 filters: vec![parse_filter_expr("Status=Active").unwrap()], 660 cols: None, 661 sort: None, 662 limit: Some(2), 663 head: None, 664 tail: None, 665 }; 666 let result = filter_pipeline(df, &opts).unwrap(); 667 assert_eq!(result.height(), 2); // 3 Active rows, limited to 2 668 } 669 }