diff --git a/internal/checker/relater.go b/internal/checker/relater.go index 66a596ee0e..fdae04e56d 100644 --- a/internal/checker/relater.go +++ b/internal/checker/relater.go @@ -86,17 +86,13 @@ type ErrorOutputContainer struct { type ErrorReporter func(message *diagnostics.Message, args ...any) -type RecursionIdKind uint32 - -const ( - RecursionIdKindNode RecursionIdKind = iota - RecursionIdKindSymbol - RecursionIdKindType -) - type RecursionId struct { - kind RecursionIdKind - id uint32 + value any +} + +// This function exists to constrain the types of values that can be used as recursion IDs. +func asRecursionId[T *ast.Node | *ast.Symbol | *Type](value T) RecursionId { + return RecursionId{value: value} } type Relation struct { @@ -836,21 +832,21 @@ func getRecursionIdentity(t *Type) RecursionId { // Deferred type references are tracked through their associated AST node. This gives us finer // granularity than using their associated target because each manifest type reference has a // unique AST node. - return RecursionId{kind: RecursionIdKindNode, id: uint32(ast.GetNodeId(t.AsTypeReference().node))} + return asRecursionId(t.AsTypeReference().node) } if t.symbol != nil && !(t.objectFlags&ObjectFlagsAnonymous != 0 && t.symbol.Flags&ast.SymbolFlagsClass != 0) { // We track object types that have a symbol by that symbol (representing the origin of the type), but // exclude the static side of a class since it shares its symbol with the instance side. - return RecursionId{kind: RecursionIdKindSymbol, id: uint32(ast.GetSymbolId(t.symbol))} + return asRecursionId(t.symbol) } if isTupleType(t) { - return RecursionId{kind: RecursionIdKindType, id: uint32(t.Target().id)} + return asRecursionId(t.Target()) } } if t.flags&TypeFlagsTypeParameter != 0 && t.symbol != nil { // We use the symbol of the type parameter such that all "fresh" instantiations of that type parameter // have the same recursion identity. - return RecursionId{kind: RecursionIdKindSymbol, id: uint32(ast.GetSymbolId(t.symbol))} + return asRecursionId(t.symbol) } if t.flags&TypeFlagsIndexedAccess != 0 { // Identity is the leftmost object type in a chain of indexed accesses, eg, in A[P1][P2][P3] it is A. @@ -858,13 +854,13 @@ func getRecursionIdentity(t *Type) RecursionId { for t.flags&TypeFlagsIndexedAccess != 0 { t = t.AsIndexedAccessType().objectType } - return RecursionId{kind: RecursionIdKindType, id: uint32(t.id)} + return asRecursionId(t) } if t.flags&TypeFlagsConditional != 0 { // The root object represents the origin of the conditional type - return RecursionId{kind: RecursionIdKindNode, id: uint32(ast.GetNodeId(t.AsConditionalType().root.node.AsNode()))} + return asRecursionId(t.AsConditionalType().root.node.AsNode()) } - return RecursionId{kind: RecursionIdKindType, id: uint32(t.id)} + return asRecursionId(t) } func (c *Checker) getBestMatchingType(source *Type, target *Type, isRelatedTo func(source *Type, target *Type) Ternary) *Type {