From 33b1fafc2d9898740f10597609febcd1c89d37a9 Mon Sep 17 00:00:00 2001 From: Felipe Lessa Date: Sun, 15 Sep 2013 04:16:35 -0300 Subject: [PATCH] Thread IdentState through subqueries (fixes #28). There used to be name clashes if a subquery referenced an entity that was already being used on the outer query. Now we thread the outer query's IdentState to its subqueries, which use it instead of initialIdentState. Note that clashes still may occur between subqueries of a query, but I think that's harmless. --- src/Database/Esqueleto/Internal/Sql.hs | 150 ++++++++++++++----------- 1 file changed, 85 insertions(+), 65 deletions(-) diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index 1c73b85..205a462 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -35,6 +35,9 @@ module Database.Esqueleto.Internal.Sql , rawEsqueleto , toRawSql , Mode(..) + , IdentState + , initialIdentState + , IdentInfo , SqlSelect , veryUnsafeCoerceSqlExprValue ) where @@ -219,9 +222,13 @@ newIdentFor = Q . lift . try . unDBName return (I t) +-- | Information needed to escape and use identifiers. +type IdentInfo = (Connection, IdentState) + + -- | Use an identifier. -useIdent :: Connection -> Ident -> TLB.Builder -useIdent conn (I ident) = fromDBName conn $ DBName ident +useIdent :: IdentInfo -> Ident -> TLB.Builder +useIdent info (I ident) = fromDBName info $ DBName ident ---------------------------------------------------------------------- @@ -240,7 +247,7 @@ data SqlExpr a where -- connection (mainly for escaping names) and returns both an -- string ('TLB.Builder') and a list of values to be -- interpolated by the SQL backend. - ERaw :: NeedParens -> (Connection -> (TLB.Builder, [PersistValue])) -> SqlExpr (Value a) + ERaw :: NeedParens -> (IdentInfo -> (TLB.Builder, [PersistValue])) -> SqlExpr (Value a) -- | 'EList' and 'EEmptyList' are used by list operators. EList :: SqlExpr (Value a) -> SqlExpr (ValueList a) @@ -256,7 +263,7 @@ data SqlExpr a where EPreprocessedFrom :: a -> FromClause -> SqlExpr (PreprocessedFrom a) -- | Used by 'insertSelect'. - EInsert :: Proxy a -> (Connection -> (TLB.Builder, [PersistValue])) -> SqlExpr (Insertion a) + EInsert :: Proxy a -> (IdentInfo -> (TLB.Builder, [PersistValue])) -> SqlExpr (Insertion a) data NeedParens = Parens | Never @@ -317,7 +324,7 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where sub_selectDistinct = sub SELECT_DISTINCT EEntity ident ^. field = - ERaw Never $ \conn -> (useIdent conn ident <> ("." <> fieldName conn field), []) + ERaw Never $ \info -> (useIdent info ident <> ("." <> fieldName info field), []) EMaybe r ?. field = maybelize (r ^. field) where @@ -331,10 +338,10 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where nothing = unsafeSqlValue "NULL" joinV (ERaw p f) = ERaw p f countRows = unsafeSqlValue "COUNT(*)" - count (ERaw _ f) = ERaw Never $ \conn -> let (b, vals) = f conn + count (ERaw _ f) = ERaw Never $ \info -> let (b, vals) = f info in ("COUNT" <> parens b, vals) - not_ (ERaw p f) = ERaw Never $ \conn -> let (b, vals) = f conn + not_ (ERaw p f) = ERaw Never $ \info -> let (b, vals) = f info in ("NOT " <> parensM p b, vals) (==.) = unsafeSqlBinOp " = " @@ -399,21 +406,21 @@ instance ToSomeValues SqlExpr (SqlExpr (Value a)) where toSomeValues a = [SomeValue a] fieldName :: (PersistEntity val, PersistField typ) - => Connection -> EntityField val typ -> TLB.Builder -fieldName conn = fromDBName conn . fieldDB . persistFieldDef + => IdentInfo -> EntityField val typ -> TLB.Builder +fieldName info = fromDBName info . fieldDB . persistFieldDef setAux :: (PersistEntity val, PersistField typ) => EntityField val typ -> (SqlExpr (Entity val) -> SqlExpr (Value typ)) -> SqlExpr (Update val) setAux field mkVal = ESet $ \ent -> unsafeSqlBinOp " = " name (mkVal ent) - where name = ERaw Never $ \conn -> (fieldName conn field, mempty) + where name = ERaw Never $ \info -> (fieldName info field, mempty) sub :: PersistField a => Mode -> SqlQuery (SqlExpr (Value a)) -> SqlExpr (Value a) -sub mode query = ERaw Parens $ \conn -> toRawSql mode pureQuery conn query +sub mode query = ERaw Parens $ \info -> toRawSql mode pureQuery info query -fromDBName :: Connection -> DBName -> TLB.Builder -fromDBName conn = TLB.fromText . connEscapeName conn +fromDBName :: IdentInfo -> DBName -> TLB.Builder +fromDBName (conn, _) = TLB.fromText . connEscapeName conn existsHelper :: SqlQuery () -> SqlExpr (Value Bool) existsHelper = sub SELECT . (>> return true) @@ -444,8 +451,8 @@ ifNotEmptyList (EList _) _ x = x unsafeSqlBinOp :: TLB.Builder -> SqlExpr (Value a) -> SqlExpr (Value b) -> SqlExpr (Value c) unsafeSqlBinOp op (ERaw p1 f1) (ERaw p2 f2) = ERaw Parens f where - f conn = let (b1, vals1) = f1 conn - (b2, vals2) = f2 conn + f info = let (b1, vals1) = f1 info + (b2, vals2) = f2 info in ( parensM p1 b1 <> op <> parensM p2 b2 , vals1 <> vals2 ) {-# INLINE unsafeSqlBinOp #-} @@ -463,9 +470,9 @@ unsafeSqlValue v = ERaw Never $ \_ -> (v, mempty) unsafeSqlFunction :: UnsafeSqlFunctionArgument a => TLB.Builder -> a -> SqlExpr (Value b) unsafeSqlFunction name arg = - ERaw Never $ \conn -> + ERaw Never $ \info -> let (argsTLB, argsVals) = - uncommas' $ map (\(ERaw _ f) -> f conn) $ toArgList arg + uncommas' $ map (\(ERaw _ f) -> f info) $ toArgList arg in (name <> parens argsTLB, argsVals) class UnsafeSqlFunctionArgument a where @@ -527,7 +534,7 @@ rawSelectSource mode query = src run conn = uncurry rawQuery $ first builderToText $ - toRawSql mode pureQuery conn query + toRawSql mode pureQuery (conn, initialIdentState) query massage = do mrow <- C.await @@ -639,7 +646,7 @@ rawEsqueleto mode query = do conn <- SqlPersistT R.ask uncurry rawExecuteCount $ first builderToText $ - toRawSql mode pureQuery conn query + toRawSql mode pureQuery (conn, initialIdentState) query -- | Execute an @esqueleto@ @DELETE@ query inside @persistent@'s @@ -723,24 +730,37 @@ builderToText = TL.toStrict . TLB.toLazyTextWith defaultChunkSize -- @esqueleto@, instead of manually using this function (which is -- possible but tedious), you may just turn on query logging of -- @persistent@. -toRawSql :: SqlSelect a r => Mode -> QueryType a -> Connection -> SqlQuery a -> (TLB.Builder, [PersistValue]) -toRawSql mode qt conn query = - let (ret, SideData fromClauses setClauses whereClauses groupByClause havingClause orderByClauses limitClause) = - flip S.evalState initialIdentState $ +toRawSql :: SqlSelect a r => Mode -> QueryType a -> IdentInfo -> SqlQuery a -> (TLB.Builder, [PersistValue]) +toRawSql mode qt (conn, firstIdentState) query = + let ((ret, sd), finalIdentState) = + flip S.runState firstIdentState $ W.runWriterT $ unQ query + SideData fromClauses + setClauses + whereClauses + groupByClause + havingClause + orderByClauses + limitClause = sd + -- Pass the finalIdentState (containing all identifiers + -- that were used) to the subsequent calls. This ensures + -- that no name clashes will occur on subqueries that may + -- appear on the expressions below. + info = (conn, finalIdentState) in mconcat [ makeInsert qt ret - , makeSelect conn mode ret - , makeFrom conn mode fromClauses - , makeSet conn setClauses - , makeWhere conn whereClauses - , makeGroupBy conn groupByClause - , makeHaving conn havingClause - , makeOrderBy conn orderByClauses - , makeLimit conn limitClause + , makeSelect info mode ret + , makeFrom info mode fromClauses + , makeSet info setClauses + , makeWhere info whereClauses + , makeGroupBy info groupByClause + , makeHaving info havingClause + , makeOrderBy info orderByClauses + , makeLimit info limitClause ] + -- | (Internal) Mode of query being converted by 'toRawSql'. data Mode = SELECT | SELECT_DISTINCT | DELETE | UPDATE @@ -767,21 +787,21 @@ uncommas' :: Monoid a => [(TLB.Builder, a)] -> (TLB.Builder, a) uncommas' = (uncommas *** mconcat) . unzip -makeSelect :: SqlSelect a r => Connection -> Mode -> a -> (TLB.Builder, [PersistValue]) -makeSelect conn mode ret = +makeSelect :: SqlSelect a r => IdentInfo -> Mode -> a -> (TLB.Builder, [PersistValue]) +makeSelect info mode ret = case mode of SELECT -> withCols "SELECT " SELECT_DISTINCT -> withCols "SELECT DISTINCT " DELETE -> plain "DELETE " UPDATE -> plain "UPDATE " where - withCols v = first (v <>) (sqlSelectCols conn ret) + withCols v = first (v <>) (sqlSelectCols info ret) plain v = (v, []) -makeFrom :: Connection -> Mode -> [FromClause] -> (TLB.Builder, [PersistValue]) +makeFrom :: IdentInfo -> Mode -> [FromClause] -> (TLB.Builder, [PersistValue]) makeFrom _ _ [] = mempty -makeFrom conn mode fs = ret +makeFrom info mode fs = ret where ret = case collectOnClauses fs of Left expr -> throw $ mkExc expr @@ -802,8 +822,8 @@ makeFrom conn mode fs = ret base ident@(I identText) def = let db@(DBName dbText) = entityDB def in ( if dbText == identText - then fromDBName conn db - else fromDBName conn db <> (" AS " <> useIdent conn ident) + then fromDBName info db + else fromDBName info db <> (" AS " <> useIdent info ident) , mempty ) fromKind InnerJoinKind = " INNER JOIN " @@ -812,56 +832,56 @@ makeFrom conn mode fs = ret fromKind RightOuterJoinKind = " RIGHT OUTER JOIN " fromKind FullOuterJoinKind = " FULL OUTER JOIN " - makeOnClause (ERaw _ f) = first (" ON " <>) (f conn) + makeOnClause (ERaw _ f) = first (" ON " <>) (f info) mkExc :: SqlExpr (Value Bool) -> OnClauseWithoutMatchingJoinException mkExc (ERaw _ f) = OnClauseWithoutMatchingJoinException $ - TL.unpack $ TLB.toLazyText $ fst (f conn) + TL.unpack $ TLB.toLazyText $ fst (f info) -makeSet :: Connection -> [SetClause] -> (TLB.Builder, [PersistValue]) +makeSet :: IdentInfo -> [SetClause] -> (TLB.Builder, [PersistValue]) makeSet _ [] = mempty -makeSet conn os = first ("\nSET " <>) $ uncommas' (map mk os) +makeSet info os = first ("\nSET " <>) $ uncommas' (map mk os) where - mk (SetClause (ERaw _ f)) = f conn + mk (SetClause (ERaw _ f)) = f info -makeWhere :: Connection -> WhereClause -> (TLB.Builder, [PersistValue]) +makeWhere :: IdentInfo -> WhereClause -> (TLB.Builder, [PersistValue]) makeWhere _ NoWhere = mempty -makeWhere conn (Where (ERaw _ f)) = first ("\nWHERE " <>) (f conn) +makeWhere info (Where (ERaw _ f)) = first ("\nWHERE " <>) (f info) -makeGroupBy :: Connection -> GroupByClause -> (TLB.Builder, [PersistValue]) +makeGroupBy :: IdentInfo -> GroupByClause -> (TLB.Builder, [PersistValue]) makeGroupBy _ (GroupBy []) = (mempty, []) -makeGroupBy conn (GroupBy fields) = first ("\nGROUP BY " <>) build +makeGroupBy info (GroupBy fields) = first ("\nGROUP BY " <>) build where - build = uncommas' $ map (\(SomeValue (ERaw _ f)) -> f conn) fields + build = uncommas' $ map (\(SomeValue (ERaw _ f)) -> f info) fields -makeHaving :: Connection -> WhereClause -> (TLB.Builder, [PersistValue]) +makeHaving :: IdentInfo -> WhereClause -> (TLB.Builder, [PersistValue]) makeHaving _ NoWhere = mempty -makeHaving conn (Where (ERaw _ f)) = first ("\nHAVING " <>) (f conn) +makeHaving info (Where (ERaw _ f)) = first ("\nHAVING " <>) (f info) -makeOrderBy :: Connection -> [OrderByClause] -> (TLB.Builder, [PersistValue]) +makeOrderBy :: IdentInfo -> [OrderByClause] -> (TLB.Builder, [PersistValue]) makeOrderBy _ [] = mempty -makeOrderBy conn os = first ("\nORDER BY " <>) $ uncommas' (map mk os) +makeOrderBy info os = first ("\nORDER BY " <>) $ uncommas' (map mk os) where - mk (EOrderBy t (ERaw p f)) = first ((<> orderByType t) . parensM p) (f conn) + mk (EOrderBy t (ERaw p f)) = first ((<> orderByType t) . parensM p) (f info) orderByType ASC = " ASC" orderByType DESC = " DESC" -makeLimit :: Connection -> LimitClause -> (TLB.Builder, [PersistValue]) +makeLimit :: IdentInfo -> LimitClause -> (TLB.Builder, [PersistValue]) makeLimit _ (Limit Nothing Nothing) = mempty makeLimit _ (Limit Nothing (Just 0)) = mempty -makeLimit conn (Limit ml mo) = (ret, mempty) +makeLimit info (Limit ml mo) = (ret, mempty) where ret = TLB.singleton '\n' <> (limitTLB <> offsetTLB) limitTLB = case ml of Just l -> "LIMIT " <> TLBI.decimal l - Nothing -> TLB.fromText (connNoLimit conn) + Nothing -> TLB.fromText (connNoLimit $ fst info) offsetTLB = case mo of @@ -886,7 +906,7 @@ class SqlSelect a r | a -> r, r -> a where -- | Creates the variable part of the @SELECT@ query and -- returns the list of 'PersistValue's that will be given to -- 'rawQuery'. - sqlSelectCols :: Connection -> a -> (TLB.Builder, [PersistValue]) + sqlSelectCols :: IdentInfo -> a -> (TLB.Builder, [PersistValue]) -- | Number of columns that will be consumed. sqlSelectColCount :: Proxy a -> Int @@ -897,7 +917,7 @@ class SqlSelect a r | a -> r, r -> a where -- | You may return an insertion of some PersistEntity instance PersistEntity a => SqlSelect (SqlExpr (Insertion a)) (Insertion a) where - sqlSelectCols conn (EInsert _ f) = f conn + sqlSelectCols info (EInsert _ f) = f info sqlSelectColCount = const 0 sqlSelectProcessRow = const (Right (error msg)) where @@ -913,10 +933,10 @@ instance SqlSelect () () where -- | You may return an 'Entity' from a 'select' query. instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where - sqlSelectCols conn expr@(EEntity ident) = ret + sqlSelectCols info expr@(EEntity ident) = ret where process ed = uncommas $ - map ((name <>) . fromDBName conn) $ + map ((name <>) . fromDBName info) $ (entityID ed:) $ map fieldDB $ entityFields ed @@ -926,7 +946,7 @@ instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where -- clause), while 'rawSql' assumes that it's just the -- name of the table (which doesn't allow self-joins, for -- example). - name = useIdent conn ident <> "." + name = useIdent info ident <> "." ret = let ed = entityDef $ getEntityVal $ return expr in (process ed, mempty) sqlSelectColCount = (+1) . length . entityFields . entityDef . getEntityVal @@ -941,7 +961,7 @@ getEntityVal = const Proxy -- | You may return a possibly-@NULL@ 'Entity' from a 'select' query. instance PersistEntity a => SqlSelect (SqlExpr (Maybe (Entity a))) (Maybe (Entity a)) where - sqlSelectCols conn (EMaybe ent) = sqlSelectCols conn ent + sqlSelectCols info (EMaybe ent) = sqlSelectCols info ent sqlSelectColCount = sqlSelectColCount . fromEMaybe where fromEMaybe :: Proxy (SqlExpr (Maybe e)) -> Proxy (SqlExpr e) @@ -954,8 +974,8 @@ instance PersistEntity a => SqlSelect (SqlExpr (Maybe (Entity a))) (Maybe (Entit -- | You may return any single value (i.e. a single column) from -- a 'select' query. instance PersistField a => SqlSelect (SqlExpr (Value a)) (Value a) where - sqlSelectCols esc (ERaw p f) = let (b, vals) = f esc - in (parensM p b, vals) + sqlSelectCols info (ERaw p f) = let (b, vals) = f info + in (parensM p b, vals) sqlSelectColCount = const 1 sqlSelectProcessRow [pv] = Value <$> fromPersistValue pv sqlSelectProcessRow _ = Left "SqlSelect (Value a): wrong number of columns." @@ -1468,4 +1488,4 @@ insertGeneralSelect :: (MonadLogger m, MonadResourceBase m, SqlSelect (SqlExpr ( Mode -> SqlQuery (SqlExpr (Insertion a)) -> SqlPersistT m () insertGeneralSelect mode query = do conn <- SqlPersistT R.ask - uncurry rawExecute $ first builderToText $ toRawSql mode insertQuery conn query + uncurry rawExecute $ first builderToText $ toRawSql mode insertQuery (conn, initialIdentState) query