diff --git a/esqueleto.cabal b/esqueleto.cabal index 7fe2782..921472e 100644 --- a/esqueleto.cabal +++ b/esqueleto.cabal @@ -49,11 +49,11 @@ library Database.Esqueleto.Internal.Language Database.Esqueleto.Internal.Sql build-depends: - base == 4.5.* - , text == 0.11.* - , persistent >= 1.0.1 && < 1.1 - , transformers == 0.3.* - , monad-supply == 0.3.* + base == 4.5.* + , text == 0.11.* + , persistent >= 1.0.1 && < 1.1 + , transformers == 0.3.* + , unordered-containers >= 0.2 , monad-logger , conduit diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index 980a0c7..eec506b 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -26,6 +26,7 @@ import Control.Exception (throw, throwIO) import Control.Monad ((>=>), ap, MonadPlus(..)) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Logger (MonadLogger) +import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Resource (MonadResourceBase) import Data.List (intersperse) import Data.Monoid (Monoid(..), (<>)) @@ -34,11 +35,12 @@ import Database.Persist.GenericSql import Database.Persist.GenericSql.Internal (Connection(escapeName)) import Database.Persist.GenericSql.Raw (withStmt) import Database.Persist.Store -import qualified Control.Monad.Supply as S import qualified Control.Monad.Trans.Reader as R +import qualified Control.Monad.Trans.State as S import qualified Control.Monad.Trans.Writer as W import qualified Data.Conduit as C import qualified Data.Conduit.List as CL +import qualified Data.HashSet as HS import qualified Data.Text as T import qualified Data.Text.Lazy as TL import qualified Data.Text.Lazy.Builder as TLB @@ -48,7 +50,7 @@ import Database.Esqueleto.Internal.Language -- | SQL backend for @esqueleto@ using 'SqlPersist'. newtype SqlQuery a = - Q { unQ :: W.WriterT SideData (S.Supply Ident) a } + Q { unQ :: W.WriterT SideData (S.State IdentState) a } instance Functor SqlQuery where fmap f = Q . fmap f . unQ @@ -62,6 +64,9 @@ instance Applicative SqlQuery where (<*>) = ap +---------------------------------------------------------------------- + + -- | Side data written by 'SqlQuery'. data SideData = SideData { sdFromClause :: ![FromClause] , sdWhereClause :: !WhereClause @@ -124,19 +129,56 @@ instance Monoid WhereClause where type OrderByClause = SqlExpr OrderBy --- | Identifier used for tables. -newtype Ident = I TLB.Builder +---------------------------------------------------------------------- --- | Infinite list of identifiers. -idents :: () -- ^ Avoid keeping everything in memory. - -> [Ident] -idents _ = - let alpha = ['A'..'Z'] - letters 1 = map return alpha - letters n = (:) <$> alpha <*> letters (n-1) - everything = concat (map letters [(1::Int)..]) - in map (I . TLB.fromString . ('T':)) everything +-- | Identifier used for table names. +newtype Ident = I T.Text + + +-- | List of identifiers already in use and supply of temporary +-- identifiers. +data IdentState = IdentState { inUse :: !(HS.HashSet T.Text) + , fresh :: ![T.Text] } + +initialIdentState :: IdentState +initialIdentState = IdentState mempty idents + where + idents = + let alpha = ['A'..'Z'] + letters 1 = map return alpha + letters n = (:) <$> alpha <*> letters (n-1) + everything = concat (map letters [(1::Int)..]) + in map T.pack everything + + +-- | Create a fresh 'Ident'. If possible, use the given +-- 'DBName'. +newIdentFor :: DBName -> SqlQuery Ident +newIdentFor = Q . lift . try . unDBName + where + try t = do + s <- S.get + if t `HS.member` inUse s + then newIdent + else markAsUsed t >> return (I t) + + newIdent = do + s <- S.get + let (f:fs) = fresh s + S.put s { fresh = fs } + try f + + markAsUsed t = + S.modify (\s -> s { inUse = HS.insert t (inUse s) }) + + +-- | Use an identifier. +useIdent :: Escape -> Ident -> TLB.Builder +useIdent esc (I ident) = esc (DBName ident) + + +---------------------------------------------------------------------- -- | An expression on the SQL backend. @@ -153,11 +195,16 @@ type Escape = DBName -> TLB.Builder instance Esqueleto SqlQuery SqlExpr SqlPersist where - fromStart = Q $ do - ident <- S.supply - let ret = EEntity ident - from_ = FromStart ident $ entityDef (getVal ret) - return (EPreprocessedFrom ret from_) + fromStart = x + where + x = do + let ed = entityDef (getVal x) + ident <- newIdentFor (entityDB ed) + let ret = EEntity ident + from_ = FromStart ident ed + return (EPreprocessedFrom ret from_) + getVal :: SqlQuery (SqlExpr (PreprocessedFrom (SqlExpr (Entity a)))) -> a + getVal = error "Esqueleto/Sql/fromStart/getVal: never here" fromStartMaybe = maybelize <$> fromStart where @@ -192,7 +239,7 @@ instance Esqueleto SqlQuery SqlExpr SqlPersist where sub_select = sub SELECT sub_selectDistinct = sub SELECT_DISTINCT - EEntity (I ident) ^. field = ERaw $ \esc -> (ident <> ("." <> name esc field), []) + EEntity ident ^. field = ERaw $ \esc -> (useIdent esc ident <> ("." <> name esc field), []) where name esc = esc . fieldDB . persistFieldDef _ ^. _ = error "Esqueleto/Sql/(^.): never here (see GHC #6124)" @@ -231,9 +278,6 @@ instance Esqueleto SqlQuery SqlExpr SqlPersist where sub :: PersistField a => Mode -> SqlQuery (SqlExpr (Single a)) -> SqlExpr (Single a) sub mode query = ERaw $ \esc -> first parens (toRawSql mode esc query) -getVal :: SqlExpr (Entity val) -> val -getVal = error "SqlQuery/getVal: never here" - fromDBName :: Connection -> DBName -> TLB.Builder fromDBName conn = TLB.fromText . escapeName conn @@ -247,6 +291,9 @@ binop op (ERaw f1) (ERaw f2) = ERaw f binop _ _ _ = error "Esqueleto/Sql/binop: never here (see GHC #6124)" +---------------------------------------------------------------------- + + -- | (Internal) Execute an @esqueleto@ @SELECT@ 'SqlQuery' inside -- @persistent@'s 'SqlPersist' monad. rawSelectSource :: ( SqlSelect a r @@ -327,7 +374,7 @@ runSource src = C.runResourceT $ src C.$$ CL.consume toRawSql :: SqlSelect a r => Mode -> Escape -> SqlQuery a -> (TLB.Builder, [PersistValue]) toRawSql mode esc query = let (ret, SideData fromClauses whereClauses orderByClauses) = - flip S.evalSupply (idents ()) $ + flip S.evalState initialIdentState $ W.runWriterT $ unQ query in mconcat @@ -363,7 +410,7 @@ makeFrom esc fs = ret Left expr -> throw $ mkExc expr Right fs' -> first ("\nFROM " <>) $ uncommas' (map (mk False mempty) fs') - mk _ onClause (FromStart (I i) def) = base i def <> onClause + mk _ onClause (FromStart i def) = base i def <> onClause mk paren onClause (FromJoin lhs kind rhs monClause) = (if paren then first parens else id) $ mconcat [ mk True onClause lhs @@ -372,7 +419,12 @@ makeFrom esc fs = ret ] mk _ _ (OnClause _) = error "Esqueleto/Sql/makeFrom: never here (is collectOnClauses working?)" - base i def = (esc (entityDB def) <> (" AS " <> i), mempty) + base ident@(I identText) def = + let db@(DBName dbText) = entityDB def + in ( if dbText == identText + then esc db + else esc db <> (" AS " <> useIdent esc ident) + , mempty ) fromKind InnerJoinKind = " INNER JOIN " fromKind CrossJoinKind = " CROSS JOIN " @@ -430,7 +482,7 @@ class SqlSelect a r | a -> r, r -> a where instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where - sqlSelectCols escape expr@(EEntity (I ident)) = ret + sqlSelectCols escape expr@(EEntity ident) = ret where process ed = uncommas $ map ((name <>) . escape) $ @@ -443,7 +495,7 @@ instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where -- clause), while 'rawSql' assumes that it's just the -- name of the table (which doesn't allow self-joins, for -- example). - name = ident <> "." + name = useIdent escape ident <> "." ret = let ed = entityDef $ getEntityVal expr in (process ed, mempty) sqlSelectCols _ _ = error "Esqueleto/Sql/sqlSelectCols[Entity]: never here (see GHC #6124)"