diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index cd28b81..bcff141 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -6,6 +6,7 @@ module Database.Esqueleto.Internal.Sql ) where import Control.Applicative (Applicative(..), (<$>)) +import Control.Arrow (first, (&&&)) import Control.Monad (ap) import Control.Monad.Logger (MonadLogger) import Control.Monad.Trans.Resource (MonadResourceBase) @@ -85,7 +86,7 @@ idents _ = -- | An expression on the SQL backend. data SqlExpr a where EEntity :: Ident -> SqlExpr (Entity val) - ERaw :: (Connection -> TLB.Builder) -> [PersistValue] -> SqlExpr (Single a) + ERaw :: (Connection -> (TLB.Builder, [PersistValue])) -> SqlExpr (Single a) instance Esqueleto SqlQuery SqlExpr SqlPersist where fromSingle = Q $ do @@ -99,12 +100,13 @@ instance Esqueleto SqlQuery SqlExpr SqlPersist where where_ expr = Q $ W.tell mempty { sdWhereClause = Where expr } - EEntity (I ident) ^. field = ERaw (\conn -> ident <> ("." <> name conn field)) [] + EEntity (I ident) ^. field = ERaw $ \conn -> (ident <> ("." <> name conn field), []) where name conn = fromDBName conn . fieldDB . persistFieldDef - val = ERaw (const "?") . return . toPersistValue + val = ERaw . const . (,) "?" . return . toPersistValue - not_ (ERaw b vals) = ERaw (\conn -> "NOT " <> parens (b conn)) vals + not_ (ERaw f) = ERaw $ \conn -> let (b, vals) = f conn + in ("NOT " <> parens b, vals) (==.) = binop " = " (>=.) = binop " >= " @@ -124,9 +126,12 @@ fromDBName :: Connection -> DBName -> TLB.Builder fromDBName conn = TLB.fromText . escapeName conn binop :: TLB.Builder -> SqlExpr (Single a) -> SqlExpr (Single b) -> SqlExpr (Single c) -binop op (ERaw b1 vals1) (ERaw b2 vals2) = ERaw b (vals1 <> vals2) +binop op (ERaw f1) (ERaw f2) = ERaw f where - b conn = parens (b1 conn) <> op <> parens (b2 conn) + f conn = let (b1, vals1) = f1 conn + (b2, vals2) = f2 conn + in ( parens b1 <> op <> parens b2 + , vals1 <> vals2 ) -- | TODO @@ -174,12 +179,12 @@ class RawSql (SqlSelectRet a) => SqlSelect a where instance RawSql a => SqlSelect (SqlExpr a) where type SqlSelectRet (SqlExpr a) = a - makeSelect _ (EEntity _) = ("??", mempty) - makeSelect conn (ERaw b vals) = (parens (b conn), vals) + makeSelect _ (EEntity _) = ("??", mempty) + makeSelect conn (ERaw f) = first parens (f conn) instance (SqlSelect a, SqlSelect b) => SqlSelect (a, b) where type SqlSelectRet (a, b) = (SqlSelectRet a, SqlSelectRet b) - makeSelect conn (a, b) = makeSelect conn a <> makeSelect conn b + makeSelect conn (a, b) = uncommas' [makeSelect conn a, makeSelect conn b] instance (SqlSelect a, SqlSelect b, SqlSelect c) => SqlSelect (a, b, c) where type SqlSelectRet (a, b, c) = ( SqlSelectRet a @@ -187,7 +192,7 @@ instance (SqlSelect a, SqlSelect b, SqlSelect c) => SqlSelect (a, b, c) where , SqlSelectRet c ) makeSelect conn (a, b, c) = - mconcat + uncommas' [ makeSelect conn a , makeSelect conn b , makeSelect conn c @@ -204,7 +209,7 @@ instance ( SqlSelect a , SqlSelectRet d ) makeSelect conn (a, b, c, d) = - mconcat + uncommas' [ makeSelect conn a , makeSelect conn b , makeSelect conn c @@ -224,7 +229,7 @@ instance ( SqlSelect a , SqlSelectRet e ) makeSelect conn (a, b, c, d, e) = - mconcat + uncommas' [ makeSelect conn a , makeSelect conn b , makeSelect conn c @@ -247,7 +252,7 @@ instance ( SqlSelect a , SqlSelectRet f ) makeSelect conn (a, b, c, d, e, f) = - mconcat + uncommas' [ makeSelect conn a , makeSelect conn b , makeSelect conn c @@ -273,7 +278,7 @@ instance ( SqlSelect a , SqlSelectRet g ) makeSelect conn (a, b, c, d, e, f, g) = - mconcat + uncommas' [ makeSelect conn a , makeSelect conn b , makeSelect conn c @@ -302,7 +307,7 @@ instance ( SqlSelect a , SqlSelectRet h ) makeSelect conn (a, b, c, d, e, f, g, h) = - mconcat + uncommas' [ makeSelect conn a , makeSelect conn b , makeSelect conn c @@ -314,15 +319,22 @@ instance ( SqlSelect a ] +uncommas :: [TLB.Builder] -> TLB.Builder +uncommas = mconcat . intersperse ", " + +uncommas' :: Monoid a => [(TLB.Builder, a)] -> (TLB.Builder, a) +uncommas' = uncommas . map fst &&& mconcat . map snd + + makeFrom :: Connection -> [FromClause] -> TLB.Builder -makeFrom conn = mconcat . intersperse ", " . map mk +makeFrom conn = uncommas . map mk where mk (From (I i) def) = fromDBName conn (entityDB def) <> (" AS " <> i) makeWhere :: Connection -> WhereClause -> (TLB.Builder, [PersistValue]) makeWhere _ NoWhere = mempty -makeWhere conn (Where (ERaw b vals)) = ("\nWHERE " <> b conn, vals) +makeWhere conn (Where (ERaw f)) = first ("\nWHERE " <>) (f conn) parens :: TLB.Builder -> TLB.Builder