Skip to content

Commit

Permalink
Using caller tree checking also for reuse, and optimize early outs
Browse files Browse the repository at this point in the history
  • Loading branch information
aardappel committed Jan 21, 2025
1 parent 995e629 commit 9e3e658
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 56 deletions.
2 changes: 1 addition & 1 deletion dev/src/lobster/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ struct CodeGen {
assert(IsUDT(dispatch_type->t));
auto de = dispatch_type->udt->dispatch_table[call.vtable_idx].get();
assert(de->dispatch_root && !de->returntype.Null() && de->subudts_size);
if (de->returned_thru_to_max >= 0) {
if (de->returned_thru_to_max >= 0) {
// This works because all overloads of a DD sit under a single Function.
GenUnwind(sf, de->returned_thru_to_max, outw);
}
Expand Down
1 change: 1 addition & 0 deletions dev/src/lobster/idents.h
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ struct SubFunction {
bool optimized = false;
bool explicit_generics = false;
int returned_thru_to_max = -1; // >=0: there exist return statements that may skip the caller.
vector<int> returned_thru_function_ids;
UDT *method_of = nullptr;
int numcallers = 0;
Type thistype { V_FUNCTION, this }; // convenient place to store the type corresponding to this
Expand Down
137 changes: 82 additions & 55 deletions dev/src/lobster/typecheck.h
Original file line number Diff line number Diff line change
Expand Up @@ -1058,19 +1058,6 @@ struct TypeChecker {
return true;
}

void CheckReturnPast(SubFunction *sf, int nretslots, const SubFunction *sf_to, const Node &context) {
// Special case for returning out of top level, which is always allowed.
if (sf_to != st.toplevel) {
if (sf->isdynamicfunctionvalue) {
// This is because the function has been typechecked against one context, but
// can be called again in a different context that does not have the same callers.
Error(context, "cannot return through dynamic function value (",
"return statement tries to return from ", Q(sf_to->parent->name), ")");
}
}
sf->returned_thru_to_max = std::max(sf->returned_thru_to_max, nretslots);
}

TypeRef TypeCheckMatchingCall(SubFunction *sf, List &call_args, bool static_dispatch,
bool first_dynamic, bool may_have_lambda_args,
DispatchEntry *de) {
Expand All @@ -1085,8 +1072,15 @@ struct TypeChecker {
}
}
sf->numcallers++;
sf->callers.push_back(
Caller{ scopes.empty() ? nullptr : scopes.back().sf, de, this_call_is_recursive });
auto parent_sf = scopes.empty() ? nullptr : scopes.back().sf;
for (auto &caller : sf->callers) {
if (caller.caller == parent_sf && caller.de == de &&
caller.is_recursive == this_call_is_recursive) {
goto existing_caller;
}
}
sf->callers.push_back(Caller{ parent_sf, de, this_call_is_recursive });
existing_caller:
Function &f = *sf->parent;
if (may_have_lambda_args && (static_dispatch || first_dynamic)) {
for (auto [i, c] : enumerate(call_args.children)) {
Expand Down Expand Up @@ -1159,17 +1153,86 @@ struct TypeChecker {
"reused return value");
goto destination_found;
}
auto nretslots = ValWidthMulti(isf->returntype, isf->returntype->NumValues());
CheckReturnPast(isc.sf, nretslots, isf, call_context);
}
// This error should hopefully be rare, but still possible if this call is in
// a very different context.
Error(call_context, "return out of call to ", Q(sf->parent->name),
" can\'t find destination ", Q(isf->parent->name));
destination_found:;
}
vector<SubFunction *> rec_sfs;
for (auto [isf, type] : sf->reuse_return_events) {
auto start_sf = scopes.back().sf;
auto nretslots = ValWidthMulti(isf->returntype, isf->returntype->NumValues());
if (!RecursiveCheckReturns(start_sf, nretslots, isf, rec_sfs, call_context))
Error(call_context, "return from ", Q(isf->parent->name), " called out of context");
assert(rec_sfs.empty());
}
};

// This more complex iteration is needed for recursion, see below in Return::TypeCheck
// and TypeCheckCallStatic
bool RecursiveCheckReturns(SubFunction *sf, int nretslots,
const SubFunction *dest_sf, vector<SubFunction *>rec_dest_sf,
const Node &context) {
if (sf->parent == dest_sf->parent) {
// Reached destination for this particular trace.
return true;
}
// Special case for returning out of top level, which is always allowed.
if (dest_sf != st.toplevel && sf->isdynamicfunctionvalue) {
// This is because the function has been typechecked against one context, but
// can be called again in a different context that does not have the same
// callers.
Error(context, "cannot return through dynamic function value (",
"return statement tries to return from ", Q(dest_sf->parent->name), ")");
}
if (sf->returned_thru_to_max >= nretslots) {
// We already have something returning thru here that is at least as big, check if
// its the same function because then we're done.
// This is purely an early-out optimization.
for (auto idx : sf->returned_thru_function_ids) {
if (idx == dest_sf->parent->idx) {
return true;
}
}
}
sf->returned_thru_to_max = std::max(sf->returned_thru_to_max, nretslots);
sf->returned_thru_function_ids.push_back(dest_sf->parent->idx);
for (auto rsf : rec_dest_sf) if (rsf == sf) {
// We were following a chain from a recursive call, and have arrived at the recursion
// entry point. We can't continue with callers here, which includes the call that set
// rec_dest_sf. We rely on the non-recursive paths to trace beyond this entry point.
return true;
}
// Now we step into the callers. This will typically only have 1 element in it in the
// non-recursive case, and 2 for a normal active recursive call.
for (auto &caller : sf->callers) {
if (!caller.caller) {
return false; // Arrived at root call.
}
if (caller.de) {
caller.de->returned_thru_to_max =
std::max(caller.de->returned_thru_to_max, nretslots);
for (auto udt : caller.de->dispatch_root->subudts) {
// If any SubFunction in the dispatch generates an unwind check, all of them
// must return assuming one.
auto dsf = udt->dispatch_table[caller.de->vtable_idx]->sf;
if (!dsf) continue;
dsf->returned_thru_to_max = std::max(dsf->returned_thru_to_max, nretslots);
}
}
if (caller.is_recursive) rec_dest_sf.push_back(sf);
auto reached =
RecursiveCheckReturns(caller.caller, nretslots, dest_sf, rec_dest_sf, context);
if (caller.is_recursive) rec_dest_sf.pop_back();
if (!reached) return false;
}
return true;
}



void UnWrapBoth(TypeRef &otype, TypeRef &atype) {
while (otype->Wrapped() && otype->t == atype->t) {
otype = otype->Element();
Expand Down Expand Up @@ -3534,43 +3597,6 @@ Node *DynCall::TypeCheck(TypeChecker &tc, size_t reqret) {
return tc.TypeCheckDynCall(this, reqret);
}

// This more complex iteration is needed for recursion, see below in Return::TypeCheck
bool RecursiveCheckReturns(TypeChecker &tc, SubFunction *sf, int nretslots, SubFunction *dest_sf,
SubFunction *rec_dest_sf, const Node &context) {
if (sf->parent == dest_sf->parent) {
// Reached destination for this particular trace.
return true;
}
tc.CheckReturnPast(sf, nretslots, dest_sf, context);
if (rec_dest_sf == sf) {
// We were following a chain from a recursive call, and have arrived at the recursion entry point.
// We can't continue with callers here, which includes the call that set rec_dest_sf.
// We rely on the non-recursive paths to trace beyond this entry point.
return true;
}
// Now we step into the callers. This will typically only have 1 element in it in the non-recursive
// case, and 2 for a normal active recursive call.
// TODO: would be good to check this never "blows up" with lots of callers.
for (auto &caller : sf->callers) {
if (!caller.caller) {
return false; // Arrived at root call.
}
if (caller.de) {
caller.de->returned_thru_to_max = std::max(caller.de->returned_thru_to_max, nretslots);
for (auto udt : caller.de->dispatch_root->subudts) {
// If any SubFunction in the dispatch generates an unwind check, all of them must
// return assuming one.
auto dsf = udt->dispatch_table[caller.de->vtable_idx]->sf;
dsf->returned_thru_to_max = std::max(dsf->returned_thru_to_max, nretslots);
}
}
if (!RecursiveCheckReturns(tc, caller.caller, nretslots, dest_sf,
caller.is_recursive ? sf : rec_dest_sf, context))
return false;
}
return true;
}

Node *Return::TypeCheck(TypeChecker &tc, size_t /*reqret*/) {
exptype = type_void;
lt = LT_ANY;
Expand Down Expand Up @@ -3664,7 +3690,8 @@ Node *Return::TypeCheck(TypeChecker &tc, size_t /*reqret*/) {
// See also reuse code in TypeCheckCallStatic
auto start_sf = tc.scopes.back().sf;
auto nretslots = ValWidthMulti(sf->returntype, sf->returntype->NumValues());
if (!RecursiveCheckReturns(tc, start_sf, nretslots, sf, nullptr, *this))
vector<SubFunction *> rec_sfs;
if (!tc.RecursiveCheckReturns(start_sf, nretslots, sf, rec_sfs, *this))
tc.Error(*this, "return from ", Q(sf->parent->name), " called out of context");
return this;
}
Expand Down

0 comments on commit 9e3e658

Please sign in to comment.