esqueleto/src/Database/Esqueleto/Internal/Sql.hs
2012-09-03 16:39:01 -03:00

318 lines
8.7 KiB
Haskell

{-# LANGUAGE ConstraintKinds, MultiParamTypeClasses, FunctionalDependencies, FlexibleContexts, FlexibleInstances, UndecidableInstances, GADTs, OverloadedStrings #-}
module Database.Esqueleto.Internal.Sql
( SqlQuery
, SqlExpr
, select
, toRawSelectSql
) where
import Control.Applicative (Applicative(..), (<$>))
import Control.Arrow (first, (&&&))
import Control.Monad (ap)
import Control.Monad.Logger (MonadLogger)
import Control.Monad.Trans.Resource (MonadResourceBase)
import Data.List (intersperse)
import Data.Monoid (Monoid(..), (<>))
import Database.Persist.EntityDef
import Database.Persist.GenericSql
import Database.Persist.GenericSql.Internal (Connection(escapeName))
import Database.Persist.Store
import qualified Control.Monad.Supply as S
import qualified Control.Monad.Trans.Reader as R
import qualified Control.Monad.Trans.Writer as W
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Builder as TLB
import Database.Esqueleto.Internal.Language
-- | SQL backend for 'Esqueleto' using 'SqlPersist'.
newtype SqlQuery a =
Q { unQ :: W.WriterT SideData (S.Supply Ident) a }
instance Functor SqlQuery where
fmap f = Q . fmap f . unQ
instance Monad SqlQuery where
return = Q . return
m >>= f = Q (unQ m >>= unQ . f)
instance Applicative SqlQuery where
pure = return
(<*>) = ap
-- | Side data written by 'SqlQuery'.
data SideData = SideData { sdFromClause :: ![FromClause]
, sdWhereClause :: !WhereClause
}
instance Monoid SideData where
mempty = SideData mempty mempty
SideData f w `mappend` SideData f' w' =
SideData (f <> f') (w <> w')
-- | A part of a @FROM@ clause.
data FromClause = From Ident EntityDef
-- | A complere @WHERE@ clause.
data WhereClause = Where (SqlExpr (Single Bool))
| NoWhere
instance Monoid WhereClause where
mempty = NoWhere
NoWhere `mappend` w = w
w `mappend` NoWhere = w
Where e1 `mappend` Where e2 = Where (e1 &&. e2)
-- | Identifier used for tables.
newtype Ident = I TLB.Builder
-- | Infinite list of identifiers.
idents :: () -- ^ Avoid keeping everything in memory.
-> [Ident]
idents _ =
let alpha = ['A'..'Z']
letters 1 = map return alpha
letters n = (:) <$> alpha <*> letters (n-1)
everything = concat (map letters [(1::Int)..])
in map (I . TLB.fromString . ('T':)) everything
-- | An expression on the SQL backend.
data SqlExpr a where
EEntity :: Ident -> SqlExpr (Entity val)
ERaw :: (Escape -> (TLB.Builder, [PersistValue])) -> SqlExpr (Single a)
type Escape = DBName -> TLB.Builder
instance Esqueleto SqlQuery SqlExpr SqlPersist where
fromSingle = Q $ do
ident <- S.supply
let from_ = From ident $ entityDef (getVal ret)
ret = EEntity ident
getVal :: SqlExpr (Entity val) -> val
getVal = error "SqlQuery/getVal: never here"
W.tell mempty { sdFromClause = [from_] }
return ret
where_ expr = Q $ W.tell mempty { sdWhereClause = Where expr }
sub query = ERaw $ \esc -> first parens (toRawSelectSql esc query)
EEntity (I ident) ^. field = ERaw $ \esc -> (ident <> ("." <> name esc field), [])
where name esc = esc . fieldDB . persistFieldDef
val = ERaw . const . (,) "?" . return . toPersistValue
not_ (ERaw f) = ERaw $ \esc -> let (b, vals) = f esc
in ("NOT " <> parens b, vals)
(==.) = binop " = "
(>=.) = binop " >= "
(>.) = binop " > "
(<=.) = binop " <= "
(<.) = binop " < "
(!=.) = binop " != "
(&&.) = binop " AND "
(||.) = binop " OR "
(+.) = binop " + "
(-.) = binop " - "
(/.) = binop " / "
(*.) = binop " * "
fromDBName :: Connection -> DBName -> TLB.Builder
fromDBName conn = TLB.fromText . escapeName conn
binop :: TLB.Builder -> SqlExpr (Single a) -> SqlExpr (Single b) -> SqlExpr (Single c)
binop op (ERaw f1) (ERaw f2) = ERaw f
where
f esc = let (b1, vals1) = f1 esc
(b2, vals2) = f2 esc
in ( parens b1 <> op <> parens b2
, vals1 <> vals2 )
-- | Execute an Esqueleto's 'SqlQuery' inside @persistent@'s
-- 'SqlPersist' monad.
select :: ( SqlSelect a r
, RawSql r
, MonadLogger m
, MonadResourceBase m)
=> SqlQuery a -> SqlPersist m [r]
select query = do
conn <- getConnection
uncurry rawSql $
first (TL.toStrict . TLB.toLazyText) $
toRawSelectSql (fromDBName conn) query
-- | Get current database 'Connection'.
getConnection :: Monad m => SqlPersist m Connection
getConnection = SqlPersist R.ask
-- | Pretty prints a 'SqlQuery' into a SQL query.
toRawSelectSql :: SqlSelect a r => Escape -> SqlQuery a -> (TLB.Builder, [PersistValue])
toRawSelectSql esc query =
let (ret, SideData fromClauses whereClauses) =
flip S.evalSupply (idents ()) $
W.runWriterT $
unQ query
(selectText, selectVars) = makeSelect esc ret
(whereText, whereVars) = makeWhere esc whereClauses
text = mconcat
[ "SELECT "
, selectText
, makeFrom esc fromClauses
, whereText
]
in (text, selectVars <> whereVars)
uncommas :: [TLB.Builder] -> TLB.Builder
uncommas = mconcat . intersperse ", "
uncommas' :: Monoid a => [(TLB.Builder, a)] -> (TLB.Builder, a)
uncommas' = uncommas . map fst &&& mconcat . map snd
makeFrom :: Escape -> [FromClause] -> TLB.Builder
makeFrom esc = uncommas . map mk
where
mk (From (I i) def) = esc (entityDB def) <> (" AS " <> i)
makeWhere :: Escape -> WhereClause -> (TLB.Builder, [PersistValue])
makeWhere _ NoWhere = mempty
makeWhere esc (Where (ERaw f)) = first ("\nWHERE " <>) (f esc)
parens :: TLB.Builder -> TLB.Builder
parens b = "(" <> (b <> "(")
-- | Class for mapping results coming from 'SqlQuery' into actual
-- results.
class RawSql r => SqlSelect a r | a -> r, r -> a where
makeSelect :: Escape -> a -> (TLB.Builder, [PersistValue])
instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where
makeSelect _ (EEntity _) = ("??", mempty)
instance PersistField a => SqlSelect (SqlExpr (Single a)) (Single a) where
makeSelect esc (ERaw f) = first parens (f esc)
instance ( SqlSelect a ra
, SqlSelect b rb
) => SqlSelect (a, b) (ra, rb) where
makeSelect esc (a, b) =
uncommas'
[ makeSelect esc a
, makeSelect esc b
]
instance ( SqlSelect a ra
, SqlSelect b rb
, SqlSelect c rc
) => SqlSelect (a, b, c) (ra, rb, rc) where
makeSelect esc (a, b, c) =
uncommas'
[ makeSelect esc a
, makeSelect esc b
, makeSelect esc c
]
instance ( SqlSelect a ra
, SqlSelect b rb
, SqlSelect c rc
, SqlSelect d rd
) => SqlSelect (a, b, c, d) (ra, rb, rc, rd) where
makeSelect esc (a, b, c, d) =
uncommas'
[ makeSelect esc a
, makeSelect esc b
, makeSelect esc c
, makeSelect esc d
]
instance ( SqlSelect a ra
, SqlSelect b rb
, SqlSelect c rc
, SqlSelect d rd
, SqlSelect e re
) => SqlSelect (a, b, c, d, e) (ra, rb, rc, rd, re) where
makeSelect esc (a, b, c, d, e) =
uncommas'
[ makeSelect esc a
, makeSelect esc b
, makeSelect esc c
, makeSelect esc d
, makeSelect esc e
]
instance ( SqlSelect a ra
, SqlSelect b rb
, SqlSelect c rc
, SqlSelect d rd
, SqlSelect e re
, SqlSelect f rf
) => SqlSelect (a, b, c, d, e, f) (ra, rb, rc, rd, re, rf) where
makeSelect esc (a, b, c, d, e, f) =
uncommas'
[ makeSelect esc a
, makeSelect esc b
, makeSelect esc c
, makeSelect esc d
, makeSelect esc e
, makeSelect esc f
]
instance ( SqlSelect a ra
, SqlSelect b rb
, SqlSelect c rc
, SqlSelect d rd
, SqlSelect e re
, SqlSelect f rf
, SqlSelect g rg
) => SqlSelect (a, b, c, d, e, f, g) (ra, rb, rc, rd, re, rf, rg) where
makeSelect esc (a, b, c, d, e, f, g) =
uncommas'
[ makeSelect esc a
, makeSelect esc b
, makeSelect esc c
, makeSelect esc d
, makeSelect esc e
, makeSelect esc f
, makeSelect esc g
]
instance ( SqlSelect a ra
, SqlSelect b rb
, SqlSelect c rc
, SqlSelect d rd
, SqlSelect e re
, SqlSelect f rf
, SqlSelect g rg
, SqlSelect h rh
) => SqlSelect (a, b, c, d, e, f, g, h) (ra, rb, rc, rd, re, rf, rg, rh) where
makeSelect esc (a, b, c, d, e, f, g, h) =
uncommas'
[ makeSelect esc a
, makeSelect esc b
, makeSelect esc c
, makeSelect esc d
, makeSelect esc e
, makeSelect esc f
, makeSelect esc g
, makeSelect esc h
]