Skip to content

Inference error fix #518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions src/Draco.Compiler.Tests/Semantics/TypeCheckingTests.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using Draco.Compiler.Api.Syntax;
using Draco.Compiler.Internal.Binding;
using Draco.Compiler.Internal.Symbols;
Expand Down Expand Up @@ -2498,4 +2499,76 @@ public void AssigningListToArrayTypeIsIllegal()
Assert.Equal(2, diags.Length);
AssertDiagnostics(diags, TypeCheckingErrors.InferenceIncomplete, TypeCheckingErrors.TypeMismatch);
}

[Fact]
public void OrderIndependenceInTypeInference()
{
// Bug https://github.com/Draco-lang/Compiler/issues/515

// Working code:
// import System;
// import System.Collections.Generic;
//
// func main() {
// val l = List();
// l.Add(ArgumentException());
// l.Add(Exception());
// l.Add(InvalidOperationException());
// }
var workingSyntax = SyntaxTree.Create(CompilationUnit(
ImportDeclaration("System"),
ImportDeclaration("System", "Collections", "Generic"),
FunctionDeclaration(
"main",
ParameterList(),
null,
BlockFunctionBody(
DeclarationStatement(ValDeclaration("l", null, CallExpression(NameExpression("List")))),
ExpressionStatement(CallExpression(MemberExpression(NameExpression("l"), "Add"), CallExpression(NameExpression("ArgumentException")))),
ExpressionStatement(CallExpression(MemberExpression(NameExpression("l"), "Add"), CallExpression(NameExpression("Exception")))),
ExpressionStatement(CallExpression(MemberExpression(NameExpression("l"), "Add"), CallExpression(NameExpression("InvalidOperationException"))))))));

// Flipped code:
// import System;
// import System.Collections.Generic;
//
// func main() {
// val l = List();
// l.Add(ArgumentException());
// l.Add(InvalidOperationException());
// l.Add(Exception());
// }
var flippedSyntax = SyntaxTree.Create(CompilationUnit(
ImportDeclaration("System"),
ImportDeclaration("System", "Collections", "Generic"),
FunctionDeclaration(
"main",
ParameterList(),
null,
BlockFunctionBody(
DeclarationStatement(ValDeclaration("l", null, CallExpression(NameExpression("List")))),
ExpressionStatement(CallExpression(MemberExpression(NameExpression("l"), "Add"), CallExpression(NameExpression("ArgumentException")))),
ExpressionStatement(CallExpression(MemberExpression(NameExpression("l"), "Add"), CallExpression(NameExpression("InvalidOperationException")))),
ExpressionStatement(CallExpression(MemberExpression(NameExpression("l"), "Add"), CallExpression(NameExpression("Exception"))))))));

// Act
var workingCompilation = CreateCompilation(workingSyntax);
var workingSemanticModel = workingCompilation.GetSemanticModel(workingSyntax);
var workingDiags = workingSemanticModel.Diagnostics;
var workingListVariable = GetInternalSymbol<LocalSymbol>(workingSemanticModel.GetDeclaredSymbol(workingSyntax.GetNode<VariableDeclarationSyntax>()));

var flippedCompilation = CreateCompilation(flippedSyntax);
var flippedSemanticModel = flippedCompilation.GetSemanticModel(flippedSyntax);
var flippedDiags = flippedSemanticModel.Diagnostics;
var flippedListVariable = GetInternalSymbol<LocalSymbol>(flippedSemanticModel.GetDeclaredSymbol(flippedSyntax.GetNode<VariableDeclarationSyntax>()));

// Assert
Assert.Empty(workingDiags);
Assert.Empty(flippedDiags);

// NOTE: This is a janky way to compare the types, but currently SymbolEqualityComparer
// can't compare types between different compilation instances
Assert.Equal("List<Exception>", workingListVariable.Type.ToString());
Assert.Equal("List<Exception>", flippedListVariable.Type.ToString());
}
}
13 changes: 8 additions & 5 deletions src/Draco.Compiler/Internal/Binding/Binder_Expression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,8 @@ private async BindingTask<BoundExpression> BindBinaryExpression(BinaryExpression
indexSet.Receiver,
indexSet.Setter,
indexSet.Indices,
await rightTask);
await rightTask,
indexSet.Type);
}
else
{
Expand Down Expand Up @@ -553,8 +554,10 @@ private async BindingTask<BoundExpression> BindBinaryExpression(BinaryExpression
syntax,
indexSet.Receiver,
getter,
indexSet.Indices),
await rightTask));
indexSet.Indices,
indexSet.Type),
await rightTask),
indexSet.Type);
}
else
{
Expand Down Expand Up @@ -706,11 +709,11 @@ private async BindingTask<BoundExpression> BindIndexExpression(IndexExpressionSy
if (receiver.TypeRequired.IsArrayType)
{
// Array getter
return new BoundArrayAccessExpression(syntax, receiver, await BindingTask.WhenAll(argsTask));
return new BoundArrayAccessExpression(syntax, receiver, await BindingTask.WhenAll(argsTask), elementType);
}
else
{
return new BoundIndexGetExpression(syntax, receiver, indexer, await BindingTask.WhenAll(argsTask));
return new BoundIndexGetExpression(syntax, receiver, indexer, await BindingTask.WhenAll(argsTask), elementType);
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/Draco.Compiler/Internal/Binding/Binder_Lvalue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,17 @@ private async BindingTask<BoundLvalue> BindIndexLvalue(IndexExpressionSyntax syn
return new BoundArrayAccessLvalue(
syntax,
receiver,
await BindingTask.WhenAll(argsTask));
await BindingTask.WhenAll(argsTask),
elementType);
}
else
{
return new BoundIndexSetLvalue(
syntax,
receiver,
indexer,
await BindingTask.WhenAll(argsTask));
await BindingTask.WhenAll(argsTask),
elementType);
}
}

