diff --git a/src/Database/Esqueleto/Internal/Language.hs b/src/Database/Esqueleto/Internal/Language.hs index 0e349f5..57e2df0 100644 --- a/src/Database/Esqueleto/Internal/Language.hs +++ b/src/Database/Esqueleto/Internal/Language.hs @@ -339,9 +339,49 @@ class (Functor query, Applicative query, Monad query) => -- | Apply extra @expr Value@ arguments to a 'PersistField' constructor (<&>) :: expr (Insertion (a -> b)) -> expr (Value a) -> expr (Insertion b) + -- | @CASE@ statement. For example: + -- + -- @ + -- select $ + -- return $ + -- case_ + -- [ when_ + -- (exists $ + -- from $ \\p -> do + -- where_ (p ^. PersonName ==. val "Mike")) + -- then_ + -- (sub_select $ + -- from $ \\v -> do + -- let sub = + -- from $ \\c -> do + -- where_ (c ^. PersonName ==. val "Mike") + -- return (c ^. PersonFavNum) + -- where_ (v ^. PersonFavNum >. sub_select sub) + -- return $ count (v ^. PersonName) +. val (1 :: Int)) ] + -- (else_ $ val (-1)) + -- @ + -- + -- This query is a bit complicated, but basically it checks if a person + -- named "Mike" exists, and if that person does, run the subquery to find + -- out how many people have a ranking (by Fav Num) higher than "Mike". + -- + -- __NOTE:__ There are a few things to be aware about this statement. + -- + -- * This only implements the full CASE statement, it does not + -- implement the "simple" CASE statement. + -- + -- + -- * At least one 'when_' and 'then_' is mandatory otherwise it will + -- emit an error. + -- + -- + -- * The 'else_' is also mandatory, unlike the SQL statement in which + -- if the @ELSE@ is omitted it will return a @NULL@. You can + -- reproduce this via 'nothing'. + -- + -- Since: 2.1.1 case_ :: PersistField a => [(expr (Value Bool), expr (Value a))] -> expr (Value a) -> expr (Value a) - -- Fixity declarations infixl 9 ^. infixl 7 *., /. @@ -351,13 +391,15 @@ infix 4 ==., >=., >., <=., <., !=. infixr 3 &&., =., +=., -=., *=., /=. infixr 2 ||., `InnerJoin`, `CrossJoin`, `LeftOuterJoin`, `RightOuterJoin`, `FullOuterJoin`, `like` --- Syntax Sugar for Case +-- | Syntax Sugar for 'case_' when_ :: expr (Value Bool) -> () -> expr a -> (expr (Value Bool), expr a) when_ cond _ expr = (cond, expr) +-- | Syntax Sugar for 'case_' then_ :: () then_ = () +-- | Syntax Sugar for 'case_' else_ :: expr a -> expr a else_ = id diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index 52d8ea9..342b69e 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -27,6 +27,7 @@ module Database.Esqueleto.Internal.Sql , insertSelectDistinct , insertSelect -- * The guts + , unsafeSqlCase , unsafeSqlBinOp , unsafeSqlValue , unsafeSqlFunction @@ -413,29 +414,7 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where (gb, gv) = g x in (fb <> ", " <> gb, fv ++ gv) - case_ when_ else_ = unsafeSqlCase when_ else_ - --- --- TODO: this is not 100% compat with sqlite as defined, looks like postgres also supports the extended version -unsafeSqlCase :: PersistField a => [(SqlExpr (Value Bool), SqlExpr (Value a))] -> SqlExpr (Value a) -> SqlExpr (Value a) -unsafeSqlCase when_ (ERaw p1 f1) = ERaw Never buildCase - where - buildCase :: IdentInfo -> (TLB.Builder, [PersistValue]) - buildCase info = - let (b1, vals1) = f1 info - (b2, vals2) = mapWhen when_ info - in ( "CASE" <> b2 <> " ELSE " <> parensM p1 b1 <> " END", vals2 <> vals1) - - mapWhen :: [(SqlExpr (Value Bool), SqlExpr (Value a))] -> IdentInfo -> (TLB.Builder, [PersistValue]) - mapWhen [] _ = error "unsafeSqlCase: empty when_ list." - mapWhen when_ info = foldl (foldHelp info) (mempty, mempty) when_ - - foldHelp :: IdentInfo -> (TLB.Builder, [PersistValue]) -> (SqlExpr (Value Bool), SqlExpr (Value a)) -> (TLB.Builder, [PersistValue]) - foldHelp info (b0, vals0) (ERaw p1 f1, ERaw p2 f2) = - let (b1, vals1) = f1 info - (b2, vals2) = f2 info - in ( b0 <> " WHEN " <> parensM p1 b1 <> " THEN " <> parensM p2 b2, vals0 <> vals1 <> vals2 ) - + case_ = unsafeSqlCase instance ToSomeValues SqlExpr (SqlExpr (Value a)) where toSomeValues a = [SomeValue a] @@ -471,6 +450,29 @@ ifNotEmptyList (EList _) _ x = x ---------------------------------------------------------------------- +-- | (Internal) Create a case statement. +-- +-- Since: 2.1.1 +unsafeSqlCase :: PersistField a => [(SqlExpr (Value Bool), SqlExpr (Value a))] -> SqlExpr (Value a) -> SqlExpr (Value a) +unsafeSqlCase when_ (ERaw p1 f1) = ERaw Never buildCase + where + buildCase :: IdentInfo -> (TLB.Builder, [PersistValue]) + buildCase info = + let (b1, vals1) = f1 info + (b2, vals2) = mapWhen when_ info + in ( "CASE" <> b2 <> " ELSE " <> parensM p1 b1 <> " END", vals2 <> vals1) + + mapWhen :: [(SqlExpr (Value Bool), SqlExpr (Value a))] -> IdentInfo -> (TLB.Builder, [PersistValue]) + mapWhen [] _ = error "unsafeSqlCase: empty when_ list." + mapWhen when_ info = foldl (foldHelp info) (mempty, mempty) when_ + + foldHelp :: IdentInfo -> (TLB.Builder, [PersistValue]) -> (SqlExpr (Value Bool), SqlExpr (Value a)) -> (TLB.Builder, [PersistValue]) + foldHelp info (b0, vals0) (ERaw p1 f1, ERaw p2 f2) = + let (b1, vals1) = f1 info + (b2, vals2) = f2 info + in ( b0 <> " WHEN " <> parensM p1 b1 <> " THEN " <> parensM p2 b2, vals0 <> vals1 <> vals2 ) + + -- | (Internal) Create a custom binary operator. You /should/ -- /not/ use this function directly since its type is very -- general, you should always use it with an explicit type diff --git a/test/Test.hs b/test/Test.hs index 80041d0..ce9cecf 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -864,25 +864,51 @@ main = do liftIO $ (ret1 == ret2) `shouldBe` False describe "case" $ do - it "works for a single when" $ + it "Works for a simple value based when - False" $ run $ do + ret <- select $ + return $ + case_ + [ when_ (val False) then_ (val (1 :: Int)) ] + (else_ (val 2)) + + liftIO $ ret `shouldBe` [ Value 2 ] + + it "Works for a simple value based when - True" $ + run $ do + ret <- select $ + return $ + case_ + [ when_ (val True) then_ (val (1 :: Int)) ] + (else_ (val 2)) + + liftIO $ ret `shouldBe` [ Value 1 ] + + it "works for a semi-complicated query" $ + run $ do + _ <- insert p1 + _ <- insert p2 + _ <- insert p3 + _ <- insert p4 + _ <- insert p5 ret <- select $ return $ case_ [ when_ (exists $ from $ \p -> do - where_ (p ^. PersonName ==. val "Paul")) + where_ (p ^. PersonName ==. val "Mike")) then_ (sub_select $ from $ \v -> do let sub = from $ \c -> do - where_ (c ^. PersonName ==. val "Paul") - return (c ^. PersonId) - where_ (v ^. PersonId >. sub_select sub) + where_ (c ^. PersonName ==. val "Mike") + return (c ^. PersonFavNum) + where_ (v ^. PersonFavNum >. sub_select sub) return $ count (v ^. PersonName) +. val (1 :: Int)) ] (else_ $ val (-1)) - liftIO $ ret `shouldBe` [ Value (-1) ] + liftIO $ ret `shouldBe` [ Value (3) ] + ----------------------------------------------------------------------