From d690e0b42540f0f7f686db5c46b774702ceefe9c Mon Sep 17 00:00:00 2001 From: Paul Berens Date: Sun, 2 Nov 2014 01:07:11 -0700 Subject: [PATCH 1/2] Implement CASE support * This seems to work but I don't have in-depth tests yet * I seem to still have some oddity here and there which needs to be nailed down * This only implements the "full" CASE syntax, not the simplified, and it makes ELSE mandatory, (its optional with CASE) --- src/Database/Esqueleto.hs | 6 +++++- src/Database/Esqueleto/Internal/Language.hs | 14 +++++++++++++ src/Database/Esqueleto/Internal/Sql.hs | 23 +++++++++++++++++++++ test/Test.hs | 21 +++++++++++++++++++ 4 files changed, 63 insertions(+), 1 deletion(-) diff --git a/src/Database/Esqueleto.hs b/src/Database/Esqueleto.hs index 65c0d80..c8fa9cf 100644 --- a/src/Database/Esqueleto.hs +++ b/src/Database/Esqueleto.hs @@ -49,7 +49,11 @@ module Database.Esqueleto , like, (%), concat_, (++.) , subList_select, subList_selectDistinct, valList , in_, notIn, exists, notExists - , set, (=.), (+=.), (-=.), (*=.), (/=.) ) + , set, (=.), (+=.), (-=.), (*=.), (/=.) + , case_ ) + , when_ + , then_ + , else_ , from , Value(..) , unValue diff --git a/src/Database/Esqueleto/Internal/Language.hs b/src/Database/Esqueleto/Internal/Language.hs index 96087b5..0e349f5 100644 --- a/src/Database/Esqueleto/Internal/Language.hs +++ b/src/Database/Esqueleto/Internal/Language.hs @@ -35,6 +35,9 @@ module Database.Esqueleto.Internal.Language , PreprocessedFrom , From , FromPreprocess + , when_ + , then_ + , else_ ) where import Control.Applicative (Applicative(..), (<$>)) @@ -336,6 +339,8 @@ class (Functor query, Applicative query, Monad query) => -- | Apply extra @expr Value@ arguments to a 'PersistField' constructor (<&>) :: expr (Insertion (a -> b)) -> expr (Value a) -> expr (Insertion b) + case_ :: PersistField a => [(expr (Value Bool), expr (Value a))] -> expr (Value a) -> expr (Value a) + -- Fixity declarations infixl 9 ^. @@ -346,6 +351,15 @@ infix 4 ==., >=., >., <=., <., !=. infixr 3 &&., =., +=., -=., *=., /=. infixr 2 ||., `InnerJoin`, `CrossJoin`, `LeftOuterJoin`, `RightOuterJoin`, `FullOuterJoin`, `like` +-- Syntax Sugar for Case +when_ :: expr (Value Bool) -> () -> expr a -> (expr (Value Bool), expr a) +when_ cond _ expr = (cond, expr) + +then_ :: () +then_ = () + +else_ :: expr a -> expr a +else_ = id -- | A single value (as opposed to a whole entity). You may use -- @('^.')@ or @('?.')@ to get a 'Value' from an 'Entity'. diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index 5a2dbc4..52d8ea9 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -413,6 +413,29 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where (gb, gv) = g x in (fb <> ", " <> gb, fv ++ gv) + case_ when_ else_ = unsafeSqlCase when_ else_ + +-- +-- TODO: this is not 100% compat with sqlite as defined, looks like postgres also supports the extended version +unsafeSqlCase :: PersistField a => [(SqlExpr (Value Bool), SqlExpr (Value a))] -> SqlExpr (Value a) -> SqlExpr (Value a) +unsafeSqlCase when_ (ERaw p1 f1) = ERaw Never buildCase + where + buildCase :: IdentInfo -> (TLB.Builder, [PersistValue]) + buildCase info = + let (b1, vals1) = f1 info + (b2, vals2) = mapWhen when_ info + in ( "CASE" <> b2 <> " ELSE " <> parensM p1 b1 <> " END", vals2 <> vals1) + + mapWhen :: [(SqlExpr (Value Bool), SqlExpr (Value a))] -> IdentInfo -> (TLB.Builder, [PersistValue]) + mapWhen [] _ = error "unsafeSqlCase: empty when_ list." + mapWhen when_ info = foldl (foldHelp info) (mempty, mempty) when_ + + foldHelp :: IdentInfo -> (TLB.Builder, [PersistValue]) -> (SqlExpr (Value Bool), SqlExpr (Value a)) -> (TLB.Builder, [PersistValue]) + foldHelp info (b0, vals0) (ERaw p1 f1, ERaw p2 f2) = + let (b1, vals1) = f1 info + (b2, vals2) = f2 info + in ( b0 <> " WHEN " <> parensM p1 b1 <> " THEN " <> parensM p2 b2, vals0 <> vals1 <> vals2 ) + instance ToSomeValues SqlExpr (SqlExpr (Value a)) where toSomeValues a = [SomeValue a] diff --git a/test/Test.hs b/test/Test.hs index bb8e5b6..80041d0 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -863,6 +863,27 @@ main = do liftIO $ (ret1 == ret2) `shouldBe` False + describe "case" $ do + it "works for a single when" $ + run $ do + ret <- select $ + return $ + case_ + [ when_ + (exists $ from $ \p -> do + where_ (p ^. PersonName ==. val "Paul")) + then_ + (sub_select $ from $ \v -> do + let sub = + from $ \c -> do + where_ (c ^. PersonName ==. val "Paul") + return (c ^. PersonId) + where_ (v ^. PersonId >. sub_select sub) + return $ count (v ^. PersonName) +. val (1 :: Int)) ] + (else_ $ val (-1)) + + liftIO $ ret `shouldBe` [ Value (-1) ] + ---------------------------------------------------------------------- From ebe9185ef2ad78659b273ee3b4e7a451bc27c526 Mon Sep 17 00:00:00 2001 From: Paul Berens Date: Sun, 2 Nov 2014 14:37:12 -0800 Subject: [PATCH 2/2] Test improvement and documentation improvements --- src/Database/Esqueleto/Internal/Language.hs | 46 +++++++++++++++++++- src/Database/Esqueleto/Internal/Sql.hs | 48 +++++++++++---------- test/Test.hs | 38 +++++++++++++--- 3 files changed, 101 insertions(+), 31 deletions(-) diff --git a/src/Database/Esqueleto/Internal/Language.hs b/src/Database/Esqueleto/Internal/Language.hs index 0e349f5..57e2df0 100644 --- a/src/Database/Esqueleto/Internal/Language.hs +++ b/src/Database/Esqueleto/Internal/Language.hs @@ -339,9 +339,49 @@ class (Functor query, Applicative query, Monad query) => -- | Apply extra @expr Value@ arguments to a 'PersistField' constructor (<&>) :: expr (Insertion (a -> b)) -> expr (Value a) -> expr (Insertion b) + -- | @CASE@ statement. For example: + -- + -- @ + -- select $ + -- return $ + -- case_ + -- [ when_ + -- (exists $ + -- from $ \\p -> do + -- where_ (p ^. PersonName ==. val "Mike")) + -- then_ + -- (sub_select $ + -- from $ \\v -> do + -- let sub = + -- from $ \\c -> do + -- where_ (c ^. PersonName ==. val "Mike") + -- return (c ^. PersonFavNum) + -- where_ (v ^. PersonFavNum >. sub_select sub) + -- return $ count (v ^. PersonName) +. val (1 :: Int)) ] + -- (else_ $ val (-1)) + -- @ + -- + -- This query is a bit complicated, but basically it checks if a person + -- named "Mike" exists, and if that person does, run the subquery to find + -- out how many people have a ranking (by Fav Num) higher than "Mike". + -- + -- __NOTE:__ There are a few things to be aware about this statement. + -- + -- * This only implements the full CASE statement, it does not + -- implement the "simple" CASE statement. + -- + -- + -- * At least one 'when_' and 'then_' is mandatory otherwise it will + -- emit an error. + -- + -- + -- * The 'else_' is also mandatory, unlike the SQL statement in which + -- if the @ELSE@ is omitted it will return a @NULL@. You can + -- reproduce this via 'nothing'. + -- + -- Since: 2.1.1 case_ :: PersistField a => [(expr (Value Bool), expr (Value a))] -> expr (Value a) -> expr (Value a) - -- Fixity declarations infixl 9 ^. infixl 7 *., /. @@ -351,13 +391,15 @@ infix 4 ==., >=., >., <=., <., !=. infixr 3 &&., =., +=., -=., *=., /=. infixr 2 ||., `InnerJoin`, `CrossJoin`, `LeftOuterJoin`, `RightOuterJoin`, `FullOuterJoin`, `like` --- Syntax Sugar for Case +-- | Syntax Sugar for 'case_' when_ :: expr (Value Bool) -> () -> expr a -> (expr (Value Bool), expr a) when_ cond _ expr = (cond, expr) +-- | Syntax Sugar for 'case_' then_ :: () then_ = () +-- | Syntax Sugar for 'case_' else_ :: expr a -> expr a else_ = id diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index 52d8ea9..342b69e 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -27,6 +27,7 @@ module Database.Esqueleto.Internal.Sql , insertSelectDistinct , insertSelect -- * The guts + , unsafeSqlCase , unsafeSqlBinOp , unsafeSqlValue , unsafeSqlFunction @@ -413,29 +414,7 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where (gb, gv) = g x in (fb <> ", " <> gb, fv ++ gv) - case_ when_ else_ = unsafeSqlCase when_ else_ - --- --- TODO: this is not 100% compat with sqlite as defined, looks like postgres also supports the extended version -unsafeSqlCase :: PersistField a => [(SqlExpr (Value Bool), SqlExpr (Value a))] -> SqlExpr (Value a) -> SqlExpr (Value a) -unsafeSqlCase when_ (ERaw p1 f1) = ERaw Never buildCase - where - buildCase :: IdentInfo -> (TLB.Builder, [PersistValue]) - buildCase info = - let (b1, vals1) = f1 info - (b2, vals2) = mapWhen when_ info - in ( "CASE" <> b2 <> " ELSE " <> parensM p1 b1 <> " END", vals2 <> vals1) - - mapWhen :: [(SqlExpr (Value Bool), SqlExpr (Value a))] -> IdentInfo -> (TLB.Builder, [PersistValue]) - mapWhen [] _ = error "unsafeSqlCase: empty when_ list." - mapWhen when_ info = foldl (foldHelp info) (mempty, mempty) when_ - - foldHelp :: IdentInfo -> (TLB.Builder, [PersistValue]) -> (SqlExpr (Value Bool), SqlExpr (Value a)) -> (TLB.Builder, [PersistValue]) - foldHelp info (b0, vals0) (ERaw p1 f1, ERaw p2 f2) = - let (b1, vals1) = f1 info - (b2, vals2) = f2 info - in ( b0 <> " WHEN " <> parensM p1 b1 <> " THEN " <> parensM p2 b2, vals0 <> vals1 <> vals2 ) - + case_ = unsafeSqlCase instance ToSomeValues SqlExpr (SqlExpr (Value a)) where toSomeValues a = [SomeValue a] @@ -471,6 +450,29 @@ ifNotEmptyList (EList _) _ x = x ---------------------------------------------------------------------- +-- | (Internal) Create a case statement. +-- +-- Since: 2.1.1 +unsafeSqlCase :: PersistField a => [(SqlExpr (Value Bool), SqlExpr (Value a))] -> SqlExpr (Value a) -> SqlExpr (Value a) +unsafeSqlCase when_ (ERaw p1 f1) = ERaw Never buildCase + where + buildCase :: IdentInfo -> (TLB.Builder, [PersistValue]) + buildCase info = + let (b1, vals1) = f1 info + (b2, vals2) = mapWhen when_ info + in ( "CASE" <> b2 <> " ELSE " <> parensM p1 b1 <> " END", vals2 <> vals1) + + mapWhen :: [(SqlExpr (Value Bool), SqlExpr (Value a))] -> IdentInfo -> (TLB.Builder, [PersistValue]) + mapWhen [] _ = error "unsafeSqlCase: empty when_ list." + mapWhen when_ info = foldl (foldHelp info) (mempty, mempty) when_ + + foldHelp :: IdentInfo -> (TLB.Builder, [PersistValue]) -> (SqlExpr (Value Bool), SqlExpr (Value a)) -> (TLB.Builder, [PersistValue]) + foldHelp info (b0, vals0) (ERaw p1 f1, ERaw p2 f2) = + let (b1, vals1) = f1 info + (b2, vals2) = f2 info + in ( b0 <> " WHEN " <> parensM p1 b1 <> " THEN " <> parensM p2 b2, vals0 <> vals1 <> vals2 ) + + -- | (Internal) Create a custom binary operator. You /should/ -- /not/ use this function directly since its type is very -- general, you should always use it with an explicit type diff --git a/test/Test.hs b/test/Test.hs index 80041d0..ce9cecf 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -864,25 +864,51 @@ main = do liftIO $ (ret1 == ret2) `shouldBe` False describe "case" $ do - it "works for a single when" $ + it "Works for a simple value based when - False" $ run $ do + ret <- select $ + return $ + case_ + [ when_ (val False) then_ (val (1 :: Int)) ] + (else_ (val 2)) + + liftIO $ ret `shouldBe` [ Value 2 ] + + it "Works for a simple value based when - True" $ + run $ do + ret <- select $ + return $ + case_ + [ when_ (val True) then_ (val (1 :: Int)) ] + (else_ (val 2)) + + liftIO $ ret `shouldBe` [ Value 1 ] + + it "works for a semi-complicated query" $ + run $ do + _ <- insert p1 + _ <- insert p2 + _ <- insert p3 + _ <- insert p4 + _ <- insert p5 ret <- select $ return $ case_ [ when_ (exists $ from $ \p -> do - where_ (p ^. PersonName ==. val "Paul")) + where_ (p ^. PersonName ==. val "Mike")) then_ (sub_select $ from $ \v -> do let sub = from $ \c -> do - where_ (c ^. PersonName ==. val "Paul") - return (c ^. PersonId) - where_ (v ^. PersonId >. sub_select sub) + where_ (c ^. PersonName ==. val "Mike") + return (c ^. PersonFavNum) + where_ (v ^. PersonFavNum >. sub_select sub) return $ count (v ^. PersonName) +. val (1 :: Int)) ] (else_ $ val (-1)) - liftIO $ ret `shouldBe` [ Value (-1) ] + liftIO $ ret `shouldBe` [ Value (3) ] + ----------------------------------------------------------------------