From 5dd2ac3f6024c61c635d5254e41ab5b2d74f8b01 Mon Sep 17 00:00:00 2001 From: changsun20 <110759360+changsun20@users.noreply.github.com> Date: Sun, 13 Apr 2025 01:50:14 -0400 Subject: [PATCH 1/2] feat: Emit warning with when doing = Null --- datafusion/sql/src/expr/mod.rs | 138 ++++++++++++++++++++++++++++++++- datafusion/sql/src/planner.rs | 3 + 2 files changed, 137 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index d29ccdc6a7e9..a569252675d2 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -21,13 +21,13 @@ use datafusion_expr::planner::{ }; use sqlparser::ast::{ AccessExpr, BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, - DictionaryField, Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, + DictionaryField, Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, Ident, MapEntry, StructField, Subscript, TrimWhereField, Value, ValueWithSpan, }; use datafusion_common::{ - internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, Result, - ScalarValue, + internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, Diagnostic, + Result, ScalarValue, Span, }; use datafusion_expr::expr::ScalarFunction; @@ -86,6 +86,44 @@ impl SqlToRel<'_, S> { StackEntry::SQLExpr(sql_expr) => { match *sql_expr { SQLExpr::BinaryOp { left, op, right } => { + // Detect if there is "= Null" in SQL + if op == BinaryOperator::Eq { + if let SQLExpr::Value(ValueWithSpan { + value: Value::Null, + span: null_span, + }) = *right + { + let left_span = match &*left { + SQLExpr::Identifier(Ident { span, .. }) => { + span.clone() + } + // In this case, we expect left to be + // Indentifier. Just to make the code + // more robust, we'll make left_span + // equals to null_span otherwise. + _ => null_span.clone(), + }; + let combined_span = Span { + start: Into::into(left_span.start), + end: Into::into(null_span.end), + }; + + let diagnostic = Diagnostic::new_warning( + "Ambiguous NULL comparison".to_string(), + Some(combined_span), + ) + .with_help( + "Use IS NULL instead of = NULL", + Some(Span { + start: Into::into(null_span.start), + end: Into::into(null_span.end), + }), + ); + + self.warnings.borrow_mut().push(diagnostic); + } + } + // Note the order that we push the entries to the stack // is important. We want to visit the left node first. stack.push(StackEntry::Operator(op)); @@ -1174,7 +1212,7 @@ mod tests { use sqlparser::parser::Parser; use datafusion_common::config::ConfigOptions; - use datafusion_common::TableReference; + use datafusion_common::{Location, TableReference}; use datafusion_expr::logical_plan::builder::LogicalTableSource; use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; @@ -1316,4 +1354,96 @@ mod tests { assert!(matches!(expr, Expr::Alias(_))); } + + // Helper to parse SQL expressions + fn parse_expr(sql: &str) -> SQLExpr { + let dialect = GenericDialect {}; + Parser::new(&dialect) + .try_with_sql(sql) + .unwrap() + .parse_expr() + .unwrap() + } + + #[test] + fn test_single_null_comparison() { + let context = TestContextProvider::new(); + let planner = SqlToRel::new(&context); + + // Test single = NULL case + let expr = parse_expr("password = NULL"); + let _ = planner + .sql_expr_to_logical_expr( + expr, + &DFSchema::empty(), + &mut PlannerContext::new(), + ) + .unwrap(); + + let warnings = planner.warnings.take(); + assert_eq!(warnings.len(), 1, "Should detect 1 warning"); + let warning = &warnings[0]; + assert_eq!(warning.message, "Ambiguous NULL comparison"); + + assert_eq!( + warning.span, + Some(Span { + start: Location { line: 1, column: 1 }, + end: Location { + line: 1, + column: 16 + } + }) + ); + + assert_eq!(warning.helps.len(), 1); + let help = &warning.helps[0]; + assert_eq!(help.message, "Use IS NULL instead of = NULL"); + } + + #[test] + fn test_multiple_null_comparisons() { + let context = TestContextProvider::new(); + let planner = SqlToRel::new(&context); + + // Test multiple = NULL cases + let expr = parse_expr("(name = NULL) OR (age = NULL)"); + let _ = planner + .sql_expr_to_logical_expr( + expr, + &DFSchema::empty(), + &mut PlannerContext::new(), + ) + .unwrap(); + + let warnings = planner.warnings.take(); + assert_eq!(warnings.len(), 2, "Should detect 2 warnings"); + + let first = &warnings[0]; + assert_eq!( + first.span, + Some(Span { + start: Location { line: 1, column: 2 }, + end: Location { + line: 1, + column: 13 + } + }) + ); + + let second = &warnings[1]; + assert_eq!( + second.span, + Some(Span { + start: Location { + line: 1, + column: 19 + }, + end: Location { + line: 1, + column: 29 + } + }) + ); + } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 3325c98aa74b..e7b61089f472 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -16,6 +16,7 @@ // under the License. //! [`SqlToRel`]: SQL Query Planner (produces [`LogicalPlan`] from SQL AST) +use std::cell::RefCell; use std::collections::HashMap; use std::sync::Arc; use std::vec; @@ -337,6 +338,7 @@ pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, pub(crate) ident_normalizer: IdentNormalizer, + pub(crate) warnings: RefCell>, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -359,6 +361,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { context_provider, options, ident_normalizer: IdentNormalizer::new(ident_normalize), + warnings: RefCell::new(Vec::new()), } } From 85ecef185104bc45fb7ed3f0a46a69a5f90e3e9a Mon Sep 17 00:00:00 2001 From: changsun20 <110759360+changsun20@users.noreply.github.com> Date: Sun, 13 Apr 2025 02:31:51 -0400 Subject: [PATCH 2/2] fix: fix clippy warnings --- datafusion/sql/src/expr/mod.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index a569252675d2..0c7769639ac7 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -94,14 +94,12 @@ impl SqlToRel<'_, S> { }) = *right { let left_span = match &*left { - SQLExpr::Identifier(Ident { span, .. }) => { - span.clone() - } + SQLExpr::Identifier(Ident { span, .. }) => *span, // In this case, we expect left to be // Indentifier. Just to make the code // more robust, we'll make left_span // equals to null_span otherwise. - _ => null_span.clone(), + _ => null_span, }; let combined_span = Span { start: Into::into(left_span.start),