From 91ab01d76fc9804601f1cd841ceec6ce1b07cd95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Estrella?= <2049686+sestrella@users.noreply.github.com> Date: Tue, 21 Jul 2020 21:37:55 -0500 Subject: [PATCH] [#197] Allow PostgreSQL aggregate functions to take a filter clause --- changelog.md | 5 ++ esqueleto.cabal | 2 +- src/Database/Esqueleto/Internal/Internal.hs | 2 + src/Database/Esqueleto/PostgreSQL.hs | 46 +++++++++++++- test/PostgreSQL/Test.hs | 70 +++++++++++++++++++++ 5 files changed, 123 insertions(+), 2 deletions(-) diff --git a/changelog.md b/changelog.md index 68566e1..364ffc1 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,8 @@ +3.3.3.3 +======== +- @sestrella + - [#198](https://github.com/bitemyapp/esqueleto/pull/198) - Allow PostgreSQL aggregate functions to take a filter clause + 3.3.3.2 ======== - @maxgabriel diff --git a/esqueleto.cabal b/esqueleto.cabal index 0e015f9..41d3d5a 100644 --- a/esqueleto.cabal +++ b/esqueleto.cabal @@ -1,7 +1,7 @@ cabal-version: 1.12 name: esqueleto -version: 3.3.3.2 +version: 3.3.3.3 synopsis: Type-safe EDSL for SQL queries on persistent backends. description: @esqueleto@ is a bare bones, type-safe EDSL for SQL queries that works with unmodified @persistent@ SQL backends. Its language closely resembles SQL, so you don't have to learn new concepts, just new syntax, and it's fairly easy to predict the generated SQL and optimize it for your backend. Most kinds of errors committed when writing SQL are caught as compile-time errors---although it is possible to write type-checked @esqueleto@ queries that fail at runtime. . diff --git a/src/Database/Esqueleto/Internal/Internal.hs b/src/Database/Esqueleto/Internal/Internal.hs index 02954e9..95f948a 100644 --- a/src/Database/Esqueleto/Internal/Internal.hs +++ b/src/Database/Esqueleto/Internal/Internal.hs @@ -1586,6 +1586,8 @@ data UnexpectedValueError = | MakeSetError | MakeWhereError | MakeHavingError + | FilterWhereAggError + | FilterWhereClauseError deriving (Show) type CompositeKeyError = UnexpectedValueError diff --git a/src/Database/Esqueleto/PostgreSQL.hs b/src/Database/Esqueleto/PostgreSQL.hs index dd1eab5..7afbc6b 100644 --- a/src/Database/Esqueleto/PostgreSQL.hs +++ b/src/Database/Esqueleto/PostgreSQL.hs @@ -23,6 +23,7 @@ module Database.Esqueleto.PostgreSQL , upsertBy , insertSelectWithConflict , insertSelectWithConflictCount + , filterWhere -- * Internal , unsafeSqlAggregateFunction ) where @@ -38,7 +39,7 @@ import Database.Esqueleto.Internal.Sql import Database.Esqueleto.Internal.Internal (EsqueletoError(..), CompositeKeyError(..), UnexpectedCaseError(..), SetClause, Ident(..), uncommas, FinalResult(..), toUniqueDef, - KnowResult, renderUpdates) + KnowResult, renderUpdates, UnexpectedValueError(..)) import Database.Persist.Class (OnlyOneUniqueKey) import Data.List.NonEmpty ( NonEmpty( (:|) ) ) import Data.Int (Int64) @@ -298,3 +299,46 @@ insertSelectWithConflictCount unique query conflictQuery = do ]),values) where (updatesTLB,values) = renderedUpdates conn + +-- | Allow aggregate functions to take a filter clause. +-- +-- Example of usage: +-- +-- @ +-- share [mkPersist sqlSettings] [persistLowerCase| +-- User +-- name Text +-- deriving Eq Show +-- Task +-- userId UserId +-- completed Bool +-- deriving Eq Show +-- |] +-- +-- select $ from $ \(users `InnerJoin` tasks) -> do +-- on $ users ^. UserId ==. tasks ^. TaskUserId +-- groupBy $ users ^. UserId +-- return +-- ( users ^. UserId +-- , count (tasks ^. TaskId) `filterWhere` (tasks ^. TaskCompleted ==. val True) +-- , count (tasks ^. TaskId) `filterWhere` (tasks ^. TaskCompleted ==. val False) +-- ) +-- @ +-- +-- @since 3.3.3.3 +filterWhere + :: SqlExpr (Value a) + -- ^ Aggregate function + -> SqlExpr (Value Bool) + -- ^ Filter clause + -> SqlExpr (Value a) +filterWhere aggExpr clauseExpr = ERaw Never $ \info -> + let (aggBuilder, aggValues) = case aggExpr of + ERaw _ aggF -> aggF info + ECompositeKey _ -> throw $ CompositeKeyErr FilterWhereAggError + (clauseBuilder, clauseValues) = case clauseExpr of + ERaw _ clauseF -> clauseF info + ECompositeKey _ -> throw $ CompositeKeyErr FilterWhereClauseError + in ( aggBuilder <> " FILTER (WHERE " <> clauseBuilder <> ")" + , aggValues <> clauseValues + ) diff --git a/test/PostgreSQL/Test.hs b/test/PostgreSQL/Test.hs index 93d2c14..bb0053f 100644 --- a/test/PostgreSQL/Test.hs +++ b/test/PostgreSQL/Test.hs @@ -1082,6 +1082,75 @@ testInsertSelectWithConflict = liftIO $ map entityVal uniques1 `shouldBe` test liftIO $ map entityVal uniques2 `shouldBe` test2 +testFilterWhere :: Spec +testFilterWhere = + describe "filterWhere" $ do + it "adds a filter clause to count aggregation" $ run $ do + -- Person "John" (Just 36) Nothing 1 + _ <- insert p1 + -- Person "Rachel" Nothing (Just 37) 2 + _ <- insert p2 + -- Person "Mike" (Just 17) Nothing 3 + _ <- insert p3 + -- Person "Livia" (Just 17) (Just 18) 4 + _ <- insert p4 + -- Person "Mitch" Nothing Nothing 5 + _ <- insert p5 + + usersByAge <- (fmap . fmap) (\(Value a, Value b, Value c) -> (a, b, c)) <$> select $ from $ \users -> do + groupBy $ users ^. PersonAge + return + ( users ^. PersonAge + -- Nothing: [Rachel { favNum = 2 }, Mitch { favNum = 5 }] = 2 + -- Just 36: [John { favNum = 1 } (excluded)] = 0 + -- Just 17: [Mike { favNum = 3 }, Livia { favNum = 4 }] = 2 + , count (users ^. PersonId) `EP.filterWhere` (users ^. PersonFavNum >=. val 2) + -- Nothing: [Rachel { favNum = 2 } (excluded), Mitch { favNum = 5 } (excluded)] = 0 + -- Just 36: [John { favNum = 1 }] = 1 + -- Just 17: [Mike { favNum = 3 } (excluded), Livia { favNum = 4 } (excluded)] = 0 + , count (users ^. PersonFavNum) `EP.filterWhere` (users ^. PersonFavNum <. val 2) + ) + + liftIO $ usersByAge `shouldMatchList` + ( [ (Nothing, 2, 0) + , (Just 36, 0, 1) + , (Just 17, 2, 0) + ] :: [(Maybe Int, Int, Int)] + ) + + it "adds a filter clause to sum aggregation" $ run $ do + -- Person "John" (Just 36) Nothing 1 + _ <- insert p1 + -- Person "Rachel" Nothing (Just 37) 2 + _ <- insert p2 + -- Person "Mike" (Just 17) Nothing 3 + _ <- insert p3 + -- Person "Livia" (Just 17) (Just 18) 4 + _ <- insert p4 + -- Person "Mitch" Nothing Nothing 5 + _ <- insert p5 + + usersByAge <- (fmap . fmap) (\(Value a, Value b, Value c) -> (a, b, c)) <$> select $ from $ \users -> do + groupBy $ users ^. PersonAge + return + ( users ^. PersonAge + -- Nothing: [Rachel { favNum = 2 }, Mitch { favNum = 5 }] = Just 7 + -- Just 36: [John { favNum = 1 } (excluded)] = Nothing + -- Just 17: [Mike { favNum = 3 }, Livia { favNum = 4 }] = Just 7 + , sum_ (users ^. PersonFavNum) `EP.filterWhere` (users ^. PersonFavNum >=. val 2) + -- Nothing: [Rachel { favNum = 2 } (excluded), Mitch { favNum = 5 } (excluded)] = Nothing + -- Just 36: [John { favNum = 1 }] = Just 1 + -- Just 17: [Mike { favNum = 3 } (excluded), Livia { favNum = 4 } (excluded)] = Nothing + , sum_ (users ^. PersonFavNum) `EP.filterWhere` (users ^. PersonFavNum <. val 2) + ) + + liftIO $ usersByAge `shouldMatchList` + ( [ (Nothing, Just 7, Nothing) + , (Just 36, Nothing, Just 1) + , (Just 17, Just 7, Nothing) + ] :: [(Maybe Int, Maybe Rational, Maybe Rational)] + ) + type JSONValue = Maybe (JSONB A.Value) createSaneSQL :: (PersistField a) => SqlExpr (Value a) -> T.Text -> [PersistValue] -> IO () @@ -1156,6 +1225,7 @@ main = do testInsertUniqueViolation testUpsert testInsertSelectWithConflict + testFilterWhere describe "PostgreSQL JSON tests" $ do -- NOTE: We only clean the table once, so we -- can use its contents across all JSON tests