diff --git a/granule-compiler.cabal b/granule-compiler.cabal index 450072e..dcc6c42 100644 --- a/granule-compiler.cabal +++ b/granule-compiler.cabal @@ -47,7 +47,9 @@ library Language.Granule.Codegen.Emit.Names Language.Granule.Codegen.Emit.Primitives Language.Granule.Codegen.Emit.Types + Language.Granule.Codegen.Monomorphise Language.Granule.Codegen.RewriteAST + Language.Granule.Codegen.SubstituteTypes Paths_granule_compiler hs-source-dirs: src @@ -62,6 +64,7 @@ library base >=4.10 && <5 , containers , granule-frontend + , hashable , llvm-hs ==12.* , llvm-hs-pure ==12.* , mtl @@ -91,6 +94,7 @@ executable grlc , gitrev , granule-compiler , granule-frontend + , hashable , llvm-hs ==12.* , llvm-hs-pure ==12.* , optparse-applicative @@ -123,6 +127,7 @@ test-suite compiler-spec , filemanip , granule-compiler , granule-frontend + , hashable , hspec , mtl , process @@ -149,6 +154,7 @@ test-suite golden , filepath , granule-compiler , granule-frontend + , hashable , llvm-hs ==12.* , llvm-hs-pure ==12.* , process diff --git a/package.yaml b/package.yaml index 784c838..df5eb50 100644 --- a/package.yaml +++ b/package.yaml @@ -8,6 +8,7 @@ github: granule-project/granule-compiler-llvm dependencies: - base >=4.10 && <5 - process + - hashable default-extensions: - LambdaCase - RecordWildCards diff --git a/src/Language/Granule/Codegen/Builtins/Builtins.hs b/src/Language/Granule/Codegen/Builtins/Builtins.hs index f4aba6e..91064e8 100644 --- a/src/Language/Granule/Codegen/Builtins/Builtins.hs +++ b/src/Language/Granule/Codegen/Builtins/Builtins.hs @@ -20,5 +20,16 @@ builtins = deleteFloatArrayDef ] +specialisable :: [Specialisable] +specialisable = + [ useDef + ] + +monoBuiltinIds :: [Id] +monoBuiltinIds = map (mkId . builtinId) builtins + +polyBuiltinIds :: [Id] +polyBuiltinIds = map (mkId . specialisableId) specialisable + builtinIds :: [Id] -builtinIds = map (mkId . builtinId) builtins +builtinIds = monoBuiltinIds ++ polyBuiltinIds diff --git a/src/Language/Granule/Codegen/Builtins/Extras.hs b/src/Language/Granule/Codegen/Builtins/Extras.hs index 63e96b8..9c09fb7 100644 --- a/src/Language/Granule/Codegen/Builtins/Extras.hs +++ b/src/Language/Granule/Codegen/Builtins/Extras.hs @@ -25,3 +25,10 @@ divDef = args = [TyCon (Id "Int" "Int"), TyCon (Id "Int" "Int")] ret = TyCon (Id "Int" "Int") impl [x, y] = sdiv x y + +-- use :: a -> a [1] +useDef :: Specialisable +useDef = + Specialisable "use" impl + where + impl _ [val] = return val diff --git a/src/Language/Granule/Codegen/Builtins/Shared.hs b/src/Language/Granule/Codegen/Builtins/Shared.hs index 29c52c5..9e72691 100644 --- a/src/Language/Granule/Codegen/Builtins/Shared.hs +++ b/src/Language/Granule/Codegen/Builtins/Shared.hs @@ -1,5 +1,6 @@ {-# LANGUAGE RankNTypes #-} {-# OPTIONS_GHC -Wno-incomplete-patterns #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} module Language.Granule.Codegen.Builtins.Shared where @@ -21,6 +22,18 @@ data Builtin = Builtin { builtinRetTy :: Gr.Type, builtinImpl :: forall m. (MonadModuleBuilder m, MonadIRBuilder m) => [Operand] -> m Operand} +data Specialisable = Specialisable { + specialisableId :: String, + specialisableImpl :: [Gr.Type] -> forall m. (MonadModuleBuilder m, MonadIRBuilder m) => [Operand] -> m Operand} + +specialise :: Specialisable -> String -> Gr.Type -> Builtin +specialise builtin id ty = Builtin id args ret impl + where + args = Gr.parameterTypes ty + ret = Gr.resultType ty + {-# HLINT ignore "Eta reduce" #-} -- has to be lazy or we need IRBuilder early + impl xs = specialisableImpl builtin args xs + -- LLVM helpers allocate :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> IR.Type -> m Operand diff --git a/src/Language/Granule/Codegen/ClosureFreeDef.hs b/src/Language/Granule/Codegen/ClosureFreeDef.hs index 0626e0b..ffd109c 100644 --- a/src/Language/Granule/Codegen/ClosureFreeDef.hs +++ b/src/Language/Granule/Codegen/ClosureFreeDef.hs @@ -54,7 +54,6 @@ data ClosureMarker = CapturedVar Type Id Int | MakeClosure Id ClosureEnvironmentInit | MakeTrivialClosure Id - | MakeBuiltinClosure Id deriving (Show, Eq) data ClosureFreeAST = @@ -84,4 +83,3 @@ instance Pretty ClosureMarker where "env(ident = \"" ++ envName ++ "\", " ++ intercalate ", " (map prettyEnvVar varInits) ++ ")" in "make-closure(" ++ pretty ident ++ ", " ++ prettyEnv env ++ ")" pretty (MakeTrivialClosure ident) = pretty ident - pretty (MakeBuiltinClosure ident) = pretty ident diff --git a/src/Language/Granule/Codegen/Compile.hs b/src/Language/Granule/Codegen/Compile.hs index 892c241..f76aa65 100644 --- a/src/Language/Granule/Codegen/Compile.hs +++ b/src/Language/Granule/Codegen/Compile.hs @@ -8,14 +8,18 @@ import Language.Granule.Codegen.TopsortDefinitions import Language.Granule.Codegen.ConvertClosures import Language.Granule.Codegen.Emit.EmitLLVM import Language.Granule.Codegen.MarkGlobals +import Language.Granule.Codegen.Monomorphise import Language.Granule.Codegen.RewriteAST +import Language.Granule.Codegen.SubstituteTypes import qualified LLVM.AST as IR compile :: String -> AST () Type -> Either String IR.Module compile moduleName typedAST = - let rewritten = rewriteAST typedAST - normalised = normaliseDefinitions rewritten + let substituted = substituteTypes typedAST + rewritten = rewriteAST substituted + monomorphised = monomorphiseAST rewritten + normalised = normaliseDefinitions monomorphised markedGlobals = markGlobals normalised (Ok topsorted) = topologicallySortDefinitions markedGlobals closureFree = convertClosures topsorted diff --git a/src/Language/Granule/Codegen/ConvertClosures.hs b/src/Language/Granule/Codegen/ConvertClosures.hs index 6842b6f..f221252 100644 --- a/src/Language/Granule/Codegen/ConvertClosures.hs +++ b/src/Language/Granule/Codegen/ConvertClosures.hs @@ -167,8 +167,5 @@ convertClosuresFromValue (_, maybeCurrentEnv, locals) (VarF ty ident) ++ sourceName ident ++ " in environment." in return $ Ext ty (Right (CapturedVar ty ident indexInEnv)) -convertClosuresFromValue (_, maybeCurrentEnv, _) (ExtF ty (BuiltinVar _ ident)) = - return $ Ext ty $ Right $ MakeBuiltinClosure ident - convertClosuresFromValue _ other = return $ fixMapExtValue (\ty gv -> Ext ty $ Left gv) other diff --git a/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs b/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs index 728990d..10ae8f6 100644 --- a/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs +++ b/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs @@ -6,7 +6,7 @@ import Control.Monad (forM) import LLVM.AST (Operand) import qualified LLVM.AST as IR import qualified LLVM.AST.Constant as C -import LLVM.AST.Type hiding (Type) +import LLVM.AST.Type hiding (resultType, Type) import LLVM.IRBuilder.Constant (int32) import LLVM.IRBuilder.Instruction import LLVM.IRBuilder.Module @@ -16,9 +16,17 @@ import Language.Granule.Codegen.Builtins.Shared import Language.Granule.Codegen.Emit.LLVMHelpers import Language.Granule.Codegen.Emit.LowerClosure (mallocEnvironment) import Language.Granule.Codegen.Emit.LowerType (llvmType, llvmTypeForClosure, llvmTypeForFunction) +import Language.Granule.Syntax.Identifiers +import Language.Granule.Syntax.Type -emitBuiltins :: (MonadModuleBuilder m) => m [Operand] -emitBuiltins = mapM emitBuiltin builtins + +emitBuiltins :: (MonadModuleBuilder m) => [(Id, Type)] -> m () +emitBuiltins uses = mapM_ emitBuiltin (monos ++ polys) + where + monos = + [b | (id, _) <- uses, b <- builtins, builtinId b == sourceName id] + polys = + [specialise b (internalName id) ty | (id, ty) <- uses, b <- specialisable, specialisableId b == sourceName id] emitBuiltin :: (MonadModuleBuilder m) => Builtin -> m Operand emitBuiltin builtin = diff --git a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs index fae0c94..0f1ccc5 100644 --- a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs +++ b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs @@ -40,20 +40,21 @@ import LLVM.IRBuilder (int32) emitLLVM :: String -> ClosureFreeAST -> Either String IR.Module emitLLVM moduleName (ClosureFreeAST dataDecls functionDefs valueDefs) = - let buildModule name m = evalState (buildModuleT name m) (EmitterState { localSymbols = Map.empty }) + let buildModule name m = evalState (buildModuleT name m) (EmitterState { localSymbols = Map.empty, builtins = Map.empty }) in Right $ buildModule (fromString moduleName) $ do _ <- extern (mkName "malloc") [i64] (ptr i8) _ <- extern (mkName "abort") [] void _ <- externVarArgs (mkName "printf") [ptr i8] i32 _ <- extern (mkName "llvm.memcpy.p0.p0.i32") [ptr i8, ptr i8, i32, i1] void _ <- extern (mkName "free") [ptr i8] void - _ <- emitBuiltins let mainTy = findMainReturnType valueDefs _ <- emitMainOut mainTy mapM_ emitDataDecl dataDecls mapM_ emitEnvironmentType functionDefs mapM_ emitFunctionDef functionDefs valueInitPairs <- mapM emitValueDef valueDefs + builtins <- usedBuiltins + _ <- emitBuiltins builtins emitGlobalInitializer valueInitPairs mainTy emitGlobalInitializer :: (MonadModuleBuilder m) => [(Operand, Operand)] -> GrType -> m Operand @@ -125,7 +126,9 @@ emitFunction _ _ _ _ _ = error "cannot emit function with non function type" paramName :: Pattern GrType -> ParameterName paramName (PConstr _ _ _ (Id "," _) _ _) = parameterNameFromId $ mkId "pair" paramName (PConstr _ _ _ (Id "()" _) _ _) = parameterNameFromId $ mkId "unit" -paramName pat = parameterNameFromId $ head $ boundVars pat +paramName pat = case boundVars pat of + [] -> parameterNameFromId $ mkId "wildcard" + (var : _) -> parameterNameFromId var emitArg :: (MonadState EmitterState m, MonadModuleBuilder m, MonadIRBuilder m) => Pattern GrType diff --git a/src/Language/Granule/Codegen/Emit/EmitterState.hs b/src/Language/Granule/Codegen/Emit/EmitterState.hs index c01c268..2e669bc 100644 --- a/src/Language/Granule/Codegen/Emit/EmitterState.hs +++ b/src/Language/Granule/Codegen/Emit/EmitterState.hs @@ -2,6 +2,7 @@ module Language.Granule.Codegen.Emit.EmitterState where import Language.Granule.Syntax.Identifiers (Id, internalName) +import Language.Granule.Syntax.Type import Control.Monad.State.Strict hiding (void) import LLVM.AST (Operand) @@ -9,7 +10,7 @@ import LLVM.AST (Operand) import Data.Map (Map, insertWith) import qualified Data.Map as Map -data EmitterState = EmitterState { localSymbols :: Map Id Operand } +data EmitterState = EmitterState { localSymbols :: Map Id Operand, builtins :: Map Id Type } addLocal :: (MonadState EmitterState m) => Id @@ -40,3 +41,9 @@ local name = case local of Just op -> return op Nothing -> error $ internalName name ++ "not registered as a local, missing call to addLocal?\n" + +useBuiltin :: (MonadState EmitterState m) => Id -> Type -> m () +useBuiltin id ty = modify $ \s -> s { builtins = Map.insert id ty (builtins s) } + +usedBuiltins :: (MonadState EmitterState m) => m [(Id, Type)] +usedBuiltins = Map.toList <$> gets builtins diff --git a/src/Language/Granule/Codegen/Emit/LowerClosure.hs b/src/Language/Granule/Codegen/Emit/LowerClosure.hs index f7730e5..9dcd8f8 100644 --- a/src/Language/Granule/Codegen/Emit/LowerClosure.hs +++ b/src/Language/Granule/Codegen/Emit/LowerClosure.hs @@ -58,11 +58,6 @@ emitClosureMarker ty maybeParentEnv (MakeClosure ident initializer) = emitClosureMarker ty _ (MakeTrivialClosure identifier) = return $ ConstantOperand $ makeTrivialClosure identifier ty -emitClosureMarker ty _ (MakeBuiltinClosure ident) = do - let functionPtr = ConstantOperand $ C.GlobalReference (ptr $ llvmTopLevelType ty) (functionNameFromId ident) - closure <- insertValue (ConstantOperand $ C.Undef (llvmType ty)) functionPtr [0] - insertValue closure (ConstantOperand $ C.Null (ptr i8)) [1] - emitEnvironmentInit :: (MonadModuleBuilder m, MonadIRBuilder m, MonadState EmitterState m) => [ClosureVariableInit] -> Operand diff --git a/src/Language/Granule/Codegen/Emit/LowerExpression.hs b/src/Language/Granule/Codegen/Emit/LowerExpression.hs index 15295c7..9f3f6fa 100644 --- a/src/Language/Granule/Codegen/Emit/LowerExpression.hs +++ b/src/Language/Granule/Codegen/Emit/LowerExpression.hs @@ -5,7 +5,7 @@ module Language.Granule.Codegen.Emit.LowerExpression where import Language.Granule.Codegen.ClosureFreeDef (ClosureMarker) import Language.Granule.Codegen.MarkGlobals (GlobalMarker, GlobalMarker(..)) import Language.Granule.Codegen.Emit.LowerOperator -import Language.Granule.Codegen.Emit.LowerType (llvmType) +import Language.Granule.Codegen.Emit.LowerType (llvmType, llvmTopLevelType) import Language.Granule.Codegen.Emit.EmitableDef import Language.Granule.Codegen.Emit.LowerPatterns (emitCaseArm) import Language.Granule.Codegen.Emit.LowerClosure (emitClosureMarker) @@ -29,7 +29,7 @@ import LLVM.IRBuilder.Monad import LLVM.IRBuilder.Instruction import LLVM.AST (Operand) -import LLVM.AST.Type (ptr) +import LLVM.AST.Type (ptr, i8) import LLVM.AST.Constant as C import qualified LLVM.IRBuilder.Constant as IC import qualified LLVM.AST as IR @@ -137,7 +137,10 @@ emitValue _ (ExtF a (Left (GlobalVar ty ident))) = do let ref = IR.ConstantOperand $ C.GlobalReference (ptr (llvmType ty)) (definitionNameFromId ident) load ref 4 emitValue _ (ExtF a (Left (BuiltinVar ty ident))) = do - error "TODO?" + useBuiltin ident ty + let functionPtr = IR.ConstantOperand $ C.GlobalReference (ptr $ llvmTopLevelType ty) (IR.mkName $ "fn." ++ internalName ident) + closure <- insertValue (IR.ConstantOperand $ C.Undef (llvmType ty)) functionPtr [0] + insertValue closure (IR.ConstantOperand $ C.Null (ptr i8)) [1] emitValue environment (ExtF ty (Right cm)) = emitClosureMarker ty environment cm {- TODO: Support tagged unions, also affects Case. diff --git a/src/Language/Granule/Codegen/Emit/MainOut.hs b/src/Language/Granule/Codegen/Emit/MainOut.hs index e79efb3..b8a51c1 100644 --- a/src/Language/Granule/Codegen/Emit/MainOut.hs +++ b/src/Language/Granule/Codegen/Emit/MainOut.hs @@ -65,4 +65,5 @@ fmtStrForTy x = (TyApp (TyCon (Id "FloatArray" _)) _) -> "" (TyCon (Id "()" _)) -> "()" (TyExists _ _ (Borrow _ ty)) -> "*" ++ fmtStrForTy ty + (Box _ ty) -> "[" ++ fmtStrForTy ty ++ "]" _ -> error ("Unsupported Main type: " ++ show x) diff --git a/src/Language/Granule/Codegen/MarkGlobals.hs b/src/Language/Granule/Codegen/MarkGlobals.hs index 64dc0bc..4bdc303 100644 --- a/src/Language/Granule/Codegen/MarkGlobals.hs +++ b/src/Language/Granule/Codegen/MarkGlobals.hs @@ -35,7 +35,8 @@ markGlobalsInExpr :: [Id] -> Expr () Type -> Expr GlobalMarker Type markGlobalsInExpr globals = bicata fixMapExtExpr markInValue where markInValue (VarF ty ident) - | ident `elem` builtinIds = Ext ty (BuiltinVar ty ident) + | any (\id -> sourceName ident == sourceName id) builtinIds = + Ext ty (BuiltinVar ty ident) | ident `elem` globals = Ext ty (GlobalVar ty ident) | otherwise = Var ty ident markInValue other = diff --git a/src/Language/Granule/Codegen/Monomorphise.hs b/src/Language/Granule/Codegen/Monomorphise.hs new file mode 100644 index 0000000..b539256 --- /dev/null +++ b/src/Language/Granule/Codegen/Monomorphise.hs @@ -0,0 +1,250 @@ +module Language.Granule.Codegen.Monomorphise (monomorphiseAST) where + +import Control.Monad.Identity (runIdentity) +import Data.Bifunctor (Bifunctor (bimap), second) +import Data.Hashable (hash) +import qualified Data.Map as Map +import Language.Granule.Codegen.Builtins.Builtins (polyBuiltinIds) +import Language.Granule.Syntax.Annotated (annotation) +import Language.Granule.Syntax.Def +import Language.Granule.Syntax.Expr hiding (subst) +import Language.Granule.Syntax.Identifiers +import Language.Granule.Syntax.Pattern +import Language.Granule.Syntax.Type + +-- polymorphic id -> [monomorphic id, [(ty var, ty subst)]] +type PolyInstances = Map.Map Id [(Id, [(Id, Type)])] + +-- polymorphic id -> ty +type PolyFuncs = Map.Map Id Type + +-- TODO: +-- ensure fixed point +-- more tests + +-- create monomorphic versions for each required instance of polymorphic function and rewrite ast +monomorphiseAST :: AST ev Type -> AST ev Type +monomorphiseAST ast = + let polymorphicFuncs = getPolymorphicFunctions ast + env = collectInstances ast polymorphicFuncs + in if null env + then -- we still need to rewrite builtins + let rewritten = rewriteCalls ast Map.empty + in rewritten {definitions = filter (not . isPolymorphic) (definitions rewritten)} + else + let monoDefs = makeMonoDefs ast env + rewritten = rewriteCalls ast env + in monomorphiseAST (rewritten {definitions = definitions rewritten ++ monoDefs}) + +isPolymorphic :: Def ev Type -> Bool +isPolymorphic def = + case defTypeScheme def of + Forall _ bindings _ _ -> any (\(_, t) -> t == Type 0) bindings + +-- e.g. id -> __id_3856 +makeMonoId :: Id -> Type -> Id +makeMonoId (Id id _) ty = + let monoId = "__" ++ id ++ "_" ++ show (abs $ hash $ show ty) + in Id id monoId + +-- create map of polymorphic function id to its ty vars +getPolymorphicFunctions :: AST ev Type -> PolyFuncs +getPolymorphicFunctions ast = + Map.fromList $ map getPolyInfo $ filter isPolymorphic $ definitions ast + where + getPolyInfo :: Def ev Type -> (Id, Type) + getPolyInfo def = + case defTypeScheme def of + Forall _ _ _ ty -> (defId def, ty) + +-- collect all insts of polymorphic functions with their concrete type substitutions +collectInstances :: AST ev Type -> PolyFuncs -> PolyInstances +collectInstances ast fns = + foldl collectDef Map.empty (definitions ast) + where + collectDef env def = + let defInstances = foldl collectEquation Map.empty (equations $ defEquations def) + in Map.unionWith (++) env defInstances + + collectEquation env eq = collectExpr (equationBody eq) + + collectExpr :: Expr ev Type -> PolyInstances + collectExpr (App _ _ _ e1 e2) = + let inst = case getPolymorphicCall fns e1 of + Just (id, tyVarSubsts) -> + Map.singleton id [(makeMonoId id (annotation e2), tyVarSubsts)] + Nothing -> Map.empty + in Map.unionWith (++) (collectExprs [e1, e2]) inst + collectExpr (Val _ _ _ val) = collectVal val + collectExpr (Binop _ _ _ _ e1 e2) = collectExprs [e1, e2] + collectExpr (Case _ _ _ e bs) = collectExprs (e : map snd bs) + collectExpr (AppTy _ _ _ e _) = collectExpr e + collectExpr (LetDiamond _ _ _ _ _ e1 e2) = collectExprs [e1, e2] + collectExpr (TryCatch _ _ _ e1 _ _ e2 e3) = collectExprs [e1, e2, e3] + collectExpr (Unpack _ _ _ _ _ e1 e2) = collectExprs [e1, e2] + collectExpr (Hole {}) = Map.empty + + collectVal :: Value ev Type -> PolyInstances + collectVal (Abs _ _ _ body) = collectExpr body + collectVal (Constr _ _ vals) = foldr (Map.unionWith (++) . collectVal) Map.empty vals + collectVal (Pure _ e) = collectExpr e + collectVal (Promote _ e) = collectExpr e + collectVal (Nec _ e) = collectExpr e + collectVal (Ref _ e) = collectExpr e + collectVal (Pack _ _ _ e _ _ _) = collectExpr e + collectVal (TyAbs _ _ e) = collectExpr e + collectVal _ = Map.empty + + -- combine results from multiple expressions + collectExprs :: [Expr ev Type] -> PolyInstances + collectExprs = foldr (Map.unionWith (++) . collectExpr) Map.empty + +-- identify polymorphic function calls and get substitution info +getPolymorphicCall :: PolyFuncs -> Expr ev Type -> Maybe (Id, [(Id, Type)]) +getPolymorphicCall fns (Val _ _ _ (Var ty id)) = + case Map.lookup id fns of + Just param -> + let substs = match param ty + in if null substs + then Nothing + else Just (id, substs) + Nothing -> Nothing +getPolymorphicCall _ _ = Nothing + +match :: Type -> Type -> [(Id, Type)] +match param arg = case (param, arg) of + (TyVar _, TyVar _) -> [] + (TyVar id, ty) -> [(id, ty)] + (FunTy _ _ a b, FunTy _ _ a' b') -> match2 a a' b b' + (TyApp a b, TyApp a' b') -> match2 a a' b b' + (Box _ a, Box _ a') -> match a a' + (Diamond _ a, Diamond _ a') -> match a a' + (Star _ a, Star _ a') -> match a a' + (Borrow _ a, Borrow _ a') -> match a a' + (TySig a _, TySig a' _) -> match a a' + (TyInfix _ a b, TyInfix _ a' b') -> match2 a a' b b' + (TyExists _ _ a, TyExists _ _ a') -> match a a' + (TyForall _ _ a, TyForall _ _ a') -> match a a' + (TyCase a as, TyCase a' as') -> match a a' ++ concat [match2 a a' b b' | ((a, b), (a', b')) <- zip as as'] + (TySet _ ts, TySet _ ts') -> concat (zipWith match ts ts') + (TyGrade (Just t) _, TyGrade (Just t') _) -> match t t' + _ -> [] + +match2 :: Type -> Type -> Type -> Type -> [(Id, Type)] +match2 a a' b b' = match a a' ++ match b b' + +-- create monomorphised definitions for all polymorphic function insts +makeMonoDefs :: AST ev Type -> PolyInstances -> [Def ev Type] +makeMonoDefs ast env = concatMap (monoDefsForFunc ast) (Map.toList env) + +monoDefsForFunc :: AST ev Type -> (Id, [(Id, [(Id, Type)])]) -> [Def ev Type] +monoDefsForFunc ast (id, instances) = + let og = head (filter (\def -> defId def == id) (definitions ast)) + in map (monoDef og) instances + +monoDef :: Def ev Type -> (Id, [(Id, Type)]) -> Def ev Type +monoDef (Def s _ r spec eqs ts) (id', typeSubsts) = + let subs = Map.fromList typeSubsts + subst = substTy subs + eqs' = monoEqList eqs id' subs subst + ts' = substTypeScheme ts subs subst + in Def s id' r spec eqs' ts' + +monoEqList :: EquationList ev Type -> Id -> Map.Map Id Type -> (Type -> Type) -> EquationList ev Type +monoEqList (EquationList s _ r eqs) id' subs applySubst = + let eqs' = map (monoEq id' subs applySubst) eqs + in EquationList s id' r eqs' + +monoEq :: Id -> Map.Map Id Type -> (Type -> Type) -> Equation ev Type -> Equation ev Type +monoEq id' subs applySubst (Equation s id a r ps b) = + let a' = applySubst a + ps' = map (substPat subs) ps + b' = substExpr subs b + in Equation s id' a' r ps' b' + +substTypeScheme :: TypeScheme -> Map.Map Id Type -> (Type -> Type) -> TypeScheme +substTypeScheme (Forall s bs cs ty) subs applySubst = + let bs' = filter (\(tyVar, _) -> not (Map.member tyVar subs)) bs + ty' = applySubst ty + in Forall s bs' cs ty' + +-- use typeFold with our substitution map +substTy :: Map.Map Id Type -> Type -> Type +substTy subs ty = + runIdentity $ typeFoldM (baseTypeFold {tfTyVar = substVar}) ty + where + substVar id = return $ Map.findWithDefault (TyVar id) id subs + +substPat :: Map.Map Id Type -> Pattern Type -> Pattern Type +substPat subs = + patternFold + (\s ty r id -> PVar s (substTy subs ty) r id) + (\s ty r -> PWild s (substTy subs ty) r) + (\s ty r pat -> PBox s (substTy subs ty) r pat) + (\s ty r i -> PInt s (substTy subs ty) r i) + (\s ty r f -> PFloat s (substTy subs ty) r f) + (\s ty r id ids pats -> PConstr s (substTy subs ty) r id ids pats) + +substExpr :: Map.Map Id Type -> Expr ev Type -> Expr ev Type +substExpr subs expr = + case expr of + App s ty r f arg -> App s (apply ty) r (subExp f) (subExp arg) + Val s ty r val -> Val s (apply ty) r (fmap apply val) + Binop s ty r op e1 e2 -> Binop s (apply ty) r op (subExp e1) (subExp e2) + Case s ty r e ps -> Case s (apply ty) r (subExp e) (map (bimap (substPat subs) subExp) ps) + Hole s ty r ids hs -> Hole s (apply ty) r ids hs + AppTy s ty r e t -> AppTy s (apply ty) r (subExp e) t + TryCatch s ty r e p mt e1 e2 -> TryCatch s (apply ty) r e p mt (subExp e1) (subExp e2) + Unpack s ty r tyVar var e1 e2 -> Unpack s (apply ty) r tyVar var (subExp e1) (subExp e2) + LetDiamond s ty r ps mt e1 e2 -> LetDiamond s (apply ty) r ps mt (subExp e1) (subExp e2) + where + apply = substTy subs + subExp = substExpr subs + +-- rewrite polymorphic function calls to use the monomorphised versions +rewriteCalls :: AST ev Type -> PolyInstances -> AST ev Type +rewriteCalls ast env = ast {definitions = map rewriteDef (definitions ast)} + where + rewriteDef def = def {defEquations = rewriteEqList (defEquations def)} + rewriteEqList eqs = eqs {equations = map rewriteEq (equations eqs)} + rewriteEq eq = eq {equationBody = rewriteExpr (equationBody eq)} + + rewriteExpr :: Expr ev Type -> Expr ev Type + rewriteExpr expr@(App s ty r f arg) = + let rewrittenF = rewriteExpr f + rewrittenArg = rewriteExpr arg + newF = case rewrittenF of + Val s' t' r' (Var vt id) -> + -- Only rewrite if this is a polymorphic function in our map + -- or polymorphic builtin + if Map.member id env || id `elem` polyBuiltinIds + then + let argTy = annotation rewrittenArg + ty' = FunTy Nothing Nothing argTy ty + in Val s' ty' r' (Var ty' (makeMonoId id argTy)) + else rewrittenF + _ -> rewrittenF + in App s ty r newF rewrittenArg + rewriteExpr (Val s ty r val) = Val s ty r (rewriteVal val) + rewriteExpr (Binop s ty r op e1 e2) = Binop s ty r op (rewriteExpr e1) (rewriteExpr e2) + rewriteExpr (Case s ty r e ps) = Case s ty r (rewriteExpr e) (map (second rewriteExpr) ps) + rewriteExpr (Hole s a b ids hs) = Hole s a b ids hs + rewriteExpr (AppTy s a b e t) = AppTy s a b (rewriteExpr e) t + rewriteExpr (TryCatch s a b e p mt e1 e2) = TryCatch s a b (rewriteExpr e) p mt (rewriteExpr e1) (rewriteExpr e2) + rewriteExpr (Unpack s a rf tyVar var e1 e2) = Unpack s a rf tyVar var (rewriteExpr e1) (rewriteExpr e2) + rewriteExpr (LetDiamond s a b ps mt e1 e2) = LetDiamond s a b ps mt (rewriteExpr e1) (rewriteExpr e2) + + rewriteVal :: Value ev Type -> Value ev Type + rewriteVal (Abs a pat mt e) = Abs a pat mt (rewriteExpr e) + rewriteVal (Constr a idv vals) = Constr a idv (map rewriteVal vals) + rewriteVal (Promote a e) = Promote a (rewriteExpr e) + rewriteVal (Pure a e) = Pure a (rewriteExpr e) + rewriteVal (Nec a e) = Nec a (rewriteExpr e) + rewriteVal (Pack s a ty e v k ty') = Pack s a ty (rewriteExpr e) v k ty' + rewriteVal (TyAbs a v e) = TyAbs a v (rewriteExpr e) + rewriteVal (NumInt n) = NumInt n + rewriteVal (NumFloat n) = NumFloat n + rewriteVal (CharLiteral ch) = CharLiteral ch + rewriteVal (StringLiteral str) = StringLiteral str + rewriteVal (Ext a ev) = Ext a ev + rewriteVal (Var a id) = Var a id diff --git a/src/Language/Granule/Codegen/RewriteAST.hs b/src/Language/Granule/Codegen/RewriteAST.hs index 32f710b..48d75be 100644 --- a/src/Language/Granule/Codegen/RewriteAST.hs +++ b/src/Language/Granule/Codegen/RewriteAST.hs @@ -1,17 +1,13 @@ module Language.Granule.Codegen.RewriteAST where -import Data.Bifunctor (bimap) -import Data.List (mapAccumL) -import qualified Data.Map as Map -import Data.Maybe (fromMaybe) import Language.Granule.Syntax.Def import Language.Granule.Syntax.Expr -import Language.Granule.Syntax.Identifiers (Id) import Language.Granule.Syntax.Pattern import Language.Granule.Syntax.Type -- Rewrite Unpack ASTs into App Abs ASTs which our --- compiler already knows how to handle. WIP. +-- compiler already knows how to handle. +-- TODO: handle unpack in compile rewriteAST :: AST ev Type -> AST ev Type rewriteAST ast = ast {definitions = map rewriteDef (definitions ast)} @@ -20,15 +16,12 @@ rewriteAST ast = ast {definitions = map rewriteDef (definitions ast)} rewriteEquationList eqs = eqs {equations = map rewriteEquation (equations eqs)} rewriteEquation eq = eq {equationBody = rewriteExpr (equationBody eq)} +-- TODO: handle not top level rewriteExpr :: Expr ev Type -> Expr ev Type rewriteExpr (Unpack s retTy b tyVar var e1 e2) = - let e1' = e1 - e1Ty = exprTy e1' - e2' = e2 + let e1Ty = exprTy e1 absTy = FunTy Nothing Nothing e1Ty retTy - in fixTypes (App s retTy b (Val s absTy b (Abs absTy (PVar s e1Ty b var) Nothing e2')) e1') - where - fixTypes expr = snd $ substExpr emptyEnv expr + in App s retTy b (Val s absTy b (Abs absTy (PVar s e1Ty b var) Nothing e2)) e1 rewriteExpr exp = exp exprTy :: Expr ev Type -> Type @@ -41,94 +34,3 @@ exprTy (Hole _ ty _ _ _) = ty exprTy (AppTy _ ty _ _ _) = ty exprTy (TryCatch _ ty _ _ _ _ _ _) = ty exprTy (Unpack _ ty _ _ _ _ _) = ty - --- `let (x, y) = ` inside of an Unpack seems to leave TyVars in the AST, and these --- are not already handled by the compiler. Here we find the correct types and substitute --- the TyVars. WIP. - --- val var -> Type, type var -> Type -type Env = (Map.Map Id Type, Map.Map Id Type) - -emptyEnv :: Env -emptyEnv = (Map.empty, Map.empty) - -insertEnv :: Env -> Either Id Id -> Type -> Env -insertEnv (vals, tys) (Left id) ty = (Map.insert id ty vals, tys) -insertEnv (vals, tys) (Right id) ty = (vals, Map.insert id ty tys) - -lookupEnv :: Env -> Either Id Id -> Maybe Type -lookupEnv (vals, tys) (Left id) = Map.lookup id vals -lookupEnv (vals, tys) (Right id) = Map.lookup id tys - -substExpr :: Env -> Expr ev Type -> (Env, Expr ev Type) -substExpr env (App s ty b e1 e2) = - let (env', e2') = substExpr env e2 - (env'', e1') = substExpr env' e1 - ty' = substTy env ty - in (env'', App s ty' b e1' e2') -substExpr env (Val s ty b v) = - let (env', v') = substVal env v - ty' = substTy env' ty - in (env', Val s ty' b v') -substExpr env exp = error "TODO expr" - -substVal :: Env -> Value ev Type -> (Env, Value ev Type) -substVal env (Var (TyVar id) var) = - -- see if we already have it - case lookupEnv env (Right id) of - Just ty -> (env, Var ty var) - Nothing -> - -- see if the value variable has it - case lookupEnv env (Left var) of - -- and update - Just ty -> (insertEnv env (Right id) ty, Var ty var) - -- we wont always win - Nothing -> (env, Var (TyVar id) var) -substVal env (Var ty var) = (insertEnv env (Left var) ty, Var ty var) -substVal env (Abs ty p mt e) = - let (env', p') = substPat env p - (env'', e') = substExpr env' e - ty' = substTy env'' ty - in (env'', Abs ty' p' mt e') -substVal env (Constr ty id vals) = - let (env', vals') = mapAccumL substVal env vals - ty' = substTy env' ty - in (env', Constr ty' id vals') -substVal env (NumInt v) = (env, NumInt v) -substVal env (NumFloat v) = (env, NumFloat v) -substVal env (Promote t v) = (env, Promote t v) -substVal env val = error "TODO val" - -substPat :: Env -> Pattern Type -> (Env, Pattern Type) -substPat env (PVar s (TyVar id) b var) = - case lookupEnv env (Right id) of - Just ty -> (env, PVar s ty b var) - Nothing -> - case lookupEnv env (Left var) of - Just ty -> (insertEnv env (Right id) ty, PVar s ty b var) - Nothing -> (env, PVar s (TyVar id) b var) -substPat env (PVar s ty b var) = (insertEnv env (Left var) ty, PVar s ty b var) -substPat env (PConstr s ty b id ids ps) = - let (env', ps') = mapAccumL substPat env ps - ty' = substTy env' ty - in (env', PConstr s ty' b id ids ps') -substPat env p = error "TODO pat" - -substTy :: Env -> Type -> Type -substTy env (TyVar id) = fromMaybe (TyVar id) (lookupEnv env (Right id)) -substTy env (Type i) = Type i -substTy env (FunTy id mc arg ret) = FunTy id mc (substTy env arg) (substTy env ret) -substTy env (TyCon id) = TyCon id -substTy env (Box c t) = substTy env t -substTy env (Diamond e t) = Diamond (substTy env e) (substTy env t) -substTy env (Star g t) = substTy env t -substTy env (Borrow p t) = substTy env t -substTy env (TyApp t1 t2) = TyApp (substTy env t1) (substTy env t2) -substTy env (TyGrade mt i) = TyGrade mt i -substTy env (TyInfix op t1 t2) = TyInfix op (substTy env t1) (substTy env t2) -substTy env (TySet p ts) = TySet p (map (substTy env) ts) -substTy env (TyCase t tps) = TyCase (substTy env t) (map (bimap (substTy env) (substTy env)) tps) -substTy env (TySig t k) = TySig (substTy env t) (substTy env k) -substTy env (TyExists id k t) = substTy env t -substTy env (TyForall id k t) = substTy env t -substTy env t = t diff --git a/src/Language/Granule/Codegen/SubstituteTypes.hs b/src/Language/Granule/Codegen/SubstituteTypes.hs new file mode 100644 index 0000000..220d031 --- /dev/null +++ b/src/Language/Granule/Codegen/SubstituteTypes.hs @@ -0,0 +1,265 @@ +module Language.Granule.Codegen.SubstituteTypes where + +import Data.Bifunctor (bimap) +import qualified Data.Map as Map +import Debug.Trace +import Language.Granule.Syntax.Def +import Language.Granule.Syntax.Expr +import Language.Granule.Syntax.Identifiers +import Language.Granule.Syntax.Pattern +import Language.Granule.Syntax.Type + +-- Many Typed ASTs contain unnecessary TyVars +-- There are a few strategies to substitue these +-- 1. From parent / sibling node i.e. +-- Val (concrete) (Var (variable)), +-- App (variable) (Val (variable -> concrete)) (Val (concrete)) +-- etc. +-- 2. From value bindings +-- i.e. +-- Somewhere: Var (concrete) x +-- Somewhere else: Var (variable) x +-- 3. From type bindings +-- i.e. +-- Somewhere: Val (concrete) (Var (x)) (using ) +-- Somewhere else: Val (x) (Var (variable)) () +-- We do 1 and 3 here +-- +-- TODO: clean up, do in 1 pass if possible +substituteTypes :: AST ev Type -> AST ev Type +substituteTypes ast = ast {definitions = map retypeDef (definitions ast)} + where + retypeDef def = def {defEquations = retypeEquationList (defEquations def)} + retypeEquationList eqs = eqs {equations = map retypeEquation (equations eqs)} + retypeEquation eq = + let expr = equationBody eq + substs = collect expr + expr' = replace substs expr + in eq {equationBody = expr'} + +type VMap = Map.Map Id Type + +diff :: Map.Map Id Type -> Type -> Type -> Map.Map Id Type +diff env t1 t2 = case (t1, t2) of + (TyVar v1, TyVar v2) -> + case (Map.lookup v1 env, Map.lookup v2 env) of + (Nothing, Nothing) -> env + (Nothing, Just t2') -> Map.insert v1 t2' env + (Just t1', Nothing) -> Map.insert v2 t1' env + (Just t1', Just t2') -> diff env t1' t2' + (TyVar v1, t2) -> + case Map.lookup v1 env of + Nothing -> Map.insert v1 t2 env + Just t1 -> diff env t1 t2 + (t1, TyVar v2) -> + case Map.lookup v2 env of + Nothing -> Map.insert v2 t1 env + Just t2 -> diff env t1 t2 + (FunTy _ _ a b, FunTy _ _ a' b') -> diff (diff env b b') a a' + (TyApp a b, TyApp a' b') -> diff (diff env b b') a a' + (Box _ a, Box _ a') -> diff env a a' + (Diamond _ a, Diamond _ a') -> diff env a a' + (Star _ a, Star _ a') -> diff env a a' + (Borrow _ a, Borrow _ a') -> diff env a a' + (TySig a _, TySig a' _) -> diff env a a' + (TyInfix _ a b, TyInfix _ a' b') -> diff (diff env b b') a a' + (TyExists _ _ a, TyExists _ _ a') -> diff env a a' + (TyForall _ _ a, TyForall _ _ a') -> diff env a a' + (TyCase a as, TyCase a' as') -> + foldl + (\e ((p1, r1), (p2, r2)) -> diff (diff e p1 p2) r1 r2) + (diff env a a') + (zip as as') + (TySet _ ts, TySet _ ts') -> + foldl + (\e (t, t') -> diff e t t') + env + (zip ts ts') + (TyGrade (Just t) _, TyGrade (Just t') _) -> diff env t t' + _ -> env + +collect :: Expr ev Type -> Map.Map Id Type +collect expr = + let (env', _) = inExpr Map.empty expr + in let (env'', _) = inExpr env' expr + in env'' + where + inExpr :: Map.Map Id Type -> Expr ev Type -> (Map.Map Id Type, Type) + inExpr env expr = + case expr of + App s retTy r f arg -> + let (env', argTy) = inExpr env arg + (env'', fTy) = inExpr env' f + env''' = diff env'' (FunTy Nothing Nothing argTy retTy) fTy + in (env''', retTy) + Val s ty r val -> + let (env', ty') = inVal env val + env'' = diff env' ty ty' + in (env'', ty) + Binop s ty r op e1 e2 -> (fst $ inExpr (fst $ inExpr env e2) e1, ty) + Case s ty r e ps -> + ( foldl + (\env (p, e) -> fst $ inExpr (fst $ inPat env p) e) + (fst $ inExpr env e) + ps, + ty + ) + AppTy s ty r e t -> error "TODO: AppTy" + LetDiamond s ty r ps mt e1 e2 -> error "TODO: LetDiamond" + TryCatch s ty r e p mt e1 e2 -> error "TODO: TryCatch" + Unpack s ty r tyVar var e1 e2 -> + let (env', _) = inExpr env e1 + (env'', retTy) = inExpr env' e2 + env''' = diff env'' ty retTy + in (env''', ty) + Hole s ty r ids hs -> error "TODO: Hole" + + inVal :: Map.Map Id Type -> Value ev Type -> (Map.Map Id Type, Type) + inVal env val = + case val of + Abs funTy p mt e -> + let (env', argTy) = inPat env p + (env'', retTy) = inExpr env' e + env''' = diff env'' (FunTy Nothing Nothing argTy retTy) funTy + in (env''', funTy) + Constr a id vs -> + (foldl (\env v -> fst $ inVal env v) env vs, a) + Promote a e -> (fst $ inExpr env e, a) + Pure a e -> (fst $ inExpr env e, a) + Nec a e -> error "TODO: Nec" + Pack s a ty e v k ty' -> + let (env', eTy) = inExpr env e + + env'' = diff env' (TyExists v k eTy) a + in trace + (show a) + trace + (show (TyExists v k eTy)) + trace + "" + (env'', a) + TyAbs a v e -> error "TODO: TyAbs" + NumInt n -> (env, TyCon (Id "Int" "Int")) + NumFloat n -> (env, TyCon (Id "Float" "Float")) + CharLiteral ch -> (env, TyCon (Id "Char" "Char")) + StringLiteral str -> (env, TyCon (Id "String" "String")) + Ext a ev -> error "TODO: Ext" + Var a id -> (env, a) + + inPat :: Map.Map Id Type -> Pattern Type -> (Map.Map Id Type, Type) + inPat env pat = + case pat of + PVar s ty r id -> (env, ty) + PWild s ty r -> (env, ty) + PBox s ty r pat' -> (fst $ inPat env pat', ty) + PInt s ty r i -> (env, ty) + PFloat s ty r f -> (env, ty) + PConstr s ty r id ids ps -> (foldl (\env p -> fst $ inPat env p) env ps, ty) + +replace :: Map.Map Id Type -> Expr ev Type -> Expr ev Type +replace env expr = inExpr env expr + where + inExpr :: Map.Map Id Type -> Expr ev Type -> Expr ev Type + inExpr env expr = + case expr of + App s ty r f arg -> + let ty' = inTy env ty + arg' = inExpr env arg + f' = inExpr env f + in App s ty' r f' arg' + Val s ty r val -> + let ty' = inTy env ty + val' = inVal env val + in Val s ty' r val' + Binop s ty r op e1 e2 -> + let ty' = inTy env ty + e2' = inExpr env e2 + e1' = inExpr env e1 + in Binop s ty' r op e1' e2' + Case s ty r e ps -> + let ty' = inTy env ty + e' = inExpr env e + ps' = map (bimap (inPat env) (inExpr env)) ps + in Case s ty' r e' ps' + AppTy s ty r e t -> error "TODO: AppTy" + LetDiamond s ty r ps mt e1 e2 -> error "TODO: LetDiamond" + TryCatch s ty r e p mt e1 e2 -> error "TODO: TryCatch" + Unpack s ty r tyVar var e1 e2 -> + let ty' = inTy env ty + e1' = inExpr env e1 + e2' = inExpr env e2 + in Unpack s ty' r tyVar var e1' e2' + Hole s ty r ids hs -> error "TODO: Hole" + + inVal :: Map.Map Id Type -> Value ev Type -> Value ev Type + inVal env val = + case val of + Abs a pat mt e -> + let a' = inTy env a + pat' = inPat env pat + e' = inExpr env e + in Abs a' pat' mt e' + Constr a id vs -> + let a' = inTy env a + vs' = map (inVal env) vs + in Constr a' id vs' + Promote a e -> + let a' = inTy env a + e' = inExpr env e + in Promote a' e' + Pure a e -> + let a' = inTy env a + e' = inExpr env e + in Pure a' e' + Nec a e -> error "TODO: Nec" + Pack s a t1 e v k t2 -> + let a' = inTy env a + t1' = inTy env t1 + e' = inExpr env e + t2' = inTy env t2 + in Pack s a' t1' e' v k t2' + TyAbs a v e -> error "TODO: TyAbs" + NumInt n -> val + NumFloat n -> val + CharLiteral ch -> val + StringLiteral str -> val + Ext a ev -> error "TODO: Ext" + Var a id -> Var (inTy env a) id + + inPat :: Map.Map Id Type -> Pattern Type -> Pattern Type + inPat env pat = + case pat of + PVar s ty r id -> PVar s (inTy env ty) r id + PWild s ty r -> PWild s (inTy env ty) r + PBox s ty r p -> PBox s (inTy env ty) r (inPat env p) + PInt s ty r i -> PInt s (inTy env ty) r i + PFloat s ty r f -> PFloat s (inTy env ty) r f + PConstr s ty r id ids ps -> PConstr s (inTy env ty) r id ids (map (inPat env) ps) + + inTy :: Map.Map Id Type -> Type -> Type + inTy env ty = + case ty of + TyVar id -> + case Map.lookup id env of + Nothing -> ty + Just ty' -> inTy env ty' + Type i -> Type i + FunTy id mc arg ret -> FunTy id mc (inTy env arg) (inTy env ret) + Box c t -> Box c (inTy env t) + Diamond e t -> Diamond e (inTy env t) + Star g t -> Star g (inTy env t) + Borrow p t -> Borrow p (inTy env t) + TyApp t1 t2 -> TyApp (inTy env t1) (inTy env t2) + TyInfix op t1 t2 -> TyInfix op (inTy env t1) (inTy env t2) + TyCase t tys -> TyCase (inTy env t) (map (bimap (inTy env) (inTy env)) tys) + TySig t k -> TySig (inTy env t) k + TyExists id k t -> TyExists id k (inTy env t) + TyForall id k t -> TyForall id k (inTy env t) + TyGrade (Just t) i -> TyGrade (Just (inTy env t)) i + TySet p ts -> TySet p (map (inTy env) ts) + TyCon {} -> ty + TyInt {} -> ty + TyRational {} -> ty + TyFraction {} -> ty + TyName {} -> ty + TyGrade {} -> ty diff --git a/tests/golden/positive/poly-builtins.golden b/tests/golden/positive/poly-builtins.golden new file mode 100644 index 0000000..8aa4ae5 --- /dev/null +++ b/tests/golden/positive/poly-builtins.golden @@ -0,0 +1 @@ +([[42.000000]], [[100]]) diff --git a/tests/golden/positive/poly-builtins.gr b/tests/golden/positive/poly-builtins.gr new file mode 100644 index 0000000..11dcbec --- /dev/null +++ b/tests/golden/positive/poly-builtins.gr @@ -0,0 +1,11 @@ +use' : forall {a : Type} . a -> a [1] +use' x = use x + +use2 : forall {a b : Type} . (a, b) -> (a [1], b [1]) +use2 (x, y) = (use x, use y) + +main : ((Float [1]) [1], (Int [1]) [1]) +main = + let a = use' 42.0; + b = use' 100 in + use2 (a, b) diff --git a/tests/golden/positive/poly-curry.golden b/tests/golden/positive/poly-curry.golden new file mode 100644 index 0000000..7093511 --- /dev/null +++ b/tests/golden/positive/poly-curry.golden @@ -0,0 +1 @@ +(100, 42.000000) diff --git a/tests/golden/positive/poly-curry.gr b/tests/golden/positive/poly-curry.gr new file mode 100644 index 0000000..e3c8a8e --- /dev/null +++ b/tests/golden/positive/poly-curry.gr @@ -0,0 +1,27 @@ +curry : forall {a : Type, b : Type, c : Type} . + (a × b -> c) -> a -> b -> c +curry f x y = f (x, y) + +uncurry : forall {a : Type, b : Type, c : Type} . + (a -> b -> c) -> (a × b -> c) +uncurry f (x, y) = f x y + +addInt : (Int, Int) -> Int +addInt (x, y) = x + y + +addFloat : (Float, Float) -> Float +addFloat (x, y) = x + y + +intDrop : forall {a : Type} . Int -> a [0] -> Int +intDrop x [_] = x + +floatDrop : forall {a : Type} . a [0] -> Float -> Float +floatDrop [_] x = x + +swap : forall {a : Type, b : Type} . (a, b) -> (b, a) +swap (x, y) = (y, x) + +main : (Int, Float) +main = let addInt4 = curry addInt 4; + addFloat4 = curry addFloat 4.0 in + swap (swap (addInt4 (intDrop 96 [2000]), addFloat4 (floatDrop [1000] 38.0))) diff --git a/tests/golden/positive/poly-simple.golden b/tests/golden/positive/poly-simple.golden new file mode 100644 index 0000000..d81cc07 --- /dev/null +++ b/tests/golden/positive/poly-simple.golden @@ -0,0 +1 @@ +42 diff --git a/tests/golden/positive/poly-simple.gr b/tests/golden/positive/poly-simple.gr new file mode 100644 index 0000000..9c2104f --- /dev/null +++ b/tests/golden/positive/poly-simple.gr @@ -0,0 +1,5 @@ +id : forall {a : Type} . a -> a +id x = x + +main : Int +main = id 42