Skip to content

Polymorphic Functions #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: float-arrays
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions granule-compiler.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ library
Language.Granule.Codegen.Emit.Names
Language.Granule.Codegen.Emit.Primitives
Language.Granule.Codegen.Emit.Types
Language.Granule.Codegen.Monomorphise
Language.Granule.Codegen.RewriteAST
Language.Granule.Codegen.SubstituteTypes
Paths_granule_compiler
hs-source-dirs:
src
Expand All @@ -62,6 +64,7 @@ library
base >=4.10 && <5
, containers
, granule-frontend
, hashable
, llvm-hs ==12.*
, llvm-hs-pure ==12.*
, mtl
Expand Down Expand Up @@ -91,6 +94,7 @@ executable grlc
, gitrev
, granule-compiler
, granule-frontend
, hashable
, llvm-hs ==12.*
, llvm-hs-pure ==12.*
, optparse-applicative
Expand Down Expand Up @@ -123,6 +127,7 @@ test-suite compiler-spec
, filemanip
, granule-compiler
, granule-frontend
, hashable
, hspec
, mtl
, process
Expand All @@ -149,6 +154,7 @@ test-suite golden
, filepath
, granule-compiler
, granule-frontend
, hashable
, llvm-hs ==12.*
, llvm-hs-pure ==12.*
, process
Expand Down
1 change: 1 addition & 0 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ github: granule-project/granule-compiler-llvm
dependencies:
- base >=4.10 && <5
- process
- hashable
default-extensions:
- LambdaCase
- RecordWildCards
Expand Down
13 changes: 12 additions & 1 deletion src/Language/Granule/Codegen/Builtins/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,16 @@ builtins =
deleteFloatArrayDef
]

specialisable :: [Specialisable]
specialisable =
[ useDef
]

monoBuiltinIds :: [Id]
monoBuiltinIds = map (mkId . builtinId) builtins

polyBuiltinIds :: [Id]
polyBuiltinIds = map (mkId . specialisableId) specialisable

builtinIds :: [Id]
builtinIds = map (mkId . builtinId) builtins
builtinIds = monoBuiltinIds ++ polyBuiltinIds
7 changes: 7 additions & 0 deletions src/Language/Granule/Codegen/Builtins/Extras.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,10 @@ divDef =
args = [TyCon (Id "Int" "Int"), TyCon (Id "Int" "Int")]
ret = TyCon (Id "Int" "Int")
impl [x, y] = sdiv x y

-- use :: a -> a [1]
useDef :: Specialisable
useDef =
Specialisable "use" impl
where
impl _ [val] = return val
13 changes: 13 additions & 0 deletions src/Language/Granule/Codegen/Builtins/Shared.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE RankNTypes #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

module Language.Granule.Codegen.Builtins.Shared where

Expand All @@ -21,6 +22,18 @@ data Builtin = Builtin {
builtinRetTy :: Gr.Type,
builtinImpl :: forall m. (MonadModuleBuilder m, MonadIRBuilder m) => [Operand] -> m Operand}

data Specialisable = Specialisable {
specialisableId :: String,
specialisableImpl :: [Gr.Type] -> forall m. (MonadModuleBuilder m, MonadIRBuilder m) => [Operand] -> m Operand}