Expand Down
26 changes: 0 additions & 26 deletions src/Draco.Compiler/Internal/BoundTree/BoundNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,6 @@ internal partial class BoundPropertySetExpression
public override TypeSymbol? Type => this.Value.Type;
}

internal partial class BoundIndexGetExpression
{
public override TypeSymbol? Type => this.Getter.ReturnType;
}

internal partial class BoundIndexSetExpression
{
public override TypeSymbol? Type => this.Value.Type;
}

internal partial class BoundLocalExpression
{
public override TypeSymbol Type => this.Local.Type;
Expand Down Expand Up @@ -146,12 +136,6 @@ internal partial class BoundDelegateCreationExpression
public override TypeSymbol Type => (TypeSymbol)this.DelegateConstructor.ContainingSymbol!;
}

internal partial class BoundArrayAccessExpression
{
public override TypeSymbol Type => this.Array.TypeRequired.GenericArguments.FirstOrDefault()
?? WellKnownTypes.ErrorType;
}

internal partial class BoundCallExpression
{
public override TypeSymbol Type => this.Method.ReturnType;
Expand Down Expand Up @@ -194,17 +178,7 @@ internal partial class BoundFieldLvalue
public override TypeSymbol Type => this.Field.Type;
}

internal partial class BoundArrayAccessLvalue
{
public override TypeSymbol Type => this.Array.TypeRequired.GenericArguments[0];
}

internal partial class BoundPropertySetLvalue
{
public override TypeSymbol Type => ((IPropertyAccessorSymbol)this.Setter).Property.Type;
}

internal partial class BoundIndexSetLvalue
{
public override TypeSymbol Type => ((IPropertyAccessorSymbol)this.Setter).Property.Type;
}
5 changes: 5 additions & 0 deletions src/Draco.Compiler/Internal/BoundTree/BoundNodes.xml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
<Node Name="BoundArrayAccessExpression" Base="BoundExpression">
<Field Name="Array" Type="BoundExpression" />
<Field Name="Indices" Type="ImmutableArray&lt;BoundExpression&gt;" />
<Field Name="Type" Type="TypeSymbol" Override="true" />
</Node>

<Node Name="BoundModuleExpression" Base="BoundExpression">
Expand Down Expand Up @@ -193,13 +194,15 @@
<Field Name="Receiver" Type="BoundExpression"/>
<Field Name="Getter" Type="FunctionSymbol" />
<Field Name="Indices" Type="ImmutableArray&lt;BoundExpression&gt;" />
<Field Name="Type" Type="TypeSymbol" Override="true" />
</Node>

