diff --git a/src/Database/Esqueleto.hs b/src/Database/Esqueleto.hs index 65c0d80..c8fa9cf 100644 --- a/src/Database/Esqueleto.hs +++ b/src/Database/Esqueleto.hs @@ -49,7 +49,11 @@ module Database.Esqueleto , like, (%), concat_, (++.) , subList_select, subList_selectDistinct, valList , in_, notIn, exists, notExists - , set, (=.), (+=.), (-=.), (*=.), (/=.) ) + , set, (=.), (+=.), (-=.), (*=.), (/=.) + , case_ ) + , when_ + , then_ + , else_ , from , Value(..) , unValue diff --git a/src/Database/Esqueleto/Internal/Language.hs b/src/Database/Esqueleto/Internal/Language.hs index 96087b5..57e2df0 100644 --- a/src/Database/Esqueleto/Internal/Language.hs +++ b/src/Database/Esqueleto/Internal/Language.hs @@ -35,6 +35,9 @@ module Database.Esqueleto.Internal.Language , PreprocessedFrom , From , FromPreprocess + , when_ + , then_ + , else_ ) where import Control.Applicative (Applicative(..), (<$>)) @@ -336,6 +339,48 @@ class (Functor query, Applicative query, Monad query) => -- | Apply extra @expr Value@ arguments to a 'PersistField' constructor (<&>) :: expr (Insertion (a -> b)) -> expr (Value a) -> expr (Insertion b) + -- | @CASE@ statement. For example: + -- + -- @ + -- select $ + -- return $ + -- case_ + -- [ when_ + -- (exists $ + -- from $ \\p -> do + -- where_ (p ^. PersonName ==. val "Mike")) + -- then_ + -- (sub_select $ + -- from $ \\v -> do + -- let sub = + -- from $ \\c -> do + -- where_ (c ^. PersonName ==. val "Mike") + -- return (c ^. PersonFavNum) + -- where_ (v ^. PersonFavNum >. sub_select sub) + -- return $ count (v ^. PersonName) +. val (1 :: Int)) ] + -- (else_ $ val (-1)) + -- @ + -- + -- This query is a bit complicated, but basically it checks if a person + -- named "Mike" exists, and if that person does, run the subquery to find + -- out how many people have a ranking (by Fav Num) higher than "Mike". + -- + -- __NOTE:__ There are a few things to be aware about this statement. + -- + -- * This only implements the full CASE statement, it does not + -- implement the "simple" CASE statement. + -- + -- + -- * At least one 'when_' and 'then_' is mandatory otherwise it will + -- emit an error. + -- + -- + -- * The 'else_' is also mandatory, unlike the SQL statement in which + -- if the @ELSE@ is omitted it will return a @NULL@. You can + -- reproduce this via 'nothing'. + -- + -- Since: 2.1.1 + case_ :: PersistField a => [(expr (Value Bool), expr (Value a))] -> expr (Value a) -> expr (Value a) -- Fixity declarations infixl 9 ^. @@ -346,6 +391,17 @@ infix 4 ==., >=., >., <=., <., !=. infixr 3 &&., =., +=., -=., *=., /=. infixr 2 ||., `InnerJoin`, `CrossJoin`, `LeftOuterJoin`, `RightOuterJoin`, `FullOuterJoin`, `like` +-- | Syntax Sugar for 'case_' +when_ :: expr (Value Bool) -> () -> expr a -> (expr (Value Bool), expr a) +when_ cond _ expr = (cond, expr) + +-- | Syntax Sugar for 'case_' +then_ :: () +then_ = () + +-- | Syntax Sugar for 'case_' +else_ :: expr a -> expr a +else_ = id -- | A single value (as opposed to a whole entity). You may use -- @('^.')@ or @('?.')@ to get a 'Value' from an 'Entity'. diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index 5a2dbc4..342b69e 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -27,6 +27,7 @@ module Database.Esqueleto.Internal.Sql , insertSelectDistinct , insertSelect -- * The guts + , unsafeSqlCase , unsafeSqlBinOp , unsafeSqlValue , unsafeSqlFunction @@ -413,6 +414,7 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where (gb, gv) = g x in (fb <> ", " <> gb, fv ++ gv) + case_ = unsafeSqlCase instance ToSomeValues SqlExpr (SqlExpr (Value a)) where toSomeValues a = [SomeValue a] @@ -448,6 +450,29 @@ ifNotEmptyList (EList _) _ x = x ---------------------------------------------------------------------- +-- | (Internal) Create a case statement. +-- +-- Since: 2.1.1 +unsafeSqlCase :: PersistField a => [(SqlExpr (Value Bool), SqlExpr (Value a))] -> SqlExpr (Value a) -> SqlExpr (Value a) +unsafeSqlCase when_ (ERaw p1 f1) = ERaw Never buildCase + where + buildCase :: IdentInfo -> (TLB.Builder, [PersistValue]) + buildCase info = + let (b1, vals1) = f1 info + (b2, vals2) = mapWhen when_ info + in ( "CASE" <> b2 <> " ELSE " <> parensM p1 b1 <> " END", vals2 <> vals1) + + mapWhen :: [(SqlExpr (Value Bool), SqlExpr (Value a))] -> IdentInfo -> (TLB.Builder, [PersistValue]) + mapWhen [] _ = error "unsafeSqlCase: empty when_ list." + mapWhen when_ info = foldl (foldHelp info) (mempty, mempty) when_ + + foldHelp :: IdentInfo -> (TLB.Builder, [PersistValue]) -> (SqlExpr (Value Bool), SqlExpr (Value a)) -> (TLB.Builder, [PersistValue]) + foldHelp info (b0, vals0) (ERaw p1 f1, ERaw p2 f2) = + let (b1, vals1) = f1 info + (b2, vals2) = f2 info + in ( b0 <> " WHEN " <> parensM p1 b1 <> " THEN " <> parensM p2 b2, vals0 <> vals1 <> vals2 ) + + -- | (Internal) Create a custom binary operator. You /should/ -- /not/ use this function directly since its type is very -- general, you should always use it with an explicit type diff --git a/test/Test.hs b/test/Test.hs index bb8e5b6..ce9cecf 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -863,6 +863,53 @@ main = do liftIO $ (ret1 == ret2) `shouldBe` False + describe "case" $ do + it "Works for a simple value based when - False" $ + run $ do + ret <- select $ + return $ + case_ + [ when_ (val False) then_ (val (1 :: Int)) ] + (else_ (val 2)) + + liftIO $ ret `shouldBe` [ Value 2 ] + + it "Works for a simple value based when - True" $ + run $ do + ret <- select $ + return $ + case_ + [ when_ (val True) then_ (val (1 :: Int)) ] + (else_ (val 2)) + + liftIO $ ret `shouldBe` [ Value 1 ] + + it "works for a semi-complicated query" $ + run $ do + _ <- insert p1 + _ <- insert p2 + _ <- insert p3 + _ <- insert p4 + _ <- insert p5 + ret <- select $ + return $ + case_ + [ when_ + (exists $ from $ \p -> do + where_ (p ^. PersonName ==. val "Mike")) + then_ + (sub_select $ from $ \v -> do + let sub = + from $ \c -> do + where_ (c ^. PersonName ==. val "Mike") + return (c ^. PersonFavNum) + where_ (v ^. PersonFavNum >. sub_select sub) + return $ count (v ^. PersonName) +. val (1 :: Int)) ] + (else_ $ val (-1)) + + liftIO $ ret `shouldBe` [ Value (3) ] + + ----------------------------------------------------------------------