Skip to content

Precise apply for enum companion objects #9728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
26 changes: 10 additions & 16 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,12 @@ object desugar {
// a reference to the class type bound by `cdef`, with type parameters coming from the constructor
val classTypeRef = appliedRef(classTycon)

def applyResultTpt =
if isEnumCase then
classTypeRef
else
TypeTree()

// a reference to `enumClass`, with type parameters coming from the case constructor
lazy val enumClassTypeRef =
if (enumClass.typeParams.isEmpty)
Expand Down Expand Up @@ -605,7 +611,7 @@ object desugar {
cpy.ValDef(vparam)(rhs = copyDefault(vparam)))
val copyRestParamss = derivedVparamss.tail.nestedMap(vparam =>
cpy.ValDef(vparam)(rhs = EmptyTree))
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr)
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, applyResultTpt, creatorExpr)
.withMods(Modifiers(Synthetic | constr1.mods.flags & copiedAccessFlags, constr1.mods.privateWithin)) :: Nil
}
}
Expand Down Expand Up @@ -656,15 +662,6 @@ object desugar {
// For all other classes, the parent is AnyRef.
val companions =
if (isCaseClass) {
// The return type of the `apply` method, and an (empty or singleton) list
// of widening coercions
val (applyResultTpt, widenDefs) =
if (!isEnumCase)
(TypeTree(), Nil)
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
(enumClassTypeRef, Nil)
else
enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))

