diff --git a/ast.model.boo b/ast.model.boo index e54186723..6f7dbe859 100644 --- a/ast.model.boo +++ b/ast.model.boo @@ -551,6 +551,12 @@ class CastExpression(Expression): class TypeofExpression(Expression): Type as TypeReference +class AsyncBlockExpression(BlockExpression): + Block as BlockExpression + +class AwaitExpression(Expression): + BaseExpression as Expression + class CustomStatement(Statement): pass diff --git a/src/Boo.Lang.Compiler/Ast/AsyncBlockExpression.cs b/src/Boo.Lang.Compiler/Ast/AsyncBlockExpression.cs new file mode 100644 index 000000000..5b0d0ed5f --- /dev/null +++ b/src/Boo.Lang.Compiler/Ast/AsyncBlockExpression.cs @@ -0,0 +1,52 @@ +#region license +// Copyright (c) 2009 Rodrigo B. de Oliveira (rbo@acm.org) +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Rodrigo B. de Oliveira nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#endregion + +namespace Boo.Lang.Compiler.Ast +{ + using System; + + public partial class AsyncBlockExpression + { + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public AsyncBlockExpression() + { + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public AsyncBlockExpression(LexicalInfo lexicalInfo) : base(lexicalInfo) + { + } + + public AsyncBlockExpression(BlockExpression value) : base(value.LexicalInfo) + { + _block = value; + _block.InitializeParent(this); + } + } +} + diff --git a/src/Boo.Lang.Compiler/Ast/AwaitExpression.cs b/src/Boo.Lang.Compiler/Ast/AwaitExpression.cs new file mode 100644 index 000000000..a1a34f6e1 --- /dev/null +++ b/src/Boo.Lang.Compiler/Ast/AwaitExpression.cs @@ -0,0 +1,25 @@ +using Boo.Lang.Compiler.TypeSystem; + +namespace Boo.Lang.Compiler.Ast +{ + public partial class AwaitExpression + { + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public AwaitExpression() + { + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public AwaitExpression(LexicalInfo lexicalInfo) + : base(lexicalInfo) + { + } + + public AwaitExpression(Expression baseExpression) + : base(baseExpression.LexicalInfo) + { + _baseExpression = baseExpression; + _baseExpression.InitializeParent(this); + } + } +} diff --git a/src/Boo.Lang.Compiler/Ast/IAstVisitor.Generated.cs b/src/Boo.Lang.Compiler/Ast/IAstVisitor.Generated.cs index de3e847b9..caa607c75 100644 --- a/src/Boo.Lang.Compiler/Ast/IAstVisitor.Generated.cs +++ b/src/Boo.Lang.Compiler/Ast/IAstVisitor.Generated.cs @@ -124,6 +124,8 @@ public interface IAstVisitor void OnTryCastExpression(TryCastExpression node); void OnCastExpression(CastExpression node); void OnTypeofExpression(TypeofExpression node); + void OnAsyncBlockExpression(AsyncBlockExpression node); + void OnAwaitExpression(AwaitExpression node); void OnCustomStatement(CustomStatement node); void OnCustomExpression(CustomExpression node); void OnStatementTypeMember(StatementTypeMember node); diff --git a/src/Boo.Lang.Compiler/Ast/Impl/AsyncBlockExpressionImpl.cs b/src/Boo.Lang.Compiler/Ast/Impl/AsyncBlockExpressionImpl.cs new file mode 100644 index 000000000..f050f81ed --- /dev/null +++ b/src/Boo.Lang.Compiler/Ast/Impl/AsyncBlockExpressionImpl.cs @@ -0,0 +1,208 @@ +#region license +// Copyright (c) 2009 Rodrigo B. de Oliveira (rbo@acm.org) +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Rodrigo B. de Oliveira nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#endregion + +// +// DO NOT EDIT THIS FILE! +// +// This file was generated automatically by astgen.boo. +// +namespace Boo.Lang.Compiler.Ast +{ + using System.Collections; + using System.Runtime.Serialization; + + [System.Serializable] + public partial class AsyncBlockExpression : BlockExpression + { + protected BlockExpression _block; + + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + new public AsyncBlockExpression CloneNode() + { + return (AsyncBlockExpression)Clone(); + } + + /// + /// + /// + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + new public AsyncBlockExpression CleanClone() + { + return (AsyncBlockExpression)base.CleanClone(); + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override public NodeType NodeType + { + get { return NodeType.AsyncBlockExpression; } + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override public void Accept(IAstVisitor visitor) + { + visitor.OnAsyncBlockExpression(this); + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override public bool Matches(Node node) + { + if (node == null) return false; + if (NodeType != node.NodeType) return false; + var other = ( AsyncBlockExpression)node; + if (!Node.AllMatch(_parameters, other._parameters)) return NoMatch("AsyncBlockExpression._parameters"); + if (!Node.Matches(_returnType, other._returnType)) return NoMatch("AsyncBlockExpression._returnType"); + if (!Node.Matches(_body, other._body)) return NoMatch("AsyncBlockExpression._body"); + if (!Node.Matches(_block, other._block)) return NoMatch("AsyncBlockExpression._block"); + return true; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override public bool Replace(Node existing, Node newNode) + { + if (base.Replace(existing, newNode)) + { + return true; + } + if (_parameters != null) + { + ParameterDeclaration item = existing as ParameterDeclaration; + if (null != item) + { + ParameterDeclaration newItem = (ParameterDeclaration)newNode; + if (_parameters.Replace(item, newItem)) + { + return true; + } + } + } + if (_returnType == existing) + { + this.ReturnType = (TypeReference)newNode; + return true; + } + if (_body == existing) + { + this.Body = (Block)newNode; + return true; + } + if (_block == existing) + { + this.Block = (BlockExpression)newNode; + return true; + } + return false; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override public object Clone() + { + + AsyncBlockExpression clone = new AsyncBlockExpression(); + clone._lexicalInfo = _lexicalInfo; + clone._endSourceLocation = _endSourceLocation; + clone._documentation = _documentation; + clone._isSynthetic = _isSynthetic; + clone._entity = _entity; + if (_annotations != null) clone._annotations = (Hashtable)_annotations.Clone(); + clone._expressionType = _expressionType; + if (null != _parameters) + { + clone._parameters = _parameters.Clone() as ParameterDeclarationCollection; + clone._parameters.InitializeParent(clone); + } + if (null != _returnType) + { + clone._returnType = _returnType.Clone() as TypeReference; + clone._returnType.InitializeParent(clone); + } + if (null != _body) + { + clone._body = _body.Clone() as Block; + clone._body.InitializeParent(clone); + } + if (null != _block) + { + clone._block = _block.Clone() as BlockExpression; + clone._block.InitializeParent(clone); + } + return clone; + + + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override internal void ClearTypeSystemBindings() + { + _annotations = null; + _entity = null; + _expressionType = null; + if (null != _parameters) + { + _parameters.ClearTypeSystemBindings(); + } + if (null != _returnType) + { + _returnType.ClearTypeSystemBindings(); + } + if (null != _body) + { + _body.ClearTypeSystemBindings(); + } + if (null != _block) + { + _block.ClearTypeSystemBindings(); + } + + } + + + [System.Xml.Serialization.XmlElement] + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public BlockExpression Block + { + + get { return _block; } + set + { + if (_block != value) + { + _block = value; + if (null != _block) + { + _block.InitializeParent(this); + } + } + } + + } + + + } +} + diff --git a/src/Boo.Lang.Compiler/Ast/Impl/AwaitExpressionImpl.cs b/src/Boo.Lang.Compiler/Ast/Impl/AwaitExpressionImpl.cs new file mode 100644 index 000000000..a960d8971 --- /dev/null +++ b/src/Boo.Lang.Compiler/Ast/Impl/AwaitExpressionImpl.cs @@ -0,0 +1,156 @@ +#region license +// Copyright (c) 2009 Rodrigo B. de Oliveira (rbo@acm.org) +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Rodrigo B. de Oliveira nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#endregion + +// +// DO NOT EDIT THIS FILE! +// +// This file was generated automatically by astgen.boo. +// +namespace Boo.Lang.Compiler.Ast +{ + using System.Collections; + using System.Runtime.Serialization; + + [System.Serializable] + public partial class AwaitExpression : Expression + { + protected Expression _baseExpression; + + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + new public AwaitExpression CloneNode() + { + return (AwaitExpression)Clone(); + } + + /// + /// + /// + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + new public AwaitExpression CleanClone() + { + return (AwaitExpression)base.CleanClone(); + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override public NodeType NodeType + { + get { return NodeType.AwaitExpression; } + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override public void Accept(IAstVisitor visitor) + { + visitor.OnAwaitExpression(this); + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override public bool Matches(Node node) + { + if (node == null) return false; + if (NodeType != node.NodeType) return false; + var other = ( AwaitExpression)node; + if (!Node.Matches(_baseExpression, other._baseExpression)) return NoMatch("AwaitExpression._baseExpression"); + return true; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override public bool Replace(Node existing, Node newNode) + { + if (base.Replace(existing, newNode)) + { + return true; + } + if (_baseExpression == existing) + { + this.BaseExpression = (Expression)newNode; + return true; + } + return false; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override public object Clone() + { + + AwaitExpression clone = new AwaitExpression(); + clone._lexicalInfo = _lexicalInfo; + clone._endSourceLocation = _endSourceLocation; + clone._documentation = _documentation; + clone._isSynthetic = _isSynthetic; + clone._entity = _entity; + if (_annotations != null) clone._annotations = (Hashtable)_annotations.Clone(); + clone._expressionType = _expressionType; + if (null != _baseExpression) + { + clone._baseExpression = _baseExpression.Clone() as Expression; + clone._baseExpression.InitializeParent(clone); + } + return clone; + + + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override internal void ClearTypeSystemBindings() + { + _annotations = null; + _entity = null; + _expressionType = null; + if (null != _baseExpression) + { + _baseExpression.ClearTypeSystemBindings(); + } + + } + + + [System.Xml.Serialization.XmlElement] + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public Expression BaseExpression + { + + get { return _baseExpression; } + set + { + if (_baseExpression != value) + { + _baseExpression = value; + if (null != _baseExpression) + { + _baseExpression.InitializeParent(this); + } + } + } + + } + + + } +} + diff --git a/src/Boo.Lang.Compiler/Ast/Impl/CodeSerializer.cs b/src/Boo.Lang.Compiler/Ast/Impl/CodeSerializer.cs index a6aec1441..fb5a444e0 100755 --- a/src/Boo.Lang.Compiler/Ast/Impl/CodeSerializer.cs +++ b/src/Boo.Lang.Compiler/Ast/Impl/CodeSerializer.cs @@ -2787,6 +2787,61 @@ override public void OnTypeofExpression(Boo.Lang.Compiler.Ast.TypeofExpression n Push(mie); } + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override public void OnAsyncBlockExpression(Boo.Lang.Compiler.Ast.AsyncBlockExpression node) + { + MethodInvocationExpression mie = new MethodInvocationExpression( + node.LexicalInfo, + CreateReference(node, "Boo.Lang.Compiler.Ast.AsyncBlockExpression")); + mie.Arguments.Add(Serialize(node.LexicalInfo)); + if (ShouldSerialize(node.Parameters)) + { + mie.NamedArguments.Add( + new ExpressionPair( + CreateReference(node, "Parameters"), + SerializeCollection(node, "Boo.Lang.Compiler.Ast.ParameterDeclarationCollection", node.Parameters))); + } + if (ShouldSerialize(node.ReturnType)) + { + mie.NamedArguments.Add( + new ExpressionPair( + CreateReference(node, "ReturnType"), + Serialize(node.ReturnType))); + } + if (ShouldSerialize(node.Body)) + { + mie.NamedArguments.Add( + new ExpressionPair( + CreateReference(node, "Body"), + Serialize(node.Body))); + } + if (ShouldSerialize(node.Block)) + { + mie.NamedArguments.Add( + new ExpressionPair( + CreateReference(node, "Block"), + Serialize(node.Block))); + } + Push(mie); + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + override public void OnAwaitExpression(Boo.Lang.Compiler.Ast.AwaitExpression node) + { + MethodInvocationExpression mie = new MethodInvocationExpression( + node.LexicalInfo, + CreateReference(node, "Boo.Lang.Compiler.Ast.AwaitExpression")); + mie.Arguments.Add(Serialize(node.LexicalInfo)); + if (ShouldSerialize(node.BaseExpression)) + { + mie.NamedArguments.Add( + new ExpressionPair( + CreateReference(node, "BaseExpression"), + Serialize(node.BaseExpression))); + } + Push(mie); + } + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] override public void OnCustomStatement(Boo.Lang.Compiler.Ast.CustomStatement node) { diff --git a/src/Boo.Lang.Compiler/Ast/Impl/DepthFirstGuide.cs b/src/Boo.Lang.Compiler/Ast/Impl/DepthFirstGuide.cs index 5e6eae5ed..18df61b2f 100755 --- a/src/Boo.Lang.Compiler/Ast/Impl/DepthFirstGuide.cs +++ b/src/Boo.Lang.Compiler/Ast/Impl/DepthFirstGuide.cs @@ -2141,6 +2141,54 @@ void IAstVisitor.OnTypeofExpression(Boo.Lang.Compiler.Ast.TypeofExpression node) if (handler != null) handler(node); } + public event NodeEvent OnAsyncBlockExpression; + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + void IAstVisitor.OnAsyncBlockExpression(Boo.Lang.Compiler.Ast.AsyncBlockExpression node) + { + { + var parameters = node.Parameters; + if (parameters != null) + { + var innerList = parameters.InnerList; + var count = innerList.Count; + for (var i=0; i OnAwaitExpression; + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + void IAstVisitor.OnAwaitExpression(Boo.Lang.Compiler.Ast.AwaitExpression node) + { + { + var baseExpression = node.BaseExpression; + if (baseExpression != null) + baseExpression.Accept(this); + } + var handler = OnAwaitExpression; + if (handler != null) + handler(node); + } public event NodeEvent OnCustomStatement; [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] diff --git a/src/Boo.Lang.Compiler/Ast/Impl/DepthFirstTransformer.cs b/src/Boo.Lang.Compiler/Ast/Impl/DepthFirstTransformer.cs index df1f6296b..019033db0 100755 --- a/src/Boo.Lang.Compiler/Ast/Impl/DepthFirstTransformer.cs +++ b/src/Boo.Lang.Compiler/Ast/Impl/DepthFirstTransformer.cs @@ -2707,6 +2707,85 @@ public virtual void LeaveTypeofExpression(Boo.Lang.Compiler.Ast.TypeofExpression { } + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual void OnAsyncBlockExpression(Boo.Lang.Compiler.Ast.AsyncBlockExpression node) + { + if (EnterAsyncBlockExpression(node)) + { + Visit(node.Parameters); + TypeReference currentReturnTypeValue = node.ReturnType; + if (null != currentReturnTypeValue) + { + TypeReference newValue = (TypeReference)VisitNode(currentReturnTypeValue); + if (!object.ReferenceEquals(newValue, currentReturnTypeValue)) + { + node.ReturnType = newValue; + } + } + Block currentBodyValue = node.Body; + if (null != currentBodyValue) + { + Block newValue = (Block)VisitNode(currentBodyValue); + if (!object.ReferenceEquals(newValue, currentBodyValue)) + { + node.Body = newValue; + } + } + BlockExpression currentBlockValue = node.Block; + if (null != currentBlockValue) + { + BlockExpression newValue = (BlockExpression)VisitNode(currentBlockValue); + if (!object.ReferenceEquals(newValue, currentBlockValue)) + { + node.Block = newValue; + } + } + + LeaveAsyncBlockExpression(node); + } + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual bool EnterAsyncBlockExpression(Boo.Lang.Compiler.Ast.AsyncBlockExpression node) + { + return true; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual void LeaveAsyncBlockExpression(Boo.Lang.Compiler.Ast.AsyncBlockExpression node) + { + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual void OnAwaitExpression(Boo.Lang.Compiler.Ast.AwaitExpression node) + { + if (EnterAwaitExpression(node)) + { + Expression currentBaseExpressionValue = node.BaseExpression; + if (null != currentBaseExpressionValue) + { + Expression newValue = (Expression)VisitNode(currentBaseExpressionValue); + if (!object.ReferenceEquals(newValue, currentBaseExpressionValue)) + { + node.BaseExpression = newValue; + } + } + + LeaveAwaitExpression(node); + } + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual bool EnterAwaitExpression(Boo.Lang.Compiler.Ast.AwaitExpression node) + { + return true; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual void LeaveAwaitExpression(Boo.Lang.Compiler.Ast.AwaitExpression node) + { + } + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] public virtual void OnCustomStatement(Boo.Lang.Compiler.Ast.CustomStatement node) { diff --git a/src/Boo.Lang.Compiler/Ast/Impl/DepthFirstVisitor.cs b/src/Boo.Lang.Compiler/Ast/Impl/DepthFirstVisitor.cs index 08ccf00a3..54f9006c3 100755 --- a/src/Boo.Lang.Compiler/Ast/Impl/DepthFirstVisitor.cs +++ b/src/Boo.Lang.Compiler/Ast/Impl/DepthFirstVisitor.cs @@ -1694,6 +1694,51 @@ public virtual void LeaveTypeofExpression(Boo.Lang.Compiler.Ast.TypeofExpression { } + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual void OnAsyncBlockExpression(Boo.Lang.Compiler.Ast.AsyncBlockExpression node) + { + if (EnterAsyncBlockExpression(node)) + { + Visit(node.Parameters); + Visit(node.ReturnType); + Visit(node.Body); + Visit(node.Block); + LeaveAsyncBlockExpression(node); + } + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual bool EnterAsyncBlockExpression(Boo.Lang.Compiler.Ast.AsyncBlockExpression node) + { + return true; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual void LeaveAsyncBlockExpression(Boo.Lang.Compiler.Ast.AsyncBlockExpression node) + { + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual void OnAwaitExpression(Boo.Lang.Compiler.Ast.AwaitExpression node) + { + if (EnterAwaitExpression(node)) + { + Visit(node.BaseExpression); + LeaveAwaitExpression(node); + } + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual bool EnterAwaitExpression(Boo.Lang.Compiler.Ast.AwaitExpression node) + { + return true; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual void LeaveAwaitExpression(Boo.Lang.Compiler.Ast.AwaitExpression node) + { + } + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] public virtual void OnCustomStatement(Boo.Lang.Compiler.Ast.CustomStatement node) { diff --git a/src/Boo.Lang.Compiler/Ast/Impl/FastDepthFirstVisitor.cs b/src/Boo.Lang.Compiler/Ast/Impl/FastDepthFirstVisitor.cs index c443c7e9a..c9684317e 100755 --- a/src/Boo.Lang.Compiler/Ast/Impl/FastDepthFirstVisitor.cs +++ b/src/Boo.Lang.Compiler/Ast/Impl/FastDepthFirstVisitor.cs @@ -1785,6 +1785,46 @@ public virtual void OnTypeofExpression(Boo.Lang.Compiler.Ast.TypeofExpression no } } + [System.CodeDom.Compiler.GeneratedCodeAttribute("astgen.boo", "1")] + public virtual void OnAsyncBlockExpression(Boo.Lang.Compiler.Ast.AsyncBlockExpression node) + { + { + var parameters = node.Parameters; + if (parameters != null) + { + var innerList = parameters.InnerList; + var count = innerList.Count; + for (var i=0; iFalse False false - v3.5 + v4.5 @@ -32,6 +32,7 @@ prompt 4 + false pdbonly @@ -40,6 +41,7 @@ TRACE;$(DefineConstants) prompt 4 + false true @@ -49,6 +51,7 @@ TRACE;DEBUG;IGNOREKEYFILE prompt 4 + false False @@ -71,6 +74,7 @@ full AnyCPU prompt + false bin\Micro-Release\ @@ -80,6 +84,7 @@ pdbonly AnyCPU prompt + false @@ -100,8 +105,10 @@ + + @@ -161,8 +168,10 @@ + + @@ -418,6 +427,13 @@ + + + + + + + @@ -426,6 +442,7 @@ + @@ -456,6 +473,8 @@ + + @@ -490,6 +509,7 @@ + @@ -503,8 +523,11 @@ + + + @@ -539,6 +562,7 @@ + @@ -632,6 +656,7 @@ + diff --git a/src/Boo.Lang.Compiler/CompilerErrorFactory.cs b/src/Boo.Lang.Compiler/CompilerErrorFactory.cs index 854145027..dc3f59f78 100644 --- a/src/Boo.Lang.Compiler/CompilerErrorFactory.cs +++ b/src/Boo.Lang.Compiler/CompilerErrorFactory.cs @@ -967,7 +967,32 @@ public static CompilerError TypeExpected(Node node) return Instantiate("BCE0177", node); } - public static CompilerError Instantiate(string code, Exception error, params object[] args) + public static CompilerError InvalidAsyncType(TypeReference tr) + { + return Instantiate("BCE0178", tr); + } + + public static CompilerError InvalidAwaitType(Expression e) + { + return Instantiate("BCE0179", e); + } + + public static CompilerError RestrictedAwaitType(Node n, IType t) + { + return Instantiate("BCE0180", n, t); + } + + public static CompilerError UnsafeReturnInAsync(Expression e) + { + return Instantiate("BCE0181", e); + } + + public static CompilerError MissingGetAwaiter(Expression e) + { + return Instantiate("BCE0182", e.LexicalInfo, e.ExpressionType); + } + + public static CompilerError Instantiate(string code, Exception error, params object[] args) { return new CompilerError(code, error, args); } diff --git a/src/Boo.Lang.Compiler/CompilerWarningFactory.cs b/src/Boo.Lang.Compiler/CompilerWarningFactory.cs index 7c6e8f3c8..27d155d3e 100644 --- a/src/Boo.Lang.Compiler/CompilerWarningFactory.cs +++ b/src/Boo.Lang.Compiler/CompilerWarningFactory.cs @@ -222,6 +222,11 @@ public static CompilerWarning MethodHidesInheritedNonVirtual(Node anchor, IMetho return Instantiate("BCW0029", AstUtil.SafeLexicalInfo(anchor), hidingMethod, hiddenMethod); } + public static CompilerWarning AsyncNoAwait(Method anchor) + { + return Instantiate("BCW0030", AstUtil.SafeLexicalInfo(anchor)); + } + private static CompilerWarning Instantiate(string code, LexicalInfo location, params object[] args) { return new CompilerWarning(code, location, Array.ConvertAll(args, CompilerErrorFactory.DisplayStringFor)); diff --git a/src/Boo.Lang.Compiler/Pipelines/Compile.cs b/src/Boo.Lang.Compiler/Pipelines/Compile.cs index 0eed1b049..bb607f8cf 100644 --- a/src/Boo.Lang.Compiler/Pipelines/Compile.cs +++ b/src/Boo.Lang.Compiler/Pipelines/Compile.cs @@ -49,7 +49,6 @@ public Compile() Add(new CheckIdentifiers()); Add(new CheckSlicingExpressions()); Add(new StricterErrorChecking()); - Add(new DetectNotImplementedFeatureUsage()); Add(new CheckAttributesUsage()); Add(new ExpandDuckTypedExpressions()); @@ -65,11 +64,13 @@ public Compile() Add(new ProcessSharedLocals()); Add(new ProcessClosures()); - Add(new ProcessGenerators()); + Add(new ProcessGeneratorsAndAsyncMethods()); + Add(new DetectNotImplementedFeatureUsage()); Add(new ExpandVarArgsMethodInvocations()); Add(new InjectCallableConversions()); + Add(new CallableTypeElision()); Add(new ImplementICallableOnCallableDefinitions()); Add(new RemoveDeadCode()); diff --git a/src/Boo.Lang.Compiler/Steps/AstAnnotations.cs b/src/Boo.Lang.Compiler/Steps/AstAnnotations.cs index e1fe5ca66..b077657e0 100644 --- a/src/Boo.Lang.Compiler/Steps/AstAnnotations.cs +++ b/src/Boo.Lang.Compiler/Steps/AstAnnotations.cs @@ -33,6 +33,9 @@ namespace Boo.Lang.Compiler.Steps public class AstAnnotations { private static object TryBlockDepthKey = new object(); + + private static object AmbiguousSigatureKey = new object(); + public const string RawArrayIndexing = "rawarrayindexing"; public const string Checked = "checked"; @@ -79,5 +82,15 @@ public static int GetTryBlockDepth(Node node) { return (int)node[TryBlockDepthKey]; } + + public static void MarkAmbiguousSignature(Expression node) + { + node.Annotate(AmbiguousSigatureKey); + } + + public static bool HasAmbiguousSignature(Expression node) + { + return node.ContainsAnnotation(AmbiguousSigatureKey); + } } } diff --git a/src/Boo.Lang.Compiler/Steps/AsyncAwait/AsyncExceptionHandlerRewriter.cs b/src/Boo.Lang.Compiler/Steps/AsyncAwait/AsyncExceptionHandlerRewriter.cs new file mode 100644 index 000000000..939501855 --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/AsyncAwait/AsyncExceptionHandlerRewriter.cs @@ -0,0 +1,982 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.TypeSystem; +using Boo.Lang.Compiler.TypeSystem.Internal; +using Boo.Lang.Environments; + +namespace Boo.Lang.Compiler.Steps.AsyncAwait +{ + /// + /// The purpose of this rewriter is to replace await-containing catch and finally handlers + /// with surrogate replacements that keep actual handler code in regular code blocks. + /// That allows these constructs to be further lowered at the async lowering pass. + /// + /// Adapted from Microsoft.CodeAnalysis.CSharp.AsyncExceptionHandlerRewriter in the Roslyn codebase + /// + public sealed class AsyncExceptionHandlerRewriter : DepthFirstTransformer + { + private static readonly IMethod _exceptionDispatchInfoCapture; + private static readonly IMethod _exceptionDispatchInfoThrow; + + private readonly BooCodeBuilder _F; + private readonly TypeSystemServices _tss; + private readonly AwaitInFinallyAnalysis _analysis; + private readonly Method _containingMethod; + + private AwaitCatchFrame _currentAwaitCatchFrame; + private AwaitFinallyFrame _currentAwaitFinallyFrame; + + private int _tryDepth = 1; + + static AsyncExceptionHandlerRewriter() + { + var tss = My.Instance; + var exceptionDispatchInfo = tss.Map(typeof(System.Runtime.ExceptionServices.ExceptionDispatchInfo)); + var methods = exceptionDispatchInfo.GetMembers().OfType().ToArray(); + _exceptionDispatchInfoCapture = methods.SingleOrDefault(m => m.Name.Equals("Capture")); + _exceptionDispatchInfoThrow = methods.SingleOrDefault(m => m.Name.Equals("Throw")); + } + + private AsyncExceptionHandlerRewriter( + Method containingMethod, + BooCodeBuilder factory, + AwaitInFinallyAnalysis analysis) + { + _F = factory; + _analysis = analysis; + _containingMethod = containingMethod; + _currentAwaitFinallyFrame = new AwaitFinallyFrame(factory); + _tss = My.Instance; + } + + /// + /// Lower a block of code by performing local rewritings. + /// The goal is to not have exception handlers that contain awaits in them. + /// + /// 1) Await containing ensure blocks: + /// The general strategy is to rewrite await containing handlers into synthetic handlers. + /// Synthetic handlers are not handlers in IL sense so it is ok to have awaits in them. + /// Since synthetic handlers are just blocks, we have to deal with pending exception/branch/return manually + /// (this is the hard part of the rewrite). + /// + /// try: + /// code + /// ensure: + /// handler + /// + /// Into ===> + /// + /// ex as Exception = null + /// pendingBranch as int = 0 + /// + /// try: + /// code // any gotos/returns are rewritten to code that pends the necessary info and goes to finallyLabel + /// goto finallyLabel + /// except ex: // essentially pend the currently active exception + /// pass + /// + /// :finallyLabel + /// handler + /// if ex != null: raise ex // unpend the exception + /// unpend branches/return + /// + /// 2) Await containing catches: + /// try: + /// code + /// except ex as Exception: + /// handler + /// raise + /// + /// + /// Into ===> + /// + /// pendingException as Object + /// pendingCatch as int = 0 + /// + /// try: + /// code + /// except temp as Exception: // essentially pend the currently active exception + /// pendingException = temp + /// pendingCatch = 1 + /// + /// if pendingCatch == 1: + /// var ex = pendingException cast Exception + /// handler + /// raise pendingException + /// + public static void Rewrite(Method containingMethod) + { + if (containingMethod == null) throw new ArgumentNullException("containingMethod"); + + var body = containingMethod.Body; + var analysis = new AwaitInFinallyAnalysis(body); + if (analysis.ContainsAwaitInHandlers()) + { + var factory = CompilerContext.Current.CodeBuilder; + var rewriter = new AsyncExceptionHandlerRewriter(containingMethod, factory, analysis); + containingMethod.Body = (Block)rewriter.Visit(body); + } + } + + public override void OnTryStatement(TryStatement node) + { + Statement finalizedRegion; + Block rewrittenFinally; + + var finallyContainsAwaits = _analysis.FinallyContainsAwaits(node); + if (!finallyContainsAwaits) + { + finalizedRegion = RewriteFinalizedRegion(node); + ++_tryDepth; + rewrittenFinally = (Block)Visit(node.EnsureBlock); + --_tryDepth; + + if (rewrittenFinally == null) + { + ReplaceCurrentNode(finalizedRegion); + return; + } + + var asTry = finalizedRegion as TryStatement; + if (asTry != null) + { + // since finalized region is a try we can just attach finally to it + Debug.Assert(asTry.EnsureBlock == null); + asTry.EnsureBlock = rewrittenFinally; + ReplaceCurrentNode(asTry); + return; + } + // wrap finalizedRegion into a Try with a finally. + ReplaceCurrentNode(new TryStatement(finalizedRegion.LexicalInfo) {EnsureBlock = rewrittenFinally}); + return; + } + + // rewrite finalized region (try and catches) in the current frame + var frame = PushFrame(node); + finalizedRegion = RewriteFinalizedRegion(node); + rewrittenFinally = (Block)Visit(node.EnsureBlock); + PopFrame(); + + var context = CompilerContext.Current; + var exceptionType = _tss.ObjectType; + var pendingExceptionLocal = _F.DeclareTempLocal(_containingMethod, exceptionType); + var finallyLabel = _F.CreateLabel(node.EnsureBlock, context.GetUniqueName("finallyLabel"), _tryDepth); + var pendingBranchVar =_F.DeclareTempLocal(_containingMethod, _tss.IntType); + + var catchAll = new ExceptionHandler + { + Declaration = new Declaration(pendingExceptionLocal.Name, _F.CreateTypeReference(exceptionType)) + {Entity = pendingExceptionLocal}, + Block = new Block(), + IsSynthetic = true + }; + + var tryBlock = new Block( + finalizedRegion, + _F.CreateGoto(finallyLabel, _tryDepth), + PendBranches(frame, pendingBranchVar, finallyLabel)) + ; + var catchAndPendException = new TryStatement {ProtectedBlock = tryBlock}; + catchAndPendException.ExceptionHandlers.Add(catchAll); + + var syntheticFinally = new Block( + finallyLabel.LabelStatement, + rewrittenFinally, + UnpendException(pendingExceptionLocal), + UnpendBranches( + frame, + pendingBranchVar, + pendingExceptionLocal)); + + var locals = _containingMethod.Locals; + var statements = new Block(); + + locals.Add(pendingExceptionLocal.Local); + statements.Add(_F.CreateAssignment( + _F.CreateLocalReference(pendingExceptionLocal), + _F.CreateDefaultInvocation(LexicalInfo.Empty, pendingExceptionLocal.Type))); + locals.Add(pendingBranchVar.Local); + statements.Add(_F.CreateAssignment( + _F.CreateLocalReference(pendingBranchVar), + _F.CreateDefaultInvocation(LexicalInfo.Empty, pendingBranchVar.Type))); + + var returnLocal = frame.returnValue; + if (returnLocal != null) + { + locals.Add(returnLocal.Local); + } + + statements.Add(catchAndPendException); + statements.Add(syntheticFinally); + + ReplaceCurrentNode(statements); + } + + private Block PendBranches( + AwaitFinallyFrame frame, + InternalLocal pendingBranchVar, + InternalLabel finallyLabel) + { + var bodyStatements = new Block(); + + // handle proxy labels if have any + var proxiedLabels = frame.proxiedLabels; + var proxyLabels = frame.proxyLabels; + + // skip 0 - it means we took no explicit branches + int i = 1; + if (proxiedLabels != null) + { + for (int cnt = proxiedLabels.Count; i <= cnt; i++) + { + var proxied = proxiedLabels[i - 1]; + var proxy = proxyLabels[proxied]; + + PendBranch(bodyStatements, proxy, i, pendingBranchVar, finallyLabel); + } + } + + var returnProxy = frame.returnProxyLabel; + if (returnProxy != null) + { + PendBranch(bodyStatements, returnProxy, i, pendingBranchVar, finallyLabel); + } + + return bodyStatements; + } + + private void PendBranch( + Block bodyStatements, + InternalLabel proxy, + int i, + InternalLocal pendingBranchVar, + InternalLabel finallyLabel) + { + // branch lands here + bodyStatements.Add(proxy.LabelStatement); + + // pend the branch + bodyStatements.Add(_F.CreateAssignment( + _F.CreateLocalReference(pendingBranchVar), + _F.CreateIntegerLiteral(i))); + + // skip other proxies + bodyStatements.Add(_F.CreateGoto(finallyLabel, _tryDepth)); + } + + private Statement UnpendBranches( + AwaitFinallyFrame frame, + InternalLocal pendingBranchVar, + InternalLocal pendingException) + { + var parent = frame.ParentOpt; + + // handle proxy labels if have any + var proxiedLabels = frame.proxiedLabels; + + // skip 0 - it means we took no explicit branches + int i = 1; + var cases = new List(); + + if (proxiedLabels != null) + { + for (int cnt = proxiedLabels.Count; i <= cnt; i++) + { + var target = proxiedLabels[i - 1]; + var parentProxy = parent.ProxyLabelIfNeeded(target); + cases.Add(_F.CreateGoto(parentProxy, _tryDepth)); + } + } + + if (frame.returnProxyLabel != null) + { + Local pendingValue = null; + if (frame.returnValue != null) + { + pendingValue = frame.returnValue.Local; + } + + InternalLocal returnValue; + Statement unpendReturn; + + var returnLabel = parent.ProxyReturnIfNeeded(_containingMethod, pendingValue, out returnValue); + + if (returnLabel == null) + { + unpendReturn = new ReturnStatement(_F.CreateLocalReference((InternalLocal)pendingValue.Entity)); + } + else + { + if (pendingValue == null) + { + unpendReturn = _F.CreateGoto(returnLabel, _tryDepth); + } + else + { + unpendReturn = new Block( + new ExpressionStatement( + _F.CreateAssignment(_F.CreateLocalReference(returnValue), + _F.CreateLocalReference((InternalLocal)pendingValue.Entity))), + _F.CreateGoto(returnLabel, _tryDepth)); + } + } + + cases.Add(unpendReturn); + } + + var defaultLabel = _F.CreateLabel(_containingMethod, CompilerContext.Current.GetUniqueName("default"), _tryDepth); + cases.Insert(0, _F.CreateGoto(defaultLabel, _tryDepth)); + return CreateSwitch(cases, defaultLabel, pendingBranchVar); + } + + private Block SwitchBlock(Statement body, int ordinal, InternalLabel endpoint) + { + var result = new Block(); + result.Add(_F.CreateLabel(result, CompilerContext.Current.GetUniqueName("L" + ordinal), _tryDepth).LabelStatement); + result.Add(body); + result.Add(_F.CreateGoto(endpoint, _tryDepth)); + return result; + } + + private Block CreateSwitch(List handlers, InternalLabel endLabel, InternalLocal switchVar) + { + var blocks = handlers.Zip(Enumerable.Range(1, handlers.Count), (s, i1) => SwitchBlock(s, i1, endLabel)).ToArray(); + var resultSwitch = _F.CreateSwitch(_F.CreateLocalReference(switchVar), + blocks.Select(b => b.FirstStatement).Cast()); + var result = new Block(resultSwitch); + result.Statements.AddRange(blocks); + result.Statements.Add(endLabel.LabelStatement); + return result; + } + + public override void OnGotoStatement(GotoStatement node) + { + base.OnGotoStatement(node); + var proxyLabel = _currentAwaitFinallyFrame.ProxyLabelIfNeeded((InternalLabel) node.Label.Entity); + node.Label.Entity = proxyLabel; + } + + public override void OnReturnStatement(ReturnStatement node) + { + InternalLocal returnValue; + var returnLabel = _currentAwaitFinallyFrame.ProxyReturnIfNeeded( + _containingMethod, + node.Expression, + out returnValue); + + if (returnLabel == null) + { + base.OnReturnStatement(node); + return; + } + + var returnExpr = Visit(node.Expression); + Statement result; + if (returnExpr != null) + { + result = new Block( + new ExpressionStatement( + _F.CreateAssignment( + _F.CreateLocalReference(returnValue), + returnExpr)), + _F.CreateGoto(returnLabel, _tryDepth)); + } + else + { + result = _F.CreateGoto(returnLabel, _tryDepth); + } + ReplaceCurrentNode(result); + } + + private Statement UnpendException(InternalLocal pendingExceptionLocal) + { + // create a temp. + // pendingExceptionLocal will certainly be captured, no need to access it over and over. + InternalLocal obj = _F.DeclareTempLocal(_containingMethod, _tss.ObjectType); + var objInit = _F.CreateAssignment(_F.CreateLocalReference(obj), _F.CreateLocalReference(pendingExceptionLocal)); + + // throw pendingExceptionLocal; + Statement rethrow = Rethrow(obj); + + return new Block( + new ExpressionStatement(objInit), + new IfStatement( + pendingExceptionLocal.Local.LexicalInfo, + _F.CreateBoundBinaryExpression( + _tss.BoolType, + BinaryOperatorType.ReferenceInequality, + _F.CreateLocalReference(obj), + new NullLiteralExpression()), + new Block(rethrow), + null)); + } + + private Statement Rethrow(InternalLocal obj) + { + // conservative rethrow + Statement rethrow = new RaiseStatement(_F.CreateLocalReference(obj)); + + // if these helpers are available, we can rethrow with original stack info + // as long as it derives from Exception + if (_exceptionDispatchInfoCapture != null && _exceptionDispatchInfoThrow != null) + { + var ex = _F.DeclareTempLocal(_containingMethod, _tss.ExceptionType); + var assignment = _F.CreateAssignment( + _F.CreateLocalReference(ex), + _F.CreateAsCast(ex.Type, _F.CreateLocalReference(obj))); + + // better rethrow + rethrow = new Block( + new ExpressionStatement(assignment), + new IfStatement( + _F.CreateBoundBinaryExpression( + _tss.BoolType, + BinaryOperatorType.ReferenceEquality, + _F.CreateLocalReference(ex), + new NullLiteralExpression()), + new Block(rethrow), + null), + // ExceptionDispatchInfo.Capture(pendingExceptionLocal).Throw() + new ExpressionStatement( + _F.CreateMethodInvocation( + _F.CreateMethodInvocation( + _exceptionDispatchInfoCapture, + _F.CreateLocalReference(ex)), + _exceptionDispatchInfoThrow)) + ); + } + + return rethrow; + } + + /// + /// Rewrites Try/Catch part of the Try/Catch/Finally + /// + private Statement RewriteFinalizedRegion(TryStatement node) + { + var rewrittenTry = (Block)Visit(node.ProtectedBlock); + + var catches = node.ExceptionHandlers; + if (catches.IsEmpty) + { + return rewrittenTry; + } + + var origAwaitCatchFrame = _currentAwaitCatchFrame; + _currentAwaitCatchFrame = null; + + Visit(node.ExceptionHandlers); + Statement tryWithCatches = new TryStatement{EnsureBlock = rewrittenTry}; + ((TryStatement)tryWithCatches).ExceptionHandlers = node.ExceptionHandlers; + + var currentAwaitCatchFrame = _currentAwaitCatchFrame; + if (currentAwaitCatchFrame != null) + { + var handledLabel = _F.CreateLabel(tryWithCatches, CompilerContext.Current.GetUniqueName("handled"), _tryDepth); + var handlersList = currentAwaitCatchFrame.handlers; + _tryDepth = node.GetAncestors().Count() + 1; + var handlers = new List { _F.CreateGoto(handledLabel, _tryDepth) }; + for (int i = 0, l = handlersList.Count; i < l; i++) + { + handlers.Add( + new Block( + handlersList[i], + _F.CreateGoto(handledLabel, _tryDepth))); + } + + _containingMethod.Locals.Add(currentAwaitCatchFrame.pendingCaughtException.Local); + _containingMethod.Locals.Add(currentAwaitCatchFrame.pendingCatch.Local); + _containingMethod.Locals.AddRange(currentAwaitCatchFrame.GetHoistedLocals().Select(l => l.Local)); + + tryWithCatches = new Block( + new ExpressionStatement( + _F.CreateAssignment( + _F.CreateLocalReference(currentAwaitCatchFrame.pendingCatch), + _F.CreateDefaultInvocation(LexicalInfo.Empty, currentAwaitCatchFrame.pendingCatch.Type))), + tryWithCatches, + CreateSwitch(handlers, handledLabel, currentAwaitCatchFrame.pendingCatch)); + } + + _currentAwaitCatchFrame = origAwaitCatchFrame; + + return tryWithCatches; + } + + public override void OnExceptionHandler(ExceptionHandler node) + { + if (!_analysis.CatchContainsAwait(node)) + { + var origCurrentAwaitCatchFrame = _currentAwaitCatchFrame; + _currentAwaitCatchFrame = null; + + var result = Visit(node); + _currentAwaitCatchFrame = origCurrentAwaitCatchFrame; + ReplaceCurrentNode(result); + } + + var currentAwaitCatchFrame = _currentAwaitCatchFrame ?? + (_currentAwaitCatchFrame = new AwaitCatchFrame(_tss, _F, _containingMethod)); + + var catchType = node.Declaration != null ? (IType)node.Declaration.Type.Entity : _tss.ObjectType; + var catchTemp = _F.DeclareTempLocal(_containingMethod, catchType); + + var storePending = _F.CreateAssignment( + _F.CreateLocalReference(currentAwaitCatchFrame.pendingCaughtException), + _F.CreateCast(currentAwaitCatchFrame.pendingCaughtException.Type, + _F.CreateLocalReference(catchTemp))); + + var setPendingCatchNum = _F.CreateAssignment( + _F.CreateLocalReference(currentAwaitCatchFrame.pendingCatch), + _F.CreateIntegerLiteral(currentAwaitCatchFrame.handlers.Count + 1)); + + // catch (ExType exTemp) + // { + // pendingCaughtException = exTemp; + // catchNo = X; + // } + ExceptionHandler catchAndPend; + + var filterOpt = node.FilterCondition; + if (filterOpt == null) + { + // store pending exception + // as the first statement in a catch + catchAndPend = new ExceptionHandler + { + Declaration = new Declaration(catchTemp.Name, _F.CreateTypeReference(catchType)) { Entity = catchTemp }, + Block = new Block( + new ExpressionStatement(storePending), + new ExpressionStatement(setPendingCatchNum)) + }; + // catch locals live on the synthetic catch handler block + } + else + { + // catch locals move up into hoisted locals + // since we might need to access them from both the filter and the catch + foreach (var local in LocalsUsedIn(node)) + { + currentAwaitCatchFrame.HoistLocal(local, _F); + } + + // store pending exception + // as the first expression in a filter + var decl = node.Declaration; + var sourceOpt = decl != null && !string.IsNullOrEmpty(decl.Name) ? decl : null; + var rewrittenFilter = Visit(filterOpt); + var newFilter = sourceOpt == null ? + _F.CreateEvalInvocation( + storePending.LexicalInfo, + storePending, + rewrittenFilter) : + _F.CreateEvalInvocation( + storePending.LexicalInfo, + storePending, + AssignCatchSource((Declaration)Visit(sourceOpt), currentAwaitCatchFrame), + rewrittenFilter); + + catchAndPend = new ExceptionHandler + { + Declaration = new Declaration(catchTemp.Name, _F.CreateTypeReference(catchType)) { Entity = catchTemp }, + FilterCondition = newFilter, + Block = new Block(new ExpressionStatement(setPendingCatchNum)) + }; + } + if (node.ContainsAnnotation("isSynthesizedAsyncCatchAll")) + catchAndPend.Annotate("isSynthesizedAsyncCatchAll"); + + var handlerStatements = new List(); + + if (filterOpt == null) + { + var sourceOpt = node.Declaration; + if (sourceOpt != null && sourceOpt.Entity != null) + { + Expression assignSource = AssignCatchSource((Declaration)Visit(sourceOpt), currentAwaitCatchFrame); + handlerStatements.Add(new ExpressionStatement(assignSource)); + } + } + + handlerStatements.Add(Visit(node.Block)); + + var handler = new Block(handlerStatements.ToArray()); + + currentAwaitCatchFrame.handlers.Add(handler); + + ReplaceCurrentNode(catchAndPend); + } + + private IEnumerable LocalsUsedIn(ExceptionHandler handler) + { + var rc = new LocalReferenceCollector(_containingMethod.Locals.Select(l => (InternalLocal) l.Entity)); + handler.Accept(rc); + return rc.Result; + } + + private Expression AssignCatchSource(Declaration rewrittenSource, AwaitCatchFrame currentAwaitCatchFrame) + { + Expression assignSource = null; + if (rewrittenSource != null) + { + // exceptionSource = (exceptionSourceType)pendingCaughtException; + assignSource = _F.CreateAssignment( + _F.CreateLocalReference((InternalLocal)rewrittenSource.Entity), + _F.CreateCast( + (IType) rewrittenSource.Type.Entity, + _F.CreateLocalReference(currentAwaitCatchFrame.pendingCaughtException))); + } + + return assignSource; + } + + public override void OnDeclaration(Declaration node) + { + if (node.Entity == null) + return; + var catchFrame = _currentAwaitCatchFrame; + InternalLocal hoistedLocal = null; + if (catchFrame == null || !catchFrame.TryGetHoistedLocal((InternalLocal)node.Entity, out hoistedLocal)) + { + base.OnDeclaration(node); + } + + ReplaceCurrentNode(new Declaration(node.Name, _F.CreateTypeReference(hoistedLocal.Type)) {Entity = hoistedLocal}); + } + + public override void OnRaiseStatement(RaiseStatement node) + { + if (node.Exception != null || _currentAwaitCatchFrame == null) + { + base.OnRaiseStatement(node); + } + + ReplaceCurrentNode(Rethrow(_currentAwaitCatchFrame.pendingCaughtException)); + } + + public override void OnBlockExpression(BlockExpression node) + { + throw new NotImplementedException(); //should have been rewritten already + } + + private AwaitFinallyFrame PushFrame(TryStatement statement) + { + var newFrame = new AwaitFinallyFrame(_currentAwaitFinallyFrame, _analysis.Labels(statement), statement, _F, _tryDepth); + _currentAwaitFinallyFrame = newFrame; + return newFrame; + } + + private void PopFrame() + { + var result = _currentAwaitFinallyFrame; + _currentAwaitFinallyFrame = result.ParentOpt; + } + + /// + /// Analyzes method body for try blocks with awaits in finally blocks + /// Also collects labels that such blocks contain. + /// + private sealed class AwaitInFinallyAnalysis : LabelCollector + { + // all try blocks with yields in them and complete set of labels inside those try blocks + // NOTE: non-yielding try blocks are transparently ignored - i.e. their labels are included + // in the label set of the nearest yielding-try parent + private Dictionary> _labelsInInterestingTry; + + private HashSet _awaitContainingCatches; + + // transient accumulators. + private bool _seenAwait; + + public AwaitInFinallyAnalysis(Statement body) + { + _seenAwait = false; + Visit(body); + } + + /// + /// Returns true if a finally of the given try contains awaits + /// + public bool FinallyContainsAwaits(TryStatement statement) + { + return _labelsInInterestingTry != null && _labelsInInterestingTry.ContainsKey(statement); + } + + /// + /// Returns true if a catch contains awaits + /// + internal bool CatchContainsAwait(ExceptionHandler node) + { + return _awaitContainingCatches != null && _awaitContainingCatches.Contains(node); + } + + /// + /// Returns true if body contains await in a finally block. + /// + public bool ContainsAwaitInHandlers() + { + return _labelsInInterestingTry != null || _awaitContainingCatches != null; + } + + /// + /// Labels reachable from within this frame without invoking its finally. + /// null if there are no such labels. + /// + internal HashSet Labels(TryStatement statement) + { + return _labelsInInterestingTry[statement]; + } + + public override void OnTryStatement(TryStatement node) + { + var origLabels = _currentLabels; + _currentLabels = null; + Visit(node.ProtectedBlock); + Visit(node.ExceptionHandlers); + + var origSeenAwait = _seenAwait; + _seenAwait = false; + Visit(node.EnsureBlock); + + if (_seenAwait) + { + // this try has awaits in the finally ! + var labelsInInterestingTry = _labelsInInterestingTry; + if (labelsInInterestingTry == null) + { + _labelsInInterestingTry = labelsInInterestingTry = new Dictionary>(); + } + + labelsInInterestingTry.Add(node, _currentLabels); + _currentLabels = origLabels; + } + else + { + // this is a boring try without awaits in finally + + // currentLabels = currentLabels U origLabels ; + if (_currentLabels == null) + { + _currentLabels = origLabels; + } + else if (origLabels != null) + { + _currentLabels.UnionWith(origLabels); + } + } + + _seenAwait = _seenAwait | origSeenAwait; + } + + public override void OnExceptionHandler(ExceptionHandler node) + { + var origSeenAwait = _seenAwait; + _seenAwait = false; + + base.OnExceptionHandler(node); + + if (_seenAwait) + { + if (_awaitContainingCatches == null) + { + _awaitContainingCatches = new HashSet(); + } + + _awaitContainingCatches.Add(node); + } + + _seenAwait |= origSeenAwait; + } + + public override void OnAwaitExpression(AwaitExpression node) + { + _seenAwait = true; + base.OnAwaitExpression(node); + } + } + + // storage of various information about a given finally frame + private sealed class AwaitFinallyFrame + { + // Enclosing frame. Root frame does not have parent. + public readonly AwaitFinallyFrame ParentOpt; + + // labels within this frame (branching to these labels does not go through finally). + private readonly HashSet _labelsOpt; + + private readonly BooCodeBuilder _builder; + + // the try statement the frame is associated with + private readonly TryStatement _tryStatementOpt; + + // proxy labels for branches leaving the frame. + // we build this on demand once we encounter leaving branches. + // subsequent leaves to an already proxied label redirected to the proxy. + // At the proxy label we will execute finally and forward the control flow + // to the actual destination. (which could be proxied again in the parent) + public Dictionary proxyLabels; + + public List proxiedLabels; + + public InternalLabel returnProxyLabel; + public InternalLocal returnValue; + + private int _tryDepth; + + public AwaitFinallyFrame(BooCodeBuilder builder) + { + _builder = builder; + } + + public AwaitFinallyFrame(AwaitFinallyFrame parent, HashSet labelsOpt, TryStatement TryStatement, BooCodeBuilder builder, int depth) + { + Debug.Assert(parent != null); + Debug.Assert(TryStatement != null); + + ParentOpt = parent; + _labelsOpt = labelsOpt; + _tryStatementOpt = TryStatement; + _builder = builder; + _tryDepth = depth; + } + + private bool IsRoot() + { + return ParentOpt == null; + } + + // returns a proxy for a label if branch must be hijacked to run finally + // otherwise returns same label back + public InternalLabel ProxyLabelIfNeeded(InternalLabel label) + { + // no need to proxy a label in the current frame or when we are at the root + if (IsRoot() || (_labelsOpt != null && _labelsOpt.Contains(label))) + { + return label; + } + + var proxyLabels = this.proxyLabels; + var proxiedLabels = this.proxiedLabels; + if (proxyLabels == null) + { + this.proxyLabels = proxyLabels = new Dictionary(); + this.proxiedLabels = proxiedLabels = new List(); + } + + InternalLabel proxy; + if (!proxyLabels.TryGetValue(label, out proxy)) + { + proxy = _builder.CreateLabel(label.LabelStatement, "proxy" + label.Name, _tryDepth); + proxyLabels.Add(label, proxy); + proxiedLabels.Add(label); + } + + return proxy; + } + + public InternalLabel ProxyReturnIfNeeded( + Method containingMethod, + Node valueOpt, + out InternalLocal retVal) + { + retVal = null; + + // no need to proxy returns at the root + if (IsRoot()) + { + return null; + } + + var returnProxy = returnProxyLabel; + if (returnProxy == null) + { + returnProxyLabel = returnProxy = _builder.CreateLabel(valueOpt ?? containingMethod, "returnProxy", _tryDepth); + } + + if (valueOpt != null) + { + retVal = returnValue; + if (retVal == null) + { + Debug.Assert(_tryStatementOpt != null); + returnValue = retVal = _builder.DeclareTempLocal( + containingMethod, + ((ITypedEntity)valueOpt.Entity).Type); + } + } + + return returnProxy; + } + } + + private sealed class AwaitCatchFrame + { + // object, stores the original caught exception + // used to initialize the exception source inside the handler + // also used in rethrow statements + public readonly InternalLocal pendingCaughtException; + + // int, stores the number of pending catch + // 0 - means no catches are pending. + public readonly InternalLocal pendingCatch; + + // synthetic handlers produced by catch rewrite. + // they will become switch sections when pending exception is dispatched. + public readonly List handlers; + + // when catch local must be used from a filter + // we need to "hoist" it up to ensure that both the filter + // and the catch access the same variable. + // NOTE: it must be the same variable, not just same value. + // The difference would be observable if filter mutates the variable + // or/and if a variable gets lifted into a closure. + private readonly Dictionary _hoistedLocals; + private readonly List _orderedHoistedLocals; + + private readonly Method _currentMethod; + + public AwaitCatchFrame(TypeSystemServices tss, BooCodeBuilder builder, Method currentMethod) + { + pendingCaughtException = builder.DeclareTempLocal(currentMethod, tss.ObjectType); + pendingCatch = builder.DeclareTempLocal(currentMethod, tss.IntType); + + handlers = new List(); + _hoistedLocals = new Dictionary(); + _orderedHoistedLocals = new List(); + _currentMethod = currentMethod; + } + + public void HoistLocal(InternalLocal local, BooCodeBuilder F) + { + if (!_hoistedLocals.Keys.Any(l => l.Name == local.Name && l.Type == local.Type)) + { + _hoistedLocals.Add(local, local); + _orderedHoistedLocals.Add(local); + return; + } + + // code uses "await" in two sibling catches with exception filters + // locals with same names and types may cause problems if they are lifted + // and become fields with identical signatures. + // To avoid such problems we will mangle the name of the second local. + // This will only affect debugging of this extremely rare case. + var newLocal = F.DeclareTempLocal(_currentMethod, local.Type); + + _hoistedLocals.Add(local, newLocal); + _orderedHoistedLocals.Add(newLocal); + } + + public IEnumerable GetHoistedLocals() + { + return _orderedHoistedLocals; + } + + public bool TryGetHoistedLocal(InternalLocal originalLocal, out InternalLocal hoistedLocal) + { + return _hoistedLocals.TryGetValue(originalLocal, out hoistedLocal); + } + } + } +} diff --git a/src/Boo.Lang.Compiler/Steps/AsyncAwait/AsyncMethodBuilderMemberCollection.cs b/src/Boo.Lang.Compiler/Steps/AsyncAwait/AsyncMethodBuilderMemberCollection.cs new file mode 100644 index 000000000..b4cd6e71d --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/AsyncAwait/AsyncMethodBuilderMemberCollection.cs @@ -0,0 +1,171 @@ +using System.Linq; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.TypeSystem; + +namespace Boo.Lang.Compiler.Steps.AsyncAwait +{ + /// + /// Async methods have both a return type (void, Task, or Task<T>) and a 'result' type, which is the + /// operand type of any return expressions in the async method. The result type is void in the case of + /// Task-returning and void-returning async methods, and T in the case of Task<T>-returning async + /// methods. + /// + /// System.Runtime.CompilerServices provides a collection of async method builders that are used in the + /// generated code of async methods to create and manipulate the async method's task. There are three + /// distinct async method builder types, one of each async return type: AsyncVoidMethodBuilder, + /// AsyncTaskMethodBuilder, and AsyncTaskMethodBuilder<T>. + /// + /// AsyncMethodBuilderMemberCollection provides a common mechanism for accessing the well-known members of + /// each async method builder type. This avoids having to inspect the return style of the current async method + /// to pick the right async method builder member during async rewriting. + /// + /// Adapted from Microsoft.CodeAnalysis.CSharp.AsyncMethodBuilderMemberCollection in the Roslyn codebase + /// + public struct AsyncMethodBuilderMemberCollection + { + /// + /// The builder's constructed type. + /// + public readonly IType BuilderType; + + /// + /// The result type of the constructed task: T for Task<T>, void otherwise. + /// + public readonly IType ResultType; + + /// + /// Create an instance of the method builder. + /// + public readonly IMethod CreateBuilder; + + /// + /// Binds an exception to the method builder. + /// + public readonly IMethod SetException; + + /// + /// Marks the method builder as successfully completed, and sets the result if method is Task<T>-returning. + /// + public readonly IMethod SetResult; + + /// + /// Schedules the state machine to proceed to the next action when the specified awaiter completes. + /// + public readonly IMethod AwaitOnCompleted; + + /// + /// Schedules the state machine to proceed to the next action when the specified awaiter completes. This method can be called from partially trusted code. + /// + public readonly IMethod AwaitUnsafeOnCompleted; + + /// + /// Begins running the builder with the associated state machine. + /// + public readonly IMethod Start; + + /// + /// Associates the builder with the specified state machine. + /// + public readonly IMethod SetStateMachine; + + /// + /// Get the constructed task for a Task-returning or Task<T>-returning async method. + /// + public readonly IProperty Task; + + private AsyncMethodBuilderMemberCollection( + IType builderType, + IType resultType, + IMethod createBuilder, + IMethod setException, + IMethod setResult, + IMethod awaitOnCompleted, + IMethod awaitUnsafeOnCompleted, + IMethod start, + IMethod setStateMachine, + IProperty task) + { + BuilderType = builderType; + ResultType = resultType; + CreateBuilder = createBuilder; + SetException = setException; + SetResult = setResult; + AwaitOnCompleted = awaitOnCompleted; + AwaitUnsafeOnCompleted = awaitUnsafeOnCompleted; + Start = start; + SetStateMachine = setStateMachine; + Task = task; + } + + public static bool TryCreate(TypeSystemServices tss, Method method, IType genericArg, + out AsyncMethodBuilderMemberCollection collection) + { + if (ContextAnnotations.IsAsync(method)) + { + var returnType = (IType) method.ReturnType.Entity; + if (returnType == tss.VoidType) + return TryCreateVoid(tss, out collection); + if (returnType == tss.TaskType) + return TryCreateTask(tss, out collection); + if (returnType.ConstructedInfo != null && + returnType.ConstructedInfo.GenericDefinition == tss.GenericTaskType) + return TryCreateGenericTask(tss, genericArg, out collection); + } + + throw CompilerErrorFactory.InvalidAsyncType(method.ReturnType); + } + + private static bool TryCreate(IType builderType, IType resultType, + out AsyncMethodBuilderMemberCollection collection) + { + var members = builderType.GetMembers().OfType().Where(m => m.IsPublic).ToDictionary(m => m.Name); + var task = members.ContainsKey("Task") ? (IProperty) members["Task"] : null; + collection = new AsyncMethodBuilderMemberCollection( + builderType, + resultType, + (IMethod) members["Create"], + (IMethod) members["SetException"], + (IMethod) members["SetResult"], + (IMethod) members["AwaitOnCompleted"], + (IMethod) members["AwaitUnsafeOnCompleted"], + (IMethod) members["Start"], + (IMethod) members["SetStateMachine"], + task); + + return true; + } + + private static bool TryCreateGenericTask(TypeSystemServices tss, + IType genericArg, out AsyncMethodBuilderMemberCollection collection) + { + var builderType = tss.AsyncGenericTaskMethodBuilderType.GenericInfo.ConstructType(genericArg); + return TryCreate( + builderType, + genericArg, + out collection); + } + + private static bool TryCreateTask(TypeSystemServices tss, + out AsyncMethodBuilderMemberCollection collection) + { + var builderType = tss.AsyncTaskMethodBuilderType; + var resultType = tss.VoidType; + return TryCreate( + builderType, + resultType, + out collection); + } + + private static bool TryCreateVoid(TypeSystemServices tss, + out AsyncMethodBuilderMemberCollection collection) + { + var builderType = tss.AsyncVoidMethodBuilderType; + var resultType = tss.VoidType; + return TryCreate( + builderType, + resultType, + out collection); + } + + } +} diff --git a/src/Boo.Lang.Compiler/Steps/AsyncAwait/AsyncMethodProcessor.cs b/src/Boo.Lang.Compiler/Steps/AsyncAwait/AsyncMethodProcessor.cs new file mode 100644 index 000000000..dff102370 --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/AsyncAwait/AsyncMethodProcessor.cs @@ -0,0 +1,744 @@ +using System; +using System.Diagnostics; +using System.Linq; +using System.Runtime.CompilerServices; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.Steps.StateMachine; +using Boo.Lang.Compiler.TypeSystem; +using Boo.Lang.Compiler.TypeSystem.Builders; +using Boo.Lang.Compiler.TypeSystem.Generics; +using Boo.Lang.Compiler.TypeSystem.Internal; +using Boo.Lang.Environments; + +namespace Boo.Lang.Compiler.Steps.AsyncAwait +{ + using System.Collections.Generic; + + internal sealed class AsyncMethodProcessor : MethodToStateMachineTransformer + { + /// + /// The field of the generated async class used to store the async method builder: an instance of + /// , , or depending on the + /// return type of the async method. + /// + private Field _asyncMethodBuilderField; + + /// + /// A collection of well-known members for the current async method builder. + /// + private AsyncMethodBuilderMemberCollection _asyncMethodBuilderMemberCollection; + + /// + /// The exprReturnLabel is used to label the return handling code at the end of the async state-machine + /// method. Return expressions are rewritten as unconditional branches to exprReturnLabel. + /// + private InternalLabel _exprReturnLabel; + + /// + /// The label containing a return from the method when the async method has not completed. + /// + private InternalLabel _exitLabel; + + /// + /// The field of the generated async class used in generic task returning async methods to store the value + /// of rewritten return expressions. The return-handling code then uses SetResult on the async method builder + /// to make the result available to the caller. + /// + private InternalLocal _exprRetValue; + + /// + /// Cached "state" of the state machine within the MoveNext method. We work with a copy of + /// the state to avoid shared mutable state between threads. (Two threads can be executing + /// in a Task's MoveNext method because an awaited task may complete after the awaiter has + /// tested whether the subtask is complete but before the awaiter has returned) + /// + private InternalLocal _cachedState; + + /// + /// Used to track whether or not a method contains await expressions, for the purpose of emitting a + /// warning if it does not contain any. + /// + private bool _seenAwait; + + private readonly Dictionary _awaiterFields; + private int _nextAwaiterId; + + private bool _isGenericTask; + + internal AsyncMethodProcessor( + CompilerContext context, + InternalMethod method) + : base(context, method) + { + _awaiterFields = new Dictionary(); + _nextAwaiterId = 0; + } + + public override void Run() + { + base.Run(); + FixAsyncMethodBody(_stateMachineConstructorInvocation); + } + + private void FixAsyncMethodBody(MethodInvocationExpression stateMachineConstructorInvocation) + { + var method = _method.Method; + // If the async method's result type is a type parameter of the method, then the AsyncTaskMethodBuilder + // needs to use the method's type parameters inside the rewritten method body. All other methods generated + // during async rewriting are members of the synthesized state machine struct, and use the type parameters + // structs type parameters. + AsyncMethodBuilderMemberCollection methodScopeAsyncMethodBuilderMemberCollection; + if (!AsyncMethodBuilderMemberCollection.TryCreate( + TypeSystemServices, + method, + MethodGenArg(false), + out methodScopeAsyncMethodBuilderMemberCollection)) + { + throw new NotImplementedException("Custom async patterns are not supported"); + } + + var bodyBuilder = new Block(); + var builderVariable = CodeBuilder.DeclareTempLocal(method, methodScopeAsyncMethodBuilderMemberCollection.BuilderType); + + var stateMachineType = stateMachineConstructorInvocation.ExpressionType; + + var stateMachineVariable = CodeBuilder.DeclareLocal( + method, + UniqueName("async"), + stateMachineType); + + bodyBuilder.Add(CodeBuilder.CreateAssignment( + CodeBuilder.CreateLocalReference(stateMachineVariable), + stateMachineConstructorInvocation)); + + // local.$builder = System.Runtime.CompilerServices.AsyncTaskMethodBuilder.Create(); + bodyBuilder.Add( + CodeBuilder.CreateAssignment( + CodeBuilder.CreateMemberReference( + CodeBuilder.CreateLocalReference(stateMachineVariable), + ExternalFieldEntity((IField)_asyncMethodBuilderField.Entity, stateMachineType)), + CodeBuilder.CreateMethodInvocation(methodScopeAsyncMethodBuilderMemberCollection.CreateBuilder))); + + // local.$stateField = NotStartedStateMachine + bodyBuilder.Add( + CodeBuilder.CreateAssignment( + CodeBuilder.CreateMemberReference( + CodeBuilder.CreateLocalReference(stateMachineVariable), + ExternalFieldEntity(_state, stateMachineType)), + CodeBuilder.CreateIntegerLiteral(StateMachineStates.NotStartedStateMachine))); + + bodyBuilder.Add( + CodeBuilder.CreateAssignment( + CodeBuilder.CreateLocalReference(builderVariable), + CodeBuilder.CreateMemberReference( + CodeBuilder.CreateLocalReference(stateMachineVariable), + ExternalFieldEntity((IField)_asyncMethodBuilderField.Entity, stateMachineType)))); + + // local.$builder.Start(ref local) -- binding to the method AsyncTaskMethodBuilder.Start() + bodyBuilder.Add( + CodeBuilder.CreateMethodInvocation( + CodeBuilder.CreateLocalReference(builderVariable), + methodScopeAsyncMethodBuilderMemberCollection.Start.GenericInfo.ConstructMethod(stateMachineType), + CodeBuilder.CreateLocalReference(stateMachineVariable))); + + var methodBuilderField = stateMachineType.ConstructedInfo == null + ? (IField) _asyncMethodBuilderField.Entity + : stateMachineType.ConstructedInfo.Map((IField) _asyncMethodBuilderField.Entity); + bodyBuilder.Add(method.ReturnType.Entity == TypeSystemServices.VoidType + ? new ReturnStatement() + : new ReturnStatement( + CodeBuilder.CreateMethodInvocation( + CodeBuilder.CreateMemberReference( + CodeBuilder.CreateLocalReference(stateMachineVariable), + methodBuilderField), + methodScopeAsyncMethodBuilderMemberCollection.Task.GetGetMethod()))); + + _method.Method.Body = bodyBuilder; + } + + private static IField ExternalFieldEntity(IField field, IType stateMachineType) + { + if (stateMachineType.ConstructedInfo != null) + { + field = (IField)stateMachineType.ConstructedInfo.Map(field); + } + return field; + } + + private Field GetAwaiterField(IType awaiterType) + { + Field result; + + // Awaiters of the same type always share the same slot, regardless of what await expressions they belong to. + // Even in case of nested await expressions only one awaiter is active. + // So we don't need to tie the awaiter variable to a particular await expression and only use its type + // to find the previous awaiter field. + if (!_awaiterFields.TryGetValue(awaiterType, out result)) + { + int slotIndex = _nextAwaiterId++; + + string fieldName = Context.GetUniqueName("_awaiter", slotIndex.ToString()); + + result = _stateMachineClass.AddField(fieldName, awaiterType); + _awaiterFields.Add(awaiterType, result); + } + + return result; + } + + /// + /// Generate the body for MoveNext(). + /// + protected override void CreateMoveNext() + { + Method asyncMethod = _method.Method; + + BooMethodBuilder methodBuilder = _stateMachineClass.AddVirtualMethod("MoveNext", TypeSystemServices.VoidType); + methodBuilder.Method.LexicalInfo = asyncMethod.LexicalInfo; + _moveNext = methodBuilder.Entity; + + TransformLocalsIntoFields(asyncMethod); + TransformParametersIntoFieldsInitializedByConstructor(_method.Method); + + _exprReturnLabel = CodeBuilder.CreateLabel(methodBuilder.Method, "exprReturn", 0); + _exitLabel = CodeBuilder.CreateLabel(methodBuilder.Method, "exitLabel", 0); + _isGenericTask = _method.ReturnType.ConstructedInfo != null; + _exprRetValue = _isGenericTask + ? CodeBuilder.DeclareTempLocal(_moveNext.Method, _methodToStateMachineMapper.MapType(_asyncMethodBuilderMemberCollection.ResultType)) + : null; + + _cachedState = CodeBuilder.DeclareLocal(methodBuilder.Method, UniqueName("state"), + TypeSystemServices.IntType); + + _seenAwait = false; + var rewrittenBody = (Block)Visit(_method.Method.Body); + if (!_seenAwait) + Context.Warnings.Add(CompilerWarningFactory.AsyncNoAwait(_method.Method)); + + var bodyBuilder = methodBuilder.Body; + + bodyBuilder.Add(CodeBuilder.CreateAssignment( + CodeBuilder.CreateLocalReference(_cachedState), + CodeBuilder.CreateMemberReference(_state))); + + CheckForTryExcept(); + + Block bodyBlock; + if (_labels.Count > 0) + { + var dispatch = + CodeBuilder.CreateSwitch( + this.LexicalInfo, + CodeBuilder.CreateLocalReference(_cachedState), + _labels); + CheckTryJumps((MethodInvocationExpression)((ExpressionStatement)dispatch).Expression); + bodyBlock = new Block(dispatch, rewrittenBody); + } + else bodyBlock = rewrittenBody; + InternalLocal exceptionLocal; + bodyBuilder.Add(CodeBuilder.CreateTryExcept( + this.LexicalInfo, + bodyBlock, + new ExceptionHandler + { + Declaration = CodeBuilder.CreateDeclaration( + methodBuilder.Method, + UniqueName("exception"), + TypeSystemServices.ExceptionType, + out exceptionLocal), + Block = new Block( + CodeBuilder.CreateFieldAssignment( + this.LexicalInfo, + _state, + CodeBuilder.CreateIntegerLiteral(StateMachineStates.FinishedStateMachine)), + new ExpressionStatement( + CodeBuilder.CreateMethodInvocation( + CodeBuilder.CreateMemberReference( + CodeBuilder.CreateSelfReference(_stateMachineClass.Entity), + (IField)_asyncMethodBuilderField.Entity), + _asyncMethodBuilderMemberCollection.SetException, + CodeBuilder.CreateLocalReference(exceptionLocal))), + GenerateReturn()) + })); + + // ReturnLabel (for the rewritten return expressions in the user's method body) + bodyBuilder.Add(_exprReturnLabel.LabelStatement); + + // this.state = finishedState + bodyBuilder.Add(CodeBuilder.CreateFieldAssignment( + this.LexicalInfo, + _state, + CodeBuilder.CreateIntegerLiteral(StateMachineStates.FinishedStateMachine))); + + // builder.SetResult([RetVal]) + var setResultInvocation = CodeBuilder.CreateMethodInvocation( + CodeBuilder.CreateMemberReference( + CodeBuilder.CreateSelfReference(_stateMachineClass.Entity), + (IField)_asyncMethodBuilderField.Entity), + _asyncMethodBuilderMemberCollection.SetResult); + if (_isGenericTask) + setResultInvocation.Arguments.Add(CodeBuilder.CreateLocalReference(_exprRetValue)); + + bodyBuilder.Add(new ExpressionStatement(setResultInvocation)); + + // this code is hidden behind a hidden sequence point. + bodyBuilder.Add(_exitLabel.LabelStatement); + bodyBuilder.Add(new ReturnStatement()); + } + + private void CheckForTryExcept() + { + var handlerBlocks = _convertedTryStatements.Where(ct => ct._handlers.Count > 0).ToArray(); + foreach (var handlerBlock in handlerBlocks) + { + var rep = handlerBlock._replacement; + var parent = rep.ParentNode; + var tryBlock = new TryStatement + { + ProtectedBlock = rep, + ExceptionHandlers = handlerBlock._handlers + }; + parent.Replace(rep, tryBlock); + } + } + + private void CheckTryJumps(MethodInvocationExpression dispatch) + { + var stateArg = dispatch.Arguments[0]; + IEnumerable> labels; + do + { + labels = dispatch.Arguments + .Skip(1) + .Select(arg => (InternalLabel) arg.Entity) + .Where(l => l.LabelStatement.GetAncestor() != null) + .GroupBy(l => l.LabelStatement.GetAncestor()); + foreach (var labelGroup in labels) + { + var newLabel = CodeBuilder.CreateLabel(dispatch, UniqueName("TryLabel")); + var parent = (Block) labelGroup.Key.ParentNode; + parent.Insert(parent.Statements.IndexOf(labelGroup.Key), newLabel.LabelStatement); + AstAnnotations.SetTryBlockDepth(newLabel.LabelStatement, parent.GetAncestors().Count()); + IfStatement innerDispatch = null; + foreach (var label in labelGroup) + { + var depth = label.LabelStatement.GetAncestors().Count(); + var switchArg = dispatch.Arguments.First(arg => arg.Entity == label); + switchArg.Entity = newLabel; + innerDispatch = new IfStatement( + CodeBuilder.CreateBoundBinaryExpression( + TypeSystemServices.BoolType, + BinaryOperatorType.Equality, + stateArg.CloneNode(), + new IntegerLiteralExpression(dispatch.Arguments.IndexOf(switchArg) - 1)), + new Block(CodeBuilder.CreateGoto(label, depth)), + innerDispatch != null ? new Block(innerDispatch) : null); + labelGroup.Key.ProtectedBlock.Insert(0, innerDispatch); + } + } + } while (labels.Any()); + } + + private Statement GenerateReturn() + { + return CodeBuilder.CreateGoto(_exitLabel, 1); + } + + #region Visitors + + public override void OnExpressionStatement(ExpressionStatement node) + { + if (node.Expression.NodeType == NodeType.AwaitExpression) + { + ReplaceCurrentNode(VisitAwaitExpression((AwaitExpression)node.Expression, null)); + return; + } + + if (node.Expression.NodeType == NodeType.BinaryExpression + && ((BinaryExpression)node.Expression).Operator == BinaryOperatorType.Assign) + { + var expression = (BinaryExpression)node.Expression; + if (expression.Right.NodeType == NodeType.AwaitExpression) + { + ReplaceCurrentNode(VisitAwaitExpression((AwaitExpression)expression.Right, expression.Left)); + return; + } + } + base.OnExpressionStatement(node); + } + + public override void OnAwaitExpression(AwaitExpression node) + { + // await expressions must, by now, have been moved to the top level. + throw new ArgumentException("Should be unreachable"); + } + + private Block VisitAwaitExpression(AwaitExpression node, Expression resultPlace) + { + _seenAwait = true; + var expression = Visit(node.BaseExpression); + resultPlace = Visit(resultPlace); + var getAwaiter = (IMethod)node["$GetAwaiter"]; + var getResult = (IMethod)node["$GetResult"]; + if (getAwaiter == null) + { + var resolveList = new List(); + if (expression.ExpressionType.Resolve(resolveList, "GetAwaiter", EntityType.Method)) + getAwaiter = resolveList.Cast().First(m => m.GetParameters().Length == 0 && m.IsPublic); + else + throw CompilerErrorFactory.MissingGetAwaiter(expression); + getResult = getAwaiter.ReturnType.GetMembers().OfType().Single(m => m.Name.Equals("GetResult")); + } + Debug.Assert(getAwaiter != null && getResult != null); + var isCompletedProp = getAwaiter.ReturnType.GetMembers().OfType().SingleOrDefault(p => p.Name.Equals("IsCompleted")); + if (isCompletedProp == null) + { + var resolveList = new List(); + if (getAwaiter.ReturnType.Resolve(resolveList, "IsCompleted", EntityType.Property)) + isCompletedProp = resolveList.Cast().First(p => p.GetParameters().Length == 0 && p.IsPublic); + if (isCompletedProp == null) + throw new ArgumentException("No valid IsCompleted property found"); + } + var isCompletedMethod = isCompletedProp.GetGetMethod(); + IType type; + if (IsCustomTaskType(expression.ExpressionType)) + type = getResult.ReturnType; + else type = expression.ExpressionType.ConstructedInfo == null + ? TypeSystemServices.VoidType + : expression.ExpressionType.ConstructedInfo.GenericArguments[0]; + + // The awaiter temp facilitates EnC method remapping and thus have to be long-lived. + // It transfers the awaiter objects from the old version of the MoveNext method to the new one. + var awaiterType = getAwaiter.ReturnType; + var awaiterTemp = CodeBuilder.DeclareTempLocal(_moveNext.Method, awaiterType); + + var getAwaiterInvocation = getAwaiter.IsExtension ? + CodeBuilder.CreateMethodInvocation(getAwaiter, expression) : + CodeBuilder.CreateMethodInvocation(expression, getAwaiter); + var awaitIfIncomplete = new Block( + // temp $awaiterTemp = .GetAwaiter(); + new ExpressionStatement( + CodeBuilder.CreateAssignment( + CodeBuilder.CreateLocalReference(awaiterTemp), + getAwaiterInvocation)), + + // if(!($awaiterTemp.IsCompleted)) { ... } + new IfStatement( + new UnaryExpression( + UnaryOperatorType.LogicalNot, + GenerateGetIsCompleted(awaiterTemp, isCompletedMethod)), + GenerateAwaitForIncompleteTask(awaiterTemp), + null)); + + TryStatementInfo currentTry = _tryStatementStack.Count > 0 ? _tryStatementStack.Peek() : null; + if (currentTry != null) + ConvertTryStatement(currentTry); + + var getResultCall = CodeBuilder.CreateMethodInvocation( + CodeBuilder.CreateLocalReference(awaiterTemp), + getResult); + + var nullAwaiter = CodeBuilder.CreateAssignment( + CodeBuilder.CreateLocalReference(awaiterTemp), + CodeBuilder.CreateDefaultInvocation(this.LexicalInfo, awaiterTemp.Type)); + if (resultPlace != null && type != TypeSystemServices.VoidType) + { + // $resultTemp = $awaiterTemp.GetResult(); + // $awaiterTemp = null; + // $resultTemp + InternalLocal resultTemp = CodeBuilder.DeclareTempLocal(_moveNext.Method, type); + return new Block( + awaitIfIncomplete, + new ExpressionStatement( + CodeBuilder.CreateAssignment(CodeBuilder.CreateLocalReference(resultTemp), getResultCall)), + new ExpressionStatement(nullAwaiter), + new ExpressionStatement( + CodeBuilder.CreateAssignment(resultPlace, CodeBuilder.CreateLocalReference(resultTemp)))); + } + + // $awaiterTemp.GetResult(); + // $awaiterTemp = null; + return new Block( + awaitIfIncomplete, + new ExpressionStatement(getResultCall), + new ExpressionStatement(nullAwaiter)); + } + + private bool IsCustomTaskType(IType type) + { + if (type == TypeSystemServices.VoidType || type == TypeSystemServices.TaskType) + return false; + if (type.ConstructedInfo != null && type.ConstructedInfo.GenericDefinition == TypeSystemServices.GenericTaskType) + return false; + return true; + } + + private Expression GenerateGetIsCompleted(InternalLocal awaiterTemp, IMethod getIsCompletedMethod) + { + return CodeBuilder.CreateMethodInvocation(CodeBuilder.CreateLocalReference(awaiterTemp), getIsCompletedMethod); + } + + private Block GenerateAwaitForIncompleteTask(InternalLocal awaiterTemp) + { + var stateNumber = _labels.Count; + var resumeLabel = CreateLabel(awaiterTemp.Node); + + IType awaiterFieldType = awaiterTemp.Type.IsVerifierReference() + ? TypeSystemServices.ObjectType + : awaiterTemp.Type; + + Field awaiterField = GetAwaiterField(awaiterFieldType); + + var blockBuilder = new Block(); + + // this.state = _cachedState = stateForLabel + blockBuilder.Add(new ExpressionStatement(SetStateTo(stateNumber))); + + blockBuilder.Add( + // this.<>t__awaiter = $awaiterTemp + CodeBuilder.CreateFieldAssignment( + awaiterField, + awaiterField.Type == awaiterTemp.Type + ? CodeBuilder.CreateLocalReference(awaiterTemp) + : CodeBuilder.CreateCast(awaiterFieldType, CodeBuilder.CreateLocalReference(awaiterTemp)))); + + blockBuilder.Add(GenerateAwaitOnCompleted(awaiterTemp.Type, awaiterTemp)); + + blockBuilder.Add(GenerateReturn()); + + blockBuilder.Add(resumeLabel); + AstAnnotations.SetTryBlockDepth(resumeLabel, blockBuilder.GetAncestors().Count()); + + var awaiterFieldRef = CodeBuilder.CreateMemberReference( + CodeBuilder.CreateSelfReference(_stateMachineClass.Entity), + (IField)awaiterField.Entity); + blockBuilder.Add( + // $awaiterTemp = this.<>t__awaiter or $awaiterTemp = (AwaiterType)this.<>t__awaiter + // $this.<>t__awaiter = null; + CodeBuilder.CreateAssignment( + CodeBuilder.CreateLocalReference(awaiterTemp), + awaiterTemp.Type == awaiterField.Type + ? awaiterFieldRef + : CodeBuilder.CreateCast(awaiterTemp.Type, awaiterFieldRef))); + + blockBuilder.Add( + CodeBuilder.CreateFieldAssignment( + awaiterField, + CodeBuilder.CreateDefaultInvocation(LexicalInfo.Empty, ((ITypedEntity)awaiterField.Entity).Type))); + + // this.state = _cachedState = NotStartedStateMachine + blockBuilder.Add(new ExpressionStatement(SetStateTo(StateMachineStates.NotStartedStateMachine))); + + return blockBuilder; + } + + private readonly IType ICriticalNotifyCompletionType = + My.Instance.Map(typeof(ICriticalNotifyCompletion)); + + private Statement GenerateAwaitOnCompleted(IType loweredAwaiterType, InternalLocal awaiterTemp) + { + // this.builder.AwaitOnCompleted(ref $awaiterTemp, ref this) + // or + // this.builder.AwaitOnCompleted(ref $awaiterArrayTemp[0], ref this) + var localEntity = MapNestedType(_stateMachineClass.Entity); + + InternalLocal selfTemp = _stateMachineClass.Entity.IsValueType ? null : CodeBuilder.DeclareTempLocal(_moveNext.Method, _stateMachineClass.Entity); + + var useUnsafeOnCompleted = loweredAwaiterType.IsAssignableFrom(ICriticalNotifyCompletionType); + + var onCompleted = (useUnsafeOnCompleted ? + _asyncMethodBuilderMemberCollection.AwaitUnsafeOnCompleted : + _asyncMethodBuilderMemberCollection.AwaitOnCompleted) + .GenericInfo.ConstructMethod(loweredAwaiterType, localEntity); + + var result = + CodeBuilder.CreateMethodInvocation( + CodeBuilder.CreateMemberReference( + CodeBuilder.CreateSelfReference(localEntity), + (IField)_asyncMethodBuilderField.Entity), + onCompleted, + CodeBuilder.CreateLocalReference(awaiterTemp), + selfTemp != null ? + CodeBuilder.CreateLocalReference(selfTemp) : + (Expression)CodeBuilder.CreateSelfReference(localEntity)); + + if (selfTemp != null) + { + result = CodeBuilder.CreateEvalInvocation( + LexicalInfo.Empty, + CodeBuilder.CreateAssignment( + CodeBuilder.CreateLocalReference(selfTemp), + CodeBuilder.CreateSelfReference(localEntity)), + result); + } + + return new ExpressionStatement(result); + } + + private static IType MapNestedType(IType entity) + { + if (entity.DeclaringEntity == null || entity.DeclaringEntity.EntityType != EntityType.Type) + return entity; + + var parentType = MapNestedType((IType) entity.DeclaringEntity); + GenericConstructedType constructedParent = null; + if (parentType.GenericInfo != null && parentType.ConstructedInfo == null) + constructedParent = (GenericConstructedType)parentType.GenericInfo.ConstructType(parentType.GenericInfo.GenericParameters); + constructedParent = constructedParent ?? parentType as GenericConstructedType; + IType result = constructedParent == null || entity.DeclaringEntity == constructedParent + ? entity + : GenericMappedType.Create(entity, constructedParent); + + if (result.GenericInfo != null && result.ConstructedInfo == null) + result = result.GenericInfo.ConstructType(result.GenericInfo.GenericParameters); + return result; + } + + public override void OnReturnStatement(ReturnStatement node) + { + Statement result = CodeBuilder.CreateGoto(_exprReturnLabel, _tryStatementStack.Count + 1); + if (node.Expression != null) + { + Debug.Assert(_isGenericTask || node.Expression.ExpressionType == TypeSystemServices.TaskType); + if (_exprRetValue == null) + _exprRetValue = CodeBuilder.DeclareTempLocal(_moveNext.Method, TypeSystemServices.TaskType); + + result = new Block( + new ExpressionStatement( + CodeBuilder.CreateAssignment( + CodeBuilder.CreateLocalReference(_exprRetValue), + Visit(node.Expression))), + result); + } + ReplaceCurrentNode(result); + } + + #endregion Visitors + + #region AbstractMemberImplementation + + protected override void PropagateReferences() + { + var ctor = _stateMachineConstructor; + // propagate the necessary parameters from the original method to the state machine + foreach (var parameter in _method.Method.Parameters) + { + var myParam = MapParamType(parameter); + + var entity = (InternalParameter)myParam.Entity; + if (entity.IsUsed) + { + _stateMachineConstructorInvocation.Arguments.Add(CodeBuilder.CreateReference(myParam)); + } + } + // propagate the external self reference if necessary + if (_externalSelfField != null) + { + var type = _method.DeclaringType; + _stateMachineConstructorInvocation.Arguments.Add( + CodeBuilder.CreateSelfReference(TypeSystemServices.SelfMapGenericType(type))); + } + } + + protected override BinaryExpression SetStateTo(int num) + { + // this.state = _cachedState = NotStartedStateMachine + return (BinaryExpression) CodeBuilder.CreateFieldAssignmentExpression( + _state, + CodeBuilder.CreateAssignment( + CodeBuilder.CreateLocalReference(_cachedState), + CodeBuilder.CreateIntegerLiteral(num))); + } + + private string _className; + + protected override string StateMachineClassName + { + get { return _className ?? (_className = Context.GetUniqueName("Async")); } + } + + protected override void SaveStateMachineClass(ClassDefinition cd) + { + _method.Method.DeclaringType.Members.Add(cd); + } + + protected override void SetupStateMachine() + { + _stateMachineClass.AddBaseType(TypeSystemServices.ValueTypeType); + _stateMachineClass.AddBaseType(TypeSystemServices.IAsyncStateMachineType); + _stateMachineClass.Modifiers |= TypeMemberModifiers.Final; + var ctr = TypeSystemServices.Map(typeof(AsyncStateMachineAttribute)).GetConstructors().Single(); + _method.Method.Attributes.Add( + CodeBuilder.CreateAttribute( + ctr, + CodeBuilder.CreateTypeofExpression(_stateMachineClass.Entity))); + AsyncMethodBuilderMemberCollection.TryCreate( + TypeSystemServices, + _method.Method, + MethodGenArg(true), + out _asyncMethodBuilderMemberCollection); + + _state = (IField)_stateMachineClass.AddInternalField(UniqueName("State"), TypeSystemServices.IntType).Entity; + _asyncMethodBuilderField = _stateMachineClass.AddInternalField( + UniqueName("Builder"), + _asyncMethodBuilderMemberCollection.BuilderType); + CreateSetStateMachine(); + PreprocessMethod(); + } + + private IType MethodGenArg(bool remap) + { + var rt = _method.ReturnType; + if (rt.ConstructedInfo != null) + { + var result = rt.ConstructedInfo.GenericArguments[0]; + return remap ? _methodToStateMachineMapper.MapType(result) : result; + } + if (rt.GenericInfo != null) + { + return rt.GenericInfo.GenericParameters[0]; + } + return null; + } + + private void PreprocessMethod() + { + if (ContextAnnotations.AwaitInExceptionHandler(_method.Method)) + { + AsyncExceptionHandlerRewriter.Rewrite(_method.Method); + } + AwaitExpressionSpiller.Rewrite(_method.Method); + } + + private void CreateSetStateMachine() + { + var method = _stateMachineClass.AddMethod("SetStateMachine", TypeSystemServices.VoidType); + var stateMachineIntfType = TypeSystemServices.IAsyncStateMachineType; + var input = method.AddParameter("stateMachine", stateMachineIntfType, false); + method.Modifiers |= TypeMemberModifiers.Virtual | TypeMemberModifiers.Final; + method.Method.ExplicitInfo = new ExplicitMemberInfo + { + InterfaceType = (SimpleTypeReference)CodeBuilder.CreateTypeReference(stateMachineIntfType), + Entity = TypeSystemServices.IAsyncStateMachineType.GetMembers().OfType() + .Single(m => m.Name.Equals("SetStateMachine")) + }; + method.Body.Add( + CodeBuilder.CreateMethodInvocation( + CodeBuilder.CreateMemberReference( + CodeBuilder.CreateSelfReference(_stateMachineClass.Entity), + (IField) _asyncMethodBuilderField.Entity), + _asyncMethodBuilderMemberCollection.SetStateMachine, + CodeBuilder.CreateReference(input))); + } + + protected override BooMethodBuilder CreateConstructor(BooClassBuilder builder) + { + BooMethodBuilder constructor = builder.AddConstructor(); + return constructor; + } + + #endregion + + } +} + diff --git a/src/Boo.Lang.Compiler/Steps/AsyncAwait/AsyncTypeHelper.cs b/src/Boo.Lang.Compiler/Steps/AsyncAwait/AsyncTypeHelper.cs new file mode 100644 index 000000000..4f14fcf3e --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/AsyncAwait/AsyncTypeHelper.cs @@ -0,0 +1,34 @@ +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Boo.Lang.Compiler.TypeSystem; +using Boo.Lang.Environments; + +namespace Boo.Lang.Compiler.Steps.AsyncAwait +{ + internal static class AsyncTypeHelper + { + private static readonly IType _typeReferenceType; + private static readonly IType _argIteratorType; + private static readonly IType _runtimeArgumentHandleType; + + static AsyncTypeHelper() + { + var tss = My.Instance; + _typeReferenceType = tss.Map(typeof(System.TypedReference)); + _argIteratorType = tss.Map(typeof(System.ArgIterator)); + _runtimeArgumentHandleType = tss.Map(typeof(System.RuntimeArgumentHandle)); + } + + public static bool IsRestrictedType(this IType type) + { + return type == _typeReferenceType || type == _argIteratorType || type == _runtimeArgumentHandleType; + } + + internal static bool IsVerifierReference(this IType type) + { + return !type.IsValueType && type.EntityType != EntityType.GenericParameter; + } + } +} diff --git a/src/Boo.Lang.Compiler/Steps/AsyncAwait/AwaitExpressionSpiller.cs b/src/Boo.Lang.Compiler/Steps/AsyncAwait/AwaitExpressionSpiller.cs new file mode 100644 index 000000000..1d1fd7384 --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/AsyncAwait/AwaitExpressionSpiller.cs @@ -0,0 +1,853 @@ +using System; +using System.Diagnostics; +using System.Linq; +using System.Text; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.Ast.Visitors; +using Boo.Lang.Compiler.TypeSystem; +using Boo.Lang.Compiler.TypeSystem.Internal; +using Boo.Lang.Environments; + +namespace Boo.Lang.Compiler.Steps.AsyncAwait +{ + using System.Collections.Generic; + + internal sealed class AwaitExpressionSpiller : DepthFirstTransformer + { + private const NodeType SpillSequenceBuilder = NodeType.CustomExpression; // NOTE: this node type is hijacked during this phase to represent BoundSpillSequenceBuilder + + private readonly BooCodeBuilder _F; + private readonly Method _currentMethod; + private readonly TypeSystemServices _tss; + + private AwaitExpressionSpiller(Method method, TypeSystemServices tss) + { + _F = My.Instance; + _currentMethod = method; + _tss = tss; + } + + private sealed class BoundSpillSequenceBuilder : CustomExpression + { + public readonly Expression Value; + + private List _locals; + private List _statements; + + public BoundSpillSequenceBuilder(Expression value = null) + { + Debug.Assert(value == null || value.NodeType != SpillSequenceBuilder); + Value = value; + ExpressionType = value == null ? null : value.ExpressionType; + } + + public bool HasStatements + { + get + { + return _statements != null; + } + } + + public bool HasLocals + { + get + { + return _locals != null; + } + } + + public override void Accept(IAstVisitor visitor) + { + throw new InvalidOperationException("Should be unreachable"); + } + + public List GetStatements() + { + return _statements; + } + + internal BoundSpillSequenceBuilder Update(Expression value) + { + var result = new BoundSpillSequenceBuilder(value) + { + _locals = _locals, + _statements = _statements + }; + return result; + } + + internal void Include(BoundSpillSequenceBuilder other) + { + if (other != null) + { + Extend(ref _locals, ref other._locals); + Extend(ref _statements, ref other._statements); + } + } + + private static void Extend(ref List left, ref List right) + { + if (right == null) + { + return; + } + + if (left == null) + { + left = right; + return; + } + + left.AddRange(right.Except(left)); + } + + public void AddLocal(InternalLocal local) + { + if (_locals == null) + { + _locals = new List(); + } + + if (local.Type.IsRestrictedType()) + { + CompilerContext.Current.Errors.Add( + CompilerErrorFactory.RestrictedAwaitType(local.Local, local.Type)); + } + + _locals.Add(local); + } + + public void AddStatement(Statement statement) + { + if (_statements == null) + { + _statements = new List(); + } + + _statements.Add(statement); + } + + internal void AddExpressions(IEnumerable expressions) + { + var existingHash = new HashSet(_statements.OfType().Select(es => es.Expression)); + foreach (var expression in expressions) + { + if (expression.NodeType == SpillSequenceBuilder) + { + var sb = (BoundSpillSequenceBuilder) expression; + Include(sb); + if (!existingHash.Contains(sb.Value)) + AddStatement(new ExpressionStatement(sb.LexicalInfo, sb.Value) { IsSynthetic = true }); + } + else if (!existingHash.Contains(expression)) + AddStatement(new ExpressionStatement(expression.LexicalInfo, expression) { IsSynthetic = true }); + } + } + } + + internal static Statement Rewrite(Method method) + { + var spiller = new AwaitExpressionSpiller(method, My.Instance); + var result = spiller.Visit(method.Body); + return result; + } + + private Expression VisitExpression(ref BoundSpillSequenceBuilder builder, Expression expression) + { + // wrap the node in a spill sequence to mark the fact that it must be moved up the tree. + // The caller will handle this node type if the result is discarded. + if (expression != null && expression.NodeType == NodeType.AwaitExpression) + { + // we force the await expression to be assigned to a temp variable + var awaitExpression = (AwaitExpression)expression; + awaitExpression.BaseExpression = VisitExpression(ref builder, awaitExpression.BaseExpression); + + var local = _F.DeclareTempLocal(_currentMethod, awaitExpression.ExpressionType); + var replacement = _F.CreateAssignment( + awaitExpression.LexicalInfo, + _F.CreateLocalReference(local), + awaitExpression); + if (builder == null) + { + builder = new BoundSpillSequenceBuilder(); + } + + builder.AddLocal(local); + builder.AddStatement(new ExpressionStatement(replacement)); + return _F.CreateLocalReference(local); + } + + var e = Visit(expression); + if (e == null || e.NodeType != SpillSequenceBuilder) + { + return e; + } + + var newBuilder = (BoundSpillSequenceBuilder)e; + if (builder == null) + { + builder = newBuilder.Update(null); + } + else + { + builder.Include(newBuilder); + } + + return newBuilder.Value; + } + + private static Expression UpdateExpression(BoundSpillSequenceBuilder builder, Expression expression) + { + if (builder == null) + { + return expression; + } + + Debug.Assert(builder.Value == null); + if (!builder.HasLocals && !builder.HasStatements) + { + return expression; + } + + return builder.Update(expression); + } + + private static void UpdateConditionalStatement(ConditionalStatement stmt) + { + var builder = (BoundSpillSequenceBuilder) stmt.Condition; + Debug.Assert(stmt.ParentNode.NodeType == NodeType.Block); + Debug.Assert(builder.Value.ExpressionType == My.Instance.BoolType); + var spills = new Block(builder.GetStatements().ToArray()); + var statements = ((Block)stmt.ParentNode).Statements; + statements.Insert(statements.IndexOf(stmt), spills); + stmt.Condition = builder.Value; + } + + private static Statement UpdateStatement(BoundSpillSequenceBuilder builder, Statement statement) + { + if (builder == null) + { + // statement doesn't contain any await + Debug.Assert(statement != null); + return statement; + } + + Debug.Assert(builder.Value == null); + if (statement != null) + { + builder.AddStatement(statement); + } + + var result = new Block(builder.GetStatements().ToArray()); + + return result; + } + + private readonly object AWAIT_SPILL_MARKER = new object(); + + private Expression Spill( + BoundSpillSequenceBuilder builder, + Expression expression, + bool isRef = false, + bool sideEffectsOnly = false) + { + Debug.Assert(builder != null); + + while (true) + { + switch (expression.NodeType) + { + case NodeType.ListLiteralExpression: + case NodeType.ArrayLiteralExpression: + Debug.Assert(!isRef); + Debug.Assert(!sideEffectsOnly); + var arrayInitialization = (ListLiteralExpression)expression; + var newInitializers = VisitExpressionList(ref builder, arrayInitialization.Items, forceSpill: true); + arrayInitialization.Items = newInitializers; + return arrayInitialization; + + case NodeType.HashLiteralExpression: + Debug.Assert(!isRef); + Debug.Assert(!sideEffectsOnly); + var hashInitialization = (HashLiteralExpression)expression; + var newInitializerPairs = VisitExpressionPairList(ref builder, hashInitialization.Items, forceSpill: true); + hashInitialization.Items = newInitializerPairs; + return hashInitialization; + + case SpillSequenceBuilder: + var sequenceBuilder = (BoundSpillSequenceBuilder)expression; + builder.Include(sequenceBuilder); + expression = sequenceBuilder.Value; + continue; + + case NodeType.SelfLiteralExpression: + case NodeType.SuperLiteralExpression: + if (isRef || !expression.ExpressionType.IsValueType) + { + return expression; + } + goto default; + + case NodeType.ParameterDeclaration: + if (isRef) + { + return expression; + } + goto default; + + case NodeType.ReferenceExpression: + var local = expression.Entity as InternalLocal; + if (local != null) + { + if (local.Local["SynthesizedKind"] == AWAIT_SPILL_MARKER || isRef) + { + return expression; + } + } + goto default; + + case NodeType.MemberReferenceExpression: + if (expression.Entity.EntityType == EntityType.Field) + { + var field = (IField)expression.Entity; + if (field.IsInitOnly) + { + if (field.IsStatic) return expression; + if (!field.DeclaringType.IsValueType) + { + // save the receiver; can get the field later. + var target = Spill(builder, + ((MemberReferenceExpression) expression).Target, + isRef && !field.Type.IsValueType, + sideEffectsOnly); + return _F.CreateMemberReference(target, field); + } + } + } + goto default; + + case NodeType.MethodInvocationExpression: + var mie = (MethodInvocationExpression) expression; + if (expression.Entity == BuiltinFunction.Eval) + { + builder.AddExpressions(mie.Arguments.Where(a => a != mie.Arguments.Last)); + expression = mie.Arguments.Last; + continue; + } + if (isRef) + { + Debug.Assert(mie.ExpressionType.IsPointer); + CompilerContext.Current.Errors.Add(CompilerErrorFactory.UnsafeReturnInAsync(mie)); + } + goto default; + + case NodeType.BoolLiteralExpression: + case NodeType.CharLiteralExpression: + case NodeType.DoubleLiteralExpression: + case NodeType.IntegerLiteralExpression: + case NodeType.NullLiteralExpression: + case NodeType.RELiteralExpression: + case NodeType.StringLiteralExpression: + case NodeType.TypeofExpression: + return expression; + + default: + if (expression.ExpressionType == _tss.VoidType || sideEffectsOnly) + { + builder.AddStatement(new ExpressionStatement(expression)); + return null; + } + var replacement = _F.DeclareTempLocal(_currentMethod, expression.ExpressionType); + + var assignToTemp = _F.CreateAssignment( + _F.CreateLocalReference(replacement), + expression); + + builder.AddLocal(replacement); + builder.AddStatement(new ExpressionStatement(assignToTemp)); + return _F.CreateLocalReference(replacement); + } + } + } + + private ExpressionCollection VisitExpressionList( + ref BoundSpillSequenceBuilder builder, + ExpressionCollection args, + IList refKinds = default(IList), + bool forceSpill = false, + bool sideEffectsOnly = false) + { + Visit(args); + + int lastSpill; + if (forceSpill) + { + lastSpill = args.Count - 1; + } + else + { + lastSpill = -1; + for (int i = args.Count - 1; i >= 0; i--) + { + if (args[i].NodeType == SpillSequenceBuilder) + { + lastSpill = i; + break; + } + } + } + + if (lastSpill == -1) + { + return args; + } + + if (builder == null) + { + builder = new BoundSpillSequenceBuilder(); + } + + for (int i = 0; i <= lastSpill; i++) + { + var refKind = refKinds != null && refKinds.Count > i && refKinds[i]; + var replacement = Spill(builder, args[i], refKind, sideEffectsOnly); + + Debug.Assert(sideEffectsOnly || replacement != null); + if (!sideEffectsOnly) + { + args[i] = replacement; + } + } + + return args; + } + + private ExpressionPairCollection VisitExpressionPairList( + ref BoundSpillSequenceBuilder builder, + ExpressionPairCollection args, + bool forceSpill = false, + bool sideEffectsOnly = false) + { + var args1 = new ExpressionCollection(); + args1.AddRange(args.Select(p => p.First)); + var args2 = new ExpressionCollection(); + args2.AddRange(args.Select(p => p.Second)); + args1 = VisitExpressionList(ref builder, args1, null, forceSpill, sideEffectsOnly); + args2 = VisitExpressionList(ref builder, args2, null, forceSpill, sideEffectsOnly); + args.Clear(); + args.AddRange(args1.Zip(args2, (l, r) => new ExpressionPair(l.LexicalInfo, l, r))); + return args; + } + + private SliceCollection VisitSliceCollection( + ref BoundSpillSequenceBuilder builder, + SliceCollection args) + { + foreach (var arg in args) + { + if (arg.Begin != null) + VisitExpression(ref builder, arg.Begin); + if (arg.End != null) + VisitExpression(ref builder, arg.End); + if (arg.Step != null) + VisitExpression(ref builder, arg.Step); + } + return args; + } + + #region Statement Visitors + + public override void OnRaiseStatement(RaiseStatement node) + { + BoundSpillSequenceBuilder builder = null; + node.Exception = VisitExpression(ref builder, node.Exception); + ReplaceCurrentNode(UpdateStatement(builder, node)); + } + + public override void OnExpressionStatement(ExpressionStatement node) + { + BoundSpillSequenceBuilder builder = null; + Expression expr; + + if (node.Expression.NodeType == NodeType.AwaitExpression) + { + // await expression with result discarded + var awaitExpression = (AwaitExpression)node.Expression; + awaitExpression.BaseExpression = VisitExpression(ref builder, awaitExpression.BaseExpression); + expr = awaitExpression; + } + else if (node.Expression.NodeType == NodeType.MethodInvocationExpression && + ((MethodInvocationExpression)node.Expression).Target.Entity == BuiltinFunction.Switch) + { + OnSwitch((MethodInvocationExpression)node.Expression, ref builder); + expr = node.Expression; + } + else + { + expr = VisitExpression(ref builder, node.Expression); + } + + Debug.Assert(expr != null); + Debug.Assert(builder == null || builder.Value == null); + node.Expression = expr; + ReplaceCurrentNode(UpdateStatement(builder, node)); + } + + public override void OnReturnStatement(ReturnStatement node) + { + BoundSpillSequenceBuilder builder = null; + node.Expression = VisitExpression(ref builder, node.Expression); + ReplaceCurrentNode(UpdateStatement(builder, node)); + } + + public override void OnIfStatement(IfStatement node) + { + base.OnIfStatement(node); + if (node.Condition.NodeType == SpillSequenceBuilder) + UpdateConditionalStatement(node); + } + + #endregion + + #region Expression Visitors + + public override void OnAwaitExpression(AwaitExpression node) + { + var builder = new BoundSpillSequenceBuilder(); + var replacement = VisitExpression(ref builder, node); + ReplaceCurrentNode(builder.Update(replacement)); + } + + public override void OnSlicingExpression(SlicingExpression node) + { + BoundSpillSequenceBuilder builder = null; + var target = VisitExpression(ref builder, node.Target); + + BoundSpillSequenceBuilder indicesBuilder = null; + var indices = VisitSliceCollection(ref indicesBuilder, node.Indices); + + if (indicesBuilder != null) + { + // spill the array if there were await expressions in the indices + if (builder == null) + { + builder = new BoundSpillSequenceBuilder(); + } + + target = Spill(builder, target); + } + + if (builder != null) + { + builder.Include(indicesBuilder); + indicesBuilder = builder; + } + node.Target = target; + node.Indices = indices; + ReplaceCurrentNode(UpdateExpression(indicesBuilder, node)); + } + + public override void OnListLiteralExpression(ListLiteralExpression node) + { + BoundSpillSequenceBuilder builder = null; + node.Items = VisitExpressionList(ref builder, node.Items); + ReplaceCurrentNode(UpdateExpression(builder, node)); + } + + public override void OnArrayLiteralExpression(ArrayLiteralExpression node) + { + OnListLiteralExpression(node); + } + + public override void OnHashLiteralExpression(HashLiteralExpression node) + { + + BoundSpillSequenceBuilder builder = null; + node.Items = VisitExpressionPairList(ref builder, node.Items); + ReplaceCurrentNode(UpdateExpression(builder, node)); + } + + public override void OnTryCastExpression(TryCastExpression node) + { + BoundSpillSequenceBuilder builder = null; + node.Target = VisitExpression(ref builder, node.Target); + ReplaceCurrentNode(UpdateExpression(builder, node)); + } + + private void OnAssignment(BinaryExpression node) + { + BoundSpillSequenceBuilder builder = null; + var right = VisitExpression(ref builder, node.Right); + Expression left; + if (builder == null || node.Left.Entity.EntityType == EntityType.Local) + { + left = VisitExpression(ref builder, node.Left); + } + else + { + // if the right-hand-side has await, spill the left + var leftBuilder = new BoundSpillSequenceBuilder(); + left = VisitExpression(ref leftBuilder, node.Left); + if (left.Entity.EntityType == EntityType.Local) + { + left = Spill(leftBuilder, left, true); + } + + leftBuilder.Include(builder); + builder = leftBuilder; + } + node.Left = left; + node.Right = right; + ReplaceCurrentNode(UpdateExpression(builder, node)); + } + + private void OnIsaOperator(BinaryExpression node) + { + BoundSpillSequenceBuilder builder = null; + node.Left = VisitExpression(ref builder, node.Left); + ReplaceCurrentNode(UpdateExpression(builder, node)); + } + + public override void OnBinaryExpression(BinaryExpression node) + { + if (node.Operator == BinaryOperatorType.Assign) + { + OnAssignment(node); + return; + } + if (node.Operator == BinaryOperatorType.TypeTest) + { + OnIsaOperator(node); + return; + } + + BoundSpillSequenceBuilder builder = null; + var right = VisitExpression(ref builder, node.Right); + Expression left; + if (builder == null) + { + left = VisitExpression(ref builder, node.Left); + } + else + { + var leftBuilder = new BoundSpillSequenceBuilder(); + left = VisitExpression(ref leftBuilder, node.Left); + left = Spill(leftBuilder, left); + if (node.Operator == BinaryOperatorType.Or || node.Operator == BinaryOperatorType.And) + { + var tmp = _F.DeclareTempLocal(_currentMethod, node.ExpressionType); + tmp.Local["SynthesizedKind"] = AWAIT_SPILL_MARKER; + leftBuilder.AddLocal(tmp); + leftBuilder.AddStatement(new ExpressionStatement(_F.CreateAssignment(_F.CreateLocalReference(tmp), left))); + var trueBlock = new Block(); + trueBlock.Add(UpdateExpression(builder, _F.CreateAssignment(_F.CreateLocalReference(tmp), right))); + leftBuilder.AddStatement( + new IfStatement(left.LexicalInfo, + node.Operator == BinaryOperatorType.And ? + _F.CreateLocalReference(tmp) : + (Expression)_F.CreateNotExpression(_F.CreateLocalReference(tmp)), + trueBlock, + null)); + + ReplaceCurrentNode(UpdateExpression(leftBuilder, _F.CreateLocalReference(tmp))); + return; + } + // if the right-hand-side has await, spill the left + leftBuilder.Include(builder); + builder = leftBuilder; + } + + node.Left = left; + node.Right = right; + ReplaceCurrentNode(UpdateExpression(builder, node)); + } + + private void OnEval(MethodInvocationExpression node) + { + BoundSpillSequenceBuilder valueBuilder = null; + var value = VisitExpression(ref valueBuilder, node.Arguments.Last); + + BoundSpillSequenceBuilder builder = null; + var seCollection = new ExpressionCollection(); + seCollection.AddRange(node.Arguments.Where(a => a != node.Arguments.Last)); + var sideEffects = VisitExpressionList(ref builder, seCollection, forceSpill: valueBuilder != null, sideEffectsOnly: true); + + if (builder == null && valueBuilder == null) + { + node.Arguments = sideEffects; + node.Arguments.Add(value); + return; + } + + if (builder == null) + { + builder = new BoundSpillSequenceBuilder(); + } + + builder.AddExpressions(sideEffects); + builder.Include(valueBuilder); + + ReplaceCurrentNode(builder.Update(value)); + } + + public override void OnMethodInvocationExpression(MethodInvocationExpression node) + { + BoundSpillSequenceBuilder builder = null; + var entity = node.Target.Entity; + if (entity.EntityType == EntityType.BuiltinFunction) + { + if (entity == BuiltinFunction.Eval) + { + OnEval(node); + } + else if (entity == BuiltinFunction.Switch) + { + throw new ArgumentException("Should be unreachable: Await spiller on switch"); + } + else base.OnMethodInvocationExpression(node); + return; + } + + var method = (IMethod) entity; + var refs = method.GetParameters().Select(p => p.IsByRef).ToList(); + node.Arguments = VisitExpressionList(ref builder, node.Arguments, refs); + + if (builder == null) + { + node.Target = VisitExpression(ref builder, node.Target); + } + else if (!method.IsStatic) + { + // spill the receiver if there were await expressions in the arguments + var targetBuilder = new BoundSpillSequenceBuilder(); + + var target = node.Target; + var isRef = TargetSpillRefKind(target); + + node.Target = Spill(targetBuilder, VisitExpression(ref targetBuilder, target), isRef); + targetBuilder.Include(builder); + builder = targetBuilder; + } + + ReplaceCurrentNode(UpdateExpression(builder, node)); + } + + private void OnSwitch(MethodInvocationExpression node, ref BoundSpillSequenceBuilder builder) + { + node.Arguments = VisitExpressionList(ref builder, node.Arguments); + } + + private static bool TargetSpillRefKind(Expression target) + { + if (target.ExpressionType.IsValueType) + { + switch (target.NodeType) + { + case NodeType.SlicingExpression: + case NodeType.SelfLiteralExpression: + case NodeType.SuperLiteralExpression: + return true; + + case NodeType.ReferenceExpression: + case NodeType.MemberReferenceExpression: + switch (target.Entity.EntityType) + { + case EntityType.Field: + case EntityType.Parameter: + case EntityType.Local: + return true; + default: + return false; + } + + case NodeType.UnaryExpression: + return ((UnaryExpression) target).Operator == UnaryOperatorType.Indirection; + + case NodeType.MethodInvocationExpression: + return ((IMethod)target.Entity).Type.IsPointer; + } + } + return false; + } + + public override void OnConditionalExpression(ConditionalExpression node) + { + BoundSpillSequenceBuilder conditionBuilder = null; + var condition = VisitExpression(ref conditionBuilder, node.Condition); + + BoundSpillSequenceBuilder trueBuilder = null; + var trueValue = VisitExpression(ref trueBuilder, node.TrueValue); + + BoundSpillSequenceBuilder falseBuilder = null; + var falseValue = VisitExpression(ref falseBuilder, node.FalseValue); + + if (trueBuilder == null && falseBuilder == null) + { + node.Condition = condition; + node.TrueValue = trueValue; + node.FalseValue = falseValue; + ReplaceCurrentNode(UpdateExpression(conditionBuilder, node)); + return; + } + + if (conditionBuilder == null) conditionBuilder = new BoundSpillSequenceBuilder(); + if (trueBuilder == null) trueBuilder = new BoundSpillSequenceBuilder(); + if (falseBuilder == null) falseBuilder = new BoundSpillSequenceBuilder(); + + if (node.ExpressionType == _tss.VoidType) + { + conditionBuilder.AddStatement( + new IfStatement( + condition.LexicalInfo, + condition, + new Block(UpdateStatement(trueBuilder, new ExpressionStatement(trueValue))), + new Block(UpdateStatement(falseBuilder, new ExpressionStatement(falseValue))))); + + ReplaceCurrentNode(conditionBuilder.Update(_F.CreateDefaultInvocation(node.LexicalInfo, node.ExpressionType))); + } + else + { + var tmp = _F.DeclareTempLocal(_currentMethod, node.ExpressionType); + tmp.Local["SynthesizedKind"] = AWAIT_SPILL_MARKER; + + conditionBuilder.AddLocal(tmp); + var trueBlock = new Block(new ExpressionStatement( + UpdateExpression(trueBuilder, _F.CreateAssignment(_F.CreateLocalReference(tmp), trueValue)))); + var falseBlock = new Block(new ExpressionStatement( + UpdateExpression(falseBuilder, _F.CreateAssignment(_F.CreateLocalReference(tmp), falseValue)))); + conditionBuilder.AddStatement(new IfStatement(condition, trueBlock, falseBlock)); + + ReplaceCurrentNode(conditionBuilder.Update(_F.CreateLocalReference(tmp))); + } + } + + public override void OnCastExpression(CastExpression node) + { + BoundSpillSequenceBuilder builder = null; + node.Target = VisitExpression(ref builder, node.Target); + ReplaceCurrentNode(UpdateExpression(builder, node)); + } + + public override void OnMemberReferenceExpression(MemberReferenceExpression node) + { + if (node.Entity.EntityType == EntityType.Field || node.Entity.EntityType == EntityType.Method) + { + BoundSpillSequenceBuilder builder = null; + node.Target = VisitExpression(ref builder, node.Target); + ReplaceCurrentNode(UpdateExpression(builder, node)); + return; + } + base.OnMemberReferenceExpression(node); + } + + public override void OnUnaryExpression(UnaryExpression node) + { + BoundSpillSequenceBuilder builder = null; + node.Operand = VisitExpression(ref builder, node.Operand); + ReplaceCurrentNode(UpdateExpression(builder, node)); + } + + #endregion + } +} diff --git a/src/Boo.Lang.Compiler/Steps/AsyncAwait/LabelCollector.cs b/src/Boo.Lang.Compiler/Steps/AsyncAwait/LabelCollector.cs new file mode 100644 index 000000000..947c0ac5e --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/AsyncAwait/LabelCollector.cs @@ -0,0 +1,32 @@ +using System.Collections.Generic; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.TypeSystem.Internal; + +namespace Boo.Lang.Compiler.Steps.AsyncAwait +{ + /// + /// Analyzes method body for labels. + /// + /// Adapted from Microsoft.CodeAnalysis.CSharp.IteratorMethodToStateMachineRewriter.LabelCollector + /// in the Roslyn codebase + /// + internal abstract class LabelCollector : FastDepthFirstVisitor + { + // transient accumulator. + protected HashSet _currentLabels; + + public override void OnLabelStatement(LabelStatement node) + { + if (node != null) + { + var cl = _currentLabels; + if (cl == null) + { + _currentLabels = cl = new HashSet(); + } + cl.Add((InternalLabel) node.Entity); + } + } + + } +} diff --git a/src/Boo.Lang.Compiler/Steps/AsyncAwait/LocalReferenceCollector.cs b/src/Boo.Lang.Compiler/Steps/AsyncAwait/LocalReferenceCollector.cs new file mode 100644 index 000000000..60329957f --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/AsyncAwait/LocalReferenceCollector.cs @@ -0,0 +1,40 @@ +using System.Collections.Generic; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.TypeSystem.Internal; + +namespace Boo.Lang.Compiler.Steps.AsyncAwait +{ + public class LocalReferenceCollector : FastDepthFirstVisitor + { + private readonly HashSet _locals; + private readonly HashSet _results = new HashSet(); + + public LocalReferenceCollector(IEnumerable locals) + { + _locals = new HashSet(locals); + } + + public override void OnReferenceExpression(ReferenceExpression node) + { + CheckReference(node); + } + + public override void OnMemberReferenceExpression(MemberReferenceExpression node) + { + CheckReference(node); + base.OnMemberReferenceExpression(node); + } + + private void CheckReference(ReferenceExpression node) + { + var local = node.Entity as InternalLocal; + if (local != null && _locals.Contains(local)) + _results.Add(local); + } + + public IEnumerable Result + { + get { return _results; } + } + } +} diff --git a/src/Boo.Lang.Compiler/Steps/BindTypeMembers.cs b/src/Boo.Lang.Compiler/Steps/BindTypeMembers.cs index b5f7ef3fe..fe24bd4d6 100644 --- a/src/Boo.Lang.Compiler/Steps/BindTypeMembers.cs +++ b/src/Boo.Lang.Compiler/Steps/BindTypeMembers.cs @@ -324,9 +324,11 @@ private Method CreateEventRaiseMethod(Event node, Field backingField) { modifiers |= TypeMemberModifiers.Protected | TypeMemberModifiers.Internal; } - + + var returnType = ((ICallableType)node.Type.Entity).GetSignature().ReturnType; + var method = CodeBuilder.CreateMethod("raise_" + node.Name, - TypeSystemServices.VoidType, + returnType, modifiers); var type = GetEntity(node.Type) as ICallableType; @@ -366,7 +368,9 @@ private Method CreateEventRaiseMethod(Event node, Field backingField) Condition = CodeBuilder.CreateReference(local), TrueBlock = new Block() }; - stmt.TrueBlock.Add(mie); + if (returnType == TypeSystemServices.VoidType) + stmt.TrueBlock.Add(mie); + else stmt.TrueBlock.Add(new ReturnStatement(mie)); method.Body.Add(stmt); return method; diff --git a/src/Boo.Lang.Compiler/Steps/CallableTypeElision.cs b/src/Boo.Lang.Compiler/Steps/CallableTypeElision.cs new file mode 100644 index 000000000..8795bd3c4 --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/CallableTypeElision.cs @@ -0,0 +1,43 @@ +using System.Collections.Generic; +using System.Linq; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.TypeSystem; + +namespace Boo.Lang.Compiler.Steps +{ + class CallableTypeElision : AbstractFastVisitorCompilerStep + { + public override void Run() + { + if (!TypeSystemServices.CompilerGeneratedTypesModuleExists()) + return; + + var cgm = TypeSystemServices.GetCompilerGeneratedTypesModule(); + var callableFinder = new TypeFinder(new TypeCollector(type => type.ParentNamespace == cgm.Entity)); + foreach (var module in CompileUnit.Modules) + { + if (module != cgm) + module.Accept(callableFinder); + } + + var foundSet = new HashSet(callableFinder.Results); + var count = 0; + while (foundSet.Count > count) + { + count = foundSet.Count; + var sweeper = new TypeFinder(new TypeCollector(type => foundSet.Contains(type))); + cgm.Accept(sweeper); + foreach (var swept in sweeper.Results) + foundSet.Add(swept); + } + + var rejects = cgm.Members + .Cast() + .Where(td => !td.Name.Contains("$adaptor$") && !foundSet.Contains(td.Entity)); + foreach (var type in rejects) + { + cgm.Members.Remove(type); + } + } + } +} diff --git a/src/Boo.Lang.Compiler/Steps/ClosureSignatureInferrer.cs b/src/Boo.Lang.Compiler/Steps/ClosureSignatureInferrer.cs index 68c00833e..5d816f33f 100644 --- a/src/Boo.Lang.Compiler/Steps/ClosureSignatureInferrer.cs +++ b/src/Boo.Lang.Compiler/Steps/ClosureSignatureInferrer.cs @@ -29,6 +29,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Boo.Lang.Compiler.Ast; @@ -39,12 +40,21 @@ namespace Boo.Lang.Compiler.Steps { class ClosureSignatureInferrer { - private BlockExpression _closure; + private readonly BlockExpression _closure; + private readonly Node _parent; + private readonly Expression _asyncParent; private IType[] _inputTypes; public ClosureSignatureInferrer(BlockExpression closure) { _closure = closure; + var parent = closure.ParentNode; + if (parent.NodeType == NodeType.AsyncBlockExpression) + { + _asyncParent = (Expression)parent; + parent = parent.ParentNode; + } + _parent = parent; InitializeInputTypes(); } @@ -67,8 +77,10 @@ public MethodInvocationExpression MethodInvocationContext { get { - MethodInvocationExpression mie = Closure.ParentNode as MethodInvocationExpression; - if (mie != null && mie.Arguments.Contains(Closure)) return mie; + MethodInvocationExpression mie = _parent as MethodInvocationExpression; + if (mie != null && + (mie.Arguments.Contains(Closure) || (_asyncParent != null && mie.Arguments.Contains(_asyncParent)))) + return mie; return null; } } @@ -86,14 +98,42 @@ public ICallableType InferCallableType() GetTypeFromMethodInvocationContext() ?? GetTypeFromDeclarationContext() ?? GetTypeFromBinaryExpressionContext() ?? - GetTypeFromCastContext()) as ICallableType; + GetTypeFromCastContext() /*?? + GetTypeFromAsyncContext() */) as ICallableType; return contextType; } +/* private IType GetTypeFromAsyncContext() + { + var async = Closure.ParentNode as AsyncBlockExpression; + if (async == null) + return null; + + var oldClosure = _closure; + _closure = async; + try + { + var unwrappedType = InferCallableType(); + if (unwrappedType == null) + return null; + async.WrappedType = WrapType(unwrappedType); + return unwrappedType; + } + finally + { + _closure = oldClosure; + } + } + + private static IType WrapType(ICallableType unwrappedType) + { + return My.Instance.GenericTaskType.GenericInfo.ConstructType(unwrappedType); + } + */ private IType GetTypeFromBinaryExpressionContext() { - BinaryExpression binary = Closure.ParentNode as BinaryExpression; + BinaryExpression binary = _parent as BinaryExpression; if (binary == null || Closure != binary.Right) return null; return binary.Left.ExpressionType; } @@ -101,13 +141,13 @@ private IType GetTypeFromBinaryExpressionContext() private IType GetTypeFromDeclarationContext() { TypeReference tr = null; - DeclarationStatement ds = Closure.ParentNode as DeclarationStatement; + DeclarationStatement ds = _parent as DeclarationStatement; if (ds != null) { tr = ds.Declaration.Type; } - - Field fd = Closure.ParentNode as Field; + + Field fd = _parent as Field; if (fd != null) { tr = fd.Type; @@ -121,10 +161,31 @@ private IType GetTypeFromMethodInvocationContext() { if (MethodInvocationContext == null) return null; - IMethod method = MethodInvocationContext.Target.Entity as IMethod; - if (method == null) return null; - int argumentIndex = MethodInvocationContext.Arguments.IndexOf(Closure); + if (argumentIndex == -1 && _asyncParent != null) + argumentIndex = MethodInvocationContext.Arguments.IndexOf(_asyncParent); + + var entity = MethodInvocationContext.Target.Entity; + var method = entity as IMethodBase; + if (method == null) + { + if (entity.EntityType == EntityType.Type) + { + var ctors = ((IType)MethodInvocationContext.Target.Entity).GetConstructors().ToArray(); + if (ctors.Length == 1) + method = ctors[0]; + else if (ctors.Length == 0) + return null; + else entity = new Ambiguous(ctors); + } + if (entity.EntityType == EntityType.Ambiguous) + { + method = ResolveAmbiguousInvocationContext((Ambiguous) entity, argumentIndex); + } + if (method == null) + return null; + } + IParameter[] parameters = method.GetParameters(); if (argumentIndex < parameters.Length) return parameters[argumentIndex].Type; @@ -132,12 +193,38 @@ private IType GetTypeFromMethodInvocationContext() return null; } + // Sometimes, for the purpose of overload resolution, it doesn't matter which overload + // you pick, because they all take the same callable type and only differ in the rest of + // the param list + private IMethod ResolveAmbiguousInvocationContext(Ambiguous entity, int argumentIndex) + { + var candidates = entity.Entities + .OfType() + .Where(m => m.GetParameters().Length > argumentIndex && m.GetParameters()[argumentIndex].Type is ICallableType) + .ToArray(); + if (candidates.Length > 0) + { + var first = candidates[0]; + if (candidates.Length == 1) + return first; + var correspondingType = first.GetParameters()[argumentIndex].Type; + if (candidates.Skip(1).All(m => m.GetParameters()[argumentIndex].Type == correspondingType)) + return first; + var returnType = ((ICallableType)first.GetParameters()[argumentIndex].Type).GetSignature().ReturnType; + if (candidates.Skip(1).All(m => ((ICallableType)m.GetParameters()[argumentIndex].Type).GetSignature().ReturnType == returnType)) + _closure["$InferredReturnType"] = returnType; + } + AstAnnotations.MarkAmbiguousSignature(MethodInvocationContext); + AstAnnotations.MarkAmbiguousSignature(_closure); + return null; + } + private IType GetTypeFromCastContext() { - TryCastExpression tryCast = Closure.ParentNode as TryCastExpression; + TryCastExpression tryCast = _parent as TryCastExpression; if (tryCast != null) return tryCast.Type.Entity as IType; - CastExpression cast = Closure.ParentNode as CastExpression; + CastExpression cast = _parent as CastExpression; if (cast != null) return cast.Type.Entity as IType; return null; diff --git a/src/Boo.Lang.Compiler/Steps/ConstantFolding.cs b/src/Boo.Lang.Compiler/Steps/ConstantFolding.cs index 5c9747bdd..ccf50734e 100644 --- a/src/Boo.Lang.Compiler/Steps/ConstantFolding.cs +++ b/src/Boo.Lang.Compiler/Steps/ConstantFolding.cs @@ -143,6 +143,10 @@ override public void LeaveBinaryExpression(BinaryExpression node) : GetFoldedIntegerLiteral(node.Operator, Convert.ToUInt64(lhs), Convert.ToUInt64(rhs)); } } + else if (node.Operator == BinaryOperatorType.TypeTest && lhsType.IsValueType) + { + folded = GetFoldedValueTypeTest(node, lhsType, (IType) node.Right.Entity); + } if (null != folded) { @@ -153,6 +157,12 @@ override public void LeaveBinaryExpression(BinaryExpression node) } } + private BoolLiteralExpression GetFoldedValueTypeTest(Expression node, IType leftType, IType rightType) + { + Context.Warnings.Add(CompilerWarningFactory.ConstantExpression(node)); + return CodeBuilder.CreateBoolLiteral(rightType.IsAssignableFrom(leftType)); + } + override public void LeaveUnaryExpression(UnaryExpression node) { if (node.Operator == UnaryOperatorType.Explode @@ -188,6 +198,19 @@ override public void LeaveUnaryExpression(UnaryExpression node) } } + public override void LeaveTryCastExpression(TryCastExpression node) + { + base.LeaveTryCastExpression(node); + var target = GetExpressionType(node.Target); + if (target.IsValueType) + { + var toType = GetType(node.Type); + ReplaceCurrentNode(toType.IsAssignableFrom(target) + ? CodeBuilder.CreateCast(toType, node.Target) + : CodeBuilder.CreateNullLiteral()); + } + } + static BoolLiteralExpression GetFoldedBoolLiteral(BinaryOperatorType @operator, bool lhs, bool rhs) { bool result; diff --git a/src/Boo.Lang.Compiler/Steps/ContextAnnotations.cs b/src/Boo.Lang.Compiler/Steps/ContextAnnotations.cs index cdc99498d..e17e2dc4d 100755 --- a/src/Boo.Lang.Compiler/Steps/ContextAnnotations.cs +++ b/src/Boo.Lang.Compiler/Steps/ContextAnnotations.cs @@ -28,6 +28,7 @@ using System; +using System.Collections.Generic; using Boo.Lang.Compiler.Ast; namespace Boo.Lang.Compiler.Steps @@ -38,6 +39,12 @@ public class ContextAnnotations private static readonly object AssemblyBuilderKey = new object(); + private static readonly object AsyncKey = new object(); + + private static readonly object AwaitInExceptionHandlerKey = new object(); + + private static readonly object FieldInvocationKey = new object(); + public static Method GetEntryPoint(CompilerContext context) { if (null == context) @@ -88,5 +95,43 @@ public static void SetAssemblyBuilder(CompilerContext context, System.Reflection private ContextAnnotations() { } + + public static void MarkAsync(INodeWithBody node) + { + ((Node)node).Annotate(AsyncKey); + } + + public static bool IsAsync(INodeWithBody node) + { + return ((Node) node).ContainsAnnotation(AsyncKey); + } + + public static void MarkAwaitInExceptionHandler(INodeWithBody node) + { + ((Node)node).Annotate(AwaitInExceptionHandlerKey); + } + + public static bool AwaitInExceptionHandler(INodeWithBody node) + { + return ((Node)node).ContainsAnnotation(AwaitInExceptionHandlerKey); + } + + public static void AddFieldInvocation(MethodInvocationExpression node) + { + var context = CompilerContext.Current; + var list = context[FieldInvocationKey] as List; + if (list == null) + { + list = new List(); + context[FieldInvocationKey] = list; + } + list.Add(node); + } + + public static List GetFieldInvocations() + { + var context = CompilerContext.Current; + return context[FieldInvocationKey] as List; + } } } \ No newline at end of file diff --git a/src/Boo.Lang.Compiler/Steps/ForeignReferenceCollector.cs b/src/Boo.Lang.Compiler/Steps/ForeignReferenceCollector.cs index e0f910df0..30f5bf2f3 100644 --- a/src/Boo.Lang.Compiler/Steps/ForeignReferenceCollector.cs +++ b/src/Boo.Lang.Compiler/Steps/ForeignReferenceCollector.cs @@ -26,12 +26,10 @@ // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #endregion -using System; using System.Linq; -using Boo.Lang; -using Boo.Lang.Compiler; using Boo.Lang.Compiler.Ast; using Boo.Lang.Compiler.Services; +using Boo.Lang.Compiler.Steps.Generators; using Boo.Lang.Compiler.TypeSystem; using Boo.Lang.Compiler.TypeSystem.Builders; using Boo.Lang.Compiler.TypeSystem.Internal; @@ -39,23 +37,25 @@ namespace Boo.Lang.Compiler.Steps { - public class ForeignReferenceCollector : FastDepthFirstVisitor + using System.Collections.Generic; + + public class ForeignReferenceCollector : FastDepthFirstVisitor { - IType _currentType; - - List _references; + private IType _currentType; - List _recursiveReferences; + private readonly List _references; + + private readonly List _recursiveReferences; - Hash _referencedEntities; + private readonly Dictionary _referencedEntities; - SelfEntity _selfEntity; + private SelfEntity _selfEntity; public ForeignReferenceCollector() { - _references = new List(); - _recursiveReferences = new List(); - _referencedEntities = new Hash(); + _references = new List(); + _recursiveReferences = new List(); + _referencedEntities = new Dictionary(); } public Node SourceNode { get; set; } @@ -73,13 +73,13 @@ public IType CurrentType _selfEntity.Type = value; } } - - public List References + + public List References { get { return _references; } } - public Hash ReferencedEntities + public Dictionary ReferencedEntities { get { return _referencedEntities; } } @@ -108,7 +108,7 @@ protected BooCodeBuilder CodeBuilder get { return _codeBuilder; } } - private EnvironmentProvision _codeBuilder = new EnvironmentProvision(); + private readonly EnvironmentProvision _codeBuilder = new EnvironmentProvision(); public BooClassBuilder CreateSkeletonClass(string name, LexicalInfo lexicalInfo) { @@ -123,28 +123,57 @@ public BooClassBuilder CreateSkeletonClass(string name, LexicalInfo lexicalInfo) public void DeclareFieldsAndConstructor(BooClassBuilder builder) { + var keys = _referencedEntities.Keys.Cast().ToArray(); + foreach (var entity in keys) + _collector.Visit(entity.Type); + if (_collector.Matches.Any()) + BuildTypeMap(builder.ClassDefinition); + // referenced entities turn into fields - foreach (ITypedEntity entity in Builtins.array(_referencedEntities.Keys)) + foreach (var entity in keys) { - Field field = builder.AddInternalField(GetUniqueName(entity.Name), entity.Type); - _referencedEntities[entity] = field.Entity; + Field field = builder.AddInternalField(GetUniqueName(entity.Name), _mapper.MapType(entity.Type)); + _referencedEntities[entity] = (InternalField) field.Entity; } // single constructor taking all referenced entities BooMethodBuilder constructor = builder.AddConstructor(); constructor.Modifiers = TypeMemberModifiers.Public; constructor.Body.Add(CodeBuilder.CreateSuperConstructorInvocation(builder.Entity.BaseType)); - foreach (ITypedEntity entity in _referencedEntities.Keys) + foreach (var entity in _referencedEntities.Keys) { - InternalField field = (InternalField)_referencedEntities[entity]; - ParameterDeclaration parameter = constructor.AddParameter(field.Name, entity.Type); + InternalField field = _referencedEntities[entity]; + ParameterDeclaration parameter = constructor.AddParameter(field.Name, ((ITypedEntity)entity).Type); constructor.Body.Add( CodeBuilder.CreateAssignment(CodeBuilder.CreateReference(field), CodeBuilder.CreateReference(parameter))); } } - private string GetUniqueName(string name) + private readonly TypeCollector _collector = new TypeCollector(type => type is IGenericParameter); + + private readonly GeneratorTypeReplacer _mapper = new GeneratorTypeReplacer(); + + private void BuildTypeMap(ClassDefinition newClass) + { + string lastName = null; + IType lastType = null; + int i = 0; + foreach (var newParam in _collector.Matches.Cast().OrderBy(t => t.GenericParameterPosition)) + { + if (!newParam.Name.Equals(lastName)) + { + lastName = newParam.Name; + var genParam = CodeBuilder.CreateGenericParameterDeclaration(i, newParam.Name); + newClass.GenericParameters.Add(genParam); + lastType = (IType) genParam.Entity; + ++i; + } + _mapper.Replace(newParam, lastType); + } + } + + private string GetUniqueName(string name) { return _uniqueNameProvider.Instance.GetUniqueName(name); } @@ -155,12 +184,17 @@ public void AdjustReferences() { foreach (Expression reference in _references) { - var entity = (InternalField)_referencedEntities[reference.Entity]; - if (null != entity) - reference.ParentNode.Replace(reference, CodeBuilder.CreateReference(entity)); + InternalField entity; + _referencedEntities.TryGetValue(reference.Entity, out entity); + if (null != entity) + { + reference.ParentNode.Replace(reference, CodeBuilder.CreateReference(entity)); + if (reference.ParentNode.NodeType == NodeType.MemberReferenceExpression) + Remap((MemberReferenceExpression)reference.ParentNode); + } } - foreach (ReferenceExpression reference in _recursiveReferences) + foreach (var reference in _recursiveReferences) { reference.ParentNode.Replace( reference, @@ -169,12 +203,28 @@ public void AdjustReferences() CurrentMethod.Entity)); } } - - public MethodInvocationExpression CreateConstructorInvocationWithReferencedEntities(IType type) - { - MethodInvocationExpression mie = CodeBuilder.CreateConstructorInvocation(type.GetConstructors().First()); - foreach (ITypedEntity entity in _referencedEntities.Keys) + + private static void Remap(MemberReferenceExpression mre) + { + var parentType = ((ITypedEntity)mre.Target.Entity).Type as TypeSystem.Generics.GenericConstructedType; + if (parentType == null) return; + var entity = (IMember)mre.Entity; + var gmm = entity as TypeSystem.Generics.IGenericMappedMember; + if (gmm != null) + entity = gmm.SourceMember; + mre.Entity = parentType.ConstructedInfo.Map(entity); + mre.ExpressionType = ((ITypedEntity)mre.Entity).Type; + } + + public MethodInvocationExpression CreateConstructorInvocationWithReferencedEntities(IType type, Method containingMethod) + { + GeneratorTypeReplacer mapper; + type = GeneratorTypeReplacer.MapTypeInMethodContext(type, containingMethod, out mapper); + MethodInvocationExpression mie = CodeBuilder.CreateConstructorInvocation(type.GetConstructors().First()); + foreach (var entity in _referencedEntities.Keys) mie.Arguments.Add(CreateForeignReference(entity)); + if (mapper != null) + mie.Accept(new GenericTypeMapper(mapper)); return mie; } @@ -193,7 +243,7 @@ public override void OnMemberReferenceExpression(MemberReferenceExpression node) Visit(node.Target); } - override public void OnReferenceExpression(ReferenceExpression node) + public override void OnReferenceExpression(ReferenceExpression node) { if (IsForeignReference(node)) { @@ -202,7 +252,7 @@ override public void OnReferenceExpression(ReferenceExpression node) } } - override public void OnSelfLiteralExpression(SelfLiteralExpression node) + public override void OnSelfLiteralExpression(SelfLiteralExpression node) { var entity = GetSelfEntity(); node.Entity = entity; @@ -210,9 +260,9 @@ override public void OnSelfLiteralExpression(SelfLiteralExpression node) _referencedEntities[entity] = null; } - private bool IsRecursiveReference(Node node) + private bool IsRecursiveReference(MemberReferenceExpression node) { - return (CurrentMethod != null && node.Entity == CurrentMethod.Entity); + return CurrentMethod != null && node.Entity == CurrentMethod.Entity; } bool IsForeignReference(ReferenceExpression node) diff --git a/src/Boo.Lang.Compiler/Steps/Generators/GeneratorExpressionProcessor.cs b/src/Boo.Lang.Compiler/Steps/Generators/GeneratorExpressionProcessor.cs index 37136853e..3fbf9a966 100644 --- a/src/Boo.Lang.Compiler/Steps/Generators/GeneratorExpressionProcessor.cs +++ b/src/Boo.Lang.Compiler/Steps/Generators/GeneratorExpressionProcessor.cs @@ -28,6 +28,7 @@ using System; using System.Collections; +using System.Collections.Generic; using Boo.Lang.Compiler.Ast; using Boo.Lang.Compiler.TypeSystem; using Boo.Lang.Compiler.TypeSystem.Builders; @@ -35,27 +36,26 @@ using Boo.Lang.Compiler.TypeSystem.Services; using Boo.Lang.Compiler.Util; using Boo.Lang.Environments; -using Boo.Lang.Runtime; namespace Boo.Lang.Compiler.Steps.Generators { class GeneratorExpressionProcessor : AbstractCompilerComponent { - GeneratorExpression _generator; - - BooClassBuilder _enumerator; + private readonly GeneratorExpression _generator; - Field _current; - - Field _enumeratorField; - - ForeignReferenceCollector _collector; - - IType _sourceItemType; - IType _sourceEnumeratorType; - IType _sourceEnumerableType; - IType _resultEnumeratorType; - GeneratorSkeleton _skeleton; + private BooClassBuilder _enumerator; + + private Field _current; + + private Field _enumeratorField; + + private readonly ForeignReferenceCollector _collector; + + private IType _sourceItemType; + private IType _sourceEnumeratorType; + private IType _sourceEnumerableType; + private IType _resultEnumeratorType; + private readonly GeneratorSkeleton _skeleton; public GeneratorExpressionProcessor(CompilerContext context, ForeignReferenceCollector collector, @@ -72,15 +72,15 @@ public void Run() RemoveReferencedDeclarations(); CreateAnonymousGeneratorType(); } - - void RemoveReferencedDeclarations() + + private void RemoveReferencedDeclarations() { - Hash referencedEntities = _collector.ReferencedEntities; + Dictionary referencedEntities = _collector.ReferencedEntities; foreach (Declaration d in _generator.Declarations) referencedEntities.Remove(d.Entity); } - - void CreateAnonymousGeneratorType() + + private void CreateAnonymousGeneratorType() { // Set up some important types _sourceItemType = TypeSystemServices.ObjectType; @@ -132,33 +132,35 @@ void CreateAnonymousGeneratorType() public MethodInvocationExpression CreateEnumerableConstructorInvocation() { - return _collector.CreateConstructorInvocationWithReferencedEntities(_skeleton.GeneratorClassBuilder.Entity); + return _collector.CreateConstructorInvocationWithReferencedEntities( + _skeleton.GeneratorClassBuilder.Entity, + _generator.GetAncestor()); } - - void EnumeratorConstructorMustCallReset() + + private void EnumeratorConstructorMustCallReset() { Constructor constructor = _enumerator.ClassDefinition.GetConstructor(0); constructor.Body.Add(CreateMethodInvocation(_enumerator.ClassDefinition, "Reset")); } - - IMethod GetMemberwiseCloneMethod() + + private IMethod GetMemberwiseCloneMethod() { return TypeSystemServices.Map( typeof(object).GetMethod("MemberwiseClone", System.Reflection.BindingFlags.NonPublic|System.Reflection.BindingFlags.Instance)); } - - MethodInvocationExpression CreateMethodInvocation(ClassDefinition cd, string name) + + private MethodInvocationExpression CreateMethodInvocation(ClassDefinition cd, string name) { - IMethod method = (IMethod)((Method)cd.Members[name]).Entity; + var method = (IMethod)((Method)cd.Members[name]).Entity; return CodeBuilder.CreateMethodInvocation( CodeBuilder.CreateSelfReference(method.DeclaringType), method); } - - void CreateCurrent() + + private void CreateCurrent() { - Property property = _enumerator.AddReadOnlyProperty("Current", TypeSystemServices.ObjectType); + var property = _enumerator.AddReadOnlyProperty("Current", TypeSystemServices.ObjectType); property.Getter.Modifiers |= TypeMemberModifiers.Virtual; property.Getter.Body.Add( new ReturnStatement( @@ -169,19 +171,20 @@ void CreateCurrent() // Since enumerator is generic, this object-typed property should be the // explicit interface implementation for the non-generic IEnumerator interface - property.ExplicitInfo = new ExplicitMemberInfo(); - property.ExplicitInfo.InterfaceType = - (SimpleTypeReference)CodeBuilder.CreateTypeReference(TypeSystemServices.IEnumeratorType); - - // ...and now we create a typed property for the generic IEnumerator<> interface + property.ExplicitInfo = new ExplicitMemberInfo + { + InterfaceType = (SimpleTypeReference) CodeBuilder.CreateTypeReference(TypeSystemServices.IEnumeratorType) + }; + + // ...and now we create a typed property for the generic IEnumerator<> interface Property typedProperty = _enumerator.AddReadOnlyProperty("Current", _skeleton.GeneratorItemType); typedProperty.Getter.Modifiers |= TypeMemberModifiers.Virtual; typedProperty.Getter.Body.Add( new ReturnStatement( CodeBuilder.CreateReference(_current))); } - - void CreateGetEnumerator() + + private void CreateGetEnumerator() { BooMethodBuilder method = _skeleton.GetEnumeratorBuilder; @@ -197,8 +200,8 @@ void CreateGetEnumerator() method.Body.Add(new ReturnStatement(mie)); } - - void CreateClone() + + private void CreateClone() { BooMethodBuilder method = _enumerator.AddVirtualMethod("Clone", TypeSystemServices.ObjectType); method.Body.Add( @@ -207,8 +210,8 @@ void CreateClone() CodeBuilder.CreateSelfReference(_enumerator.Entity), GetMemberwiseCloneMethod()))); } - - void CreateReset() + + private void CreateReset() { // Find GetEnumerator method on the source type IMethod getEnumerator = (IMethod)GetMember(_sourceEnumerableType, "GetEnumerator", EntityType.Method); @@ -220,8 +223,8 @@ void CreateReset() CodeBuilder.CreateReference((InternalField)_enumeratorField.Entity), CodeBuilder.CreateMethodInvocation(_generator.Iterator, getEnumerator))); } - - void CreateMoveNext() + + private void CreateMoveNext() { BooMethodBuilder method = _enumerator.AddVirtualMethod("MoveNext", TypeSystemServices.BoolType); @@ -234,9 +237,9 @@ void CreateMoveNext() ((IProperty)GetMember(_sourceEnumeratorType, "Current", EntityType.Property)).GetGetMethod()); Statement filter = null; - Statement stmt = null; - Block outerBlock = null; - Block innerBlock = null; + Statement stmt; + Block outerBlock; + Block innerBlock; if (null == _generator.Filter) { diff --git a/src/Boo.Lang.Compiler/Steps/Generators/GeneratorMethodProcessor.cs b/src/Boo.Lang.Compiler/Steps/Generators/GeneratorMethodProcessor.cs index 8172cba45..15618ffe1 100755 --- a/src/Boo.Lang.Compiler/Steps/Generators/GeneratorMethodProcessor.cs +++ b/src/Boo.Lang.Compiler/Steps/Generators/GeneratorMethodProcessor.cs @@ -26,10 +26,10 @@ // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #endregion -using System.Collections; using System.Collections.Generic; using System.Linq; using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.Steps.StateMachine; using Boo.Lang.Compiler.TypeSystem; using Boo.Lang.Compiler.TypeSystem.Builders; using Boo.Lang.Compiler.TypeSystem.Generics; @@ -38,159 +38,122 @@ namespace Boo.Lang.Compiler.Steps.Generators { - internal class GeneratorMethodProcessor : AbstractTransformerCompilerStep - { - private readonly InternalMethod _generator; - - private InternalMethod _moveNext; - + internal class GeneratorMethodProcessor : MethodToStateMachineTransformer + { private readonly BooClassBuilder _enumerable; - private BooClassBuilder _enumerator; + private readonly IType _generatorItemType; - private BooMethodBuilder _enumeratorConstructor; + private readonly BooMethodBuilder _getEnumeratorBuilder; - private BooMethodBuilder _enumerableConstructor; + private readonly GeneratorTypeReplacer _methodToEnumerableMapper; - private IField _state; + private BooMethodBuilder _enumerableConstructor; private IMethod _yield; private IMethod _yieldDefault; - private Field _externalEnumeratorSelf; - - private readonly List _labels; + public GeneratorMethodProcessor(CompilerContext context, InternalMethod method) + : base(context, method) + { + var skeleton = My.Instance.SkeletonFor(method); + _generatorItemType = skeleton.GeneratorItemType; + _enumerable = skeleton.GeneratorClassBuilder; + _getEnumeratorBuilder = skeleton.GetEnumeratorBuilder; + _methodToEnumerableMapper = skeleton.GeneratorClassTypeReplacer; + } - private readonly System.Collections.Generic.List _tryStatementInfoForLabels = new System.Collections.Generic.List(); + private MethodInvocationExpression _enumerableConstructorInvocation; - private readonly Hashtable _mapping; + public override void Run() + { + base.Run(); + CreateGetEnumeratorBody(_stateMachineConstructorInvocation); + FixGeneratorMethodBody(_enumerableConstructorInvocation); + } - private readonly IType _generatorItemType; + protected override void PrepareConstructorCalls() + { + base.PrepareConstructorCalls(); + _enumerableConstructorInvocation = CodeBuilder.CreateGenericConstructorInvocation( + (IType)_enumerable.ClassDefinition.Entity, + _genericParams); + } - private readonly BooMethodBuilder _getEnumeratorBuilder; + protected override void PropagateReferences() + { + // propagate the necessary parameters from + // the enumerable to the enumerator + foreach (ParameterDeclaration parameter in _method.Method.Parameters) + { + var myParam = MapParamType(parameter); - private readonly GeneratorTypeReplacer _methodToEnumerableMapper; + var entity = (InternalParameter)myParam.Entity; + if (entity.IsUsed) + { + _enumerableConstructorInvocation.Arguments.Add(CodeBuilder.CreateReference(myParam)); - private readonly GeneratorTypeReplacer _methodToEnumeratorMapper = new GeneratorTypeReplacer(); + PropagateFromEnumerableToEnumerator(_stateMachineConstructorInvocation, + entity.Name, + entity.Type); + } + } - private readonly Dictionary _entityMapper = new Dictionary(); + // propagate the external self reference if necessary + if (_externalSelfField != null) + { + var type = (IType)_externalSelfField.Type.Entity; + _enumerableConstructorInvocation.Arguments.Add( + CodeBuilder.CreateSelfReference(_methodToEnumerableMapper.MapType(type))); - public GeneratorMethodProcessor(CompilerContext context, InternalMethod method) - { - _labels = new List(); - _mapping = new Hashtable(); - _generator = method; - - var skeleton = My.Instance.SkeletonFor(method); - _generatorItemType = skeleton.GeneratorItemType; - _enumerable = skeleton.GeneratorClassBuilder; - _getEnumeratorBuilder = skeleton.GetEnumeratorBuilder; - _methodToEnumerableMapper = skeleton.GeneratorClassTypeReplacer; + PropagateFromEnumerableToEnumerator(_stateMachineConstructorInvocation, + "self_", + _methodToStateMachineMapper.MapType(type)); + } - Initialize(context); - } + } - private LexicalInfo LexicalInfo - { - get { return _generator.Method.LexicalInfo; } - } + private void PropagateFromEnumerableToEnumerator(MethodInvocationExpression enumeratorConstructorInvocation, + string parameterName, + IType parameterType) + { + Field field = DeclareFieldInitializedFromConstructorParameter( + _enumerable, + _enumerableConstructor, + parameterName, + parameterType, + _methodToEnumerableMapper); + enumeratorConstructorInvocation.Arguments.Add(CodeBuilder.CreateReference(field)); + } - private GenericParameterDeclaration[] _genericParams; + protected override void CreateStateMachine() + { + _enumerableConstructor = CreateConstructor(_enumerable); + base.CreateStateMachine(); + } - public override void Run() - { - _genericParams = _generator.Method.DeclaringType.GenericParameters.Concat(_generator.Method.GenericParameters).ToArray(); - CreateEnumerableConstructor(); - CreateEnumerator(); - var enumerableConstructorInvocation = CodeBuilder.CreateGenericConstructorInvocation( - (IType)_enumerable.ClassDefinition.Entity, - _genericParams); - var enumeratorConstructorInvocation = CodeBuilder.CreateGenericConstructorInvocation( - (IType)_enumerator.ClassDefinition.Entity, - _enumerable.ClassDefinition.GenericParameters); - PropagateReferences(enumerableConstructorInvocation, enumeratorConstructorInvocation); - CreateGetEnumeratorBody(enumeratorConstructorInvocation); - FixGeneratorMethodBody(enumerableConstructorInvocation); - } + private void CreateGetEnumeratorBody(Expression enumeratorExpression) + { + _getEnumeratorBuilder.Body.Add( + new ReturnStatement(enumeratorExpression)); + } private void FixGeneratorMethodBody(MethodInvocationExpression enumerableConstructorInvocation) - { - var body = _generator.Method.Body; - body.Clear(); - - body.Add( - new ReturnStatement( - _generator.Method.LexicalInfo, - GeneratorReturnsIEnumerator() - ? CreateGetEnumeratorInvocation(enumerableConstructorInvocation) - : enumerableConstructorInvocation)); - } - - private ParameterDeclaration MapParamType(ParameterDeclaration parameter) - { - if (parameter.Type.NodeType == NodeType.GenericTypeReference) - { - var gen = (GenericTypeReference)parameter.Type; - var genEntityType = gen.Entity as IConstructedTypeInfo; - if (genEntityType == null) - return parameter; - var trc = new TypeReferenceCollection(); - foreach (var genArg in gen.GenericArguments) - { - var replacement = genArg; - foreach (var genParam in _enumerable.ClassDefinition.GenericParameters) - if (genParam.Name.Equals(genArg.Entity.Name)) - { - replacement = new SimpleTypeReference(genParam.Name) {Entity = genParam.Entity}; - break; - } - trc.Add(replacement); - } - parameter = parameter.CloneNode(); - gen = (GenericTypeReference)parameter.Type; - gen.GenericArguments = trc; - gen.Entity = new GenericConstructedType(genEntityType.GenericDefinition, trc.Select(a => a.Entity).Cast().ToArray()); - } - return parameter; - } - - private void PropagateReferences(MethodInvocationExpression enumerableConstructorInvocation, - MethodInvocationExpression enumeratorConstructorInvocation) - { - // propagate the necessary parameters from - // the enumerable to the enumerator - foreach (ParameterDeclaration parameter in _generator.Method.Parameters) - { - var myParam = MapParamType(parameter); + { + var body = _method.Method.Body; + body.Clear(); + + body.Add( + new ReturnStatement( + _method.Method.LexicalInfo, + GeneratorReturnsIEnumerator() + ? CreateGetEnumeratorInvocation(enumerableConstructorInvocation) + : enumerableConstructorInvocation)); + } - var entity = (InternalParameter)myParam.Entity; - if (entity.IsUsed) - { - enumerableConstructorInvocation.Arguments.Add(CodeBuilder.CreateReference(myParam)); - - PropagateFromEnumerableToEnumerator(enumeratorConstructorInvocation, - entity.Name, - entity.Type); - } - } - - // propagate the external self reference if necessary - if (null != _externalEnumeratorSelf) - { - var type = (IType)_externalEnumeratorSelf.Type.Entity; - enumerableConstructorInvocation.Arguments.Add( - CodeBuilder.CreateSelfReference(_methodToEnumerableMapper.MapType(type))); - - PropagateFromEnumerableToEnumerator(enumeratorConstructorInvocation, - "self_", - _methodToEnumeratorMapper.MapType(type)); - } - - } - - private MethodInvocationExpression CreateGetEnumeratorInvocation(MethodInvocationExpression enumerableConstructorInvocation) - { + private MethodInvocationExpression CreateGetEnumeratorInvocation(MethodInvocationExpression enumerableConstructorInvocation) + { IMethod enumeratorEntity = GetGetEnumeratorEntity(); var enumeratorInfo = enumeratorEntity.DeclaringType.GenericInfo; if (enumeratorInfo != null && _genericParams.Length > 0) @@ -201,547 +164,181 @@ private MethodInvocationExpression CreateGetEnumeratorInvocation(MethodInvocatio var replacement = _genericParams.SingleOrDefault(gp => gp.Name.Equals(param.Name)); argList.Add(replacement == null ? param : (IType)replacement.Entity); } - var baseType = (IConstructedTypeInfo) new GenericConstructedType(enumeratorEntity.DeclaringType, argList.ToArray()); - enumeratorEntity = (IMethod) baseType.Map(enumeratorEntity); + var baseType = (IConstructedTypeInfo)new GenericConstructedType(enumeratorEntity.DeclaringType, argList.ToArray()); + enumeratorEntity = (IMethod)baseType.Map(enumeratorEntity); } - return CodeBuilder.CreateMethodInvocation(enumerableConstructorInvocation, enumeratorEntity); - } - - private InternalMethod GetGetEnumeratorEntity() - { - return _getEnumeratorBuilder.Entity; - } - - private bool GeneratorReturnsIEnumerator() - { - bool returnsEnumerator = _generator.ReturnType == TypeSystemServices.IEnumeratorType; - returnsEnumerator |= - _generator.ReturnType.ConstructedInfo != null && - _generator.ReturnType.ConstructedInfo.GenericDefinition == TypeSystemServices.IEnumeratorGenericType; - - return returnsEnumerator; - } + return CodeBuilder.CreateMethodInvocation(enumerableConstructorInvocation, enumeratorEntity); + } - private void CreateGetEnumeratorBody(Expression enumeratorExpression) - { - _getEnumeratorBuilder.Body.Add( - new ReturnStatement(enumeratorExpression)); - } - - private void CreateEnumerableConstructor() - { - _enumerableConstructor = CreateConstructor(_enumerable); - } - - private void CreateEnumeratorConstructor() - { - _enumeratorConstructor = CreateConstructor(_enumerator); - } - - private void CreateEnumerator() - { - _enumerator = CodeBuilder.CreateClass("$Enumerator"); - _enumerator.AddAttribute(CodeBuilder.CreateAttribute(typeof(System.Runtime.CompilerServices.CompilerGeneratedAttribute))); - _enumerator.Modifiers |= _enumerable.Modifiers; - _enumerator.LexicalInfo = this.LexicalInfo; - foreach (var param in _genericParams) - { - var replacement = _enumerator.AddGenericParameter(param.Name); - _methodToEnumeratorMapper.Replace((IType)param.Entity, (IType)replacement.Entity); - } + private InternalMethod GetGetEnumeratorEntity() + { + return _getEnumeratorBuilder.Entity; + } + + private bool GeneratorReturnsIEnumerator() + { + bool returnsEnumerator = _method.ReturnType == TypeSystemServices.IEnumeratorType; + returnsEnumerator |= + _method.ReturnType.ConstructedInfo != null && + _method.ReturnType.ConstructedInfo.GenericDefinition == TypeSystemServices.IEnumeratorGenericType; + + return returnsEnumerator; + } + + protected override void SetupStateMachine() + { + _stateMachineClass.Modifiers |= _enumerable.Modifiers; var abstractEnumeratorType = TypeSystemServices.Map(typeof(GenericGeneratorEnumerator<>)). - GenericInfo.ConstructType(_methodToEnumeratorMapper.MapType(_generatorItemType)); + GenericInfo.ConstructType(_methodToStateMachineMapper.MapType(_generatorItemType)); _state = NameResolutionService.ResolveField(abstractEnumeratorType, "_state"); _yield = NameResolutionService.ResolveMethod(abstractEnumeratorType, "Yield"); _yieldDefault = NameResolutionService.ResolveMethod(abstractEnumeratorType, "YieldDefault"); - _enumerator.AddBaseType(abstractEnumeratorType); - _enumerator.AddBaseType(TypeSystemServices.IEnumeratorType); - - CreateEnumeratorConstructor(); - CreateMoveNext(); - - _enumerable.ClassDefinition.Members.Add(_enumerator.ClassDefinition); - } - - private void CreateMoveNext() - { - Method generator = _generator.Method; - - BooMethodBuilder methodBuilder = _enumerator.AddVirtualMethod("MoveNext", TypeSystemServices.BoolType); - methodBuilder.Method.LexicalInfo = generator.LexicalInfo; - _moveNext = methodBuilder.Entity; - - TransformLocalsIntoFields(generator); - - TransformParametersIntoFieldsInitializedByConstructor(generator); - - methodBuilder.Body.Add(CreateLabel(generator)); - - // Visit() needs to know the number of the finished state - _finishedStateNumber = _labels.Count; - LabelStatement finishedLabel = CreateLabel(generator); - methodBuilder.Body.Add(generator.Body); - generator.Body.Clear(); - - Visit(methodBuilder.Body); - - methodBuilder.Body.Add(CreateYieldInvocation(LexicalInfo.Empty, _finishedStateNumber, null)); - methodBuilder.Body.Add(finishedLabel); - - methodBuilder.Body.Insert(0, - CodeBuilder.CreateSwitch( - this.LexicalInfo, - CodeBuilder.CreateMemberReference(_state), - _labels)); - - // if the method contains converted try statements, put it in a try/failure block - if (_convertedTryStatements.Count > 0) - { - IMethod dispose = CreateDisposeMethod(); - - var tryFailure = new TryStatement(); - tryFailure.ProtectedBlock.Add(methodBuilder.Body); - tryFailure.FailureBlock = new Block(); - tryFailure.FailureBlock.Add(CallMethodOnSelf(dispose)); - methodBuilder.Body.Clear(); - methodBuilder.Body.Add(tryFailure); - } - } - - private void TransformParametersIntoFieldsInitializedByConstructor(Method generator) - { - foreach (ParameterDeclaration parameter in generator.Parameters) - { - var entity = (InternalParameter)parameter.Entity; - if (entity.IsUsed) - { - var field = DeclareFieldInitializedFromConstructorParameter(_enumerator, - _enumeratorConstructor, - entity.Name, - entity.Type, - _methodToEnumeratorMapper); - _mapping[entity] = field.Entity; - } - } - } - - private void TransformLocalsIntoFields(Method generator) - { - foreach (var local in generator.Locals) - { - var entity = (InternalLocal)local.Entity; - if (IsExceptionHandlerVariable(entity)) - { - AddToMoveNextMethod(local); - continue; - } - - AddInternalFieldFor(entity); - } - generator.Locals.Clear(); - } - - private void AddToMoveNextMethod(Local local) - { - var newLocal = new InternalLocal(local, _methodToEnumerableMapper.MapType(((InternalLocal)local.Entity).Type)); - _entityMapper.Add(local.Entity, newLocal); - local.Entity = newLocal; - _moveNext.Method.Locals.Add(local); - } - - private void AddInternalFieldFor(InternalLocal entity) - { - Field field = _enumerator.AddInternalField(UniqueName(entity.Name), _methodToEnumeratorMapper.MapType(entity.Type)); - _mapping[entity] = field.Entity; - } - - private bool IsExceptionHandlerVariable(InternalLocal local) - { - Declaration originalDeclaration = local.OriginalDeclaration; - if (originalDeclaration == null) return false; - return originalDeclaration.ParentNode is ExceptionHandler; - } - - MethodInvocationExpression CallMethodOnSelf(IMethod method) - { - var entity = _enumerator.Entity; - var genParams = _enumerator.ClassDefinition.GenericParameters; - if (!genParams.IsEmpty) - { - var args = genParams.Select(gpd => gpd.Entity).Cast().ToArray(); - entity = new GenericConstructedType(entity, args); - var mapping = new InternalGenericMapping(entity, args); - method = mapping.Map(method); - } - return CodeBuilder.CreateMethodInvocation( - CodeBuilder.CreateSelfReference(entity), - method); - } + _stateMachineClass.AddBaseType(abstractEnumeratorType); + _stateMachineClass.AddBaseType(TypeSystemServices.IEnumeratorType); - private IMethod CreateDisposeMethod() - { - BooMethodBuilder mn = _enumerator.AddVirtualMethod("Dispose", TypeSystemServices.VoidType); - mn.Method.LexicalInfo = this.LexicalInfo; - - LabelStatement noEnsure = CodeBuilder.CreateLabel(_generator.Method, "noEnsure").LabelStatement; - mn.Body.Add(noEnsure); - mn.Body.Add(SetStateTo(_finishedStateNumber)); - mn.Body.Add(new ReturnStatement()); - - // Create a section calling all ensure methods for each converted try block - LabelStatement[] disposeLabels = new LabelStatement[_labels.Count]; - for (int i = 0; i < _convertedTryStatements.Count; i++) { - TryStatementInfo info = _convertedTryStatements[i]; - disposeLabels[info._stateNumber] = CodeBuilder.CreateLabel(_generator.Method, "$ensure_" + info._stateNumber).LabelStatement; - mn.Body.Add(disposeLabels[info._stateNumber]); - mn.Body.Add(SetStateTo(_finishedStateNumber)); - Block block = mn.Body; - while (info._parent != null) { - TryStatement ts = new TryStatement(); - block.Add(ts); - ts.ProtectedBlock.Add(CallMethodOnSelf(info._ensureMethod)); - block = ts.EnsureBlock = new Block(); - info = info._parent; - } - block.Add(CallMethodOnSelf(info._ensureMethod)); - mn.Body.Add(new ReturnStatement()); - } - - // now map the labels of the suspended states to the labels we just created - for (int i = 0; i < _labels.Count; i++) { - if (_tryStatementInfoForLabels[i] == null) - disposeLabels[i] = noEnsure; - else - disposeLabels[i] = disposeLabels[_tryStatementInfoForLabels[i]._stateNumber]; - } - - mn.Body.Insert(0, CodeBuilder.CreateSwitch( - this.LexicalInfo, - CodeBuilder.CreateMemberReference(_state), - disposeLabels)); - return mn.Entity; - } + } - private void PropagateFromEnumerableToEnumerator(MethodInvocationExpression enumeratorConstructorInvocation, - string parameterName, - IType parameterType) - { - Field field = DeclareFieldInitializedFromConstructorParameter( - _enumerable, - _enumerableConstructor, - parameterName, - parameterType, - _methodToEnumerableMapper); - enumeratorConstructorInvocation.Arguments.Add(CodeBuilder.CreateReference(field)); - } - - private Field DeclareFieldInitializedFromConstructorParameter(BooClassBuilder type, - BooMethodBuilder constructor, - string parameterName, - IType parameterType, - TypeReplacer replacer) + protected override string StateMachineClassName { - parameterType = replacer.MapType(parameterType); - Field field = type.AddInternalField(UniqueName(parameterName), parameterType); - InitializeFieldFromConstructorParameter(constructor, field, parameterName, parameterType); - return field; - } - - private void InitializeFieldFromConstructorParameter(BooMethodBuilder constructor, - Field field, - string parameterName, - IType parameterType) - { - ParameterDeclaration parameter = constructor.AddParameter(parameterName, parameterType); - constructor.Body.Add( - CodeBuilder.CreateAssignment( - CodeBuilder.CreateReference(field), - CodeBuilder.CreateReference(parameter))); - } - - private void OnTypeReference(TypeReference node) - { - var type = (IType)node.Entity; - node.Entity = _methodToEnumeratorMapper.MapType(type); - } - - public override void OnSimpleTypeReference(SimpleTypeReference node) - { - OnTypeReference(node); - } - - public override void OnArrayTypeReference(ArrayTypeReference node) + get { return "$Enumerator"; } + } + + protected override void SaveStateMachineClass(ClassDefinition cd) { - base.OnArrayTypeReference(node); - OnTypeReference(node); + _enumerable.ClassDefinition.Members.Add(cd); } - public override void OnCallableTypeReference(CallableTypeReference node) + private MethodInvocationExpression CreateYieldInvocation(LexicalInfo sourceLocation, int newState, Expression value) { - base.OnCallableTypeReference(node); - OnTypeReference(node); + MethodInvocationExpression invocation = CodeBuilder.CreateMethodInvocation( + CodeBuilder.CreateSelfReference(_stateMachineClass.Entity), + value != null ? _yield : _yieldDefault, + CodeBuilder.CreateIntegerLiteral(newState)); + if (value != null) invocation.Arguments.Add(value); + invocation.LexicalInfo = sourceLocation; + return invocation; } - public override void OnGenericTypeReference(GenericTypeReference node) - { - base.OnGenericTypeReference(node); - OnTypeReference(node); + protected override void CreateMoveNext() + { + Method generator = _method.Method; + + BooMethodBuilder methodBuilder = _stateMachineClass.AddVirtualMethod("MoveNext", TypeSystemServices.BoolType); + methodBuilder.Method.LexicalInfo = generator.LexicalInfo; + _moveNext = methodBuilder.Entity; + + TransformLocalsIntoFields(generator); + + TransformParametersIntoFieldsInitializedByConstructor(generator); + + methodBuilder.Body.Add(CreateLabel(generator)); + + // Visit() needs to know the number of the finished state + _finishedStateNumber = _labels.Count; + LabelStatement finishedLabel = CreateLabel(generator); + methodBuilder.Body.Add(generator.Body); + generator.Body.Clear(); + + Visit(methodBuilder.Body); + + methodBuilder.Body.Add(CreateYieldInvocation(LexicalInfo.Empty, _finishedStateNumber, null)); + methodBuilder.Body.Add(finishedLabel); + + methodBuilder.Body.Insert(0, + CodeBuilder.CreateSwitch( + this.LexicalInfo, + CodeBuilder.CreateMemberReference(_state), + _labels)); + + // if the method contains converted try statements, put it in a try/failure block + if (_convertedTryStatements.Count > 0) + { + IMethod dispose = CreateDisposeMethod(); + + var tryFailure = new TryStatement(); + tryFailure.ProtectedBlock.Add(methodBuilder.Body); + tryFailure.FailureBlock = new Block(); + tryFailure.FailureBlock.Add(CallMethodOnSelf(dispose)); + methodBuilder.Body.Clear(); + methodBuilder.Body.Add(tryFailure); + } } - public override void OnGenericTypeDefinitionReference(GenericTypeDefinitionReference node) + public override void LeaveYieldStatement(YieldStatement node) { - base.OnGenericTypeDefinitionReference(node); - OnTypeReference(node); + TryStatementInfo currentTry = _tryStatementStack.Count > 0 ? _tryStatementStack.Peek() : null; + if (currentTry != null) + { + ConvertTryStatement(currentTry); + } + var block = new Block(); + block.Add( + new ReturnStatement( + node.LexicalInfo, + CreateYieldInvocation(node.LexicalInfo, _labels.Count, node.Expression), + null)); + block.Add(CreateLabel(node)); + // setting the state back to the "running" state not required, as that state has the same ensure blocks + // as the state we are currently in. + // if (currentTry != null) { + // block.Add(SetStateTo(currentTry._stateNumber)); + // } + ReplaceCurrentNode(block); } - public override void OnReferenceExpression(ReferenceExpression node) - { - InternalField mapped = (InternalField)_mapping[node.Entity]; - if (null != mapped) - { - ReplaceCurrentNode( - CodeBuilder.CreateMemberReference( - node.LexicalInfo, - CodeBuilder.CreateSelfReference(_enumerator.Entity), - mapped)); - } - } - - public override void OnSelfLiteralExpression(SelfLiteralExpression node) - { - ReplaceCurrentNode(CodeBuilder.CreateReference(node.LexicalInfo, ExternalEnumeratorSelf())); - } - - public override void OnSuperLiteralExpression(SuperLiteralExpression node) - { - var externalSelf = CodeBuilder.CreateReference(node.LexicalInfo, ExternalEnumeratorSelf()); - if (AstUtil.IsTargetOfMethodInvocation(node)) // super(...) - ReplaceCurrentNode(CodeBuilder.CreateMemberReference(externalSelf, (IMethod)GetEntity(node))); - else // super.Method(...) - ReplaceCurrentNode(externalSelf); - } - - private IMethod RemapMethod(Node node, GenericMappedMethod gmm, GenericParameterDeclarationCollection genParams) - { - var sourceMethod = gmm.SourceMember; - if (sourceMethod.GenericInfo != null) - throw new CompilerError(node, "Mapping generic methods in generators is not implemented yet"); - - var baseType = sourceMethod.DeclaringType; - var genericInfo = baseType.GenericInfo; - if (genericInfo == null) - throw new CompilerError(node, "Mapping generic nested types in generators is not implemented yet"); - - var genericArgs = ((IGenericArgumentsProvider)gmm.DeclaringType).GenericArguments; - var mapList = new List(); - foreach (var arg in genericArgs) - { - var mappedArg = genParams.SingleOrDefault(gp => gp.Name == arg.Name); - if (mappedArg != null) - mapList.Add((IType)mappedArg.Entity); - else mapList.Add(arg); - } - var newType = (IConstructedTypeInfo)new GenericConstructedType(baseType, mapList.ToArray()); - return (IMethod)newType.Map(sourceMethod); - } - - public override void OnMemberReferenceExpression(MemberReferenceExpression node) + private IMethod CreateDisposeMethod() { - base.OnMemberReferenceExpression(node); - var gmm = node.Entity as GenericMappedMethod; - if (gmm != null) + BooMethodBuilder mn = _stateMachineClass.AddVirtualMethod("Dispose", TypeSystemServices.VoidType); + mn.Method.LexicalInfo = this.LexicalInfo; + + var noEnsure = CodeBuilder.CreateLabel(_method.Method, "noEnsure").LabelStatement; + mn.Body.Add(noEnsure); + mn.Body.Add(SetStateTo(_finishedStateNumber)); + mn.Body.Add(new ReturnStatement()); + + // Create a section calling all ensure methods for each converted try block + var disposeLabels = new LabelStatement[_labels.Count]; + foreach (var t in _convertedTryStatements) { - var genParams = _enumerator.ClassDefinition.GenericParameters; - if (genParams.IsEmpty) - return; - node.Entity = RemapMethod(node, gmm, genParams); + var info = t; + disposeLabels[info._stateNumber] = CodeBuilder.CreateLabel(_method.Method, "$ensure_" + info._stateNumber).LabelStatement; + mn.Body.Add(disposeLabels[info._stateNumber]); + mn.Body.Add(SetStateTo(_finishedStateNumber)); + var block = mn.Body; + while (info._parent != null) + { + TryStatement ts = new TryStatement(); + block.Add(ts); + ts.ProtectedBlock.Add(CallMethodOnSelf(info._ensureMethod)); + block = ts.EnsureBlock = new Block(); + info = info._parent; + } + block.Add(CallMethodOnSelf(info._ensureMethod)); + mn.Body.Add(new ReturnStatement()); + } + + // now map the labels of the suspended states to the labels we just created + for (var i = 0; i < _labels.Count; i++) + { + if (_tryStatementInfoForLabels[i] == null) + disposeLabels[i] = noEnsure; + else + disposeLabels[i] = disposeLabels[_tryStatementInfoForLabels[i]._stateNumber]; } + + mn.Body.Insert(0, CodeBuilder.CreateSwitch( + this.LexicalInfo, + CodeBuilder.CreateMemberReference(_state), + disposeLabels)); + return mn.Entity; } - public override void OnDeclaration(Declaration node) + protected override IEnumerable GetStateMachineGenericParams() { - base.OnDeclaration(node); - if (_entityMapper.ContainsKey(node.Entity)) - node.Entity = _entityMapper[node.Entity]; + return _enumerable.ClassDefinition.GenericParameters; } - public override void OnMethodInvocationExpression(MethodInvocationExpression node) - { - var superInvocation = IsInvocationOnSuperMethod(node); - base.OnMethodInvocationExpression(node); - if (!superInvocation) - return; - - var accessor = CreateAccessorForSuperMethod(node.Target); - Bind(node.Target, accessor); - } - - private IEntity CreateAccessorForSuperMethod(Expression target) - { - var superMethod = (IMethod)GetEntity(target); - var accessor = CodeBuilder.CreateMethodFromPrototype(target.LexicalInfo, superMethod, TypeMemberModifiers.Internal, UniqueName(superMethod.Name)); - var accessorEntity = (IMethod)GetEntity(accessor); - var superMethodInvocation = CodeBuilder.CreateSuperMethodInvocation(superMethod); - foreach (var p in accessorEntity.GetParameters()) - superMethodInvocation.Arguments.Add(CodeBuilder.CreateReference(p)); - accessor.Body.Add(new ReturnStatement(superMethodInvocation)); - - DeclaringTypeDefinition.Members.Add(accessor); - return GetEntity(accessor); - } - - private string UniqueName(string name) - { - return Context.GetUniqueName(name); - } - - protected TypeDefinition DeclaringTypeDefinition - { - get { return _generator.Method.DeclaringType; } - } - - private static bool IsInvocationOnSuperMethod(MethodInvocationExpression node) - { - if (node.Target is SuperLiteralExpression) - return true; - - var target = node.Target as MemberReferenceExpression; - return target != null && target.Target is SuperLiteralExpression; - } - - private Field ExternalEnumeratorSelf() - { - if (null == _externalEnumeratorSelf) - { - _externalEnumeratorSelf = DeclareFieldInitializedFromConstructorParameter( - _enumerator, - _enumeratorConstructor, - "self_", - _generator.DeclaringType, - _methodToEnumeratorMapper); - } - - return _externalEnumeratorSelf; - } - - private sealed class TryStatementInfo - { - internal TryStatement _statement; - internal TryStatementInfo _parent; - - internal bool _containsYield; - internal int _stateNumber = -1; - internal Block _replacement; - - internal IMethod _ensureMethod; - } - - private readonly System.Collections.Generic.List _convertedTryStatements - = new System.Collections.Generic.List(); - private readonly Stack _tryStatementStack = new Stack(); - private int _finishedStateNumber; - - public override bool EnterTryStatement(TryStatement node) - { - var info = new TryStatementInfo(); - info._statement = node; - if (_tryStatementStack.Count > 0) - info._parent = _tryStatementStack.Peek(); - _tryStatementStack.Push(info); - return true; - } - - private BinaryExpression SetStateTo(int num) - { - return CodeBuilder.CreateAssignment(CodeBuilder.CreateMemberReference(_state), - CodeBuilder.CreateIntegerLiteral(num)); - } - - public override void LeaveTryStatement(TryStatement node) - { - TryStatementInfo info = _tryStatementStack.Pop(); - if (info._containsYield) { - ReplaceCurrentNode(info._replacement); - TryStatementInfo currentTry = (_tryStatementStack.Count > 0) ? _tryStatementStack.Peek() : null; - info._replacement.Add(node.ProtectedBlock); - if (currentTry != null) { - ConvertTryStatement(currentTry); - info._replacement.Add(SetStateTo(currentTry._stateNumber)); - } else { - // leave try block, reset state to prevent ensure block from being called again - info._replacement.Add(SetStateTo(_finishedStateNumber)); - } - BooMethodBuilder ensureMethod = _enumerator.AddMethod("$ensure" + info._stateNumber, TypeSystemServices.VoidType, TypeMemberModifiers.Private); - ensureMethod.Body.Add(info._statement.EnsureBlock); - info._ensureMethod = ensureMethod.Entity; - info._replacement.Add(CallMethodOnSelf(ensureMethod.Entity)); - _convertedTryStatements.Add(info); - } - } - - private void ConvertTryStatement(TryStatementInfo currentTry) - { - if (currentTry._containsYield) - return; - currentTry._containsYield = true; - currentTry._stateNumber = _labels.Count; - var tryReplacement = new Block(); - //tryReplacement.Add(CreateLabel(tryReplacement)); - // when the MoveNext() is called while the enumerator is still in running state, don't jump to the - // try block, but handle it like MoveNext() calls when the enumerator is in the finished state. - _labels.Add(_labels[_finishedStateNumber]); - _tryStatementInfoForLabels.Add(currentTry); - tryReplacement.Add(SetStateTo(currentTry._stateNumber)); - currentTry._replacement = tryReplacement; - } - - public override void LeaveYieldStatement(YieldStatement node) - { - TryStatementInfo currentTry = _tryStatementStack.Count > 0 ? _tryStatementStack.Peek() : null; - if (currentTry != null) { - ConvertTryStatement(currentTry); - } - var block = new Block(); - block.Add( - new ReturnStatement( - node.LexicalInfo, - CreateYieldInvocation(node.LexicalInfo, _labels.Count, node.Expression), - null)); - block.Add(CreateLabel(node)); - // setting the state back to the "running" state not required, as that state has the same ensure blocks - // as the state we are currently in. -// if (currentTry != null) { -// block.Add(SetStateTo(currentTry._stateNumber)); -// } - ReplaceCurrentNode(block); - } - - private MethodInvocationExpression CreateYieldInvocation(LexicalInfo sourceLocation, int newState, Expression value) - { - MethodInvocationExpression invocation = CodeBuilder.CreateMethodInvocation( - CodeBuilder.CreateSelfReference(_enumerator.Entity), - value != null ? _yield : _yieldDefault, - CodeBuilder.CreateIntegerLiteral(newState)); - if (value != null) invocation.Arguments.Add(value); - invocation.LexicalInfo = sourceLocation; - return invocation; - } - - private LabelStatement CreateLabel(Node sourceNode) - { - InternalLabel label = CodeBuilder.CreateLabel(sourceNode, "$state$" + _labels.Count); - _labels.Add(label.LabelStatement); - _tryStatementInfoForLabels.Add(_tryStatementStack.Count > 0 ? _tryStatementStack.Peek() : null); - return label.LabelStatement; - } - - private BooMethodBuilder CreateConstructor(BooClassBuilder builder) - { - BooMethodBuilder constructor = builder.AddConstructor(); - constructor.Body.Add(CodeBuilder.CreateSuperConstructorInvocation(builder.Entity.BaseType)); - return constructor; - } - } + } } \ No newline at end of file diff --git a/src/Boo.Lang.Compiler/Steps/Generators/GeneratorTypeReplacer.cs b/src/Boo.Lang.Compiler/Steps/Generators/GeneratorTypeReplacer.cs index 2466bcf81..7a932dbdd 100644 --- a/src/Boo.Lang.Compiler/Steps/Generators/GeneratorTypeReplacer.cs +++ b/src/Boo.Lang.Compiler/Steps/Generators/GeneratorTypeReplacer.cs @@ -41,11 +41,43 @@ private IType ConstructType(IType sourceType) var match = TypeMap.Keys.FirstOrDefault(t => t.Name.Equals(param.Name)); if (match == null) break; - typeMap.Add(match); + typeMap.Add(TypeMap[match]); } if (typeMap.Count > 0) return sourceType.GenericInfo.ConstructType(typeMap.ToArray()); return sourceType; } + + public static IType MapTypeInMethodContext(IType type, Ast.Method method) + { + GeneratorTypeReplacer mapper; + return MapTypeInMethodContext(type, method, out mapper); + } + + public bool ContainsType(IType type) + { + return TypeMap.ContainsKey(type); + } + + public bool Any + { + get { return TypeMap.Count > 0; } + } + + public static IType MapTypeInMethodContext(IType type, Ast.Method method, out GeneratorTypeReplacer mapper) + { + if (type.GenericInfo != null && type.ConstructedInfo == null) + { + var td = method.GetAncestor(); + var allGenParams = td.GenericParameters.Concat(method.GenericParameters) + .Select(gp => (IGenericParameter) gp.Entity).ToArray(); + mapper = new GeneratorTypeReplacer(); + foreach (var genParam in type.GenericInfo.GenericParameters) + mapper.Replace(genParam, allGenParams.First(gp => gp.Name.Equals(genParam.Name))); + return mapper.MapType(type); + } + mapper = null; + return type; + } } } diff --git a/src/Boo.Lang.Compiler/Steps/GenericTypeFinder.cs b/src/Boo.Lang.Compiler/Steps/GenericTypeFinder.cs new file mode 100644 index 000000000..84de2a390 --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/GenericTypeFinder.cs @@ -0,0 +1,41 @@ +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.TypeSystem; +using Boo.Lang.Compiler.TypeSystem.Internal; + +namespace Boo.Lang.Compiler.Steps +{ + public class GenericTypeFinder : TypeFinder + { + private bool _localOnly; + + public GenericTypeFinder() : base(new TypeCollector(type => type is IGenericParameter)) + { + } + + public GenericTypeFinder(bool localOnly) : this() + { + _localOnly = localOnly; + } + + public override void OnReferenceExpression(ReferenceExpression node) + { + if (!_localOnly || IsLocal(node.Entity, node)) + base.OnReferenceExpression(node); + } + + private bool IsLocal(IEntity entity, Node node) + { + if (entity.EntityType == EntityType.Local) + { + var local = (InternalLocal) entity; + return local.Local.GetAncestor() == node.GetAncestor(); + } + if (entity.EntityType == EntityType.Parameter) + { + var param = (InternalParameter) entity; + return param.Node.ParentNode == node.GetAncestor(); + } + return false; + } + } +} diff --git a/src/Boo.Lang.Compiler/Steps/GenericTypeMapper.cs b/src/Boo.Lang.Compiler/Steps/GenericTypeMapper.cs new file mode 100644 index 000000000..02fc680f8 --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/GenericTypeMapper.cs @@ -0,0 +1,165 @@ +using System.Linq; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.Steps.Generators; +using Boo.Lang.Compiler.TypeSystem; +using Boo.Lang.Compiler.TypeSystem.Generics; +using Boo.Lang.Compiler.TypeSystem.Internal; + +namespace Boo.Lang.Compiler.Steps +{ + class GenericTypeMapper : AbstractFastVisitorCompilerStep + { + private readonly GeneratorTypeReplacer _replacer; + + public GenericTypeMapper(GeneratorTypeReplacer replacer) + { + _replacer = replacer; + } + + private void OnTypeReference(TypeReference node) + { + var type = (IType)node.Entity; + node.Entity = _replacer.MapType(type); + } + + public override void OnReferenceExpression(ReferenceExpression node) + { + var local = node.Entity as InternalLocal; + if (local != null) + { + var type = local.Type; + var mappedType = _replacer.MapType(type); + if (mappedType != type) + { + node.Entity = UpdateLocal(local.Local, mappedType); + } + } + var te = node.Entity as ITypedEntity; + if (te != null) + { + if (node.Entity is IGenericMappedMember) + ReplaceMappedEntity(node, te.Type); + else if (node.Entity.EntityType == EntityType.Type) + { + var type = (IType)node.Entity; + node.Entity = _replacer.MapType(type); + } + node.ExpressionType = ((ITypedEntity)node.Entity).Type; + } + } + + private static IEntity UpdateLocal(Local local, IType type) + { + if (type != ((ITypedEntity) local.Entity).Type) + { + local.Entity = new InternalLocal(local, type); + } + return local.Entity; + } + + public override void OnMemberReferenceExpression(MemberReferenceExpression node) + { + base.OnMemberReferenceExpression(node); + var member = node.Entity as IMember; + if (member != null) + { + var type = member.Type; + var mappedType = _replacer.MapType(type); + if (mappedType != type) + { + _replacer.Replace(type, mappedType); + node.ExpressionType = mappedType; + ReplaceMappedEntity(node, mappedType); + } + } + } + + public override void OnGenericReferenceExpression(GenericReferenceExpression node) + { + base.OnGenericReferenceExpression(node); + node.ExpressionType = _replacer.MapType(node.ExpressionType); + } + + public override void OnSelfLiteralExpression(SelfLiteralExpression node) + { + base.OnSelfLiteralExpression(node); + node.ExpressionType = _replacer.MapType(node.ExpressionType); + } + + private void ReplaceMappedEntity(MemberReferenceExpression node, IType mappedType) + { + var entity = (IMember)node.Entity; + var targetType = node.Target.ExpressionType; + var newEntity = (IMember)NameResolutionService.ResolveMember(targetType, entity.Name, entity.EntityType); + node.Entity = newEntity; + if (!newEntity.Type.Equals(mappedType)) + { + var gmi = newEntity as IGenericMethodInfo; + if (gmi != null) + { + var args = ((IConstructedMethodInfo)entity).GenericArguments.Select(_replacer.MapType).ToArray(); + newEntity = gmi.ConstructMethod(args); + if (newEntity.Type.Equals(mappedType)) + { + node.Entity = newEntity; + return; + } + } + throw new System.NotImplementedException("Incorrect mapped type for " + node.ToCodeString()); + } + } + + private void ReplaceMappedEntity(ReferenceExpression node, IType mappedType) + { + var entity = (IMember)node.Entity; + var targetType = _replacer.MapType(entity.DeclaringType); + node.Entity = NameResolutionService.ResolveMember(targetType, entity.Name, entity.EntityType); + } + + public override void OnMethodInvocationExpression(MethodInvocationExpression node) + { + base.OnMethodInvocationExpression(node); + node.ExpressionType = _replacer.MapType(node.ExpressionType); + } + + public override void OnAwaitExpression(AwaitExpression node) + { + base.OnAwaitExpression(node); + node.ExpressionType = _replacer.MapType(node.ExpressionType); + } + + public override void OnField(Field node) + { + base.OnField(node); + } + + public override void OnSimpleTypeReference(SimpleTypeReference node) + { + OnTypeReference(node); + } + + public override void OnArrayTypeReference(ArrayTypeReference node) + { + base.OnArrayTypeReference(node); + OnTypeReference(node); + } + + public override void OnCallableTypeReference(CallableTypeReference node) + { + base.OnCallableTypeReference(node); + OnTypeReference(node); + } + + public override void OnGenericTypeReference(GenericTypeReference node) + { + base.OnGenericTypeReference(node); + OnTypeReference(node); + } + + public override void OnGenericTypeDefinitionReference(GenericTypeDefinitionReference node) + { + base.OnGenericTypeDefinitionReference(node); + OnTypeReference(node); + } + } +} diff --git a/src/Boo.Lang.Compiler/Steps/ImplementICallableOnCallableDefinitions.cs b/src/Boo.Lang.Compiler/Steps/ImplementICallableOnCallableDefinitions.cs index 038577d8e..d64479d69 100644 --- a/src/Boo.Lang.Compiler/Steps/ImplementICallableOnCallableDefinitions.cs +++ b/src/Boo.Lang.Compiler/Steps/ImplementICallableOnCallableDefinitions.cs @@ -26,11 +26,13 @@ // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #endregion +using Boo.Lang.Compiler.Steps.Generators; using Boo.Lang.Compiler.TypeSystem.Internal; namespace Boo.Lang.Compiler.Steps { using System.Diagnostics; + using System.Linq; using Boo.Lang.Compiler.Ast; using Boo.Lang.Compiler.TypeSystem; @@ -207,8 +209,31 @@ void ImplementRegularICallableCall( { call.Body.Add(new ReturnStatement(mie)); } + CheckMethodGenerics(call, node); } - + + private static void CheckMethodGenerics(Method call, ClassDefinition node) + { + if (node.GenericParameters.Count > 0) + { + var tf = new GenericTypeFinder(); + call.Body.Accept(tf); + var results = tf.Results.ToArray(); + if (results.Length > 0) + { + var replacer = new GeneratorTypeReplacer(); + foreach (var gtype in results) + { + var replacement = node.GenericParameters.FirstOrDefault(gp => gp.Name == gtype.Name); + if (replacement != null) + replacer.Replace(gtype, (IType) replacement.Entity); + } + if (replacer.Any) + call.Body.Accept(new GenericTypeMapper(replacer)); + } + } + } + MethodInvocationExpression CreateInvokeInvocation(InternalCallableType type) { return CodeBuilder.CreateMethodInvocation( diff --git a/src/Boo.Lang.Compiler/Steps/InjectCallableConversions.cs b/src/Boo.Lang.Compiler/Steps/InjectCallableConversions.cs index 26dcbb7a4..36de08ca9 100644 --- a/src/Boo.Lang.Compiler/Steps/InjectCallableConversions.cs +++ b/src/Boo.Lang.Compiler/Steps/InjectCallableConversions.cs @@ -52,6 +52,28 @@ override public void Run() return; Visit(CompileUnit); + + CheckFieldInvocations(); + } + + private void CheckFieldInvocations() + { + var invocations = ContextAnnotations.GetFieldInvocations(); + if (invocations == null) return; + + foreach (var node in invocations) + { + var et = node.Target.ExpressionType; + if (et is AnonymousCallableType) + { + et = ((AnonymousCallableType) et).ConcreteType; + node.Target.ExpressionType = et; + } + var invoke = NameResolutionService.Resolve(et, "Invoke") as IMethod; + if (invoke == null) + throw new System.NotSupportedException("Invoke method on callable field not found"); + node.Target = CodeBuilder.CreateMemberReference(node.Target.LexicalInfo, node.Target, invoke); + } } override public void LeaveExpressionStatement(ExpressionStatement node) @@ -290,14 +312,39 @@ private Expression ConvertMethodReference(IType expectedType, Expression argumen if (expectedCallable != null) { var argumentType = (ICallableType) GetExpressionType(argument); - if (argumentType.GetSignature() != expectedCallable.GetSignature()) - return Adapt(expectedCallable, CreateDelegate(GetConcreteExpressionType(argument), argument)); + var expectedSig = expectedCallable.GetSignature(); + var argSig = argumentType.GetSignature(); + if (argSig != expectedSig) + { + if (TypeSystemServices.CompatibleSignatures(argSig, expectedSig) || + (TypeSystemServices.CompatibleGenericSignatures(argSig, expectedSig) /*&& IsUnspecializedGenericMethodReference(argument)*/) + ) + { + argument.ExpressionType = expectedType; + return CreateDelegate(expectedType, argument); + } + return Adapt(expectedCallable, CreateDelegate(GetConcreteExpressionType(argument), argument)); + } return CreateDelegate(expectedType, argument); } return CreateDelegate(GetConcreteExpressionType(argument), argument); } - Expression Adapt(ICallableType expected, Expression callable) + private static bool IsUnspecializedGenericMethodReference(Expression argument) + { + if (argument.NodeType != NodeType.MemberReferenceExpression) + return false; + var target = ((MemberReferenceExpression) argument).Target; + if (target.NodeType != NodeType.MethodInvocationExpression) + return false; + target = ((MethodInvocationExpression)target).Target; + if (target.Entity.EntityType != EntityType.Constructor) + return false; + var cls = ((IConstructor) target.Entity).DeclaringType; + return cls.GenericInfo != null && (cls.ConstructedInfo == null || !cls.ConstructedInfo.FullyConstructed); + } + + Expression Adapt(ICallableType expected, Expression callable) { ICallableType actual = GetExpressionType(callable) as ICallableType; if (null == actual) @@ -436,7 +483,9 @@ Expression CreateDelegate(IType type, Expression source) ? CodeBuilder.CreateNullLiteral() : ((MemberReferenceExpression)source).Target; - return CodeBuilder.CreateConstructorInvocation(GetConcreteType(type).GetConstructors().First(), + var cType = GetConcreteType(type) ?? + TypeSystemServices.GetConcreteCallableType(source, (AnonymousCallableType) type); + return CodeBuilder.CreateConstructorInvocation(cType.GetConstructors().First(), target, CodeBuilder.CreateAddressOfExpression(method)); } diff --git a/src/Boo.Lang.Compiler/Steps/NormalizeExpressions.cs b/src/Boo.Lang.Compiler/Steps/NormalizeExpressions.cs index f918227fc..08ad972c4 100644 --- a/src/Boo.Lang.Compiler/Steps/NormalizeExpressions.cs +++ b/src/Boo.Lang.Compiler/Steps/NormalizeExpressions.cs @@ -67,11 +67,11 @@ public override void OnCollectionInitializationExpression(CollectionInitializati if (node.Initializer is ListLiteralExpression) foreach (var item in ((ListLiteralExpression)node.Initializer).Items) // temp.Add(item) - initialization.Arguments.Add(NewAddInvocation(item.LexicalInfo, temp, item)); + initialization.Arguments.Add(NewAddInvocation(item.LexicalInfo, temp.CloneNode(), item)); else foreach (var pair in ((HashLiteralExpression)node.Initializer).Items) // temp.Add(key, value) - initialization.Arguments.Add(NewAddInvocation(pair.LexicalInfo, temp, pair.First, pair.Second)); + initialization.Arguments.Add(NewAddInvocation(pair.LexicalInfo, temp.CloneNode(), pair.First, pair.Second)); // return temp initialization.Arguments.Add(temp.CloneNode()); diff --git a/src/Boo.Lang.Compiler/Steps/ProcessClosures.cs b/src/Boo.Lang.Compiler/Steps/ProcessClosures.cs index e16d47a20..81d63fa24 100755 --- a/src/Boo.Lang.Compiler/Steps/ProcessClosures.cs +++ b/src/Boo.Lang.Compiler/Steps/ProcessClosures.cs @@ -26,23 +26,33 @@ // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #endregion +using System.Diagnostics; using System.Linq; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.Steps.Generators; +using Boo.Lang.Compiler.TypeSystem; using Boo.Lang.Compiler.TypeSystem.Builders; using Boo.Lang.Compiler.TypeSystem.Internal; namespace Boo.Lang.Compiler.Steps { - using Boo.Lang.Compiler.Ast; - using Boo.Lang.Compiler.TypeSystem; public class ProcessClosures : AbstractTransformerCompilerStep { - override public void Run() + public override void Run() { Visit(CompileUnit); } - override public void LeaveBlockExpression(BlockExpression node) + public override void OnAsyncBlockExpression(AsyncBlockExpression node) + { + var result = Visit(node.Block); + ReplaceCurrentNode(result); + } + + private GeneratorTypeReplacer _mapper; + + public override void LeaveBlockExpression(BlockExpression node) { var closureEntity = GetEntity(node) as InternalMethod; if (closureEntity == null) @@ -51,36 +61,60 @@ override public void LeaveBlockExpression(BlockExpression node) var collector = new ForeignReferenceCollector(); { collector.CurrentMethod = closureEntity.Method; - collector.CurrentType = (IType)closureEntity.DeclaringType; + collector.CurrentType = closureEntity.DeclaringType; closureEntity.Method.Body.Accept(collector); if (collector.ContainsForeignLocalReferences) { BooClassBuilder closureClass = CreateClosureClass(collector, closureEntity); + if (closureEntity is InternalGenericMethod) + closureEntity = GetEntity(closureEntity.Method) as InternalMethod; closureClass.ClassDefinition.LexicalInfo = node.LexicalInfo; collector.AdjustReferences(); - - ReplaceCurrentNode( + + if (_mapper != null) + { + closureClass.ClassDefinition.Accept(new GenericTypeMapper(_mapper)); + } + + ReplaceCurrentNode( CodeBuilder.CreateMemberReference( collector.CreateConstructorInvocationWithReferencedEntities( - closureClass.Entity), + closureClass.Entity, + node.GetAncestor()), closureEntity)); } else { - Expression expression = CodeBuilder.CreateMemberReference(closureEntity); + _mapper = closureEntity.Method["GenericMapper"] as GeneratorTypeReplacer; + if (_mapper != null) + closureEntity.Method.Accept(new GenericTypeMapper(_mapper)); + IMethod entity = closureEntity; + if (entity.GenericInfo != null) + { + entity = MapGenericMethod(entity, node.GetAncestor().GenericParameters); + } + Expression expression = CodeBuilder.CreateMemberReference(entity); expression.LexicalInfo = node.LexicalInfo; TypeSystemServices.GetConcreteExpressionType(expression); ReplaceCurrentNode(expression); } } } - + + private static IMethod MapGenericMethod(IMethod method, GenericParameterDeclarationCollection genericArgs) + { + var args = method.GenericInfo.GenericParameters + .Select(gp => (IType)genericArgs.First(ga => ga.Name == gp.Name).Entity).ToArray(); + return method.GenericInfo.ConstructMethod(args); + } + BooClassBuilder CreateClosureClass(ForeignReferenceCollector collector, InternalMethod closure) { Method method = closure.Method; TypeDefinition parent = method.DeclaringType; parent.Members.Remove(method); + _mapper = method["GenericMapper"] as GeneratorTypeReplacer; BooClassBuilder builder = collector.CreateSkeletonClass(closure.Name, method.LexicalInfo); parent.Members.Add(builder.ClassDefinition); @@ -99,7 +133,43 @@ BooClassBuilder CreateClosureClass(ForeignReferenceCollector collector, Internal method.Modifiers = TypeMemberModifiers.Public; var coll = new GenericTypeCollector(CodeBuilder); coll.Process(builder.ClassDefinition); + if (!method.GenericParameters.IsEmpty) + { + MapMethodGenerics(builder, method); + } + if (builder.ClassDefinition.GenericParameters.Count > 0) + MapGenerics(builder.ClassDefinition); return builder; } + + private void MapMethodGenerics(BooClassBuilder builder, Method method) + { + Debug.Assert(_mapper != null); + var classParams = builder.ClassDefinition.GenericParameters; + foreach (var genParam in method.GenericParameters) + { + var replacement = classParams.FirstOrDefault(p => p.Name.Equals(genParam.Name)); + if (replacement != null && genParam != replacement.Entity) + _mapper.Replace((IType) genParam.Entity, (IType) replacement.Entity); + } + method.GenericParameters.Clear(); + method.Entity = new InternalMethod(Environments.My.Instance, method); + } + + private void MapGenerics(ClassDefinition cd) + { + var finder = new GenericTypeFinder(); + foreach (var member in cd.Members) + member.Accept(finder); + + _mapper = _mapper ?? new GeneratorTypeReplacer(); + var genParams = cd.GenericParameters; + foreach (var genType in finder.Results) + { + var replacement = genParams.FirstOrDefault(p => p.Name.Equals(genType.Name)); + if (replacement != null && genType != replacement.Entity) + _mapper.Replace(genType, (IType)replacement.Entity); + } + } } } diff --git a/src/Boo.Lang.Compiler/Steps/ProcessGenerators.cs b/src/Boo.Lang.Compiler/Steps/ProcessGenerators.cs index 169697b4b..c73e6d6b1 100644 --- a/src/Boo.Lang.Compiler/Steps/ProcessGenerators.cs +++ b/src/Boo.Lang.Compiler/Steps/ProcessGenerators.cs @@ -39,8 +39,6 @@ public class ProcessGenerators : AbstractTransformerCompilerStep { public static readonly System.Reflection.ConstructorInfo List_IEnumerableConstructor = Methods.ConstructorOf(() => new List(default(IEnumerable))); - private Method _current; - public override void Run() { if (Errors.Count > 0) return; @@ -62,18 +60,6 @@ public override void OnField(Field node) // ignore } - public override void OnConstructor(Constructor method) - { - _current = method; - Visit(_current.Body); - } - - public override bool EnterMethod(Method method) - { - _current = method; - return true; - } - public override void LeaveMethod(Method method) { var entity = (InternalMethod)method.Entity; diff --git a/src/Boo.Lang.Compiler/Steps/ProcessGeneratorsAndAsyncMethods.cs b/src/Boo.Lang.Compiler/Steps/ProcessGeneratorsAndAsyncMethods.cs new file mode 100644 index 000000000..c57e69c8b --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/ProcessGeneratorsAndAsyncMethods.cs @@ -0,0 +1,20 @@ +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.Steps.AsyncAwait; +using Boo.Lang.Compiler.TypeSystem.Internal; + +namespace Boo.Lang.Compiler.Steps +{ + public class ProcessGeneratorsAndAsyncMethods : ProcessGenerators + { + public override void LeaveMethod(Method method) + { + if (ContextAnnotations.IsAsync(method)) + { + var entity = (InternalMethod)method.Entity; + var processor = new AsyncMethodProcessor(Context, entity); + processor.Run(); + } + else base.LeaveMethod(method); + } + } +} diff --git a/src/Boo.Lang.Compiler/Steps/ProcessMethodBodies.cs b/src/Boo.Lang.Compiler/Steps/ProcessMethodBodies.cs index b90dc5175..5b35e390d 100755 --- a/src/Boo.Lang.Compiler/Steps/ProcessMethodBodies.cs +++ b/src/Boo.Lang.Compiler/Steps/ProcessMethodBodies.cs @@ -74,6 +74,9 @@ public class ProcessMethodBodies : AbstractNamespaceSensitiveVisitorCompilerStep const string TempInitializerName = "$___temp_initializer"; + private bool _inExceptionHandler; + private bool _seenAwaitInExceptionHandler; + public override void Initialize(CompilerContext context) { base.Initialize(context); @@ -721,12 +724,37 @@ bool CheckDeclarationType(TypeReference type) override public void OnBlockExpression(BlockExpression node) { if (WasVisited(node)) return; - if (ShouldDeferClosureProcessing(node)) return; + if (ShouldDeferClosureProcessing(node)) + { + node.Annotate("$Deferred$"); + return; + } InferClosureSignature(node); ProcessClosureBody(node); } + public override void OnAwaitExpression(AwaitExpression node) + { + Visit(node.BaseExpression); + node.ExpressionType = AsyncHelper.GetAwaitType(node.BaseExpression); + if (node.ExpressionType == null) + Context.Errors.Add(CompilerErrorFactory.MissingGetAwaiter(node.BaseExpression)); + else + { + node["$GetAwaiter"] = node.BaseExpression["$GetAwaiter"]; + node["$GetResult"] = node.BaseExpression["$GetResult"]; + } + _seenAwaitInExceptionHandler |= _inExceptionHandler; + } + + public override void OnAsyncBlockExpression(AsyncBlockExpression node) + { + Visit(node.Block); + node.Entity = node.Block.Entity; + node.ExpressionType = node.Block.ExpressionType; + } + private void InferClosureSignature(BlockExpression node) { ClosureSignatureInferrer inferrer = new ClosureSignatureInferrer(node); @@ -810,8 +838,16 @@ void ProcessClosureBody(BlockExpression node) if (explicitClosureName != null) ns.DelegateTo(new AliasedNamespace(explicitClosureName, closureEntity)); + if (ContextAnnotations.IsAsync(node)) + ContextAnnotations.MarkAsync(closure); + ProcessMethodBody(closureEntity, ns); + if (!_currentMethod.Method.GenericParameters.IsEmpty) + { + CheckForGenericClosure(closure, ref closureEntity); + } + if (closureEntity.ReturnType is Unknown) TryToResolveReturnType(closureEntity); @@ -819,6 +855,36 @@ void ProcessClosureBody(BlockExpression node) node.Entity = closureEntity; } + private void CheckForGenericClosure(Method closure, ref InternalMethod closureEntity) + { + var finder = new GenericTypeFinder(true); + closure.Accept(finder); + var genParams = + finder.Results.OfType().Where(gp => gp.DeclaringEntity == _currentMethod).ToArray(); + if (genParams.Length > 0) + { + var mapper = new GeneratorTypeReplacer(); + foreach (var param in genParams) + { + var clone = ((GenericParameterDeclaration) param.Node).CleanClone(); + closure.GenericParameters.Add(clone); + clone.Entity = new InternalGenericParameter(TypeSystemServices, clone); + mapper.Replace(param, (IGenericParameter) clone.Entity); + } + var newClosureEntity = new InternalGenericMethod(My.Instance, closure); + var rets = closureEntity.ReturnExpressions; + if (rets != null) + foreach (var ret in rets) + newClosureEntity.AddReturnExpression(ret); + if (closureEntity.IsGenerator) + foreach (var yld in closureEntity.YieldExpressions) + newClosureEntity.AddYieldStatement((YieldStatement) yld.ParentNode); + closure.Entity = newClosureEntity; + closureEntity = newClosureEntity; + closure["GenericMapper"] = mapper; + } + } + protected Method CurrentMethod { get { return _currentMethod.Method; } @@ -1278,9 +1344,24 @@ void ProcessMethodBody(InternalMethod entity) ProcessMethodBody(entity, entity); } - void ProcessMethodBody(InternalMethod entity, INamespace ns) + private void ProcessMethodBody(InternalMethod entity, INamespace ns) { - ProcessNodeInMethodContext(entity, ns, entity.Method.Body); + var ieh = _inExceptionHandler; + var seenAwaitInExceptionHandler = _seenAwaitInExceptionHandler; + _inExceptionHandler = false; + _seenAwaitInExceptionHandler = false; + + try + { + ProcessNodeInMethodContext(entity, ns, entity.Method.Body); + if (_seenAwaitInExceptionHandler) + ContextAnnotations.MarkAwaitInExceptionHandler(entity.Method); + } + finally + { + _inExceptionHandler = ieh; + _seenAwaitInExceptionHandler = seenAwaitInExceptionHandler; + } } void ProcessNodeInMethodContext(InternalMethod entity, INamespace ns, Node node) @@ -1376,15 +1457,23 @@ static bool CanResolveReturnType(InternalMethod method) void ResolveReturnType(InternalMethod entity) { var method = entity.Method; - method.ReturnType = entity.ReturnExpressions == null + if (ContextAnnotations.IsAsync(method)) + { + method.ReturnType = entity.ReturnExpressions == null + ? CodeBuilder.CreateTypeReference(TypeSystemServices.TaskType) + : GetMostGenericTypeReference(entity.ReturnExpressions, true); + } + else method.ReturnType = entity.ReturnExpressions == null ? CodeBuilder.CreateTypeReference(TypeSystemServices.VoidType) - : GetMostGenericTypeReference(entity.ReturnExpressions); + : GetMostGenericTypeReference(entity.ReturnExpressions, false); TraceReturnType(method, entity); } - private TypeReference GetMostGenericTypeReference(ExpressionCollection expressions) + private TypeReference GetMostGenericTypeReference(ExpressionCollection expressions, bool isAsync) { var type = MapWildcardType(GetMostGenericType(expressions)); + if (isAsync && type != TypeSystemServices.TaskType) + type = TypeSystemServices.GenericTaskType.GenericInfo.ConstructType(type); return CodeBuilder.CreateTypeReference(type); } @@ -1434,7 +1523,24 @@ override public void OnCharLiteralExpression(CharLiteralExpression node) BindExpressionType(node, TypeSystemServices.CharType); } - private void CheckCharLiteralValue(CharLiteralExpression node) + public override void OnTryStatement(TryStatement node) + { + Visit(node.ProtectedBlock); + var ieh = _inExceptionHandler; + _inExceptionHandler = true; + try + { + Visit(node.ExceptionHandlers); + Visit(node.FailureBlock); + Visit(node.EnsureBlock); + } + finally + { + _inExceptionHandler = ieh; + } + } + + private void CheckCharLiteralValue(CharLiteralExpression node) { var value = node.Value; if (value == null || value.Length != 1) @@ -1772,6 +1878,11 @@ private void EnsureDeclarationType(DeclarationStatement node) var declaration = node.Declaration; if (declaration.Type != null) return; declaration.Type = CodeBuilder.CreateTypeReference(declaration.LexicalInfo, InferDeclarationType(node)); + var typeEntity = (IType)declaration.Type.Entity; + if (typeEntity.GenericInfo != null && typeEntity.ConstructedInfo == null) + { + declaration.Type.Entity = GeneratorTypeReplacer.MapTypeInMethodContext(typeEntity, node.GetAncestor()); + } } private IType InferDeclarationType(DeclarationStatement node) @@ -1867,14 +1978,9 @@ override public void LeaveCastExpression(CastExpression node) override public void LeaveTryCastExpression(TryCastExpression node) { - var target = GetExpressionType(node.Target); var toType = GetType(node.Type); - - if (target.IsValueType) - Error(CompilerErrorFactory.CantCastToValueType(node.Target, target)); - else if (toType.IsValueType) + if (toType.IsValueType) Error(CompilerErrorFactory.CantCastToValueType(node.Type, toType)); - BindExpressionType(node, toType); } @@ -2005,12 +2111,6 @@ override public void OnReferenceExpression(ReferenceExpression node) PostProcessReferenceExpression(node); } - private static IType SelfMapGenericType(IType type) - { - return type.GenericInfo.ConstructType( - Array.ConvertAll(type.GenericInfo.GenericParameters, gp => (IType)gp)); - } - private static bool AlreadyBound(ReferenceExpression node) { return null != node.ExpressionType; @@ -2506,6 +2606,23 @@ override public void LeaveConditionalExpression(ConditionalExpression node) var trueType = GetExpressionType(node.TrueValue); var falseType = GetExpressionType(node.FalseValue); BindExpressionType(node, GetMostGenericType(trueType, falseType)); + + // special-case handling for nullable types + var genBase = node.ExpressionType.ConstructedInfo; + if (genBase != null && + genBase.GenericDefinition == TypeSystemServices.NullableGenericType && + trueType != falseType) + { + var ctor = node.ExpressionType.GetConstructors().First(c => c.GetParameters().Length == 1); + var genType = genBase.GenericArguments[0]; + Expression baseExpr = null; + if (trueType == genType) + baseExpr = node.TrueValue; + else if (falseType == genType) + baseExpr = node.FalseValue; + if (baseExpr != null) + node.Replace(baseExpr, CodeBuilder.CreateConstructorInvocation(ctor, baseExpr)); + } } override public void LeaveYieldStatement(YieldStatement node) @@ -2536,7 +2653,13 @@ override public void LeaveReturnStatement(ReturnStatement node) return; } - IType returnType = _currentMethod.ReturnType; + // Keep async returns from erroring out + if (ContextAnnotations.IsAsync(_currentMethod.Method)) + { + expressionType = GetAsyncReturnExpressionType(expressionType); + } + + IType returnType = _currentMethod.ReturnType; if (TypeSystemServices.IsUnknown(returnType)) _currentMethod.AddReturnExpression(node.Expression); else @@ -2552,6 +2675,27 @@ override public void LeaveReturnStatement(ReturnStatement node) } } + private IType GetAsyncReturnExpressionType(IType expressionType) + { + if (expressionType == TypeSystemServices.VoidType || expressionType == TypeSystemServices.TaskType) + return TypeSystemServices.TaskType; + + var newExpressionType = TypeSystemServices.GenericTaskType.GenericInfo.ConstructType(expressionType); + + //covariance check + var cRet = _currentMethod.ReturnType; + if (cRet != newExpressionType && + !TypeSystemServices.IsUnknown(cRet) && + cRet.ConstructedInfo != null && + cRet.ConstructedInfo.GenericDefinition == TypeSystemServices.GenericTaskType) + { + var cRetArg = cRet.ConstructedInfo.GenericArguments[0]; + if (cRetArg.IsAssignableFrom(expressionType)) + newExpressionType = cRet; + } + return newExpressionType; + } + protected Expression GetCorrectIterator(Expression iterator) { IType type = GetExpressionType(iterator); @@ -3735,6 +3879,18 @@ private IEntity ResolveCallableReference(MethodInvocationExpression node, Ambigu return m.GenericInfo.ConstructMethod(arguments); }).Where(m => m != null).ToArray(); + //check for unprocessed deferred closures + var orphanClosures = node.Arguments + .OfType() + .Where(b => b.ExpressionType == null && b.ContainsAnnotation("$Deferred$")).ToArray(); + foreach (var closure in orphanClosures) + { + InferClosureSignature(closure); + ProcessClosureBody(closure); + } + if (orphanClosures.Length > 0) + return ResolveCallableReference(node, entity); + var resolved = CallableResolutionService.ResolveCallableReference(node.Arguments, methods); if (null == resolved) return null; @@ -4019,6 +4175,9 @@ protected virtual bool ProcessMethodInvocationWithInvalidParameters(MethodInvoca protected virtual void ProcessMethodInvocation(MethodInvocationExpression node, IMethod method) { + if (AstAnnotations.HasAmbiguousSignature(node)) + FixAmbiguousSignatures(node); + if (ResolvedAsExtension(node)) PostNormalizeExtensionInvocation(node, method); var targetMethod = InferGenericMethodInvocation(node, method); @@ -4041,6 +4200,33 @@ protected virtual void ProcessMethodInvocation(MethodInvocationExpression node, ApplyBuiltinMethodTypeInference(node, targetMethod); } + private void FixAmbiguousSignatures(MethodInvocationExpression node) + { + foreach (var be in node.Arguments.OfType()) + { + var expr = be.NodeType == NodeType.AsyncBlockExpression ? ((AsyncBlockExpression) be).Block : be; + if (AstAnnotations.HasAmbiguousSignature(expr)) + { + expr.ExpressionType = null; + InferClosureSignature(expr); + if (be != expr) + be.ExpressionType = expr.ExpressionType; + var associatedMethod = ((InternalMethod)expr.Entity).Method; + var sig = ((ICallableType)expr.ExpressionType).GetSignature(); + if (associatedMethod.ReturnType.Entity != sig.ReturnType) + associatedMethod.ReturnType = CreateTypeReference(associatedMethod.ReturnType.LexicalInfo, sig.ReturnType); + var parameters = sig.Parameters; + for (int i = 0; i < associatedMethod.Parameters.Count; ++i) + { + var param = associatedMethod.Parameters[i]; + if (param.Type.Entity != parameters[i].Type && + (param.Type.Entity == null || !parameters[i].Type.IsAssignableFrom((IType)param.Type.Entity))) + param.Type = CreateTypeReference(param.Type.LexicalInfo, parameters[i].Type); + } + } + } + } + private IMethod InferGenericMethodInvocation(MethodInvocationExpression node, IMethod targetMethod) { if (targetMethod.GenericInfo == null) return targetMethod; @@ -4329,7 +4515,7 @@ void ProcessTypeInvocation(MethodInvocationExpression node) if (type.GenericInfo != null && !(type is IGenericArgumentsProvider)) { - type = SelfMapGenericType(type); + type = TypeSystemServices.SelfMapGenericType(type); } var ctor = GetCorrectConstructor(node, type, node.Arguments); @@ -4922,7 +5108,7 @@ private Expression CreateNullableGetValueOrDefaultExpression(Expression target) void BindTypeTest(BinaryExpression node) { - if (CheckIsNotValueType(node, node.Left) && CheckIsaArgument(node.Right)) + if (CheckIsaArgument(node.Right)) BindExpressionType(node, TypeSystemServices.BoolType); else Error(node); @@ -5797,6 +5983,7 @@ protected virtual bool HasSideEffect(Expression node) { return node.NodeType == NodeType.MethodInvocationExpression || + node.NodeType == NodeType.AwaitExpression || AstUtil.IsAssignment(node) || AstUtil.IsIncDec(node); } diff --git a/src/Boo.Lang.Compiler/Steps/ProcessSharedLocals.cs b/src/Boo.Lang.Compiler/Steps/ProcessSharedLocals.cs index 1e54c1ab8..bc52d7ff8 100755 --- a/src/Boo.Lang.Compiler/Steps/ProcessSharedLocals.cs +++ b/src/Boo.Lang.Compiler/Steps/ProcessSharedLocals.cs @@ -27,194 +27,228 @@ #endregion using System.Linq; +using Boo.Lang.Compiler.Steps.Generators; using Boo.Lang.Compiler.TypeSystem.Builders; +using Boo.Lang.Compiler.TypeSystem.Generics; using Boo.Lang.Compiler.TypeSystem.Internal; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.TypeSystem; namespace Boo.Lang.Compiler.Steps { - using System.Collections; - using Boo.Lang; - using Boo.Lang.Compiler.Ast; - using Boo.Lang.Compiler.TypeSystem; - - public class ProcessSharedLocals : AbstractTransformerCompilerStep - { - Method _currentMethod; - - ClassDefinition _sharedLocalsClass; - - Hashtable _mappings = new Hashtable(); - - readonly List _references = new List(); - - readonly List _shared = new List(); - - int _closureDepth; - - override public void Dispose() - { - _shared.Clear(); - _references.Clear(); - _mappings.Clear(); - base.Dispose(); - } - - override public void OnField(Field node) - { - } - - override public void OnInterfaceDefinition(InterfaceDefinition node) - { - } - - override public void OnEnumDefinition(EnumDefinition node) - { - } - - override public void OnConstructor(Constructor node) - { - OnMethod(node); - } - - override public void OnMethod(Method node) - { - _references.Clear(); - _mappings.Clear(); - _currentMethod = node; - _sharedLocalsClass = null; - _closureDepth = 0; - - Visit(node.Body); - - CreateSharedLocalsClass(); - if (null != _sharedLocalsClass) - { - node.DeclaringType.Members.Add(_sharedLocalsClass); - Map(); - } - } - - override public void OnBlockExpression(BlockExpression node) - { - ++_closureDepth; - Visit(node.Body); - --_closureDepth; - } - - override public void OnGeneratorExpression(GeneratorExpression node) - { - ++_closureDepth; - Visit(node.Iterator); - Visit(node.Expression); - Visit(node.Filter); - --_closureDepth; - } - - override public void OnReferenceExpression(ReferenceExpression node) - { - ILocalEntity local = node.Entity as ILocalEntity; - if (null == local) return; - if (local.IsPrivateScope) return; - - _references.Add(node); - - if (_closureDepth == 0) return; - - local.IsShared = _currentMethod.Locals.ContainsEntity(local) - || _currentMethod.Parameters.ContainsEntity(local); - - } - - void Map() - { - IType type = (IType)_sharedLocalsClass.Entity; - InternalLocal locals = CodeBuilder.DeclareLocal(_currentMethod, "$locals", type); - - foreach (ReferenceExpression reference in _references) - { - IField mapped = (IField)_mappings[reference.Entity]; - if (null == mapped) continue; - - reference.ParentNode.Replace( - reference, - CodeBuilder.CreateMemberReference( - CodeBuilder.CreateReference(locals), - mapped)); - } - - Block initializationBlock = new Block(); - initializationBlock.Add(CodeBuilder.CreateAssignment( - CodeBuilder.CreateReference(locals), - CodeBuilder.CreateConstructorInvocation(type.GetConstructors().First()))); - InitializeSharedParameters(initializationBlock, locals); - _currentMethod.Body.Statements.Insert(0, initializationBlock); - - foreach (IEntity entity in _mappings.Keys) - { - _currentMethod.Locals.RemoveByEntity(entity); - } - } - - void InitializeSharedParameters(Block block, InternalLocal locals) - { - foreach (Node node in _currentMethod.Parameters) - { - InternalParameter param = (InternalParameter)node.Entity; - if (param.IsShared) - { - block.Add( - CodeBuilder.CreateAssignment( - CodeBuilder.CreateMemberReference( - CodeBuilder.CreateReference(locals), - (IField)_mappings[param]), - CodeBuilder.CreateReference(param))); - } - } - } - - void CreateSharedLocalsClass() - { - _shared.Clear(); - - CollectSharedLocalEntities(_currentMethod.Locals); - CollectSharedLocalEntities(_currentMethod.Parameters); - - if (_shared.Count > 0) - { - BooClassBuilder builder = CodeBuilder.CreateClass(Context.GetUniqueName(_currentMethod.Name, "locals")); - builder.Modifiers |= TypeMemberModifiers.Internal; - builder.AddBaseType(TypeSystemServices.ObjectType); - - var genericsSet = new System.Collections.Generic.HashSet(); - foreach (ILocalEntity local in _shared) - { - Field field = builder.AddInternalField( - string.Format("${0}", local.Name), - local.Type); - if (local.Type is IGenericParameter && !genericsSet.Contains(local.Type.Name)) - { - builder.AddGenericParameter(local.Type.Name); - genericsSet.Add(local.Type.Name); - } - - _mappings[local] = field.Entity; - } - - builder.AddConstructor().Body.Add( - CodeBuilder.CreateSuperConstructorInvocation(TypeSystemServices.ObjectType)); - - _sharedLocalsClass = builder.ClassDefinition; - } - } - - void CollectSharedLocalEntities(System.Collections.Generic.IEnumerable nodes) where T : Node - { - foreach (T node in nodes) - { - var local = (ILocalEntity)node.Entity; - if (local.IsShared) - _shared.Add(local); - } - } - } + using System.Collections.Generic; + + public class ProcessSharedLocals : AbstractTransformerCompilerStep + { + private Method _currentMethod; + + private ClassDefinition _sharedLocalsClass; + + private readonly Dictionary _mappings = new Dictionary(); + + private readonly List _references = new List(); + + private readonly List _shared = new List(); + + private int _closureDepth; + + public override void Dispose() + { + _shared.Clear(); + _references.Clear(); + _mappings.Clear(); + base.Dispose(); + } + + public override void OnField(Field node) + { + } + + public override void OnInterfaceDefinition(InterfaceDefinition node) + { + } + + public override void OnEnumDefinition(EnumDefinition node) + { + } + + public override void OnConstructor(Constructor node) + { + OnMethod(node); + } + + public override void OnMethod(Method node) + { + _references.Clear(); + _mappings.Clear(); + _currentMethod = node; + _sharedLocalsClass = null; + _closureDepth = 0; + + Visit(node.Body); + + CreateSharedLocalsClass(); + if (null != _sharedLocalsClass) + { + node.DeclaringType.Members.Add(_sharedLocalsClass); + Map(); + } + } + + public override void OnBlockExpression(BlockExpression node) + { + ++_closureDepth; + Visit(node.Body); + --_closureDepth; + } + + public override void OnGeneratorExpression(GeneratorExpression node) + { + ++_closureDepth; + Visit(node.Iterator); + Visit(node.Expression); + Visit(node.Filter); + --_closureDepth; + } + + public override void OnReferenceExpression(ReferenceExpression node) + { + var local = node.Entity as ILocalEntity; + if (null == local) return; + if (local.IsPrivateScope) return; + + _references.Add(node); + + if (_closureDepth == 0) return; + + local.IsShared = _currentMethod.Locals.ContainsEntity(local) + || _currentMethod.Parameters.ContainsEntity(local); + + } + + private void Map() + { + var type = GeneratorTypeReplacer.MapTypeInMethodContext((IType)_sharedLocalsClass.Entity, _currentMethod); + var conType = type as GenericConstructedType; + if (conType != null) + { + foreach (var key in _mappings.Keys.ToArray()) + _mappings[key] = (IField)conType.ConstructedInfo.Map(_mappings[key]); + } + var locals = CodeBuilder.DeclareLocal(_currentMethod, "$locals", type); + + foreach (var reference in _references) + { + IField mapped; + if (!_mappings.TryGetValue(reference.Entity, out mapped)) continue; + + reference.ParentNode.Replace( + reference, + CodeBuilder.CreateMemberReference( + CodeBuilder.CreateReference(locals), + mapped)); + } + + var initializationBlock = new Block(); + initializationBlock.Add(CodeBuilder.CreateAssignment( + CodeBuilder.CreateReference(locals), + CodeBuilder.CreateConstructorInvocation(type.GetConstructors().First()))); + InitializeSharedParameters(initializationBlock, locals); + _currentMethod.Body.Statements.Insert(0, initializationBlock); + + foreach (IEntity entity in _mappings.Keys) + { + _currentMethod.Locals.RemoveByEntity(entity); + } + } + + private void InitializeSharedParameters(Block block, InternalLocal locals) + { + foreach (var node in _currentMethod.Parameters) + { + var param = (InternalParameter)node.Entity; + if (param.IsShared) + { + block.Add( + CodeBuilder.CreateAssignment( + CodeBuilder.CreateMemberReference( + CodeBuilder.CreateReference(locals), + _mappings[param]), + CodeBuilder.CreateReference(param))); + } + } + } + + private void CreateSharedLocalsClass() + { + _shared.Clear(); + + CollectSharedLocalEntities(_currentMethod.Locals); + CollectSharedLocalEntities(_currentMethod.Parameters); + + if (_shared.Count > 0) + { + BooClassBuilder builder = CodeBuilder.CreateClass(Context.GetUniqueName(_currentMethod.Name, "locals")); + builder.Modifiers |= TypeMemberModifiers.Internal; + builder.AddBaseType(TypeSystemServices.ObjectType); + + var genericsSet = new HashSet(); + var replacer = new GeneratorTypeReplacer(); + foreach (ILocalEntity local in _shared) + { + CheckTypeForGenericParams(local.Type, genericsSet, builder, replacer); + Field field = builder.AddInternalField( + string.Format("${0}", local.Name), + replacer.MapType(local.Type)); + + _mappings[local] = (IField)field.Entity; + } + + builder.AddConstructor().Body.Add( + CodeBuilder.CreateSuperConstructorInvocation(TypeSystemServices.ObjectType)); + + _sharedLocalsClass = builder.ClassDefinition; + } + } + + private static void CheckTypeForGenericParams( + IType type, + HashSet genericsSet, + BooClassBuilder builder, + GeneratorTypeReplacer mapper) + { + if (type is IGenericParameter) + { + if (!genericsSet.Contains(type.Name)) + { + builder.AddGenericParameter(type.Name); + genericsSet.Add(type.Name); + } + if (!mapper.ContainsType(type)) + { + mapper.Replace( + type, + (IType)builder.ClassDefinition.GenericParameters + .First(gp => gp.Name.Equals(type.Name)).Entity); + } + } + else + { + var genType = type as IConstructedTypeInfo; + if (genType != null) + foreach (var arg in genType.GenericArguments) + CheckTypeForGenericParams(arg, genericsSet, builder, mapper); + } + } + + private void CollectSharedLocalEntities(IEnumerable nodes) where T : Node + { + foreach (T node in nodes) + { + var local = (ILocalEntity)node.Entity; + if (local.IsShared) + _shared.Add(local); + } + } + } } diff --git a/src/Boo.Lang.Compiler/Steps/StateMachine/MethodToStateMachineTransformer.cs b/src/Boo.Lang.Compiler/Steps/StateMachine/MethodToStateMachineTransformer.cs new file mode 100644 index 000000000..30ac2d6ed --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/StateMachine/MethodToStateMachineTransformer.cs @@ -0,0 +1,670 @@ +#region license +// Copyright (c) 2003-2017 Rodrigo B. de Oliveira (rbo@acm.org), Mason Wheeler +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Rodrigo B. de Oliveira nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#endregion +using System.Linq; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.Steps.Generators; +using Boo.Lang.Compiler.TypeSystem; +using Boo.Lang.Compiler.TypeSystem.Builders; +using Boo.Lang.Compiler.TypeSystem.Generics; +using Boo.Lang.Compiler.TypeSystem.Internal; + +namespace Boo.Lang.Compiler.Steps.StateMachine +{ + using System.Collections.Generic; + + internal abstract class MethodToStateMachineTransformer : AbstractTransformerCompilerStep + { + + protected readonly InternalMethod _method; + + protected InternalMethod _moveNext; + + protected IField _state; + + protected readonly GeneratorTypeReplacer _methodToStateMachineMapper = new GeneratorTypeReplacer(); + + protected BooClassBuilder _stateMachineClass; + + protected BooMethodBuilder _stateMachineConstructor; + + protected Field _externalSelfField; + + protected readonly List _labels; + + protected readonly List _tryStatementInfoForLabels = new List(); + + private readonly Dictionary _mapping = new Dictionary(); + + private readonly Dictionary _entityMapper = new Dictionary(); + + protected int _finishedStateNumber; + + protected MethodToStateMachineTransformer(CompilerContext context, InternalMethod method) + { + _labels = new List(); + _method = method; + + Initialize(context); + } + + protected LexicalInfo LexicalInfo + { + get { return _method.Method.LexicalInfo; } + } + + protected GenericParameterDeclaration[] _genericParams; + + protected MethodInvocationExpression _stateMachineConstructorInvocation; + + public override void Run() + { + _genericParams = _method.Method.DeclaringType.GenericParameters.Concat(_method.Method.GenericParameters).ToArray(); + CreateStateMachine(); + PrepareConstructorCalls(); + PropagateReferences(); + } + + protected virtual IEnumerable GetStateMachineGenericParams() + { + return _genericParams; + } + + protected virtual void PrepareConstructorCalls() + { + _stateMachineConstructorInvocation = CodeBuilder.CreateGenericConstructorInvocation( + (IType)_stateMachineClass.ClassDefinition.Entity, + GetStateMachineGenericParams()); + } + + protected ParameterDeclaration MapParamType(ParameterDeclaration parameter) + { + if (parameter.Type.NodeType == NodeType.GenericTypeReference) + { + var gen = (GenericTypeReference)parameter.Type; + var genEntityType = gen.Entity as IConstructedTypeInfo; + if (genEntityType == null) + return parameter; + var trc = new TypeReferenceCollection(); + foreach (var genArg in gen.GenericArguments) + { + var replacement = genArg; + foreach (var genParam in _genericParams) + if (genParam.Name.Equals(genArg.Entity.Name)) + { + replacement = new SimpleTypeReference(genParam.Name) {Entity = genParam.Entity}; + break; + } + trc.Add(replacement); + } + parameter = parameter.CloneNode(); + gen = (GenericTypeReference)parameter.Type; + gen.GenericArguments = trc; + gen.Entity = new GenericConstructedType(genEntityType.GenericDefinition, trc.Select(a => a.Entity).Cast().ToArray()); + } + return parameter; + } + + protected abstract void PropagateReferences(); + + private void CreateStateMachineConstructor() + { + _stateMachineConstructor = CreateConstructor(_stateMachineClass); + } + + protected abstract void SetupStateMachine(); + + protected abstract string StateMachineClassName { + get; + } + + protected virtual void CreateStateMachine() + { + _stateMachineClass = CodeBuilder.CreateClass(StateMachineClassName); + _stateMachineClass.AddAttribute(CodeBuilder.CreateAttribute(typeof(System.Runtime.CompilerServices.CompilerGeneratedAttribute))); + _stateMachineClass.LexicalInfo = this.LexicalInfo; + foreach (var param in _genericParams) + { + var replacement = _stateMachineClass.AddGenericParameter(param); + _methodToStateMachineMapper.Replace((IType)param.Entity, (IType)replacement.Entity); + } + + SetupStateMachine(); + CreateStateMachineConstructor(); + + SaveStateMachineClass(_stateMachineClass.ClassDefinition); + CreateMoveNext(); + } + + protected abstract void SaveStateMachineClass(ClassDefinition cd); + + protected abstract void CreateMoveNext(); + + protected void TransformParametersIntoFieldsInitializedByConstructor(Method generator) + { + foreach (ParameterDeclaration parameter in generator.Parameters) + { + var entity = (InternalParameter)parameter.Entity; + if (entity.IsUsed) + { + var field = DeclareFieldInitializedFromConstructorParameter(_stateMachineClass, + _stateMachineConstructor, + entity.Name, + entity.Type, + _methodToStateMachineMapper); + _mapping[entity] = (InternalField)field.Entity; + } + } + } + + protected void TransformLocalsIntoFields(Method stateMachine) + { + foreach (var local in stateMachine.Locals) + { + var entity = (InternalLocal)local.Entity; + if (IsExceptionHandlerVariable(entity)) + { + AddToMoveNextMethod(local); + continue; + } + + AddInternalFieldFor(entity); + } + stateMachine.Locals.Clear(); + } + + private void AddToMoveNextMethod(Local local) + { + var newLocal = new InternalLocal(local, _methodToStateMachineMapper.MapType(((InternalLocal)local.Entity).Type)); + _entityMapper.Add(local.Entity, newLocal); + local.Entity = newLocal; + _moveNext.Method.Locals.Add(local); + } + + private void AddInternalFieldFor(InternalLocal entity) + { + Field field = _stateMachineClass.AddInternalField(UniqueName(entity.Name), _methodToStateMachineMapper.MapType(entity.Type)); + _mapping[entity] = (InternalField)field.Entity; + } + + private bool IsExceptionHandlerVariable(InternalLocal local) + { + Declaration originalDeclaration = local.OriginalDeclaration; + if (originalDeclaration == null) return false; + return originalDeclaration.ParentNode is ExceptionHandler; + } + + protected MethodInvocationExpression CallMethodOnSelf(IMethod method) + { + var entity = _stateMachineClass.Entity; + var genParams = _stateMachineClass.ClassDefinition.GenericParameters; + if (!genParams.IsEmpty) + { + var args = genParams.Select(gpd => gpd.Entity).Cast().ToArray(); + entity = new GenericConstructedType(entity, args); + var mapping = new InternalGenericMapping(entity, args); + method = mapping.Map(method); + } + return CodeBuilder.CreateMethodInvocation( + CodeBuilder.CreateSelfReference(entity), + method); + } + + protected Field DeclareFieldInitializedFromConstructorParameter(BooClassBuilder type, + BooMethodBuilder constructor, + string parameterName, + IType parameterType, + TypeReplacer replacer) + { + var internalFieldType = replacer.MapType(parameterType); + Field field = type.AddInternalField(UniqueName(parameterName), internalFieldType); + InitializeFieldFromConstructorParameter(constructor, field, parameterName, internalFieldType); + return field; + } + + private void InitializeFieldFromConstructorParameter(BooMethodBuilder constructor, + Field field, + string parameterName, + IType parameterType) + { + ParameterDeclaration parameter = constructor.AddParameter(parameterName, parameterType); + constructor.Body.Add( + CodeBuilder.CreateAssignment( + CodeBuilder.CreateReference(field), + CodeBuilder.CreateReference(parameter))); + } + + private void OnTypeReference(TypeReference node) + { + var type = (IType)node.Entity; + node.Entity = _methodToStateMachineMapper.MapType(type); + } + + public override void OnSimpleTypeReference(SimpleTypeReference node) + { + OnTypeReference(node); + } + + public override void OnArrayTypeReference(ArrayTypeReference node) + { + base.OnArrayTypeReference(node); + OnTypeReference(node); + } + + public override void OnCallableTypeReference(CallableTypeReference node) + { + base.OnCallableTypeReference(node); + OnTypeReference(node); + } + + public override void OnGenericTypeReference(GenericTypeReference node) + { + base.OnGenericTypeReference(node); + OnTypeReference(node); + } + + public override void OnGenericTypeDefinitionReference(GenericTypeDefinitionReference node) + { + base.OnGenericTypeDefinitionReference(node); + OnTypeReference(node); + } + + public override void OnReferenceExpression(ReferenceExpression node) + { + InternalField mapped; + if (_mapping.TryGetValue(node.Entity, out mapped)) + { + ReplaceCurrentNode( + CodeBuilder.CreateMemberReference( + node.LexicalInfo, + CodeBuilder.CreateSelfReference(_stateMachineClass.Entity), + mapped)); + } + else if (node.Entity is IGenericMappedMember || node.Entity is IGenericParameter || node.Entity is InternalLocal) + { + node.Accept(new GenericTypeMapper(_methodToStateMachineMapper)); + } + } + + public override void OnSelfLiteralExpression(SelfLiteralExpression node) + { + var newNode = CodeBuilder.CreateMappedReference( + node.LexicalInfo, + ExternalEnumeratorSelf(), + _stateMachineClass.Entity); + ReplaceCurrentNode(newNode); + } + + public override void OnSuperLiteralExpression(SuperLiteralExpression node) + { + var externalSelf = CodeBuilder.CreateReference(node.LexicalInfo, ExternalEnumeratorSelf()); + if (AstUtil.IsTargetOfMethodInvocation(node)) // super(...) + ReplaceCurrentNode(CodeBuilder.CreateMemberReference(externalSelf, (IMethod)GetEntity(node))); + else // super.Method(...) + ReplaceCurrentNode(externalSelf); + } + + private static IMethod RemapMethod(Node node, GenericMappedMethod gmm, IType[] genParams) + { + var sourceMethod = gmm.SourceMember; + if (sourceMethod.GenericInfo != null) + throw new CompilerError(node, "Mapping generic methods in generators is not implemented yet"); + + var baseType = sourceMethod.DeclaringType; + var genericInfo = baseType.GenericInfo; + if (genericInfo == null) + throw new CompilerError(node, "Mapping generic nested types in generators is not implemented yet"); + + var genericArgs = ((IGenericArgumentsProvider)gmm.DeclaringType).GenericArguments; + var collector = new TypeCollector(type => type is IGenericParameter); + foreach (var arg in genericArgs) + collector.Visit(arg); + var mapper = new GeneratorTypeReplacer(); + foreach (var genParam in collector.Matches) + { + var mappedArg = genParams.SingleOrDefault(gp => gp.Name == genParam.Name); + if (mappedArg != null) + mapper.Replace(genParam, mappedArg); + } + var newType = (IConstructedTypeInfo)new GenericConstructedType( + baseType, + genericArgs.Select(mapper.MapType).ToArray()); + return (IMethod)newType.Map(sourceMethod); + } + + public override void OnMemberReferenceExpression(MemberReferenceExpression node) + { + base.OnMemberReferenceExpression(node); + + var genParams = GetGenericParams(node); + if (genParams != null) + { + var gmm = node.Entity as GenericMappedMethod; + if (gmm != null) + { + node.Entity = RemapMethod(node, gmm, genParams); + node.ExpressionType = ((IMethod)node.Entity).CallableType; + return; + } + } + var member = node.Entity as IMember; + if (member != null) + MapMember(node, member); + } + + private void MapMember(MemberReferenceExpression node, IMember member) + { + var baseType = member.DeclaringType; + var mapped = member as IGenericMappedMember; + if (mapped != null) + { + if (baseType == node.Target.ExpressionType) + return; + member = mapped.SourceMember; + } + var didMap = false; + if (node.Target.ExpressionType != null) + { + var newType = node.Target.ExpressionType.ConstructedInfo; + if (newType != null) + { + member = newType.Map(member); + didMap = true; + } + else if (node.Target.ExpressionType.GenericInfo != null) + throw new System.InvalidOperationException("Bad target type"); + } + if (!didMap && member.EntityType == EntityType.Method) + { + var gen = member as IGenericMethodInfo; + if (gen != null) + { + foreach (var gp in gen.GenericParameters) + if (!_methodToStateMachineMapper.ContainsType(gp)) + { + var replacement = this._genericParams.FirstOrDefault(p => p.Name == gp.Name); + if (replacement != null) + _methodToStateMachineMapper.Replace(gp, _methodToStateMachineMapper.MapType((IType)replacement.Entity)); + } + member = gen.ConstructMethod( + gen.GenericParameters.Select(_methodToStateMachineMapper.MapType).ToArray()); + } + else + { + var con = member as IConstructedMethodInfo; + if (con != null) + { + var gd = (IGenericMethodInfo)con.GenericDefinition; + member = gd.ConstructMethod(con.GenericArguments.Select(_methodToStateMachineMapper.MapType).ToArray()); + } + } + } + node.Entity = member; + node.ExpressionType = member.Type; + } + + private void MapMember(GenericReferenceExpression node, IMember member) + { + if (member.EntityType == EntityType.Constructor) + { + //if this is an External constructor, we don't care about mapping it here. + var mappedCtor = member as GenericMappedConstructor; + if (mappedCtor != null) + MapConstructor(node, mappedCtor); + return; + } + var genArgs = node.GenericArguments + .Select(ga => _methodToStateMachineMapper.MapType((IType) ga.Entity)) + .ToArray(); + var mapped = member as IGenericMappedMember; + if (mapped != null) + { + var source = mapped.SourceMember; + member = source.DeclaringType.GenericInfo.ConstructType(genArgs).ConstructedInfo.Map(source); + } + else + { + var method = ((IMethod)member).ConstructedInfo; + member = method.GenericDefinition.GenericInfo.ConstructMethod(genArgs); + } + + node.Entity = member; + node.ExpressionType = member.Type; + } + + private void MapConstructor(GenericReferenceExpression node, GenericMappedConstructor member) + { + var source = member.SourceMember; + var genArgs = node.GenericArguments + .Select(ga => _methodToStateMachineMapper.MapType((IType) ga.Entity)) + .ToArray(); + var result = source.DeclaringType.GenericInfo.ConstructType(genArgs).ConstructedInfo.Map(source); + node.Entity = result; + node.ExpressionType = result.Type; + } + + private static IType[] GetGenericParams(MemberReferenceExpression node) + { + var target = node.Target.Entity ?? node.Target.ExpressionType; + IType targetType; + if (target is IMember) + targetType = ((IMember)target).DeclaringType; + else if (target.EntityType == EntityType.Type) + targetType = (IType)target; + else return null; + + IType[] genParams; + if (targetType.ConstructedInfo != null) + genParams = targetType.ConstructedInfo.GenericArguments; + else if (targetType.GenericInfo != null) + genParams = System.Array.ConvertAll(targetType.GenericInfo.GenericParameters, igp => (IType) igp); + else genParams = null; + return genParams; + } + + public override void OnDeclaration(Declaration node) + { + base.OnDeclaration(node); + if (node.Entity != null && _entityMapper.ContainsKey(node.Entity)) + node.Entity = _entityMapper[node.Entity]; + } + + public override void OnMethodInvocationExpression(MethodInvocationExpression node) + { + var superInvocation = IsInvocationOnSuperMethod(node); + var et = node.ExpressionType; + base.OnMethodInvocationExpression(node); + if (node.Target.Entity.EntityType == EntityType.Field) + ContextAnnotations.AddFieldInvocation(node); + else if (node.Target.Entity.EntityType == EntityType.BuiltinFunction) + { + if (node.Target.Entity == BuiltinFunction.Default) + node.ExpressionType = node.Arguments[0].ExpressionType; + } + if (et != null && + ((et.GenericInfo != null || + (et.ConstructedInfo != null && !et.ConstructedInfo.FullyConstructed)) + && node.Target.ExpressionType != null)) + node.ExpressionType = node.Target.Entity.EntityType == EntityType.Constructor ? + ((IConstructor)node.Target.Entity).DeclaringType: + ((ICallableType)node.Target.ExpressionType).GetSignature().ReturnType; + if (!superInvocation) + return; + + var accessor = CreateAccessorForSuperMethod(node.Target); + Bind(node.Target, accessor); + } + + public override void OnGenericReferenceExpression(GenericReferenceExpression node) + { + base.OnGenericReferenceExpression(node); + node.ExpressionType = _methodToStateMachineMapper.MapType(node.ExpressionType); + var member = node.Entity as IMember; + if (member != null) + MapMember(node, member); + } + + private IEntity CreateAccessorForSuperMethod(Expression target) + { + var superMethod = (IMethod)GetEntity(target); + var accessor = CodeBuilder.CreateMethodFromPrototype(target.LexicalInfo, superMethod, TypeMemberModifiers.Internal, UniqueName(superMethod.Name)); + var accessorEntity = (IMethod)GetEntity(accessor); + var superMethodInvocation = CodeBuilder.CreateSuperMethodInvocation(superMethod); + foreach (var p in accessorEntity.GetParameters()) + superMethodInvocation.Arguments.Add(CodeBuilder.CreateReference(p)); + accessor.Body.Add(new ReturnStatement(superMethodInvocation)); + + DeclaringTypeDefinition.Members.Add(accessor); + return GetEntity(accessor); + } + + protected string UniqueName(string name) + { + return Context.GetUniqueName(name); + } + + protected TypeDefinition DeclaringTypeDefinition + { + get { return _method.Method.DeclaringType; } + } + + private static bool IsInvocationOnSuperMethod(MethodInvocationExpression node) + { + if (node.Target is SuperLiteralExpression) + return true; + + var target = node.Target as MemberReferenceExpression; + return target != null && target.Target is SuperLiteralExpression; + } + + private Field ExternalEnumeratorSelf() + { + if (null == _externalSelfField) + { + _externalSelfField = DeclareFieldInitializedFromConstructorParameter( + _stateMachineClass, + _stateMachineConstructor, + "self_", + TypeSystemServices.SelfMapGenericType(_method.DeclaringType), + _methodToStateMachineMapper); + } + + return _externalSelfField; + } + + protected sealed class TryStatementInfo + { + internal TryStatement _statement; + internal TryStatementInfo _parent; + + internal bool _containsYield; + internal int _stateNumber = -1; + internal Block _replacement; + + internal IMethod _ensureMethod; + internal ExceptionHandlerCollection _handlers; + } + + protected readonly System.Collections.Generic.List _convertedTryStatements + = new System.Collections.Generic.List(); + protected readonly Stack _tryStatementStack = new Stack(); + + public override bool EnterTryStatement(TryStatement node) + { + var info = new TryStatementInfo(); + info._statement = node; + if (_tryStatementStack.Count > 0) + info._parent = _tryStatementStack.Peek(); + _tryStatementStack.Push(info); + return true; + } + + protected virtual BinaryExpression SetStateTo(int num) + { + return CodeBuilder.CreateAssignment(CodeBuilder.CreateMemberReference(_state), + CodeBuilder.CreateIntegerLiteral(num)); + } + + public override void LeaveTryStatement(TryStatement node) + { + TryStatementInfo info = _tryStatementStack.Pop(); + if (info._containsYield) { + ReplaceCurrentNode(info._replacement); + info._handlers = node.ExceptionHandlers; + TryStatementInfo currentTry = (_tryStatementStack.Count > 0) ? _tryStatementStack.Peek() : null; + info._replacement.Add(node.ProtectedBlock); + if (currentTry != null) { + ConvertTryStatement(currentTry); + info._replacement.Add(SetStateTo(currentTry._stateNumber)); + } else { + // leave try block, reset state to prevent ensure block from being called again + info._replacement.Add(SetStateTo(_finishedStateNumber)); + } + if (info._statement.EnsureBlock != null) + { + BooMethodBuilder ensureMethod = _stateMachineClass.AddMethod("$ensure" + info._stateNumber, + TypeSystemServices.VoidType, TypeMemberModifiers.Private); + ensureMethod.Body.Add(info._statement.EnsureBlock); + info._ensureMethod = ensureMethod.Entity; + info._replacement.Add(CallMethodOnSelf(ensureMethod.Entity)); + } + _convertedTryStatements.Add(info); + } + } + + protected void ConvertTryStatement(TryStatementInfo currentTry) + { + if (currentTry._containsYield) + return; + currentTry._containsYield = true; + currentTry._stateNumber = _labels.Count; + var tryReplacement = new Block(); + //tryReplacement.Add(CreateLabel(tryReplacement)); + // when the MoveNext() is called while the enumerator is still in running state, don't jump to the + // try block, but handle it like MoveNext() calls when the enumerator is in the finished state. + _labels.Add(_labels[_finishedStateNumber]); + _tryStatementInfoForLabels.Add(currentTry); + tryReplacement.Add(SetStateTo(currentTry._stateNumber)); + currentTry._replacement = tryReplacement; + } + + protected LabelStatement CreateLabel(Node sourceNode) + { + InternalLabel label = CodeBuilder.CreateLabel(sourceNode, "$state$" + _labels.Count); + _labels.Add(label.LabelStatement); + _tryStatementInfoForLabels.Add(_tryStatementStack.Count > 0 ? _tryStatementStack.Peek() : null); + return label.LabelStatement; + } + + protected virtual BooMethodBuilder CreateConstructor(BooClassBuilder builder) + { + BooMethodBuilder constructor = builder.AddConstructor(); + constructor.Body.Add(CodeBuilder.CreateSuperConstructorInvocation(builder.Entity.BaseType)); + return constructor; + } + } +} diff --git a/src/Boo.Lang.Compiler/Steps/StateMachine/StateMachineStates.cs b/src/Boo.Lang.Compiler/Steps/StateMachine/StateMachineStates.cs new file mode 100644 index 000000000..e28187b20 --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/StateMachine/StateMachineStates.cs @@ -0,0 +1,9 @@ +namespace Boo.Lang.Compiler.Steps.StateMachine +{ + internal static class StateMachineStates + { + internal static readonly int FinishedStateMachine = -2; + internal static readonly int NotStartedStateMachine = -1; + internal static readonly int FirstUnusedState = 0; + } +} diff --git a/src/Boo.Lang.Compiler/Steps/TypeFinder.cs b/src/Boo.Lang.Compiler/Steps/TypeFinder.cs new file mode 100644 index 000000000..a81dfe1d1 --- /dev/null +++ b/src/Boo.Lang.Compiler/Steps/TypeFinder.cs @@ -0,0 +1,73 @@ +using System.Collections.Generic; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.TypeSystem; + +namespace Boo.Lang.Compiler.Steps +{ + public class TypeFinder : AbstractFastVisitorCompilerStep + { + protected TypeCollector _collector; + + public TypeFinder(TypeCollector collector) + { + _collector = collector; + } + + public IEnumerable Results + { + get { return _collector.Matches; } + } + + private void OnTypeReference(TypeReference node) + { + var type = (IType)node.Entity; + _collector.Visit(type); + } + + public override void OnReferenceExpression(ReferenceExpression node) + { + var te = node.Entity as ITypedEntity; + if (te != null) + { + _collector.Visit(te.Type); + if (te.EntityType == EntityType.Constructor) + _collector.Visit(((IConstructor) te).DeclaringType); + } + } + + public override void OnMemberReferenceExpression(MemberReferenceExpression node) + { + base.OnMemberReferenceExpression(node); + OnReferenceExpression(node); + } + + public override void OnSimpleTypeReference(SimpleTypeReference node) + { + OnTypeReference(node); + } + + public override void OnArrayTypeReference(ArrayTypeReference node) + { + base.OnArrayTypeReference(node); + OnTypeReference(node); + } + + public override void OnCallableTypeReference(CallableTypeReference node) + { + base.OnCallableTypeReference(node); + OnTypeReference(node); + } + + public override void OnGenericTypeReference(GenericTypeReference node) + { + base.OnGenericTypeReference(node); + OnTypeReference(node); + } + + public override void OnGenericTypeDefinitionReference(GenericTypeDefinitionReference node) + { + base.OnGenericTypeDefinitionReference(node); + OnTypeReference(node); + } + } +} \ No newline at end of file diff --git a/src/Boo.Lang.Compiler/TypeSystem/Builders/BooClassBuilder.cs b/src/Boo.Lang.Compiler/TypeSystem/Builders/BooClassBuilder.cs index 98c37e15f..e3b2eaf8a 100755 --- a/src/Boo.Lang.Compiler/TypeSystem/Builders/BooClassBuilder.cs +++ b/src/Boo.Lang.Compiler/TypeSystem/Builders/BooClassBuilder.cs @@ -153,5 +153,16 @@ public GenericParameterDeclaration AddGenericParameter(string name) genericParameters.Add(declaration); return declaration; } + + public GenericParameterDeclaration AddGenericParameter(GenericParameterDeclaration source) + { + var result = AddGenericParameter(source.Name); + result.Constraints = source.Constraints; + foreach (var baseType in source.BaseTypes) + { + result.BaseTypes.Add(baseType.CloneNode()); + } + return result; + } } } \ No newline at end of file diff --git a/src/Boo.Lang.Compiler/TypeSystem/GenericTypeCollector.cs b/src/Boo.Lang.Compiler/TypeSystem/GenericTypeCollector.cs index 190a50bf9..7b631cf9a 100644 --- a/src/Boo.Lang.Compiler/TypeSystem/GenericTypeCollector.cs +++ b/src/Boo.Lang.Compiler/TypeSystem/GenericTypeCollector.cs @@ -83,8 +83,11 @@ public void Process(ClassDefinition cd) var parameters = GenericParameters.ToArray(); for (var i = 0; i < parameters.Length; ++i) { - var param = parameters[i]; - var gen = _codeBuilder.CreateGenericParameterDeclaration(i, param.Name); + var param = parameters[i]; + var gen = cd.GenericParameters.FirstOrDefault(gp => gp.Name.Equals(param.Name)); + var found = gen != null; + if (!found) + gen = _codeBuilder.CreateGenericParameterDeclaration(i, param.Name); foreach (IType baseType in param.GetTypeConstraints()) { gen.BaseTypes.Add(_codeBuilder.CreateTypeReference(baseType)); @@ -99,14 +102,14 @@ public void Process(ClassDefinition cd) gen.Constraints |= GenericParameterConstraints.Covariant; else if (param.Variance == Variance.Contravariant) gen.Constraints |= GenericParameterConstraints.Contravariant; - cd.GenericParameters.Add(gen); + if (!found) + cd.GenericParameters.Add(gen); } } private IEnumerable GenericParameters { - get { return _matches.Cast().Distinct(new DistinctGenericComparer()). - OrderBy(p => p.GenericParameterPosition); } + get { return _matches.Distinct(new DistinctGenericComparer()).OrderBy(p => p.GenericParameterPosition); } } } diff --git a/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericConstructedType.cs b/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericConstructedType.cs index 06e7587f9..df315c36c 100755 --- a/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericConstructedType.cs +++ b/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericConstructedType.cs @@ -29,6 +29,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text; using Boo.Lang.Compiler.TypeSystem.Core; using Boo.Lang.Compiler.TypeSystem.Internal; using Boo.Lang.Compiler.TypeSystem.Services; @@ -68,11 +69,43 @@ protected bool IsFullyConstructed() } protected string BuildFullName() - { + {/* + var sb = new StringBuilder(); + sb.Append(_definition.FullName); + sb.Append("[of "); + for (var i = 0; i < _arguments.Length; ++i) + { + sb.Append(_arguments[i].FullName); + if (i < _arguments.Length - 1) + sb.Append(", "); + } + sb.Append("]"); + return sb.ToString(); + */ return _definition.FullName; } - protected GenericMapping GenericMapping + public override bool Equals(object other) + { + if (other == null) + return false; + + if (other.GetType() != this.GetType()) + return false; + + var otherType = (GenericConstructedType) other; + if (otherType._definition != _definition) + return false; + + for (var i = 0; i < _arguments.Length; ++i) + { + if (!_arguments[i].Equals(otherType._arguments[i])) + return false; + } + return true; + } + + protected internal GenericMapping GenericMapping { get { return _genericMapping; } } diff --git a/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericMappedType.cs b/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericMappedType.cs new file mode 100644 index 000000000..50244b0cb --- /dev/null +++ b/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericMappedType.cs @@ -0,0 +1,241 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.TypeSystem.Core; +using Boo.Lang.Compiler.TypeSystem.Services; + +namespace Boo.Lang.Compiler.TypeSystem.Generics +{ + public interface IGenericMappedType : IType + { + IType SourceType { get; } + } + + public class GenericMappedType : IGenericMappedType + { + private readonly IType _sourceType; + private readonly GenericConstructedType _containingType; + + private static Dictionary, GenericMappedType> _cache = + new Dictionary, GenericMappedType>(); + + public static GenericMappedType Create(IType sourceType, GenericConstructedType containingType) + { + var pair = new KeyValuePair(sourceType, containingType); + GenericMappedType result; + if (!_cache.TryGetValue(pair, out result)) + { + result = new GenericMappedType(sourceType, containingType); + _cache[pair] = result; + } + return result; + } + + protected GenericMappedType(IType sourceType, GenericConstructedType containingType) + { + if (sourceType.DeclaringEntity != ((IConstructedTypeInfo)containingType).GenericDefinition) + throw new ArgumentException("Mapping type onto invalid container"); + + _sourceType = sourceType; + _containingType = containingType; + } + + public IType SourceType { get { return _sourceType; } } + + public GenericMapping GenericMapping + { + get { return _containingType.GenericMapping; } + } + + public string Name + { + get { return _sourceType.Name; } + } + + public string FullName + { + get { return _sourceType.FullName; } + } + + public EntityType EntityType + { + get { return EntityType.Type; } + } + + public IGenericTypeInfo GenericInfo + { + get { return _sourceType.GenericInfo; } + } + + public IConstructedTypeInfo ConstructedInfo + { + get { return null; } + } + + public IType Type + { + get { return this; } + } + + public IEntity DeclaringEntity + { + get { return _containingType; } + } + + public bool IsClass + { + get { return _sourceType.IsClass; } + } + + public bool IsAbstract + { + get { return _sourceType.IsAbstract; } + } + + public bool IsInterface + { + get { return _sourceType.IsInterface; } + } + + public bool IsEnum + { + get { return _sourceType.IsEnum; } + } + + public bool IsByRef + { + get { return _sourceType.IsByRef; } + } + + public bool IsValueType + { + get { return _sourceType.IsValueType; } + } + + public bool IsFinal + { + get { return _sourceType.IsFinal; } + } + + public bool IsArray + { + get { return _sourceType.IsArray; } + } + + public bool IsPointer + { + get { return _sourceType.IsPointer; } + } + + public int GetTypeDepth() + { + return _sourceType.GetTypeDepth(); + } + + public bool IsDefined(IType attributeType) + { + return _sourceType.IsDefined(GenericMapping.MapType(attributeType)); + } + + public IEnumerable GetMembers() + { + return _sourceType.GetMembers().Select(GenericMapping.Map); + } + + public INamespace ParentNamespace + { + get + { + return GenericMapping.Map(_sourceType.ParentNamespace) as INamespace; + } + } + + public bool Resolve(ICollection resultingSet, string name, EntityType typesToConsider) + { + var definitionMatches = new HashSet(); + if (!_sourceType.Resolve(definitionMatches, name, typesToConsider)) + return false; + foreach (var match in definitionMatches) + resultingSet.Add(GenericMapping.Map(match)); + return true; + } + + public IType BaseType + { + get { return GenericMapping.MapType(_sourceType.BaseType); } + } + + public IType ElementType + { + get { return GenericMapping.MapType(_sourceType.ElementType); } + } + + public IEntity GetDefaultMember() + { + IEntity definitionDefaultMember = _sourceType.GetDefaultMember(); + if (definitionDefaultMember != null) return GenericMapping.Map(definitionDefaultMember); + return null; + } + + public IType[] GetInterfaces() + { + return Array.ConvertAll( + _sourceType.GetInterfaces(), + GenericMapping.MapType); + } + + public virtual bool IsAssignableFrom(IType other) + { + if (other == null) + return false; + + if (other == this || other.IsSubclassOf(this) || (other.IsNull() && !IsValueType) || IsGenericAssignableFrom(other)) + return true; + + return false; + } + + public bool IsGenericAssignableFrom(IType other) + { + var gmt = other as GenericMappedType; + if (gmt == null) + return false; + + if (!this._containingType.IsGenericAssignableFrom(gmt._containingType)) + return false; + + var st = _sourceType; + return (st != null && st.IsAssignableFrom(gmt._sourceType)); + } + + public bool IsSubclassOf(IType other) + { + if (null == other) + return false; + + if (BaseType != null && (BaseType == other || BaseType.IsSubclassOf(other))) + { + return true; + } + + return other.IsInterface && Array.Exists( + GetInterfaces(), + i => TypeCompatibilityRules.IsAssignableFrom(other, i)); + } + + private ArrayTypeCache _arrayTypes; + + public IArrayType MakeArrayType(int rank) + { + if (null == _arrayTypes) + _arrayTypes = new ArrayTypeCache(this); + return _arrayTypes.MakeArrayType(rank); + } + + public IType MakePointerType() + { + return null; + } + } +} diff --git a/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericMapping.cs b/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericMapping.cs index 9862c89ac..0c787b272 100755 --- a/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericMapping.cs +++ b/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericMapping.cs @@ -27,6 +27,7 @@ #endregion using System; +using System.Linq; using System.Collections.Generic; using Boo.Lang.Compiler.TypeSystem.Generics; @@ -139,18 +140,27 @@ private IMember MapMember(IMember source) return CacheMember(source, CreateMappedMember(source)); } + IType[] genArgs = null; // If member is declared on a basetype of our source, that is itself constructed, let its own mapper map it IType declaringType = source.DeclaringType; if (declaringType.ConstructedInfo != null) { + var genMethod = source as IConstructedMethodInfo; + if (genMethod != null) + { + genArgs = genMethod.GenericArguments; + source = genMethod.GenericDefinition; + } source = declaringType.ConstructedInfo.UnMap(source); } IType mappedDeclaringType = MapType(declaringType); if (mappedDeclaringType.ConstructedInfo != null) { - return mappedDeclaringType.ConstructedInfo.Map(source); + source = mappedDeclaringType.ConstructedInfo.Map(source); } + if (genArgs != null) + source = ((IMethod)source).GenericInfo.ConstructMethod(genArgs.Select(MapType).ToArray()); return source; } diff --git a/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericsServices.cs b/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericsServices.cs index 24d057180..d04e23f4d 100755 --- a/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericsServices.cs +++ b/src/Boo.Lang.Compiler/TypeSystem/Generics/GenericsServices.cs @@ -29,6 +29,7 @@ using Boo.Lang.Compiler.Ast; using System.Collections.Generic; using System; +using System.Linq; using Boo.Lang.Compiler.TypeSystem.Core; namespace Boo.Lang.Compiler.TypeSystem.Generics @@ -311,6 +312,24 @@ public static IGenericParameter[] GetGenericParameters(IEntity definition) return null; } + public static int GetTypeConcreteness(IType type) + { + if (type.IsByRef || type.IsArray) + return GetTypeConcreteness(type.ElementType); + + if (type is IGenericParameter) + return 0; + + if (type.ConstructedInfo != null) + { + var result = 0; + foreach (IType typeArg in type.ConstructedInfo.GenericArguments) + result += GetTypeConcreteness(typeArg); + return result; + } + return 1; + } + /// /// Determines the number of open generic parameters in the specified type. /// diff --git a/src/Boo.Lang.Compiler/TypeSystem/Generics/TypeMapper.cs b/src/Boo.Lang.Compiler/TypeSystem/Generics/TypeMapper.cs index 3968002c8..c8fd19cb0 100755 --- a/src/Boo.Lang.Compiler/TypeSystem/Generics/TypeMapper.cs +++ b/src/Boo.Lang.Compiler/TypeSystem/Generics/TypeMapper.cs @@ -63,8 +63,9 @@ public virtual IType MapType(IType sourceType) if (sourceType.ConstructedInfo != null) return MapConstructedType(sourceType); - // TODO: Map nested types - // GenericType[of T].NestedType => GenericType[of int].NestedType + var de = sourceType.DeclaringEntity; + if (de != null && de.EntityType == EntityType.Type && !(sourceType is IGenericParameter)) + return MapNestedType(sourceType); var array = sourceType as IArrayType; if (array != null) @@ -76,7 +77,7 @@ public virtual IType MapType(IType sourceType) : sourceType; } - public virtual IType MapByRefType(IType sourceType) + public virtual IType MapByRefType(IType sourceType) { var et = sourceType.ElementType; if (sourceType.IsAssignableFrom(et)) @@ -111,12 +112,31 @@ public virtual IType MapConstructedType(IType sourceType) sourceType.ConstructedInfo.GenericArguments, MapType); - IType mapped = mappedDefinition.GenericInfo.ConstructType(mappedArguments); + var genericInfo = mappedDefinition.GenericInfo ?? mappedDefinition.ConstructedInfo.GenericDefinition.GenericInfo; + IType mapped = genericInfo.ConstructType(mappedArguments); return mapped; } - internal IParameter[] MapParameters(IParameter[] parameters) + public virtual IType MapNestedType(IType sourceType) + { + var containingType = (IType)sourceType.DeclaringEntity; + var mappedContainingType = MapType(containingType); + if (containingType == mappedContainingType) + { + return sourceType; + } + + var mt = sourceType as IGenericMappedType; + if (mt != null) + { + sourceType = mt.SourceType; + } + + return GenericMappedType.Create(sourceType, (GenericConstructedType)mappedContainingType); + } + + internal IParameter[] MapParameters(IParameter[] parameters) { return Array.ConvertAll(parameters, MapParameter); } diff --git a/src/Boo.Lang.Compiler/TypeSystem/Reflection/ExternalType.cs b/src/Boo.Lang.Compiler/TypeSystem/Reflection/ExternalType.cs index e1c9b805f..49953768f 100755 --- a/src/Boo.Lang.Compiler/TypeSystem/Reflection/ExternalType.cs +++ b/src/Boo.Lang.Compiler/TypeSystem/Reflection/ExternalType.cs @@ -219,6 +219,16 @@ public virtual bool IsAssignableFrom(IType other) { return !IsValueType; } + if (other.ConstructedInfo != null && this.ConstructedInfo != null + && ConstructedInfo.GenericDefinition == other.ConstructedInfo.GenericDefinition) + { + for (int i = 0; i < ConstructedInfo.GenericArguments.Length; ++i) + { + if (!ConstructedInfo.GenericArguments[i].IsAssignableFrom(other.ConstructedInfo.GenericArguments[i])) + return false; + } + return true; + } return other.IsSubclassOf(this); } if (other == _provider.Map(Types.Void)) diff --git a/src/Boo.Lang.Compiler/TypeSystem/Services/AnonymousCallablesManager.cs b/src/Boo.Lang.Compiler/TypeSystem/Services/AnonymousCallablesManager.cs index aa3b38c9c..4a2ff466c 100755 --- a/src/Boo.Lang.Compiler/TypeSystem/Services/AnonymousCallablesManager.cs +++ b/src/Boo.Lang.Compiler/TypeSystem/Services/AnonymousCallablesManager.cs @@ -108,16 +108,17 @@ private IType CreateConcreteCallableType(Node sourceNode, AnonymousCallableType cd.Members.Add(beginInvoke); cd.Members.Add(CreateEndInvokeMethod(anonymousType)); - AddGenericTypes(cd); + AddGenericTypes(cd, sourceNode.NodeType != NodeType.BlockExpression); module.Members.Add(cd); return (IType)cd.Entity; } - private void AddGenericTypes(ClassDefinition cd) + private void AddGenericTypes(ClassDefinition cd, bool adaptInnerGenerics) { var collector = new GenericTypeCollector(this.CodeBuilder); collector.Process(cd); - + if (!adaptInnerGenerics) return; + var counter = cd.GenericParameters.Count; var innerCollector = new DetectInnerGenerics(); cd.Accept(innerCollector); diff --git a/src/Boo.Lang.Compiler/TypeSystem/Services/AsyncHelper.cs b/src/Boo.Lang.Compiler/TypeSystem/Services/AsyncHelper.cs new file mode 100644 index 000000000..9e5a310d4 --- /dev/null +++ b/src/Boo.Lang.Compiler/TypeSystem/Services/AsyncHelper.cs @@ -0,0 +1,84 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Boo.Lang.Compiler.Ast; +using Boo.Lang.Compiler.Steps; +using Boo.Lang.Compiler.TypeSystem.Services; +using Boo.Lang.Environments; + +namespace Boo.Lang.Compiler.TypeSystem +{ + public static class AsyncHelper + { + public static bool ValidAsyncTypeUnbound(Method value) + { + var ret = value.ReturnType; + if (ret == null) + return true; + var typeRef = ret as SimpleTypeReference; + return typeRef != null; //trying to resolve this too early has too many edge cases + } + + internal static bool ValidAsyncTypeBound(Method value) + { + var ret = (IType)value.ReturnType.Entity; + var tss = My.Instance; + return ret == tss.VoidType || + ret == tss.TaskType || + TypeCompatibilityRules.IsAssignableFrom(tss.GenericTaskType, ret); + } + + private static IMethod GetNoArgs(IEnumerable value, TypeSystemServices tss) + { + return value.Cast().SingleOrDefault(m => m.GetParameters().Length == 0); + } + + private static IMethod GetNoArgsNoVoid(IEnumerable value, TypeSystemServices tss) + { + return value.Cast().SingleOrDefault(m => + ((m.GetParameters().Length == 0) || (m.IsExtension && m.GetParameters().Length == 1)) && + m.ReturnType != tss.VoidType); + } + + public static IType GetAwaitType(Expression value) + { + var type = value.ExpressionType; + var tss = My.Instance; + if (type == tss.TaskType) + return type; + if (type.ConstructedInfo != null && type.ConstructedInfo.GenericDefinition == tss.GenericTaskType) + return type.ConstructedInfo.GenericArguments[0]; + + var awaiterSet = new List(); + IEntity[] candidates = type.Resolve(awaiterSet, "GetAwaiter", EntityType.Method) ? + awaiterSet.ToArray() : + tss.FindExtension(type, "GetAwaiter"); + + var awaiter = GetNoArgsNoVoid(candidates , tss); + if (awaiter == null) + return null; + value["$GetAwaiter"] = awaiter; + var awaiterType = awaiter.ReturnType; + if (awaiterType == null || awaiterType == tss.VoidType) + return null; + awaiterSet.Clear(); + if (awaiterType.Resolve(awaiterSet, "GetResult", EntityType.Method)) + { + var getResult = GetNoArgs(awaiterSet, tss); + if (getResult == null) + return null; + value["$GetResult"] = getResult; + if (getResult.ReturnType == tss.VoidType) + return tss.TaskType; + return getResult.ReturnType; + } + return null; + } + + internal static bool InAsyncMethod(Expression value) + { + INodeWithBody ancestor = value.GetAncestor() ?? (INodeWithBody) value.GetAncestor(); + return ContextAnnotations.IsAsync(ancestor); + } + } +} diff --git a/src/Boo.Lang.Compiler/TypeSystem/Services/BooCodeBuilder.cs b/src/Boo.Lang.Compiler/TypeSystem/Services/BooCodeBuilder.cs index de943c186..e74071437 100644 --- a/src/Boo.Lang.Compiler/TypeSystem/Services/BooCodeBuilder.cs +++ b/src/Boo.Lang.Compiler/TypeSystem/Services/BooCodeBuilder.cs @@ -33,6 +33,7 @@ using System.Reflection; using Boo.Lang.Compiler.Ast; using Boo.Lang.Compiler.Services; +using Boo.Lang.Compiler.Steps; using Boo.Lang.Compiler.TypeSystem.Builders; using Boo.Lang.Compiler.TypeSystem.Generics; using Boo.Lang.Compiler.TypeSystem.Internal; @@ -164,7 +165,21 @@ public Expression CreateCast(IType type, Expression target) return expression; } - public TypeofExpression CreateTypeofExpression(IType type) + public Expression CreateAsCast(IType type, Expression target) + { + if (type == target.ExpressionType) + return target; + + var expression = new TryCastExpression(target.LexicalInfo) + { + Type = CreateTypeReference(type), + Target = target, + ExpressionType = type + }; + return expression; + } + + public TypeofExpression CreateTypeofExpression(IType type) { return new TypeofExpression { @@ -348,12 +363,22 @@ public SelfLiteralExpression CreateSelfReference(IType expressionType) return new SelfLiteralExpression { ExpressionType = expressionType }; } - public ReferenceExpression CreateLocalReference(string name, InternalLocal entity) + public SelfLiteralExpression CreateSelfReference(IMethod method, IType expressionType) + { + return new SelfLiteralExpression { ExpressionType = expressionType, Entity = method }; + } + + public ReferenceExpression CreateLocalReference(string name, InternalLocal entity) { return CreateTypedReference(name, entity); } - public ReferenceExpression CreateTypedReference(string name, ITypedEntity entity) + public ReferenceExpression CreateLocalReference(InternalLocal entity) + { + return CreateTypedReference(entity.Name, entity); + } + + public ReferenceExpression CreateTypedReference(string name, ITypedEntity entity) { ReferenceExpression expression = new ReferenceExpression(name); expression.Entity = entity; @@ -386,6 +411,18 @@ public MemberReferenceExpression CreateReference(LexicalInfo li, Field field) return e; } + public MemberReferenceExpression CreateMappedReference(LexicalInfo nodeLexicalInfo, + Field field, + IType type) + { + if (type.GenericInfo != null && type.ConstructedInfo == null) + type = TypeSystemServices.SelfMapGenericType(type); + var entity = type.ConstructedInfo != null ? + (IField)type.ConstructedInfo.Map((IField)field.Entity) : + (IField)field.Entity; + return CreateReference(entity); + } + public MemberReferenceExpression CreateReference(Field field) { return CreateReference((IField)field.Entity); @@ -443,6 +480,16 @@ public MemberReferenceExpression MemberReferenceForEntity(Expression target, IEn MemberReferenceExpression reference = new MemberReferenceExpression(target.LexicalInfo); reference.Target = target; reference.Name = entity.Name; + var genType = target.ExpressionType as IConstructedTypeInfo; + if (genType != null && entity is IMember && !(entity is IGenericMappedMember)) + { + var gcm = entity as GenericConstructedMethod; + if (gcm != null) + entity = gcm.GenericDefinition; + entity = genType.Map((IMember)entity); + if (gcm != null) + entity = ((GenericMappedMethod)entity).GenericInfo.ConstructMethod(gcm.GenericArguments); + } reference.Entity = entity; return reference; } @@ -485,7 +532,9 @@ public ReferenceExpression CreateReference(IType type) { if (type.DeclaringEntity is GenericConstructedType) return CreateNestedReference(type); - return new ReferenceExpression(type.FullName) {Entity = type, IsSynthetic = true}; + if (type.GenericInfo != null) + type = type.GenericInfo.ConstructType(type.GenericInfo.GenericParameters); + return new ReferenceExpression(type.FullName) {Entity = type, ExpressionType = type, IsSynthetic = true}; } public MethodInvocationExpression CreateEvalInvocation(LexicalInfo li) @@ -512,6 +561,27 @@ public MethodInvocationExpression CreateEvalInvocation(LexicalInfo li, Expressio return eval; } + public MethodInvocationExpression CreateEvalInvocation(LexicalInfo li, params Expression[] args) + { + MethodInvocationExpression eval = CreateEvalInvocation(li); + IType et = null; + foreach (var arg in args) + { + eval.Arguments.Add(arg); + et = arg.ExpressionType; + } + eval.ExpressionType = et; + return eval; + } + + public MethodInvocationExpression CreateDefaultInvocation(LexicalInfo li, IType type) + { + var result = CreateBuiltinInvocation(li, BuiltinFunction.Default); + result.Arguments.Add(CreateTypeofExpression(type)); + result.ExpressionType = type; + return result; + } + public UnpackStatement CreateUnpackStatement(DeclarationCollection declarations, Expression expression) { UnpackStatement unpack = new UnpackStatement(expression.LexicalInfo); @@ -686,14 +756,13 @@ public MethodInvocationExpression CreateGenericConstructorInvocation(IType class return CreateConstructorInvocation(constructor); } - classType = new Generics.GenericConstructedType( + classType = new GenericConstructedType( classType, genericArgs.Select(a => a.Entity).Cast().ToArray()); constructor = classType.GetConstructors().First(); - var result = new MethodInvocationExpression(); - result.Target = CreateGenericReference(CreateReference(constructor.DeclaringType), genericArgs); - result.Target.Entity = constructor; + var result = new MethodInvocationExpression {Target = CreateReference(constructor.DeclaringType)}; + result.Target.Entity = constructor; result.ExpressionType = constructor.DeclaringType; return result; @@ -797,7 +866,18 @@ public Method CreateMethod(string name, TypeReference returnType, TypeMemberModi return method; } - public Property CreateProperty(string name, IType type) + public Method CreateGenericMethod(string name, TypeReference returnType, TypeMemberModifiers modifiers, GenericParameterDeclaration[] genParams) + { + Method method = new Method(name); + method.Modifiers = modifiers; + method.ReturnType = returnType; + method.IsSynthetic = true; + method.GenericParameters.AddRange(genParams); + EnsureEntityFor(method); + return method; + } + + public Property CreateProperty(string name, IType type) { Property property = new Property(name); property.Modifiers = TypeMemberModifiers.Public; @@ -961,6 +1041,19 @@ public RaiseStatement RaiseException(LexicalInfo lexicalInfo, IConstructor excep return new RaiseStatement(lexicalInfo, CreateConstructorInvocation(lexicalInfo, exceptionConstructor, args)); } + public TryStatement CreateTryExcept(LexicalInfo lexicalInfo, Block protecteBlock, + params ExceptionHandler[] handlers) + { + var result = new TryStatement(lexicalInfo) {ProtectedBlock = protecteBlock}; + result.ExceptionHandlers.AddRange(handlers); + return result; + } + + public ExceptionHandler CreateExceptionHandler(LexicalInfo lexicalInfo, Declaration definition, Block body) + { + return new ExceptionHandler(lexicalInfo){ Declaration = definition, Block = body}; + } + public InternalLocal DeclareTempLocal(Method node, IType type) { var local = DeclareLocal(node, My.Instance.GetUniqueName(), type); @@ -977,7 +1070,16 @@ public InternalLocal DeclareLocal(Method node, string name, IType type) return entity; } - public void BindParameterDeclarations(bool isStatic, INodeWithParameters node) + public Declaration CreateDeclaration(Method method, string name, IType type, out InternalLocal local) + { + var result = new Declaration(name, CreateTypeReference(type)); + local = this.DeclareLocal(method, name, type); + method.Locals.Add(local.Local); + result.Entity = local; + return result; + } + + public void BindParameterDeclarations(bool isStatic, INodeWithParameters node) { // arg0 is the this pointer when member is not static int delta = isStatic ? 0 : 1; @@ -1005,7 +1107,31 @@ public InternalLabel CreateLabel(Node sourceNode, string name) return new InternalLabel(new LabelStatement(sourceNode.LexicalInfo, name)); } - public TypeMember CreateStub(ClassDefinition node, IMember member) + public InternalLabel CreateLabel(Node sourceNode, string name, int depth) + { + var result = CreateLabel(sourceNode, name); + AstAnnotations.SetTryBlockDepth(result.LabelStatement, depth); + return result; + } + + public GotoStatement CreateGoto(LexicalInfo li, InternalLabel target) + { + return new GotoStatement(li, CreateLabelReference(target.LabelStatement)); + } + + public GotoStatement CreateGoto(InternalLabel target) + { + return CreateGoto(LexicalInfo.Empty, target); + } + + public GotoStatement CreateGoto(InternalLabel target, int depth) + { + var result = CreateGoto(LexicalInfo.Empty, target); + AstAnnotations.SetTryBlockDepth(result, depth); + return result; + } + + public TypeMember CreateStub(ClassDefinition node, IMember member) { IMethod baseMethod = member as IMethod; if (null != baseMethod) diff --git a/src/Boo.Lang.Compiler/TypeSystem/Services/CallableResolutionService.cs b/src/Boo.Lang.Compiler/TypeSystem/Services/CallableResolutionService.cs index 37cfb0df7..6f6497a97 100755 --- a/src/Boo.Lang.Compiler/TypeSystem/Services/CallableResolutionService.cs +++ b/src/Boo.Lang.Compiler/TypeSystem/Services/CallableResolutionService.cs @@ -432,11 +432,15 @@ private int MoreSpecific(IType t1, IType t2) if (t1.IsArray && t2.IsArray || t1.IsByRef && t2.IsByRef) return MoreSpecific(t1.ElementType, t2.ElementType); - // The less-generic type is more specific. + // A more concrete type is more specfic + int result = GenericsServices.GetTypeConcreteness(t1) - GenericsServices.GetTypeConcreteness(t2); + if (result != 0) return result; + + // With equal concreteness, the more generic type is more specific. //First search for open args, then for all args - int result = GenericsServices.GetTypeGenerity(t2) - GenericsServices.GetTypeGenerity(t1); + result = GenericsServices.GetTypeGenerity(t1) - GenericsServices.GetTypeGenerity(t2); if (result != 0) return result; - result = GenericsServices.GetTypeGenericDepth(t2) - GenericsServices.GetTypeGenericDepth(t1); + result = GenericsServices.GetTypeGenericDepth(t1) - GenericsServices.GetTypeGenericDepth(t2); if (result != 0) return result; // If both types have the same genrity, the deeper-nested type is more specific diff --git a/src/Boo.Lang.Compiler/TypeSystem/Services/TypeSystemServices.cs b/src/Boo.Lang.Compiler/TypeSystem/Services/TypeSystemServices.cs index 1a6bbc03f..9844616eb 100755 --- a/src/Boo.Lang.Compiler/TypeSystem/Services/TypeSystemServices.cs +++ b/src/Boo.Lang.Compiler/TypeSystem/Services/TypeSystemServices.cs @@ -82,6 +82,8 @@ public class TypeSystemServices public IType IListGenericType; public IType IListType; + public IType NullableGenericType; + public IType IEnumeratorGenericType; public IType IEnumeratorType; public IType IQuackFuType; @@ -113,6 +115,14 @@ public class TypeSystemServices public IType ValueTypeType; public IType VoidType; + public IType TaskType; + public IType GenericTaskType; + public IType AsyncGenericTaskMethodBuilderType; + public IType AsyncTaskMethodBuilderType; + public IType AsyncVoidMethodBuilderType; + public IType IAsyncStateMachineType; + public IType GenericFuncType; + private Module _compilerGeneratedTypesModule; private readonly Set _literalPrimitives = new Set(); private readonly Dictionary _primitives = new Dictionary(StringComparer.Ordinal); @@ -188,8 +198,16 @@ public TypeSystemServices(CompilerContext context) ICollectionGenericType = Map(typeof(ICollection<>)); IListGenericType = Map(typeof (IList<>)); IListType = Map(typeof (IList)); + NullableGenericType = Map(Types.Nullable); IAstMacroType = Map(typeof(IAstMacro)); IAstGeneratorMacroType = Map(typeof(IAstGeneratorMacro)); + TaskType = Map(typeof(System.Threading.Tasks.Task)); + GenericTaskType = Map(typeof(System.Threading.Tasks.Task<>)); + AsyncGenericTaskMethodBuilderType = Map(typeof(System.Runtime.CompilerServices.AsyncTaskMethodBuilder<>)); + AsyncTaskMethodBuilderType = Map(typeof(System.Runtime.CompilerServices.AsyncTaskMethodBuilder)); + AsyncVoidMethodBuilderType = Map(typeof(System.Runtime.CompilerServices.AsyncVoidMethodBuilder)); + IAsyncStateMachineType = Map(typeof(System.Runtime.CompilerServices.IAsyncStateMachine)); + GenericFuncType = Map(typeof(System.Func<>)); ObjectArrayType = ObjectType.MakeArrayType(1); @@ -515,6 +533,11 @@ public Module GetCompilerGeneratedTypesModule() return _compilerGeneratedTypesModule ?? (_compilerGeneratedTypesModule = NewModule("CompilerGenerated")); } + public bool CompilerGeneratedTypesModuleExists() + { + return _compilerGeneratedTypesModule != null; + } + private Module NewModule(string nameSpace) { return NewModule(nameSpace, nameSpace); @@ -586,7 +609,7 @@ private IMethod FindConversionOperator(string name, IType fromType, IType toType return null; } - private IEntity[] FindExtension(IType fromType, string name) + internal IEntity[] FindExtension(IType fromType, string name) { IEntity extension = NameResolutionService.ResolveExtension(fromType, name); if (null == extension) return Ambiguous.NoEntities; @@ -1125,5 +1148,80 @@ public IType MapWildcardType(IType type) return ObjectArrayType; return type; } + + private static bool SameOrEquivalentGenericTypes(IType t1, IType t2, ref bool genericType) + { + if (t1 == t2) return true; + var g1 = t1 as IGenericParameter; + var g2 = t2 as IGenericParameter; + if (g1 == null || g2 == null) + { + var c1 = t1 as IConstructedTypeInfo; + var c2 = t2 as IConstructedTypeInfo; + if (c1 == null || c2 == null) + return false; + if (c1.GenericDefinition != c2.GenericDefinition) + return false; + for (var i = 0; i < c1.GenericArguments.Length; ++i) + { + if (!SameOrEquivalentGenericTypes(c1.GenericArguments[i], c2.GenericArguments[i], ref genericType)) + return false; + } + return true; + } + genericType = true; + var constraints = g2.GetTypeConstraints(); + if (constraints.Length > 0 && !constraints.Any(c => TypeCompatibilityRules.IsAssignableFrom(g1, c))) + return false; + return (g1.Variance == g2.Variance && g1.MustHaveDefaultConstructor == g2.MustHaveDefaultConstructor); + } + + public static bool CompatibleSignatures(CallableSignature sig1, CallableSignature sig2) + { + if (sig1.Parameters.Length != sig2.Parameters.Length) + return false; + if (sig1.AcceptVarArgs != sig2.AcceptVarArgs) + return false; + for (var i = 0; i < sig1.Parameters.Length; ++i) + { + var p1 = sig1.Parameters[i]; + var p2 = sig2.Parameters[i]; + if (p1.IsByRef != p2.IsByRef) + return false; + if (p1.Type != p2.Type) + return false; + } + return sig1.ReturnType == sig2.ReturnType; + } + + public static bool CompatibleGenericSignatures(CallableSignature sig1, CallableSignature sig2) + { + if (sig1.Parameters.Length != sig2.Parameters.Length) + return false; + if (sig1.AcceptVarArgs != sig2.AcceptVarArgs) + return false; + var seenGeneric = false; + for (var i = 0; i < sig1.Parameters.Length; ++i) + { + var p1 = sig1.Parameters[i]; + var p2 = sig2.Parameters[i]; + if (p1.IsByRef != p2.IsByRef) + return false; + if (!SameOrEquivalentGenericTypes(p1.Type, p2.Type, ref seenGeneric)) + return false; + } + if (!SameOrEquivalentGenericTypes(sig1.ReturnType, sig2.ReturnType, ref seenGeneric)) + return false; + return seenGeneric; + } + + public static IType SelfMapGenericType(IType type) + { + if (type.GenericInfo != null && type.ConstructedInfo == null) + return type.GenericInfo.ConstructType( + Array.ConvertAll(type.GenericInfo.GenericParameters, gp => (IType)gp)); + return type; + } + } } \ No newline at end of file diff --git a/src/Boo.Lang.Compiler/TypeSystem/TypeVisitor.cs b/src/Boo.Lang.Compiler/TypeSystem/TypeVisitor.cs index e27b0fca8..99a2d30c4 100644 --- a/src/Boo.Lang.Compiler/TypeSystem/TypeVisitor.cs +++ b/src/Boo.Lang.Compiler/TypeSystem/TypeVisitor.cs @@ -42,7 +42,13 @@ public virtual void Visit(IType type) if (type.IsByRef) VisitByRefType(type); - if (type.ConstructedInfo != null) VisitConstructedType(type); + if (type.ConstructedInfo != null) + VisitConstructedType(type); + else if (type.GenericInfo != null) + { + foreach (var gp in type.GenericInfo.GenericParameters) + Visit(gp); + } ICallableType callableType = type as ICallableType; if (callableType != null) VisitCallableType(callableType); @@ -67,14 +73,27 @@ public virtual void VisitConstructedType(IType constructedType) } } - public virtual void VisitCallableType(ICallableType callableType) - { - CallableSignature sig = callableType.GetSignature(); - foreach (IParameter parameter in sig.Parameters) - { - Visit(parameter.Type); - } - Visit(sig.ReturnType); + private Stack _inCallableTypes = new Stack(); + + public virtual void VisitCallableType(ICallableType callableType) + { + if (_inCallableTypes.Contains(callableType)) + return; + + _inCallableTypes.Push(callableType); + try + { + CallableSignature sig = callableType.GetSignature(); + foreach (IParameter parameter in sig.Parameters) + { + Visit(parameter.Type); + } + Visit(sig.ReturnType); + } + finally + { + _inCallableTypes.Pop(); + } } } } diff --git a/src/Boo.Lang.Extensions/Attributes/AsyncAttribute.boo b/src/Boo.Lang.Extensions/Attributes/AsyncAttribute.boo new file mode 100644 index 000000000..db758d859 --- /dev/null +++ b/src/Boo.Lang.Extensions/Attributes/AsyncAttribute.boo @@ -0,0 +1,52 @@ +#region license +// Copyright (c) 2017, Mason Wheeler +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Rodrigo B. de Oliveira nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#endregion + +namespace Boo.Lang.Extensions + +import System +import Boo.Lang.Compiler +import Boo.Lang.Compiler.Ast +import Boo.Lang.Compiler.Steps +import Boo.Lang.Compiler.TypeSystem + +class AsyncAttribute(AbstractAstAttribute): + + def constructor(): + pass + + override def Apply(node as Node): + if node.NodeType != NodeType.Method: + InvalidNodeForAttribute('Method') + return + + method as Method = node cast Method + if not AsyncHelper.ValidAsyncTypeUnbound(method): + Errors.Add(CompilerErrorFactory.InvalidAsyncType(method.ReturnType)) + return + + ContextAnnotations.MarkAsync(method) \ No newline at end of file diff --git a/src/Boo.Lang.Extensions/MetaMethods/Async.boo b/src/Boo.Lang.Extensions/MetaMethods/Async.boo new file mode 100644 index 000000000..1abef978d --- /dev/null +++ b/src/Boo.Lang.Extensions/MetaMethods/Async.boo @@ -0,0 +1,38 @@ +#region license +// Copyright (c) 2017, Mason Wheeler +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Rodrigo B. de Oliveira nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#endregion + +namespace Boo.Lang.Extensions + +import System +import Boo.Lang.Compiler.Ast +import Boo.Lang.Compiler.Steps + +[meta] +def async(block as BlockExpression): + ContextAnnotations.MarkAsync(block) + return AsyncBlockExpression(block) \ No newline at end of file diff --git a/src/Boo.Lang.Extensions/MetaMethods/Await.boo b/src/Boo.Lang.Extensions/MetaMethods/Await.boo new file mode 100644 index 000000000..2c9b3a7d9 --- /dev/null +++ b/src/Boo.Lang.Extensions/MetaMethods/Await.boo @@ -0,0 +1,36 @@ +#region license +// Copyright (c) 2017, Mason Wheeler +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Rodrigo B. de Oliveira nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#endregion + +namespace Boo.Lang.Extensions + +import System +import Boo.Lang.Compiler.Ast + +[meta] +def await(e as Expression): + return AwaitExpression(e) \ No newline at end of file diff --git a/src/Boo.Lang.Parser/Boo.Lang.Parser.csproj b/src/Boo.Lang.Parser/Boo.Lang.Parser.csproj index 5dc8490c0..d6c1d7bf3 100644 --- a/src/Boo.Lang.Parser/Boo.Lang.Parser.csproj +++ b/src/Boo.Lang.Parser/Boo.Lang.Parser.csproj @@ -12,7 +12,8 @@ Boo.Lang.Parser true ..\boo.snk - v3.5 + v4.5 + true @@ -27,6 +28,7 @@ 4194304 4096 AnyCPU + false pdbonly @@ -40,6 +42,7 @@ 4194304 4096 AnyCPU + false true @@ -54,6 +57,7 @@ 4194304 4096 AnyCPU + false False @@ -73,6 +77,7 @@ prompt 4 false + false bin\Micro-Release\ @@ -81,6 +86,7 @@ pdbonly prompt 4 + false diff --git a/src/Boo.Lang/Resources/StringResources.cs b/src/Boo.Lang/Resources/StringResources.cs index e7dae958c..c8c7527b6 100644 --- a/src/Boo.Lang/Resources/StringResources.cs +++ b/src/Boo.Lang/Resources/StringResources.cs @@ -180,6 +180,11 @@ public static class StringResources public const string BCE0175 = "Nested type '{0}' cannot extend enclosing type '{1}'."; public const string BCE0176 = "Incompatible partial definition for type '{0}', expecting '{1}' but got '{2}'."; public const string BCE0177 = "Default expression requires a type."; + public const string BCE0178 = "Async methods must return void, Task, or Task"; + public const string BCE0179 = "Await requires an expression of type Task or Task"; + public const string BCE0180 = "Type '{0}' is not valid in an async method"; + public const string BCE0181 = "Unsafe method calls returning a pointer are not valid in an async method"; + public const string BCE0182 = "Type {0} does not contain a valid GetAwaiter method"; public const string BCW0000 = "WARNING: {0}"; public const string BCW0001 = "WARNING: Type '{0}' does not provide an implementation for '{1}' and will be marked abstract."; public const string BCW0002 = "WARNING: Statement modifiers have no effect in labels."; @@ -210,6 +215,7 @@ public static class StringResources public const string BCW0027 = "WARNING: Obsolete syntax '{0}'. Use '{1}'."; public const string BCW0028 = "WARNING: Implicit downcast from '{0}' to '{1}'."; public const string BCW0029 = "WARNING: Method '{0}' hides inherited non virtual method '{1}'. Declare '{0}' as a 'new' method."; + public const string BCW0030 = "WARNING: This async method lacks \'await\' operators and will run synchronously. Consider using the \'await\' operator to await non-blocking API calls, or \'await Task.Run(...)\' to do CPU-bound work on a background thread."; public const string BCE0500 = "Response file '{0}' listed more than once."; public const string BCE0501 = "Response file '{0}' could not be found."; public const string BCE0502 = "An error occurred while loading response file '{0}'."; diff --git a/src/booc/app.config b/src/booc/app.config index cb2586beb..b7a7ef166 100755 --- a/src/booc/app.config +++ b/src/booc/app.config @@ -1,3 +1,3 @@ - + diff --git a/src/booc/booc.csproj b/src/booc/booc.csproj index 7e3e60e27..327de505b 100755 --- a/src/booc/booc.csproj +++ b/src/booc/booc.csproj @@ -11,8 +11,9 @@ 4194304 AnyCPU 4096 - v3.5 + v4.5 9.0.30729 + 4 @@ -20,6 +21,7 @@ DEBUG;TRACE ..\..\ide-build none + false 4 @@ -27,12 +29,14 @@ TRACE ..\..\ide-build none + false bin\Micro-Debug\ DEBUG;TRACE 4096 AnyCPU + false bin\Micro-Release\ @@ -40,6 +44,7 @@ true 4096 AnyCPU + false diff --git a/tests/BooCompiler.Tests/AsyncTestFixture.cs b/tests/BooCompiler.Tests/AsyncTestFixture.cs new file mode 100644 index 000000000..e41933d05 --- /dev/null +++ b/tests/BooCompiler.Tests/AsyncTestFixture.cs @@ -0,0 +1,275 @@ +// Test suite ported from Roslyn tests found at +// https://github.com/dotnet/roslyn/blob/master/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenAsyncTests.cs + +namespace BooCompiler.Tests +{ + using NUnit.Framework; + + [TestFixture] + public class AsyncTestFixture : AbstractCompilerTestCase + { + [Test] + public void async_conformance_awaiting_indexer() + { + RunCompilerTestCase(@"async-conformance-awaiting-indexer.boo"); + } + + [Test] + public void async_delegates() + { + RunCompilerTestCase(@"async-delegates.boo"); + } + + [Test] + public void async_extension_add_method() + { + RunCompilerTestCase(@"async-extension-add-method.boo"); + } + + [Test] + public void async_hello_world() + { + RunCompilerTestCase(@"async-hello-world.boo"); + } + + [Test] + public void async_method_only_writes_to_enclosing_struct() + { + RunCompilerTestCase(@"async-method-only-writes-to-enclosing-struct.boo"); + } + + [Test] + public void async_state_machine_struct_task_t() + { + RunCompilerTestCase(@"async-state-machine-struct-task-t.boo"); + } + + [Test] + public void await_in_delegate_constructor() + { + RunCompilerTestCase(@"await-in-delegate-constructor.boo"); + } + + [Test] + public void await_in_obj_initializer() + { + RunCompilerTestCase(@"await-in-obj-initializer.boo"); + } + + [Test] + public void await_in_using_and_for() + { + RunCompilerTestCase(@"await-in-using-and-for.boo"); + } + + [Test] + public void await_switch() + { + RunCompilerTestCase(@"await-switch.boo"); + } + + [Test] + public void await_void() + { + RunCompilerTestCase(@"await-void.boo"); + } + + [Test] [Ignore("Requires better closure signature inferring")] + public void better_conversion_from_async_lambda() + { + RunCompilerTestCase(@"better-conversion-from-async-lambda.boo"); + } + + [Test] + public void conformance_awaiting_methods_accessible() + { + RunCompilerTestCase(@"conformance-awaiting-methods-accessible.boo"); + } + + [Test] + public void conformance_awaiting_methods_generic() + { + RunCompilerTestCase(@"conformance-awaiting-methods-generic.boo"); + } + + [Test][Ignore("This will fail until Run and RunEx are merged back together")] + public void conformance_awaiting_methods_method() + { + RunCompilerTestCase(@"conformance-awaiting-methods-method.boo"); + } + + [Test] + public void conformance_awaiting_methods_method02() + { + RunCompilerTestCase(@"conformance-awaiting-methods-method02.boo"); + } + + [Test] + public void conformance_awaiting_methods_parameter() + { + RunCompilerTestCase(@"conformance-awaiting-methods-parameter.boo"); + } + + [Test] + public void conformance_exceptions_async_await_names() + { + RunCompilerTestCase(@"conformance-exceptions-async-await-names.boo"); + + } + + [Test][Ignore("Requires better closure signature inferring")] + public void conformance_overload_resolution_class_generic_regular_method() + { + RunCompilerTestCase(@"conformance-overload-resolution-class-generic-regular-method.boo"); + } + + [Test] + public void cs_bug_602246() + { + RunCompilerTestCase(@"cs-bug-602246.boo"); + } + + [Test] + public void cs_bug_748527() + { + RunCompilerTestCase(@"cs-bug-748527.boo"); + } + + [Test] + public void delegate_async() + { + RunCompilerTestCase(@"delegate-async.boo"); + } + + [Test] + public void generic_async_lambda() + { + RunCompilerTestCase(@"generic-async-lambda.boo"); + } + + [Test] + public void generic_task_returning_async() + { + RunCompilerTestCase(@"generic-task-returning-async.boo"); + } + + [Test] + public void generic() + { + RunCompilerTestCase(@"generic.boo"); + } + + [Test] + public void hoist_structure() + { + RunCompilerTestCase(@"hoist-structure.boo"); + } + + [Test] + public void hoist_using_1() + { + RunCompilerTestCase(@"hoist-using-1.boo"); + } + + [Test] + public void hoist_using_2() + { + RunCompilerTestCase(@"hoist-using-2.boo"); + } + + [Test] + public void hoist_using_3() + { + RunCompilerTestCase(@"hoist-using-3.boo"); + } + + [Test] + public void infer_from_async_lambda() + { + RunCompilerTestCase(@"infer-from-async-lambda.boo"); + } + + [Test] + public void inference() + { + RunCompilerTestCase(@"inference.boo"); + } + + [Test] + public void init() + { + RunCompilerTestCase(@"init.boo"); + } + + [Test] + public void is_and_as_operators() + { + RunCompilerTestCase(@"is-and-as-operators.boo"); + } + + [Test] + public void mutating_array_of_structs() + { + RunCompilerTestCase(@"mutating-array-of-structs.boo"); + } + + [Test] + public void mutating_struct_with_using() + { + RunCompilerTestCase(@"mutating-struct-with-using.boo"); + } + + [Test] + public void my_task_2() + { + RunCompilerTestCase(@"my-task-2.boo"); + } + + [Test] + public void my_task() + { + RunCompilerTestCase(@"my-task.boo"); + } + + [Test] + public void premature_null() + { + RunCompilerTestCase(@"premature-null.boo"); + } + + [Test] + public void property() + { + RunCompilerTestCase(@"property.boo"); + } + + [Test] + public void struct_async() + { + RunCompilerTestCase(@"struct-async.boo"); + } + + [Test] + public void switch_on_awaited_value_async() + { + RunCompilerTestCase(@"switch-on-awaited-value-async.boo"); + } + + [Test] + public void task_returning_async() + { + RunCompilerTestCase(@"task-returning-async.boo"); + } + + [Test] + public void void_returning_async() + { + RunCompilerTestCase(@"void-returning-async.boo"); + } + + protected override string GetRelativeTestCasesPath() + { + return "async"; + } + } +} \ No newline at end of file diff --git a/tests/BooCompiler.Tests/BooCompiler.Tests.csproj b/tests/BooCompiler.Tests/BooCompiler.Tests.csproj index be4cb6984..74bdc6303 100644 --- a/tests/BooCompiler.Tests/BooCompiler.Tests.csproj +++ b/tests/BooCompiler.Tests/BooCompiler.Tests.csproj @@ -127,6 +127,7 @@ + diff --git a/tests/BooCompiler.Tests/CompilerErrorsTestFixture.cs b/tests/BooCompiler.Tests/CompilerErrorsTestFixture.cs index 3c0012231..3df0d4e9f 100644 --- a/tests/BooCompiler.Tests/CompilerErrorsTestFixture.cs +++ b/tests/BooCompiler.Tests/CompilerErrorsTestFixture.cs @@ -326,12 +326,6 @@ public void BCE0046_1() RunCompilerTestCase(@"BCE0046-1.boo"); } - [Test] - public void BCE0046_2() - { - RunCompilerTestCase(@"BCE0046-2.boo"); - } - [Test] public void BCE0047_1() { diff --git a/tests/BooCompiler.Tests/GenericsTestFixture.cs b/tests/BooCompiler.Tests/GenericsTestFixture.cs index 4a9e56d76..c714de43f 100644 --- a/tests/BooCompiler.Tests/GenericsTestFixture.cs +++ b/tests/BooCompiler.Tests/GenericsTestFixture.cs @@ -175,8 +175,26 @@ public void generic_array_3() { RunCompilerTestCase(@"generic-array-3.boo"); } - - [Test] + + [Test] + public void generic_closures() + { + RunCompilerTestCase(@"generic-closures.boo"); + } + + [Test] + public void generic_closures_2() + { + RunCompilerTestCase(@"generic-closures-2.boo"); + } + + [Test] + public void generic_closures_3() + { + RunCompilerTestCase(@"generic-closures-3.boo"); + } + + [Test] public void generic_extension_1() { RunCompilerTestCase(@"generic-extension-1.boo"); @@ -301,7 +319,14 @@ public void generic_overload_6() { RunCompilerTestCase(@"generic-overload-6.boo"); } - + + [Test] + public void generic_overload_7() + { + RunCompilerTestCase(@"generic-overload-7.boo"); + } + + [Test] public void generic_ref_parameter() { diff --git a/tests/testcases/async/async-conformance-awaiting-indexer.boo b/tests/testcases/async/async-conformance-awaiting-indexer.boo new file mode 100644 index 000000000..721651ce1 --- /dev/null +++ b/tests/testcases/async/async-conformance-awaiting-indexer.boo @@ -0,0 +1,52 @@ +""" +0 +""" + +import System.Threading +import System.Threading.Tasks +import System + +struct MyStruct[of T(Task[of Func[of int]])]: + property t as T + + public self[index as T] as T: + get: + return t + set: + t = value + +struct TestCase: + public static Count = 0 + private tests as int + + [async] public def Run(): + self.tests = 0 + var ms = MyStruct[of Task[of Func[of int]]]() + try: + ms[null] = Task.Run[of Func[of int]](async({ await(Task.Delay(1)); Interlocked.Increment(TestCase.Count); return {123} })) + self.tests++ + var x = await(ms[await(Foo(null))]) + if x() == 123: + self.tests++ + ensure: + Driver.Result = TestCase.Count - self.tests + //When test complete, set the flag. + Driver.CompletedSignal.Set() + + [async] public def Foo(d as Task[of Func[of int]]) as Task[of Task[of Func[of int]]]: + await Task.Delay(1) + Interlocked.Increment(TestCase.Count) + return d + +class Driver: + public static Result = -1 + public static CompletedSignal = AutoResetEvent(false) + +def Main(): + var t = TestCase() + t.Run() + Driver.CompletedSignal.WaitOne() + // 0 - success + // 1 - failed (test completed) + // -1 - failed (test incomplete - deadlock, etc) + Console.WriteLine(Driver.Result) diff --git a/tests/testcases/async/async-delegates.boo b/tests/testcases/async/async-delegates.boo new file mode 100644 index 000000000..aa7ae92bc --- /dev/null +++ b/tests/testcases/async/async-delegates.boo @@ -0,0 +1,47 @@ +""" +0 +2 +""" + +import System +import System.Threading.Tasks + +static class Program: + def test1(): + value = async() do(): + if 0.ToString().Length == 0: + await Task.Yield() + else: + System.Console.WriteLine(0.ToString()) + Invoke(value) + + def test2() as string: + value = async() do(): + if 0.ToString().Length == 0: + await Task.Yield() + return 1.ToString() + else: + System.Console.WriteLine(2.ToString()) + return null + return Invoke(value); + + def Invoke(method as Action): + method() + + def Invoke(method as Func[of Task]): + method().Wait(); + + def Invoke[of TResult](method as Func[of TResult]) as TResult: + return method() + + internal def Invoke[of TResult](method as Func[of Task[of TResult]]) as TResult: + if method != null: + return Invoke1(async({ await(Task.Yield()); return await(method()) })) + return Default(TResult) + + internal static def Invoke1[of TResult](method as Func[of Task[of TResult]]) as TResult: + return method().Result + +def Main(args as (string)): + Program.test1() + Program.test2() diff --git a/tests/testcases/async/async-extension-add-method.boo b/tests/testcases/async/async-extension-add-method.boo new file mode 100644 index 000000000..52fc0521c --- /dev/null +++ b/tests/testcases/async/async-extension-add-method.boo @@ -0,0 +1,35 @@ +""" +GetVal 1 +Add 1 +Add 2 +Add 3 +""" + +import System +import System.Collections.Generic +import System.Threading +import System.Threading.Tasks + +[Extension] +public def Add[of T](stack as Stack[of T], item as T): + Console.WriteLine("Add $item") + stack.Push(item) + +class TestCase: + public handle = AutoResetEvent(false) + + [async] private def GetVal[of T](x as T) as Task[of T]: + await Task.Delay(1) + Console.WriteLine("GetVal $x") + return x + + [async] public def Run(): + try: + var stack = Stack[of int]() {await(GetVal(1)), 2, 3 } + ensure: + handle.Set() + +public def Main(args as (string)): + var tc = TestCase() + tc.Run() + tc.handle.WaitOne(1000) \ No newline at end of file diff --git a/tests/testcases/async/async-hello-world.boo b/tests/testcases/async/async-hello-world.boo new file mode 100644 index 000000000..010e44dd8 --- /dev/null +++ b/tests/testcases/async/async-hello-world.boo @@ -0,0 +1,11 @@ +""" +Hello, World! +""" + +import System.Threading +import System.Threading.Tasks + +[async] public static def F(a as string) as Task: + await(Task.Factory.StartNew({ System.Console.WriteLine(a) })) + +F('Hello, World!').Wait() diff --git a/tests/testcases/async/async-method-only-writes-to-enclosing-struct.boo b/tests/testcases/async/async-method-only-writes-to-enclosing-struct.boo new file mode 100644 index 000000000..46218137b --- /dev/null +++ b/tests/testcases/async/async-method-only-writes-to-enclosing-struct.boo @@ -0,0 +1,14 @@ +""" +1 +""" + +public struct GenC[of T(struct)]: + public valueN as T? + [async] public def Test(t as T): + valueN = t; + +public static def Main(): + var test = 12 + var _int = GenC[of int]() + _int.Test(test) + System.Console.WriteLine((_int.valueN if _int.valueN.HasValue else 1)) diff --git a/tests/testcases/async/async-state-machine-struct-task-t.boo b/tests/testcases/async/async-state-machine-struct-task-t.boo new file mode 100644 index 000000000..ec0899d12 --- /dev/null +++ b/tests/testcases/async/async-state-machine-struct-task-t.boo @@ -0,0 +1,15 @@ +""" +42 +""" + +import System +import System.Threading.Tasks + +class Test: + [async] public static def F() as Task of int: + return await(Task.Factory.StartNew({42})) + +public def Main(): + var t = Test.F() + t.Wait() + Console.WriteLine(t.Result) diff --git a/tests/testcases/async/await-in-delegate-constructor.boo b/tests/testcases/async/await-in-delegate-constructor.boo new file mode 100644 index 000000000..9a60c8ed9 --- /dev/null +++ b/tests/testcases/async/await-in-delegate-constructor.boo @@ -0,0 +1,41 @@ +""" +0 +""" + +import System +import System.Collections.Generic +import System.Text +import System.Threading +import System.Threading.Tasks + +class TestCase: + static test = 0 + static count = 0 + + [async] public static def Run() as Task: + try: + test++ + checked: + var f = Func[of int, object](await(Bar())) + var x = f(1) + if (x cast string) != "1": + count-- + ensure: + Driver.Result = test - count + Driver.CompleteSignal.Set() + + [async] static def Bar() as Task[of Converter[of int, string]]: + count++ + await Task.Delay(1) + return {p1 as int | return p1.ToString()}; + +class Driver: + static public CompleteSignal = AutoResetEvent(false) + public static Result as int = -1 + + public static def Main(): + TestCase.Run() + CompleteSignal.WaitOne() + Console.Write(Result) + +Driver.Main() \ No newline at end of file diff --git a/tests/testcases/async/await-in-obj-initializer.boo b/tests/testcases/async/await-in-obj-initializer.boo new file mode 100644 index 000000000..e7a48bc61 --- /dev/null +++ b/tests/testcases/async/await-in-obj-initializer.boo @@ -0,0 +1,23 @@ +""" +0 +""" + +namespace CompilerCrashRepro2 + +import System +import System.Threading.Tasks + +public class Item[of T]: + [Property(Value)] + private _value as T + +public static class Crasher: + public def Build[of T]() as Func[of Task[of Item[of T]]]: + return async({ Item[of T](Value: await(GetValue[of T]())) }) + + public def GetValue[of T]() as Task[of T]: + return Task.FromResult(Default(T)) + +public def Main(): + var r = Crasher.Build[of int]()().Result.Value + System.Console.WriteLine(r) diff --git a/tests/testcases/async/await-in-using-and-for.boo b/tests/testcases/async/await-in-using-and-for.boo new file mode 100644 index 000000000..26fb04b5c --- /dev/null +++ b/tests/testcases/async/await-in-using-and-for.boo @@ -0,0 +1,16 @@ +import System.Threading.Tasks +import System + +class Program: + ien as System.Collections.Generic.IEnumerable[of int] = null + [async] def Test(id as IDisposable, task as Task[of int]) as Task[of int]: + try: + for i in ien: + return await(task) + using id: + return await(task) + except as Exception: + return await(task) + +public static def Main(): + pass diff --git a/tests/testcases/async/await-switch.boo b/tests/testcases/async/await-switch.boo new file mode 100644 index 000000000..2fdc1bf87 --- /dev/null +++ b/tests/testcases/async/await-switch.boo @@ -0,0 +1,32 @@ +""" +0 +""" + +import System +import System.Threading +import System.Threading.Tasks + +class TestCase: + [async] public def Run() as void: + test as int = 0 + result as int = 0 + try: + test++ + __switch__(await (async ({ await(Task.Delay(1)); return 5 })()), d, d, d, d, d, r) + goto d + :r + result++ + :d + ensure: + Driver.Result = test - result + Driver.CompleteSignal.Set() + +class Driver: + static public CompleteSignal = AutoResetEvent(false) + public static Result = -1 + +public static def Main(): + var tc = TestCase() + tc.Run() + Driver.CompleteSignal.WaitOne() + Console.WriteLine(Driver.Result) diff --git a/tests/testcases/async/await-void.boo b/tests/testcases/async/await-void.boo new file mode 100644 index 000000000..e1359839e --- /dev/null +++ b/tests/testcases/async/await-void.boo @@ -0,0 +1,20 @@ +""" +42 +""" + +import System +import System.Threading +import System.Threading.Tasks + +class Test: + internal static i = 0 + + [async] public static def F(handle as AutoResetEvent): + await Task.Factory.StartNew({ Test.i = 42 }) + handle.Set() + +public static def Main(): + var handle = AutoResetEvent(false) + Test.F(handle) + handle.WaitOne(1000) + Console.WriteLine(Test.i) diff --git a/tests/testcases/async/better-conversion-from-async-lambda.boo b/tests/testcases/async/better-conversion-from-async-lambda.boo new file mode 100644 index 000000000..db224540d --- /dev/null +++ b/tests/testcases/async/better-conversion-from-async-lambda.boo @@ -0,0 +1,16 @@ +""" +12 +""" + +import System.Threading +import System.Threading.Tasks +import System + +static class TestCase: + public def Foo(f as Func[of Task[of double]]): + return 12 + public def Foo(f as Func[of Task[of object]]): + return 13 + +public def Main(): + Console.WriteLine(TestCase.Foo(async({ return 14 }))) diff --git a/tests/testcases/async/conformance-awaiting-methods-accessible.boo b/tests/testcases/async/conformance-awaiting-methods-accessible.boo new file mode 100644 index 000000000..89dc571cc --- /dev/null +++ b/tests/testcases/async/conformance-awaiting-methods-accessible.boo @@ -0,0 +1,33 @@ +""" +0 +""" + +import System +import System.Collections.Generic +import System.Threading.Tasks +import System.Threading + +class TestCase(Test): + public static Count = 0 + + [async] public static def Run(): + try: + x as int = await(Test.GetValue[of int](1)) + if x != 1: + Count++ + ensure: + Driver.CompletedSignal.Set() + +class Test: + [async] protected static def GetValue[of T](t as T) as Task[of T]: + await Task.Delay(1) + return t + +class Driver: + public static CompletedSignal = AutoResetEvent(false) + +static def Main(): + TestCase.Run() + Driver.CompletedSignal.WaitOne() + // 0 - success + Console.WriteLine(TestCase.Count) diff --git a/tests/testcases/async/conformance-awaiting-methods-generic.boo b/tests/testcases/async/conformance-awaiting-methods-generic.boo new file mode 100644 index 000000000..1605b6aef --- /dev/null +++ b/tests/testcases/async/conformance-awaiting-methods-generic.boo @@ -0,0 +1,52 @@ +""" +0 +""" + +import System; +import System.Runtime.CompilerServices; +import System.Threading; + +//Implementation of your own async pattern + +public class MyTask[of T]: + + public def GetAwaiter() as MyTaskAwaiter[of T]: + return MyTaskAwaiter[of T]() + + [async] public def Run[of U(MyTask[of int], constructor)](u as U) as void: + try: + tests as int = 0 + tests++ + var rez = await(u) + if rez == 0: + Driver.Count++ + Driver.Result = Driver.Count - tests + ensure: + //When test complete, set the flag. + Driver.CompletedSignal.Set() + +public class MyTaskAwaiter[of T](INotifyCompletion): + + public def OnCompleted(continuationAction as Action) as void: + pass + + public def GetResult() as T: + return Default(T) + + public IsCompleted as bool: + get: + return true + +//------------------------------------- +class Driver: + public static Result as int = -1 + public static Count as int = 0 + public static CompletedSignal = AutoResetEvent(false) + +static def Main(): + MyTask[of int]().Run[of MyTask[of int]](MyTask[of int]()) + Driver.CompletedSignal.WaitOne() + // 0 - success + // 1 - failed (test completed) + // -1 - failed (test incomplete - deadlock, etc) + Console.WriteLine(Driver.Result) diff --git a/tests/testcases/async/conformance-awaiting-methods-method.boo b/tests/testcases/async/conformance-awaiting-methods-method.boo new file mode 100644 index 000000000..2cb67dc2c --- /dev/null +++ b/tests/testcases/async/conformance-awaiting-methods-method.boo @@ -0,0 +1,44 @@ +""" +0 +""" + +import System.Threading +import System.Threading.Tasks +import System + +public interface IExplicit: + def Method(x as int) as Task + +class C1(IExplicit): + [async] def IExplicit.Method(x as int) as Task: + //This will fail until Run and RunEx are merged back together + return Task.Run() do(): + await Task.Delay(1) + Driver.Count++ + +class TestCase: + [async] public def Run(): + try: + tests as int = 0 + tests++ + var c = C1() + var e = c cast IExplicit + await e.Method(4) + Driver.Result = Driver.Count - tests + ensure: + //When test complete, set the flag. + Driver.CompletedSignal.Set() + +class Driver: + public static Result = -1 + public static Count = 0 + public static CompletedSignal = AutoResetEvent(false) + +static def Main(): + var t = TestCase() + t.Run() + Driver.CompletedSignal.WaitOne() + // 0 - success + // 1 - failed (test completed) + // -1 - failed (test incomplete - deadlock, etc) + Console.WriteLine(Driver.Result) diff --git a/tests/testcases/async/conformance-awaiting-methods-method02.boo b/tests/testcases/async/conformance-awaiting-methods-method02.boo new file mode 100644 index 000000000..e4467d8d8 --- /dev/null +++ b/tests/testcases/async/conformance-awaiting-methods-method02.boo @@ -0,0 +1,50 @@ +""" +0 +""" + +import System.Threading +import System.Threading.Tasks +import System + +class C: + public Status as int + public def constructor(): + pass + +interface IImplicit: + def Method[of T(Task[of C])](*d as (decimal)) as T + +class Impl(IImplicit): + public def Method[of T(Task[of C])](*d as (decimal)) as T: + //this will fail until Run and RunEx are merged + aTask = async() do(): + await Task.Delay(1) + Driver.Count++ + return C(Status: 1) + return Task.Run(aTask) cast T + +class TestCase: + [async] public def Run(): + try: + var tests = 0 + var i = Impl() + tests++ + await i.Method[of Task[of C]](3, 4) + Driver.Result = Driver.Count - tests + ensure: + //When test complete, set the flag. + Driver.CompletedSignal.Set() + +class Driver: + public static Result = -1 + public static Count = 0 + public static CompletedSignal = AutoResetEvent(false) + +static def Main(): + var t = TestCase() + t.Run() + Driver.CompletedSignal.WaitOne() + // 0 - success + // 1 - failed (test completed) + // -1 - failed (test incomplete - deadlock, etc) + Console.WriteLine(Driver.Result) diff --git a/tests/testcases/async/conformance-awaiting-methods-parameter.boo b/tests/testcases/async/conformance-awaiting-methods-parameter.boo new file mode 100644 index 000000000..2c81e6dc8 --- /dev/null +++ b/tests/testcases/async/conformance-awaiting-methods-parameter.boo @@ -0,0 +1,39 @@ +""" +0 +""" + +import System +import System.Threading.Tasks +import System.Collections.Generic +import System.Threading + +class TestCase: + public static Count = 0 + + public static def Foo[of T](t as T) as T: + return t + + [async] public static def Bar[of T](t as T) as Task[of T]: + await Task.Delay(1) + return t + + [async] public static def Run(): + try: + var x1 = Foo(await(Bar(4))) + t as Task[of int] = Bar(5) + x2 as int = Foo(await(t)) + if x1 != 4: + Count++ + if x2 != 5: + Count++ + ensure: + Driver.CompletedSignal.Set() + +class Driver: + public static CompletedSignal = AutoResetEvent(false) + +static def Main(): + TestCase.Run() + Driver.CompletedSignal.WaitOne() + // 0 - success + Console.WriteLine(TestCase.Count) diff --git a/tests/testcases/async/conformance-exceptions-async-await-names.boo b/tests/testcases/async/conformance-exceptions-async-await-names.boo new file mode 100644 index 000000000..e65408e39 --- /dev/null +++ b/tests/testcases/async/conformance-exceptions-async-await-names.boo @@ -0,0 +1,27 @@ +""" +0 +""" +import System +class TestCase: + public def Run(): + Driver.Tests++ + try: + raise ArgumentException() + except await as Exception: + if await isa ArgumentException: + Driver.Count++ + Driver.Tests++ + try: + raise ArgumentException() + except async as Exception: + if async isa ArgumentException: + Driver.Count++ + +class Driver: + public static Tests as int + public static Count as int + +static def Main(): + var t = TestCase() + t.Run() + Console.WriteLine(Driver.Tests - Driver.Count) diff --git a/tests/testcases/async/conformance-overload-resolution-class-generic-regular-method.boo b/tests/testcases/async/conformance-overload-resolution-class-generic-regular-method.boo new file mode 100644 index 000000000..27e735ef8 --- /dev/null +++ b/tests/testcases/async/conformance-overload-resolution-class-generic-regular-method.boo @@ -0,0 +1,45 @@ +""" +0 +""" + +import System.Threading +import System.Threading.Tasks +import System + +struct Test[of U, V, W]: + //Regular methods + public def Foo(f as Func[of Task[of U]]): + return 1 + public def Foo(f as Func[of Task[of V]]): + return 2 + public def Foo(f as Func[of Task[of W]]): + return 3 + +class TestCase: + //where there is a conversion between types (int->double) + [async] public def Run(): + var test = Test[of decimal, string, object]() + var rez = 0 + // Pick double + Driver.Tests++ + rez = test.Foo(async({ return 1.0 })) + if rez == 3: Driver.Count++ + //pick int + Driver.Tests++ + rez = test.Foo(async({ return 1 })) + if rez == 1: Driver.Count++ + // The best overload is Func[of Task[of object>> + Driver.Tests++ + rez = test.Foo(async({ return ""; })) + if rez == 2: Driver.Count++ + +class Driver: + public static Count = 0 + public static Tests = 0 + +static def Main() as int: + var t = TestCase() + t.Run() + var ret = Driver.Tests - Driver.Count + Console.WriteLine(ret) + return ret diff --git a/tests/testcases/async/cs-bug-602246.boo b/tests/testcases/async/cs-bug-602246.boo new file mode 100644 index 000000000..d9c624322 --- /dev/null +++ b/tests/testcases/async/cs-bug-602246.boo @@ -0,0 +1,18 @@ +""" +12 +""" + +import System +import System.Threading.Tasks + +public static class TestCase: + [async] public def Run[of T](t as T) as Task[of T]: + await Task.Delay(1) + f as Func[of Func[of Task[of T]], Task[of T]] = async({x | return await(x()) }) + var rez = await(f(async({ await(Task.Delay(1)); return t }))) + return rez; + +public def Main() as void: + var t = TestCase.Run[of int](12) + if not t.Wait(1000): raise Exception() + Console.Write(t.Result) diff --git a/tests/testcases/async/cs-bug-748527.boo b/tests/testcases/async/cs-bug-748527.boo new file mode 100644 index 000000000..7d07b36c6 --- /dev/null +++ b/tests/testcases/async/cs-bug-748527.boo @@ -0,0 +1,23 @@ +""" +""" + +namespace A + +import System.Threading.Tasks +import System + +public struct TestClass: + [async] public def IntRet(IntI as int) as System.Threading.Tasks.Task[of int]: + return await(async({ await(Task.Yield()); return IntI })()); + +static public class B: + [async] public def MainMethod() as System.Threading.Tasks.Task[of int]: + var MyRet = 0 + var TC = TestClass() + if (await((async({ await(Task.Yield()); return (await(TestClass().IntRet(await(async({ await(Task.Yield()); return 3 })()) ))) }) )()) ) != await(async({ await(Task.Yield()); return 3 } )()): + MyRet = 1 + return await(async({await(Task.Yield()); return MyRet})()) + +def Main(): + B.MainMethod() + return diff --git a/tests/testcases/async/delegate-async.boo b/tests/testcases/async/delegate-async.boo new file mode 100644 index 000000000..6ee8173ae --- /dev/null +++ b/tests/testcases/async/delegate-async.boo @@ -0,0 +1,57 @@ +""" +0 +""" + +import System.Threading +import System.Threading.Tasks +import System + +callable MyDel[of U](ref u as U) as Task + +class MyClass[of T]: + public static def Meth(ref t as T) as Task: + t = Default(T) + return Task.Run(async ({ await(Task.Delay(1)); TestCase.Count++ })) + + public myDel as MyDel[of T] + + public event myEvent as MyDel[of T] + + [async] public def TriggerEvent(p as T) as Task: + try: + await myEvent(p) + except: + TestCase.Count += 5 + +struct TestCase: + public static Count = 0 + private tests as int + + [async] public def Run(): + tests = 0 + try: + tests++; + var ms = MyClass[of string]() + ms.myDel = MyClass[of string].Meth; + var str = "" + await ms.myDel(str) + tests++ + ms.myEvent += MyClass[of string].Meth + await ms.TriggerEvent(str) + ensure: + Driver.Result = TestCase.Count - self.tests + //When test complete, set the flag. + Driver.CompletedSignal.Set() + +class Driver: + public static Result = -1 + public static CompletedSignal = AutoResetEvent(false) + +static def Main(): + var t = TestCase() + t.Run() + Driver.CompletedSignal.WaitOne() + // 0 - success + // 1 - failed (test completed) + // -1 - failed (test incomplete - deadlock, etc) + Console.WriteLine(Driver.Result) diff --git a/tests/testcases/async/generic-async-lambda.boo b/tests/testcases/async/generic-async-lambda.boo new file mode 100644 index 000000000..bd29d5fba --- /dev/null +++ b/tests/testcases/async/generic-async-lambda.boo @@ -0,0 +1,32 @@ +""" +12 +""" + +import System +import System.Diagnostics +import System.Threading +import System.Threading.Tasks + +class G[of T]: + t as T + + public def constructor(t as T, action as Func[of T, Task[of T]]): + var tt = action(t) + var completed = tt.Wait(1000) + Debug.Assert(completed) + self.t = tt.Result + + public override def ToString() as string: + return t.ToString() + +static class Test: + def M[of U](t as U) as G[of U]: + return G[of U](t, async({x | return await(IdentityAsync(x)) })) + + [async] def IdentityAsync[of V](x as V) as Task[of V]: + await Task.Delay(1) + return x + +public def Main(): + var g = Test.M(12) + Console.WriteLine(g) diff --git a/tests/testcases/async/generic-task-returning-async.boo b/tests/testcases/async/generic-task-returning-async.boo new file mode 100644 index 000000000..3e0daa5b4 --- /dev/null +++ b/tests/testcases/async/generic-task-returning-async.boo @@ -0,0 +1,15 @@ +""" +O brave new world... +""" + +import System +import System.Diagnostics +import System.Threading.Tasks + +[async] public static def F() as Task[of string]: + return await(Task.Factory.StartNew({ return "O brave new world..." })) + +public static def Main(): + t as Task[of string] = F() + t.Wait(1000 * 3) + Console.WriteLine(t.Result) diff --git a/tests/testcases/async/generic.boo b/tests/testcases/async/generic.boo new file mode 100644 index 000000000..338c4f8d7 --- /dev/null +++ b/tests/testcases/async/generic.boo @@ -0,0 +1,36 @@ +""" +0 +""" + +import System +import System.Collections.Generic +import System.Text +import System.Threading +import System.Threading.Tasks + +class TestCase: + static test as int = 0 + static count as int = 0 + + [async] public static def Run() as Task: + try: + test++ + Qux(async({ return 1 })) + await Task.Delay(50) + ensure: + Driver.Result = test - count + Driver.CompleteSignal.Set() + + [async] static def Qux[of T](x as Func[of Task[of T]]): + var y = await(x()) + if (y cast object) cast int == 1: + count++ + +class Driver: + static public CompleteSignal = AutoResetEvent(false) + public static Result as int = -1 + +public static def Main(): + TestCase.Run() + Driver.CompleteSignal.WaitOne() + Console.WriteLine(Driver.Result) diff --git a/tests/testcases/async/hoist-structure.boo b/tests/testcases/async/hoist-structure.boo new file mode 100644 index 000000000..d762d4d3b --- /dev/null +++ b/tests/testcases/async/hoist-structure.boo @@ -0,0 +1,24 @@ +""" +Before 12 +After 12 +""" + +namespace ConsoleApp + +import System +import System.Threading.Tasks + +struct TestStruct: + public i as long + public j as long + +static class Program: + [async] def TestAsync() as Task: + t as TestStruct + t.i = 12 + Console.WriteLine("Before {0}", t.i); // emits "Before 12" + await Task.Delay(100); + Console.WriteLine("After {0}", t.i); // emits "After 0" expecting "After 12" + +def Main(args as (string)): + Program.TestAsync().Wait() diff --git a/tests/testcases/async/hoist-using-1.boo b/tests/testcases/async/hoist-using-1.boo new file mode 100644 index 000000000..84c62bd6e --- /dev/null +++ b/tests/testcases/async/hoist-using-1.boo @@ -0,0 +1,30 @@ +""" +Pre +show +disposed +Post +result +""" + +import System.Threading.Tasks +import System + +class Program: + class D(IDisposable): + public def Dispose(): + print "disposed" + + [async] static def M(input as int) as Task[of string]: + print "Pre" + var window = D() + try: + print "show" + for i in range(2): + await Task.Delay(100) + ensure: + window.Dispose() + print "Post" + return "result" + +static def Main(): + System.Console.WriteLine(Program.M(0).Result) diff --git a/tests/testcases/async/hoist-using-2.boo b/tests/testcases/async/hoist-using-2.boo new file mode 100644 index 000000000..0cb2ddba0 --- /dev/null +++ b/tests/testcases/async/hoist-using-2.boo @@ -0,0 +1,27 @@ +""" +Pre +show +disposed +Post +result +""" + +import System.Threading.Tasks +import System + +class Program: + class D(IDisposable): + public def Dispose(): + print "disposed" + + [async] static def M(input as int) as Task[of string]: + print "Pre" + using window = D(): + print "show" + for i in range(2): + await Task.Delay(100) + print "Post" + return "result" + +static def Main(): + System.Console.WriteLine(Program.M(0).Result) diff --git a/tests/testcases/async/hoist-using-3.boo b/tests/testcases/async/hoist-using-3.boo new file mode 100644 index 000000000..16bb99a5f --- /dev/null +++ b/tests/testcases/async/hoist-using-3.boo @@ -0,0 +1,31 @@ +""" +Pre +show +show +disposed +disposed +Post +result +""" + +import System.Threading.Tasks +import System + +class Program: + class D(IDisposable): + public def Dispose(): + print "disposed" + + [async] static def M(input as int) as Task[of string]: + print "Pre" + using window = D(): + print "show" + using window = D(): + print "show" + for i in range(2): + await Task.Delay(100) + print "Post" + return "result" + +static def Main(): + System.Console.WriteLine(Program.M(0).Result) diff --git a/tests/testcases/async/infer-from-async-lambda.boo b/tests/testcases/async/infer-from-async-lambda.boo new file mode 100644 index 000000000..0df5106e4 --- /dev/null +++ b/tests/testcases/async/infer-from-async-lambda.boo @@ -0,0 +1,23 @@ +""" +System.Threading.Tasks.Task +""" + +import System +import System.Threading.Tasks + +static class Program: + public def CallWithCatch[of T](func as Func of T) as T: + Console.WriteLine(typeof(T).ToString()) + return func() + + [async] public def LoadTestDataAsync() as Task: + await CallWithCatch(async({await(LoadTestData())})) + + [async] private def LoadTestData() as Task: + nullLambda = do(): + pass + await Task.Run(nullLambda) + +public def Main(args as (string)): + var t = Program.LoadTestDataAsync() + t.Wait(1000) diff --git a/tests/testcases/async/inference.boo b/tests/testcases/async/inference.boo new file mode 100644 index 000000000..bc21df674 --- /dev/null +++ b/tests/testcases/async/inference.boo @@ -0,0 +1,49 @@ +""" +0 +""" + +import System +import System.Collections.Generic +import System.Threading +import System.Threading.Tasks + +struct Test: + public Foo as Task[of string]: + get: return Task.Run[of string](async ({ await(Task.Delay(1)); return "abc" })) + +class TestCase[of U]: + [async] public static def GetValue(x as object) as Task[of object]: + await Task.Delay(1) + return x + + public static def GetValue1[of T(Task[of U])](t as T) as T: + return t + + [async] public def Run(): + tests as int = 0 + var t = Test() + tests++ + var x1 = await(TestCase[of string].GetValue(await(t.Foo))) + if x1 == "abc": + Driver.Count++ + tests++ + var x2 = await(TestCase[of string].GetValue1(t.Foo)) + if x2 == "abc": + Driver.Count++ + Driver.Result = Driver.Count - tests + //When test completes, set the flag. + Driver.CompletedSignal.Set() + +class Driver: + public static Result = -1 + public static Count = 0 + public static CompletedSignal = AutoResetEvent(false) + +static def Main(): + var t = TestCase[of int]() + t.Run() + Driver.CompletedSignal.WaitOne() + // 0 - success + // 1 - failed (test completed) + // -1 - failed (test incomplete - deadlock, etc) + Console.WriteLine(Driver.Result) diff --git a/tests/testcases/async/init.boo b/tests/testcases/async/init.boo new file mode 100644 index 000000000..abcdcfe4e --- /dev/null +++ b/tests/testcases/async/init.boo @@ -0,0 +1,63 @@ +""" +0 +""" + +import System +import System.Collections.Generic +import System.Threading +import System.Threading.Tasks + +class ObjInit: + public async as int + public t as Task + public l as long + +class TestCase: + + private def Throw[of T](i as T) as T: + MethodCount++ + raise OverflowException() + + [async] private def GetVal[of T](x as T) as Task[of T]: + await Task.Delay(1) + Throw(x) + return x + + [property(MyProperty)] + private _myProperty as Task[of long] + + [async] public def Run(): + var tests = 0 + t as Task[of int] = Task.Run[of int](async({ await(Task.Delay(1)); raise FieldAccessException(); return 1 })) + //object type init + tests++ + try: + MyProperty = Task.Run[of long](async ({ await(Task.Delay(1)); raise DataMisalignedException(); return 1L })) + var obj = ObjInit( + async: await(t), + t: GetVal(Task.Run(async({ await(Task.Delay(1))}))), + l: await(MyProperty) ) + await obj.t + except as FieldAccessException: + Driver.Count++ + except: + Driver.Count-- + Driver.Result = Driver.Count - tests + //When test complete, set the flag. + Driver.CompletedSignal.Set() + + public MethodCount = 0 + +class Driver: + public static Result = -1 + public static Count = 0 + public static CompletedSignal = AutoResetEvent(false) + +static def Main(): + var t = TestCase() + t.Run() + Driver.CompletedSignal.WaitOne() + // 0 - success + // 1 - failed (test completed) + // -1 - failed (test incomplete - deadlock, etc) + Console.WriteLine(Driver.Result) diff --git a/tests/testcases/async/is-and-as-operators.boo b/tests/testcases/async/is-and-as-operators.boo new file mode 100644 index 000000000..3ce5615b1 --- /dev/null +++ b/tests/testcases/async/is-and-as-operators.boo @@ -0,0 +1,46 @@ +""" +0 +""" + +import System.Threading +import System.Threading.Tasks +import System + +class TestCase: + public static Count = 0 + + [async] public def Run(): + var tests = 0 + var x1 = await(Foo1()) isa object + var x2 = await(Foo2()) as string + if x1 == true: + tests++ + if x2 == "string": + tests++ + Driver.Result = TestCase.Count - tests + //When test complete, set the flag. + Driver.CompletedSignal.Set() + + [async] public def Foo1() as Task[of int]: + await Task.Delay(1) + TestCase.Count++ + var i = 0 + return i + + [async] public def Foo2() as Task[of object]: + await Task.Delay(1) + TestCase.Count++ + return "string" + +class Driver: + public static Result = -1 + public static CompletedSignal = AutoResetEvent(false) + +def Main(): + var t = TestCase() + t.Run() + Driver.CompletedSignal.WaitOne() + // 0 - success + // 1 - failed (test completed) + // -1 - failed (test incomplete - deadlock, etc) + Console.Write(Driver.Result) diff --git a/tests/testcases/async/mutating-array-of-structs.boo b/tests/testcases/async/mutating-array-of-structs.boo new file mode 100644 index 000000000..5251eca9e --- /dev/null +++ b/tests/testcases/async/mutating-array-of-structs.boo @@ -0,0 +1,25 @@ +""" +""" + +import System +import System.Diagnostics +import System.Threading +import System.Threading.Tasks + +struct S: + public A as int + + public def Mutate(b as int) as int: + A += b + return 1 + +static class Test: + i = 0 + + public def G() as Task[of int]: + return null + + [async] public def F() as Task[of int]: + var arr = array(S, 10) + + return arr[1].Mutate(await(G())) diff --git a/tests/testcases/async/mutating-struct-with-using.boo b/tests/testcases/async/mutating-struct-with-using.boo new file mode 100644 index 000000000..1a4a7b276 --- /dev/null +++ b/tests/testcases/async/mutating-struct-with-using.boo @@ -0,0 +1,20 @@ +""" +True +1 +""" + +import System +import System.Collections.Generic +import System.Threading +import System.Threading.Tasks + +class Program: + [async] public def Test() as Task: + var list = List[of int]() {1, 2, 3} + using enumerator = list.GetEnumerator(): + Console.WriteLine(enumerator.MoveNext()); + Console.WriteLine(enumerator.Current); + await Task.Delay(1); + +public static def Main(): + Program().Test().Wait() diff --git a/tests/testcases/async/my-task-2.boo b/tests/testcases/async/my-task-2.boo new file mode 100644 index 000000000..069b5d6d2 --- /dev/null +++ b/tests/testcases/async/my-task-2.boo @@ -0,0 +1,51 @@ +""" +0 +""" + +import System +import System.Threading +import System.Threading.Tasks + +//Implementation of you own async pattern +public class MyTask: + public def GetAwaiter() as MyTaskAwaiter: + return MyTaskAwaiter() + + [async]public def Run(): + tests as int = 0 + try: + tests++ + var myTask = MyTask() + var x = await(myTask) + if x == 123: Driver.Count++ + ensure: + Driver.Result = Driver.Count - tests + //When test complete, set the flag. + Driver.CompletedSignal.Set() + +public class MyTaskBaseAwaiter(System.Runtime.CompilerServices.INotifyCompletion): + public def OnCompleted(continuationAction as Action): + pass + + public def GetResult() as int: + return 123 + + public IsCompleted as bool: + get: return true + +public class MyTaskAwaiter(MyTaskBaseAwaiter): + pass + +//------------------------------------- +class Driver: + public static Result = -1 + public static Count = 0 + public static CompletedSignal = AutoResetEvent(false) + +def Main(): + MyTask().Run() + Driver.CompletedSignal.WaitOne() + // 0 - success + // 1 - failed (test completed) + // -1 - failed (test incomplete - deadlock, etc) + Console.WriteLine(Driver.Result) diff --git a/tests/testcases/async/my-task.boo b/tests/testcases/async/my-task.boo new file mode 100644 index 000000000..c1e774835 --- /dev/null +++ b/tests/testcases/async/my-task.boo @@ -0,0 +1,49 @@ +""" +0 +""" + +import System +import System.Threading +import System.Threading.Tasks + +//Implementation of you own async pattern +public class MyTask: + [async] public def Run(): + tests as int = 0 + try: + tests++ + var myTask = MyTask() + var x = await(myTask) + if x == 123: Driver.Count++ + ensure: + Driver.Result = Driver.Count - tests + //When test complete, set the flag. + Driver.CompletedSignal.Set() + +public class MyTaskAwaiter(System.Runtime.CompilerServices.INotifyCompletion): + public def OnCompleted(continuationAction as Action): + pass + + public def GetResult() as int: + return 123 + + public IsCompleted as bool: + get: return true + +[Extension] +public static def GetAwaiter(my as MyTask) as MyTaskAwaiter: + return MyTaskAwaiter() + +//------------------------------------- +class Driver: + public static Result = -1 + public static Count = 0 + public static CompletedSignal = AutoResetEvent(false) + +static def Main(): + MyTask().Run() + Driver.CompletedSignal.WaitOne() + // 0 - success + // 1 - failed (test completed) + // -1 - failed (test incomplete - deadlock, etc) + Console.WriteLine(Driver.Result) diff --git a/tests/testcases/async/premature-null.boo b/tests/testcases/async/premature-null.boo new file mode 100644 index 000000000..abd632a3c --- /dev/null +++ b/tests/testcases/async/premature-null.boo @@ -0,0 +1,46 @@ +""" +in FindReferencesInDocumentAsync +in GetTokensWithIdentifierAsync +in FindReferencesInTokensAsync +tokens were fine +document was fine +done! +""" + +import System +import System.Collections.Generic +import System.Diagnostics +import System.Linq +import System.Text +import System.Threading +import System.Threading.Tasks + +static class Program: + [async] internal def GetTokensWithIdentifierAsync() as Task[of string]: + Console.WriteLine("in GetTokensWithIdentifierAsync") + return "GetTokensWithIdentifierAsync" + + [async] protected def FindReferencesInTokensAsync(document as string, tokens as string) as Task[of string]: + Console.WriteLine("in FindReferencesInTokensAsync") + if tokens is null: raise NullReferenceException("tokens") + Console.WriteLine("tokens were fine") + if document is null: raise NullReferenceException("document") + Console.WriteLine("document was fine") + return "FindReferencesInTokensAsync" + + [async] public def FindReferencesInDocumentAsync(document as string) as Task[of string]: + Console.WriteLine("in FindReferencesInDocumentAsync") + if document is null: raise NullReferenceException("document") + var nonAliasReferences = await(FindReferencesInTokensAsync( + document, + await(GetTokensWithIdentifierAsync()) + ).ConfigureAwait(true)) + return "done!" + +public def Main(args as (string)): + try: + var ar = Program.FindReferencesInDocumentAsync("Document") + ar.Wait(1000) + Console.WriteLine(ar.Result) + except ex as Exception: + Console.WriteLine(ex) diff --git a/tests/testcases/async/property.boo b/tests/testcases/async/property.boo new file mode 100644 index 000000000..af32a5735 --- /dev/null +++ b/tests/testcases/async/property.boo @@ -0,0 +1,33 @@ +""" +0 +""" + +import System.Threading +import System.Threading.Tasks +import System + +class Base: + private _myProp as int + + public virtual MyProp as int: + get: return _myProp + private set: _myProp = value + +class TestClass(Base): + [async] def getBaseMyProp() as Task[of int]: + await Task.Delay(1) + return super.MyProp + + [async] public def Run(): + Driver.Result = await(getBaseMyProp()) + Driver.CompleteSignal.Set() + +class Driver: + public static CompleteSignal = AutoResetEvent(false) + public static Result = -1 + +public static def Main(): + var tc = TestClass() + tc.Run() + Driver.CompleteSignal.WaitOne() + Console.WriteLine(Driver.Result) diff --git a/tests/testcases/async/struct-async.boo b/tests/testcases/async/struct-async.boo new file mode 100644 index 000000000..3ba0cf86e --- /dev/null +++ b/tests/testcases/async/struct-async.boo @@ -0,0 +1,40 @@ +""" +0 +""" + +import System +import System.Threading +import System.Threading.Tasks + +struct TestCase: + private t as Task[of int] + + [async] public def Run(): + tests as int = 0 + try: + tests++ + t = Task.Run(async({ await(Task.Delay(1)); return 1 })) + var x = await(t) + if x == 1: Driver.Count++; + tests++ + t = Task.Run(async({ await(Task.Delay(1)); return 1 })) + var x2 = await(self.t) + if x2 == 1: Driver.Count++ + ensure: + Driver.Result = Driver.Count - tests + //When test complete, set the flag. + Driver.CompletedSignal.Set() + +class Driver: + public static Result = -1 + public static Count = 0 + public static CompletedSignal = AutoResetEvent(false) + +static def Main(): + var t = TestCase() + t.Run() + Driver.CompletedSignal.WaitOne() + // 0 - success + // 1 - failed (test completed) + // -1 - failed (test incomplete - deadlock, etc) + Console.Write(Driver.Result) diff --git a/tests/testcases/async/switch-on-awaited-value-async.boo b/tests/testcases/async/switch-on-awaited-value-async.boo new file mode 100644 index 000000000..c380c9f55 --- /dev/null +++ b/tests/testcases/async/switch-on-awaited-value-async.boo @@ -0,0 +1,14 @@ +import System.Threading.Tasks +import System + +static class Program: + [async] def M(input as int) as Task: + var value = 1 + __switch__(value, c0, c1) + :c0 + return + :c1 + return + +def Main(): + Program.M(0).Wait() diff --git a/tests/testcases/async/task-returning-async.boo b/tests/testcases/async/task-returning-async.boo new file mode 100644 index 000000000..281aa795c --- /dev/null +++ b/tests/testcases/async/task-returning-async.boo @@ -0,0 +1,17 @@ +""" +42 +""" + +import System +import System.Diagnostics +import System.Threading.Tasks + +class Test: + public static i as int = 0 + [async] public static def F() as Task: + await(Task.Factory.StartNew({ Test.i = 42} )) + +public static def Main(): + t as Task = Test.F() + t.Wait(1000) + Console.WriteLine(Test.i) diff --git a/tests/testcases/async/void-returning-async.boo b/tests/testcases/async/void-returning-async.boo new file mode 100644 index 000000000..b363dc5b3 --- /dev/null +++ b/tests/testcases/async/void-returning-async.boo @@ -0,0 +1,24 @@ +""" +1 +""" + +import System +import System.Diagnostics +import System.Threading +import System.Threading.Tasks + +class Test: + + public static i as int = 0 + + [async] public static def F(handle as AutoResetEvent) as void: + try: + await Task.Factory.StartNew({ Interlocked.Increment(Test.i) }) + ensure: + handle.Set() + +public static def Main(): + var handle = AutoResetEvent(false) + Test.F(handle) + handle.WaitOne(1000 * 3) + Console.WriteLine(Test.i) diff --git a/tests/testcases/errors/BCE0006-1.boo b/tests/testcases/errors/BCE0006-1.boo index 701112906..18df53389 100644 --- a/tests/testcases/errors/BCE0006-1.boo +++ b/tests/testcases/errors/BCE0006-1.boo @@ -1,6 +1,5 @@ """ -BCE0006-1.boo(5,5): BCE0006: 'int' is a value type. The 'as' operator can only be used with reference types. -BCE0006-1.boo(6,17): BCE0006: 'long' is a value type. The 'as' operator can only be used with reference types. +BCE0006-1.boo(5,17): BCE0006: 'long' is a value type. The 'as' operator can only be used with reference types. """ a = 3 as object b = object() as long diff --git a/tests/testcases/errors/BCE0046-2.boo b/tests/testcases/errors/BCE0046-2.boo deleted file mode 100644 index 92ad67af7..000000000 --- a/tests/testcases/errors/BCE0046-2.boo +++ /dev/null @@ -1,17 +0,0 @@ -""" -BCE0046-2.boo(7,8): BCE0046: 'isa' can't be used with a value type ('T') -BCE0046-2.boo(11,8): BCE0046: 'isa' can't be used with a value type ('T') -""" - -def Foo[of T(struct)](x as T): - if x isa string: - pass - -def FooInt[of T(int)](x as T): - if x isa string: - pass - - -Foo[of int](0) -FooInt[of int](0) - diff --git a/tests/testcases/integration/types/fields-10.boo b/tests/testcases/integration/types/fields-10.boo new file mode 100644 index 000000000..bdaf50d48 --- /dev/null +++ b/tests/testcases/integration/types/fields-10.boo @@ -0,0 +1,6 @@ +class Fields10: + + public Foo as System.Action + + def DoSomething(): + Foo() \ No newline at end of file diff --git a/tests/testcases/net2/generics/generic-closures-2.boo b/tests/testcases/net2/generics/generic-closures-2.boo new file mode 100644 index 000000000..a79c70067 --- /dev/null +++ b/tests/testcases/net2/generics/generic-closures-2.boo @@ -0,0 +1,12 @@ +""" +pass +""" +import System + +static class Foo[of T]: + + def Bar(baz as T): + var method = {return baz} + print method() + +Foo[of string].Bar('pass') \ No newline at end of file diff --git a/tests/testcases/net2/generics/generic-closures-3.boo b/tests/testcases/net2/generics/generic-closures-3.boo new file mode 100644 index 000000000..1bd29ce71 --- /dev/null +++ b/tests/testcases/net2/generics/generic-closures-3.boo @@ -0,0 +1,13 @@ +""" +pass +""" +import System +import System.Collections.Generic + +static class Foo: + + def Bar[of TResult](method as Func[of List[of TResult]]) as TResult: + return {return method()[0]}() + +var list = List[of string](){'pass'} +print Foo.Bar[of string]({return list}) \ No newline at end of file diff --git a/tests/testcases/net2/generics/generic-closures.boo b/tests/testcases/net2/generics/generic-closures.boo new file mode 100644 index 000000000..264f1d8e3 --- /dev/null +++ b/tests/testcases/net2/generics/generic-closures.boo @@ -0,0 +1,12 @@ +""" +pass +""" +import System + +static class Foo: + + def Bar[of T](baz as T): + var method = {return baz} + print method() + +Foo.Bar('pass') \ No newline at end of file diff --git a/tests/testcases/net2/generics/generic-overload-7.boo b/tests/testcases/net2/generics/generic-overload-7.boo new file mode 100644 index 000000000..569475c38 --- /dev/null +++ b/tests/testcases/net2/generics/generic-overload-7.boo @@ -0,0 +1,15 @@ +""" +1 +""" + +import System + +def Foo[of T](method as Func[of T]) as T: + return method() + +def Foo[of T](method as Func[of List[of T]]) as T: + return method()[0] + +bar = {return List[of int]((1, 2))} +x as int = Foo(bar) +print x \ No newline at end of file