Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support better type checking of array sections #14

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 69 additions & 22 deletions src/Language/Fortran/Vars/TypeCheck.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeApplications #-}

module Language.Fortran.Vars.TypeCheck
( Kind
Expand All @@ -12,10 +11,11 @@ import Prelude hiding ( GT
, EQ
, LT
)
import Data.List.NonEmpty ( NonEmpty( (:|) ) )
import qualified Data.Map as M
import Data.Data ( toConstr )
import Data.Maybe ( fromJust )
import Data.List.NonEmpty ( NonEmpty(..) )
import qualified Data.List.NonEmpty as NE
import Language.Fortran.AST ( Expression(..)
, Value(..)
, AList(..)
Expand Down Expand Up @@ -60,6 +60,7 @@ import Language.Fortran.Vars.Types ( SymbolTableEntry(..)
, typeError
, Dim(..)
, Dims(..)
, dimsLength
)
import Language.Fortran.Vars.Kind ( getTypeKind
, setTypeKind
Expand All @@ -70,7 +71,9 @@ import Language.Fortran.Vars.StructureTable
( lookupField )

import Language.Fortran.Analysis.SemanticTypes
( charLenConcat )
( Dimensions
, charLenConcat
)


-- | Given 'SymbolTable' of a 'ProgramUnit', and an 'Expression' within
Expand All @@ -96,11 +99,13 @@ typeOf strTable symTable expr = case expr of
Right _ ->
Left . typeError s $ "Tried to access elements of scalar"
err -> err
ExpSubscript _ s arr (AList _ _ (r@IxRange{} : _)) -> do
ExpSubscript _ s arr (AList _ _ (r@IxRange{} : rs)) -> do
ty <- typeOf strTable symTable arr
case ty of
TCharacter _ _ -> typeOfSubString s symTable strTable ty r
_ -> Right ty
TArray ety dims ->
typeOfArraySection s symTable strTable dims (r :| rs) ety
_ -> Right ty
ExpImpliedDo _ _ es doSpec -> do
dim <- specToDim symTable doSpec
ty <- typeOf strTable symTable . head $ aStrip es
Expand Down Expand Up @@ -264,6 +269,41 @@ typeOfBinaryExp' sp op t1 t2
k1 = getTypeKind t1
k2 = getTypeKind t2

typeOfArraySection
:: SrcSpan
-> SymbolTable
-> StructureTable
-> Dimensions
-> NonEmpty (Index a)
-> Type
-> Either TypeError Type
typeOfArraySection ss symt strt dims ((IxRange _ _ lower upper _) :| ixs) ty
| any isIxRange ixs = Left . typeError ss $ "Unexpected ranges"
| (length ixs + 1) /= dimsLength dims = Left . typeError ss $ "Mismatched indices"
| otherwise = do
isInteger ss $ traverse (typeOf strt symt) lower
isInteger ss $ traverse (typeOf strt symt) upper
case calcLen of
Nothing -> Right $ TArray ty (DimsAssumedSize Nothing lowerIndex)
-- Always return lower 1 because it's not specified
Just len -> Right $ TArray ty (DimsExplicitShape $ Dim (Just 1) (Just len) :| [])
where
calcLen = (\x y -> y - x + 1) <$> lowerIndex <*> upperIndex
isIxRange = \case
IxRange{} -> True
_ -> False
upperIndex :: Maybe Int
upperIndex = case dims of
DimsExplicitShape dims' -> do
upperIx <- dimUpper $ NE.head dims'
getIndex symt upperIx upper
DimsAssumedSize (Just (Dim _ u :| _)) _ -> do
upperIx <- u
getIndex symt upperIx upper
DimsAssumedSize Nothing u -> u
DimsAssumedShape _ -> Nothing
lowerIndex = getIndex symt 1 lower
typeOfArraySection _ _ _ _ _ _ = error "Internal error"

-- | Internal function to determine the type of a substring
-- If either of the indexes cannot be evaluated then we return a dynamically
Expand All @@ -277,27 +317,34 @@ typeOfSubString
-> Index a
-> Either TypeError Type
typeOfSubString sp symt strt ty (IxRange _ _ lower upper _) = do
isInteger $ traverse (typeOf strt symt) lower
isInteger $ traverse (typeOf strt symt) upper
isInteger sp $ traverse (typeOf strt symt) lower
isInteger sp $ traverse (typeOf strt symt) upper
pure $ TCharacter calcLen 1
where
calcLen = case (\x y -> y - x + 1) <$> lowerIndex <*> upperIndex of
Nothing -> CharLenStar
Just len -> CharLenInt len
isInteger = \case
Right (Just (TInteger _)) -> Right ()
Right Nothing -> Right ()
_ -> Left . typeError sp $ "Index wasn't an integer type"
upperIndex = let Just k = getTypeKind ty in getIndex k upper
lowerIndex = getIndex 1 lower
getIndex :: Int -> Maybe (Expression a) -> Maybe Int
getIndex dflt Nothing = Just dflt
getIndex _ (Just e) = case eval' symt e of
Right (Int i) -> Just i
_ -> Nothing

calcLen = maybe CharLenStar
CharLenInt
((\x y -> y - x + 1) <$> lowerIndex <*> upperIndex)
upperIndex = do
k <- getTypeKind ty
getIndex symt k upper
lowerIndex = getIndex symt 1 lower
typeOfSubString _ _ _ _ idx = Left $ UnknownType (getSpan idx)


-- | Aux function used in typeOfSubstring and typeOfArraySection
getIndex :: SymbolTable -> Int -> Maybe (Expression a) -> Maybe Int
getIndex _ dflt Nothing = Just dflt
getIndex symt _ (Just e) = case eval' symt e of
Right (Int i) -> Just i
_ -> Nothing

isInteger :: SrcSpan -> Either TypeError (Maybe SemType) -> Either TypeError ()
isInteger sp = \case
Right (Just (TInteger _)) -> Right ()
Right Nothing -> Right ()
Left err -> Left err
_ -> Left . typeError sp $ "Index wasn't an integer type"

-- | determine the return type of a function call
typeOfFunctionCall
:: SrcSpan
Expand Down
8 changes: 5 additions & 3 deletions test/TypeCheckSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import Language.Fortran.Vars.SymbolTable
( collectSymbols )
import Language.Fortran.Vars.StructureTable
( collectStructures )
import Language.Fortran.Analysis.SemanticTypes (Dims(DimsAssumedSize))

-- | Given a varaible name, 'RHSFunc' search assignment statements within a program
-- unit and returns the RHS of first assignment statement whose LHS matches the
Expand Down Expand Up @@ -183,9 +184,10 @@ spec = do

it "Index ranges" $ do
(typeof, rhs) <- helper path puName
typeof (rhs "i1") `shouldBe` Right (TArray (TInteger 4) (dess1 1 10))
typeof (rhs "i2") `shouldBe` Right (TArray (TInteger 4) (dess1 1 10))
typeof (rhs "i3") `shouldBe` Right (TArray (TInteger 4) (DimsExplicitShape (Dim (Just 1) Nothing :| [])))
typeof (rhs "i1") `shouldBe` Right (TArray (TInteger 4) (dess1 1 3))
typeof (rhs "i2") `shouldBe` Right (TArray (TInteger 4) (dess1 1 1))
typeof (rhs "i3") `shouldBe` Right (TArray (TInteger 4) (DimsAssumedSize Nothing (Just 3)))
typeof (rhs "i4") `shouldBe` Right (TArray (TInteger 2) (dess1 1 6))

it "Erroneous expressions" $ do
-- These expressions aren't valid but any subscript can be assumed to
Expand Down
2 changes: 2 additions & 0 deletions test/type_check/array_and_substring.f
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ subroutine array(c,d,N)
integer N
integer c(N) ! adjustable
integer d(10, *) ! assumed-size
integer*2 e(10, 10)

C test the types of RHS expressions
arr1 = a(1)
Expand All @@ -16,6 +17,7 @@ subroutine array(c,d,N)
i1 = a(3:5)
i2 = a(1:1)
i3 = c(3:)
i4 = e(5:, 2)

c test erroneous expressions where we have too many indices
err1 = a(2, 3)
Expand Down