<Node Name="BoundIndexSetExpression" Base="BoundExpression">
<Field Name="Receiver" Type="BoundExpression"/>
<Field Name="Setter" Type="FunctionSymbol" />
<Field Name="Indices" Type="ImmutableArray&lt;BoundExpression&gt;" />
<Field Name="Value" Type="BoundExpression"/>
<Field Name="Type" Type="TypeSymbol" Override="true" />
</Node>

<Node Name="BoundDelegateCreationExpression" Base="BoundExpression">
Expand Down Expand Up @@ -282,11 +285,13 @@
<Node Name="BoundArrayAccessLvalue" Base="BoundLvalue">
<Field Name="Array" Type="BoundExpression" />
<Field Name="Indices" Type="ImmutableArray&lt;BoundExpression&gt;" />
<Field Name="Type" Type="TypeSymbol" Override="true" />
</Node>

<Node Name="BoundIndexSetLvalue" Base="BoundLvalue">
<Field Name="Receiver" Type="BoundExpression" />
<Field Name="Setter" Type="FunctionSymbol" />
<Field Name="Indices" Type="ImmutableArray&lt;BoundExpression&gt;" />
<Field Name="Type" Type="TypeSymbol" Override="true" />
</Node>
</Tree>
6 changes: 4 additions & 2 deletions src/Draco.Compiler/Internal/Lowering/LocalRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ public override BoundNode VisitCallExpression(BoundCallExpression node)
.Select((n, i) => ExpressionStatement(AssignmentExpression(
left: ArrayAccessLvalue(
array: LocalExpression(varArgs),
indices: [this.LiteralExpression(i)]),
indices: [this.LiteralExpression(i)],
type: elementType),
right: n)) as BoundStatement);

return BlockExpression(
Expand Down Expand Up @@ -528,7 +529,8 @@ public override BoundNode VisitStringExpression(BoundStringExpression node)
arrayAssignmentBuilder.Add(ExpressionStatement(AssignmentExpression(
left: ArrayAccessLvalue(
array: LocalExpression(arrayLocal),
indices: [this.LiteralExpression(i)]),
indices: [this.LiteralExpression(i)],
type: this.WellKnownTypes.SystemObject),
right: args[i])));
}

Expand Down
74 changes: 52 additions & 22 deletions src/Draco.Compiler/Internal/Solver/ConstraintSolver_Rules.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,6 @@ private IEnumerable<Rule> ConstructRules(DiagnosticBag diagnostics) => [
})
.Named("assignable"),

// If all types are ground-types, common-type constraints are trivial
Simplification(typeof(CommonAncestor))
.Guard((CommonAncestor common) => common.AlternativeTypes.All(t => t.IsGroundType))
.Body((ConstraintStore store, CommonAncestor common) =>
{
foreach (var type in common.AlternativeTypes)
{
if (!common.AlternativeTypes.All(t => SymbolEqualityComparer.Default.IsBaseOf(type, t))) continue;
// Found a good common type
this.Assignable(common.CommonType, type, ConstraintLocator.Constraint(common));
return;
}
// No common type found
common.ReportDiagnostic(diagnostics, builder => builder
.WithFormatArgs(string.Join(", ", common.AlternativeTypes)));
// Stop cascading uninferred type
UnifyWithError(common.CommonType);
})
.Named("common_ancestor"),

// Member constraints are trivial, if the receiver is a ground-type
Simplification(typeof(Member))
.Guard((Member member) => !member.Receiver.Substitution.IsTypeVariable)
Expand Down Expand Up @@ -177,7 +157,7 @@ private IEnumerable<Rule> ConstructRules(DiagnosticBag diagnostics) => [
}

