Thread IdentState through subqueries (fixes #28).

There used to be name clashes if a subquery referenced
an entity that was already being used on the outer query.
Now we thread the outer query's IdentState to its subqueries,
which use it instead of initialIdentState.

Note that clashes still may occur between subqueries of
a query, but I think that's harmless.
This commit is contained in:
Felipe Lessa 2013-09-15 04:16:35 -03:00
parent c5c76959bd
commit 33b1fafc2d

View File

@ -35,6 +35,9 @@ module Database.Esqueleto.Internal.Sql
, rawEsqueleto
, toRawSql
, Mode(..)
, IdentState
, initialIdentState
, IdentInfo
, SqlSelect
, veryUnsafeCoerceSqlExprValue
) where
@ -219,9 +222,13 @@ newIdentFor = Q . lift . try . unDBName
return (I t)
-- | Information needed to escape and use identifiers.
type IdentInfo = (Connection, IdentState)
-- | Use an identifier.
useIdent :: Connection -> Ident -> TLB.Builder
useIdent conn (I ident) = fromDBName conn $ DBName ident
useIdent :: IdentInfo -> Ident -> TLB.Builder
useIdent info (I ident) = fromDBName info $ DBName ident
----------------------------------------------------------------------
@ -240,7 +247,7 @@ data SqlExpr a where
-- connection (mainly for escaping names) and returns both an
-- string ('TLB.Builder') and a list of values to be
-- interpolated by the SQL backend.
ERaw :: NeedParens -> (Connection -> (TLB.Builder, [PersistValue])) -> SqlExpr (Value a)
ERaw :: NeedParens -> (IdentInfo -> (TLB.Builder, [PersistValue])) -> SqlExpr (Value a)
-- | 'EList' and 'EEmptyList' are used by list operators.
EList :: SqlExpr (Value a) -> SqlExpr (ValueList a)
@ -256,7 +263,7 @@ data SqlExpr a where
EPreprocessedFrom :: a -> FromClause -> SqlExpr (PreprocessedFrom a)
-- | Used by 'insertSelect'.
EInsert :: Proxy a -> (Connection -> (TLB.Builder, [PersistValue])) -> SqlExpr (Insertion a)
EInsert :: Proxy a -> (IdentInfo -> (TLB.Builder, [PersistValue])) -> SqlExpr (Insertion a)
data NeedParens = Parens | Never
@ -317,7 +324,7 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where
sub_selectDistinct = sub SELECT_DISTINCT
EEntity ident ^. field =
ERaw Never $ \conn -> (useIdent conn ident <> ("." <> fieldName conn field), [])
ERaw Never $ \info -> (useIdent info ident <> ("." <> fieldName info field), [])
EMaybe r ?. field = maybelize (r ^. field)
where
@ -331,10 +338,10 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where
nothing = unsafeSqlValue "NULL"
joinV (ERaw p f) = ERaw p f
countRows = unsafeSqlValue "COUNT(*)"
count (ERaw _ f) = ERaw Never $ \conn -> let (b, vals) = f conn
count (ERaw _ f) = ERaw Never $ \info -> let (b, vals) = f info
in ("COUNT" <> parens b, vals)
not_ (ERaw p f) = ERaw Never $ \conn -> let (b, vals) = f conn
not_ (ERaw p f) = ERaw Never $ \info -> let (b, vals) = f info
in ("NOT " <> parensM p b, vals)
(==.) = unsafeSqlBinOp " = "
@ -399,21 +406,21 @@ instance ToSomeValues SqlExpr (SqlExpr (Value a)) where
toSomeValues a = [SomeValue a]
fieldName :: (PersistEntity val, PersistField typ)
=> Connection -> EntityField val typ -> TLB.Builder
fieldName conn = fromDBName conn . fieldDB . persistFieldDef
=> IdentInfo -> EntityField val typ -> TLB.Builder
fieldName info = fromDBName info . fieldDB . persistFieldDef
setAux :: (PersistEntity val, PersistField typ)
=> EntityField val typ
-> (SqlExpr (Entity val) -> SqlExpr (Value typ))
-> SqlExpr (Update val)
setAux field mkVal = ESet $ \ent -> unsafeSqlBinOp " = " name (mkVal ent)
where name = ERaw Never $ \conn -> (fieldName conn field, mempty)
where name = ERaw Never $ \info -> (fieldName info field, mempty)
sub :: PersistField a => Mode -> SqlQuery (SqlExpr (Value a)) -> SqlExpr (Value a)
sub mode query = ERaw Parens $ \conn -> toRawSql mode pureQuery conn query
sub mode query = ERaw Parens $ \info -> toRawSql mode pureQuery info query
fromDBName :: Connection -> DBName -> TLB.Builder
fromDBName conn = TLB.fromText . connEscapeName conn
fromDBName :: IdentInfo -> DBName -> TLB.Builder
fromDBName (conn, _) = TLB.fromText . connEscapeName conn
existsHelper :: SqlQuery () -> SqlExpr (Value Bool)
existsHelper = sub SELECT . (>> return true)
@ -444,8 +451,8 @@ ifNotEmptyList (EList _) _ x = x
unsafeSqlBinOp :: TLB.Builder -> SqlExpr (Value a) -> SqlExpr (Value b) -> SqlExpr (Value c)
unsafeSqlBinOp op (ERaw p1 f1) (ERaw p2 f2) = ERaw Parens f
where
f conn = let (b1, vals1) = f1 conn
(b2, vals2) = f2 conn
f info = let (b1, vals1) = f1 info
(b2, vals2) = f2 info
in ( parensM p1 b1 <> op <> parensM p2 b2
, vals1 <> vals2 )
{-# INLINE unsafeSqlBinOp #-}
@ -463,9 +470,9 @@ unsafeSqlValue v = ERaw Never $ \_ -> (v, mempty)
unsafeSqlFunction :: UnsafeSqlFunctionArgument a =>
TLB.Builder -> a -> SqlExpr (Value b)
unsafeSqlFunction name arg =
ERaw Never $ \conn ->
ERaw Never $ \info ->
let (argsTLB, argsVals) =
uncommas' $ map (\(ERaw _ f) -> f conn) $ toArgList arg
uncommas' $ map (\(ERaw _ f) -> f info) $ toArgList arg
in (name <> parens argsTLB, argsVals)
class UnsafeSqlFunctionArgument a where
@ -527,7 +534,7 @@ rawSelectSource mode query = src
run conn =
uncurry rawQuery $
first builderToText $
toRawSql mode pureQuery conn query
toRawSql mode pureQuery (conn, initialIdentState) query
massage = do
mrow <- C.await
@ -639,7 +646,7 @@ rawEsqueleto mode query = do
conn <- SqlPersistT R.ask
uncurry rawExecuteCount $
first builderToText $
toRawSql mode pureQuery conn query
toRawSql mode pureQuery (conn, initialIdentState) query
-- | Execute an @esqueleto@ @DELETE@ query inside @persistent@'s
@ -723,24 +730,37 @@ builderToText = TL.toStrict . TLB.toLazyTextWith defaultChunkSize
-- @esqueleto@, instead of manually using this function (which is
-- possible but tedious), you may just turn on query logging of
-- @persistent@.
toRawSql :: SqlSelect a r => Mode -> QueryType a -> Connection -> SqlQuery a -> (TLB.Builder, [PersistValue])
toRawSql mode qt conn query =
let (ret, SideData fromClauses setClauses whereClauses groupByClause havingClause orderByClauses limitClause) =
flip S.evalState initialIdentState $
toRawSql :: SqlSelect a r => Mode -> QueryType a -> IdentInfo -> SqlQuery a -> (TLB.Builder, [PersistValue])
toRawSql mode qt (conn, firstIdentState) query =
let ((ret, sd), finalIdentState) =
flip S.runState firstIdentState $
W.runWriterT $
unQ query
SideData fromClauses
setClauses
whereClauses
groupByClause
havingClause
orderByClauses
limitClause = sd
-- Pass the finalIdentState (containing all identifiers
-- that were used) to the subsequent calls. This ensures
-- that no name clashes will occur on subqueries that may
-- appear on the expressions below.
info = (conn, finalIdentState)
in mconcat
[ makeInsert qt ret
, makeSelect conn mode ret
, makeFrom conn mode fromClauses
, makeSet conn setClauses
, makeWhere conn whereClauses
, makeGroupBy conn groupByClause
, makeHaving conn havingClause
, makeOrderBy conn orderByClauses
, makeLimit conn limitClause
, makeSelect info mode ret
, makeFrom info mode fromClauses
, makeSet info setClauses
, makeWhere info whereClauses
, makeGroupBy info groupByClause
, makeHaving info havingClause
, makeOrderBy info orderByClauses
, makeLimit info limitClause
]
-- | (Internal) Mode of query being converted by 'toRawSql'.
data Mode = SELECT | SELECT_DISTINCT | DELETE | UPDATE
@ -767,21 +787,21 @@ uncommas' :: Monoid a => [(TLB.Builder, a)] -> (TLB.Builder, a)
uncommas' = (uncommas *** mconcat) . unzip
makeSelect :: SqlSelect a r => Connection -> Mode -> a -> (TLB.Builder, [PersistValue])
makeSelect conn mode ret =
makeSelect :: SqlSelect a r => IdentInfo -> Mode -> a -> (TLB.Builder, [PersistValue])
makeSelect info mode ret =
case mode of
SELECT -> withCols "SELECT "
SELECT_DISTINCT -> withCols "SELECT DISTINCT "
DELETE -> plain "DELETE "
UPDATE -> plain "UPDATE "
where
withCols v = first (v <>) (sqlSelectCols conn ret)
withCols v = first (v <>) (sqlSelectCols info ret)
plain v = (v, [])
makeFrom :: Connection -> Mode -> [FromClause] -> (TLB.Builder, [PersistValue])
makeFrom :: IdentInfo -> Mode -> [FromClause] -> (TLB.Builder, [PersistValue])
makeFrom _ _ [] = mempty
makeFrom conn mode fs = ret
makeFrom info mode fs = ret
where
ret = case collectOnClauses fs of
Left expr -> throw $ mkExc expr
@ -802,8 +822,8 @@ makeFrom conn mode fs = ret
base ident@(I identText) def =
let db@(DBName dbText) = entityDB def
in ( if dbText == identText
then fromDBName conn db
else fromDBName conn db <> (" AS " <> useIdent conn ident)
then fromDBName info db
else fromDBName info db <> (" AS " <> useIdent info ident)
, mempty )
fromKind InnerJoinKind = " INNER JOIN "
@ -812,56 +832,56 @@ makeFrom conn mode fs = ret
fromKind RightOuterJoinKind = " RIGHT OUTER JOIN "
fromKind FullOuterJoinKind = " FULL OUTER JOIN "
makeOnClause (ERaw _ f) = first (" ON " <>) (f conn)
makeOnClause (ERaw _ f) = first (" ON " <>) (f info)
mkExc :: SqlExpr (Value Bool) -> OnClauseWithoutMatchingJoinException
mkExc (ERaw _ f) =
OnClauseWithoutMatchingJoinException $
TL.unpack $ TLB.toLazyText $ fst (f conn)
TL.unpack $ TLB.toLazyText $ fst (f info)
makeSet :: Connection -> [SetClause] -> (TLB.Builder, [PersistValue])
makeSet :: IdentInfo -> [SetClause] -> (TLB.Builder, [PersistValue])
makeSet _ [] = mempty
makeSet conn os = first ("\nSET " <>) $ uncommas' (map mk os)
makeSet info os = first ("\nSET " <>) $ uncommas' (map mk os)
where
mk (SetClause (ERaw _ f)) = f conn
mk (SetClause (ERaw _ f)) = f info
makeWhere :: Connection -> WhereClause -> (TLB.Builder, [PersistValue])
makeWhere :: IdentInfo -> WhereClause -> (TLB.Builder, [PersistValue])
makeWhere _ NoWhere = mempty
makeWhere conn (Where (ERaw _ f)) = first ("\nWHERE " <>) (f conn)
makeWhere info (Where (ERaw _ f)) = first ("\nWHERE " <>) (f info)
makeGroupBy :: Connection -> GroupByClause -> (TLB.Builder, [PersistValue])
makeGroupBy :: IdentInfo -> GroupByClause -> (TLB.Builder, [PersistValue])
makeGroupBy _ (GroupBy []) = (mempty, [])
makeGroupBy conn (GroupBy fields) = first ("\nGROUP BY " <>) build
makeGroupBy info (GroupBy fields) = first ("\nGROUP BY " <>) build
where
build = uncommas' $ map (\(SomeValue (ERaw _ f)) -> f conn) fields
build = uncommas' $ map (\(SomeValue (ERaw _ f)) -> f info) fields
makeHaving :: Connection -> WhereClause -> (TLB.Builder, [PersistValue])
makeHaving :: IdentInfo -> WhereClause -> (TLB.Builder, [PersistValue])
makeHaving _ NoWhere = mempty
makeHaving conn (Where (ERaw _ f)) = first ("\nHAVING " <>) (f conn)
makeHaving info (Where (ERaw _ f)) = first ("\nHAVING " <>) (f info)
makeOrderBy :: Connection -> [OrderByClause] -> (TLB.Builder, [PersistValue])
makeOrderBy :: IdentInfo -> [OrderByClause] -> (TLB.Builder, [PersistValue])
makeOrderBy _ [] = mempty
makeOrderBy conn os = first ("\nORDER BY " <>) $ uncommas' (map mk os)
makeOrderBy info os = first ("\nORDER BY " <>) $ uncommas' (map mk os)
where
mk (EOrderBy t (ERaw p f)) = first ((<> orderByType t) . parensM p) (f conn)
mk (EOrderBy t (ERaw p f)) = first ((<> orderByType t) . parensM p) (f info)
orderByType ASC = " ASC"
orderByType DESC = " DESC"
makeLimit :: Connection -> LimitClause -> (TLB.Builder, [PersistValue])
makeLimit :: IdentInfo -> LimitClause -> (TLB.Builder, [PersistValue])
makeLimit _ (Limit Nothing Nothing) = mempty
makeLimit _ (Limit Nothing (Just 0)) = mempty
makeLimit conn (Limit ml mo) = (ret, mempty)
makeLimit info (Limit ml mo) = (ret, mempty)
where
ret = TLB.singleton '\n' <> (limitTLB <> offsetTLB)
limitTLB =
case ml of
Just l -> "LIMIT " <> TLBI.decimal l
Nothing -> TLB.fromText (connNoLimit conn)
Nothing -> TLB.fromText (connNoLimit $ fst info)
offsetTLB =
case mo of
@ -886,7 +906,7 @@ class SqlSelect a r | a -> r, r -> a where
-- | Creates the variable part of the @SELECT@ query and
-- returns the list of 'PersistValue's that will be given to
-- 'rawQuery'.
sqlSelectCols :: Connection -> a -> (TLB.Builder, [PersistValue])
sqlSelectCols :: IdentInfo -> a -> (TLB.Builder, [PersistValue])
-- | Number of columns that will be consumed.
sqlSelectColCount :: Proxy a -> Int
@ -897,7 +917,7 @@ class SqlSelect a r | a -> r, r -> a where
-- | You may return an insertion of some PersistEntity
instance PersistEntity a => SqlSelect (SqlExpr (Insertion a)) (Insertion a) where
sqlSelectCols conn (EInsert _ f) = f conn
sqlSelectCols info (EInsert _ f) = f info
sqlSelectColCount = const 0
sqlSelectProcessRow = const (Right (error msg))
where
@ -913,10 +933,10 @@ instance SqlSelect () () where
-- | You may return an 'Entity' from a 'select' query.
instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where
sqlSelectCols conn expr@(EEntity ident) = ret
sqlSelectCols info expr@(EEntity ident) = ret
where
process ed = uncommas $
map ((name <>) . fromDBName conn) $
map ((name <>) . fromDBName info) $
(entityID ed:) $
map fieldDB $
entityFields ed
@ -926,7 +946,7 @@ instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where
-- clause), while 'rawSql' assumes that it's just the
-- name of the table (which doesn't allow self-joins, for
-- example).
name = useIdent conn ident <> "."
name = useIdent info ident <> "."
ret = let ed = entityDef $ getEntityVal $ return expr
in (process ed, mempty)
sqlSelectColCount = (+1) . length . entityFields . entityDef . getEntityVal
@ -941,7 +961,7 @@ getEntityVal = const Proxy
-- | You may return a possibly-@NULL@ 'Entity' from a 'select' query.
instance PersistEntity a => SqlSelect (SqlExpr (Maybe (Entity a))) (Maybe (Entity a)) where
sqlSelectCols conn (EMaybe ent) = sqlSelectCols conn ent
sqlSelectCols info (EMaybe ent) = sqlSelectCols info ent
sqlSelectColCount = sqlSelectColCount . fromEMaybe
where
fromEMaybe :: Proxy (SqlExpr (Maybe e)) -> Proxy (SqlExpr e)
@ -954,8 +974,8 @@ instance PersistEntity a => SqlSelect (SqlExpr (Maybe (Entity a))) (Maybe (Entit
-- | You may return any single value (i.e. a single column) from
-- a 'select' query.
instance PersistField a => SqlSelect (SqlExpr (Value a)) (Value a) where
sqlSelectCols esc (ERaw p f) = let (b, vals) = f esc
in (parensM p b, vals)
sqlSelectCols info (ERaw p f) = let (b, vals) = f info
in (parensM p b, vals)
sqlSelectColCount = const 1
sqlSelectProcessRow [pv] = Value <$> fromPersistValue pv
sqlSelectProcessRow _ = Left "SqlSelect (Value a): wrong number of columns."
@ -1468,4 +1488,4 @@ insertGeneralSelect :: (MonadLogger m, MonadResourceBase m, SqlSelect (SqlExpr (
Mode -> SqlQuery (SqlExpr (Insertion a)) -> SqlPersistT m ()
insertGeneralSelect mode query = do
conn <- SqlPersistT R.ask
uncurry rawExecute $ first builderToText $ toRawSql mode insertQuery conn query
uncurry rawExecute $ first builderToText $ toRawSql mode insertQuery (conn, initialIdentState) query