Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test cases for @assert notation #153

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions src/Horus/FunctionAnalysis.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ module Horus.FunctionAnalysis
, isAuxFunc
, scopedFOfPc
, uncheckedScopedFOfPc
, functionsOf
, callgraph
, graphOfCG
)
where

Expand All @@ -25,7 +28,7 @@ import Control.Monad (liftM2, (<=<))
import Data.Array (assocs)
import Data.Coerce (coerce)
import Data.Function ((&))
import Data.Graph (Graph, Vertex, graphFromEdges', reachable)
import Data.Graph (Graph, Vertex, graphFromEdges, reachable)
import Data.List (foldl', sort, union)
import Data.Map qualified as Map
( Map
Expand Down Expand Up @@ -109,8 +112,8 @@ cgMbInsertArc (CG verts arcs) (fro, to) =
then Nothing
else Just . CG verts $ Map.insertWith (++) fro [to] arcs

graphOfCG :: CG -> (Graph, Vertex -> (Label, Label, [Label]))
graphOfCG cg = graphFromEdges' . map named . Map.assocs $ cg_arcs cg
graphOfCG :: CG -> (Graph, Vertex -> (Label, Label, [Label]), Label -> Maybe Vertex)
graphOfCG cg = graphFromEdges . map named . Map.assocs $ cg_arcs cg
where
named (fro, tos) = (fro, fro, tos)

Expand All @@ -121,7 +124,7 @@ cycles g = map fst . filter (uncurry reachableSet) $ assocs g

cyclicVerts :: CG -> [Label]
cyclicVerts cg =
let (graph, vertToNode) = graphOfCG cg
let (graph, vertToNode, _) = graphOfCG cg
in map ((\(lbl, _, _) -> lbl) . vertToNode) (cycles graph)

pcToFunOfProg :: Program -> Map.Map Label ScopedFunction
Expand Down Expand Up @@ -271,18 +274,13 @@ isGeneratedName fname cd = fname `elem` generatedNames
isSvarFunc :: ScopedName -> ContractDefinition -> Bool
isSvarFunc fname cd = isGeneratedName fname cd || fname `elem` [fStorageRead, fStorageWrite]

fHash2 :: ScopedName
fHash2 = ScopedName ["starkware", "cairo", "common", "hash", "hash2"]

fAssert250bit :: ScopedName
fAssert250bit = ScopedName ["starkware", "cairo", "common", "math", "assert_250_bit"]

fNormalizeAddress :: ScopedName
fNormalizeAddress = ScopedName ["starkware", "starknet", "common", "storage", "normalize_address"]

isAuxFunc :: ScopedFunction -> ContractDefinition -> Bool
isAuxFunc (ScopedFunction fname _) cd =
isSvarFunc fname cd || fname `elem` [fHash2, fAssert250bit, fNormalizeAddress]
where
fHash2 = ScopedName ["starkware", "cairo", "common", "hash", "hash2"]
fAssert250bit = ScopedName ["starkware", "cairo", "common", "math", "assert_250_bit"]
fNormalizeAddress = ScopedName ["starkware", "starknet", "common", "storage", "normalize_address"]

sizeOfCall :: Int
sizeOfCall = 2
Expand All @@ -304,9 +302,9 @@ inlinableFuns rows prog cd =
notIsAnnotated sf = maybe False (isNotAnnotated cd) . Map.lookup (sf_scopedName sf) $ idents
notIsAnnotatedLater f = sf_scopedName f `notElem` map fst stdSpecsList
localCycles = Map.map (cyclicVerts . jumpgraph)
isAcylic cyclicFuns f cyclicLbls = f `notElem` cyclicFuns && null cyclicLbls
isAcyclic cyclicFuns f cyclicLbls = f `notElem` cyclicFuns && null cyclicLbls
inlinable =
Map.keys . Map.filterWithKey (isAcylic . cyclicVerts $ callgraph (Map.mapKeys sf_pc functions)) $
Map.keys . Map.filterWithKey (isAcyclic . cyclicVerts $ callgraph (Map.mapKeys sf_pc functions)) $
Map.mapKeys sf_pc (localCycles functions)

uninlinableFuns :: [LabeledInst] -> Program -> ContractDefinition -> Map.Map ScopedFunction [LabeledInst]
Expand Down
33 changes: 26 additions & 7 deletions src/Horus/Global.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ import Control.Monad (when)
import Control.Monad.Except (MonadError (..))
import Control.Monad.Free.Church (F, liftF)
import Data.Foldable (for_)
import Data.List (groupBy)
import Data.Maybe (fromMaybe)
import Data.Graph (reachable)
import Data.List (groupBy, partition)
import Data.Map qualified as Map
import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set, singleton, toAscList, (\\))
import Data.Set qualified as Set (map)
import Data.Set qualified as Set (fromList, map, member)
import Data.Text (Text, unpack)
import Data.Text qualified as Text (isPrefixOf)
import Data.Traversable (for)
Expand All @@ -37,7 +39,7 @@ import Horus.CairoSemantics.Runner
import Horus.CallStack (CallStack, initialWithFunc)
import Horus.Expr qualified as Expr
import Horus.Expr.Util (gatherLogicalVariables)
import Horus.FunctionAnalysis (ScopedFunction (ScopedFunction, sf_pc), isWrapper)
import Horus.FunctionAnalysis (ScopedFunction (ScopedFunction, sf_pc), callgraph, functionsOf, graphOfCG, isWrapper)
import Horus.Logger qualified as L (LogL, logDebug, logError, logInfo, logWarning)
import Horus.Module (Module (..), ModuleL, gatherModules, getModuleNameParts)
import Horus.Preprocessor (HorusResult (..), PreprocessorL, SolverResult (..), goalListToTextList, optimizeQuery, solve)
Expand All @@ -49,6 +51,7 @@ import Horus.SW.Identifier (Function (..))
import Horus.SW.ScopedName (ScopedName ())
import Horus.SW.Std (trustedStdFuncs)
import Horus.Util (tShow, whenJust)
import Lens.Micro ((^.), _3)

data Config = Config
{ cfg_verbose :: Bool
Expand Down Expand Up @@ -329,6 +332,7 @@ collapseAllUnsats infos@(SolvingInfo _ funcName result _ _ : _)

{- | Return a solution of SMT queries corresponding with the contract.


For the purposes of reporting results,
we also remember which SMT query corresponding to a function was inlined.
-}
Expand All @@ -343,7 +347,8 @@ solveContract = do
let fs = toAscList inlinables
cfgs <- for fs $ \f -> runCFGBuildL (buildCFG lInstructions $ inlinables \\ singleton f)
for_ cfgs verbosePrint
modules <- concat <$> for ((cfg, isStandardSource inlinables) : zip cfgs (map (==) fs)) makeModules
sources <- userAnnotatedSources inlinables lInstructions
modules <- concat <$> for ((cfg, (`elem` sources)) : zip cfgs (map (==) fs)) makeModules

identifiers <- getIdentifiers
let isUntrusted :: Module -> Bool
Expand All @@ -359,8 +364,22 @@ solveContract = do
)
infos
where
isStandardSource :: Set ScopedFunction -> ScopedFunction -> Bool
isStandardSource inlinables f = f `notElem` inlinables && not (isWrapper f)
userAnnotatedSources :: Set ScopedFunction -> [LabeledInst] -> GlobalL (Set ScopedFunction)
userAnnotatedSources inlinableFs rows =
getProgram >>= \prog ->
let functionsWithBodies = functionsOf rows prog
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bug here that triggers fromJust under some circumstances unknown to me. This needs investigation. Preferably, this whole bit of code wouldn't be here in the first place because we're replicating this kind of logic in FunctionAnalysis.hs and frankly even in Module.hs.

functions = Map.keys functionsWithBodies
(cg, vToLbl, lblToV) = graphOfCG . callgraph . Map.mapKeys sf_pc $ functionsWithBodies
(wrapperFunctions, nonwrapperFunctions) = partition isWrapper functions
reachableLabelsFromWrappers =
Set.fromList
. concatMap (concatMap ((^. _3) . vToLbl) . reachable cg . fromJust . lblToV . sf_pc)
$ wrapperFunctions
calledByWrappers =
Set.fromList
[ sf | sf <- functions, sf_pc sf `Set.member` reachableLabelsFromWrappers
]
in pure (Set.fromList nonwrapperFunctions \\ inlinableFs \\ calledByWrappers)

sameFuncName :: SolvingInfo -> SolvingInfo -> Bool
sameFuncName (SolvingInfo _ nameA _ _ _) (SolvingInfo _ nameB _ _ _) = nameA == nameB
Expand Down
12 changes: 12 additions & 0 deletions tests/resources/golden/extern_remove_dirty.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
%lang starknet

@external
func f() -> (array_len : felt, array : felt*) {
alloc_locals;
// An array of felts.
local felt_array: felt*;
assert felt_array[0] = 0;
assert felt_array[1] = 1;
assert felt_array[2] = 2;
return (array_len=3, array=felt_array);
}
2 changes: 2 additions & 0 deletions tests/resources/golden/extern_remove_dirty.gold
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
f [inlined]
Verified