esqueleto/src/Database/Esqueleto/Internal/Sql.hs
2012-09-04 00:29:39 -03:00

436 lines
13 KiB
Haskell

{-# LANGUAGE ConstraintKinds, MultiParamTypeClasses, FunctionalDependencies, FlexibleContexts, FlexibleInstances, UndecidableInstances, GADTs, OverloadedStrings #-}
module Database.Esqueleto.Internal.Sql
( SqlQuery
, SqlExpr
, select
, selectSource
, toRawSelectSql
) where
import Control.Applicative (Applicative(..), (<$>))
import Control.Arrow (first)
import Control.Exception (throwIO)
import Control.Monad (ap)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Logger (MonadLogger)
import Control.Monad.Trans.Resource (MonadResourceBase)
import Data.List (intersperse)
import Data.Monoid (Monoid(..), (<>))
import Database.Persist.EntityDef
import Database.Persist.GenericSql
import Database.Persist.GenericSql.Internal (Connection(escapeName))
import Database.Persist.GenericSql.Raw (withStmt)
import Database.Persist.Store
import qualified Control.Monad.Supply as S
import qualified Control.Monad.Trans.Reader as R
import qualified Control.Monad.Trans.Writer as W
import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL
import qualified Data.Text as T
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Builder as TLB
import Database.Esqueleto.Internal.Language
-- | SQL backend for 'Esqueleto' using 'SqlPersist'.
newtype SqlQuery a =
Q { unQ :: W.WriterT SideData (S.Supply Ident) a }
instance Functor SqlQuery where
fmap f = Q . fmap f . unQ
instance Monad SqlQuery where
return = Q . return
m >>= f = Q (unQ m >>= unQ . f)
instance Applicative SqlQuery where
pure = return
(<*>) = ap
-- | Side data written by 'SqlQuery'.
data SideData = SideData { sdFromClause :: ![FromClause]
, sdWhereClause :: !WhereClause
}
instance Monoid SideData where
mempty = SideData mempty mempty
SideData f w `mappend` SideData f' w' =
SideData (f <> f') (w <> w')
-- | A part of a @FROM@ clause.
data FromClause = From Ident EntityDef
-- | A complete @WHERE@ clause.
data WhereClause = Where (SqlExpr (Single Bool))
| NoWhere
instance Monoid WhereClause where
mempty = NoWhere
NoWhere `mappend` w = w
w `mappend` NoWhere = w
Where e1 `mappend` Where e2 = Where (e1 &&. e2)
-- | Identifier used for tables.
newtype Ident = I TLB.Builder
-- | Infinite list of identifiers.
idents :: () -- ^ Avoid keeping everything in memory.
-> [Ident]
idents _ =
let alpha = ['A'..'Z']
letters 1 = map return alpha
letters n = (:) <$> alpha <*> letters (n-1)
everything = concat (map letters [(1::Int)..])
in map (I . TLB.fromString . ('T':)) everything
-- | An expression on the SQL backend.
data SqlExpr a where
EEntity :: Ident -> SqlExpr (Entity val)
ERaw :: (Escape -> (TLB.Builder, [PersistValue])) -> SqlExpr (Single a)
type Escape = DBName -> TLB.Builder
instance Esqueleto SqlQuery SqlExpr SqlPersist where
fromSingle = Q $ do
ident <- S.supply
let from_ = From ident $ entityDef (getVal ret)
ret = EEntity ident
getVal :: SqlExpr (Entity val) -> val
getVal = error "SqlQuery/getVal: never here"
W.tell mempty { sdFromClause = [from_] }
return ret
where_ expr = Q $ W.tell mempty { sdWhereClause = Where expr }
sub query = ERaw $ \esc -> first parens (toRawSelectSql esc query)
EEntity (I ident) ^. field = ERaw $ \esc -> (ident <> ("." <> name esc field), [])
where name esc = esc . fieldDB . persistFieldDef
_ ^. _ = error "Esqueleto/Sql/(^.): never here (see GHC #6124)"
val = ERaw . const . (,) "?" . return . toPersistValue
not_ (ERaw f) = ERaw $ \esc -> let (b, vals) = f esc
in ("NOT " <> parens b, vals)
not_ _ = error "Esqueleto/Sql/not_: never here (see GHC #6124)"
(==.) = binop " = "
(>=.) = binop " >= "
(>.) = binop " > "
(<=.) = binop " <= "
(<.) = binop " < "
(!=.) = binop " != "
(&&.) = binop " AND "
(||.) = binop " OR "
(+.) = binop " + "
(-.) = binop " - "
(/.) = binop " / "
(*.) = binop " * "
fromDBName :: Connection -> DBName -> TLB.Builder
fromDBName conn = TLB.fromText . escapeName conn
binop :: TLB.Builder -> SqlExpr (Single a) -> SqlExpr (Single b) -> SqlExpr (Single c)
binop op (ERaw f1) (ERaw f2) = ERaw f
where
f esc = let (b1, vals1) = f1 esc
(b2, vals2) = f2 esc
in ( parens b1 <> op <> parens b2
, vals1 <> vals2 )
binop _ _ _ = error "Esqueleto/Sql/binop: never here (see GHC #6124)"
-- | Execute an Esqueleto's 'SqlQuery' inside @persistent@'s
-- 'SqlPersist' monad.
selectSource :: ( SqlSelect a r
, MonadLogger m
, MonadResourceBase m )
=> SqlQuery a -> SqlPersist m (C.Source (C.ResourceT (SqlPersist m)) r)
selectSource query = src
where
src = do
conn <- getConnection
return $ run conn C.$= massage
run conn =
uncurry withStmt $
first (TL.toStrict . TLB.toLazyText) $
toRawSelectSql (fromDBName conn) query
massage = do
mrow <- C.await
case process <$> mrow of
Just (Right r) -> C.yield r >> massage
Just (Left err) -> liftIO $ throwIO $ PersistMarshalError err
Nothing -> return ()
process = sqlSelectProcessRow
-- | Execute an Esqueleto's 'SqlQuery' inside @persistent@'s
-- 'SqlPersist' monad.
select :: ( SqlSelect a r
, MonadLogger m
, MonadResourceBase m )
=> SqlQuery a -> SqlPersist m [r]
select query = do
src <- selectSource query
C.runResourceT $ src C.$$ CL.consume
-- | Get current database 'Connection'.
getConnection :: Monad m => SqlPersist m Connection
getConnection = SqlPersist R.ask
-- | Pretty prints a 'SqlQuery' into a SQL query.
toRawSelectSql :: SqlSelect a r => Escape -> SqlQuery a -> (TLB.Builder, [PersistValue])
toRawSelectSql esc query =
let (ret, SideData fromClauses whereClauses) =
flip S.evalSupply (idents ()) $
W.runWriterT $
unQ query
(_, selectText, selectVars) = sqlSelectCols esc ret
( whereText, whereVars) = makeWhere esc whereClauses
text = mconcat
[ "SELECT "
, selectText
, makeFrom esc fromClauses
, whereText
]
in (text, selectVars <> whereVars)
uncommas :: [TLB.Builder] -> TLB.Builder
uncommas = mconcat . intersperse ", "
uncommas' :: Monoid a => [(Int, TLB.Builder, a)] -> (Int, TLB.Builder, a)
uncommas' xs =
let (as, bs, cs) = unzip3 xs
in (sum as, uncommas bs, mconcat cs)
makeFrom :: Escape -> [FromClause] -> TLB.Builder
makeFrom _ [] = mempty
makeFrom esc fs = "\nFROM " <> uncommas (map mk fs)
where
mk (From (I i) def) = esc (entityDB def) <> (" AS " <> i)
makeWhere :: Escape -> WhereClause -> (TLB.Builder, [PersistValue])
makeWhere _ NoWhere = mempty
makeWhere esc (Where (ERaw f)) = first ("\nWHERE " <>) (f esc)
makeWhere _ _ = error "Esqueleto/Sql/makeWhere: never here (see GHC #6124)"
parens :: TLB.Builder -> TLB.Builder
parens b = "(" <> (b <> ")")
-- | Class for mapping results coming from 'SqlQuery' into actual
-- results.
--
-- This looks very similar to @RawSql@, and it is! However,
-- there are some crucial differences and ultimately they're
-- different classes.
class SqlSelect a r | a -> r, r -> a where
-- | Creates the variable part of the @SELECT@ query and
-- returns the list of 'PersistValue's that will be given to
-- 'withStmt'.
sqlSelectCols :: Escape -> a -> (Int, TLB.Builder, [PersistValue])
-- | Transform a row of the result into the data type.
sqlSelectProcessRow :: [PersistValue] -> Either T.Text r
instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where
sqlSelectCols escape expr@(EEntity (I ident)) = ret
where
process ed = uncommas $
map ((name <>) . escape) $
(entityID ed:) $
map fieldDB $
entityFields ed
-- 'name' is the biggest difference between 'RawSql' and
-- 'SqlSelect'. We automatically create names for tables
-- (since it's not the user who's writing the FROM
-- clause), while 'rawSql' assumes that it's just the
-- name of the table (which doesn't allow self-joins, for
-- example).
name = ident <> "."
ret = let ed = entityDef $ getEntityVal expr
in (length (entityFields ed) + 1, process ed, mempty)
getEntityVal :: SqlExpr (Entity a) -> a
getEntityVal = error "Database.Esqueleto.SqlSelect.getEntityVal"
sqlSelectCols _ _ = error "Esqueleto/Sql/sqlSelectCols[Entity]: never here (see GHC #6124)"
sqlSelectProcessRow (idCol:ent) =
Entity <$> fromPersistValue idCol
<*> fromPersistValues ent
sqlSelectProcessRow _ = Left "SqlSelect (Entity a): wrong number of columns."
instance PersistField a => SqlSelect (SqlExpr (Single a)) (Single a) where
sqlSelectCols esc (ERaw f) = let (b, vals) = f esc
in (1, parens b, vals)
sqlSelectCols _ _ = error "Esqueleto/Sql/sqlSelectCols[Single]: never here (see GHC #6124)"
sqlSelectProcessRow [pv] = Single <$> fromPersistValue pv
sqlSelectProcessRow _ = Left "SqlSelect (Single a): wrong number of columns."
instance ( SqlSelect a ra
, SqlSelect b rb
) => SqlSelect (a, b) (ra, rb) where
sqlSelectCols esc (a, b) =
uncommas'
[ sqlSelectCols esc a
, sqlSelectCols esc b
]
sqlSelectProcessRow =
let x = getType processRow
getType :: SqlSelect a r => (z -> Either y (r,x)) -> a
getType = error "Esqueleto/SqlSelect[(a,b)]/sqlSelectProcessRow/getType"
(colCountFst, _, _) = sqlSelectCols escape x
where escape = error "Esqueleto/SqlSelect[(a,b)]/sqlSelectProcessRow/escape"
processRow row =
let (rowFst, rowSnd) = splitAt colCountFst row
in (,) <$> sqlSelectProcessRow rowFst
<*> sqlSelectProcessRow rowSnd
in colCountFst `seq` processRow
-- Avoids recalculating 'colCountFst'.
instance ( SqlSelect a ra
, SqlSelect b rb
, SqlSelect c rc
) => SqlSelect (a, b, c) (ra, rb, rc) where
sqlSelectCols esc (a, b, c) =
uncommas'
[ sqlSelectCols esc a
, sqlSelectCols esc b
, sqlSelectCols esc c
]
sqlSelectProcessRow = fmap to3 . sqlSelectProcessRow
to3 :: ((a,b),c) -> (a,b,c)
to3 ((a,b),c) = (a,b,c)
instance ( SqlSelect a ra
, SqlSelect b rb
, SqlSelect c rc
, SqlSelect d rd
) => SqlSelect (a, b, c, d) (ra, rb, rc, rd) where
sqlSelectCols esc (a, b, c, d) =
uncommas'
[ sqlSelectCols esc a
, sqlSelectCols esc b
, sqlSelectCols esc c
, sqlSelectCols esc d
]
sqlSelectProcessRow = fmap to4 . sqlSelectProcessRow
to4 :: ((a,b),(c,d)) -> (a,b,c,d)
to4 ((a,b),(c,d)) = (a,b,c,d)
instance ( SqlSelect a ra
, SqlSelect b rb
, SqlSelect c rc
, SqlSelect d rd
, SqlSelect e re
) => SqlSelect (a, b, c, d, e) (ra, rb, rc, rd, re) where
sqlSelectCols esc (a, b, c, d, e) =
uncommas'
[ sqlSelectCols esc a
, sqlSelectCols esc b
, sqlSelectCols esc c
, sqlSelectCols esc d
, sqlSelectCols esc e
]
sqlSelectProcessRow = fmap to5 . sqlSelectProcessRow
to5 :: ((a,b),(c,d),e) -> (a,b,c,d,e)
to5 ((a,b),(c,d),e) = (a,b,c,d,e)
instance ( SqlSelect a ra
, SqlSelect b rb
, SqlSelect c rc
, SqlSelect d rd
, SqlSelect e re
, SqlSelect f rf
) => SqlSelect (a, b, c, d, e, f) (ra, rb, rc, rd, re, rf) where
sqlSelectCols esc (a, b, c, d, e, f) =
uncommas'
[ sqlSelectCols esc a
, sqlSelectCols esc b
, sqlSelectCols esc c
, sqlSelectCols esc d
, sqlSelectCols esc e
, sqlSelectCols esc f
]
sqlSelectProcessRow = fmap to6 . sqlSelectProcessRow
to6 :: ((a,b),(c,d),(e,f)) -> (a,b,c,d,e,f)
to6 ((a,b),(c,d),(e,f)) = (a,b,c,d,e,f)
instance ( SqlSelect a ra
, SqlSelect b rb
, SqlSelect c rc
, SqlSelect d rd
, SqlSelect e re
, SqlSelect f rf
, SqlSelect g rg
) => SqlSelect (a, b, c, d, e, f, g) (ra, rb, rc, rd, re, rf, rg) where
sqlSelectCols esc (a, b, c, d, e, f, g) =
uncommas'
[ sqlSelectCols esc a
, sqlSelectCols esc b
, sqlSelectCols esc c
, sqlSelectCols esc d
, sqlSelectCols esc e
, sqlSelectCols esc f
, sqlSelectCols esc g
]
sqlSelectProcessRow = fmap to7 . sqlSelectProcessRow
to7 :: ((a,b),(c,d),(e,f),g) -> (a,b,c,d,e,f,g)
to7 ((a,b),(c,d),(e,f),g) = (a,b,c,d,e,f,g)
instance ( SqlSelect a ra
, SqlSelect b rb
, SqlSelect c rc
, SqlSelect d rd
, SqlSelect e re
, SqlSelect f rf
, SqlSelect g rg
, SqlSelect h rh
) => SqlSelect (a, b, c, d, e, f, g, h) (ra, rb, rc, rd, re, rf, rg, rh) where
sqlSelectCols esc (a, b, c, d, e, f, g, h) =
uncommas'
[ sqlSelectCols esc a
, sqlSelectCols esc b
, sqlSelectCols esc c
, sqlSelectCols esc d
, sqlSelectCols esc e
, sqlSelectCols esc f
, sqlSelectCols esc g
, sqlSelectCols esc h
]
sqlSelectProcessRow = fmap to8 . sqlSelectProcessRow
to8 :: ((a,b),(c,d),(e,f),(g,h)) -> (a,b,c,d,e,f,g,h)
to8 ((a,b),(c,d),(e,f),(g,h)) = (a,b,c,d,e,f,g,h)