From 2f07061a282d1821df0a0394b306b060073fa7d7 Mon Sep 17 00:00:00 2001 From: shruti2522 Date: Sun, 20 Apr 2025 20:43:09 +0530 Subject: [PATCH] Support GroupsAccumulator for avg duration --- datafusion/functions-aggregate/src/average.rs | 41 ++++++++++- .../sqllogictest/test_files/aggregate.slt | 71 +++++++++++++++++++ 2 files changed, 111 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 798a039f50b1..595d23721af8 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -182,7 +182,7 @@ impl AggregateUDFImpl for Avg { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { matches!( args.return_type, - DataType::Float64 | DataType::Decimal128(_, _) + DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_) ) } @@ -243,6 +243,45 @@ impl AggregateUDFImpl for Avg { ))) } + (Duration(time_unit), Duration(_result_unit)) => { + let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64); + + match time_unit { + TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::< + DurationSecondType, + _, + >::new( + &data_type, + args.return_type, + avg_fn, + ))), + TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationMillisecondType, + _, + >::new( + &data_type, + args.return_type, + avg_fn, + ))), + TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationMicrosecondType, + _, + >::new( + &data_type, + args.return_type, + avg_fn, + ))), + TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationNanosecondType, + _, + >::new( + &data_type, + args.return_type, + avg_fn, + ))), + } + } + _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", &data_type, diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 3f250ab21d46..9fd5c6e664c2 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5036,6 +5036,77 @@ FROM d WHERE column1 IS NOT NULL; statement ok drop table d; +# avg_duration (GroupsAccumulator) + +statement ok +create table duration as values + (arrow_cast(10, 'Duration(Second)'), arrow_cast(100, 'Duration(Millisecond)'), 'a', 1), + (arrow_cast(20, 'Duration(Second)'), arrow_cast(200, 'Duration(Millisecond)'), 'a', 2), + (arrow_cast(30, 'Duration(Second)'), arrow_cast(300, 'Duration(Millisecond)'), 'b', 1), + (arrow_cast(40, 'Duration(Second)'), arrow_cast(400, 'Duration(Millisecond)'), 'b', 2), + (arrow_cast(50, 'Duration(Second)'), arrow_cast(500, 'Duration(Millisecond)'), 'c', 1), + (arrow_cast(60, 'Duration(Second)'), arrow_cast(600, 'Duration(Millisecond)'), 'c', 2); + +query T??I +SELECT column3, avg(column1), avg(column2), column4 FROM duration GROUP BY column3, column4 ORDER BY column3, column4; +---- +a 0 days 0 hours 0 mins 10 secs 0 days 0 hours 0 mins 0.100 secs 1 +a 0 days 0 hours 0 mins 20 secs 0 days 0 hours 0 mins 0.200 secs 2 +b 0 days 0 hours 0 mins 30 secs 0 days 0 hours 0 mins 0.300 secs 1 +b 0 days 0 hours 0 mins 40 secs 0 days 0 hours 0 mins 0.400 secs 2 +c 0 days 0 hours 0 mins 50 secs 0 days 0 hours 0 mins 0.500 secs 1 +c 0 days 0 hours 1 mins 0 secs 0 days 0 hours 0 mins 0.600 secs 2 + +query T? +SELECT column3, avg(column1) FROM duration GROUP BY column3 ORDER BY column3; +---- +a 0 days 0 hours 0 mins 15 secs +b 0 days 0 hours 0 mins 35 secs +c 0 days 0 hours 0 mins 55 secs + +query I?? +SELECT column4, avg(column1), avg(column2) FROM duration GROUP BY column4 ORDER BY column4; +---- +1 0 days 0 hours 0 mins 30 secs 0 days 0 hours 0 mins 0.300 secs +2 0 days 0 hours 0 mins 40 secs 0 days 0 hours 0 mins 0.400 secs + +query TI?? +SELECT column3, column4, column1, avg(column1) OVER (PARTITION BY column3 ORDER BY column4 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as running_avg +FROM duration +ORDER BY column3, column4; +---- +a 1 0 days 0 hours 0 mins 10 secs 0 days 0 hours 0 mins 10 secs +a 2 0 days 0 hours 0 mins 20 secs 0 days 0 hours 0 mins 15 secs +b 1 0 days 0 hours 0 mins 30 secs 0 days 0 hours 0 mins 30 secs +b 2 0 days 0 hours 0 mins 40 secs 0 days 0 hours 0 mins 35 secs +c 1 0 days 0 hours 0 mins 50 secs 0 days 0 hours 0 mins 50 secs +c 2 0 days 0 hours 1 mins 0 secs 0 days 0 hours 0 mins 55 secs + +statement ok +drop table duration; + +statement ok +create table duration_nulls as values + (arrow_cast(10, 'Duration(Second)'), 'a', 1), + (arrow_cast(20, 'Duration(Second)'), 'a', 2), + (NULL, 'b', 1), + (arrow_cast(40, 'Duration(Second)'), 'b', 2), + (arrow_cast(50, 'Duration(Second)'), 'c', 1), + (NULL, 'c', 2); + +query T?I +SELECT column2, avg(column1), column3 FROM duration_nulls GROUP BY column2, column3 ORDER BY column2, column3; +---- +a 0 days 0 hours 0 mins 10 secs 1 +a 0 days 0 hours 0 mins 20 secs 2 +b NULL 1 +b 0 days 0 hours 0 mins 40 secs 2 +c 0 days 0 hours 0 mins 50 secs 1 +c NULL 2 + +statement ok +drop table duration_nulls; + # Prepare the table with dictionary values for testing statement ok CREATE TABLE value(x bigint) AS VALUES (1), (2), (3), (1), (3), (4), (5), (2);