Skip to content

Commit

Permalink
Instance simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentRDC committed Jan 21, 2025
1 parent b7ec49d commit ffb326e
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions javelin-frames/src/Data/Frame.hs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ import Data.Functor.Identity (Identity(..))
import Data.Kind (Type)
import Data.Vector (Vector)
import qualified Data.Vector
import GHC.Generics hiding (Selector)
import GHC.Generics ( Generic(..), K1(..), M1(..), type (:*:)(..) )


-- | Type family which allows for higher-kinded record types
Expand All @@ -97,12 +97,10 @@ type family Column (f :: Type -> Type) a where
Column Vector x = Vector x

-- | Type synonym for a record type with scalar elements
type Row (dt :: (Type -> Type) -> Type)
= (dt Identity)
type Row (dt :: (Type -> Type) -> Type) = dt Identity

-- | Type synonym for a record type whose elements are arrays (columns)
type Frame (dt :: (Type -> Type) -> Type)
= (dt Vector)
type Frame (dt :: (Type -> Type) -> Type) = dt Vector


-- | Typeclass to generically derive the function `fromRows`.
Expand All @@ -125,7 +123,7 @@ instance GFromRows tI tV => GFromRows (M1 i c tI) (M1 i c tV) where
class GToRows tI tV where
gtoRows :: tV a -> Vector (tI a)

instance (v ~ Vector a) => GToRows (K1 i a) (K1 i v) where
instance GToRows (K1 i a) (K1 i (Vector a)) where
gtoRows = Data.Vector.map K1 . unK1

instance (GToRows tI1 tV1, GToRows tI2 tV2)
Expand All @@ -149,9 +147,9 @@ class Frameable t where
-- To convert a @`Frame` t@ to rows, see `toRows`
fromRows :: Vector (Row t) -> Frame t

default fromRows :: ( Generic (t Identity)
, Generic (t Vector)
, GFromRows (Rep (t Identity)) (Rep (t Vector))
default fromRows :: ( Generic (Row t)
, Generic (Frame t)
, GFromRows (Rep (Row t)) (Rep (Frame t))
)
=> Vector (Row t)
-> Frame t
Expand All @@ -162,7 +160,7 @@ class Frameable t where

default toRows :: ( Generic (t Identity)
, Generic (t Vector)
, GToRows (Rep (t Identity)) (Rep (t Vector))
, GToRows (Rep (Row t)) (Rep (Frame t))
)
=> Frame t
-> Vector (Row t)
Expand All @@ -171,9 +169,9 @@ class Frameable t where

-- | Map a function over each row individually.
mapFrame :: (Frameable t1, Frameable t2)
=> (Row t1 -> Row t2)
-> Frame t1
-> Frame t2
=> (Row t1 -> Row t2)
-> Frame t1
-> Frame t2
mapFrame f = fromRows
. Data.Vector.map f
. toRows
Expand All @@ -182,9 +180,9 @@ mapFrame f = fromRows
-- | Filter rows from a @`Frame` t@, only keeping
-- the rows where the predicate is `True`.
filterFrame :: (Frameable t)
=> (Row t -> Bool)
-> Frame t
-> Frame t
=> (Row t -> Bool)
-> Frame t
-> Frame t
filterFrame f = fromRows
. Data.Vector.filter f
. toRows
Expand All @@ -208,9 +206,9 @@ zipFramesWith f xs ys

-- | Left-associative fold of a structure but with strict application of the operator.
foldlFrame :: Frameable t
=> (b -> Row t -> b) -- ^ Reduction function that takes in individual rows
-> b -- ^ Initial value for the accumulator
-> Frame t -- ^ Data frame
-> b
=> (b -> Row t -> b) -- ^ Reduction function that takes in individual rows
-> b -- ^ Initial value for the accumulator
-> Frame t -- ^ Data frame
-> b
foldlFrame f start
= Data.Vector.foldl' f start . toRows

0 comments on commit ffb326e

Please sign in to comment.