diff --git a/src/Database/Esqueleto.hs b/src/Database/Esqueleto.hs index 65c0d80..c8fa9cf 100644 --- a/src/Database/Esqueleto.hs +++ b/src/Database/Esqueleto.hs @@ -49,7 +49,11 @@ module Database.Esqueleto , like, (%), concat_, (++.) , subList_select, subList_selectDistinct, valList , in_, notIn, exists, notExists - , set, (=.), (+=.), (-=.), (*=.), (/=.) ) + , set, (=.), (+=.), (-=.), (*=.), (/=.) + , case_ ) + , when_ + , then_ + , else_ , from , Value(..) , unValue diff --git a/src/Database/Esqueleto/Internal/Language.hs b/src/Database/Esqueleto/Internal/Language.hs index 96087b5..0e349f5 100644 --- a/src/Database/Esqueleto/Internal/Language.hs +++ b/src/Database/Esqueleto/Internal/Language.hs @@ -35,6 +35,9 @@ module Database.Esqueleto.Internal.Language , PreprocessedFrom , From , FromPreprocess + , when_ + , then_ + , else_ ) where import Control.Applicative (Applicative(..), (<$>)) @@ -336,6 +339,8 @@ class (Functor query, Applicative query, Monad query) => -- | Apply extra @expr Value@ arguments to a 'PersistField' constructor (<&>) :: expr (Insertion (a -> b)) -> expr (Value a) -> expr (Insertion b) + case_ :: PersistField a => [(expr (Value Bool), expr (Value a))] -> expr (Value a) -> expr (Value a) + -- Fixity declarations infixl 9 ^. @@ -346,6 +351,15 @@ infix 4 ==., >=., >., <=., <., !=. infixr 3 &&., =., +=., -=., *=., /=. infixr 2 ||., `InnerJoin`, `CrossJoin`, `LeftOuterJoin`, `RightOuterJoin`, `FullOuterJoin`, `like` +-- Syntax Sugar for Case +when_ :: expr (Value Bool) -> () -> expr a -> (expr (Value Bool), expr a) +when_ cond _ expr = (cond, expr) + +then_ :: () +then_ = () + +else_ :: expr a -> expr a +else_ = id -- | A single value (as opposed to a whole entity). You may use -- @('^.')@ or @('?.')@ to get a 'Value' from an 'Entity'. diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index 5a2dbc4..52d8ea9 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -413,6 +413,29 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where (gb, gv) = g x in (fb <> ", " <> gb, fv ++ gv) + case_ when_ else_ = unsafeSqlCase when_ else_ + +-- +-- TODO: this is not 100% compat with sqlite as defined, looks like postgres also supports the extended version +unsafeSqlCase :: PersistField a => [(SqlExpr (Value Bool), SqlExpr (Value a))] -> SqlExpr (Value a) -> SqlExpr (Value a) +unsafeSqlCase when_ (ERaw p1 f1) = ERaw Never buildCase + where + buildCase :: IdentInfo -> (TLB.Builder, [PersistValue]) + buildCase info = + let (b1, vals1) = f1 info + (b2, vals2) = mapWhen when_ info + in ( "CASE" <> b2 <> " ELSE " <> parensM p1 b1 <> " END", vals2 <> vals1) + + mapWhen :: [(SqlExpr (Value Bool), SqlExpr (Value a))] -> IdentInfo -> (TLB.Builder, [PersistValue]) + mapWhen [] _ = error "unsafeSqlCase: empty when_ list." + mapWhen when_ info = foldl (foldHelp info) (mempty, mempty) when_ + + foldHelp :: IdentInfo -> (TLB.Builder, [PersistValue]) -> (SqlExpr (Value Bool), SqlExpr (Value a)) -> (TLB.Builder, [PersistValue]) + foldHelp info (b0, vals0) (ERaw p1 f1, ERaw p2 f2) = + let (b1, vals1) = f1 info + (b2, vals2) = f2 info + in ( b0 <> " WHEN " <> parensM p1 b1 <> " THEN " <> parensM p2 b2, vals0 <> vals1 <> vals2 ) + instance ToSomeValues SqlExpr (SqlExpr (Value a)) where toSomeValues a = [SomeValue a] diff --git a/test/Test.hs b/test/Test.hs index bb8e5b6..80041d0 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -863,6 +863,27 @@ main = do liftIO $ (ret1 == ret2) `shouldBe` False + describe "case" $ do + it "works for a single when" $ + run $ do + ret <- select $ + return $ + case_ + [ when_ + (exists $ from $ \p -> do + where_ (p ^. PersonName ==. val "Paul")) + then_ + (sub_select $ from $ \v -> do + let sub = + from $ \c -> do + where_ (c ^. PersonName ==. val "Paul") + return (c ^. PersonId) + where_ (v ^. PersonId >. sub_select sub) + return $ count (v ^. PersonName) +. val (1 :: Int)) ] + (else_ $ val (-1)) + + liftIO $ ret `shouldBe` [ Value (-1) ] + ----------------------------------------------------------------------