esqueleto/src/Database/Esqueleto/PostgreSQL.hs
2020-10-29 07:43:05 -06:00

351 lines
11 KiB
Haskell

{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | This module contain PostgreSQL-specific functions.
--
-- @since: 2.2.8
module Database.Esqueleto.PostgreSQL
( AggMode(..)
, arrayAggDistinct
, arrayAgg
, arrayAggWith
, arrayRemove
, arrayRemoveNull
, stringAgg
, stringAggWith
, maybeArray
, chr
, now_
, random_
, upsert
, upsertBy
, insertSelectWithConflict
, insertSelectWithConflictCount
, filterWhere
-- * Internal
, unsafeSqlAggregateFunction
) where
#if __GLASGOW_HASKELL__ < 804
import Data.Semigroup
#endif
import Control.Arrow (first, (***))
import Control.Exception (throw)
import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO(..))
import qualified Control.Monad.Trans.Reader as R
import Data.Int (Int64)
import Data.List.NonEmpty (NonEmpty((:|)))
import Data.Proxy (Proxy(..))
import qualified Data.Text.Internal.Builder as TLB
import Data.Time.Clock (UTCTime)
import Database.Esqueleto.Internal.Internal hiding (random_)
import Database.Esqueleto.Internal.PersistentImport hiding (upsert, upsertBy)
import Database.Persist.Class (OnlyOneUniqueKey)
-- | (@random()@) Split out into database specific modules
-- because MySQL uses `rand()`.
--
-- @since 2.6.0
random_ :: (PersistField a, Num a) => SqlExpr (Value a)
random_ = unsafeSqlValue "RANDOM()"
-- | Empty array literal. (@val []@) does unfortunately not work
emptyArray :: SqlExpr (Value [a])
emptyArray = unsafeSqlValue "'{}'"
-- | Coalesce an array with an empty default value
maybeArray ::
(PersistField a, PersistField [a])
=> SqlExpr (Value (Maybe [a]))
-> SqlExpr (Value [a])
maybeArray x = coalesceDefault [x] (emptyArray)
-- | Aggregate mode
data AggMode
= AggModeAll -- ^ ALL
| AggModeDistinct -- ^ DISTINCT
deriving (Show)
-- | (Internal) Create a custom aggregate functions with aggregate mode
--
-- /Do/ /not/ use this function directly, instead define a new function and give
-- it a type (see `unsafeSqlBinOp`)
unsafeSqlAggregateFunction
:: UnsafeSqlFunctionArgument a
=> TLB.Builder
-> AggMode
-> a
-> [OrderByClause]
-> SqlExpr (Value b)
unsafeSqlAggregateFunction name mode args orderByClauses = ERaw Never $ \info ->
let (orderTLB, orderVals) = makeOrderByNoNewline info orderByClauses
-- Don't add a space if we don't have order by clauses
orderTLBSpace =
case orderByClauses of
[] -> ""
(_:_) -> " "
(argsTLB, argsVals) =
uncommas' $ map (\(ERaw _ f) -> f info) $ toArgList args
aggMode =
case mode of
AggModeAll -> ""
-- ALL is the default, so we don't need to
-- specify it
AggModeDistinct -> "DISTINCT "
in ( name <> parens (aggMode <> argsTLB <> orderTLBSpace <> orderTLB)
, argsVals <> orderVals
)
--- | (@array_agg@) Concatenate input values, including @NULL@s,
--- into an array.
arrayAggWith
:: AggMode
-> SqlExpr (Value a)
-> [OrderByClause]
-> SqlExpr (Value (Maybe [a]))
arrayAggWith = unsafeSqlAggregateFunction "array_agg"
--- | (@array_agg@) Concatenate input values, including @NULL@s,
--- into an array.
arrayAgg :: (PersistField a) => SqlExpr (Value a) -> SqlExpr (Value (Maybe [a]))
arrayAgg x = arrayAggWith AggModeAll x []
-- | (@array_agg@) Concatenate distinct input values, including @NULL@s, into
-- an array.
--
-- @since 2.5.3
arrayAggDistinct
:: (PersistField a, PersistField [a])
=> SqlExpr (Value a)
-> SqlExpr (Value (Maybe [a]))
arrayAggDistinct x = arrayAggWith AggModeDistinct x []
-- | (@array_remove@) Remove all elements equal to the given value from the
-- array.
--
-- @since 2.5.3
arrayRemove :: SqlExpr (Value [a]) -> SqlExpr (Value a) -> SqlExpr (Value [a])
arrayRemove arr elem' = unsafeSqlFunction "array_remove" (arr, elem')
-- | Remove @NULL@ values from an array
arrayRemoveNull :: SqlExpr (Value [Maybe a]) -> SqlExpr (Value [a])
-- This can't be a call to arrayRemove because it changes the value type
arrayRemoveNull x = unsafeSqlFunction "array_remove" (x, unsafeSqlValue "NULL")
-- | (@string_agg@) Concatenate input values separated by a
-- delimiter.
stringAggWith ::
SqlString s
=> AggMode -- ^ Aggregate mode (ALL or DISTINCT)
-> SqlExpr (Value s) -- ^ Input values.
-> SqlExpr (Value s) -- ^ Delimiter.
-> [OrderByClause] -- ^ ORDER BY clauses
-> SqlExpr (Value (Maybe s)) -- ^ Concatenation.
stringAggWith mode expr delim os =
unsafeSqlAggregateFunction "string_agg" mode (expr, delim) os
-- | (@string_agg@) Concatenate input values separated by a
-- delimiter.
--
-- @since 2.2.8
stringAgg ::
SqlString s
=> SqlExpr (Value s) -- ^ Input values.
-> SqlExpr (Value s) -- ^ Delimiter.
-> SqlExpr (Value (Maybe s)) -- ^ Concatenation.
stringAgg expr delim = stringAggWith AggModeAll expr delim []
-- | (@chr@) Translate the given integer to a character. (Note the result will
-- depend on the character set of your database.)
--
-- @since 2.2.11
chr :: SqlString s => SqlExpr (Value Int) -> SqlExpr (Value s)
chr = unsafeSqlFunction "chr"
now_ :: SqlExpr (Value UTCTime)
now_ = unsafeSqlFunction "NOW" ()
upsert
::
( MonadIO m
, PersistEntity record
, OnlyOneUniqueKey record
, PersistRecordBackend record SqlBackend
, IsPersistBackend (PersistEntityBackend record)
)
=> record
-- ^ new record to insert
-> [SqlExpr (Update record)]
-- ^ updates to perform if the record already exists
-> R.ReaderT SqlBackend m (Entity record)
-- ^ the record in the database after the operation
upsert record updates = do
uniqueKey <- onlyUnique record
upsertBy uniqueKey record updates
upsertBy
::
(MonadIO m
, PersistEntity record
, IsPersistBackend (PersistEntityBackend record)
)
=> Unique record
-- ^ uniqueness constraint to find by
-> record
-- ^ new record to insert
-> [SqlExpr (Update record)]
-- ^ updates to perform if the record already exists
-> R.ReaderT SqlBackend m (Entity record)
-- ^ the record in the database after the operation
upsertBy uniqueKey record updates = do
sqlB <- R.ask
maybe
(throw (UnexpectedCaseErr OperationNotSupported)) -- Postgres backend should have connUpsertSql, if this error is thrown, check changes on persistent
(handler sqlB)
(connUpsertSql sqlB)
where
addVals l = map toPersistValue (toPersistFields record) ++ l ++ persistUniqueToValues uniqueKey
entDef = entityDef (Just record)
uDef = toUniqueDef uniqueKey
updatesText conn = first builderToText $ renderUpdates conn updates
handler conn f = fmap head $ uncurry rawSql $
(***) (f entDef (uDef :| [])) addVals $ updatesText conn
-- | Inserts into a table the results of a query similar to 'insertSelect' but allows
-- to update values that violate a constraint during insertions.
--
-- Example of usage:
--
-- @
-- share [ mkPersist sqlSettings
-- , mkDeleteCascade sqlSettings
-- , mkMigrate "migrate"
-- ] [persistLowerCase|
-- Bar
-- num Int
-- deriving Eq Show
-- Foo
-- num Int
-- UniqueFoo num
-- deriving Eq Show
-- |]
--
-- insertSelectWithConflict
-- UniqueFoo -- (UniqueFoo undefined) or (UniqueFoo anyNumber) would also work
-- (from $ \b ->
-- return $ Foo <# (b ^. BarNum)
-- )
-- (\current excluded ->
-- [FooNum =. (current ^. FooNum) +. (excluded ^. FooNum)]
-- )
-- @
--
-- Inserts to table Foo all Bar.num values and in case of conflict SomeFooUnique,
-- the conflicting value is updated to the current plus the excluded.
--
-- @since 3.1.3
insertSelectWithConflict
:: forall a m val
. (FinalResult a, KnowResult a ~ Unique val, MonadIO m, PersistEntity val)
=> a
-- ^ Unique constructor or a unique, this is used just to get the name of
-- the postgres constraint, the value(s) is(are) never used, so if you have
-- a unique "MyUnique 0", "MyUnique undefined" would work as well.
-> SqlQuery (SqlExpr (Insertion val))
-- ^ Insert query.
-> (SqlExpr (Entity val) -> SqlExpr (Entity val) -> [SqlExpr (Update val)])
-- ^ A list of updates to be applied in case of the constraint being
-- violated. The expression takes the current and excluded value to produce
-- the updates.
-> SqlWriteT m ()
insertSelectWithConflict unique query =
void . insertSelectWithConflictCount unique query
-- | Same as 'insertSelectWithConflict' but returns the number of rows affected.
--
-- @since 3.1.3
insertSelectWithConflictCount
:: forall a val m
. (FinalResult a, KnowResult a ~ Unique val, MonadIO m, PersistEntity val)
=> a
-> SqlQuery (SqlExpr (Insertion val))
-> (SqlExpr (Entity val) -> SqlExpr (Entity val) -> [SqlExpr (Update val)])
-> SqlWriteT m Int64
insertSelectWithConflictCount unique query conflictQuery = do
conn <- R.ask
uncurry rawExecuteCount $
combine
(toRawSql INSERT_INTO (conn, initialIdentState) (fmap EInsertFinal query))
(conflict conn)
where
proxy :: Proxy val
proxy = Proxy
updates = conflictQuery entCurrent entExcluded
combine (tlb1,vals1) (tlb2,vals2) = (builderToText (tlb1 `mappend` tlb2), vals1 ++ vals2)
entExcluded = EEntity $ I "excluded"
tableName = unDBName . entityDB . entityDef
entCurrent = EEntity $ I (tableName proxy)
uniqueDef = toUniqueDef unique
constraint = TLB.fromText . unDBName . uniqueDBName $ uniqueDef
renderedUpdates :: (BackendCompatible SqlBackend backend) => backend -> (TLB.Builder, [PersistValue])
renderedUpdates conn = renderUpdates conn updates
conflict conn = (mconcat ([
TLB.fromText "ON CONFLICT ON CONSTRAINT \"",
constraint,
TLB.fromText "\" DO "
] ++ if null updates then [TLB.fromText "NOTHING"] else [
TLB.fromText "UPDATE SET ",
updatesTLB
]),values)
where
(updatesTLB,values) = renderedUpdates conn
-- | Allow aggregate functions to take a filter clause.
--
-- Example of usage:
--
-- @
-- share [mkPersist sqlSettings] [persistLowerCase|
-- User
-- name Text
-- deriving Eq Show
-- Task
-- userId UserId
-- completed Bool
-- deriving Eq Show
-- |]
--
-- select $ from $ \(users `InnerJoin` tasks) -> do
-- on $ users ^. UserId ==. tasks ^. TaskUserId
-- groupBy $ users ^. UserId
-- return
-- ( users ^. UserId
-- , count (tasks ^. TaskId) `filterWhere` (tasks ^. TaskCompleted ==. val True)
-- , count (tasks ^. TaskId) `filterWhere` (tasks ^. TaskCompleted ==. val False)
-- )
-- @
--
-- @since 3.3.3.3
filterWhere
:: SqlExpr (Value a)
-- ^ Aggregate function
-> SqlExpr (Value Bool)
-- ^ Filter clause
-> SqlExpr (Value a)
filterWhere aggExpr clauseExpr = ERaw Never $ \info ->
let (aggBuilder, aggValues) = case aggExpr of
ERaw _ aggF -> aggF info
ECompositeKey _ -> throw $ CompositeKeyErr FilterWhereAggError
(clauseBuilder, clauseValues) = case clauseExpr of
ERaw _ clauseF -> clauseF info
ECompositeKey _ -> throw $ CompositeKeyErr FilterWhereClauseError
in ( aggBuilder <> " FILTER (WHERE " <> clauseBuilder <> ")"
, aggValues <> clauseValues
)