// true if access to the apply method has to be restricted
// i.e. if the case class constructor is either private or qualified private
Expand Down Expand Up @@ -695,8 +692,6 @@ object desugar {
then anyRef
else
constrVparamss.foldRight(classTypeRef)((vparams, restpe) => Function(vparams map (_.tpt), restpe))
def widenedCreatorExpr =
widenDefs.foldLeft(creatorExpr)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
val applyMeths =
if (mods.is(Abstract)) Nil
else {
Expand All @@ -709,9 +704,8 @@ object desugar {
val appParamss =
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)
.withMods(appMods)
app :: widenDefs
DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, creatorExpr)
.withMods(appMods) :: Nil
}
val unapplyMeth = {
val hasRepeatedParam = constrVparamss.head.exists {
Expand All @@ -720,7 +714,7 @@ object desugar {
val methName = if (hasRepeatedParam) nme.unapplySeq else nme.unapply
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
val unapplyResTp = if (arity == 0) Literal(Constant(true)) else TypeTree()
val unapplyResTp = if (arity == 0) Literal(Constant(true)) else applyResultTpt
DefDef(methName, derivedTparams, (unapplyParam :: Nil) :: Nil, unapplyResTp, unapplyRHS)
.withMods(synthetic)
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,8 @@ object Types {
def widenSingletons(using Context): Type = dealias match {
case tp: SingletonType =>
tp.widen
case tp: (TypeRef | AppliedType) if tp.typeSymbol.isAllOf(EnumCase) =>
tp.parents.head
case tp: OrType =>
val tp1w = tp.widenSingletons
if (tp1w eq tp) this else tp1w
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Scanners.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1399,8 +1399,8 @@ object Scanners {

object IndentWidth {
private inline val MaxCached = 40
private val spaces = Array.tabulate(MaxCached + 1)(new Run(' ', _))
private val tabs = Array.tabulate(MaxCached + 1)(new Run('\t', _))
private val spaces = Array.tabulate[Run](MaxCached + 1)(new Run(' ', _)) // TODO: remove new after bootstrap
private val tabs = Array.tabulate[Run](MaxCached + 1)(new Run('\t', _)) // TODO: remove new after bootstrap

def Run(ch: Char, n: Int): Run =
if (n <= MaxCached && ch == ' ') spaces(n)
Expand Down
10 changes: 10 additions & 0 deletions tests/pos/i3935.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
enum Foo3[T](x: T) {
case Bar[S, T](y: T) extends Foo3[y.type](y)
}

val foo: Foo3.Bar[Nothing, 3] = Foo3.Bar(3)
val bar = foo

def baz[T](f: Foo3[T]): f.type = f

val qux = baz(bar) // existentials are back in Dotty?
3 changes: 1 addition & 2 deletions tests/run-macros/i8007/Macro_3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ object Eq {
$ordx == $ordy && $elements($ordx).asInstanceOf[Eq[Any]].eqv($x, $y)
}
}

'{
eqSum((x: T, y: T) => ${eqSumBody('x, 'y)})
}
Expand All @@ -76,4 +75,4 @@ object Macro3 {
extension [T](x: =>T) inline def === (y: =>T)(using eq: Eq[T]): Boolean = eq.eqv(x, y)

implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] }
}
}
30 changes: 25 additions & 5 deletions tests/run-macros/i8007/Test_4.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,22 @@ import Macro3.eqGen
case class Person(name: String, age: Int)

enum Opt[+T] {
case Sm(t: T)
case Sm[U](t: U) extends Opt[U]
case Nn
}

enum OptInfer[+T] {
case Sm[+U](t: U) extends OptInfer[U]
case Nn
}

// simulation of Opt using case class hierarchy
sealed abstract class OptCase[+T]
object OptCase {
final case class Sm[T](t: T) extends OptCase[T]
case object Nn extends OptCase[Nothing]
}

@main def Test() = {
import Opt._
import Eq.{given _, _}
Expand All @@ -30,15 +42,23 @@ enum Opt[+T] {
println(t4) // false
println

val t5 = Sm(23) === Sm(23)
val t5 = Opt.Sm[Int](23) === Opt.Sm(23) // same behaviour as case class when using apply
println(t5) // true
println

val t6 = Sm(Person("Test", 23)) === Sm(Person("Test", 23))
val t5_2 = OptCase.Sm[Int](23) === OptCase.Sm(23)
println(t5_2) // true
println

val t5_3 = OptInfer.Sm(23) === OptInfer.Sm(23) // covariant `Sm` case means we can avoid explicit type parameter
println(t5_3) // true
println

val t6 = Sm[Person](Person("Test", 23)) === Sm(Person("Test", 23))
println(t6) // true
println

val t7 = Sm(Person("Test", 23)) === Sm(Person("Test", 24))
val t7 = Sm[Person](Person("Test", 23)) === Sm(Person("Test", 24))
println(t7) // false
println
}
}
31 changes: 31 additions & 0 deletions tests/run/enum-precise.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
enum NonEmptyList[+T]:
case Many[+U](head: U, tail: NonEmptyList[U]) extends NonEmptyList[U]
case One [+U](value: U) extends NonEmptyList[U]

enum Ast:
case Binding(name: String, tpe: String)
case Lambda(args: NonEmptyList[Binding], rhs: Ast) // reference to another case of the enum
case Ident(name: String)
case Apply(fn: Ast, args: NonEmptyList[Ast])

import NonEmptyList._
import Ast._

// This example showcases the widening when inferring enum case types.
// With scala 2 case class hierarchies, if One.apply(1) returns One[Int] and Many.apply(2, One(3)) returns Many[Int]
// then the `foldRight` expression below would complain that Many[Binding] is not One[Binding]. With Scala 3 enums,
// .apply on the companion returns the precise class, but type inference will widen to NonEmptyList[Binding] unless
// the precise class is expected.
def Bindings(arg: (String, String), args: (String, String)*): NonEmptyList[Binding] =
def Bind(arg: (String, String)): Binding =
val (name, tpe) = arg
Binding(name, tpe)

args.foldRight(One[Binding](Bind(arg)))((arg, acc) => Many(Bind(arg), acc))

@main def Test: Unit =
val OneOfOne: One[1] = One[1](1)
val True = Lambda(Bindings("x" -> "T", "y" -> "T"), Ident("x"))
val Const = Lambda(One(Binding("x", "T")), Lambda(One(Binding("y", "U")), Ident("x"))) // precise type is forwarded

assert(OneOfOne.value == 1)