Thread IdentState through subqueries (fixes #28).
There used to be name clashes if a subquery referenced an entity that was already being used on the outer query. Now we thread the outer query's IdentState to its subqueries, which use it instead of initialIdentState. Note that clashes still may occur between subqueries of a query, but I think that's harmless.
This commit is contained in:
parent
c5c76959bd
commit
33b1fafc2d
@ -35,6 +35,9 @@ module Database.Esqueleto.Internal.Sql
|
||||
, rawEsqueleto
|
||||
, toRawSql
|
||||
, Mode(..)
|
||||
, IdentState
|
||||
, initialIdentState
|
||||
, IdentInfo
|
||||
, SqlSelect
|
||||
, veryUnsafeCoerceSqlExprValue
|
||||
) where
|
||||
@ -219,9 +222,13 @@ newIdentFor = Q . lift . try . unDBName
|
||||
return (I t)
|
||||
|
||||
|
||||
-- | Information needed to escape and use identifiers.
|
||||
type IdentInfo = (Connection, IdentState)
|
||||
|
||||
|
||||
-- | Use an identifier.
|
||||
useIdent :: Connection -> Ident -> TLB.Builder
|
||||
useIdent conn (I ident) = fromDBName conn $ DBName ident
|
||||
useIdent :: IdentInfo -> Ident -> TLB.Builder
|
||||
useIdent info (I ident) = fromDBName info $ DBName ident
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
@ -240,7 +247,7 @@ data SqlExpr a where
|
||||
-- connection (mainly for escaping names) and returns both an
|
||||
-- string ('TLB.Builder') and a list of values to be
|
||||
-- interpolated by the SQL backend.
|
||||
ERaw :: NeedParens -> (Connection -> (TLB.Builder, [PersistValue])) -> SqlExpr (Value a)
|
||||
ERaw :: NeedParens -> (IdentInfo -> (TLB.Builder, [PersistValue])) -> SqlExpr (Value a)
|
||||
|
||||
-- | 'EList' and 'EEmptyList' are used by list operators.
|
||||
EList :: SqlExpr (Value a) -> SqlExpr (ValueList a)
|
||||
@ -256,7 +263,7 @@ data SqlExpr a where
|
||||
EPreprocessedFrom :: a -> FromClause -> SqlExpr (PreprocessedFrom a)
|
||||
|
||||
-- | Used by 'insertSelect'.
|
||||
EInsert :: Proxy a -> (Connection -> (TLB.Builder, [PersistValue])) -> SqlExpr (Insertion a)
|
||||
EInsert :: Proxy a -> (IdentInfo -> (TLB.Builder, [PersistValue])) -> SqlExpr (Insertion a)
|
||||
|
||||
data NeedParens = Parens | Never
|
||||
|
||||
@ -317,7 +324,7 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where
|
||||
sub_selectDistinct = sub SELECT_DISTINCT
|
||||
|
||||
EEntity ident ^. field =
|
||||
ERaw Never $ \conn -> (useIdent conn ident <> ("." <> fieldName conn field), [])
|
||||
ERaw Never $ \info -> (useIdent info ident <> ("." <> fieldName info field), [])
|
||||
|
||||
EMaybe r ?. field = maybelize (r ^. field)
|
||||
where
|
||||
@ -331,10 +338,10 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where
|
||||
nothing = unsafeSqlValue "NULL"
|
||||
joinV (ERaw p f) = ERaw p f
|
||||
countRows = unsafeSqlValue "COUNT(*)"
|
||||
count (ERaw _ f) = ERaw Never $ \conn -> let (b, vals) = f conn
|
||||
count (ERaw _ f) = ERaw Never $ \info -> let (b, vals) = f info
|
||||
in ("COUNT" <> parens b, vals)
|
||||
|
||||
not_ (ERaw p f) = ERaw Never $ \conn -> let (b, vals) = f conn
|
||||
not_ (ERaw p f) = ERaw Never $ \info -> let (b, vals) = f info
|
||||
in ("NOT " <> parensM p b, vals)
|
||||
|
||||
(==.) = unsafeSqlBinOp " = "
|
||||
@ -399,21 +406,21 @@ instance ToSomeValues SqlExpr (SqlExpr (Value a)) where
|
||||
toSomeValues a = [SomeValue a]
|
||||
|
||||
fieldName :: (PersistEntity val, PersistField typ)
|
||||
=> Connection -> EntityField val typ -> TLB.Builder
|
||||
fieldName conn = fromDBName conn . fieldDB . persistFieldDef
|
||||
=> IdentInfo -> EntityField val typ -> TLB.Builder
|
||||
fieldName info = fromDBName info . fieldDB . persistFieldDef
|
||||
|
||||
setAux :: (PersistEntity val, PersistField typ)
|
||||
=> EntityField val typ
|
||||
-> (SqlExpr (Entity val) -> SqlExpr (Value typ))
|
||||
-> SqlExpr (Update val)
|
||||
setAux field mkVal = ESet $ \ent -> unsafeSqlBinOp " = " name (mkVal ent)
|
||||
where name = ERaw Never $ \conn -> (fieldName conn field, mempty)
|
||||
where name = ERaw Never $ \info -> (fieldName info field, mempty)
|
||||
|
||||
sub :: PersistField a => Mode -> SqlQuery (SqlExpr (Value a)) -> SqlExpr (Value a)
|
||||
sub mode query = ERaw Parens $ \conn -> toRawSql mode pureQuery conn query
|
||||
sub mode query = ERaw Parens $ \info -> toRawSql mode pureQuery info query
|
||||
|
||||
fromDBName :: Connection -> DBName -> TLB.Builder
|
||||
fromDBName conn = TLB.fromText . connEscapeName conn
|
||||
fromDBName :: IdentInfo -> DBName -> TLB.Builder
|
||||
fromDBName (conn, _) = TLB.fromText . connEscapeName conn
|
||||
|
||||
existsHelper :: SqlQuery () -> SqlExpr (Value Bool)
|
||||
existsHelper = sub SELECT . (>> return true)
|
||||
@ -444,8 +451,8 @@ ifNotEmptyList (EList _) _ x = x
|
||||
unsafeSqlBinOp :: TLB.Builder -> SqlExpr (Value a) -> SqlExpr (Value b) -> SqlExpr (Value c)
|
||||
unsafeSqlBinOp op (ERaw p1 f1) (ERaw p2 f2) = ERaw Parens f
|
||||
where
|
||||
f conn = let (b1, vals1) = f1 conn
|
||||
(b2, vals2) = f2 conn
|
||||
f info = let (b1, vals1) = f1 info
|
||||
(b2, vals2) = f2 info
|
||||
in ( parensM p1 b1 <> op <> parensM p2 b2
|
||||
, vals1 <> vals2 )
|
||||
{-# INLINE unsafeSqlBinOp #-}
|
||||
@ -463,9 +470,9 @@ unsafeSqlValue v = ERaw Never $ \_ -> (v, mempty)
|
||||
unsafeSqlFunction :: UnsafeSqlFunctionArgument a =>
|
||||
TLB.Builder -> a -> SqlExpr (Value b)
|
||||
unsafeSqlFunction name arg =
|
||||
ERaw Never $ \conn ->
|
||||
ERaw Never $ \info ->
|
||||
let (argsTLB, argsVals) =
|
||||
uncommas' $ map (\(ERaw _ f) -> f conn) $ toArgList arg
|
||||
uncommas' $ map (\(ERaw _ f) -> f info) $ toArgList arg
|
||||
in (name <> parens argsTLB, argsVals)
|
||||
|
||||
class UnsafeSqlFunctionArgument a where
|
||||
@ -527,7 +534,7 @@ rawSelectSource mode query = src
|
||||
run conn =
|
||||
uncurry rawQuery $
|
||||
first builderToText $
|
||||
toRawSql mode pureQuery conn query
|
||||
toRawSql mode pureQuery (conn, initialIdentState) query
|
||||
|
||||
massage = do
|
||||
mrow <- C.await
|
||||
@ -639,7 +646,7 @@ rawEsqueleto mode query = do
|
||||
conn <- SqlPersistT R.ask
|
||||
uncurry rawExecuteCount $
|
||||
first builderToText $
|
||||
toRawSql mode pureQuery conn query
|
||||
toRawSql mode pureQuery (conn, initialIdentState) query
|
||||
|
||||
|
||||
-- | Execute an @esqueleto@ @DELETE@ query inside @persistent@'s
|
||||
@ -723,24 +730,37 @@ builderToText = TL.toStrict . TLB.toLazyTextWith defaultChunkSize
|
||||
-- @esqueleto@, instead of manually using this function (which is
|
||||
-- possible but tedious), you may just turn on query logging of
|
||||
-- @persistent@.
|
||||
toRawSql :: SqlSelect a r => Mode -> QueryType a -> Connection -> SqlQuery a -> (TLB.Builder, [PersistValue])
|
||||
toRawSql mode qt conn query =
|
||||
let (ret, SideData fromClauses setClauses whereClauses groupByClause havingClause orderByClauses limitClause) =
|
||||
flip S.evalState initialIdentState $
|
||||
toRawSql :: SqlSelect a r => Mode -> QueryType a -> IdentInfo -> SqlQuery a -> (TLB.Builder, [PersistValue])
|
||||
toRawSql mode qt (conn, firstIdentState) query =
|
||||
let ((ret, sd), finalIdentState) =
|
||||
flip S.runState firstIdentState $
|
||||
W.runWriterT $
|
||||
unQ query
|
||||
SideData fromClauses
|
||||
setClauses
|
||||
whereClauses
|
||||
groupByClause
|
||||
havingClause
|
||||
orderByClauses
|
||||
limitClause = sd
|
||||
-- Pass the finalIdentState (containing all identifiers
|
||||
-- that were used) to the subsequent calls. This ensures
|
||||
-- that no name clashes will occur on subqueries that may
|
||||
-- appear on the expressions below.
|
||||
info = (conn, finalIdentState)
|
||||
in mconcat
|
||||
[ makeInsert qt ret
|
||||
, makeSelect conn mode ret
|
||||
, makeFrom conn mode fromClauses
|
||||
, makeSet conn setClauses
|
||||
, makeWhere conn whereClauses
|
||||
, makeGroupBy conn groupByClause
|
||||
, makeHaving conn havingClause
|
||||
, makeOrderBy conn orderByClauses
|
||||
, makeLimit conn limitClause
|
||||
, makeSelect info mode ret
|
||||
, makeFrom info mode fromClauses
|
||||
, makeSet info setClauses
|
||||
, makeWhere info whereClauses
|
||||
, makeGroupBy info groupByClause
|
||||
, makeHaving info havingClause
|
||||
, makeOrderBy info orderByClauses
|
||||
, makeLimit info limitClause
|
||||
]
|
||||
|
||||
|
||||
-- | (Internal) Mode of query being converted by 'toRawSql'.
|
||||
data Mode = SELECT | SELECT_DISTINCT | DELETE | UPDATE
|
||||
|
||||
@ -767,21 +787,21 @@ uncommas' :: Monoid a => [(TLB.Builder, a)] -> (TLB.Builder, a)
|
||||
uncommas' = (uncommas *** mconcat) . unzip
|
||||
|
||||
|
||||
makeSelect :: SqlSelect a r => Connection -> Mode -> a -> (TLB.Builder, [PersistValue])
|
||||
makeSelect conn mode ret =
|
||||
makeSelect :: SqlSelect a r => IdentInfo -> Mode -> a -> (TLB.Builder, [PersistValue])
|
||||
makeSelect info mode ret =
|
||||
case mode of
|
||||
SELECT -> withCols "SELECT "
|
||||
SELECT_DISTINCT -> withCols "SELECT DISTINCT "
|
||||
DELETE -> plain "DELETE "
|
||||
UPDATE -> plain "UPDATE "
|
||||
where
|
||||
withCols v = first (v <>) (sqlSelectCols conn ret)
|
||||
withCols v = first (v <>) (sqlSelectCols info ret)
|
||||
plain v = (v, [])
|
||||
|
||||
|
||||
makeFrom :: Connection -> Mode -> [FromClause] -> (TLB.Builder, [PersistValue])
|
||||
makeFrom :: IdentInfo -> Mode -> [FromClause] -> (TLB.Builder, [PersistValue])
|
||||
makeFrom _ _ [] = mempty
|
||||
makeFrom conn mode fs = ret
|
||||
makeFrom info mode fs = ret
|
||||
where
|
||||
ret = case collectOnClauses fs of
|
||||
Left expr -> throw $ mkExc expr
|
||||
@ -802,8 +822,8 @@ makeFrom conn mode fs = ret
|
||||
base ident@(I identText) def =
|
||||
let db@(DBName dbText) = entityDB def
|
||||
in ( if dbText == identText
|
||||
then fromDBName conn db
|
||||
else fromDBName conn db <> (" AS " <> useIdent conn ident)
|
||||
then fromDBName info db
|
||||
else fromDBName info db <> (" AS " <> useIdent info ident)
|
||||
, mempty )
|
||||
|
||||
fromKind InnerJoinKind = " INNER JOIN "
|
||||
@ -812,56 +832,56 @@ makeFrom conn mode fs = ret
|
||||
fromKind RightOuterJoinKind = " RIGHT OUTER JOIN "
|
||||
fromKind FullOuterJoinKind = " FULL OUTER JOIN "
|
||||
|
||||
makeOnClause (ERaw _ f) = first (" ON " <>) (f conn)
|
||||
makeOnClause (ERaw _ f) = first (" ON " <>) (f info)
|
||||
|
||||
mkExc :: SqlExpr (Value Bool) -> OnClauseWithoutMatchingJoinException
|
||||
mkExc (ERaw _ f) =
|
||||
OnClauseWithoutMatchingJoinException $
|
||||
TL.unpack $ TLB.toLazyText $ fst (f conn)
|
||||
TL.unpack $ TLB.toLazyText $ fst (f info)
|
||||
|
||||
|
||||
makeSet :: Connection -> [SetClause] -> (TLB.Builder, [PersistValue])
|
||||
makeSet :: IdentInfo -> [SetClause] -> (TLB.Builder, [PersistValue])
|
||||
makeSet _ [] = mempty
|
||||
makeSet conn os = first ("\nSET " <>) $ uncommas' (map mk os)
|
||||
makeSet info os = first ("\nSET " <>) $ uncommas' (map mk os)
|
||||
where
|
||||
mk (SetClause (ERaw _ f)) = f conn
|
||||
mk (SetClause (ERaw _ f)) = f info
|
||||
|
||||
|
||||
makeWhere :: Connection -> WhereClause -> (TLB.Builder, [PersistValue])
|
||||
makeWhere :: IdentInfo -> WhereClause -> (TLB.Builder, [PersistValue])
|
||||
makeWhere _ NoWhere = mempty
|
||||
makeWhere conn (Where (ERaw _ f)) = first ("\nWHERE " <>) (f conn)
|
||||
makeWhere info (Where (ERaw _ f)) = first ("\nWHERE " <>) (f info)
|
||||
|
||||
|
||||
makeGroupBy :: Connection -> GroupByClause -> (TLB.Builder, [PersistValue])
|
||||
makeGroupBy :: IdentInfo -> GroupByClause -> (TLB.Builder, [PersistValue])
|
||||
makeGroupBy _ (GroupBy []) = (mempty, [])
|
||||
makeGroupBy conn (GroupBy fields) = first ("\nGROUP BY " <>) build
|
||||
makeGroupBy info (GroupBy fields) = first ("\nGROUP BY " <>) build
|
||||
where
|
||||
build = uncommas' $ map (\(SomeValue (ERaw _ f)) -> f conn) fields
|
||||
build = uncommas' $ map (\(SomeValue (ERaw _ f)) -> f info) fields
|
||||
|
||||
makeHaving :: Connection -> WhereClause -> (TLB.Builder, [PersistValue])
|
||||
makeHaving :: IdentInfo -> WhereClause -> (TLB.Builder, [PersistValue])
|
||||
makeHaving _ NoWhere = mempty
|
||||
makeHaving conn (Where (ERaw _ f)) = first ("\nHAVING " <>) (f conn)
|
||||
makeHaving info (Where (ERaw _ f)) = first ("\nHAVING " <>) (f info)
|
||||
|
||||
makeOrderBy :: Connection -> [OrderByClause] -> (TLB.Builder, [PersistValue])
|
||||
makeOrderBy :: IdentInfo -> [OrderByClause] -> (TLB.Builder, [PersistValue])
|
||||
makeOrderBy _ [] = mempty
|
||||
makeOrderBy conn os = first ("\nORDER BY " <>) $ uncommas' (map mk os)
|
||||
makeOrderBy info os = first ("\nORDER BY " <>) $ uncommas' (map mk os)
|
||||
where
|
||||
mk (EOrderBy t (ERaw p f)) = first ((<> orderByType t) . parensM p) (f conn)
|
||||
mk (EOrderBy t (ERaw p f)) = first ((<> orderByType t) . parensM p) (f info)
|
||||
orderByType ASC = " ASC"
|
||||
orderByType DESC = " DESC"
|
||||
|
||||
|
||||
makeLimit :: Connection -> LimitClause -> (TLB.Builder, [PersistValue])
|
||||
makeLimit :: IdentInfo -> LimitClause -> (TLB.Builder, [PersistValue])
|
||||
makeLimit _ (Limit Nothing Nothing) = mempty
|
||||
makeLimit _ (Limit Nothing (Just 0)) = mempty
|
||||
makeLimit conn (Limit ml mo) = (ret, mempty)
|
||||
makeLimit info (Limit ml mo) = (ret, mempty)
|
||||
where
|
||||
ret = TLB.singleton '\n' <> (limitTLB <> offsetTLB)
|
||||
|
||||
limitTLB =
|
||||
case ml of
|
||||
Just l -> "LIMIT " <> TLBI.decimal l
|
||||
Nothing -> TLB.fromText (connNoLimit conn)
|
||||
Nothing -> TLB.fromText (connNoLimit $ fst info)
|
||||
|
||||
offsetTLB =
|
||||
case mo of
|
||||
@ -886,7 +906,7 @@ class SqlSelect a r | a -> r, r -> a where
|
||||
-- | Creates the variable part of the @SELECT@ query and
|
||||
-- returns the list of 'PersistValue's that will be given to
|
||||
-- 'rawQuery'.
|
||||
sqlSelectCols :: Connection -> a -> (TLB.Builder, [PersistValue])
|
||||
sqlSelectCols :: IdentInfo -> a -> (TLB.Builder, [PersistValue])
|
||||
|
||||
-- | Number of columns that will be consumed.
|
||||
sqlSelectColCount :: Proxy a -> Int
|
||||
@ -897,7 +917,7 @@ class SqlSelect a r | a -> r, r -> a where
|
||||
|
||||
-- | You may return an insertion of some PersistEntity
|
||||
instance PersistEntity a => SqlSelect (SqlExpr (Insertion a)) (Insertion a) where
|
||||
sqlSelectCols conn (EInsert _ f) = f conn
|
||||
sqlSelectCols info (EInsert _ f) = f info
|
||||
sqlSelectColCount = const 0
|
||||
sqlSelectProcessRow = const (Right (error msg))
|
||||
where
|
||||
@ -913,10 +933,10 @@ instance SqlSelect () () where
|
||||
|
||||
-- | You may return an 'Entity' from a 'select' query.
|
||||
instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where
|
||||
sqlSelectCols conn expr@(EEntity ident) = ret
|
||||
sqlSelectCols info expr@(EEntity ident) = ret
|
||||
where
|
||||
process ed = uncommas $
|
||||
map ((name <>) . fromDBName conn) $
|
||||
map ((name <>) . fromDBName info) $
|
||||
(entityID ed:) $
|
||||
map fieldDB $
|
||||
entityFields ed
|
||||
@ -926,7 +946,7 @@ instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where
|
||||
-- clause), while 'rawSql' assumes that it's just the
|
||||
-- name of the table (which doesn't allow self-joins, for
|
||||
-- example).
|
||||
name = useIdent conn ident <> "."
|
||||
name = useIdent info ident <> "."
|
||||
ret = let ed = entityDef $ getEntityVal $ return expr
|
||||
in (process ed, mempty)
|
||||
sqlSelectColCount = (+1) . length . entityFields . entityDef . getEntityVal
|
||||
@ -941,7 +961,7 @@ getEntityVal = const Proxy
|
||||
|
||||
-- | You may return a possibly-@NULL@ 'Entity' from a 'select' query.
|
||||
instance PersistEntity a => SqlSelect (SqlExpr (Maybe (Entity a))) (Maybe (Entity a)) where
|
||||
sqlSelectCols conn (EMaybe ent) = sqlSelectCols conn ent
|
||||
sqlSelectCols info (EMaybe ent) = sqlSelectCols info ent
|
||||
sqlSelectColCount = sqlSelectColCount . fromEMaybe
|
||||
where
|
||||
fromEMaybe :: Proxy (SqlExpr (Maybe e)) -> Proxy (SqlExpr e)
|
||||
@ -954,8 +974,8 @@ instance PersistEntity a => SqlSelect (SqlExpr (Maybe (Entity a))) (Maybe (Entit
|
||||
-- | You may return any single value (i.e. a single column) from
|
||||
-- a 'select' query.
|
||||
instance PersistField a => SqlSelect (SqlExpr (Value a)) (Value a) where
|
||||
sqlSelectCols esc (ERaw p f) = let (b, vals) = f esc
|
||||
in (parensM p b, vals)
|
||||
sqlSelectCols info (ERaw p f) = let (b, vals) = f info
|
||||
in (parensM p b, vals)
|
||||
sqlSelectColCount = const 1
|
||||
sqlSelectProcessRow [pv] = Value <$> fromPersistValue pv
|
||||
sqlSelectProcessRow _ = Left "SqlSelect (Value a): wrong number of columns."
|
||||
@ -1468,4 +1488,4 @@ insertGeneralSelect :: (MonadLogger m, MonadResourceBase m, SqlSelect (SqlExpr (
|
||||
Mode -> SqlQuery (SqlExpr (Insertion a)) -> SqlPersistT m ()
|
||||
insertGeneralSelect mode query = do
|
||||
conn <- SqlPersistT R.ask
|
||||
uncurry rawExecute $ first builderToText $ toRawSql mode insertQuery conn query
|
||||
uncurry rawExecute $ first builderToText $ toRawSql mode insertQuery (conn, initialIdentState) query
|
||||
|
||||
Loading…
Reference in New Issue
Block a user