diff --git a/pil-analyzer/src/pil_analyzer.rs b/pil-analyzer/src/pil_analyzer.rs index 8b141328b..fab2b5605 100644 --- a/pil-analyzer/src/pil_analyzer.rs +++ b/pil-analyzer/src/pil_analyzer.rs @@ -14,7 +14,7 @@ use powdr_ast::parsed::asm::{ use powdr_ast::parsed::types::Type; use powdr_ast::parsed::visitor::{AllChildren, Children}; use powdr_ast::parsed::{ - self, FunctionKind, LambdaExpression, PILFile, PilStatement, SymbolCategory, + self, FunctionKind, LambdaExpression, PILFile, PilStatement, SourceReference, SymbolCategory, TraitImplementation, TypedExpression, }; use powdr_number::{FieldElement, GoldilocksField}; @@ -29,7 +29,7 @@ use powdr_parser_util::Error; use crate::traits_resolver::TraitsResolver; use crate::type_builtins::constr_function_statement_type; -use crate::type_inference::infer_types; +use crate::type_inference::{infer_types, DeclaredType}; use crate::{side_effect_checker, AnalysisDriver}; use crate::statement_processor::{Counters, PILItem, StatementProcessor}; @@ -303,19 +303,21 @@ impl PILAnalyzer { ) }) .flat_map(|(name, (symbol, value))| { - let (type_scheme, expr) = match (symbol.kind, value) { + let (declared_type, expr) = match (symbol.kind, value) { (SymbolKind::Poly(PolynomialType::Committed), Some(value)) => { // Witness column, move its value (query function) into the expressions to be checked separately. - let type_scheme = type_from_definition(symbol, &None); let FunctionValueDefinition::Expression(TypedExpression { e, .. }) = value else { panic!("Invalid value for query function") }; - + let source = e.source_reference().clone(); expressions.push((e, query_type.clone().into())); - (type_scheme, None) + let declared_type = type_from_definition(symbol, &None) + .map(|ts| ts.into()) + .map(|dec: DeclaredType| dec.with_source(source)); + (declared_type, None) } ( _, @@ -323,19 +325,26 @@ impl PILAnalyzer { type_scheme, e, })), - ) => (type_scheme.clone(), Some(e)), + ) => { + let source = e.source_reference(); + let declared_type = type_scheme + .clone() + .map(|ts| ts.into()) + .map(|dec: DeclaredType| dec.with_source(source.clone())); + (declared_type, Some(e)) + } (_, value) => { - let type_scheme = type_from_definition(symbol, value); + let declared_type = type_from_definition(symbol, value).map(|ts| ts.into()); if let Some(FunctionValueDefinition::Array(items)) = value { // Expect all items in the arrays to be field elements. expressions.extend(items.children_mut().map(|e| (e, Type::Fe.into()))); } - (type_scheme, None) + (declared_type, None) } }; - Some((name.clone(), (type_scheme, expr))) + Some((name.clone(), (declared_type, expr))) }) .collect(); for expr in &mut self.proof_items { diff --git a/pil-analyzer/src/type_inference.rs b/pil-analyzer/src/type_inference.rs index 82ee2c8a4..37d968b69 100644 --- a/pil-analyzer/src/type_inference.rs +++ b/pil-analyzer/src/type_inference.rs @@ -1,3 +1,4 @@ +use core::panic; use std::collections::{BTreeSet, HashMap}; use itertools::Itertools; @@ -29,7 +30,7 @@ use crate::{ /// Sets the generic arguments for references and the literal types in all expressions. /// Returns the types for symbols without explicit type. pub fn infer_types( - definitions: HashMap, Option<&mut Expression>)>, + definitions: HashMap, Option<&mut Expression>)>, expressions: &mut [(&mut Expression, ExpectedType)], ) -> Result, Vec> { TypeChecker::new().infer_types(definitions, expressions) @@ -60,13 +61,70 @@ impl From for ExpectedType { } } +#[derive(Debug, Clone)] +pub struct DeclaredType { + source: SourceRef, + vars: TypeBounds, + ty: DeclaredTypeKind, +} + +impl DeclaredType { + fn scheme(&self) -> TypeScheme { + match &self.ty { + DeclaredTypeKind::Struct(ty, _) | DeclaredTypeKind::Type(ty) => TypeScheme { + vars: self.vars.clone(), + ty: ty.clone(), + }, + } + } + + fn type_mut(&mut self) -> &mut Type { + match &mut self.ty { + DeclaredTypeKind::Struct(ty, _) => ty, + DeclaredTypeKind::Type(ty) => ty, + } + } + + fn declared_type(&self) -> &Type { + match &self.ty { + DeclaredTypeKind::Struct(ty, _) | DeclaredTypeKind::Type(ty) => ty, + } + } + + pub fn with_source(mut self, source: SourceRef) -> Self { + self.source = source; + self + } + + fn is_concrete(&self) -> bool { + self.vars.is_empty() + } +} + +#[derive(Debug, Clone)] +enum DeclaredTypeKind { + #[allow(dead_code)] // Remove when #1910 is merged + Struct(Type, HashMap), + Type(Type), +} + +impl From for DeclaredType { + fn from(scheme: TypeScheme) -> Self { + Self { + source: SourceRef::unknown(), + vars: scheme.vars.clone(), + ty: DeclaredTypeKind::Type(scheme.ty.clone()), + } + } +} + struct TypeChecker { /// Types for local variables, might contain type variables. local_var_types: Vec, /// Declared types for all symbols and their source references. /// Contains the unmodified type scheme for symbols with generic types and newly /// created type variables for symbols without declared type. - declared_types: HashMap, + declared_types: HashMap, /// Current mapping of declared type vars to type. Reset before checking each definition. declared_type_vars: HashMap, unifier: Unifier, @@ -89,7 +147,7 @@ impl TypeChecker { /// returns the types for symbols without explicit type. pub fn infer_types( mut self, - mut definitions: HashMap, Option<&mut Expression>)>, + mut definitions: HashMap, Option<&mut Expression>)>, expressions: &mut [(&mut Expression, ExpectedType)], ) -> Result, Vec> { let type_var_mapping = self @@ -100,11 +158,11 @@ impl TypeChecker { .into_iter() .filter(|(_, (ty, _))| ty.is_none()) .map(|(name, _)| { - let (_, mut scheme) = self.declared_types.remove(&name).unwrap(); - assert!(scheme.vars.is_empty()); - self.substitute(&mut scheme.ty); - assert!(scheme.ty.is_concrete_type()); - (name, scheme.ty) + let mut declared_type = self.declared_types.remove(&name).unwrap(); + assert!(declared_type.is_concrete()); + self.substitute(declared_type.type_mut()); + assert!(declared_type.scheme().ty.is_concrete_type()); + (name, declared_type.scheme().ty) }) .collect()) } @@ -113,7 +171,7 @@ impl TypeChecker { /// the type variables used by the type checker to those used in the declaration. fn infer_types_inner( &mut self, - definitions: &mut HashMap, Option<&mut Expression>)>, + definitions: &mut HashMap, Option<&mut Expression>)>, expressions: &mut [(&mut Expression, ExpectedType)], ) -> Result>, Error> { // TODO in order to fix type inference on recursive functions, we need to: @@ -129,6 +187,14 @@ impl TypeChecker { ); self.setup_declared_types(definitions); + // After we setup declared types, every definition + // related with a Struct Declaration is not nedded any more + let mut definitions: HashMap<_,_> = definitions + .iter_mut() + .filter(|(_, (ty, _))| { + !matches!(ty, Some(declared) if matches!(declared.ty, DeclaredTypeKind::Struct(_, _))) + }) + .collect(); // These are the inferred types for symbols that are declared // as type schemes. They are compared to the declared types @@ -145,12 +211,12 @@ impl TypeChecker { continue; }; - let (_, declared_type) = self.declared_types[&name].clone(); - if declared_type.vars.is_empty() { + let declared_type_scheme = self.type_scheme_from_declared_type(&name); + if declared_type_scheme.vars.is_empty() { self.declared_type_vars.clear(); - self.process_concrete_symbol(declared_type.ty.clone(), value)?; + self.process_concrete_symbol(declared_type_scheme.ty.clone(), value)?; } else { - self.declared_type_vars = declared_type + self.declared_type_vars = declared_type_scheme .vars .vars() .map(|v| (v.clone(), self.unifier.new_type_var())) @@ -168,23 +234,25 @@ impl TypeChecker { // Now we check for all symbols that are not declared as a type scheme that they // can resolve to a concrete type. - for (name, (source_ref, declared_type)) in &self.declared_types { - if declared_type.vars.is_empty() { - // It is not a type scheme, see if we were able to derive a concrete type. - let inferred = self.type_into_substituted(declared_type.ty.clone()); - if !inferred.is_concrete_type() { - let inferred_scheme = self.to_type_scheme(inferred); - return Err(source_ref.with_error( - format!( - "Could not derive a concrete type for symbol {name}.\nInferred type scheme: {}\n", - format_type_scheme_around_name( - name, - &Some(inferred_scheme), - ) - ))); - } - } - } + self.declared_types + .iter() + .filter(|(_, declared_type)| declared_type.is_concrete()) + // It is not a type scheme, see if we were able to derive a concrete type. + .map(|(name, declared_type)| { + ( + name, + declared_type.source.clone(), + self.type_into_substituted(declared_type.declared_type().clone()), + ) + }) + .filter(|(_, _, inferred)| !inferred.is_concrete_type()) + .try_for_each(|(name, source, inferred)| { + let inferred_scheme = self.to_type_scheme(inferred); + Err(source.with_error(format!( + "Could not derive a concrete type for symbol {name}.\nInferred type scheme: {}\n", + format_type_scheme_around_name(name, &Some(inferred_scheme)) + ))) + })?; // We check type schemes last, because only at this point do we know // that other types that should be concrete do not occur as type variables in the @@ -194,41 +262,53 @@ impl TypeChecker { self.verify_type_schemes(inferred_types) } + fn type_scheme_from_declared_type(&mut self, name: &String) -> TypeScheme { + let declared_type = self.declared_types[name].clone(); + + match declared_type.ty { + DeclaredTypeKind::Struct(_, _) => { + unreachable!("Declared types for Structs should have been removed at this point") + } + DeclaredTypeKind::Type(_) => declared_type.scheme(), + } + } + /// Fills self.declared_types and checks that declared builtins have the correct type. fn setup_declared_types( &mut self, - definitions: &mut HashMap, Option<&mut Expression>)>, + definitions: &mut HashMap, Option<&mut Expression>)>, ) { // Add types from declarations. Type schemes are added without instantiating. self.declared_types = definitions .iter() - .map(|(name, (type_scheme, value))| { - let source_ref = value - .as_ref() - .map(|v| v.source_reference()) - .cloned() - .unwrap_or_default(); - // Check if it is a builtin symbol. - let ty = match (builtin_schemes().get(name), type_scheme) { + .map(|(name, (declared_type, _))| { + let declared_type = match (builtin_schemes().get(name), declared_type) { (Some(builtin), declared) => { - if let Some(declared) = declared { + if let Some(declared_inner) = declared { + let declared_scheme = declared_inner.scheme(); assert_eq!( - builtin, - declared, + *builtin, + declared_scheme, "Invalid type for built-in scheme. Got {} but expected {}", - format_type_scheme_around_name(name, &Some(declared.clone())), + format_type_scheme_around_name( + name, + &Some(declared_scheme.clone()) + ), format_type_scheme_around_name(name, &Some(builtin.clone())) ); }; - builtin.clone() + builtin.clone().into() } // Store an (uninstantiated) type scheme for symbols with a declared polymorphic type. - (None, Some(type_scheme)) => type_scheme.clone(), + (None, Some(declared_type)) => declared_type.clone(), // Store a new (unquantified) type variable for symbols without declared type. // This forces a single concrete type for them. - (None, None) => self.unifier.new_type_var().into(), + (None, None) => { + let scheme: TypeScheme = self.unifier.new_type_var().into(); + scheme.into() + } }; - (name.clone(), (source_ref, ty)) + (name.clone(), declared_type) }) .collect(); @@ -237,7 +317,7 @@ impl TypeChecker { for (name, scheme) in builtin_schemes() { self.declared_types .entry(name.clone()) - .or_insert_with(|| (SourceRef::unknown(), scheme.clone())); + .or_insert_with(|| DeclaredType::from(scheme.clone())); definitions.remove(name); } } @@ -327,7 +407,7 @@ impl TypeChecker { /// the type variable names used by the type checker to those from the declaration. fn update_type_args( &mut self, - definitions: &mut HashMap, Option<&mut Expression>)>, + definitions: &mut HashMap, Option<&mut Expression>)>, expressions: &mut [(&mut Expression, ExpectedType)], type_var_mapping: &HashMap>, ) -> Result<(), Vec> { @@ -521,9 +601,7 @@ impl TypeChecker { source_ref, Reference::Poly(PolynomialReference { name, type_args }), ) => { - let (ty, args) = self - .unifier - .instantiate_scheme(self.declared_types[name].1.clone()); + let (ty, args) = self.instantiate_scheme_by_declared_name(name); if let Some(requested_type_args) = type_args { if requested_type_args.len() != args.len() { return Err(source_ref.with_error(format!( @@ -856,9 +934,8 @@ impl TypeChecker { Pattern::Enum(source_ref, name, data) => { // We just ignore the generic args here, storing them in the pattern // is not helpful because the type is obvious from the value. - let (ty, _generic_args) = self - .unifier - .instantiate_scheme(self.declared_types[&name.to_string()].1.clone()); + let (ty, _generic_args) = + self.instantiate_scheme_by_declared_name(&name.to_string()); let ty = type_for_reference(&ty); match data { @@ -912,28 +989,45 @@ impl TypeChecker { &self, inferred_types: HashMap, ) -> Result>, Error> { - inferred_types.into_iter().map(|(name, inferred_type)| { - let (source_ref, declared_type) = self.declared_types[&name].clone(); - let inferred_type = self.type_into_substituted(inferred_type.clone()); - let inferred = self.to_type_scheme(inferred_type.clone()); - let declared = declared_type.clone().simplify_type_vars(); - if inferred != declared { - return Err(source_ref.with_error(format!( - "Inferred type scheme for symbol {name} does not match the declared type.\nInferred: let{}\nDeclared: let{}", - format_type_scheme_around_name(&name, &Some(inferred)), - format_type_scheme_around_name(&name, &Some(declared_type), - )))); - } - let declared_type_vars = declared_type.ty.contained_type_vars(); - let inferred_type_vars = inferred_type.contained_type_vars(); - Ok((name.clone(), - inferred_type_vars - .into_iter() - .cloned() - .zip(declared_type_vars.into_iter().map(|tv| Type::TypeVar(tv.clone()))) - .collect(), - )) - }).collect::>() + inferred_types + .into_iter() + .map(|(name, inferred_type)| { + self.compare_inferred_type_to_declared(&name, inferred_type) + .map(|mapping| (name, mapping)) + }) + .collect() + } + + /// Compares two type schemes and returns a mapping from inferred type vars to declared type vars if they match + fn compare_inferred_type_to_declared( + &self, + name: &str, + inferred_type: Type, + ) -> Result, Error> { + let declared_type = self.declared_types[name].clone(); + let declared_scheme = declared_type.scheme(); + let inferred_type = self.type_into_substituted(inferred_type.clone()); + let inferred = self.to_type_scheme(inferred_type.clone()); + let declared = declared_scheme.clone().simplify_type_vars(); + + if inferred != declared { + return Err(declared_type.source.with_error(format!( + "Inferred type scheme does not match the declared type.\nInferred: let{}\nDeclared: let{}", + format_type_scheme_around_name(&name, &Some(inferred)), + format_type_scheme_around_name(&name, &Some(declared)) + ))); + } + + Ok(inferred_type + .contained_type_vars() + .cloned() + .zip( + declared_scheme + .ty + .contained_type_vars() + .map(|tv| Type::TypeVar(tv.clone())), + ) + .collect()) } fn type_into_substituted(&self, mut ty: Type) -> Type { @@ -977,6 +1071,12 @@ impl TypeChecker { pub fn local_var_type(&self, id: u64) -> Type { self.local_var_types[id as usize].clone() } + + /// Returns a new type scheme as a tuple (type, generic args) for the given name. + fn instantiate_scheme_by_declared_name(&mut self, name: &str) -> (Type, Vec) { + self.unifier + .instantiate_scheme(self.declared_types[name].scheme().clone()) + } } fn update_type_if_literal(