diff --git a/esqueleto.cabal b/esqueleto.cabal index 7c07394..6974d5e 100644 --- a/esqueleto.cabal +++ b/esqueleto.cabal @@ -31,6 +31,7 @@ library Database.Esqueleto Database.Esqueleto.Experimental Database.Esqueleto.Experimental.Aggregates + Database.Esqueleto.Experimental.WindowFunctions Database.Esqueleto.Internal.Language Database.Esqueleto.Internal.Sql Database.Esqueleto.Internal.Internal @@ -164,7 +165,6 @@ test-suite sqlite Paths_esqueleto hs-source-dirs: test - ghc-options: -Wall build-depends: base >=4.8 && <5.0 , attoparsec diff --git a/src/Database/Esqueleto/Experimental/WindowFunctions.hs b/src/Database/Esqueleto/Experimental/WindowFunctions.hs new file mode 100644 index 0000000..08da94a --- /dev/null +++ b/src/Database/Esqueleto/Experimental/WindowFunctions.hs @@ -0,0 +1,219 @@ +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} + +module Database.Esqueleto.Experimental.WindowFunctions + where + +import Data.Coerce (coerce) +import Data.Int (Int64) +import Data.Semigroup (First(..)) +import qualified Data.Text.Lazy.Builder as TLB +import Database.Esqueleto.Experimental.Aggregates +import Database.Esqueleto.Internal.Internal + ( NeedParens(..) + , SideData(..) + , SqlExpr(..) + , SqlQuery(..) + , SqlSelect(..) + , ToSomeValues(..) + , Value(..) + , noMeta + , select + , unsafeSqlFunction + , (?.) + , (^.) + ) +import Database.Esqueleto.Internal.PersistentImport + ( Entity + , EntityField + , PersistEntity + , PersistField + , PersistValue(..) + , SqlReadT + , fromPersistValue + ) + + +data Window = Window + { windowPartitionBy :: Maybe (First (TLB.Builder, [PersistValue])) + , windowOrderBy :: Maybe (First (TLB.Builder, [PersistValue])) + , windowFrame :: Maybe (First Frame) + } + +class RenderWindow a where + renderWindow :: a -> (TLB.Builder, [PersistValue]) +instance RenderWindow () where + renderWindow = mempty +instance RenderWindow Window where + renderWindow window = + let (p, pVal) = maybe mempty getFirst $ windowPartitionBy window + (o, oVal) = maybe mempty getFirst $ windowOrderBy window + (f, fVal) = maybe mempty (renderWindow . getFirst) (windowFrame window) + in (p <> o <> f, pVal <> oVal <> fVal) + +instance Semigroup Window where + (Window a b c) <> (Window a' b' c') = Window (a <> a') (b <> b') (c <> c') + +instance Monoid Window where + mempty = Window mempty mempty mempty + mappend = (<>) + +data Frame = Frame (Maybe FrameKind) FrameBody (Maybe FrameExclusion) + +frame :: ToFrame frame => frame -> Window +frame f = mempty{windowFrame = Just $ First $ toFrame f} + +instance RenderWindow Frame where + renderWindow (Frame mKind frameBody mExclusion) = + let (kind, kindVals) = maybe mempty renderWindow mKind + (exclusion, exclusionVals) = maybe mempty renderWindow mExclusion + (body, bodyVals) = renderWindow frameBody + in (kind <> body <> exclusion, kindVals <> bodyVals <> exclusionVals) + +class ToFrame a where + toFrame :: a -> Frame +instance ToFrame Frame where + toFrame = id + +newtype FrameKind = FrameKind { unFrameKind :: (TLB.Builder, [PersistValue]) } + +instance RenderWindow FrameKind where + renderWindow = unFrameKind + +frameKind :: ToFrame frame => TLB.Builder -> frame -> Frame +frameKind tlb frame = + let Frame _ b e = toFrame frame + in Frame (Just (FrameKind (tlb <> " ", []))) b e + +range :: ToFrame frame => frame -> Frame +range = frameKind "RANGE" + +rows :: ToFrame frame => frame -> Frame +rows = frameKind "ROWS" + +groups :: ToFrame frame => frame -> Frame +groups = frameKind "GROUPS" + +newtype FrameExclusion = FrameExclusion { unFrameExclusion :: (TLB.Builder, [PersistValue]) } + +instance RenderWindow FrameExclusion where + renderWindow = unFrameExclusion + +frameExclusion :: ToFrame frame => TLB.Builder -> frame -> Frame +frameExclusion tlb frame = + let Frame k b _ = toFrame frame + in Frame k b (Just $ FrameExclusion (" EXCLUDE " <> tlb, [])) + +excludeCurrentRow :: ToFrame frame => frame -> Frame +excludeCurrentRow = frameExclusion "CURRENT ROW" + +excludeGroup :: ToFrame frame => frame -> Frame +excludeGroup = frameExclusion "GROUP" + +excludeTies :: ToFrame frame => frame -> Frame +excludeTies = frameExclusion "TIES" + +excludeNoOthers :: ToFrame frame => frame -> Frame +excludeNoOthers = frameExclusion "NO OTHERS" + +data FrameBody + = FrameStart FrameRange + | FrameBetween FrameRange FrameRange + +instance ToFrame FrameBody where + toFrame b = Frame Nothing b Nothing + +instance RenderWindow FrameBody where + renderWindow (FrameStart (FrameRangeFollowing b)) = renderWindow (FrameBetween FrameRangeCurrentRow (FrameRangeFollowing b)) + renderWindow (FrameStart f) = renderWindow f + renderWindow (FrameBetween startRange endRange) + | startRange > endRange = renderWindow (FrameBetween endRange startRange) + renderWindow (FrameBetween r r') = + let (b, v) = renderWindow r + (b', v') = renderWindow r' + in ("BETWEEN " <> b <> " AND " <> b', v <> v') + +instance ToFrame FrameRange where + toFrame r = Frame Nothing (FrameStart r) Nothing +instance RenderWindow FrameRange where + renderWindow (FrameRangeCurrentRow) = ("CURRENT ROW", []) + renderWindow (FrameRangePreceeding bounds) = renderBounds bounds <> (" PRECEEDING", []) + renderWindow (FrameRangeFollowing bounds) = renderBounds bounds <> (" FOLLOWING", []) + +renderBounds :: FrameRangeBound -> (TLB.Builder, [PersistValue]) +renderBounds (FrameRangeUnbounded) = ("UNBOUNDED", []) +renderBounds (FrameRangeBounded i) = ("?", [PersistInt64 i]) + +data FrameRange + = FrameRangePreceeding FrameRangeBound + | FrameRangeCurrentRow + | FrameRangeFollowing FrameRangeBound + deriving Eq + +instance Ord FrameRange where + FrameRangePreceeding b1 <= FrameRangePreceeding b2 = b1 <= b2 + FrameRangePreceeding _ <= FrameRangeCurrentRow = True + FrameRangePreceeding _ <= FrameRangeFollowing _ = True + FrameRangeCurrentRow <= FrameRangePreceeding _ = False + FrameRangeCurrentRow <= FrameRangeCurrentRow = True + FrameRangeCurrentRow <= FrameRangeFollowing _ = True + FrameRangeFollowing _ <= FrameRangePreceeding _ = False + FrameRangeFollowing _ <= FrameRangeCurrentRow = False + FrameRangeFollowing b1 <= FrameRangeFollowing b2 = b1 <= b2 + +data FrameRangeBound + = FrameRangeUnbounded + | FrameRangeBounded Int64 + deriving Eq + +instance Ord FrameRangeBound where + FrameRangeUnbounded <= FrameRangeBounded _ = False + FrameRangeBounded _ <= FrameRangeUnbounded = True + FrameRangeBounded a <= FrameRangeBounded b = a <= b + +between :: FrameRange -> FrameRange -> FrameBody +between = FrameBetween + +unboundedPreceeding :: FrameRange +unboundedPreceeding = FrameRangePreceeding FrameRangeUnbounded + +preceeding :: Int64 -> FrameRange +preceeding offset = FrameRangePreceeding (FrameRangeBounded offset) + +following :: Int64 -> FrameRange +following offset = FrameRangeFollowing (FrameRangeBounded offset) + +unboundedFollowing :: FrameRange +unboundedFollowing = FrameRangeFollowing FrameRangeUnbounded + +currentRow :: FrameRange +currentRow = FrameRangeCurrentRow + +class Over expr where + over_ :: RenderWindow window => expr a -> window -> SqlExpr (WindowedValue a) + +data WindowedValue a = WindowedValue { unWindowedValue :: a } +instance PersistField a => SqlSelect (SqlExpr (WindowedValue a)) (WindowedValue a) where + sqlSelectCols info expr = sqlSelectCols info (coerce expr :: SqlExpr a) + sqlSelectColCount = const 1 + sqlSelectProcessRow _ [pv] = WindowedValue <$> fromPersistValue pv + sqlSelectProcessRow _ pvs = WindowedValue <$> fromPersistValue (PersistList pvs) + + +newtype WindowExpr a = WindowExpr { unsafeWindowExpr :: SqlExpr a } +instance Over WindowExpr where + (WindowExpr (ERaw _ f)) `over_` window = ERaw noMeta $ \p info -> + let (b, v) = f Never info + (w, vw) = renderWindow window + in (b <> " OVER (" <> w <> ")", v <> vw) + +deriving via WindowExpr instance Over SqlAggregate diff --git a/src/Database/Esqueleto/Internal/Internal.hs b/src/Database/Esqueleto/Internal/Internal.hs index 71e22b9..d265ade 100644 --- a/src/Database/Esqueleto/Internal/Internal.hs +++ b/src/Database/Esqueleto/Internal/Internal.hs @@ -16,6 +16,8 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fno-warn-redundant-constraints #-} -- | This is an internal module, anything exported by this module -- may change without a major version bump. Please use only @@ -25,7 +27,6 @@ -- tracker so we can safely support it. module Database.Esqueleto.Internal.Internal where -import Data.Kind (Constraint) import Control.Applicative ((<|>)) import Data.Coerce (Coercible, coerce) import Control.Arrow (first, (***)) @@ -268,7 +269,7 @@ orderByExpr orderByType (ERaw m f) in uncommas' $ zip (map (<> orderByType) fs) vals | otherwise = ERaw noMeta $ \_ info -> - first (<> orderByType) $ f Never info + first (<> orderByType) $ f Parens info -- | @LIMIT@. Limit the number of returned rows. limit :: Int64 -> SqlQuery () @@ -570,15 +571,15 @@ subSelectUnsafe = sub SELECT ed = entityDef $ getEntityVal $ getProxy field - dot info fieldDef = + dot info fd = sourceIdent info <> "." <> fieldIdent where sourceIdent = fmap fst $ f Never fieldIdent | Just baseI <- sqlExprMetaAlias m = - useIdent info $ aliasedEntityColumnIdent baseI fieldDef + useIdent info $ aliasedEntityColumnIdent baseI fd | otherwise = - fromDBName info (fieldDB fieldDef) + fromDBName info (fieldDB fd) -- | Project an SqlExpression that may be null, guarding against null cases. withNonNull @@ -632,9 +633,7 @@ isNothing v = first (parensM p) . flip (,) [] . (intersperseB " AND " . map (<> " IS NULL")) $ fields info Nothing -> ERaw noMeta $ \p info -> - first (parensM p) . isNullExpr $ f Never info - where - isNullExpr = first (<> " IS NULL") + first (parensM p . (<> " IS NULL")) $ f Never info -- | Analogous to 'Just', promotes a value of type @typ@ into -- one of type @Maybe typ@. It should hold that @'val' . Just @@ -930,7 +929,7 @@ in_ :: PersistField typ => SqlExpr typ -> SqlExpr (ValueList typ) -> SqlExpr Boo if b2 == "()" then ("FALSE", []) else - (b1 <> " IN " <> b2, vals1 <> vals2) + (parensM p (b1 <> " IN " <> b2), vals1 <> vals2) -- | @NOT IN@ operator. notIn :: PersistField typ => SqlExpr typ -> SqlExpr (ValueList typ) -> SqlExpr Bool @@ -938,7 +937,11 @@ notIn :: PersistField typ => SqlExpr typ -> SqlExpr (ValueList typ) -> SqlExpr B ERaw noMeta $ \p info -> let (b1, vals1) = v Parens info (b2, vals2) = list Parens info - in (b1 <> " NOT IN " <> b2, vals1 <> vals2) + in + if b2 == "()" then + ("FALSE", []) + else + (parensM p (b1 <> " NOT IN " <> b2), vals1 <> vals2) -- | @EXISTS@ operator. For example: -- @@ -953,14 +956,14 @@ notIn :: PersistField typ => SqlExpr typ -> SqlExpr (ValueList typ) -> SqlExpr B exists :: SqlQuery () -> SqlExpr Bool exists q = ERaw noMeta $ \p info -> let ERaw _ f = existsHelper q - (b, vals) = f Never info + (b, vals) = f Parens info in ( parensM p $ "EXISTS " <> b, vals) -- | @NOT EXISTS@ operator. notExists :: SqlQuery () -> SqlExpr Bool notExists q = ERaw noMeta $ \p info -> let ERaw _ f = existsHelper q - (b, vals) = f Never info + (b, vals) = f Parens info in ( parensM p $ "NOT EXISTS " <> b, vals) -- | @SET@ clause used on @UPDATE@s. Note that while it's not @@ -2107,7 +2110,7 @@ setAux field value = \ent -> ERaw noMeta $ \_ info -> in (fieldName info field <> " = " <> valueToSet, valueVals) sub :: (SqlSelect (SqlExpr a) r, PersistField a) => Mode -> SqlQuery (SqlExpr a) -> SqlExpr a -sub mode query = ERaw noMeta $ \_ info -> first parens $ toRawSql mode info query +sub mode query = ERaw noMeta $ \p info -> first (parensM p) $ toRawSql mode info query fromDBName :: IdentInfo -> DBName -> TLB.Builder fromDBName (conn, _) = TLB.fromText . connEscapeName conn @@ -2254,7 +2257,7 @@ unsafeSqlFunction :: UnsafeSqlFunctionArgument a => TLB.Builder -> a -> SqlExpr b unsafeSqlFunction name arg = - ERaw noMeta $ \p info -> + ERaw noMeta $ \_ info -> let (argsTLB, argsVals) = uncommas' $ map (valueToFunctionArg info) $ toArgList arg in @@ -2280,8 +2283,8 @@ unsafeSqlFunctionParens :: UnsafeSqlFunctionArgument a => TLB.Builder -> a -> SqlExpr b unsafeSqlFunctionParens name arg = - ERaw noMeta $ \p info -> - let valueToFunctionArgParens (ERaw _ f) = f Never info + ERaw noMeta $ \_ info -> + let valueToFunctionArgParens (ERaw _ f) = f Parens info (argsTLB, argsVals) = uncommas' $ map valueToFunctionArgParens $ toArgList arg in diff --git a/test/Common/Test.hs b/test/Common/Test.hs index 6bfc967..f79f928 100644 --- a/test/Common/Test.hs +++ b/test/Common/Test.hs @@ -22,6 +22,7 @@ {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-unused-binds #-} +{-# OPTIONS_GHC -fno-warn-unused-imports #-} {-# OPTIONS_GHC -fno-warn-deprecations #-} module Common.Test ( tests