// If there is a single indexer, we check visibility
// This is because in this case overload resolution will skip hecking visibility
// This is because in this case overload resolution will skip checking visibility
if (indexers.Length == 1)
{
this.Context.CheckVisibility(indexer.Locator, indexers[0], "indexer", diagnostics);
Expand All @@ -199,7 +179,7 @@ private IEnumerable<Rule> ConstructRules(DiagnosticBag diagnostics) => [
else
{
// Setter
// We allocate a type var for the return type, but we don't care about it
// We allocate a type var for the return type, but we don't care about it as it's generally just void
var returnType = this.AllocateTypeVariable();
store.Add(new Overload(
locator: ConstraintLocator.Constraint(indexer),
Expand Down Expand Up @@ -370,6 +350,36 @@ private IEnumerable<Rule> ConstructRules(DiagnosticBag diagnostics) => [
})
.Named("overload"),

// This is basically an accumulation case of merged assignables (see below)
// Once we merged two, we'll have an additional common-ancestor constraint we'll need to maintain,
// if more assignables are merged. This is basically that.
// Example:
//
// var x = Derived1();
// x = Derived2();
// x = Base();
Simplification(typeof(Assignable), typeof(Assignable), typeof(CommonAncestor))
.Guard((Assignable a1, Assignable a2, CommonAncestor comm) =>
SymbolEqualityComparer.AllowTypeVariables.Equals(a1.TargetType, a2.TargetType)
&& ( SymbolEqualityComparer.AllowTypeVariables.Equals(a1.AssignedType, comm.CommonType)
|| SymbolEqualityComparer.AllowTypeVariables.Equals(a2.AssignedType, comm.CommonType)))
.Body((ConstraintStore store, Assignable a1, Assignable a2, CommonAncestor comm) =>
{
var targetType = a1.TargetType;
var alternative = SymbolEqualityComparer.AllowTypeVariables.Equals(a1.AssignedType, comm.CommonType)
? a2.AssignedType
: a1.AssignedType;
store.Add(new CommonAncestor(
locator: ConstraintLocator.Constraint(a2),
commonType: comm.CommonType,
alternativeTypes: [alternative, ..comm.AlternativeTypes]));
store.Add(new Assignable(
locator: ConstraintLocator.Constraint(a2),
targetType: targetType,
assignedType: comm.CommonType));
})
.Named("merge_assignables_accumulate"),

// As a last resort, we try to drive forward the solver by trying to merge assignable constraints with the same target
// This is a common situation for things like this:
//
Expand All @@ -395,6 +405,26 @@ private IEnumerable<Rule> ConstructRules(DiagnosticBag diagnostics) => [
})
.Named("merge_assignables"),

// If all types are ground-types, common-type constraints are trivial
Simplification(typeof(CommonAncestor))
.Guard((CommonAncestor common) => common.AlternativeTypes.All(t => t.IsGroundType))
.Body((ConstraintStore store, CommonAncestor common) =>
{
foreach (var type in common.AlternativeTypes)
{
if (!common.AlternativeTypes.All(t => SymbolEqualityComparer.Default.IsBaseOf(type, t))) continue;
// Found a good common type
this.Assignable(common.CommonType, type, ConstraintLocator.Constraint(common));
return;
}
// No common type found
common.ReportDiagnostic(diagnostics, builder => builder
.WithFormatArgs(string.Join(", ", common.AlternativeTypes)));
// Stop cascading uninferred type
UnifyWithError(common.CommonType);
})
.Named("common_ancestor"),

// As a last-last effort, we assume that a singular assignment means exact matching types
Simplification(typeof(Assignable))
.Guard((Assignable assignable) => CanAssign(assignable.TargetType, assignable.AssignedType))
Expand Down
21 changes: 12 additions & 9 deletions src/Draco.Examples.Tests/ExamplesTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,20 @@ public ExamplesTests()
public async Task RunExample(string projectFile, string verifiedFile)
{
// Invoke 'dotnet run' on the project
var startInfo = new ProcessStartInfo
{
FileName = "dotnet",
ArgumentList = { "run", "--project", projectFile },
RedirectStandardOutput = true,
RedirectStandardError = true,
UseShellExecute = false,
CreateNoWindow = true,
};
// Skip first-time message
startInfo.EnvironmentVariables["DOTNET_SKIP_FIRST_TIME_EXPERIENCE"] = "1";
var process = new Process
{
StartInfo = new ProcessStartInfo
{
FileName = "dotnet",
ArgumentList = { "run", "--project", projectFile },
RedirectStandardOutput = true,
RedirectStandardError = true,
UseShellExecute = false,
CreateNoWindow = true,
},
StartInfo = startInfo,
};
process.Start();

Expand Down
Loading