From 0f677e92469c5801099a2cb94c19558d5fb6d80c Mon Sep 17 00:00:00 2001 From: Felipe Lessa Date: Wed, 5 Sep 2012 22:58:08 -0300 Subject: [PATCH] Add UPDATE support. --- src/Database/Esqueleto.hs | 4 +- src/Database/Esqueleto/Internal/Language.hs | 20 ++++- src/Database/Esqueleto/Internal/Sql.hs | 90 ++++++++++++++++++--- test/Test.hs | 19 +++++ 4 files changed, 118 insertions(+), 15 deletions(-) diff --git a/src/Database/Esqueleto.hs b/src/Database/Esqueleto.hs index 327032e..b1f2e44 100644 --- a/src/Database/Esqueleto.hs +++ b/src/Database/Esqueleto.hs @@ -20,7 +20,8 @@ module Database.Esqueleto , sub_select, sub_selectDistinct, (^.), (?.) , val, isNothing, just, nothing, not_, (==.), (>=.) , (>.), (<=.), (<.), (!=.), (&&.), (||.) - , (+.), (-.), (/.), (*.) ) + , (+.), (-.), (/.), (*.) + , set, (=.), (+=.), (-=.), (*=.), (/=.) ) , from , OrderBy -- ** Joins @@ -39,6 +40,7 @@ module Database.Esqueleto , selectSource , selectDistinctSource , delete + , update -- * Re-exports -- $reexports diff --git a/src/Database/Esqueleto/Internal/Language.hs b/src/Database/Esqueleto/Internal/Language.hs index 5af9b7f..7bf2c40 100644 --- a/src/Database/Esqueleto/Internal/Language.hs +++ b/src/Database/Esqueleto/Internal/Language.hs @@ -20,6 +20,7 @@ module Database.Esqueleto.Internal.Language , OnClauseWithoutMatchingJoinException(..) , PreprocessedFrom , OrderBy + , Update ) where import Control.Applicative (Applicative(..), (<$>)) @@ -173,12 +174,24 @@ class (Functor query, Applicative query, Monad query) => (/.) :: PersistField a => expr (Single a) -> expr (Single a) -> expr (Single a) (*.) :: PersistField a => expr (Single a) -> expr (Single a) -> expr (Single a) + -- | @SET@ clause used on @UPDATE@s. Note that while it's not + -- a type error to use this function on a @SELECT@, it will + -- most certainly result in a runtime error. + set :: PersistEntity val => expr (Entity val) -> [expr (Update val)] -> query () + + (=.) :: (PersistEntity val, PersistField typ) => EntityField val typ -> expr (Single typ) -> expr (Update val) + (+=.) :: (PersistEntity val, PersistField a) => EntityField val a -> expr (Single a) -> expr (Update val) + (-=.) :: (PersistEntity val, PersistField a) => EntityField val a -> expr (Single a) -> expr (Update val) + (*=.) :: (PersistEntity val, PersistField a) => EntityField val a -> expr (Single a) -> expr (Update val) + (/=.) :: (PersistEntity val, PersistField a) => EntityField val a -> expr (Single a) -> expr (Update val) + + -- Fixity declarations infixl 9 ^. infixl 7 *., /. infixl 6 +., -. infix 4 ==., >=., >., <=., <., !=. -infixr 3 &&. +infixr 3 &&., =., +=., -=., *=., /=. infixr 2 ||. infixr 2 `InnerJoin`, `CrossJoin`, `LeftOuterJoin`, `RightOuterJoin`, `FullOuterJoin` @@ -263,6 +276,11 @@ data PreprocessedFrom a data OrderBy +-- | Phantom type for a @SET@ operation on an entity of the given +-- type (see 'set' and '(=.)'). +data Update typ + + -- | @FROM@ clause: bring an entity into scope. -- -- The following types implement 'from': diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index 0eb1655..458fec4 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -18,6 +18,7 @@ module Database.Esqueleto.Internal.Sql , runSource , rawExecute , delete + , update , toRawSql , Mode(..) ) where @@ -71,14 +72,15 @@ instance Applicative SqlQuery where -- | Side data written by 'SqlQuery'. data SideData = SideData { sdFromClause :: ![FromClause] + , sdSetClause :: ![SetClause] , sdWhereClause :: !WhereClause , sdOrderByClause :: ![OrderByClause] } instance Monoid SideData where - mempty = SideData mempty mempty mempty - SideData f w o `mappend` SideData f' w' o' = - SideData (f <> f') (w <> w') (o <> o') + mempty = SideData mempty mempty mempty mempty + SideData f s w o `mappend` SideData f' s' w' o' = + SideData (f <> f') (s <> s') (w <> w') (o <> o') -- | A part of a @FROM@ clause. @@ -88,6 +90,10 @@ data FromClause = | OnClause (SqlExpr (Single Bool)) +-- | A part of a @SET@ clause. +newtype SetClause = SetClause (SqlExpr (Single ())) + + -- | Collect 'OnClause's on 'FromJoin's. Returns the first -- unmatched 'OnClause's data on error. Returns a list without -- 'OnClauses' on success. @@ -182,6 +188,7 @@ data SqlExpr a where EMaybe :: SqlExpr a -> SqlExpr (Maybe a) ERaw :: NeedParens -> (Escape -> (TLB.Builder, [PersistValue])) -> SqlExpr (Single a) EOrderBy :: OrderByType -> SqlExpr (Single a) -> SqlExpr OrderBy + ESet :: (SqlExpr (Entity val) -> SqlExpr (Single ())) -> SqlExpr (Update val) EPreprocessedFrom :: a -> FromClause -> SqlExpr (PreprocessedFrom a) data NeedParens = Parens | Never @@ -240,8 +247,8 @@ instance Esqueleto SqlQuery SqlExpr SqlPersist where sub_select = sub SELECT sub_selectDistinct = sub SELECT_DISTINCT - EEntity ident ^. field = ERaw Never $ \esc -> (useIdent esc ident <> ("." <> name esc field), []) - where name esc = esc . fieldDB . persistFieldDef + EEntity ident ^. field = + ERaw Never $ \esc -> (useIdent esc ident <> ("." <> fieldName esc field), []) _ ^. _ = error "Esqueleto/Sql/(^.): never here (see GHC #6124)" EMaybe r ?. field = maybelize (r ^. field) @@ -276,6 +283,29 @@ instance Esqueleto SqlQuery SqlExpr SqlPersist where (/.) = binop " / " (*.) = binop " * " + set ent upds = Q $ W.tell mempty { sdSetClause = map apply upds } + where + apply (ESet f) = SetClause (f ent) + apply _ = error "Esqueleto/Sql/set/apply: never here (see GHC #6124)" + + field =. expr = setAux field (const expr) + field +=. expr = setAux field (\ent -> ent ^. field +. expr) + field -=. expr = setAux field (\ent -> ent ^. field -. expr) + field *=. expr = setAux field (\ent -> ent ^. field *. expr) + field /=. expr = setAux field (\ent -> ent ^. field /. expr) + + +fieldName :: (PersistEntity val, PersistField typ) + => Escape -> EntityField val typ -> TLB.Builder +fieldName esc = esc . fieldDB . persistFieldDef + +setAux :: (PersistEntity val, PersistField typ) + => EntityField val typ + -> (SqlExpr (Entity val) -> SqlExpr (Single typ)) + -> SqlExpr (Update val) +setAux field mkVal = ESet $ \ent -> binop " = " name (mkVal ent) + where name = ERaw Never $ \esc -> (fieldName esc field, mempty) + sub :: PersistField a => Mode -> SqlQuery (SqlExpr (Single a)) -> SqlExpr (Single a) sub mode query = ERaw Parens $ \esc -> first parens (toRawSql mode esc query) @@ -407,24 +437,46 @@ delete :: ( MonadLogger m delete = rawExecute DELETE +-- | Execute an @esqueleto@ @UPDATE@ query inside @persistent@'s +-- 'SqlPersist' monad. Note that currently there are no type +-- checks for statements that should not appear on a @UPDATE@ +-- query. +-- +-- Example of usage: +-- +-- @ +-- update $ \p -> do +-- set p [ PersonAge =. just (val thisYear) -. p ^. PersonBorn ] +-- where_ $ isNull (p ^. PersonAge) +-- @ +update :: ( MonadLogger m + , MonadResourceBase m + , PersistEntity val + , PersistEntityBackend val ~ SqlPersist ) + => (SqlExpr (Entity val) -> SqlQuery ()) + -> SqlPersist m () +update = rawExecute UPDATE . from + + ---------------------------------------------------------------------- -- | Pretty prints a 'SqlQuery' into a SQL query. toRawSql :: SqlSelect a r => Mode -> Escape -> SqlQuery a -> (TLB.Builder, [PersistValue]) toRawSql mode esc query = - let (ret, SideData fromClauses whereClauses orderByClauses) = + let (ret, SideData fromClauses setClauses whereClauses orderByClauses) = flip S.evalState initialIdentState $ W.runWriterT $ unQ query in mconcat [ makeSelect esc mode ret - , makeFrom esc fromClauses + , makeFrom esc mode fromClauses + , makeSet esc setClauses , makeWhere esc whereClauses , makeOrderBy esc orderByClauses ] -data Mode = SELECT | SELECT_DISTINCT | DELETE +data Mode = SELECT | SELECT_DISTINCT | DELETE | UPDATE uncommas :: [TLB.Builder] -> TLB.Builder @@ -435,21 +487,25 @@ uncommas' = (uncommas *** mconcat) . unzip makeSelect :: SqlSelect a r => Escape -> Mode -> a -> (TLB.Builder, [PersistValue]) -makeSelect esc mode ret = first (s <>) (sqlSelectCols esc ret) +makeSelect esc mode ret = first (s <>) (sqlSelectCols esc ret) where s = case mode of SELECT -> "SELECT " SELECT_DISTINCT -> "SELECT DISTINCT " DELETE -> "DELETE" + UPDATE -> "UPDATE " -makeFrom :: Escape -> [FromClause] -> (TLB.Builder, [PersistValue]) -makeFrom _ [] = mempty -makeFrom esc fs = ret +makeFrom :: Escape -> Mode -> [FromClause] -> (TLB.Builder, [PersistValue]) +makeFrom _ _ [] = mempty +makeFrom esc mode fs = ret where ret = case collectOnClauses fs of Left expr -> throw $ mkExc expr - Right fs' -> first ("\nFROM " <>) $ uncommas' (map (mk Never mempty) fs') + Right fs' -> keyword $ uncommas' (map (mk Never mempty) fs') + keyword = case mode of + UPDATE -> id + _ -> first ("\nFROM " <>) mk _ onClause (FromStart i def) = base i def <> onClause mk paren onClause (FromJoin lhs kind rhs monClause) = @@ -482,6 +538,14 @@ makeFrom esc fs = ret mkExc _ = OnClauseWithoutMatchingJoinException "???" +makeSet :: Escape -> [SetClause] -> (TLB.Builder, [PersistValue]) +makeSet _ [] = mempty +makeSet esc os = first ("\nSET " <>) $ uncommas' (map mk os) + where + mk (SetClause (ERaw _ f)) = f esc + mk _ = error "Esqueleto/Sql/makeSet: never here (see GHC #6124)" + + makeWhere :: Escape -> WhereClause -> (TLB.Builder, [PersistValue]) makeWhere _ NoWhere = mempty makeWhere esc (Where (ERaw _ f)) = first ("\nWHERE " <>) (f esc) diff --git a/test/Test.hs b/test/Test.hs index 75db1ee..00b3183 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -363,6 +363,25 @@ main = do return p liftIO $ ret2 `shouldBe` [ p3e, p2e ] + describe "update" $ + it "works on a simple example" $ + run $ do + p1k <- insert p1 + p2k <- insert p2 + p3k <- insert p3 + let anon = "Anonymous" + () <- update $ \p -> do + set p [ PersonName =. val anon + , PersonAge *=. just (val 2) ] + where_ (p ^. PersonName !=. val "Mike") + ret <- select $ + from $ \p -> do + orderBy [ asc (p ^. PersonName), asc (p ^. PersonAge) ] + return p + liftIO $ ret `shouldBe` [ Entity p2k (Person anon Nothing) + , Entity p1k (Person anon (Just 72)) + , Entity p3k p3 ] + ----------------------------------------------------------------------