specialise :: Specialisable -> String -> Gr.Type -> Builtin
specialise builtin id ty = Builtin id args ret impl
where
args = Gr.parameterTypes ty
ret = Gr.resultType ty
{-# HLINT ignore "Eta reduce" #-} -- has to be lazy or we need IRBuilder early
impl xs = specialisableImpl builtin args xs

-- LLVM helpers

allocate :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> IR.Type -> m Operand
Expand Down
2 changes: 0 additions & 2 deletions src/Language/Granule/Codegen/ClosureFreeDef.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ data ClosureMarker =
CapturedVar Type Id Int
| MakeClosure Id ClosureEnvironmentInit
| MakeTrivialClosure Id
| MakeBuiltinClosure Id
deriving (Show, Eq)

data ClosureFreeAST =
Expand Down Expand Up @@ -84,4 +83,3 @@ instance Pretty ClosureMarker where
"env(ident = \"" ++ envName ++ "\", " ++ intercalate ", " (map prettyEnvVar varInits) ++ ")"
in "make-closure(" ++ pretty ident ++ ", " ++ prettyEnv env ++ ")"
pretty (MakeTrivialClosure ident) = pretty ident
pretty (MakeBuiltinClosure ident) = pretty ident
8 changes: 6 additions & 2 deletions src/Language/Granule/Codegen/Compile.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@ import Language.Granule.Codegen.TopsortDefinitions
import Language.Granule.Codegen.ConvertClosures
import Language.Granule.Codegen.Emit.EmitLLVM
import Language.Granule.Codegen.MarkGlobals
import Language.Granule.Codegen.Monomorphise
import Language.Granule.Codegen.RewriteAST
import Language.Granule.Codegen.SubstituteTypes

import qualified LLVM.AST as IR

compile :: String -> AST () Type -> Either String IR.Module
compile moduleName typedAST =
let rewritten = rewriteAST typedAST
normalised = normaliseDefinitions rewritten
let substituted = substituteTypes typedAST
rewritten = rewriteAST substituted
monomorphised = monomorphiseAST rewritten
normalised = normaliseDefinitions monomorphised
markedGlobals = markGlobals normalised
(Ok topsorted) = topologicallySortDefinitions markedGlobals
closureFree = convertClosures topsorted
Expand Down
3 changes: 0 additions & 3 deletions src/Language/Granule/Codegen/ConvertClosures.hs
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,5 @@ convertClosuresFromValue (_, maybeCurrentEnv, locals) (VarF ty ident)
++ sourceName ident ++ " in environment."
in return $ Ext ty (Right (CapturedVar ty ident indexInEnv))

convertClosuresFromValue (_, maybeCurrentEnv, _) (ExtF ty (BuiltinVar _ ident)) =
return $ Ext ty $ Right $ MakeBuiltinClosure ident

convertClosuresFromValue _ other =
return $ fixMapExtValue (\ty gv -> Ext ty $ Left gv) other
14 changes: 11 additions & 3 deletions src/Language/Granule/Codegen/Emit/EmitBuiltins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Control.Monad (forM)
import LLVM.AST (Operand)
import qualified LLVM.AST as IR
import qualified LLVM.AST.Constant as C
import LLVM.AST.Type hiding (Type)
import LLVM.AST.Type hiding (resultType, Type)
import LLVM.IRBuilder.Constant (int32)
import LLVM.IRBuilder.Instruction
import LLVM.IRBuilder.Module
Expand All @@ -16,9 +16,17 @@ import Language.Granule.Codegen.Builtins.Shared
import Language.Granule.Codegen.Emit.LLVMHelpers
import Language.Granule.Codegen.Emit.LowerClosure (mallocEnvironment)
import Language.Granule.Codegen.Emit.LowerType (llvmType, llvmTypeForClosure, llvmTypeForFunction)
import Language.Granule.Syntax.Identifiers
import Language.Granule.Syntax.Type

emitBuiltins :: (MonadModuleBuilder m) => m [Operand]
emitBuiltins = mapM emitBuiltin builtins

emitBuiltins :: (MonadModuleBuilder m) => [(Id, Type)] -> m ()
emitBuiltins uses = mapM_ emitBuiltin (monos ++ polys)
where
monos =
[b | (id, _) <- uses, b <- builtins, builtinId b == sourceName id]
polys =
[specialise b (internalName id) ty | (id, ty) <- uses, b <- specialisable, specialisableId b == sourceName id]

emitBuiltin :: (MonadModuleBuilder m) => Builtin -> m Operand
emitBuiltin builtin =
Expand Down
9 changes: 6 additions & 3 deletions src/Language/Granule/Codegen/Emit/EmitLLVM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,21 @@ import LLVM.IRBuilder (int32)

emitLLVM :: String -> ClosureFreeAST -> Either String IR.Module
emitLLVM moduleName (ClosureFreeAST dataDecls functionDefs valueDefs) =
let buildModule name m = evalState (buildModuleT name m) (EmitterState { localSymbols = Map.empty })
let buildModule name m = evalState (buildModuleT name m) (EmitterState { localSymbols = Map.empty, builtins = Map.empty })
in Right $ buildModule (fromString moduleName) $ do
_ <- extern (mkName "malloc") [i64] (ptr i8)
_ <- extern (mkName "abort") [] void
_ <- externVarArgs (mkName "printf") [ptr i8] i32
_ <- extern (mkName "llvm.memcpy.p0.p0.i32") [ptr i8, ptr i8, i32, i1] void
_ <- extern (mkName "free") [ptr i8] void
_ <- emitBuiltins
let mainTy = findMainReturnType valueDefs
_ <- emitMainOut mainTy
mapM_ emitDataDecl dataDecls
mapM_ emitEnvironmentType functionDefs
mapM_ emitFunctionDef functionDefs
valueInitPairs <- mapM emitValueDef valueDefs
builtins <- usedBuiltins
_ <- emitBuiltins builtins
emitGlobalInitializer valueInitPairs mainTy

emitGlobalInitializer :: (MonadModuleBuilder m) => [(Operand, Operand)] -> GrType -> m Operand
Expand Down Expand Up @@ -125,7 +126,9 @@ emitFunction _ _ _ _ _ = error "cannot emit function with non function type"
paramName :: Pattern GrType -> ParameterName
paramName (PConstr _ _ _ (Id "," _) _ _) = parameterNameFromId $ mkId "pair"
paramName (PConstr _ _ _ (Id "()" _) _ _) = parameterNameFromId $ mkId "unit"
paramName pat = parameterNameFromId $ head $ boundVars pat
paramName pat = case boundVars pat of
[] -> parameterNameFromId $ mkId "wildcard"
(var : _) -> parameterNameFromId var

emitArg :: (MonadState EmitterState m, MonadModuleBuilder m, MonadIRBuilder m)
=> Pattern GrType
Expand Down
9 changes: 8 additions & 1 deletion src/Language/Granule/Codegen/Emit/EmitterState.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
module Language.Granule.Codegen.Emit.EmitterState where

import Language.Granule.Syntax.Identifiers (Id, internalName)
import Language.Granule.Syntax.Type
import Control.Monad.State.Strict hiding (void)

import LLVM.AST (Operand)

import Data.Map (Map, insertWith)
import qualified Data.Map as Map

data EmitterState = EmitterState { localSymbols :: Map Id Operand }
data EmitterState = EmitterState { localSymbols :: Map Id Operand, builtins :: Map Id Type }

addLocal :: (MonadState EmitterState m)
=> Id
Expand Down Expand Up @@ -40,3 +41,9 @@ local name =
case local of
Just op -> return op
Nothing -> error $ internalName name ++ "not registered as a local, missing call to addLocal?\n"

useBuiltin :: (MonadState EmitterState m) => Id -> Type -> m ()
useBuiltin id ty = modify $ \s -> s { builtins = Map.insert id ty (builtins s) }

usedBuiltins :: (MonadState EmitterState m) => m [(Id, Type)]
usedBuiltins = Map.toList <$> gets builtins
5 changes: 0 additions & 5 deletions src/Language/Granule/Codegen/Emit/LowerClosure.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ emitClosureMarker ty maybeParentEnv (MakeClosure ident initializer) =
emitClosureMarker ty _ (MakeTrivialClosure identifier) =
return $ ConstantOperand $ makeTrivialClosure identifier ty

emitClosureMarker ty _ (MakeBuiltinClosure ident) = do
let functionPtr = ConstantOperand $ C.GlobalReference (ptr $ llvmTopLevelType ty) (functionNameFromId ident)
closure <- insertValue (ConstantOperand $ C.Undef (llvmType ty)) functionPtr [0]
insertValue closure (ConstantOperand $ C.Null (ptr i8)) [1]

emitEnvironmentInit :: (MonadModuleBuilder m, MonadIRBuilder m, MonadState EmitterState m)
=> [ClosureVariableInit]
-> Operand
Expand Down
9 changes: 6 additions & 3 deletions src/Language/Granule/Codegen/Emit/LowerExpression.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Language.Granule.Codegen.Emit.LowerExpression where
import Language.Granule.Codegen.ClosureFreeDef (ClosureMarker)
import Language.Granule.Codegen.MarkGlobals (GlobalMarker, GlobalMarker(..))
import Language.Granule.Codegen.Emit.LowerOperator
import Language.Granule.Codegen.Emit.LowerType (llvmType)
import Language.Granule.Codegen.Emit.LowerType (llvmType, llvmTopLevelType)
import Language.Granule.Codegen.Emit.EmitableDef
import Language.Granule.Codegen.Emit.LowerPatterns (emitCaseArm)
import Language.Granule.Codegen.Emit.LowerClosure (emitClosureMarker)
Expand All @@ -29,7 +29,7 @@ import LLVM.IRBuilder.Monad
import LLVM.IRBuilder.Instruction

import LLVM.AST (Operand)
import LLVM.AST.Type (ptr)
import LLVM.AST.Type (ptr, i8)
import LLVM.AST.Constant as C
import qualified LLVM.IRBuilder.Constant as IC
import qualified LLVM.AST as IR
Expand Down Expand Up @@ -137,7 +137,10 @@ emitValue _ (ExtF a (Left (GlobalVar ty ident))) = do
let ref = IR.ConstantOperand $ C.GlobalReference (ptr (llvmType ty)) (definitionNameFromId ident)
load ref 4
emitValue _ (ExtF a (Left (BuiltinVar ty ident))) = do
error "TODO?"
useBuiltin ident ty
let functionPtr = IR.ConstantOperand $ C.GlobalReference (ptr $ llvmTopLevelType ty) (IR.mkName $ "fn." ++ internalName ident)
closure <- insertValue (IR.ConstantOperand $ C.Undef (llvmType ty)) functionPtr [0]
insertValue closure (IR.ConstantOperand $ C.Null (ptr i8)) [1]
emitValue environment (ExtF ty (Right cm)) =
emitClosureMarker ty environment cm
{- TODO: Support tagged unions, also affects Case.
Expand Down
1 change: 1 addition & 0 deletions src/Language/Granule/Codegen/Emit/MainOut.hs
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,5 @@ fmtStrForTy x =
(TyApp (TyCon (Id "FloatArray" _)) _) -> "<array>"
(TyCon (Id "()" _)) -> "()"
(TyExists _ _ (Borrow _ ty)) -> "*" ++ fmtStrForTy ty
(Box _ ty) -> "[" ++ fmtStrForTy ty ++ "]"
_ -> error ("Unsupported Main type: " ++ show x)
3 changes: 2 additions & 1 deletion src/Language/Granule/Codegen/MarkGlobals.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ markGlobalsInExpr :: [Id] -> Expr () Type -> Expr GlobalMarker Type
markGlobalsInExpr globals =
bicata fixMapExtExpr markInValue
where markInValue (VarF ty ident)
| ident `elem` builtinIds = Ext ty (BuiltinVar ty ident)
| any (\id -> sourceName ident == sourceName id) builtinIds =
Ext ty (BuiltinVar ty ident)
| ident `elem` globals = Ext ty (GlobalVar ty ident)
| otherwise = Var ty ident
markInValue other =
Expand Down
Loading