diff --git a/app/Main.hs b/app/Main.hs index ce4280c3..828cdb71 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -62,13 +62,12 @@ compatibleHorusCompileVersionLower = makeVersion [0, 0, 6, 8] compatibleHorusCompileVersionHigher :: Version compatibleHorusCompileVersionHigher = makeVersion [0, 0, 7] -{- | The main entrypoint of everything that happens in our monad stack. - The contract is a 1-1 representation of the data in the compiled JSON file. - The contract is then used to create a 'ContractInfo' which is a more - convenient representation of the contract. - We run `solveContract`, which is the entrypoint into the *rest* of the - program, and gather the results for pretty-printing. --} +-- | The main entrypoint of everything that happens in our monad stack. +-- The contract is a 1-1 representation of the data in the compiled JSON file. +-- The contract is then used to create a 'ContractInfo' which is a more +-- convenient representation of the contract. +-- We run `solveContract`, which is the entrypoint into the *rest* of the +-- program, and gather the results for pretty-printing. main' :: Arguments -> FilePath -> FilePath -> EIO () main' Arguments{..} filename specFileName = do contract <- eioDecodeFileStrict filename specFileName <&> cdSpecs %~ (<> stdSpecs) @@ -81,30 +80,30 @@ main' Arguments{..} filename specFileName = do TextIO.putStrLn (ppSolvingInfo si) let unknowns = [res | res@(Timeout{}) <- map si_result infos] unless (null unknowns) $ liftIO (TextIO.putStrLn hint') - where - hint' = "\ESC[33m" <> (T.strip . T.unlines . map ("hint: " <>) . T.lines) hint <> "\ESC[0m" - guardVersion :: ContractDefinition -> EIO () - guardVersion cd = do - compilerVersion <- case [ x - | x@(_, lst) <- readP_to_S parseVersion (cd_version cd) - , null lst - ] of - [(v, [])] -> pure v - _ -> fail $ "Wrong version format: " <> cd_version cd - when - ( compilerVersion < compatibleHorusCompileVersionLower - || compilerVersion >= compatibleHorusCompileVersionHigher - ) - . fail - . concat - $ [ "The *.json on input has been compiled with an incompatible version of Horus-compile.\nExpected: " - , ">=" - , showVersion compatibleHorusCompileVersionLower - , ", <" - , showVersion compatibleHorusCompileVersionHigher - , " but got: " - , showVersion compilerVersion - ] + where + hint' = "\ESC[33m" <> (T.strip . T.unlines . map ("hint: " <>) . T.lines) hint <> "\ESC[0m" + guardVersion :: ContractDefinition -> EIO () + guardVersion cd = do + compilerVersion <- case [ x + | x@(_, lst) <- readP_to_S parseVersion (cd_version cd) + , null lst + ] of + [(v, [])] -> pure v + _ -> fail $ "Wrong version format: " <> cd_version cd + when + ( compilerVersion < compatibleHorusCompileVersionLower + || compilerVersion >= compatibleHorusCompileVersionHigher + ) + . fail + . concat + $ [ "The *.json on input has been compiled with an incompatible version of Horus-compile.\nExpected: " + , ">=" + , showVersion compatibleHorusCompileVersionLower + , ", <" + , showVersion compatibleHorusCompileVersionHigher + , " but got: " + , showVersion compilerVersion + ] eioDecodeFileStrict :: FromJSON a => FilePath -> FilePath -> EIO a eioDecodeFileStrict contractFile specFile = do @@ -122,17 +121,16 @@ eioDecodeFileStrict contractFile specFile = do ppSolvingInfo :: SolvingInfo -> Text ppSolvingInfo SolvingInfo{..} = si_moduleName <> inlinedIndicator <> "\n" <> tShow si_result <> "\n" - where - inlinedIndicator = if si_inlinable then " [inlined]" else "" + where + inlinedIndicator = if si_inlinable then " [inlined]" else "" -{- | Main entrypoint of the program. - Cases - ===== - 1. No arguments are passed. In this case, we print the help message. - 2. The `--version` flag is passed. In this case, we print the version number. - 3. No file is passed. In this case, we print an error. - 4. A file is passed. In this case, we run `main'`. --} +-- | Main entrypoint of the program. +-- Cases +-- ===== +-- 1. No arguments are passed. In this case, we print the help message. +-- 2. The `--version` flag is passed. In this case, we print the version number. +-- 3. No file is passed. In this case, we print an error. +-- 4. A file is passed. In this case, we run `main'`. main :: IO () main = do TextIO.putStrLn issuesMsg' @@ -153,26 +151,26 @@ main = do (_, Nothing) -> putStrLn "Missing specification JSON file. Use --help for more information." (Just filename, Just specFileName) -> do runExceptT (main' arguments filename specFileName) >>= either (fail . T.unpack) pure - where - issuesMsg' = - "\ESC[33m" <> (T.strip . T.unlines . map ("Warning: " <>) . T.lines) issuesMsg <> "\ESC[0m\n" - opts = - info - (argParser <**> helper) - ( fullDesc - <> progDescDoc - ( Just $ - text "Verifies " - <> text (T.unpack fileArgument) - <> text " (a JSON contract compiled with horus-compile) with the specification file " - <> text (T.unpack specFileArgument) - <> text " provided by horus-compile \n\n" - <> text "Example using solver cvc5 (default):\n" - <> text " $ horus-check a.json spec.json\n\n" - <> text "Example using solver mathsat:\n" - <> text " $ horus-check -s mathsat a.json spec.json\n\n" - <> text "Example using solvers z3, mathsat:\n" - <> text " $ horus-check -s z3 -s mathsat a.json spec.json\n" - ) - <> header "horus-check: an SMT-based checker for StarkNet contracts" - ) + where + issuesMsg' = + "\ESC[33m" <> (T.strip . T.unlines . map ("Warning: " <>) . T.lines) issuesMsg <> "\ESC[0m\n" + opts = + info + (argParser <**> helper) + ( fullDesc + <> progDescDoc + ( Just $ + text "Verifies " + <> text (T.unpack fileArgument) + <> text " (a JSON contract compiled with horus-compile) with the specification file " + <> text (T.unpack specFileArgument) + <> text " provided by horus-compile \n\n" + <> text "Example using solver cvc5 (default):\n" + <> text " $ horus-check a.json spec.json\n\n" + <> text "Example using solver mathsat:\n" + <> text " $ horus-check -s mathsat a.json spec.json\n\n" + <> text "Example using solvers z3, mathsat:\n" + <> text " $ horus-check -s z3 -s mathsat a.json spec.json\n" + ) + <> header "horus-check: an SMT-based checker for StarkNet contracts" + ) diff --git a/fourmolu.yaml b/fourmolu.yaml index bb9ee55a..69240dfa 100644 --- a/fourmolu.yaml +++ b/fourmolu.yaml @@ -1,10 +1,12 @@ indentation: 2 -comma-style: leading # for lists, tuples etc. - can also be 'trailing' -import-export-comma-style: leading # for module import export lists - can also be 'trailing' -record-brace-space: false # rec {x = 1} vs. rec{x = 1} -indent-wheres: false # 'false' means save space by only half-indenting the 'where' keyword -diff-friendly-import-export: false # 'false' uses Ormolu-style lists -respectful: true # don't be too opinionated about newlines etc. -haddock-style: multi-line # '--' vs. '{-' -newlines-between-decls: 1 # number of newlines between top-level declarations -fixities: [] # fixity information, see the section about fixities below. +comma-style: leading # for lists, tuples etc. - can also be 'trailing' +import-export-style: leading # for module import export lists - can also be 'trailing' +record-brace-space: false # rec {x = 1} vs. rec{x = 1} +indent-wheres: true # 'false' means save space by only half-indenting the 'where' keyword +respectful: true # don't be too opinionated about newlines etc. +haddock-style: single-line # '--' vs. '{-' +newlines-between-decls: 1 # number of newlines between top-level declarations +fixities: [] # fixity information, see the section about fixities below. +column-limit: 100 +function-arrows: leading +single-constraint-parens: never diff --git a/src/Horus/Arguments.hs b/src/Horus/Arguments.hs index 76065e36..945d67ea 100644 --- a/src/Horus/Arguments.hs +++ b/src/Horus/Arguments.hs @@ -12,7 +12,14 @@ import Data.Text (Text, unpack) import Options.Applicative import Horus.Global (Config (..)) -import Horus.Preprocessor.Solvers (MultiSolver (..), SingleSolver, SolverSettings (..), cvc5, mathsat, z3) +import Horus.Preprocessor.Solvers + ( MultiSolver (..) + , SingleSolver + , SolverSettings (..) + , cvc5 + , mathsat + , z3 + ) data Arguments = Arguments { arg_fileName :: Maybe FilePath @@ -55,7 +62,9 @@ singleSolverParser = ( long "solver" <> short 's' <> metavar "SOLVER" - <> help ("Solver to check the resulting smt queries (options: " <> intercalate ", " singleSolverNames <> ").") + <> help + ( "Solver to check the resulting smt queries (options: " <> intercalate ", " singleSolverNames <> ")." + ) <> completeWith singleSolverNames ) diff --git a/src/Horus/CFGBuild.hs b/src/Horus/CFGBuild.hs index 4899a61c..3c08af44 100644 --- a/src/Horus/CFGBuild.hs +++ b/src/Horus/CFGBuild.hs @@ -152,14 +152,13 @@ getSvarSpecs = liftF' (GetSvarSpecs id) getVerts :: Label -> CFGBuildL [Vertex] getVerts l = liftF' (GetVerts l id) -{- | Salient vertices can be thought of as 'main' vertices of the CFG, meaning that -if one wants to reason about flow control of the program, one should query salient vertices. - -Certain program transformations and optimisations can add various additional nodes into the CFG, -whose primary purpose is not to reason about control flow. - -It is enforced that for any one PC, one can add at most a single salient vertex. --} +-- | Salient vertices can be thought of as 'main' vertices of the CFG, meaning that +-- if one wants to reason about flow control of the program, one should query salient vertices. +-- +-- Certain program transformations and optimisations can add various additional nodes into the CFG, +-- whose primary purpose is not to reason about control flow. +-- +-- It is enforced that for any one PC, one can add at most a single salient vertex. getSalientVertex :: Label -> CFGBuildL Vertex getSalientVertex l = do verts <- filter (not . isPreCheckingVertex) <$> getVerts l @@ -200,31 +199,29 @@ buildFrame inlinables rows prog = do segmentsWithVerts <- for segments $ \s -> addVertex (segmentLabel s) <&> (s,) for_ segmentsWithVerts . uncurry $ addArcsFrom inlinables prog rows -{- | A simple procedure for splitting a stream of instructions into nonempty Segments based -on program labels, which more-or-less correspond with changes in control flow in the program. -We thus obtain linear segments of instructions without control flow. --} +-- | A simple procedure for splitting a stream of instructions into nonempty Segments based +-- on program labels, which more-or-less correspond with changes in control flow in the program. +-- We thus obtain linear segments of instructions without control flow. breakIntoSegments :: [Label] -> [LabeledInst] -> [Segment] breakIntoSegments _ [] = [] breakIntoSegments ls_ (i_ : is_) = coerce (go [] (i_ :| []) ls_ is_) - where - go gAcc lAcc [] rest = reverse (NonEmpty.reverse lAcc `appendList` rest : gAcc) - go gAcc lAcc (_ : _) [] = reverse (NonEmpty.reverse lAcc : gAcc) - go gAcc lAcc (l : ls) (i@(pc, _) : is) - | l < pc = go gAcc lAcc ls (i : is) - | l == pc = go (NonEmpty.reverse lAcc : gAcc) (i :| []) ls is - | otherwise = go gAcc (i NonEmpty.<| lAcc) (l : ls) is + where + go gAcc lAcc [] rest = reverse (NonEmpty.reverse lAcc `appendList` rest : gAcc) + go gAcc lAcc (_ : _) [] = reverse (NonEmpty.reverse lAcc : gAcc) + go gAcc lAcc (l : ls) (i@(pc, _) : is) + | l < pc = go gAcc lAcc ls (i : is) + | l == pc = go (NonEmpty.reverse lAcc : gAcc) (i :| []) ls is + | otherwise = go gAcc (i NonEmpty.<| lAcc) (l : ls) is addArc' :: Vertex -> Vertex -> [LabeledInst] -> CFGBuildL () addArc' lFrom lTo insts = addArc lFrom lTo insts ACNone Nothing -{- | This function adds arcs (edges) into the CFG and labels them with instructions that are -to be executed when traversing from one vertex to another. - -Currently, we do not have an optimisation post-processing pass in Horus and we therefore -also include an optimisation here that generates an extra vertex in order to implement -separate checking of preconditions for abstracted functions. --} +-- | This function adds arcs (edges) into the CFG and labels them with instructions that are +-- to be executed when traversing from one vertex to another. +-- +-- Currently, we do not have an optimisation post-processing pass in Horus and we therefore +-- also include an optimisation here that generates an extra vertex in order to implement +-- separate checking of preconditions for abstracted functions. addArcsFrom :: Set ScopedFunction -> Program -> [LabeledInst] -> Segment -> Vertex -> CFGBuildL () addArcsFrom inlinables prog rows seg@(Segment s) vFrom | Call <- i_opCode endInst = @@ -255,67 +252,67 @@ addArcsFrom inlinables prog rows seg@(Segment s) vFrom | otherwise = do lTo <- getSalientVertex $ nextSegmentLabel seg addArc' vFrom lTo insts - where - lInst@(endPc, endInst) = NonEmpty.last s - insts = segmentInsts seg - inlinableLabels = Set.map sf_pc inlinables - - callee = uncheckedScopedFOfPc (p_identifiers prog) (uncheckedCallDestination lInst) - - beginInlining = do - salientCalleeV <- getSalientVertex (sf_pc callee) - addArc vFrom salientCalleeV insts ACNone . Just $ ArcCall endPc (sf_pc callee) - - optimiseCheckingOfPre = do - -- Suppose F calls G where G has a precondition. We synthesize an extra module - -- Pre(F) -> Pre(G) to check whether Pre(G) holds. The standard module for F - -- is then Pre(F) -> Post(F) (conceptually, unless there's a split in the middle, etc.), - -- in which Pre(G) is assumed to hold at the PC of the G call site, as it will have - -- been checked by the module induced by the ghost vertex. - ghostV <- addOptimisingVertex (nextSegmentLabel seg) callee - pre <- maybe (mkPre Expr.True) mkPre . fs'_pre <$> getFuncSpec callee - - -- Important note on the way we deal with logical variables. These are @declare-d and - -- their values can be bound in preconditions. They generate existentials which only occur - -- in our models here and require special treatment, in addition to being somewhat - -- difficult for SMT checkers to deal with. - - -- First note that these preconditions now become optimising-module postconditions. - -- We existentially quantify all logical variables present in the expression, thus in the - -- following example: - -- func foo: - -- call bar // where bar refers to $my_logical_var - -- We get an optimising module along the lines of: - -- Pre(foo) -> Pre(bar) where Pre(bar) contains \exists my_logical_var, ... - -- We can then check whether this instantiation exists in the optimising module exclusively. - -- The module that then considers that pre holds as a fact now has the luxury of not having - -- to deal with existential quantifiers, as it can simply 'declare' them as free variables. - addAssertion ghostV $ quantifyEx pre - addArc' vFrom ghostV insts - - abstractOver = do - salientLinearV <- getSalientVertex (nextSegmentLabel seg) - addArc' vFrom salientLinearV insts - svarSpecs <- getSvarSpecs - when (sf_scopedName callee `Set.notMember` svarSpecs) optimiseCheckingOfPre - - addRetArc :: Label -> CFGBuildL () - addRetArc pc = do - retV <- getSalientVertex endPc - pastRet <- getSalientVertex pc - addArc retV pastRet [(endPc, endInst)] ACNone $ Just ArcRet - - addRetArcs :: Label -> CFGBuildL () - addRetArcs owner - | owner `Set.notMember` inlinableLabels = pure () - | otherwise = forM_ returnAddrs addRetArc - where - returnAddrs = map (`moveLabel` sizeOfCall) (callersOf rows owner) - - quantifyEx :: (AnnotationType, Expr 'TBool) -> (AnnotationType, Expr 'TBool) - quantifyEx = second $ \expr -> - let lvars = gatherLogicalVariables expr - in foldr Expr.ExistsFelt expr lvars + where + lInst@(endPc, endInst) = NonEmpty.last s + insts = segmentInsts seg + inlinableLabels = Set.map sf_pc inlinables + + callee = uncheckedScopedFOfPc (p_identifiers prog) (uncheckedCallDestination lInst) + + beginInlining = do + salientCalleeV <- getSalientVertex (sf_pc callee) + addArc vFrom salientCalleeV insts ACNone . Just $ ArcCall endPc (sf_pc callee) + + optimiseCheckingOfPre = do + -- Suppose F calls G where G has a precondition. We synthesize an extra module + -- Pre(F) -> Pre(G) to check whether Pre(G) holds. The standard module for F + -- is then Pre(F) -> Post(F) (conceptually, unless there's a split in the middle, etc.), + -- in which Pre(G) is assumed to hold at the PC of the G call site, as it will have + -- been checked by the module induced by the ghost vertex. + ghostV <- addOptimisingVertex (nextSegmentLabel seg) callee + pre <- maybe (mkPre Expr.True) mkPre . fs'_pre <$> getFuncSpec callee + + -- Important note on the way we deal with logical variables. These are @declare-d and + -- their values can be bound in preconditions. They generate existentials which only occur + -- in our models here and require special treatment, in addition to being somewhat + -- difficult for SMT checkers to deal with. + + -- First note that these preconditions now become optimising-module postconditions. + -- We existentially quantify all logical variables present in the expression, thus in the + -- following example: + -- func foo: + -- call bar // where bar refers to $my_logical_var + -- We get an optimising module along the lines of: + -- Pre(foo) -> Pre(bar) where Pre(bar) contains \exists my_logical_var, ... + -- We can then check whether this instantiation exists in the optimising module exclusively. + -- The module that then considers that pre holds as a fact now has the luxury of not having + -- to deal with existential quantifiers, as it can simply 'declare' them as free variables. + addAssertion ghostV $ quantifyEx pre + addArc' vFrom ghostV insts + + abstractOver = do + salientLinearV <- getSalientVertex (nextSegmentLabel seg) + addArc' vFrom salientLinearV insts + svarSpecs <- getSvarSpecs + when (sf_scopedName callee `Set.notMember` svarSpecs) optimiseCheckingOfPre + + addRetArc :: Label -> CFGBuildL () + addRetArc pc = do + retV <- getSalientVertex endPc + pastRet <- getSalientVertex pc + addArc retV pastRet [(endPc, endInst)] ACNone $ Just ArcRet + + addRetArcs :: Label -> CFGBuildL () + addRetArcs owner + | owner `Set.notMember` inlinableLabels = pure () + | otherwise = forM_ returnAddrs addRetArc + where + returnAddrs = map (`moveLabel` sizeOfCall) (callersOf rows owner) + + quantifyEx :: (AnnotationType, Expr 'TBool) -> (AnnotationType, Expr 'TBool) + quantifyEx = second $ \expr -> + let lvars = gatherLogicalVariables expr + in foldr Expr.ExistsFelt expr lvars -- | This function labels appropriate vertices (at 'ret'urns) with their respective postconditions. addAssertions :: Set ScopedFunction -> Identifiers -> CFGBuildL () diff --git a/src/Horus/CFGBuild/Runner.hs b/src/Horus/CFGBuild/Runner.hs index 7951c0ef..94c83063 100644 --- a/src/Horus/CFGBuild/Runner.hs +++ b/src/Horus/CFGBuild/Runner.hs @@ -20,19 +20,27 @@ import Lens.Micro (Lens', at, (&), (^.), _Just) import Lens.Micro.GHC () import Lens.Micro.Mtl ((%=), (<%=)) -import Horus.CFGBuild (AnnotationType, ArcCondition (..), CFGBuildF (..), CFGBuildL (..), Label, LabeledInst, Vertex (..), isPreCheckingVertex) +import Horus.CFGBuild + ( AnnotationType + , ArcCondition (..) + , CFGBuildF (..) + , CFGBuildL (..) + , Label + , LabeledInst + , Vertex (..) + , isPreCheckingVertex + ) import Horus.ContractInfo (ContractInfo (..)) import Horus.Expr (Expr, Ty (..)) import Horus.FunctionAnalysis (FInfo) type Impl = ReaderT ContractInfo (ExceptT Text (State CFG)) -{- | This represents a quasi Control Flow Graph. - -Normally, they store instructions in nodes and edges represent flow control / jumps. -In our case, we store instructions in edges and nodes represent points of program -with associated logical assertions - preconditions, postconditions and invariants. --} +-- | This represents a quasi Control Flow Graph. +-- +-- Normally, they store instructions in nodes and edges represent flow control / jumps. +-- In our case, we store instructions in edges and nodes represent points of program +-- with associated logical assertions - preconditions, postconditions and invariants. data CFG = CFG { cfg_vertices :: [Vertex] , cfg_arcs :: Map Vertex [(Vertex, [LabeledInst], ArcCondition, FInfo)] @@ -61,44 +69,44 @@ verticesLabelledBy cfg l = [v | v <- cfg_vertices cfg, v_label v == l] interpret :: CFGBuildL a -> Impl a interpret = iterM exec . runCFGBuildL - where - exec (AddVertex l mbPreCheckedF cont) = do - freshVal <- cfgVertexCounter <%= succ - let newVertex = Vertex (Text.pack (show freshVal)) l mbPreCheckedF - vs <- gets cfg_vertices - -- Currently, the design is such that it is convenient to be able to distinguish - -- 'the unique vertex the entire codebase relies on' from vertices that exist - -- with the same label for one reason or the other, e.g. optimisation purposes. - -- Ideally, vertices would be treated uniformally, regardless of their raison d'etre, - -- removing the need for enforcing invariants like this. - if (not . isPreCheckingVertex) newVertex - && (not . null) [vert | vert <- vs, v_label vert == l, (not . isPreCheckingVertex) vert] - then throwError "At most one salient Vertex is allowed per PC." - else cfgVertices %= ([newVertex] `union`) >> cont newVertex - exec (AddArc lFrom lTo insts test isF cont) = cfgArcs . at lFrom %= doAdd >> cont - where - doAdd mArcs = Just ((lTo, insts, test, isF) : mArcs ^. _Just) - exec (AddAssertion l assertion cont) = cfgAssertions . at l %= doAdd >> cont - where - doAdd mAssertions = Just (assertion : mAssertions ^. _Just) - exec (AskIdentifiers cont) = asks ci_identifiers >>= cont - exec (AskProgram cont) = asks ci_program >>= cont - exec (GetFuncSpec name cont) = do - ci <- ask - ci_getFuncSpec ci name & cont - exec (GetInvariant name cont) = do - ci <- ask - ci_getInvariant ci name & cont - exec (GetRets name cont) = do - ci <- ask - ci_getRets ci name >>= cont - exec (GetSvarSpecs cont) = - asks ci_svarSpecs >>= cont - exec (GetVerts l cont) = do - cfg <- get - cont $ verticesLabelledBy cfg l - exec (Throw t) = throwError t - exec (Catch m handler cont) = catchError (interpret m) (interpret . handler) >>= cont + where + exec (AddVertex l mbPreCheckedF cont) = do + freshVal <- cfgVertexCounter <%= succ + let newVertex = Vertex (Text.pack (show freshVal)) l mbPreCheckedF + vs <- gets cfg_vertices + -- Currently, the design is such that it is convenient to be able to distinguish + -- 'the unique vertex the entire codebase relies on' from vertices that exist + -- with the same label for one reason or the other, e.g. optimisation purposes. + -- Ideally, vertices would be treated uniformally, regardless of their raison d'etre, + -- removing the need for enforcing invariants like this. + if (not . isPreCheckingVertex) newVertex + && (not . null) [vert | vert <- vs, v_label vert == l, (not . isPreCheckingVertex) vert] + then throwError "At most one salient Vertex is allowed per PC." + else cfgVertices %= ([newVertex] `union`) >> cont newVertex + exec (AddArc lFrom lTo insts test isF cont) = cfgArcs . at lFrom %= doAdd >> cont + where + doAdd mArcs = Just ((lTo, insts, test, isF) : mArcs ^. _Just) + exec (AddAssertion l assertion cont) = cfgAssertions . at l %= doAdd >> cont + where + doAdd mAssertions = Just (assertion : mAssertions ^. _Just) + exec (AskIdentifiers cont) = asks ci_identifiers >>= cont + exec (AskProgram cont) = asks ci_program >>= cont + exec (GetFuncSpec name cont) = do + ci <- ask + ci_getFuncSpec ci name & cont + exec (GetInvariant name cont) = do + ci <- ask + ci_getInvariant ci name & cont + exec (GetRets name cont) = do + ci <- ask + ci_getRets ci name >>= cont + exec (GetSvarSpecs cont) = + asks ci_svarSpecs >>= cont + exec (GetVerts l cont) = do + cfg <- get + cont $ verticesLabelledBy cfg l + exec (Throw t) = throwError t + exec (Catch m handler cont) = catchError (interpret m) (interpret . handler) >>= cont runImpl :: ContractInfo -> Impl a -> Either Text CFG runImpl contractInfo m = do diff --git a/src/Horus/CairoSemantics.hs b/src/Horus/CairoSemantics.hs index a3f58856..0c6b2102 100644 --- a/src/Horus/CairoSemantics.hs +++ b/src/Horus/CairoSemantics.hs @@ -145,10 +145,9 @@ throw t = liftF (Throw t) enableStorage :: CairoSemanticsL () enableStorage = liftF (EnableStorage ()) -{- | Get an expression for the value of a storage variable with certain - arguments given a value of type `Storage`, which represents the state of all - storage variables during program execution at some specific point in time. --} +-- | Get an expression for the value of a storage variable with certain +-- arguments given a value of type `Storage`, which represents the state of all +-- storage variables during program execution at some specific point in time. readStorage :: Maybe Storage -> ScopedName -> [Expr TFelt] -> CairoSemanticsL (Expr TFelt) readStorage storage name args = liftF (ReadStorage storage name args id) @@ -172,10 +171,10 @@ expect a = expect' =<< memoryRemoval a memoryRemoval :: Expr a -> CairoSemanticsL (Expr a) memoryRemoval = Expr.transform step - where - step :: Expr b -> CairoSemanticsL (Expr b) - step (Memory x) = declareMem x - step e = pure e + where + step :: Expr b -> CairoSemanticsL (Expr b) + step (Memory x) = declareMem x + step e = pure e isInlinable :: ScopedFunction -> CairoSemanticsL Bool isInlinable f = liftF (IsInlinable f id) @@ -201,52 +200,50 @@ top = liftF (Top id) storageRemoval :: Expr a -> CairoSemanticsL (Expr a) storageRemoval = storageRemoval' Nothing -{- | Substitute a reference to a storage variable in an expression with its - value according to `storage :: Storage`. - - For example, suppose we have a storage variable called `state() : felt`. If - we reference this storage variable in the precondition for some function - `f`, for example in `// @pre state() == 5`, then when constructing the - assertions to represent this constraint, we must replace the symbolic name - `state()` in this expression with an expression for the actual value of the - storage variable just before the function `f` is called. - - This substitution is what `storageRemoval'` does, and it does it with - respect to the argument `storage :: Maybe Storage`, which represents the - state of all storage variables during program execution at a particular - point in time. - - Some better names: `resolveStorageReferences`, `resolveStorage`, - `expandStorageExpressions`, `substituteStorage`, or `dereferenceStorage`. --} +-- | Substitute a reference to a storage variable in an expression with its +-- value according to `storage :: Storage`. +-- +-- For example, suppose we have a storage variable called `state() : felt`. If +-- we reference this storage variable in the precondition for some function +-- `f`, for example in `// @pre state() == 5`, then when constructing the +-- assertions to represent this constraint, we must replace the symbolic name +-- `state()` in this expression with an expression for the actual value of the +-- storage variable just before the function `f` is called. +-- +-- This substitution is what `storageRemoval'` does, and it does it with +-- respect to the argument `storage :: Maybe Storage`, which represents the +-- state of all storage variables during program execution at a particular +-- point in time. +-- +-- Some better names: `resolveStorageReferences`, `resolveStorage`, +-- `expandStorageExpressions`, `substituteStorage`, or `dereferenceStorage`. storageRemoval' :: Maybe Storage -> Expr a -> CairoSemanticsL (Expr a) storageRemoval' storage = Expr.transform step - where - step :: Expr b -> CairoSemanticsL (Expr b) - step (StorageVar name args) = readStorage storage (ScopedName.fromText name) args - step e = pure e + where + step :: Expr b -> CairoSemanticsL (Expr b) + step (StorageVar name args) = readStorage storage (ScopedName.fromText name) args + step e = pure e substitute :: Text -> Expr TFelt -> Expr a -> Expr a substitute what forWhat = Expr.canonicalize . Expr.transformId step - where - step :: Expr b -> Expr b - step (Expr.cast @TFelt -> CastOk (Expr.Fun name)) | name == what = forWhat - step (Expr.cast @TBool -> CastOk (Expr.ExistsFelt name expr)) = Expr.ExistsFelt name (substitute what forWhat expr) - step e = e - -{- | Prepare the expression for usage in the model. - -That is, deduce AP from the ApTracking data by PC and replace FP name -with the given one. --} + where + step :: Expr b -> Expr b + step (Expr.cast @TFelt -> CastOk (Expr.Fun name)) | name == what = forWhat + step (Expr.cast @TBool -> CastOk (Expr.ExistsFelt name expr)) = Expr.ExistsFelt name (substitute what forWhat expr) + step e = e + +-- | Prepare the expression for usage in the model. +-- +-- That is, deduce AP from the ApTracking data by PC and replace FP name +-- with the given one. prepare :: Label -> Expr TFelt -> Expr a -> CairoSemanticsL (Expr a) prepare pc fp expr = getAp pc >>= \ap -> prepare' ap fp expr prepare' :: Expr TFelt -> Expr TFelt -> Expr a -> CairoSemanticsL (Expr a) prepare' ap fp expr = memoryRemoval (substitute "fp" fp (substitute "ap" ap expr)) -preparePost :: - Expr TFelt -> Expr TFelt -> Expr TBool -> Bool -> CairoSemanticsL (Expr TBool) +preparePost + :: Expr TFelt -> Expr TFelt -> Expr TBool -> Bool -> CairoSemanticsL (Expr TBool) preparePost ap fp expr isOptim = do if isOptim then do @@ -284,10 +281,9 @@ moduleEndAp mdl = [] -> getStackTraceDescr <&> Expr.const . ("ap!" <>) _ -> getAp' (Just callstack) pc where (callstack, pc) = m_lastPc mdl -{- | Gather the assertions and other state (in the `ConstraintsState` contained - in `CairoSemanticsL`) associated with a function specification that may - contain a storage update. --} +-- | Gather the assertions and other state (in the `ConstraintsState` contained +-- in `CairoSemanticsL`) associated with a function specification that may +-- contain a storage update. encodeModule :: Module -> CairoSemanticsL () encodeModule mdl@(Module (FuncSpec pre post storage) instrs oracle _ _ mbPreCheckedFuncWithCallStack) = do enableStorage @@ -328,54 +324,54 @@ exMemoryRemoval :: Expr TBool -> CairoSemanticsL ([MemoryVariable] -> Expr TBool exMemoryRemoval expr = do (expr', localMemVars, _referencesLocals) <- unsafeMemoryRemoval expr pure (intro expr' localMemVars) - where - exVars = gatherLogicalVariables expr - - restrictMemTail [] = [] - restrictMemTail (mv0 : rest) = - [ addr0 .== Expr.const mv_addrName .=> mem0 .== Expr.const mv_varName - | MemoryVariable{..} <- rest - ] - where - mem0 = Expr.const (mv_varName mv0) - addr0 = Expr.const (mv_addrName mv0) - - intro :: Expr TBool -> [MemoryVariable] -> [MemoryVariable] -> Expr TBool - intro ex lmv gmv = - let globMemRestrictions = - [ addr1 .== addr2 .=> Expr.const var1 .== Expr.const var2 - | MemoryVariable var1 _ addr1 <- lmv - , MemoryVariable var2 _ addr2 <- gmv - ] - locMemRestrictions = concatMap restrictMemTail (tails lmv) - innerExpr = Expr.and (ex : (locMemRestrictions ++ globMemRestrictions)) - quantLmv = foldr (\mvar e -> Expr.ExistsFelt (mv_varName mvar) e) innerExpr lmv - in foldr (\var e -> Expr.ExistsFelt var e) quantLmv exVars - - unsafeMemoryRemoval :: Expr a -> CairoSemanticsL (Expr a, [MemoryVariable], Bool) - unsafeMemoryRemoval (Memory addr) = do - (addr', localMemVars, referencesLocals) <- unsafeMemoryRemoval addr - if referencesLocals - then do - mv <- declareLocalMem addr' - pure (Expr.const (mv_varName mv), mv : localMemVars, True) - else do - mv <- declareMem addr' - pure (mv, localMemVars, False) - unsafeMemoryRemoval e@Expr.Felt{} = pure (e, [], False) - unsafeMemoryRemoval e@Expr.True = pure (e, [], False) - unsafeMemoryRemoval e@Expr.False = pure (e, [], False) - unsafeMemoryRemoval e@(Expr.Fun name) = pure (e, [], name `Set.member` exVars) - unsafeMemoryRemoval (f Expr.:*: x) = do - (f', localMemVars1, referencesLocals1) <- unsafeMemoryRemoval f - (x', localMemVars2, referencesLocals2) <- unsafeMemoryRemoval x - pure (f' Expr.:*: x', localMemVars2 <> localMemVars1, referencesLocals1 || referencesLocals2) - unsafeMemoryRemoval (Expr.ExistsFelt name e) = do - (e', localMemVars, referencesLocals) <- unsafeMemoryRemoval e - pure (Expr.ExistsFelt name e', localMemVars, referencesLocals) - unsafeMemoryRemoval (Expr.ExitField e) = do - (e', localMemVars, referencesLocals) <- unsafeMemoryRemoval e - pure (Expr.ExitField e', localMemVars, referencesLocals) + where + exVars = gatherLogicalVariables expr + + restrictMemTail [] = [] + restrictMemTail (mv0 : rest) = + [ addr0 .== Expr.const mv_addrName .=> mem0 .== Expr.const mv_varName + | MemoryVariable{..} <- rest + ] + where + mem0 = Expr.const (mv_varName mv0) + addr0 = Expr.const (mv_addrName mv0) + + intro :: Expr TBool -> [MemoryVariable] -> [MemoryVariable] -> Expr TBool + intro ex lmv gmv = + let globMemRestrictions = + [ addr1 .== addr2 .=> Expr.const var1 .== Expr.const var2 + | MemoryVariable var1 _ addr1 <- lmv + , MemoryVariable var2 _ addr2 <- gmv + ] + locMemRestrictions = concatMap restrictMemTail (tails lmv) + innerExpr = Expr.and (ex : (locMemRestrictions ++ globMemRestrictions)) + quantLmv = foldr (\mvar e -> Expr.ExistsFelt (mv_varName mvar) e) innerExpr lmv + in foldr (\var e -> Expr.ExistsFelt var e) quantLmv exVars + + unsafeMemoryRemoval :: Expr a -> CairoSemanticsL (Expr a, [MemoryVariable], Bool) + unsafeMemoryRemoval (Memory addr) = do + (addr', localMemVars, referencesLocals) <- unsafeMemoryRemoval addr + if referencesLocals + then do + mv <- declareLocalMem addr' + pure (Expr.const (mv_varName mv), mv : localMemVars, True) + else do + mv <- declareMem addr' + pure (mv, localMemVars, False) + unsafeMemoryRemoval e@Expr.Felt{} = pure (e, [], False) + unsafeMemoryRemoval e@Expr.True = pure (e, [], False) + unsafeMemoryRemoval e@Expr.False = pure (e, [], False) + unsafeMemoryRemoval e@(Expr.Fun name) = pure (e, [], name `Set.member` exVars) + unsafeMemoryRemoval (f Expr.:*: x) = do + (f', localMemVars1, referencesLocals1) <- unsafeMemoryRemoval f + (x', localMemVars2, referencesLocals2) <- unsafeMemoryRemoval x + pure (f' Expr.:*: x', localMemVars2 <> localMemVars1, referencesLocals1 || referencesLocals2) + unsafeMemoryRemoval (Expr.ExistsFelt name e) = do + (e', localMemVars, referencesLocals) <- unsafeMemoryRemoval e + pure (Expr.ExistsFelt name e', localMemVars, referencesLocals) + unsafeMemoryRemoval (Expr.ExitField e) = do + (e', localMemVars, referencesLocals) <- unsafeMemoryRemoval e + pure (Expr.ExitField e', localMemVars, referencesLocals) withExecutionCtx :: CallEntry -> CairoSemanticsL b -> CairoSemanticsL b withExecutionCtx ctx action = do @@ -384,28 +380,27 @@ withExecutionCtx ctx action = do pop pure res -{- | Records in the `ConstraintsState` (and in particular, in `cs_asserts` - field) the assertions corresponding with the semantics of `assert_eq` and - `call`, and possibly returns a felt expression that represents an FP. - - This is only used in `encodePlainSpec`, and so is essentially a helper function. - - We need this information because sometimes, when we call `getFp` in - `encodePlainSpec`, we get a value that is misleading as a result of the - optimising modules, which interrupt execution, meaning there may be a - missing Cairo `ret`. - - The return value is usually `Nothing` because most functions execute until - the end, matching every call with a `ret`. A return value of `Just fp` - represents the FP of the function that is on the top of the stack at the - point when the execution is interrupted. --} -mkInstructionConstraints :: - [LabeledInst] -> - Maybe (CallStack, ScopedFunction) -> - Map (NonEmpty Label, Label) Bool -> - (Int, LabeledInst) -> - CairoSemanticsL (Maybe (Expr TFelt)) +-- | Records in the `ConstraintsState` (and in particular, in `cs_asserts` +-- field) the assertions corresponding with the semantics of `assert_eq` and +-- `call`, and possibly returns a felt expression that represents an FP. +-- +-- This is only used in `encodePlainSpec`, and so is essentially a helper function. +-- +-- We need this information because sometimes, when we call `getFp` in +-- `encodePlainSpec`, we get a value that is misleading as a result of the +-- optimising modules, which interrupt execution, meaning there may be a +-- missing Cairo `ret`. +-- +-- The return value is usually `Nothing` because most functions execute until +-- the end, matching every call with a `ret`. A return value of `Just fp` +-- represents the FP of the function that is on the top of the stack at the +-- point when the execution is interrupted. +mkInstructionConstraints + :: [LabeledInst] + -> Maybe (CallStack, ScopedFunction) + -> Map (NonEmpty Label, Label) Bool + -> (Int, LabeledInst) + -> CairoSemanticsL (Maybe (Expr TFelt)) mkInstructionConstraints instrs mbPreCheckedFuncWithCallStack jnzOracle (idx, lInst@(pc, Instruction{..})) = do fp <- getFp dst <- prepare pc fp (memory (regToVar i_dstRegister + fromInteger i_dstOffset)) @@ -419,17 +414,17 @@ mkInstructionConstraints instrs mbPreCheckedFuncWithCallStack jnzOracle (idx, lI Just True -> assert (dst ./= 0) $> Nothing Nothing -> pure Nothing Ret -> pop $> Nothing - where - nextPc = getNextPcInlinedWithFallback instrs idx + where + nextPc = getNextPcInlinedWithFallback instrs idx -- | A particular case of mkInstructionConstraints for the instruction 'call'. -mkCallConstraints :: - Label -> - Label -> - Expr TFelt -> - Maybe (CallStack, ScopedFunction) -> - ScopedFunction -> - CairoSemanticsL (Maybe (Expr TFelt)) +mkCallConstraints + :: Label + -> Label + -> Expr TFelt + -> Maybe (CallStack, ScopedFunction) + -> ScopedFunction + -> CairoSemanticsL (Maybe (Expr TFelt)) mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack f = do calleeFp <- withExecutionCtx stackFrame getFp nextAp <- prepare pc calleeFp (Vars.fp .== Vars.ap + 2) @@ -465,24 +460,25 @@ mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack f = do assert preparedPre assert preparedPost pure Nothing - where - lvarSuffix = "+" <> tShowLabel pc - calleePc = sf_pc f - stackFrame = (pc, calleePc) - -- Determine whether the current function matches the function being optimised exactly - - -- this necessitates comparing execution traces. - isModuleCheckingPre = do - stackDescr <- getStackTraceDescr - preCheckedFuncStackDescr <- getStackTraceDescr' ((^. _1) <$> mbPreCheckedFuncWithCallStack) - pure (isJust mbPreCheckedFuncWithCallStack && stackDescr == preCheckedFuncStackDescr) - guardWith condM val cont = do cond <- condM; if cond then val else cont - -traverseStorage :: (forall a. Expr a -> CairoSemanticsL (Expr a)) -> Storage -> CairoSemanticsL Storage + where + lvarSuffix = "+" <> tShowLabel pc + calleePc = sf_pc f + stackFrame = (pc, calleePc) + -- Determine whether the current function matches the function being optimised exactly - + -- this necessitates comparing execution traces. + isModuleCheckingPre = do + stackDescr <- getStackTraceDescr + preCheckedFuncStackDescr <- getStackTraceDescr' ((^. _1) <$> mbPreCheckedFuncWithCallStack) + pure (isJust mbPreCheckedFuncWithCallStack && stackDescr == preCheckedFuncStackDescr) + guardWith condM val cont = do cond <- condM; if cond then val else cont + +traverseStorage + :: (forall a. Expr a -> CairoSemanticsL (Expr a)) -> Storage -> CairoSemanticsL Storage traverseStorage preparer = traverse prepareWrites - where - prepareWrites = traverse prepareWrite - prepareWrite (args, value) = (,) <$> traverse prepareExpr args <*> prepareExpr value - prepareExpr e = storageRemoval e >>= preparer + where + prepareWrites = traverse prepareWrite + prepareWrite (args, value) = (,) <$> traverse prepareExpr args <*> prepareExpr value + prepareExpr e = storageRemoval e >>= preparer mkApConstraints :: Expr TFelt -> NonEmpty LabeledInst -> CairoSemanticsL () mkApConstraints apEnd insts = do @@ -507,10 +503,11 @@ mkApConstraints apEnd insts = do getApIncrement fp lastLInst >>= \case Just lastApIncrement -> assert (lastAp + lastApIncrement .== apEnd) Nothing -> assert (lastAp .< apEnd) - where - lastLInst@(lastPc, lastInst) = NonEmpty.last insts + where + lastLInst@(lastPc, lastInst) = NonEmpty.last insts -mkBuiltinConstraints :: Expr TFelt -> NonEmpty LabeledInst -> Maybe (CallStack, ScopedFunction) -> CairoSemanticsL () +mkBuiltinConstraints + :: Expr TFelt -> NonEmpty LabeledInst -> Maybe (CallStack, ScopedFunction) -> CairoSemanticsL () mkBuiltinConstraints apEnd insts optimisesF = unless (isJust optimisesF) $ do fp <- getFp @@ -524,14 +521,14 @@ mkBuiltinConstraints apEnd insts optimisesF = mkBuiltinConstraintsForInst i (NonEmpty.toList insts) b inst Nothing -> checkBuiltinNotRequired b (toList insts) -getBuiltinContract :: - Expr TFelt -> Expr TFelt -> Builtin -> BuiltinOffsets -> (Expr TBool, Expr TBool) +getBuiltinContract + :: Expr TFelt -> Expr TFelt -> Builtin -> BuiltinOffsets -> (Expr TBool, Expr TBool) getBuiltinContract fp apEnd b bo = (pre, post) - where - pre = builtinAligned initialPtr b .&& finalPtr .<= builtinEnd b - post = initialPtr .<= finalPtr .&& builtinAligned finalPtr b - initialPtr = memory (fp - fromIntegral (bo_input bo)) - finalPtr = memory (apEnd - fromIntegral (bo_output bo)) + where + pre = builtinAligned initialPtr b .&& finalPtr .<= builtinEnd b + post = initialPtr .<= finalPtr .&& builtinAligned finalPtr b + initialPtr = memory (fp - fromIntegral (bo_input bo)) + finalPtr = memory (apEnd - fromIntegral (bo_output bo)) mkBuiltinConstraintsForInst :: Int -> [LabeledInst] -> Builtin -> LabeledInst -> CairoSemanticsL () mkBuiltinConstraintsForInst pos instrs b inst@(pc, Instruction{..}) = @@ -553,30 +550,30 @@ mkBuiltinConstraintsForInst pos instrs b inst@(pc, Instruction{..}) = -- 'ret's are not in the bytecote for functions that are not inlinable Ret -> mkBuiltinConstraintsForFunc True _ -> pure () - where - mkBuiltinConstraintsForFunc canInline = do - calleeFp <- getFp - callEntry@(_, calleePc) <- top <* pop - whenJustM (getBuiltinOffsets calleePc b) $ \bo -> do - calleeApEnd <- - if canInline - then withExecutionCtx callEntry (getAp pc) - else getAp (getNextPcInlinedWithFallback instrs pos) - let (pre, post) = getBuiltinContract calleeFp calleeApEnd b bo - expect pre *> assert post + where + mkBuiltinConstraintsForFunc canInline = do + calleeFp <- getFp + callEntry@(_, calleePc) <- top <* pop + whenJustM (getBuiltinOffsets calleePc b) $ \bo -> do + calleeApEnd <- + if canInline + then withExecutionCtx callEntry (getAp pc) + else getAp (getNextPcInlinedWithFallback instrs pos) + let (pre, post) = getBuiltinContract calleeFp calleeApEnd b bo + expect pre *> assert post checkBuiltinNotRequired :: Builtin -> [LabeledInst] -> CairoSemanticsL () checkBuiltinNotRequired b = traverse_ check - where - check inst = whenJust (callDestination inst) $ \calleePc -> - whenJustM (getBuiltinOffsets calleePc b) $ \_ -> - throw - ( "The function doesn't require the '" - <> Builtin.name b - <> "' builtin, but calls a function (at PC " - <> tShow (unLabel calleePc) - <> ") that does require it" - ) + where + check inst = whenJust (callDestination inst) $ \calleePc -> + whenJustM (getBuiltinOffsets calleePc b) $ \_ -> + throw + ( "The function doesn't require the '" + <> Builtin.name b + <> "' builtin, but calls a function (at PC " + <> tShow (unLabel calleePc) + <> ") that does require it" + ) getRes :: Expr TFelt -> LabeledInst -> CairoSemanticsL (Expr TFelt) getRes fp (pc, Instruction{..}) = do diff --git a/src/Horus/CairoSemantics/Runner.hs b/src/Horus/CairoSemantics/Runner.hs index 6ed768a8..2764c986 100644 --- a/src/Horus/CairoSemantics/Runner.hs +++ b/src/Horus/CairoSemantics/Runner.hs @@ -31,7 +31,12 @@ import Lens.Micro (Lens', (%~), (<&>), (^.)) import Lens.Micro.GHC () import Lens.Micro.Mtl (use, (%=), (.=), (<%=)) -import Horus.CairoSemantics (AssertionType (PreAssertion), CairoSemanticsF (..), CairoSemanticsL, MemoryVariable (..)) +import Horus.CairoSemantics + ( AssertionType (PreAssertion) + , CairoSemanticsF (..) + , CairoSemanticsL + , MemoryVariable (..) + ) import Horus.CallStack (CallStack, digestOfCallStack, pop, push, reset, stackTrace, top) import Horus.Command.SMT qualified as Command import Horus.ContractInfo (ContractInfo (..)) @@ -119,79 +124,79 @@ type Impl = ReaderT ContractInfo (ExceptT Text (State Env)) interpret :: forall a. CairoSemanticsL a -> Impl a interpret = iterM exec - where - exec :: CairoSemanticsF (Impl a) -> Impl a - exec (Assert' a assType cont) = eConstraints . csAsserts %= ((QFAss a, assType) :) >> cont - exec (Expect' a assType cont) = eConstraints . csExpects %= ((a, assType) :) >> cont - exec (DeclareMem address cont) = do - memVars <- use (eConstraints . csMemoryVariables) - case List.find ((address ==) . mv_addrExpr) memVars of - Just MemoryVariable{..} -> cont (Expr.const mv_varName) - Nothing -> do - freshCount <- eConstraints . csNameCounter <%= (+ 1) - let name = "MEM!" <> tShow freshCount - let addrName = "ADDR!" <> tShow freshCount - eConstraints . csMemoryVariables %= (MemoryVariable name addrName address :) - cont (Expr.const name) - exec (DeclareLocalMem address cont) = do - memVars <- use (eConstraints . csMemoryVariables) - case List.find ((address ==) . mv_addrExpr) memVars of - Just mv -> cont mv - Nothing -> do - freshCount <- eConstraints . csNameCounter <%= (+ 1) - let name = "MEM!" <> tShow freshCount - let addrName = "ADDR!" <> tShow freshCount - cont (MemoryVariable name addrName address) - exec (GetApTracking label cont) = do - ci <- ask - ci_getApTracking ci label >>= cont - exec (GetBuiltinOffsets label builtin cont) = do - ci <- ask - ci_getBuiltinOffsets ci label builtin >>= cont - exec (GetCallee inst cont) = do - ci <- ask - ci_getCallee ci inst >>= cont - exec (GetFuncSpec name cont) = do - ci <- ask - ci_getFuncSpec ci name & cont - exec (GetFunPc label cont) = do - ci <- ask - ci_getFunPc ci label >>= cont - exec (GetInlinable cont) = do - ask >>= cont . ci_inlinables - exec (GetStackTraceDescr callstack cont) = do - fNames <- asks (Map.map sf_scopedName . ci_functions) - case callstack of - Nothing -> get >>= cont . digestOfCallStack fNames . (^. csCallStack) . e_constraints - Just stack -> cont $ digestOfCallStack fNames stack - exec (GetMemVars cont) = do - use (eConstraints . csMemoryVariables) >>= cont - exec (GetOracle cont) = do - get >>= cont . stackTrace . (^. csCallStack) . e_constraints - exec (IsInlinable label cont) = do - inlinableFs <- asks ci_inlinables - cont (label `elem` inlinableFs) - exec (Push entry cont) = eConstraints . csCallStack %= push entry >> cont - exec (Pop cont) = eConstraints . csCallStack %= (snd . pop) >> cont - exec (Top cont) = do - get >>= cont . top . (^. csCallStack) . e_constraints - exec (EnableStorage cont) = eStorageEnabled .= True >> cont - exec (ReadStorage mbStorage name args cont) = do - storage <- case mbStorage of Nothing -> use eStorage; Just st -> pure st - cont (Storage.read storage name args) - exec (ResetStack cont) = eConstraints . csCallStack %= reset >> cont - exec (UpdateStorage newStorage cont) = do - storageEnabled <- use eStorageEnabled - unless (storageEnabled || Map.null newStorage) $ - throwError plainSpecStorageAccessErr - oldStorage <- use eStorage - let combined = Map.unionWith (<>) newStorage oldStorage - eStorage .= combined >> cont - exec (GetStorage cont) = use eStorage >>= cont - exec (Throw t) = throwError t + where + exec :: CairoSemanticsF (Impl a) -> Impl a + exec (Assert' a assType cont) = eConstraints . csAsserts %= ((QFAss a, assType) :) >> cont + exec (Expect' a assType cont) = eConstraints . csExpects %= ((a, assType) :) >> cont + exec (DeclareMem address cont) = do + memVars <- use (eConstraints . csMemoryVariables) + case List.find ((address ==) . mv_addrExpr) memVars of + Just MemoryVariable{..} -> cont (Expr.const mv_varName) + Nothing -> do + freshCount <- eConstraints . csNameCounter <%= (+ 1) + let name = "MEM!" <> tShow freshCount + let addrName = "ADDR!" <> tShow freshCount + eConstraints . csMemoryVariables %= (MemoryVariable name addrName address :) + cont (Expr.const name) + exec (DeclareLocalMem address cont) = do + memVars <- use (eConstraints . csMemoryVariables) + case List.find ((address ==) . mv_addrExpr) memVars of + Just mv -> cont mv + Nothing -> do + freshCount <- eConstraints . csNameCounter <%= (+ 1) + let name = "MEM!" <> tShow freshCount + let addrName = "ADDR!" <> tShow freshCount + cont (MemoryVariable name addrName address) + exec (GetApTracking label cont) = do + ci <- ask + ci_getApTracking ci label >>= cont + exec (GetBuiltinOffsets label builtin cont) = do + ci <- ask + ci_getBuiltinOffsets ci label builtin >>= cont + exec (GetCallee inst cont) = do + ci <- ask + ci_getCallee ci inst >>= cont + exec (GetFuncSpec name cont) = do + ci <- ask + ci_getFuncSpec ci name & cont + exec (GetFunPc label cont) = do + ci <- ask + ci_getFunPc ci label >>= cont + exec (GetInlinable cont) = do + ask >>= cont . ci_inlinables + exec (GetStackTraceDescr callstack cont) = do + fNames <- asks (Map.map sf_scopedName . ci_functions) + case callstack of + Nothing -> get >>= cont . digestOfCallStack fNames . (^. csCallStack) . e_constraints + Just stack -> cont $ digestOfCallStack fNames stack + exec (GetMemVars cont) = do + use (eConstraints . csMemoryVariables) >>= cont + exec (GetOracle cont) = do + get >>= cont . stackTrace . (^. csCallStack) . e_constraints + exec (IsInlinable label cont) = do + inlinableFs <- asks ci_inlinables + cont (label `elem` inlinableFs) + exec (Push entry cont) = eConstraints . csCallStack %= push entry >> cont + exec (Pop cont) = eConstraints . csCallStack %= (snd . pop) >> cont + exec (Top cont) = do + get >>= cont . top . (^. csCallStack) . e_constraints + exec (EnableStorage cont) = eStorageEnabled .= True >> cont + exec (ReadStorage mbStorage name args cont) = do + storage <- case mbStorage of Nothing -> use eStorage; Just st -> pure st + cont (Storage.read storage name args) + exec (ResetStack cont) = eConstraints . csCallStack %= reset >> cont + exec (UpdateStorage newStorage cont) = do + storageEnabled <- use eStorageEnabled + unless (storageEnabled || Map.null newStorage) $ + throwError plainSpecStorageAccessErr + oldStorage <- use eStorage + let combined = Map.unionWith (<>) newStorage oldStorage + eStorage .= combined >> cont + exec (GetStorage cont) = use eStorage >>= cont + exec (Throw t) = throwError t - plainSpecStorageAccessErr :: Text - plainSpecStorageAccessErr = "Storage access isn't allowed in a plain spec." + plainSpecStorageAccessErr :: Text + plainSpecStorageAccessErr = "Storage access isn't allowed in a plain spec." debugFriendlyModel :: ConstraintsState -> Text debugFriendlyModel ConstraintsState{..} = @@ -204,11 +209,11 @@ debugFriendlyModel ConstraintsState{..} = , ["# Expect"] , map (pprExpr . fst) cs_expects ] - where - memoryPairs = - [ mv_varName <> "=[" <> pprExpr mv_addrExpr <> "]" - | MemoryVariable{..} <- cs_memoryVariables - ] + where + memoryPairs = + [ mv_varName <> "=[" <> pprExpr mv_addrExpr <> "]" + | MemoryVariable{..} <- cs_memoryVariables + ] restrictMemTail :: [MemoryVariable] -> [Expr TBool] restrictMemTail [] = [] @@ -216,53 +221,53 @@ restrictMemTail (mv0 : rest) = [ addr0 .== Expr.const mv_addrName .=> mem0 .== Expr.const mv_varName | MemoryVariable{..} <- rest ] - where - mem0 = Expr.const (mv_varName mv0) - addr0 = Expr.const (mv_addrName mv0) + where + mem0 = Expr.const (mv_varName mv0) + addr0 = Expr.const (mv_addrName mv0) makeModel :: Bool -> ConstraintsState -> Integer -> Text makeModel checkPreOnly ConstraintsState{..} fPrime = Text.intercalate "\n" (decls <> map (Command.assert fPrime) restrictions) - where - functions = - toList (foldMap gatherNonStdFunctions generalRestrictions <> gatherNonStdFunctions prime) - decls = map (foldSome Command.declare) functions - rangeRestrictions = mapMaybe (foldSome restrictRange) functions - memRestrictions = concatMap restrictMemTail (List.tails cs_memoryVariables) - addrRestrictions = - [Expr.const mv_addrName .== mv_addrExpr | MemoryVariable{..} <- cs_memoryVariables] + where + functions = + toList (foldMap gatherNonStdFunctions generalRestrictions <> gatherNonStdFunctions prime) + decls = map (foldSome Command.declare) functions + rangeRestrictions = mapMaybe (foldSome restrictRange) functions + memRestrictions = concatMap restrictMemTail (List.tails cs_memoryVariables) + addrRestrictions = + [Expr.const mv_addrName .== mv_addrExpr | MemoryVariable{..} <- cs_memoryVariables] - -- If checking @pre only, only take `PreAssertion`s, no postconditions. - allowedAsserts = if checkPreOnly then filter ((== PreAssertion) . snd) cs_asserts else cs_asserts - allowedExpects = if checkPreOnly then [] else cs_expects + -- If checking @pre only, only take `PreAssertion`s, no postconditions. + allowedAsserts = if checkPreOnly then filter ((== PreAssertion) . snd) cs_asserts else cs_asserts + allowedExpects = if checkPreOnly then [] else cs_expects - generalRestrictions = - concat - [ memRestrictions - , addrRestrictions - , map (builderToAss cs_memoryVariables . fst) allowedAsserts - , [Expr.not (Expr.and (map fst allowedExpects)) | not (null allowedExpects)] - ] - restrictions = rangeRestrictions <> generalRestrictions + generalRestrictions = + concat + [ memRestrictions + , addrRestrictions + , map (builderToAss cs_memoryVariables . fst) allowedAsserts + , [Expr.not (Expr.and (map fst allowedExpects)) | not (null allowedExpects)] + ] + restrictions = rangeRestrictions <> generalRestrictions - restrictRange :: forall ty. Function ty -> Maybe (Expr TBool) - restrictRange (Function name) = case sing @ty of - SFelt - | Just value <- name `lookup` constants -> Just (ExitField (var .== fromInteger value)) - | otherwise -> Just (0 .<= var .&& var .< prime) - where - var = Fun name - constants :: [(Text, Integer)] - constants = [(pprExpr prime, fPrime), (pprExpr rcBound, Builtin.rcBound)] - _ -> Nothing + restrictRange :: forall ty. Function ty -> Maybe (Expr TBool) + restrictRange (Function name) = case sing @ty of + SFelt + | Just value <- name `lookup` constants -> Just (ExitField (var .== fromInteger value)) + | otherwise -> Just (0 .<= var .&& var .< prime) + where + var = Fun name + constants :: [(Text, Integer)] + constants = [(pprExpr prime, fPrime), (pprExpr rcBound, Builtin.rcBound)] + _ -> Nothing runImpl :: CallStack -> ContractInfo -> Impl a -> Either Text ConstraintsState runImpl initStack contractInfo m = v $> e_constraints env - where - (v, env) = - runReaderT m contractInfo - & runExceptT - & flip runState (emptyEnv initStack) + where + (v, env) = + runReaderT m contractInfo + & runExceptT + & flip runState (emptyEnv initStack) run :: CallStack -> ContractInfo -> CairoSemanticsL a -> Either Text ConstraintsState run initStack contractInfo a = diff --git a/src/Horus/CallStack.hs b/src/Horus/CallStack.hs index c11b300a..0fa2fa82 100644 --- a/src/Horus/CallStack.hs +++ b/src/Horus/CallStack.hs @@ -75,13 +75,13 @@ callerOfRoot = Label (-1) digestOfStackTrace :: Map Label ScopedName -> NonEmpty CallEntry -> Text digestOfStackTrace names = wrap . foldr (flip (<>) . tShowCaller) "" - where - tShowCaller (Label pc, calledF) = - if pc == unLabel callerOfRoot - then "root" - else tShow pc <> "=" <> fName (names ! calledF) <> "/" - wrap trace = "<" <> trace <> ">" - fName (ScopedName name) = last name + where + tShowCaller (Label pc, calledF) = + if pc == unLabel callerOfRoot + then "root" + else tShow pc <> "=" <> fName (names ! calledF) <> "/" + wrap trace = "<" <> trace <> ">" + fName (ScopedName name) = last name digestOfCallStack :: Map Label ScopedName -> CallStack -> Text digestOfCallStack names = digestOfStackTrace names . descriptiveStackTrace diff --git a/src/Horus/Command/SMT.hs b/src/Horus/Command/SMT.hs index 9da35af2..f36bd032 100644 --- a/src/Horus/Command/SMT.hs +++ b/src/Horus/Command/SMT.hs @@ -13,10 +13,10 @@ import Horus.Expr.Type.SMT qualified as Ty (toSMT) declare :: forall ty. Function ty -> Text declare (Function name) = pack (printf "(declare-fun %s (%s) %s)" name args res) - where - args = unwords [SMT.showsSExpr x "" | x <- argTys] - res = SMT.showsSExpr resTy "" - (resTy :| argTys) = Ty.toSMT @ty + where + args = unwords [SMT.showsSExpr x "" | x <- argTys] + res = SMT.showsSExpr resTy "" + (resTy :| argTys) = Ty.toSMT @ty assert :: Integer -> Expr TBool -> Text assert fPrime e = pack (printf "(assert %s)" (SMT.showsSExpr (Expr.toSMT fPrime e) "")) diff --git a/src/Horus/ContractInfo.hs b/src/Horus/ContractInfo.hs index a60a5314..870094d0 100644 --- a/src/Horus/ContractInfo.hs +++ b/src/Horus/ContractInfo.hs @@ -13,10 +13,30 @@ import Data.Text (Text) import Horus.ContractDefinition (ContractDefinition (..)) import Horus.Expr (Expr, Ty (..)) -import Horus.FunctionAnalysis (ScopedFunction (..), inlinableFuns, mkGeneratedNames, storageVarsOfCD) -import Horus.Instruction (LabeledInst, callDestination, isRet, labelInstructions, readAllInstructions, toSemiAsmUnsafe) +import Horus.FunctionAnalysis + ( ScopedFunction (..) + , inlinableFuns + , mkGeneratedNames + , storageVarsOfCD + ) +import Horus.Instruction + ( LabeledInst + , callDestination + , isRet + , labelInstructions + , readAllInstructions + , toSemiAsmUnsafe + ) import Horus.Label (Label) -import Horus.Program (ApTracking, DebugInfo (..), FlowTrackingData (..), ILInfo (..), Identifiers, Program (..), sizeOfType) +import Horus.Program + ( ApTracking + , DebugInfo (..) + , FlowTrackingData (..) + , ILInfo (..) + , Identifiers + , Program (..) + , sizeOfType + ) import Horus.SW.Builtin (Builtin, BuiltinOffsets (..)) import Horus.SW.Builtin qualified as Builtin (ptrName) import Horus.SW.CairoType (CairoType (..)) @@ -83,144 +103,144 @@ mkContractInfo cd = do , ci_getCallee = getCallee , ci_getRets = mkGetRets retsByFun } - where - ---- plain data - debugInfo = p_debugInfo (cd_program cd) - identifiers = p_identifiers (cd_program cd) - instructionLocations = di_instructionLocations debugInfo - program = cd_program cd - storageVarsNames = storageVarsOfCD cd - - functions :: [(ScopedName, Label)] - functions = mapMaybe (\(name, f) -> (name,) <$> getFunctionPc f) (Map.toList identifiers) - - ---- functions, purely computable from plain data - callDestination' :: MonadError Text m => LabeledInst -> m Label - callDestination' i = maybeToError msg (callDestination i) - where - msg = "Can't find the call destination of " <> toSemiAsmUnsafe (snd i) - - getApTracking :: MonadError Text m => Label -> m ApTracking - getApTracking l = case instructionLocations Map.!? l of - Nothing -> throwError ("There is no instruction_locations entry for '" <> tShow l <> "'") - Just il -> pure (il_flowTrackingData il & ftd_apTracking) - - getBuiltinOffsets :: MonadError Text m => Label -> Builtin -> m (Maybe BuiltinOffsets) - getBuiltinOffsets l b = do - funName <- getFunName' l - args <- getStruct (funName <> "Args") - implicits <- getStruct (funName <> "ImplicitArgs") - returns <- getTypeDef (funName <> "Return") - outputOffset <- getOutputOffset (Builtin.ptrName b) returns implicits - pure $ - BuiltinOffsets - <$> getInputOffset (Builtin.ptrName b) args implicits - <*> outputOffset - where - getStruct name = case identifiers Map.! name of - IStruct s -> pure s - _ -> throwError ("Expected '" <> tShow name <> "' to have a 'struct' type") - getTypeDef name = case identifiers Map.! name of - IType t -> pure t - _ -> throwError ("Expected '" <> tShow name <> "' to have a 'type_definition' type") - getInputOffset n args implicits = - asum - [ st_members args Map.!? n - <&> \m -> -me_offset m + 2 + st_size args - , st_members implicits Map.!? n - <&> \m -> -me_offset m + 2 + st_size args + st_size implicits - ] - getOutputOffset n returns@(TypeTuple mems) implicits = do - returnSize <- sizeOfType returns identifiers - asum - <$> sequenceA - [ sequenceA $ - lookup (Just (fromText n)) mems - <&> \m -> - case elemIndex (Just $ fromText n, m) mems of - Just memIndex -> do - offset <- sizeOfType (TypeTuple $ take (memIndex + 1) mems) identifiers - pure $ -offset + returnSize - Nothing -> throwError "This not supposed to be reachable." - , pure $ - st_members implicits Map.!? n - <&> \m -> -me_offset m + returnSize + st_size implicits - ] - getOutputOffset n returns implicits = getOutputOffset n (TypeTuple [(Nothing, Just returns)]) implicits - - getCallee :: MonadError Text m => LabeledInst -> m ScopedFunction - getCallee inst = do - callee <- callDestination' inst - name <- getFunName' callee - pure $ ScopedFunction name callee - - getFunName :: Label -> Maybe ScopedName - getFunName l = do - ilInfo <- instructionLocations Map.!? l - safeLast (il_accessibleScopes ilInfo) - - getFunName' :: MonadError Text m => Label -> m ScopedName - getFunName' l = - maybeToError ("Can't find a function enclosing '" <> tShow l <> "'") $ - getFunName l - - getFunPc :: MonadError Text m => Label -> m Label - getFunPc l = do - name <- getFunName' l - getFunctionPc (identifiers Map.! name) - & maybeToError ("'" <> tShow name <> "' isn't a function") - - getFuncSpec :: ScopedFunction -> FuncSpec' - getFuncSpec name = - maybe - emptyFuncSpec' - ( \FuncSpec{..} -> - FuncSpec' - { fs'_pre = Just fs_pre - , fs'_post = Just fs_post - , fs'_storage = fs_storage - } - ) - $ allSpecs Map.!? sf_scopedName name - - allSpecs :: Map ScopedName FuncSpec - allSpecs = Map.union (cd_specs cd) storageVarsSpecs - - storageVarsSpecs :: Map ScopedName FuncSpec - storageVarsSpecs = - Map.foldrWithKey - ( \name arity m -> - Map.insert (name <> "read") (mkReadSpec name arity) $ - Map.insert (name <> "write") (mkWriteSpec name arity) m - ) - Map.empty - (cd_storageVars cd) - - getInvariant :: ScopedName -> Maybe (Expr TBool) - getInvariant name = Map.lookup name (cd_invariants cd) - - ---- non-plain data producers that depend on the outer monad (likely, for errors) - mkInstructions :: Integer -> m' [LabeledInst] - mkInstructions fPrime = fmap labelInstructions (readAllInstructions fPrime (p_code (cd_program cd))) - - mkRetsByFun :: [LabeledInst] -> m' (Map ScopedName [Label]) - mkRetsByFun insts = do - retAndFun <- sequenceA [fmap (,[pc]) (getFunName' pc) | (pc, inst) <- insts, isRet inst] - let preliminaryRes = Map.fromListWith (++) retAndFun - -- Note that `preliminaryRes` doesn't contain info about functions with - -- zero returns. A function might not contain returns when it ends with an - -- endless loop. - let insertFunWithNoRets fun = Map.insertWith (\_new old -> old) fun [] - pure (foldr (insertFunWithNoRets . fst) preliminaryRes functions) - - mkGetRets :: MonadError Text m => Map ScopedName [Label] -> ScopedName -> m [Label] - mkGetRets retsByFun name = maybeToError msg (retsByFun Map.!? name) - where - msg = "Can't find 'ret' instructions for " <> tShow name <> ". Is it a function?" - - mkSources :: [ScopedName] -> [(Function, ScopedName, FuncSpec)] - mkSources generatedNames = - [ (f, name, toFuncSpec . getFuncSpec . ScopedFunction name $ fu_pc f) - | (name, IFunction f) <- Map.toList identifiers - , name `notElem` generatedNames - ] + where + ---- plain data + debugInfo = p_debugInfo (cd_program cd) + identifiers = p_identifiers (cd_program cd) + instructionLocations = di_instructionLocations debugInfo + program = cd_program cd + storageVarsNames = storageVarsOfCD cd + + functions :: [(ScopedName, Label)] + functions = mapMaybe (\(name, f) -> (name,) <$> getFunctionPc f) (Map.toList identifiers) + + ---- functions, purely computable from plain data + callDestination' :: MonadError Text m => LabeledInst -> m Label + callDestination' i = maybeToError msg (callDestination i) + where + msg = "Can't find the call destination of " <> toSemiAsmUnsafe (snd i) + + getApTracking :: MonadError Text m => Label -> m ApTracking + getApTracking l = case instructionLocations Map.!? l of + Nothing -> throwError ("There is no instruction_locations entry for '" <> tShow l <> "'") + Just il -> pure (il_flowTrackingData il & ftd_apTracking) + + getBuiltinOffsets :: MonadError Text m => Label -> Builtin -> m (Maybe BuiltinOffsets) + getBuiltinOffsets l b = do + funName <- getFunName' l + args <- getStruct (funName <> "Args") + implicits <- getStruct (funName <> "ImplicitArgs") + returns <- getTypeDef (funName <> "Return") + outputOffset <- getOutputOffset (Builtin.ptrName b) returns implicits + pure $ + BuiltinOffsets + <$> getInputOffset (Builtin.ptrName b) args implicits + <*> outputOffset + where + getStruct name = case identifiers Map.! name of + IStruct s -> pure s + _ -> throwError ("Expected '" <> tShow name <> "' to have a 'struct' type") + getTypeDef name = case identifiers Map.! name of + IType t -> pure t + _ -> throwError ("Expected '" <> tShow name <> "' to have a 'type_definition' type") + getInputOffset n args implicits = + asum + [ st_members args Map.!? n + <&> \m -> -me_offset m + 2 + st_size args + , st_members implicits Map.!? n + <&> \m -> -me_offset m + 2 + st_size args + st_size implicits + ] + getOutputOffset n returns@(TypeTuple mems) implicits = do + returnSize <- sizeOfType returns identifiers + asum + <$> sequenceA + [ sequenceA $ + lookup (Just (fromText n)) mems + <&> \m -> + case elemIndex (Just $ fromText n, m) mems of + Just memIndex -> do + offset <- sizeOfType (TypeTuple $ take (memIndex + 1) mems) identifiers + pure $ -offset + returnSize + Nothing -> throwError "This not supposed to be reachable." + , pure $ + st_members implicits Map.!? n + <&> \m -> -me_offset m + returnSize + st_size implicits + ] + getOutputOffset n returns implicits = getOutputOffset n (TypeTuple [(Nothing, Just returns)]) implicits + + getCallee :: MonadError Text m => LabeledInst -> m ScopedFunction + getCallee inst = do + callee <- callDestination' inst + name <- getFunName' callee + pure $ ScopedFunction name callee + + getFunName :: Label -> Maybe ScopedName + getFunName l = do + ilInfo <- instructionLocations Map.!? l + safeLast (il_accessibleScopes ilInfo) + + getFunName' :: MonadError Text m => Label -> m ScopedName + getFunName' l = + maybeToError ("Can't find a function enclosing '" <> tShow l <> "'") $ + getFunName l + + getFunPc :: MonadError Text m => Label -> m Label + getFunPc l = do + name <- getFunName' l + getFunctionPc (identifiers Map.! name) + & maybeToError ("'" <> tShow name <> "' isn't a function") + + getFuncSpec :: ScopedFunction -> FuncSpec' + getFuncSpec name = + maybe + emptyFuncSpec' + ( \FuncSpec{..} -> + FuncSpec' + { fs'_pre = Just fs_pre + , fs'_post = Just fs_post + , fs'_storage = fs_storage + } + ) + $ allSpecs Map.!? sf_scopedName name + + allSpecs :: Map ScopedName FuncSpec + allSpecs = Map.union (cd_specs cd) storageVarsSpecs + + storageVarsSpecs :: Map ScopedName FuncSpec + storageVarsSpecs = + Map.foldrWithKey + ( \name arity m -> + Map.insert (name <> "read") (mkReadSpec name arity) $ + Map.insert (name <> "write") (mkWriteSpec name arity) m + ) + Map.empty + (cd_storageVars cd) + + getInvariant :: ScopedName -> Maybe (Expr TBool) + getInvariant name = Map.lookup name (cd_invariants cd) + + ---- non-plain data producers that depend on the outer monad (likely, for errors) + mkInstructions :: Integer -> m' [LabeledInst] + mkInstructions fPrime = fmap labelInstructions (readAllInstructions fPrime (p_code (cd_program cd))) + + mkRetsByFun :: [LabeledInst] -> m' (Map ScopedName [Label]) + mkRetsByFun insts = do + retAndFun <- sequenceA [fmap (,[pc]) (getFunName' pc) | (pc, inst) <- insts, isRet inst] + let preliminaryRes = Map.fromListWith (++) retAndFun + -- Note that `preliminaryRes` doesn't contain info about functions with + -- zero returns. A function might not contain returns when it ends with an + -- endless loop. + let insertFunWithNoRets fun = Map.insertWith (\_new old -> old) fun [] + pure (foldr (insertFunWithNoRets . fst) preliminaryRes functions) + + mkGetRets :: MonadError Text m => Map ScopedName [Label] -> ScopedName -> m [Label] + mkGetRets retsByFun name = maybeToError msg (retsByFun Map.!? name) + where + msg = "Can't find 'ret' instructions for " <> tShow name <> ". Is it a function?" + + mkSources :: [ScopedName] -> [(Function, ScopedName, FuncSpec)] + mkSources generatedNames = + [ (f, name, toFuncSpec . getFuncSpec . ScopedFunction name $ fu_pc f) + | (name, IFunction f) <- Map.toList identifiers + , name `notElem` generatedNames + ] diff --git a/src/Horus/Expr.hs b/src/Horus/Expr.hs index da4e9599..218a712c 100644 --- a/src/Horus/Expr.hs +++ b/src/Horus/Expr.hs @@ -159,11 +159,11 @@ transform_ f = transform (\x -> f x $> x) canonicalize :: Expr a -> Expr a canonicalize = transformId step - where - step (a :+ b) = a + b - step (a :- b) = a - b - step (Negate a) = negate a - step a = a + where + step (a :+ b) = a + b + step (a :- b) = a - b + step (Negate a) = negate a + step a = a -- pattern synonyms @@ -205,47 +205,47 @@ apply1 acc xs = apply acc (toList xs) apply1' :: SingI c => (forall a. SingI a => Expr a) -> Expr c -> [Expr b] -> Expr c apply1' acc whenEmpty = maybe whenEmpty (apply1 acc) . nonEmpty -unfoldVariadic :: - forall arg res ty. - (Typeable arg, Typeable res) => - Expr ty -> - Maybe (ty :~: res, Text, [Expr arg]) +unfoldVariadic + :: forall arg res ty + . (Typeable arg, Typeable res) + => Expr ty + -> Maybe (ty :~: res, Text, [Expr arg]) unfoldVariadic e = do Refl <- eqT @res @ty \\ isProper e (name, args) <- gatherArgs [] e pure (Refl, name, args) - where - gatherArgs :: [Expr arg] -> Expr ty' -> Maybe (Text, [Expr arg]) - gatherArgs acc (f :*: x) = do - x' <- cast' @arg x - gatherArgs (x' : acc) f - gatherArgs acc (Fun name) = pure (name, acc) - gatherArgs _ _ = Nothing - -pattern FeltConst :: () => (a ~ TFelt) => Text -> Expr a + where + gatherArgs :: [Expr arg] -> Expr ty' -> Maybe (Text, [Expr arg]) + gatherArgs acc (f :*: x) = do + x' <- cast' @arg x + gatherArgs (x' : acc) f + gatherArgs acc (Fun name) = pure (name, acc) + gatherArgs _ _ = Nothing + +pattern FeltConst :: () => a ~ TFelt => Text -> Expr a pattern FeltConst name <- (cast @TFelt -> CastOk (Fun name)) where FeltConst = const -pattern (:+) :: () => (a ~ TFelt) => Expr TFelt -> Expr TFelt -> Expr a +pattern (:+) :: () => a ~ TFelt => Expr TFelt -> Expr TFelt -> Expr a pattern a :+ b <- (cast @(TFelt :-> TFelt :-> TFelt) -> CastOk (Fun "+")) :*: a :*: b where (:+) = function "+" -pattern (:*) :: () => (a ~ TFelt) => Expr TFelt -> Expr TFelt -> Expr a +pattern (:*) :: () => a ~ TFelt => Expr TFelt -> Expr TFelt -> Expr a pattern a :* b <- (cast @(TFelt :-> TFelt :-> TFelt) -> CastOk (Fun "*")) :*: a :*: b where (:*) = function "*" -pattern (:-) :: () => (a ~ TFelt) => Expr TFelt -> Expr TFelt -> Expr a +pattern (:-) :: () => a ~ TFelt => Expr TFelt -> Expr TFelt -> Expr a pattern a :- b <- (cast @(TFelt :-> TFelt :-> TFelt) -> CastOk (Fun "-")) :*: a :*: b -pattern Negate :: () => (a ~ TFelt) => Expr TFelt -> Expr a +pattern Negate :: () => a ~ TFelt => Expr TFelt -> Expr a pattern Negate a <- (cast @(TFelt :-> TFelt) -> CastOk (Fun "-")) :*: a where Negate = function "-" -pattern And :: () => (a ~ TBool) => [Expr TBool] -> Expr a +pattern And :: () => a ~ TBool => [Expr TBool] -> Expr a pattern And cs <- (unfoldVariadic @TBool @TBool -> Just (Refl, "and", cs)) where And = apply1' (Fun "and") True @@ -297,10 +297,10 @@ and xs | [] <- xs' = True | [x] <- xs' = x | otherwise = And xs' - where - xs' = filter (/= True) (concatMap unfold xs) - unfold (And cs) = cs - unfold x = [x] + where + xs' = filter (/= True) (concatMap unfold xs) + unfold (And cs) = cs + unfold x = [x] infixr 2 .|| (.||) :: Expr TBool -> Expr TBool -> Expr TBool diff --git a/src/Horus/Expr/SMT.hs b/src/Horus/Expr/SMT.hs index d06abf7f..d59ee0ca 100644 --- a/src/Horus/Expr/SMT.hs +++ b/src/Horus/Expr/SMT.hs @@ -54,8 +54,8 @@ toSMT' (Felt b) = SMT.int b toSMT' (f :*: x) = let (h, args) = splitApp (f :*: x) in SMT.app h (reverse args) toSMT' (Fun s) = SMT.Atom (unpack s) toSMT' (ExistsFelt name e) = SMT.fun "exists" [bindings, toSMT' e] - where - bindings = SMT.List [SMT.List [SMT.Atom (unpack name), SMT.tInt]] + where + bindings = SMT.List [SMT.List [SMT.Atom (unpack name), SMT.tInt]] toSMT' (ExitField e) = toSMT' e splitApp :: Expr b -> (SMT.SExpr, [SMT.SExpr]) @@ -82,19 +82,19 @@ parseArithmetic = parse inlineLets :: SMT.SExpr -> SMT.SExpr inlineLets = flip runReader Map.empty . go - where - go :: SMT.SExpr -> Reader (Map String SMT.SExpr) SMT.SExpr - go (SMT.Atom s) = view (at s . non (SMT.Atom s)) - go (SMT.List [SMT.Atom "let", SMT.List bs, body]) = do - extension <- bindingsToMap bs - local (<> extension) (go body) - go (SMT.List l) = SMT.List <$> traverse go l - - bindingsToMap :: [SMT.SExpr] -> Reader (Map String SMT.SExpr) (Map String SMT.SExpr) - bindingsToMap bs = - [(s, v) | SMT.List [SMT.Atom s, v] <- bs] - & traverse (\(s, v) -> (s,) <$> go v) - & fmap Map.fromList + where + go :: SMT.SExpr -> Reader (Map String SMT.SExpr) SMT.SExpr + go (SMT.Atom s) = view (at s . non (SMT.Atom s)) + go (SMT.List [SMT.Atom "let", SMT.List bs, body]) = do + extension <- bindingsToMap bs + local (<> extension) (go body) + go (SMT.List l) = SMT.List <$> traverse go l + + bindingsToMap :: [SMT.SExpr] -> Reader (Map String SMT.SExpr) (Map String SMT.SExpr) + bindingsToMap bs = + [(s, v) | SMT.List [SMT.Atom s, v] <- bs] + & traverse (\(s, v) -> (s,) <$> go v) + & fmap Map.fromList -- parsing per se @@ -105,9 +105,9 @@ informativeCast :: forall b a. Typeable b => Expr a -> Either Text (Expr b) informativeCast e = case cast' @b e of Just e' -> pure e' Nothing -> Left (pack (printf "Can't cast '%s' to '%s'." aType bType)) - where - aType = show (typeRep @a \\ isProper e) - bType = show (typeRep @b) + where + aType = show (typeRep @a \\ isProper e) + bType = show (typeRep @b) pureSome :: Applicative m => f a -> m (Some f) pureSome = pure . Some @@ -131,26 +131,26 @@ parseSexp' s@(SMT.List (SMT.Atom f : x1 : xTail)) | f == "not" = parseUnary Expr.not | f == "abs" = parseUnary @TFelt abs | otherwise = parseStorageVar - where - fText = pack f + where + fText = pack f - parseVariadic :: forall arg res. (IsProper arg, IsProper res) => Either Text (Some Expr) - parseVariadic = do - xs <- traverse (parseSexp @arg) (x1 :| xTail) - pureSome (Expr.apply1 @res @arg (Fun fText) xs) + parseVariadic :: forall arg res. (IsProper arg, IsProper res) => Either Text (Some Expr) + parseVariadic = do + xs <- traverse (parseSexp @arg) (x1 :| xTail) + pureSome (Expr.apply1 @res @arg (Fun fText) xs) - parseArithL :: (Expr TFelt -> Expr TFelt -> Expr TFelt) -> Either Text (Some Expr) - parseArithL op = do - x1' <- parseSexp x1 - xTail' <- traverse parseSexp xTail - pureSome (foldl' op x1' xTail') + parseArithL :: (Expr TFelt -> Expr TFelt -> Expr TFelt) -> Either Text (Some Expr) + parseArithL op = do + x1' <- parseSexp x1 + xTail' <- traverse parseSexp xTail + pureSome (foldl' op x1' xTail') - parseUnary :: forall arg res. Typeable arg => (Expr arg -> Expr res) -> Either Text (Some Expr) - parseUnary con = case xTail of - [] -> pureSome . con =<< parseSexp x1 - _ -> Left (eNonUnary fText s) + parseUnary :: forall arg res. Typeable arg => (Expr arg -> Expr res) -> Either Text (Some Expr) + parseUnary con = case xTail of + [] -> pureSome . con =<< parseSexp x1 + _ -> Left (eNonUnary fText s) - parseStorageVar = parseVariadic @TFelt @TFelt + parseStorageVar = parseVariadic @TFelt @TFelt eEmptySexp :: Text eEmptySexp = "Can't parse an empty sexp." @@ -170,5 +170,5 @@ eNullaryFunction s = ] eNonUnary :: Text -> SMT.SExpr -> Text eNonUnary f s = "'" <> f <> "' must have only one argument, but has several: '" <> sText <> "'." - where - sText = pack (SMT.showsSExpr s "") + where + sText = pack (SMT.showsSExpr s "") diff --git a/src/Horus/Expr/Type/SMT.hs b/src/Horus/Expr/Type/SMT.hs index fdaca7b6..8d3ac04e 100644 --- a/src/Horus/Expr/Type/SMT.hs +++ b/src/Horus/Expr/Type/SMT.hs @@ -6,17 +6,16 @@ import SimpleSMT qualified as SMT import Horus.Expr.Type (STy (..), Ty) -{- | For the type of a function 'ty' return it's SExpr representation of - the form 'resType :| argTypes'. - -Example: toSMT (TFelt :-> TBool :-> TBool) = Bool :| [Felt, Bool] --} +-- | For the type of a function 'ty' return it's SExpr representation of +-- the form 'resType :| argTypes'. +-- +-- Example: toSMT (TFelt :-> TBool :-> TBool) = Bool :| [Felt, Bool] toSMT :: forall ty. SingI (ty :: Ty) => NonEmpty SMT.SExpr toSMT = go [] (sing @ty) - where - go :: [SMT.SExpr] -> Sing (ty' :: Ty) -> NonEmpty SMT.SExpr - go args SFelt = SMT.tInt :| reverse args - go args SBool = SMT.tBool :| reverse args - go args ((sArg :: STy ty'') ::-> sRes) = case withSingI sArg (toSMT @ty'') of - arg :| [] -> go (arg : args) sRes - _ -> error "Horus.Expr.Type.SMT.toSMT: higher order functions are not supported by SMT-Libv2" + where + go :: [SMT.SExpr] -> Sing (ty' :: Ty) -> NonEmpty SMT.SExpr + go args SFelt = SMT.tInt :| reverse args + go args SBool = SMT.tBool :| reverse args + go args ((sArg :: STy ty'') ::-> sRes) = case withSingI sArg (toSMT @ty'') of + arg :| [] -> go (arg : args) sRes + _ -> error "Horus.Expr.Type.SMT.toSMT: higher order functions are not supported by SMT-Libv2" diff --git a/src/Horus/Expr/Util.hs b/src/Horus/Expr/Util.hs index 7bcb1ecc..39874bd8 100644 --- a/src/Horus/Expr/Util.hs +++ b/src/Horus/Expr/Util.hs @@ -24,28 +24,28 @@ import Horus.Expr.Vars (prime) gatherNonStdFunctions :: Expr a -> Set (Some Function) gatherNonStdFunctions = execWriter . transform_ step - where - step :: forall ty. Expr ty -> Writer (Set (Some Function)) () - step (Fun name) | name `notElem` stdNames = emit (Function @ty name) - step _ = pure () + where + step :: forall ty. Expr ty -> Writer (Set (Some Function)) () + step (Fun name) | name `notElem` stdNames = emit (Function @ty name) + step _ = pure () - emit :: Function ty -> Writer (Set (Some Function)) () - emit f = tell (Set.singleton (Some f)) + emit :: Function ty -> Writer (Set (Some Function)) () + emit f = tell (Set.singleton (Some f)) gatherLogicalVariables :: Expr a -> Set Text gatherLogicalVariables (ExistsFelt name expr) = Set.singleton name `Set.union` gatherLogicalVariables expr gatherLogicalVariables expr = Set.filter isLogical . Set.map takeName . gatherNonStdFunctions $ expr - where - takeName (Some (Function name)) = name - isLogical name = "$" `Text.isPrefixOf` name + where + takeName (Some (Function name)) = name + isLogical name = "$" `Text.isPrefixOf` name suffixLogicalVariables :: Text -> Expr a -> Expr a suffixLogicalVariables suffix = Expr.transformId step - where - step :: Expr b -> Expr b - step (Expr.FeltConst name) | "$" `Text.isPrefixOf` name = Expr.FeltConst (name <> suffix) - step e = e + where + step :: Expr b -> Expr b + step (Expr.FeltConst name) | "$" `Text.isPrefixOf` name = Expr.FeltConst (name <> suffix) + step e = e fieldToInt :: Integer -> Expr a -> Expr a fieldToInt fPrime e = runReader (fieldToInt' fPrime e) UCNo diff --git a/src/Horus/Expr/Vars.hs b/src/Horus/Expr/Vars.hs index 064a6754..dc279355 100644 --- a/src/Horus/Expr/Vars.hs +++ b/src/Horus/Expr/Vars.hs @@ -61,7 +61,7 @@ parseRegKind t = fmap CallFp (Text.stripPrefix "fp@" t >>= readMaybe . unpack) <|> fmap ApGroup (Text.stripPrefix "ap!" t >>= readMaybe . unpack) -pattern Memory :: () => (a ~ TFelt) => Expr TFelt -> Expr a +pattern Memory :: () => a ~ TFelt => Expr TFelt -> Expr a pattern Memory addr <- (cast @(TFelt :-> TFelt) -> CastOk (Fun "memory")) :*: addr where Memory = memory @@ -85,17 +85,17 @@ parseStorageVar e = do guard (not (isReg name)) guard (not (isLVar name)) pure res - where - isStd n = n `elem` stdNames || n == "memory" - isReg n = - isJust (parseRegKind n) - || n == "ap" - || n == "fp" - || n == "range-check-bound" - || n == "prime" - isLVar n = "$" `Text.isPrefixOf` n - -pattern StorageVar :: () => (a ~ TFelt) => Text -> [Expr TFelt] -> Expr a + where + isStd n = n `elem` stdNames || n == "memory" + isReg n = + isJust (parseRegKind n) + || n == "ap" + || n == "fp" + || n == "range-check-bound" + || n == "prime" + isLVar n = "$" `Text.isPrefixOf` n + +pattern StorageVar :: () => a ~ TFelt => Text -> [Expr TFelt] -> Expr a pattern StorageVar name args <- (parseStorageVar -> Just (Refl, name, args)) rcBound :: Expr TFelt diff --git a/src/Horus/FunctionAnalysis.hs b/src/Horus/FunctionAnalysis.hs index 739225cd..763782f5 100644 --- a/src/Horus/FunctionAnalysis.hs +++ b/src/Horus/FunctionAnalysis.hs @@ -49,7 +49,9 @@ import Data.Map qualified as Map import Data.Maybe (fromJust, fromMaybe, isNothing, listToMaybe, mapMaybe) import Data.Text (Text) -import Horus.ContractDefinition (ContractDefinition (cd_invariants, cd_program, cd_specs, cd_storageVars)) +import Horus.ContractDefinition + ( ContractDefinition (cd_invariants, cd_program, cd_specs, cd_storageVars) + ) import Horus.Instruction ( LabeledInst , callDestination @@ -111,13 +113,13 @@ cgMbInsertArc (CG verts arcs) (fro, to) = graphOfCG :: CG -> (Graph, Vertex -> (Label, Label, [Label])) graphOfCG cg = graphFromEdges' . map named . Map.assocs $ cg_arcs cg - where - named (fro, tos) = (fro, fro, tos) + where + named (fro, tos) = (fro, fro, tos) cycles :: Graph -> [Vertex] cycles g = map fst . filter (uncurry reachableSet) $ assocs g - where - reachableSet v = elem v . concatMap (reachable g) + where + reachableSet v = elem v . concatMap (reachable g) cyclicVerts :: CG -> [Label] cyclicVerts cg = @@ -126,22 +128,22 @@ cyclicVerts cg = pcToFunOfProg :: Program -> Map.Map Label ScopedFunction pcToFunOfProg prog = Map.mapMaybe (go <=< ilInfoToFun) ilInfoOfLabel - where - -- The last accessible scope of the given label is the function said label belongs to. - idents = p_identifiers prog - ilInfoOfLabel = di_instructionLocations (p_debugInfo prog) + where + -- The last accessible scope of the given label is the function said label belongs to. + idents = p_identifiers prog + ilInfoOfLabel = di_instructionLocations (p_debugInfo prog) - ilInfoToFun :: ILInfo -> Maybe Label - ilInfoToFun ilInfo = - safeLast (il_accessibleScopes ilInfo) >>= getFunctionPc . (idents Map.!) + ilInfoToFun :: ILInfo -> Maybe Label + ilInfoToFun ilInfo = + safeLast (il_accessibleScopes ilInfo) >>= getFunctionPc . (idents Map.!) - go :: Label -> Maybe ScopedFunction - go label = ScopedFunction <$> fNameOfPc idents label <*> Just label + go :: Label -> Maybe ScopedFunction + go label = ScopedFunction <$> fNameOfPc idents label <*> Just label fNameOfPc :: Identifiers -> Label -> Maybe ScopedName fNameOfPc idents lblpc = listToMaybe fLblsAtPc - where - fLblsAtPc = [name | (name, ident) <- Map.toList idents, Just lblpc == getFunctionPc ident] + where + fLblsAtPc = [name | (name, ident) <- Map.toList idents, Just lblpc == getFunctionPc ident] functionsOf :: [LabeledInst] -> Program -> Map.Map ScopedFunction [LabeledInst] functionsOf rows prog = @@ -203,14 +205,14 @@ labelOfIdent _ = Nothing scopedFOfPc :: Identifiers -> Label -> Maybe ScopedFunction scopedFOfPc idents label = ScopedFunction <$> scopedName <*> Just label - where - scopedName = - listToMaybe $ - [ name - | (name, ident) <- Map.toList idents - , Just pc <- [getFunctionPc ident] - , pc == label - ] + where + scopedName = + listToMaybe $ + [ name + | (name, ident) <- Map.toList idents + , Just pc <- [getFunctionPc ident] + , pc == label + ] uncheckedScopedFOfPc :: Identifiers -> Label -> ScopedFunction uncheckedScopedFOfPc idents = fromJust . scopedFOfPc idents @@ -227,18 +229,18 @@ labelIdentifiersOfPc idents lblpc = -- (last parameter). isNotAnnotated :: ContractDefinition -> Identifier -> Bool isNotAnnotated cd = not . maybe False isAnnotated' . labelOfIdent - where - idents = (p_identifiers . cd_program) cd - isAnnotated' :: Label -> Bool - isAnnotated' = any (liftM2 (||) isSpec isInvariant) . labelIdentifiersOfPc idents - identToName :: Identifier -> Maybe ScopedName - identToName ident = listToMaybe [name | (name, i) <- Map.toList idents, i == ident] + where + idents = (p_identifiers . cd_program) cd + isAnnotated' :: Label -> Bool + isAnnotated' = any (liftM2 (||) isSpec isInvariant) . labelIdentifiersOfPc idents + identToName :: Identifier -> Maybe ScopedName + identToName ident = listToMaybe [name | (name, i) <- Map.toList idents, i == ident] - isSpec :: Identifier -> Bool - isSpec ident = maybe False (`Map.member` cd_specs cd) $ identToName ident + isSpec :: Identifier -> Bool + isSpec ident = maybe False (`Map.member` cd_specs cd) $ identToName ident - isInvariant :: Identifier -> Bool - isInvariant ident = maybe False (`Map.member` cd_invariants cd) $ identToName ident + isInvariant :: Identifier -> Bool + isInvariant ident = maybe False (`Map.member` cd_invariants cd) $ identToName ident wrapperScope :: Text wrapperScope = "__wrappers__" @@ -250,9 +252,9 @@ wrapperScope = "__wrappers__" -- be able to identify and exclude these. isWrapper :: ScopedFunction -> Bool isWrapper f = outerScope (sf_scopedName f) == wrapperScope - where - outerScope (ScopedName []) = "" - outerScope (ScopedName (scope : _)) = scope + where + outerScope (ScopedName []) = "" + outerScope (ScopedName (scope : _)) = scope fStorageRead :: ScopedName fStorageRead = ScopedName ["starkware", "starknet", "common", "syscalls", "storage_read"] @@ -262,16 +264,16 @@ fStorageWrite = ScopedName ["starkware", "starknet", "common", "syscalls", "stor mkGeneratedNames :: [ScopedName] -> [ScopedName] mkGeneratedNames = concatMap svNames - where - svNames sv = [sv <> "addr", sv <> "read", sv <> "write"] + where + svNames sv = [sv <> "addr", sv <> "read", sv <> "write"] storageVarsOfCD :: ContractDefinition -> [ScopedName] storageVarsOfCD = Map.keys . cd_storageVars isGeneratedName :: ScopedName -> ContractDefinition -> Bool isGeneratedName fname cd = fname `elem` generatedNames - where - generatedNames = mkGeneratedNames $ storageVarsOfCD cd + where + generatedNames = mkGeneratedNames $ storageVarsOfCD cd isSvarFunc :: ScopedName -> ContractDefinition -> Bool isSvarFunc fname cd = isGeneratedName fname cd || fname `elem` [fStorageRead, fStorageWrite] @@ -292,7 +294,8 @@ isAuxFunc (ScopedFunction fname _) cd = sizeOfCall :: Int sizeOfCall = 2 -inlinableFuns :: [LabeledInst] -> Program -> ContractDefinition -> Map.Map ScopedFunction [LabeledInst] +inlinableFuns + :: [LabeledInst] -> Program -> ContractDefinition -> Map.Map ScopedFunction [LabeledInst] inlinableFuns rows prog cd = Map.filterWithKey ( \f _ -> @@ -303,22 +306,23 @@ inlinableFuns rows prog cd = && not (isAuxFunc f cd) ) functions - where - idents = p_identifiers prog - functions = functionsOf rows prog - notIsAnnotated sf = maybe False (isNotAnnotated cd) . Map.lookup (sf_scopedName sf) $ idents - - -- Annotated *later* by horus-checker because they come from e.g. the - -- standard library. These are things with default specs. - notIsAnnotatedLater f = sf_scopedName f `notElem` map fst stdSpecsList - localCycles = Map.map (cyclicVerts . jumpgraph) - isAcyclic cyclicFuns f cyclicLbls = f `notElem` cyclicFuns && null cyclicLbls - -- The functions that contain neither global nor local cycles. - inlinable = - Map.keys . Map.filterWithKey (isAcyclic . cyclicVerts $ callgraph (Map.mapKeys sf_pc functions)) $ - Map.mapKeys sf_pc (localCycles functions) - -uninlinableFuns :: [LabeledInst] -> Program -> ContractDefinition -> Map.Map ScopedFunction [LabeledInst] + where + idents = p_identifiers prog + functions = functionsOf rows prog + notIsAnnotated sf = maybe False (isNotAnnotated cd) . Map.lookup (sf_scopedName sf) $ idents + + -- Annotated *later* by horus-checker because they come from e.g. the + -- standard library. These are things with default specs. + notIsAnnotatedLater f = sf_scopedName f `notElem` map fst stdSpecsList + localCycles = Map.map (cyclicVerts . jumpgraph) + isAcyclic cyclicFuns f cyclicLbls = f `notElem` cyclicFuns && null cyclicLbls + -- The functions that contain neither global nor local cycles. + inlinable = + Map.keys . Map.filterWithKey (isAcyclic . cyclicVerts $ callgraph (Map.mapKeys sf_pc functions)) $ + Map.mapKeys sf_pc (localCycles functions) + +uninlinableFuns + :: [LabeledInst] -> Program -> ContractDefinition -> Map.Map ScopedFunction [LabeledInst] uninlinableFuns rows prog cd = Map.difference (functionsOf rows prog) (inlinableFuns rows prog cd) diff --git a/src/Horus/Global.hs b/src/Horus/Global.hs index 3d383c7e..66ee8a9c 100644 --- a/src/Horus/Global.hs +++ b/src/Horus/Global.hs @@ -40,9 +40,22 @@ import Horus.Expr.Util (gatherLogicalVariables) import Horus.FunctionAnalysis (ScopedFunction (ScopedFunction, sf_pc), isWrapper) import Horus.Logger qualified as L (LogL, logDebug, logError, logInfo, logWarning) import Horus.Module (Module (..), ModuleL, gatherModules, getModuleNameParts) -import Horus.Preprocessor (HorusResult (..), PreprocessorL, SolverResult (..), goalListToTextList, optimizeQuery, solve) +import Horus.Preprocessor + ( HorusResult (..) + , PreprocessorL + , SolverResult (..) + , goalListToTextList + , optimizeQuery + , solve + ) import Horus.Preprocessor.Runner (PreprocessorEnv (..)) -import Horus.Preprocessor.Solvers (Solver, SolverSettings, filterMathsat, includesMathsat, isEmptySolver) +import Horus.Preprocessor.Solvers + ( Solver + , SolverSettings + , filterMathsat + , includesMathsat + , isEmptySolver + ) import Horus.Program (Identifiers, Program (p_prime)) import Horus.SW.FuncSpec (FuncSpec, FuncSpec' (fs'_pre)) import Horus.SW.Identifier (Function (..)) @@ -171,26 +184,24 @@ data SolvingInfo = SolvingInfo } deriving (Eq, Show) -{- | Construct a function name from a qualified function name and a summary of - the label(s) (usually just one). - - Basically, just concatenates the function name with the label(s), separated - by a dot. But crucially, it does this safely, so that if the label is empty, - it doesn't add a dot, and vice-versa if tghe function name is empty. - - This terminology comes from the `normalizedName` function in `Module.hs`. --} +-- | Construct a function name from a qualified function name and a summary of +-- the label(s) (usually just one). +-- +-- Basically, just concatenates the function name with the label(s), separated +-- by a dot. But crucially, it does this safely, so that if the label is empty, +-- it doesn't add a dot, and vice-versa if tghe function name is empty. +-- +-- This terminology comes from the `normalizedName` function in `Module.hs`. mkLabeledFuncName :: Text -> Text -> Text mkLabeledFuncName qualifiedFuncName "" = qualifiedFuncName mkLabeledFuncName "" labelsSummary = labelsSummary mkLabeledFuncName qualifiedFuncName labelsSummary = qualifiedFuncName <> "." <> labelsSummary -{- | Solve the constraints for a single module. - - Here, a module is a label-delimited section of a function (or possibly the - whole function). In general, we have multiple modules in a function when - that function contains multiple branches (an if-then-else, for example). --} +-- | Solve the constraints for a single module. +-- +-- Here, a module is a label-delimited section of a function (or possibly the +-- whole function). In general, we have multiple modules in a function when +-- that function contains multiple branches (an if-then-else, for example). solveModule :: Module -> GlobalL SolvingInfo solveModule m = do inlinables <- getInlinables @@ -207,16 +218,16 @@ solveModule m = do , si_inlinable = inlinable , si_preCheckingContext = m_preCheckedFuncAndCallStack m } - where - mkResult :: Text -> GlobalL HorusResult - mkResult moduleName = printingErrors $ do - constraints <- extractConstraints m - outputSmtQueries moduleName constraints - verbosePrint m - verbosePrint (debugFriendlyModel constraints) - solveSMT constraints - printingErrors :: GlobalL HorusResult -> GlobalL HorusResult - printingErrors a = a `catchError` (\e -> pure (Timeout (Just ("Error: " <> e)))) + where + mkResult :: Text -> GlobalL HorusResult + mkResult moduleName = printingErrors $ do + constraints <- extractConstraints m + outputSmtQueries moduleName constraints + verbosePrint m + verbosePrint (debugFriendlyModel constraints) + solveSMT constraints + printingErrors :: GlobalL HorusResult -> GlobalL HorusResult + printingErrors a = a `catchError` (\e -> pure (Timeout (Just ("Error: " <> e)))) outputSmtQueries :: Text -> ConstraintsState -> GlobalL () outputSmtQueries moduleName constraints = do @@ -225,30 +236,31 @@ outputSmtQueries moduleName constraints = do Config{..} <- getConfig whenJust cfg_outputQueries (writeSmtFile query) whenJust cfg_outputOptimizedQueries (writeSmtFileOptimized query) - where - memVars = map (\mv -> (mv_varName mv, mv_addrName mv)) (cs_memoryVariables constraints) + where + memVars = map (\mv -> (mv_varName mv, mv_addrName mv)) (cs_memoryVariables constraints) - writeSmtFile :: Text -> FilePath -> GlobalL () - writeSmtFile query dir = do - writeFile' (dir unpack moduleName <> ".smt2") query + writeSmtFile :: Text -> FilePath -> GlobalL () + writeSmtFile query dir = do + writeFile' (dir unpack moduleName <> ".smt2") query - getQueryList :: Text -> PreprocessorL [Text] - getQueryList query = do - queryList <- optimizeQuery query - goalListToTextList queryList + getQueryList :: Text -> PreprocessorL [Text] + getQueryList query = do + queryList <- optimizeQuery query + goalListToTextList queryList - writeSmtFileOptimized :: Text -> FilePath -> GlobalL () - writeSmtFileOptimized query dir = do - Config{..} <- getConfig - queries <- runPreprocessorL (PreprocessorEnv memVars cfg_solver cfg_solverSettings) (getQueryList query) - writeSmtQueries queries dir moduleName + writeSmtFileOptimized :: Text -> FilePath -> GlobalL () + writeSmtFileOptimized query dir = do + Config{..} <- getConfig + queries <- + runPreprocessorL (PreprocessorEnv memVars cfg_solver cfg_solverSettings) (getQueryList query) + writeSmtQueries queries dir moduleName writeSmtQueries :: [Text] -> FilePath -> Text -> GlobalL () writeSmtQueries queries dir moduleName = do for_ (zip [1 :: Int ..] queries) writeQueryFile - where - newFileName n = dir "optimized_goals_" <> unpack moduleName show n <> ".smt2" - writeQueryFile (n, q) = writeFile' (newFileName n) q + where + newFileName n = dir "optimized_goals_" <> unpack moduleName show n <> ".smt2" + writeQueryFile (n, q) = writeFile' (newFileName n) q removeMathSAT :: Module -> GlobalL a -> GlobalL a removeMathSAT m run = do @@ -259,22 +271,24 @@ removeMathSAT m run = do then do let solver' = filterMathsat solver if isEmptySolver solver' - then throw "Only the MathSAT solver was used to analyze a call with a logical variable in its specification." + then + throw + "Only the MathSAT solver was used to analyze a call with a logical variable in its specification." else do setConfig conf{cfg_solver = solver'} result <- run setConfig conf pure result else run - where - -- FIXME should check not just pre, but also post - instUsesLvars i = falseIfError $ do - callee <- getCallee i - spec <- getFuncSpec callee - let lvars = gatherLogicalVariables (fromMaybe Expr.True (fs'_pre spec)) - pure (not (null lvars)) + where + -- FIXME should check not just pre, but also post + instUsesLvars i = falseIfError $ do + callee <- getCallee i + spec <- getFuncSpec callee + let lvars = gatherLogicalVariables (fromMaybe Expr.True (fs'_pre spec)) + pure (not (null lvars)) - falseIfError a = a `catchError` const (pure False) + falseIfError a = a `catchError` const (pure False) solveSMT :: ConstraintsState -> GlobalL HorusResult solveSMT cs = do @@ -283,7 +297,8 @@ solveSMT cs = do let query = makeModel False cs fPrime let preQuery = makeModel True cs fPrime res <- runPreprocessorL (PreprocessorEnv memVars cfg_solver cfg_solverSettings) (solve fPrime query) - preRes <- runPreprocessorL (PreprocessorEnv memVars cfg_solver cfg_solverSettings) (solve fPrime preQuery) + preRes <- + runPreprocessorL (PreprocessorEnv memVars cfg_solver cfg_solverSettings) (solve fPrime preQuery) -- Convert the `SolverResult` to a `HorusResult`. -- @@ -294,8 +309,8 @@ solveSMT cs = do (Unknown mbReason, _) -> pure $ Timeout mbReason (Unsat, Unsat) -> pure ContradictoryPrecondition (Unsat, _) -> pure Verified - where - memVars = map (\mv -> (mv_varName mv, mv_addrName mv)) (cs_memoryVariables cs) + where + memVars = map (\mv -> (mv_varName mv, mv_addrName mv)) (cs_memoryVariables cs) -- | Add an oracle suffix to the module name when the module name *is* the function name. appendMissingDefaultOracleSuffix :: SolvingInfo -> SolvingInfo @@ -304,34 +319,33 @@ appendMissingDefaultOracleSuffix si@(SolvingInfo moduleName funcName result inli then SolvingInfo (moduleName <> ":::default") funcName result inlinable preCheckingContext else si -{- | Collapse a list of modules for the same function if they are all `Unsat`. - - Given a list of `SolvingInfo`s, each associated with a module, under the - assumption that they all have the same `si_funcName`, if they are all `Unsat`, - we collapse them into a singleton list of one `SolvingInfo`, where we - hot-patch the `moduleName`, replacing it with the `funcName`. - - Otherwise, we have at least one `Sat`. - - We break the remaining cases into two subcases: - * If it is just a single module, we return the list as-is. - * If there are multiple modules, we add a `:::default` oracle suffix to the - module name that would otherwise just be ``. --} +-- | Collapse a list of modules for the same function if they are all `Unsat`. +-- +-- Given a list of `SolvingInfo`s, each associated with a module, under the +-- assumption that they all have the same `si_funcName`, if they are all `Unsat`, +-- we collapse them into a singleton list of one `SolvingInfo`, where we +-- hot-patch the `moduleName`, replacing it with the `funcName`. +-- +-- Otherwise, we have at least one `Sat`. +-- +-- We break the remaining cases into two subcases: +-- * If it is just a single module, we return the list as-is. +-- * If there are multiple modules, we add a `:::default` oracle suffix to the +-- module name that would otherwise just be ``. collapseAllUnsats :: [SolvingInfo] -> [SolvingInfo] collapseAllUnsats [] = [] collapseAllUnsats infos@(SolvingInfo _ funcName result _ _ : _) - | all ((== Verified) . si_result) infos = [SolvingInfo funcName funcName result reportInlinable Nothing] + | all ((== Verified) . si_result) infos = + [SolvingInfo funcName funcName result reportInlinable Nothing] | length infos == 1 = infos | otherwise = map appendMissingDefaultOracleSuffix infos - where - reportInlinable = all si_inlinable infos - -{- | Return a solution of SMT queries corresponding with the contract. + where + reportInlinable = all si_inlinable infos - For the purposes of reporting results, - we also remember which SMT query corresponding to a function was inlined. --} +-- | Return a solution of SMT queries corresponding with the contract. +-- +-- For the purposes of reporting results, +-- we also remember which SMT query corresponding to a function was inlined. solveContract :: GlobalL [SolvingInfo] solveContract = do lInstructions <- getLabelledInstructions @@ -352,9 +366,9 @@ solveContract = do identifiers <- getIdentifiers let isUntrusted :: Module -> Bool isUntrusted m = labeledFuncName `notElem` trustedStdFuncs - where - (qualifiedFuncName, labelsSummary, _, _) = getModuleNameParts identifiers m - labeledFuncName = mkLabeledFuncName qualifiedFuncName labelsSummary + where + (qualifiedFuncName, labelsSummary, _, _) = getModuleNameParts identifiers m + labeledFuncName = mkLabeledFuncName qualifiedFuncName labelsSummary infos <- for (filter isUntrusted modules) solveModule pure $ ( concatMap collapseAllUnsats @@ -362,19 +376,19 @@ solveContract = do . filter (not . isVerifiedIgnorable) ) infos - where - isStandardSource :: Set ScopedFunction -> ScopedFunction -> Bool - isStandardSource inlinables f = f `notElem` inlinables && not (isWrapper f) + where + isStandardSource :: Set ScopedFunction -> ScopedFunction -> Bool + isStandardSource inlinables f = f `notElem` inlinables && not (isWrapper f) - sameFuncName :: SolvingInfo -> SolvingInfo -> Bool - sameFuncName (SolvingInfo _ nameA _ _ _) (SolvingInfo _ nameB _ _ _) = nameA == nameB + sameFuncName :: SolvingInfo -> SolvingInfo -> Bool + sameFuncName (SolvingInfo _ nameA _ _ _) (SolvingInfo _ nameB _ _ _) = nameA == nameB - ignorableFuncPrefixes :: [Text] - ignorableFuncPrefixes = ["empty: ", "starkware.cairo.lang", "starkware.cairo.common", "starkware.starknet.common"] + ignorableFuncPrefixes :: [Text] + ignorableFuncPrefixes = ["empty: ", "starkware.cairo.lang", "starkware.cairo.common", "starkware.starknet.common"] - isVerifiedIgnorable :: SolvingInfo -> Bool - isVerifiedIgnorable (SolvingInfo name _ res _ _) = - res == Verified && any (`Text.isPrefixOf` name) ignorableFuncPrefixes + isVerifiedIgnorable :: SolvingInfo -> Bool + isVerifiedIgnorable (SolvingInfo name _ res _ _) = + res == Verified && any (`Text.isPrefixOf` name) ignorableFuncPrefixes logM :: (a -> L.LogL ()) -> a -> GlobalL () logM lg v = diff --git a/src/Horus/Global/Runner.hs b/src/Horus/Global/Runner.hs index cff693b1..751caa10 100644 --- a/src/Horus/Global/Runner.hs +++ b/src/Horus/Global/Runner.hs @@ -26,42 +26,42 @@ type Impl = ReaderT Env (ExceptT Text IO) -- TODO replace ExceptT with exception interpret :: GlobalL a -> Impl a interpret = iterM exec . runGlobalL - where - exec :: GlobalF (Impl a) -> Impl a - exec (RunCFGBuildL builder cont) = do - ci <- asks e_contractInfo - liftEither (CFGBuild.runImpl ci (CFGBuild.interpret builder)) >>= cont - exec (RunCairoSemanticsL initStack builder cont) = do - ci <- asks e_contractInfo - liftEither (CairoSemantics.run initStack ci builder) >>= cont - exec (RunModuleL builder cont) = liftEither (Module.run builder) >>= cont - exec (RunPreprocessorL penv preprocessor cont) = do - mPreprocessed <- lift (Preprocessor.run penv preprocessor) - liftEither mPreprocessed >>= cont - exec (GetCallee inst cont) = do - ci <- asks e_contractInfo - ci_getCallee ci inst >>= cont - exec (GetConfig cont) = asks e_config >>= liftIO . readIORef >>= cont - exec (GetFuncSpec name cont) = do - ci <- asks e_contractInfo - cont (ci_getFuncSpec ci name) - exec (GetIdentifiers cont) = asks (ci_identifiers . e_contractInfo) >>= cont - exec (GetInlinable cont) = asks (ci_inlinables . e_contractInfo) >>= cont - exec (GetLabelledInstrs cont) = asks (ci_labelledInstrs . e_contractInfo) >>= cont - exec (GetProgram cont) = asks (ci_program . e_contractInfo) >>= cont - exec (GetSources cont) = asks (ci_sources . e_contractInfo) >>= cont - exec (SetConfig conf cont) = do - configRef <- asks e_config - liftIO (writeIORef configRef conf) - cont - exec (PutStrLn' what cont) = pPrintString (unpack what) >> cont - exec (WriteFile' file text cont) = liftIO (createAndWriteFile file text) >> cont - exec (Log logger cont) = do - (_, vs) <- liftEither $ Logger.runImpl (Logger.interpret logger) - liftIO $ mapM_ print vs - cont - exec (Throw t) = throwError t - exec (Catch m handler cont) = catchError (interpret m) (interpret . handler) >>= cont + where + exec :: GlobalF (Impl a) -> Impl a + exec (RunCFGBuildL builder cont) = do + ci <- asks e_contractInfo + liftEither (CFGBuild.runImpl ci (CFGBuild.interpret builder)) >>= cont + exec (RunCairoSemanticsL initStack builder cont) = do + ci <- asks e_contractInfo + liftEither (CairoSemantics.run initStack ci builder) >>= cont + exec (RunModuleL builder cont) = liftEither (Module.run builder) >>= cont + exec (RunPreprocessorL penv preprocessor cont) = do + mPreprocessed <- lift (Preprocessor.run penv preprocessor) + liftEither mPreprocessed >>= cont + exec (GetCallee inst cont) = do + ci <- asks e_contractInfo + ci_getCallee ci inst >>= cont + exec (GetConfig cont) = asks e_config >>= liftIO . readIORef >>= cont + exec (GetFuncSpec name cont) = do + ci <- asks e_contractInfo + cont (ci_getFuncSpec ci name) + exec (GetIdentifiers cont) = asks (ci_identifiers . e_contractInfo) >>= cont + exec (GetInlinable cont) = asks (ci_inlinables . e_contractInfo) >>= cont + exec (GetLabelledInstrs cont) = asks (ci_labelledInstrs . e_contractInfo) >>= cont + exec (GetProgram cont) = asks (ci_program . e_contractInfo) >>= cont + exec (GetSources cont) = asks (ci_sources . e_contractInfo) >>= cont + exec (SetConfig conf cont) = do + configRef <- asks e_config + liftIO (writeIORef configRef conf) + cont + exec (PutStrLn' what cont) = pPrintString (unpack what) >> cont + exec (WriteFile' file text cont) = liftIO (createAndWriteFile file text) >> cont + exec (Log logger cont) = do + (_, vs) <- liftEither $ Logger.runImpl (Logger.interpret logger) + liftIO $ mapM_ print vs + cont + exec (Throw t) = throwError t + exec (Catch m handler cont) = catchError (interpret m) (interpret . handler) >>= cont run :: Env -> GlobalL a -> IO (Either Text a) run env = runExceptT . flip runReaderT env . interpret diff --git a/src/Horus/Instruction.hs b/src/Horus/Instruction.hs index 973f8fb8..7a419217 100644 --- a/src/Horus/Instruction.hs +++ b/src/Horus/Instruction.hs @@ -85,8 +85,8 @@ type LabeledInst = (Label, Instruction) labelInstructions :: [Instruction] -> [LabeledInst] labelInstructions insts = zip (coerce pcs) insts - where - pcs = scanl (+) 0 (map instructionSize insts) + where + pcs = scanl (+) 0 (map instructionSize insts) instructionSize :: Instruction -> Int instructionSize Instruction{i_op1Source = Imm} = 2 @@ -116,9 +116,9 @@ jumpDestination (pc, i@Instruction{i_opCode = Nop}) = case i_pcUpdate i of JumpAbs -> pure absDst Jnz -> pure relDst _ -> Nothing - where - relDst = moveLabel pc (fromInteger (i_imm i)) - absDst = Label (fromInteger (i_imm i)) + where + relDst = moveLabel pc (fromInteger (i_imm i)) + absDst = Label (fromInteger (i_imm i)) jumpDestination _ = Nothing n15, n16 :: Int @@ -130,7 +130,8 @@ readAllInstructions fPrime (i : is) = do (instr, is') <- readInstruction fPrime (i :| is) (instr :) <$> readAllInstructions fPrime is' -readInstruction :: forall m. MonadError Text m => Integer -> NonEmpty Integer -> m (Instruction, [Integer]) +readInstruction + :: forall m. MonadError Text m => Integer -> NonEmpty Integer -> m (Instruction, [Integer]) readInstruction fPrime (i :| is) = do let flags = i `shiftR` (3 * 16) let dstEnc = i .&. (2 ^ n16 - 1) @@ -190,35 +191,35 @@ readInstruction fPrime (i :| is) = do ) <*> return (toSignedFelt fPrime imm) pure (instruction, is') - where - op1Map :: Bool -> Bool -> Bool -> m Op1Source - op1Map True False False = return Imm - op1Map False True False = return $ RegisterSource AllocationPointer - op1Map False False True = return $ RegisterSource FramePointer - op1Map False False False = return Op0 - op1Map _ _ _ = throwError "wrong op1 code" - resMap :: Bool -> Bool -> m ResLogic - resMap True False = return Add - resMap False True = return Mult - resMap False False = return Op1 - resMap True True = return Unconstrained - pcMap :: Bool -> Bool -> Bool -> m PcUpdate - pcMap True False False = return JumpAbs - pcMap False True False = return JumpRel - pcMap False False True = return Jnz - pcMap False False False = return Regular - pcMap _ _ _ = throwError "wrong pc flag" - apMap :: Bool -> Bool -> m ApUpdate - apMap True False = return AddRes - apMap False True = return Add1 - apMap False False = return NoUpdate - apMap _ _ = throwError "wrong ap flag" - opCodeMap :: Bool -> Bool -> Bool -> m OpCode - opCodeMap True False False = return Call - opCodeMap False True False = return Ret - opCodeMap False False True = return AssertEqual - opCodeMap False False False = return Nop - opCodeMap _ _ _ = throwError "wrong opcode" + where + op1Map :: Bool -> Bool -> Bool -> m Op1Source + op1Map True False False = return Imm + op1Map False True False = return $ RegisterSource AllocationPointer + op1Map False False True = return $ RegisterSource FramePointer + op1Map False False False = return Op0 + op1Map _ _ _ = throwError "wrong op1 code" + resMap :: Bool -> Bool -> m ResLogic + resMap True False = return Add + resMap False True = return Mult + resMap False False = return Op1 + resMap True True = return Unconstrained + pcMap :: Bool -> Bool -> Bool -> m PcUpdate + pcMap True False False = return JumpAbs + pcMap False True False = return JumpRel + pcMap False False True = return Jnz + pcMap False False False = return Regular + pcMap _ _ _ = throwError "wrong pc flag" + apMap :: Bool -> Bool -> m ApUpdate + apMap True False = return AddRes + apMap False True = return Add1 + apMap False False = return NoUpdate + apMap _ _ = throwError "wrong ap flag" + opCodeMap :: Bool -> Bool -> Bool -> m OpCode + opCodeMap True False False = return Call + opCodeMap False True False = return Ret + opCodeMap False False True = return AssertEqual + opCodeMap False False False = return Nop + opCodeMap _ _ _ = throwError "wrong opcode" toSemiAsmUnsafe :: Instruction -> Text toSemiAsmUnsafe i = case toSemiAsm i of @@ -241,29 +242,29 @@ toSemiAsm Instruction{..} = do Regular -> case i_apUpdate of AddRes -> withRes ("ap += " <>) other -> throwError ("Unexpected AP update for a NOP, non-jump opcode: " <> tShow other) - where - withRes f = fmap f getRes - dst = mem (printReg i_dstRegister `add` i_dstOffset) - mbApPP = case i_apUpdate of - Add1 -> "; ap++" - _ -> "" - getRes = case i_resLogic of - Op1 -> pure op1 - Add -> pure (op0 <> " + " <> op1) - Mult -> pure (op0 <> " * " <> op1) - Unconstrained -> throwError "Don't use the result" - mem addr = "[" <> addr <> "]" - printReg AllocationPointer = "ap" - printReg FramePointer = "fp" - op1 = case i_op1Source of - Op0 -> mem (op0 `add` i_op1Offset) - RegisterSource reg -> mem (printReg reg `add` i_op1Offset) - Imm -> tShow i_imm - op0 = mem (printReg i_op0Register `add` i_op0Offset) - op `add` v - | v < 0 = op <> " - " <> tShow (-v) - | v == 0 = op - | otherwise = op <> " + " <> tShow v + where + withRes f = fmap f getRes + dst = mem (printReg i_dstRegister `add` i_dstOffset) + mbApPP = case i_apUpdate of + Add1 -> "; ap++" + _ -> "" + getRes = case i_resLogic of + Op1 -> pure op1 + Add -> pure (op0 <> " + " <> op1) + Mult -> pure (op0 <> " * " <> op1) + Unconstrained -> throwError "Don't use the result" + mem addr = "[" <> addr <> "]" + printReg AllocationPointer = "ap" + printReg FramePointer = "fp" + op1 = case i_op1Source of + Op0 -> mem (op0 `add` i_op1Offset) + RegisterSource reg -> mem (printReg reg `add` i_op1Offset) + Imm -> tShow i_imm + op0 = mem (printReg i_op0Register `add` i_op0Offset) + op `add` v + | v < 0 = op <> " - " <> tShow (-v) + | v == 0 = op + | otherwise = op <> " + " <> tShow v isRet :: Instruction -> Bool isRet Instruction{i_opCode = Ret} = True diff --git a/src/Horus/Logger/Runner.hs b/src/Horus/Logger/Runner.hs index 58d4f1da..104373c3 100644 --- a/src/Horus/Logger/Runner.hs +++ b/src/Horus/Logger/Runner.hs @@ -22,8 +22,8 @@ data Message instance Show Message where show (Message s t) = "[" <> show s <> "] - " <> t' - where - t' = unpack $ filter (/= '\"') t + where + t' = unpack $ filter (/= '\"') t newtype ImplL a = ImplL (State (Seq Message) a) @@ -37,11 +37,11 @@ newtype ImplL a runImpl :: ImplL a -> Either Text (a, [Message]) runImpl (ImplL s) = return $ f (runState s mempty) - where - f (x, y) = (x, toList y) + where + f (x, y) = (x, toList y) interpret :: LogL a -> ImplL a interpret = iterM exec . runLogL - where - exec (LogF sev txt next) = - modify' (|> Message sev txt) >> next + where + exec (LogF sev txt next) = + modify' (|> Message sev txt) >> next diff --git a/src/Horus/Module.hs b/src/Horus/Module.hs index 900463c9..86e7dc47 100644 --- a/src/Horus/Module.hs +++ b/src/Horus/Module.hs @@ -27,13 +27,28 @@ import Text.Printf (printf) import Horus.CFGBuild (ArcCondition (..), Label (unLabel), Vertex (..)) import Horus.CFGBuild.Runner (CFG (..), verticesLabelledBy) -import Horus.CallStack (CallStack, callerPcOfCallEntry, digestOfCallStack, initialWithFunc, pop, push, stackTrace, top) +import Horus.CallStack + ( CallStack + , callerPcOfCallEntry + , digestOfCallStack + , initialWithFunc + , pop + , push + , stackTrace + , top + ) import Horus.ContractInfo (pcToFun) import Horus.Expr (Expr, Ty (..), (.&&), (.==)) import Horus.Expr qualified as Expr (and) import Horus.Expr.SMT (pprExpr) import Horus.Expr.Vars (ap, fp) -import Horus.FunctionAnalysis (FInfo, FuncOp (ArcCall, ArcRet), ScopedFunction (sf_scopedName), isRetArc, sizeOfCall) +import Horus.FunctionAnalysis + ( FInfo + , FuncOp (ArcCall, ArcRet) + , ScopedFunction (sf_scopedName) + , isRetArc + , sizeOfCall + ) import Horus.Instruction (LabeledInst, uncheckedCallDestination) import Horus.Label (moveLabel) import Horus.Program (Identifiers) @@ -74,15 +89,14 @@ dropMain :: ScopedName -> ScopedName dropMain (ScopedName ("__main__" : xs)) = ScopedName xs dropMain name = name -{- | Summarize a list of labels for a function. - - If you have `__main__.foo.bar` on the same PC* as `__main__.foo.baz`, you - get a string that tells you you're in `foo` scope for `bar | baz`. - - If you get more than one scope (possibly, this cannot occur in Cairo), for - example, `__main__.foo.bar` and `__main__.FOO.baz` you get a summarization - of the scopes `fooFOO` and `bar|baz`. --} +-- | Summarize a list of labels for a function. +-- +-- If you have `__main__.foo.bar` on the same PC* as `__main__.foo.baz`, you +-- get a string that tells you you're in `foo` scope for `bar | baz`. +-- +-- If you get more than one scope (possibly, this cannot occur in Cairo), for +-- example, `__main__.foo.bar` and `__main__.FOO.baz` you get a summarization +-- of the scopes `fooFOO` and `bar|baz`. summarizeLabels :: [Text] -> Text summarizeLabels labels = let prettyLabels = Text.intercalate "|" labels @@ -90,60 +104,58 @@ summarizeLabels labels = then prettyLabels else Text.concat ["{", prettyLabels, "}"] -commonPrefix :: (Eq e) => [e] -> [e] -> [e] +commonPrefix :: Eq e => [e] -> [e] -> [e] commonPrefix _ [] = [] commonPrefix [] _ = [] commonPrefix (x : xs) (y : ys) | x == y = x : commonPrefix xs ys | otherwise = [] -{- | For labels whose names are prefixed by the scope specifier equivalent to the - scope of the function they are declared in, do not replicate this scope - information in their name. - - We do this by computing the longest common prefix, dropping it from all the - names, and then adding the prefix itself as a new name. --} +-- | For labels whose names are prefixed by the scope specifier equivalent to the +-- scope of the function they are declared in, do not replicate this scope +-- information in their name. +-- +-- We do this by computing the longest common prefix, dropping it from all the +-- names, and then adding the prefix itself as a new name. sansCommonAncestor :: [[Text]] -> [[Text]] sansCommonAncestor xss = prefix : remainders - where - prefix = foldl1 commonPrefix xss - remainders = map (drop (length prefix)) xss - -{- | Returns the function name parts, in particular the fully qualified - function name and the label summary. - - We take as arguments a list of scoped names, and a boolean flag indicating - whether the list of scoped names belongs to a function or a *floating label* - (as distinct from a function label). - - A floating label is, for example, `add:` in the snippet below, which is - taken from the `func_multiple_ret.cairo` test file at revision 89ddeb2: - - ```cairo - func succpred(m) -> (res: felt) { - ... - add: - [ap] = [fp - 3] - 1, ap++; - ... - } - ``` - In particular, `add` is not a function name. A function name itself is, of - course, a label. But it is not a *floating label*, as defined above. - - Note: we say "fully qualified", but we remove the `__main__` prefix from - top-level function names, if it exists. --} + where + prefix = foldl1 commonPrefix xss + remainders = map (drop (length prefix)) xss + +-- | Returns the function name parts, in particular the fully qualified +-- function name and the label summary. +-- +-- We take as arguments a list of scoped names, and a boolean flag indicating +-- whether the list of scoped names belongs to a function or a *floating label* +-- (as distinct from a function label). +-- +-- A floating label is, for example, `add:` in the snippet below, which is +-- taken from the `func_multiple_ret.cairo` test file at revision 89ddeb2: +-- +-- ```cairo +-- func succpred(m) -> (res: felt) { +-- ... +-- add: +-- [ap] = [fp - 3] - 1, ap++; +-- ... +-- } +-- ``` +-- In particular, `add` is not a function name. A function name itself is, of +-- course, a label. But it is not a *floating label*, as defined above. +-- +-- Note: we say "fully qualified", but we remove the `__main__` prefix from +-- top-level function names, if it exists. normalizedName :: [ScopedName] -> Bool -> (Text, Text) normalizedName scopedNames isFloatingLabel = (Text.concat scopes, labelsSummary) - where - -- Extract list of scopes from each ScopedName, dropping `__main__`. - names = filter (not . null) $ sansCommonAncestor $ map (sn_path . dropMain) scopedNames - -- If we have a floating label, we need to drop the last scope, because it is - -- the label name itself. - scopes = map (Text.intercalate ".") (if isFloatingLabel then map init names else names) - -- This will almost always just be the name of the single label. - labelsSummary = if isFloatingLabel then summarizeLabels (map last names) else "" + where + -- Extract list of scopes from each ScopedName, dropping `__main__`. + names = filter (not . null) $ sansCommonAncestor $ map (sn_path . dropMain) scopedNames + -- If we have a floating label, we need to drop the last scope, because it is + -- the label name itself. + scopes = map (Text.intercalate ".") (if isFloatingLabel then map init names else names) + -- This will almost always just be the name of the single label. + labelsSummary = if isFloatingLabel then summarizeLabels (map last names) else "" descrOfBool :: Bool -> Text descrOfBool True = "1" @@ -155,37 +167,36 @@ descrOfOracle oracle = then "" else (<>) ":::" . Text.concat . map descrOfBool . Map.elems $ oracle -{- | Return a quadruple of the function name, the label summary, the oracle and - precondition check suffix (indicates, for precondition-checking modules, - which function's precondition is being checked). - - The oracle is a string of `1` and `2` characters, representing a path - through the control flow graph of the function. For example, if we have a - function - - ```cairo - func f(x : felt) -> felt { - if (x == 0) { - return 0; - } else { - return 1; - } - } - ``` - - then the branch where we return 0 is usually represented by `1` (since the - predicate `x == 0` is True), and the branch where we return 1 is represented - by `2`. - - Nested control flow results in multiple `1` or `2` characters. - - See `normalizedName` for the definition of a floating label. Here, the label - is floating if it is not a function declaration (i.e. equal to `calledF`), - since these are the only two types of labels we may encounter. - - Note: while we do have the name of the called function in the `Module` type, - it does not contain the rest of the labels. --} +-- | Return a quadruple of the function name, the label summary, the oracle and +-- precondition check suffix (indicates, for precondition-checking modules, +-- which function's precondition is being checked). +-- +-- The oracle is a string of `1` and `2` characters, representing a path +-- through the control flow graph of the function. For example, if we have a +-- function +-- +-- ```cairo +-- func f(x : felt) -> felt { +-- if (x == 0) { +-- return 0; +-- } else { +-- return 1; +-- } +-- } +-- ``` +-- +-- then the branch where we return 0 is usually represented by `1` (since the +-- predicate `x == 0` is True), and the branch where we return 1 is represented +-- by `2`. +-- +-- Nested control flow results in multiple `1` or `2` characters. +-- +-- See `normalizedName` for the definition of a floating label. Here, the label +-- is floating if it is not a function declaration (i.e. equal to `calledF`), +-- since these are the only two types of labels we may encounter. +-- +-- Note: while we do have the name of the called function in the `Module` type, +-- it does not contain the rest of the labels. getModuleNameParts :: Identifiers -> Module -> (Text, Text, Text, Text) getModuleNameParts idents (Module spec prog oracle calledF _ mbPreCheckedFuncAndCallStack) = case beginOfModule prog of @@ -195,14 +206,14 @@ getModuleNameParts idents (Module spec prog oracle calledF _ mbPreCheckedFuncAnd isFloatingLabel = label /= calledF (prefix, labelsSummary) = normalizedName scopedNames isFloatingLabel in (prefix, labelsSummary, descrOfOracle oracle, preCheckingSuffix) - where - post = fs_post spec - preCheckingSuffix = case mbPreCheckedFuncAndCallStack of - Nothing -> "" - Just (callstack, f) -> - let fName = toText . dropMain . sf_scopedName $ f - stackDigest = digestOfCallStack (Map.map sf_scopedName (pcToFun idents)) callstack - in " Pre<" <> fName <> "|" <> stackDigest <> ">" + where + post = fs_post spec + preCheckingSuffix = case mbPreCheckedFuncAndCallStack of + Nothing -> "" + Just (callstack, f) -> + let fName = toText . dropMain . sf_scopedName $ f + stackDigest = digestOfCallStack (Map.map sf_scopedName (pcToFun idents)) callstack + in " Pre<" <> fName <> "|" <> stackDigest <> ">" data Error = ELoopNoInvariant Label @@ -215,7 +226,8 @@ instance Show Error where data ModuleF a = EmitModule Module a - | forall b. Visiting (NonEmpty Label, Map (NonEmpty Label, Label) Bool, Vertex) (Bool -> ModuleL b) (b -> a) + | forall b. + Visiting (NonEmpty Label, Map (NonEmpty Label, Label) Bool, Vertex) (Bool -> ModuleL b) (b -> a) | Throw Error | forall b. Catch (ModuleL b) (Error -> ModuleL b) (b -> a) @@ -235,11 +247,11 @@ liftF' = ModuleL . liftF emitModule :: Module -> ModuleL () emitModule m = liftF' (EmitModule m ()) -{- | Perform the action on the path where the label 'l' has been marked - as visited. The `action` parameter takes a boolean argument determining - whether the vertex has already been visited. --} -visiting :: (NonEmpty Label, Map (NonEmpty Label, Label) Bool, Vertex) -> (Bool -> ModuleL b) -> ModuleL b +-- | Perform the action on the path where the label 'l' has been marked +-- as visited. The `action` parameter takes a boolean argument determining +-- whether the vertex has already been visited. +visiting + :: (NonEmpty Label, Map (NonEmpty Label, Label) Bool, Vertex) -> (Bool -> ModuleL b) -> ModuleL b visiting vertexDesc action = liftF' (Visiting vertexDesc action id) throw :: Error -> ModuleL a @@ -258,24 +270,24 @@ extractPlainBuilder (FuncSpec pre _ storage) gatherModules :: CFG -> [(Function, ScopedName, FuncSpec)] -> ModuleL () gatherModules cfg = traverse_ $ \(f, _, spec) -> gatherFromSource cfg f spec -visitArcs :: - CFG -> - FuncSpec -> - Function -> - CallStack -> - Map (NonEmpty Label, Label) Bool -> - [LabeledInst] -> - SpecBuilder -> - Vertex -> - ModuleL () +visitArcs + :: CFG + -> FuncSpec + -> Function + -> CallStack + -> Map (NonEmpty Label, Label) Bool + -> [LabeledInst] + -> SpecBuilder + -> Vertex + -> ModuleL () visitArcs cfg fSpec function callstack' newOracle acc' pre v' = do unless (null outArcs) $ for_ outArcs' $ \(lTo, insts, test, f') -> visit cfg fSpec function newOracle callstack' (acc' <> insts) pre lTo test f' - where - outArcs = cfg_arcs cfg ^. ix v' - isCalledBy = (moveLabel (callerPcOfCallEntry $ top callstack') sizeOfCall ==) . v_label - outArcs' = filter (\(dst, _, _, f') -> not (isRetArc f') || isCalledBy dst) outArcs + where + outArcs = cfg_arcs cfg ^. ix v' + isCalledBy = (moveLabel (callerPcOfCallEntry $ top callstack') sizeOfCall ==) . v_label + outArcs' = filter (\(dst, _, _, f') -> not (isRetArc f') || isCalledBy dst) outArcs {- Revisiting nodes (thus looping) within the CFG is verboten in all cases but one, specifically when we are jumping back to a label that is annotated @@ -301,96 +313,95 @@ visitArcs cfg fSpec function callstack' newOracle acc' pre v' = do form of ArcCondition and CallStack needs a bit of extra information about when call/ret are called, in the form of FInfo. -} -visit :: - CFG -> - FuncSpec -> - Function -> - Map (NonEmpty Label, Label) Bool -> - CallStack -> - [LabeledInst] -> - SpecBuilder -> - Vertex -> - ArcCondition -> - FInfo -> - ModuleL () +visit + :: CFG + -> FuncSpec + -> Function + -> Map (NonEmpty Label, Label) Bool + -> CallStack + -> [LabeledInst] + -> SpecBuilder + -> Vertex + -> ArcCondition + -> FInfo + -> ModuleL () visit cfg fSpec function oracle callstack acc builder v@(Vertex _ label preCheckedF) arcCond f = visiting (stackTrace callstack', oracle, v) $ \alreadyVisited -> if alreadyVisited then visitLoop builder else visitLinear builder - where - visitLoop :: SpecBuilder -> ModuleL () - visitLoop SBRich = extractPlainBuilder fSpec >>= visitLoop - visitLoop (SBPlain pre) - | null assertions = throwError (ELoopNoInvariant label) - | otherwise = emit pre (Expr.and assertions) - - visitLinear :: SpecBuilder -> ModuleL () - visitLinear SBRich - | onFinalNode = emit (fs_pre fSpec) (Expr.and $ map snd (cfg_assertions cfg ^. ix v)) - | null assertions = visitArcs cfg fSpec function callstack' oracle' acc builder v - | otherwise = extractPlainBuilder fSpec >>= visitLinear - visitLinear (SBPlain pre) - | null assertions = visitArcs cfg fSpec function callstack' oracle' acc builder v - | otherwise = do - emit pre (Expr.and assertions) - visitArcs cfg fSpec function callstack' Map.empty [] (SBPlain (Expr.and assertions)) v - - callstack' = case f of - Nothing -> callstack - (Just (ArcCall callerPc calleePc)) -> push (callerPc, calleePc) callstack - (Just ArcRet) -> snd (pop callstack) - - oracle' = updateOracle arcCond callstack' oracle - assertions = map snd (cfg_assertions cfg ^. ix v) - onFinalNode = null (cfg_arcs cfg ^. ix v) - - labelledCall@(fCallerPc, _) = last acc - preCheckingStackFrame = (fCallerPc, uncheckedCallDestination labelledCall) - preCheckingContext = (push preCheckingStackFrame callstack',) <$> preCheckedF - - emit :: Expr TBool -> Expr TBool -> ModuleL () - emit pre post = emitModule (Module spec acc oracle' pc (callstack', label) preCheckingContext) - where - pc = fu_pc function - spec = FuncSpec pre post (fs_storage fSpec) - -{- | This function represents a depth first search through the CFG that uses as - sentinels (for where to begin and where to end) assertions in nodes, such - that nodes that are not annotated are traversed without stopping the search, - gathering labels from respective edges that represent instructions and - concatenating them into final Modules, that are subsequently transformed into - actual *.smt2 queries. - - Thus, a module can comprise of 0 to several segments, where the precondition - of the module is the annotation of the node 'begin' that begins the first - segment, the postcondition of the module is the annotation of the node 'end' - that ends the last segment and instructions of the module are a concatenation - of edge labels for the given path through the graph from 'begin' to 'end'. - - Note that NO node with an annotation can be encountered in the middle of one - such path, because annotated nodes are sentinels and the search would - terminate. - - We distinguish between plain and rich modules. A plain module is a - self-contained 'sub-program' with its own semantics that is referentially - pure in the sense that it has no side-effects on the environment, i.e. does - not access storage variables. - - A rich module is very much like a plain module except it allows side effects, - i.e. accesses to storage variables. --} + where + visitLoop :: SpecBuilder -> ModuleL () + visitLoop SBRich = extractPlainBuilder fSpec >>= visitLoop + visitLoop (SBPlain pre) + | null assertions = throwError (ELoopNoInvariant label) + | otherwise = emit pre (Expr.and assertions) + + visitLinear :: SpecBuilder -> ModuleL () + visitLinear SBRich + | onFinalNode = emit (fs_pre fSpec) (Expr.and $ map snd (cfg_assertions cfg ^. ix v)) + | null assertions = visitArcs cfg fSpec function callstack' oracle' acc builder v + | otherwise = extractPlainBuilder fSpec >>= visitLinear + visitLinear (SBPlain pre) + | null assertions = visitArcs cfg fSpec function callstack' oracle' acc builder v + | otherwise = do + emit pre (Expr.and assertions) + visitArcs cfg fSpec function callstack' Map.empty [] (SBPlain (Expr.and assertions)) v + + callstack' = case f of + Nothing -> callstack + (Just (ArcCall callerPc calleePc)) -> push (callerPc, calleePc) callstack + (Just ArcRet) -> snd (pop callstack) + + oracle' = updateOracle arcCond callstack' oracle + assertions = map snd (cfg_assertions cfg ^. ix v) + onFinalNode = null (cfg_arcs cfg ^. ix v) + + labelledCall@(fCallerPc, _) = last acc + preCheckingStackFrame = (fCallerPc, uncheckedCallDestination labelledCall) + preCheckingContext = (push preCheckingStackFrame callstack',) <$> preCheckedF + + emit :: Expr TBool -> Expr TBool -> ModuleL () + emit pre post = emitModule (Module spec acc oracle' pc (callstack', label) preCheckingContext) + where + pc = fu_pc function + spec = FuncSpec pre post (fs_storage fSpec) + +-- | This function represents a depth first search through the CFG that uses as +-- sentinels (for where to begin and where to end) assertions in nodes, such +-- that nodes that are not annotated are traversed without stopping the search, +-- gathering labels from respective edges that represent instructions and +-- concatenating them into final Modules, that are subsequently transformed into +-- actual *.smt2 queries. +-- +-- Thus, a module can comprise of 0 to several segments, where the precondition +-- of the module is the annotation of the node 'begin' that begins the first +-- segment, the postcondition of the module is the annotation of the node 'end' +-- that ends the last segment and instructions of the module are a concatenation +-- of edge labels for the given path through the graph from 'begin' to 'end'. +-- +-- Note that NO node with an annotation can be encountered in the middle of one +-- such path, because annotated nodes are sentinels and the search would +-- terminate. +-- +-- We distinguish between plain and rich modules. A plain module is a +-- self-contained 'sub-program' with its own semantics that is referentially +-- pure in the sense that it has no side-effects on the environment, i.e. does +-- not access storage variables. +-- +-- A rich module is very much like a plain module except it allows side effects, +-- i.e. accesses to storage variables. gatherFromSource :: CFG -> Function -> FuncSpec -> ModuleL () gatherFromSource cfg function fSpec = do let verticesAtFuPc = verticesLabelledBy cfg $ fu_pc function for_ verticesAtFuPc $ \v -> visit cfg fSpec function Map.empty (initialWithFunc (fu_pc function)) [] SBRich v ACNone Nothing -updateOracle :: - ArcCondition -> - CallStack -> - Map (NonEmpty Label, Label) Bool -> - Map (NonEmpty Label, Label) Bool +updateOracle + :: ArcCondition + -> CallStack + -> Map (NonEmpty Label, Label) Bool + -> Map (NonEmpty Label, Label) Bool updateOracle ACNone _ = id updateOracle (ACJnz jnzPc isSat) callstack = Map.insert (stackTrace callstack, jnzPc) isSat diff --git a/src/Horus/Module/Runner.hs b/src/Horus/Module/Runner.hs index 769d2433..b3943bc3 100644 --- a/src/Horus/Module/Runner.hs +++ b/src/Horus/Module/Runner.hs @@ -27,15 +27,15 @@ type Impl = interpret :: ModuleL a -> Impl a interpret = iterM exec . runModuleL - where - exec :: ModuleF (Impl a) -> Impl a - exec (EmitModule m cont) = tell (D.singleton m) *> cont - exec (Visiting l action cont) = do - visited <- ask - local (Set.insert l) $ - interpret (action (Set.member l visited)) >>= cont - exec (Throw t) = throwError t - exec (Catch m handler cont) = catchError (interpret m) (interpret . handler) >>= cont + where + exec :: ModuleF (Impl a) -> Impl a + exec (EmitModule m cont) = tell (D.singleton m) *> cont + exec (Visiting l action cont) = do + visited <- ask + local (Set.insert l) $ + interpret (action (Set.member l visited)) >>= cont + exec (Throw t) = throwError t + exec (Catch m handler cont) = catchError (interpret m) (interpret . handler) >>= cont run :: ModuleL a -> Either Text [Module] run m = diff --git a/src/Horus/Preprocessor.hs b/src/Horus/Preprocessor.hs index 69afbdc6..21519726 100644 --- a/src/Horus/Preprocessor.hs +++ b/src/Horus/Preprocessor.hs @@ -120,12 +120,11 @@ interpConst model name = do data SolverResult = Unsat | Sat (Maybe Model) | Unknown (Maybe Text) deriving (Eq) -{- | The set of user-facing results for a given module or function. - - This is just like `SolverResult`, except that we rename the constructors to - match more closely what a person unfamiliar with SMT solvers would expect, - and we add the `ContradictoryPrecondition` constructor. --} +-- | The set of user-facing results for a given module or function. +-- +-- This is just like `SolverResult`, except that we rename the constructors to +-- match more closely what a person unfamiliar with SMT solvers would expect, +-- and we add the `ContradictoryPrecondition` constructor. data HorusResult = Verified | Counterexample (Maybe Model) @@ -156,37 +155,36 @@ instance Show Model where concatMap showAp (toList m_regs) <> concatMap showMem (toList m_mem) <> concatMap showLVar (toList m_lvars) - where - showAp (reg, value) = printf "%8s\t=\t%d\n" reg value - showMem (addr, value) = printf "mem[%3d]\t=\t%d\n" addr value - showLVar (lvar, value) = printf "%8s\t=\t%d\n" lvar value - -{- | Optimize the query into a list of `Goal`s, and then fold the results of - each goal together into a single `SolverResult`. --} + where + showAp (reg, value) = printf "%8s\t=\t%d\n" reg value + showMem (addr, value) = printf "mem[%3d]\t=\t%d\n" addr value + showLVar (lvar, value) = printf "%8s\t=\t%d\n" lvar value + +-- | Optimize the query into a list of `Goal`s, and then fold the results of +-- each goal together into a single `SolverResult`. solve :: Integer -> Text -> PreprocessorL SolverResult solve fPrime smtQuery = do optimizeQuery smtQuery >>= foldlM combineResult (Unknown Nothing) - where - -- Given a `SolverResult` which is the combination of a bunch of results for - -- some list of `Goal`s, we compute the result of one additional `Goal` (the - -- second argument), and combine it with the existing results. - combineResult :: SolverResult -> Goal -> PreprocessorL SolverResult - combineResult (Sat mbModel) _ = pure (Sat mbModel) - combineResult Unsat subgoal = do - result <- computeResult subgoal - pure $ case result of - Sat mbModel -> Sat mbModel - _ -> Unsat - combineResult Unknown{} subgoal = computeResult subgoal - - computeResult :: Goal -> PreprocessorL SolverResult - computeResult subgoal = do - result <- runSolver =<< runZ3 (goalToSExpr subgoal) - case result of - (SMT.Sat, mbModel) -> maybe (pure (Sat Nothing)) (processModel fPrime subgoal) mbModel - (SMT.Unsat, _mbCore) -> pure Unsat - (SMT.Unknown, mbReason) -> pure (Unknown mbReason) + where + -- Given a `SolverResult` which is the combination of a bunch of results for + -- some list of `Goal`s, we compute the result of one additional `Goal` (the + -- second argument), and combine it with the existing results. + combineResult :: SolverResult -> Goal -> PreprocessorL SolverResult + combineResult (Sat mbModel) _ = pure (Sat mbModel) + combineResult Unsat subgoal = do + result <- computeResult subgoal + pure $ case result of + Sat mbModel -> Sat mbModel + _ -> Unsat + combineResult Unknown{} subgoal = computeResult subgoal + + computeResult :: Goal -> PreprocessorL SolverResult + computeResult subgoal = do + result <- runSolver =<< runZ3 (goalToSExpr subgoal) + case result of + (SMT.Sat, mbModel) -> maybe (pure (Sat Nothing)) (processModel fPrime subgoal) mbModel + (SMT.Unsat, _mbCore) -> pure Unsat + (SMT.Unknown, mbReason) -> pure (Unknown mbReason) optimizeQuery :: Text -> PreprocessorL [Goal] optimizeQuery smtQuery = do @@ -234,9 +232,9 @@ z3ModelToHorusModel fPrime model = consts <- getConsts fPrime model mbLVars <- for consts (pure . parseLVar) pure $ fromList $ catMaybes mbLVars - where - parseRegVar :: (Text, Integer) -> Maybe (RegKind, Text, Integer) - parseRegVar (name, value) = - parseRegKind name <&> (,name,toSignedFelt fPrime value) - parseLVar :: (Text, Integer) -> Maybe (Text, Integer) - parseLVar (name, value) = (,value) . pack . ('$' :) . unpack <$> stripPrefix "$" name + where + parseRegVar :: (Text, Integer) -> Maybe (RegKind, Text, Integer) + parseRegVar (name, value) = + parseRegKind name <&> (,name,toSignedFelt fPrime value) + parseLVar :: (Text, Integer) -> Maybe (Text, Integer) + parseLVar (name, value) = (,value) . pack . ('$' :) . unpack <$> stripPrefix "$" name diff --git a/src/Horus/Preprocessor/Runner.hs b/src/Horus/Preprocessor/Runner.hs index 0ddcff22..55360b97 100644 --- a/src/Horus/Preprocessor/Runner.hs +++ b/src/Horus/Preprocessor/Runner.hs @@ -44,19 +44,19 @@ type Impl a = ReaderT PreprocessorEnv (ExceptT Text Z3) a interpret :: PreprocessorL a -> Impl a interpret = iterM exec . runPreprocessor - where - exec :: PreprocessorF (Impl a) -> Impl a - exec (RunZ3 z3 cont) = do - lift (lift z3) >>= cont - exec (RunSolver tGoal cont) = do - externalSolver <- view peSolver - solverSettings <- view peSolverSettings - liftIO (runSolver externalSolver solverSettings tGoal) >>= cont - exec (GetMemsAndAddrs cont) = do - view peMemsAndAddrs >>= cont - exec (Throw e) = throwError e - exec (Catch preprocessor handler cont) = do - catchError (interpret preprocessor) (interpret . handler) >>= cont + where + exec :: PreprocessorF (Impl a) -> Impl a + exec (RunZ3 z3 cont) = do + lift (lift z3) >>= cont + exec (RunSolver tGoal cont) = do + externalSolver <- view peSolver + solverSettings <- view peSolverSettings + liftIO (runSolver externalSolver solverSettings tGoal) >>= cont + exec (GetMemsAndAddrs cont) = do + view peMemsAndAddrs >>= cont + exec (Throw e) = throwError e + exec (Catch preprocessor handler cont) = do + catchError (interpret preprocessor) (interpret . handler) >>= cont runImpl :: PreprocessorEnv -> Impl a -> Z3 (Either Text a) runImpl penv m = diff --git a/src/Horus/Preprocessor/Solvers.hs b/src/Horus/Preprocessor/Solvers.hs index c8e37db7..f9ba1c0e 100644 --- a/src/Horus/Preprocessor/Solvers.hs +++ b/src/Horus/Preprocessor/Solvers.hs @@ -107,25 +107,25 @@ runSingleSolver SingleSolver{..} SolverSettings{..} query = solving $ \solver -> pure (Just (pack (SMT.ppSExpr reason ""))) SMT.Unsat -> pure Nothing pure (res, mbModelOrReason) - where - solving f = withTimeout (withSolver s_name s_auxFlags f) - withTimeout f = do - mbResult <- timeout (ss_timeoutMillis * 1000) f - pure (fromMaybe timeoutResult mbResult) - timeoutResult = (SMT.Unknown, Just (s_name <> ": Time is out.")) + where + solving f = withTimeout (withSolver s_name s_auxFlags f) + withTimeout f = do + mbResult <- timeout (ss_timeoutMillis * 1000) f + pure (fromMaybe timeoutResult mbResult) + timeoutResult = (SMT.Unknown, Just (s_name <> ": Time is out.")) runSolver :: Solver -> SolverSettings -> Text -> IO (SMT.Result, Maybe Text) runSolver (MultiSolver solvers) settings query = foldlM combineResult (SMT.Unknown, Just "All solvers failed.") solvers - where - combineResult (SMT.Unknown, mbReason) nextSolver = do - (nextStatus, nextMbReason) <- runSingleSolver nextSolver settings query - case nextStatus of - SMT.Unknown -> pure (SMT.Unknown, mbReason <> annotateFailure nextSolver nextMbReason) - _ -> pure (nextStatus, nextMbReason) - combineResult res _ = pure res - - annotateFailure solver reason = Just ("\n" <> tShow solver <> " failed with: ") <> reason + where + combineResult (SMT.Unknown, mbReason) nextSolver = do + (nextStatus, nextMbReason) <- runSingleSolver nextSolver settings query + case nextStatus of + SMT.Unknown -> pure (SMT.Unknown, mbReason <> annotateFailure nextSolver nextMbReason) + _ -> pure (nextStatus, nextMbReason) + combineResult res _ = pure res + + annotateFailure solver reason = Just ("\n" <> tShow solver <> " failed with: ") <> reason withSolver :: Text -> [Text] -> (SMT.Solver -> IO a) -> IO a withSolver solverName args = diff --git a/src/Horus/SW/FuncSpec.hs b/src/Horus/SW/FuncSpec.hs index 7a64c5ca..158b1a32 100644 --- a/src/Horus/SW/FuncSpec.hs +++ b/src/Horus/SW/FuncSpec.hs @@ -20,13 +20,12 @@ data FuncSpec = FuncSpec emptyFuncSpec :: FuncSpec emptyFuncSpec = FuncSpec{fs_pre = Expr.True, fs_post = Expr.True, fs_storage = mempty} -{- | A version of `FuncSpec` that distinguishes omitted preconditions and - postconditions from trivial ones. - - We define this in addition to `FuncSpec` for separation of concerns. Note - that `FuncSpec` has a direct mapping from JSON, but conflates `True` with - `Nothing`. --} +-- | A version of `FuncSpec` that distinguishes omitted preconditions and +-- postconditions from trivial ones. +-- +-- We define this in addition to `FuncSpec` for separation of concerns. Note +-- that `FuncSpec` has a direct mapping from JSON, but conflates `True` with +-- `Nothing`. data FuncSpec' = FuncSpec' { fs'_pre :: Maybe (Expr TBool) , fs'_post :: Maybe (Expr TBool) diff --git a/src/Horus/SW/Std.hs b/src/Horus/SW/Std.hs index 90bc356b..24770a27 100644 --- a/src/Horus/SW/Std.hs +++ b/src/Horus/SW/Std.hs @@ -8,7 +8,16 @@ import Data.Text (Text) import Horus.Expr (Expr (ExitField), (.&&), (.<), (.<=), (.==)) import Horus.Expr qualified as Expr -import Horus.Expr.Vars (ap, blockTimestamp, callerAddress, contractAddress, fp, memory, prime, rcBound) +import Horus.Expr.Vars + ( ap + , blockTimestamp + , callerAddress + , contractAddress + , fp + , memory + , prime + , rcBound + ) import Horus.SW.FuncSpec (FuncSpec (..), emptyFuncSpec) import Horus.SW.ScopedName (ScopedName) import Horus.Util (tShow) @@ -18,20 +27,19 @@ stdSpecs = Map.fromList stdSpecsList mkReadSpec :: ScopedName -> Int -> FuncSpec mkReadSpec name arity = emptyFuncSpec{fs_post = memory (ap - 1) .== var} - where - offsets = [-3 - arity + 1 .. -3] - args = [memory (fp + fromIntegral offset) | offset <- offsets] - var = Expr.apply (Expr.Fun (tShow name)) args + where + offsets = [-3 - arity + 1 .. -3] + args = [memory (fp + fromIntegral offset) | offset <- offsets] + var = Expr.apply (Expr.Fun (tShow name)) args mkWriteSpec :: ScopedName -> Int -> FuncSpec mkWriteSpec name arity = emptyFuncSpec{fs_storage = [(name, [(args, memory (fp - 3))])]} - where - offsets = [-4 - arity + 1 .. -4] - args = [memory (fp + fromIntegral offset) | offset <- offsets] + where + offsets = [-4 - arity + 1 .. -4] + args = [memory (fp + fromIntegral offset) | offset <- offsets] -{- | A list of names of trusted standard library functions. -These functions will not be checked against their specifications. --} +-- | A list of names of trusted standard library functions. +-- These functions will not be checked against their specifications. trustedStdFuncs :: [Text] trustedStdFuncs = [ "starkware.starknet.common.syscalls.get_block_timestamp" @@ -40,13 +48,12 @@ trustedStdFuncs = , "starkware.cairo.common.math.assert_le_felt" ] -{- | A lexicographically sorted by fs_name list of specifications of - standard library functions. - -The list should be lexicographically sorted by function name. It -doesn't impact correctness of the program, but simplifies looking for -functions. --} +-- | A lexicographically sorted by fs_name list of specifications of +-- standard library functions. +-- +-- The list should be lexicographically sorted by function name. It +-- doesn't impact correctness of the program, but simplifies looking for +-- functions. stdSpecsList :: [(ScopedName, FuncSpec)] stdSpecsList = [ diff --git a/src/Horus/SW/Storage.hs b/src/Horus/SW/Storage.hs index 32bc4070..49e636be 100644 --- a/src/Horus/SW/Storage.hs +++ b/src/Horus/SW/Storage.hs @@ -21,28 +21,29 @@ equivalenceExpr a b = Expr.and [checkStorageIsSubset a b, checkStorageIsSubset b checkStorageIsSubset :: Storage -> Storage -> Expr TBool checkStorageIsSubset a b = Expr.and $ map equalReads (getWrites a) - where - equalReads (name, args, _value) = read a name args .== read b name args + where + equalReads (name, args, _value) = read a name args .== read b name args read :: Storage -> ScopedName -> [Expr TFelt] -> Expr TFelt read storage name args = buildReadChain args baseCase writes - where - baseCase = Expr.apply (Expr.Fun (tShow name)) args - writes = Map.findWithDefault [] name storage + where + baseCase = Expr.apply (Expr.Fun (tShow name)) args + writes = Map.findWithDefault [] name storage buildReadChain :: [Expr TFelt] -> Expr TFelt -> [([Expr TFelt], Expr TFelt)] -> Expr TFelt buildReadChain readAt baseCase writes = go baseCase (reverse writes) - where - go acc [] = acc - go acc ((args, value) : rest) - | length args /= arity = error "buildReadChain: a storage var is accessed with a wrong number of arguments." - | otherwise = go (Expr.ite (Expr.and (zipWith (.==) readAt args)) value acc) rest - arity = length readAt + where + go acc [] = acc + go acc ((args, value) : rest) + | length args /= arity = + error "buildReadChain: a storage var is accessed with a wrong number of arguments." + | otherwise = go (Expr.ite (Expr.and (zipWith (.==) readAt args)) value acc) rest + arity = length readAt getWrites :: Storage -> [(ScopedName, [Expr TFelt], Expr TFelt)] getWrites storage = concatMap getWritesForName (Map.toList storage) - where - getWritesForName (name, writes) = [(name, args, value) | (args, value) <- writes] + where + getWritesForName (name, writes) = [(name, args, value) | (args, value) <- writes] parse :: Value -> Parser Storage parse v = fmap elimHelpersFromStorage (parseJSON v) diff --git a/src/Horus/Util.hs b/src/Horus/Util.hs index 2ee1091f..3946b3b0 100644 --- a/src/Horus/Util.hs +++ b/src/Horus/Util.hs @@ -31,8 +31,8 @@ toSignedFelt :: Integer -> Integer -> Integer toSignedFelt fPrime x | moddedX > fPrime `div` 2 = moddedX - fPrime | otherwise = moddedX - where - moddedX = x `mod` fPrime + where + moddedX = x `mod` fPrime whenJust :: Applicative f => Maybe a -> (a -> f ()) -> f () whenJust Nothing _ = pure () @@ -61,9 +61,9 @@ tShow = pack . show commonPrefix :: [Text] -> Text commonPrefix = foldr (\x acc -> unspoon $ Text.commonPrefixes x acc) "" - where - unspoon :: Maybe (Text, Text, Text) -> Text - unspoon = maybe "" $ \(prefix, _, _) -> prefix + where + unspoon :: Maybe (Text, Text, Text) -> Text + unspoon = maybe "" $ \(prefix, _, _) -> prefix enumerate :: (Enum a, Bounded a) => [a] enumerate = [minBound ..]