Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typevar count & schema instantiation to Unifier #1684

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 26 additions & 52 deletions pil-analyzer/src/type_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ struct TypeChecker {
/// Current mapping of declared type vars to type. Reset before checking each definition.
declared_type_vars: HashMap<String, Type>,
unifier: Unifier,
/// Last used type variable index.
last_type_var: usize,
/// Keeps track of the kind of lambda we are currently type-checking.
lambda_kind: FunctionKind,
}
Expand All @@ -78,7 +76,6 @@ impl TypeChecker {
declared_types: Default::default(),
declared_type_vars: Default::default(),
unifier: Default::default(),
last_type_var: Default::default(),
lambda_kind: FunctionKind::Constr,
}
}
Expand Down Expand Up @@ -151,7 +148,7 @@ impl TypeChecker {
self.declared_type_vars = declared_type
.vars
.vars()
.map(|v| (v.clone(), self.new_type_var()))
.map(|v| (v.clone(), self.unifier.new_type_var()))
.collect();
self.infer_type_of_expression(value).map(|ty| {
inferred_types.insert(name.to_string(), ty);
Expand Down Expand Up @@ -224,7 +221,7 @@ impl TypeChecker {
(None, Some(type_scheme)) => type_scheme.clone(),
// Store a new (unquantified) type variable for symbols without declared type.
// This forces a single concrete type for them.
(None, None) => self.new_type_var().into(),
(None, None) => self.unifier.new_type_var().into(),
};
(name.clone(), (source_ref, ty))
})
Expand All @@ -250,7 +247,7 @@ impl TypeChecker {
Type::Col => {
// This is a column. It means we prefer `int -> fe`, but `int -> int`
// is also OK if it can be derived directly.
let return_type = self.new_type_var_name();
let return_type = self.unifier.new_type_var_name();
let fun_type = Type::Function(FunctionType {
params: vec![Type::Int],
value: Box::new(Type::TypeVar(return_type.clone())),
Expand All @@ -260,7 +257,7 @@ impl TypeChecker {
Type::Array(ArrayType { base, length: _ }) if base.as_ref() == &Type::Col => {
// An array of columns. We prefer `(int -> fe)[]`, but we also allow `(int -> int)[]`.
// Also we ignore the length.
let return_type = self.new_type_var_name();
let return_type = self.unifier.new_type_var_name();
let fun_type = Type::Function(FunctionType {
params: vec![Type::Int],
value: Box::new(Type::TypeVar(return_type.clone())),
Expand Down Expand Up @@ -515,7 +512,9 @@ impl TypeChecker {
type_args,
}),
) => {
let (ty, args) = self.instantiate_scheme(self.declared_types[name].1.clone());
let (ty, args) = self
.unifier
.instantiate_scheme(self.declared_types[name].1.clone());
if let Some(requested_type_args) = type_args {
if requested_type_args.len() != args.len() {
return Err(source_ref.with_error(format!(
Expand Down Expand Up @@ -550,7 +549,7 @@ impl TypeChecker {
Some(Type::TypeVar(tv)) => Type::TypeVar(tv.clone()),
Some(t) => panic!("Type name annotation for number is not supported: {t}"),
None => {
let tv = self.new_type_var_name();
let tv = self.unifier.new_type_var_name();
*annotated_type = Some(Type::TypeVar(tv.clone()));
Type::TypeVar(tv)
}
Expand Down Expand Up @@ -596,7 +595,7 @@ impl TypeChecker {
})
}
Expression::ArrayLiteral(_, ArrayLiteral { items }) => {
let item_type = self.new_type_var();
let item_type = self.unifier.new_type_var();
for e in items {
self.expect_type(&item_type, e)?;
}
Expand All @@ -608,7 +607,10 @@ impl TypeChecker {
}
Expression::BinaryOperation(source_ref, BinaryOperation { left, op, right }) => {
// TODO at some point, also store the generic args for operators
let fun_type = self.instantiate_scheme(binary_operator_scheme(*op)).0;
let fun_type = self
.unifier
.instantiate_scheme(binary_operator_scheme(*op))
.0;
self.infer_type_of_function_call(
fun_type,
[left, right].into_iter().map(AsMut::as_mut),
Expand All @@ -618,7 +620,10 @@ impl TypeChecker {
}
Expression::UnaryOperation(source_ref, UnaryOperation { op, expr: inner }) => {
// TODO at some point, also store the generic args for operators
let fun_type = self.instantiate_scheme(unary_operator_scheme(*op)).0;
let fun_type = self
.unifier
.instantiate_scheme(unary_operator_scheme(*op))
.0;
self.infer_type_of_function_call(
fun_type,
[inner].into_iter().map(AsMut::as_mut),
Expand All @@ -627,7 +632,7 @@ impl TypeChecker {
)?
}
Expression::IndexAccess(_, IndexAccess { array, index }) => {
let result = self.new_type_var();
let result = self.unifier.new_type_var();
self.expect_type(
&Type::Array(ArrayType {
base: Box::new(result.clone()),
Expand Down Expand Up @@ -657,7 +662,7 @@ impl TypeChecker {
Expression::FreeInput(_, _) => todo!(),
Expression::MatchExpression(_, MatchExpression { scrutinee, arms }) => {
let scrutinee_type = self.infer_type_of_expression(scrutinee)?;
let result = self.new_type_var();
let result = self.unifier.new_type_var();
for MatchArm { pattern, value } in arms {
let local_var_count = self.local_var_types.len();
self.expect_type_of_pattern(&scrutinee_type, pattern)?;
Expand Down Expand Up @@ -738,9 +743,9 @@ impl TypeChecker {
) -> Result<Type, Error> {
let arguments = arguments.collect::<Vec<_>>();
let params = (0..arguments.len())
.map(|_| self.new_type_var())
.map(|_| self.unifier.new_type_var())
.collect::<Vec<_>>();
let result_type = self.new_type_var();
let result_type = self.unifier.new_type_var();
let expected_function_type = Type::Function(FunctionType {
params: params.clone(),
value: Box::new(result_type.clone()),
Expand Down Expand Up @@ -803,9 +808,9 @@ impl TypeChecker {
fn infer_type_of_pattern(&mut self, pattern: &Pattern) -> Result<Type, Error> {
Ok(match pattern {
Pattern::Ellipsis(_) => unreachable!("Should be handled higher up."),
Pattern::CatchAll(_) => self.new_type_var(),
Pattern::CatchAll(_) => self.unifier.new_type_var(),
Pattern::Number(source_ref, _) => {
let ty = self.new_type_var();
let ty = self.unifier.new_type_var();
self.unifier
.ensure_bound(&ty, "FromLiteral".to_string())
.map_err(|e| source_ref.with_error(e))?;
Expand All @@ -819,7 +824,7 @@ impl TypeChecker {
.collect::<Result<_, _>>()?,
}),
Pattern::Array(_, items) => {
let item_type = self.new_type_var();
let item_type = self.unifier.new_type_var();
for item in items {
if !matches!(item, Pattern::Ellipsis(_)) {
self.expect_type_of_pattern(&item_type, item)?;
Expand All @@ -831,14 +836,15 @@ impl TypeChecker {
})
}
Pattern::Variable(_, _) => {
let ty = self.new_type_var();
let ty = self.unifier.new_type_var();
self.local_var_types.push(ty.clone());
ty
}
Pattern::Enum(source_ref, name, data) => {
// We just ignore the generic args here, storing them in the pattern
// is not helpful because the type is obvious from the value.
let (ty, _generic_args) = self
.unifier
.instantiate_scheme(self.declared_types[&name.to_dotted_string()].1.clone());
let ty = type_for_reference(&ty);

Expand Down Expand Up @@ -926,29 +932,6 @@ impl TypeChecker {
self.unifier.substitute(ty);
}

/// Instantiates a type scheme by creating new type variables for the quantified
/// type variables in the scheme and adds the required trait bounds for the
/// new type variables.
/// Returns the new type and a vector of the type variables used for those
/// declared in the scheme.
fn instantiate_scheme(&mut self, scheme: TypeScheme) -> (Type, Vec<Type>) {
let mut ty = scheme.ty;
let vars = scheme
.vars
.bounds()
.map(|(_, bounds)| {
let new_var = self.new_type_var();
for b in bounds {
self.unifier.ensure_bound(&new_var, b.clone()).unwrap();
}
new_var
})
.collect::<Vec<_>>();
let substitutions = scheme.vars.vars().cloned().zip(vars.clone()).collect();
ty.substitute_type_vars(&substitutions);
(ty, vars)
}

fn format_type_with_bounds(&self, ty: Type) -> String {
let scheme = self.to_type_scheme(ty);
let bounds = scheme.vars.format_vars_with_nonempty_bounds();
Expand All @@ -961,15 +944,6 @@ impl TypeChecker {
}
}

fn new_type_var_name(&mut self) -> String {
self.last_type_var += 1;
format!("T{}", self.last_type_var)
}

fn new_type_var(&mut self) -> Type {
Type::TypeVar(self.new_type_var_name())
}

/// Creates a type scheme out of a type by making all unsubstituted
/// type variables generic.
/// TODO this is wrong for mutually recursive generic functions.
Expand Down
39 changes: 38 additions & 1 deletion pil-analyzer/src/type_unifier.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::collections::{HashMap, HashSet};

use powdr_ast::parsed::{types::Type, visitor::Children};
use powdr_ast::parsed::{
types::{Type, TypeScheme},
visitor::Children,
};

use crate::type_builtins::elementary_type_bounds;

Expand All @@ -10,6 +13,8 @@ pub struct Unifier {
type_var_bounds: HashMap<String, HashSet<String>>,
/// Substitutions for type variables
substitutions: HashMap<String, Type>,
/// Last used type variable index.
last_type_var: usize,
}

impl Unifier {
Expand Down Expand Up @@ -115,6 +120,38 @@ impl Unifier {
ty.children_mut().for_each(|t| self.substitute(t));
}

/// Instantiates a type scheme by creating new type variables for the quantified
/// type variables in the scheme and adds the required trait bounds for the
/// new type variables.
/// Returns the new type and a vector of the type variables used for those
/// declared in the scheme.
pub fn instantiate_scheme(&mut self, scheme: TypeScheme) -> (Type, Vec<Type>) {
let mut ty = scheme.ty;
let vars = scheme
.vars
.bounds()
.map(|(_, bounds)| {
let new_var = self.new_type_var();
for b in bounds {
self.ensure_bound(&new_var, b.clone()).unwrap();
}
new_var
})
.collect::<Vec<_>>();
let substitutions = scheme.vars.vars().cloned().zip(vars.clone()).collect();
ty.substitute_type_vars(&substitutions);
(ty, vars)
}

pub fn new_type_var_name(&mut self) -> String {
self.last_type_var += 1;
format!("T{}", self.last_type_var)
}

pub fn new_type_var(&mut self) -> Type {
Type::TypeVar(self.new_type_var_name())
}

fn add_type_var_bound(&mut self, type_var: String, bound: String) {
self.type_var_bounds
.entry(type_var)
Expand Down
Loading