Skip to content

Commit a4ca160

Browse files
authored
Support for Scala 3 (#18)
* Partial Support for Scala 3 This implements the support for the `@main` annotated methods in Scala 3. While `ParserForMethods` is implemented and passes all tests (except the tests using the old varargs that were moved in the `src-2` directory), `ParserForClass` is not implemented yet. * Add empty artifacts for Mima on Scala 3 * Use MainData.create instead of MainData.apply * Partial support for ParserForClass * Update Mill to support Scala Native on Scala 3 Scala Native is not supported yet since it's blocked by PPrint * Apply code review suggestions * Use Symbol.requiredClass to get annotations * Use cleaner syntax for splice * Match type to avoid casting * Rephrase error message * Remove asInstanceOf * Remove asInstanceOf
1 parent b1bdb8b commit a4ca160

22 files changed

+536
-211
lines changed

.github/workflows/actions.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ name: ci
22

33
on:
44
push:
5+
branches:
6+
- master
7+
tags:
8+
- '*'
59
pull_request:
610
branches:
711
- master

build.sc

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
import mill._, scalalib._, scalajslib._, scalanativelib._, publish._
2+
import mill.scalalib.api.Util.isScala3
23
import scalalib._
3-
import $ivy.`de.tototec::de.tobiasroeser.mill.vcs.version_mill0.9:0.1.1`
4+
import $ivy.`de.tototec::de.tobiasroeser.mill.vcs.version::0.1.4`
45
import de.tobiasroeser.mill.vcs.version.VcsVersion
5-
import $ivy.`com.github.lolgab::mill-mima_mill0.9:0.0.4`
6+
import $ivy.`com.github.lolgab::mill-mima::0.0.9`
67
import com.github.lolgab.mill.mima._
78

89
val scala212 = "2.12.13"
910
val scala213 = "2.13.4"
11+
val scala30 = "3.0.2"
12+
val scala31 = "3.1.1"
13+
14+
val scala2Versions = List(scala212, scala213)
1015

1116
val scalaJSVersions = for {
12-
scalaV <- Seq(scala213, scala212)
13-
scalaJSV <- Seq("1.4.0")
17+
scalaV <- scala30 :: scala2Versions
18+
scalaJSV <- Seq("1.5.1")
1419
} yield (scalaV, scalaJSV)
1520

1621
val scalaNativeVersions = for {
17-
scalaV <- Seq(scala213, scala212)
18-
scalaNativeV <- Seq("0.4.0")
22+
scalaV <- scala2Versions
23+
scalaNativeV <- Seq("0.4.3")
1924
} yield (scalaV, scalaNativeV)
2025

2126
trait MainArgsPublishModule extends PublishModule with CrossScalaModule with Mima {
@@ -26,6 +31,8 @@ trait MainArgsPublishModule extends PublishModule with CrossScalaModule with Mim
2631
.lastTag
2732
.getOrElse(throw new Exception("Missing last tag"))
2833
)
34+
// Remove after Scala 3 artifacts are published
35+
def mimaPreviousArtifacts = T{ if(isScala3(scalaVersion())) Seq() else super.mimaPreviousArtifacts() }
2936
def artifactName = "mainargs"
3037

3138
def pomSettings = PomSettings(
@@ -39,54 +46,57 @@ trait MainArgsPublishModule extends PublishModule with CrossScalaModule with Mim
3946
)
4047
)
4148

42-
def scalacOptions = super.scalacOptions() ++ Seq("-P:acyclic:force")
49+
def scalacOptions = super.scalacOptions() ++ (if (!isScala3(crossScalaVersion)) Seq("-P:acyclic:force") else Seq.empty)
4350

44-
def scalacPluginIvyDeps = super.scalacPluginIvyDeps() ++ Agg(ivy"com.lihaoyi::acyclic:0.2.0")
51+
def scalacPluginIvyDeps = super.scalacPluginIvyDeps() ++ (if (!isScala3(crossScalaVersion)) Agg(ivy"com.lihaoyi::acyclic:0.2.0") else Agg.empty)
4552

46-
def compileIvyDeps = super.compileIvyDeps() ++ Agg(
47-
ivy"com.lihaoyi::acyclic:0.2.0",
48-
ivy"org.scala-lang:scala-reflect:$crossScalaVersion"
49-
)
53+
def compileIvyDeps = super.compileIvyDeps() ++ (if (!isScala3(crossScalaVersion)) Agg(
54+
ivy"com.lihaoyi::acyclic:0.2.0",
55+
ivy"org.scala-lang:scala-reflect:$crossScalaVersion"
56+
) else Agg.empty)
5057

5158
def ivyDeps = Agg(
52-
ivy"org.scala-lang.modules::scala-collection-compat::2.4.0"
53-
)
59+
ivy"org.scala-lang.modules::scala-collection-compat::2.4.4"
60+
) ++ Agg(ivy"com.lihaoyi::pprint:0.6.6")
5461
}
5562

63+
def scalaMajor(scalaVersion: String) = if(isScala3(scalaVersion)) "3" else "2"
64+
5665
trait Common extends CrossScalaModule {
5766
def millSourcePath = build.millSourcePath / "mainargs"
5867
def sources = T.sources(
5968
millSourcePath / "src",
60-
millSourcePath / s"src-$platform"
69+
millSourcePath / s"src-$platform",
70+
millSourcePath / s"src-${scalaMajor(scalaVersion())}",
6171
)
6272
def platform: String
6373
}
6474

65-
trait CommonTestModule extends ScalaModule with TestModule {
66-
def ivyDeps = Agg(ivy"com.lihaoyi::utest::0.7.6")
67-
def testFrameworks = Seq("utest.runner.Framework")
75+
trait CommonTestModule extends ScalaModule with TestModule.Utest {
76+
def ivyDeps = Agg(ivy"com.lihaoyi::utest::0.7.11")
6877
def sources = T.sources(
6978
millSourcePath / "src",
70-
millSourcePath / s"src-$platform"
79+
millSourcePath / s"src-$platform",
80+
millSourcePath / s"src-${scalaMajor(scalaVersion())}",
7181
)
7282
def platform: String
7383
}
7484

7585

7686
object mainargs extends Module {
77-
object jvm extends Cross[JvmMainArgsModule](scala212, scala213)
87+
object jvm extends Cross[JvmMainArgsModule](scala30 :: scala2Versions: _*)
7888
class JvmMainArgsModule(val crossScalaVersion: String)
7989
extends Common with ScalaModule with MainArgsPublishModule {
8090
def platform = "jvm"
8191
object test extends Tests with CommonTestModule{
8292
def platform = "jvm"
83-
def ivyDeps = super.ivyDeps() ++ Agg(ivy"com.lihaoyi::os-lib:0.7.1")
93+
def ivyDeps = super.ivyDeps() ++ Agg(ivy"com.lihaoyi::os-lib:0.7.8")
8494
}
8595
}
8696

8797
object js extends Cross[JSMainArgsModule](scalaJSVersions: _*)
8898
class JSMainArgsModule(val crossScalaVersion: String, crossJSVersion: String)
89-
extends Common with ScalaJSModule with MainArgsPublishModule {
99+
extends Common with MainArgsPublishModule with ScalaJSModule {
90100
def platform = "js"
91101
def scalaJSVersion = crossJSVersion
92102
object test extends Tests with CommonTestModule{
@@ -96,7 +106,7 @@ object mainargs extends Module {
96106

97107
object native extends Cross[NativeMainArgsModule](scalaNativeVersions: _*)
98108
class NativeMainArgsModule(val crossScalaVersion: String, crossScalaNativeVersion: String)
99-
extends Common with ScalaNativeModule with MainArgsPublishModule {
109+
extends Common with MainArgsPublishModule with ScalaNativeModule {
100110
def scalaNativeVersion = crossScalaNativeVersion
101111
def platform = "native"
102112
object test extends Tests with CommonTestModule{
File renamed without changes.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package mainargs
2+
3+
import acyclic.skipped
4+
5+
import scala.language.experimental.macros
6+
7+
private[mainargs] trait ParserForClassCompanionVersionSpecific {
8+
def apply[T]: ParserForClass[T] = macro Macros.parserForClass[T]
9+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package mainargs
2+
3+
import acyclic.skipped
4+
5+
import scala.language.experimental.macros
6+
7+
private[mainargs] trait ParserForMethodsCompanionVersionSpecific {
8+
def apply[B](base: B): ParserForMethods[B] = macro Macros.parserForMethods[B]
9+
}

mainargs/src-3/Macros.scala

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package mainargs
2+
3+
import scala.quoted._
4+
5+
object Macros {
6+
private def mainAnnotation(using Quotes) = quotes.reflect.Symbol.requiredClass("mainargs.main")
7+
private def argAnnotation(using Quotes) = quotes.reflect.Symbol.requiredClass("mainargs.arg")
8+
def parserForMethods[B](base: Expr[B])(using Quotes, Type[B]): Expr[ParserForMethods[B]] = {
9+
import quotes.reflect._
10+
val allMethods = TypeRepr.of[B].typeSymbol.memberMethods
11+
val annotatedMethodsWithMainAnnotations = allMethods.flatMap { methodSymbol =>
12+
methodSymbol.getAnnotation(mainAnnotation).map(methodSymbol -> _)
13+
}.sortBy(_._1.pos.map(_.start))
14+
val mainDatas = Expr.ofList(annotatedMethodsWithMainAnnotations.map { (annotatedMethod, mainAnnotationInstance) =>
15+
createMainData[Any, B](annotatedMethod, mainAnnotationInstance)
16+
})
17+
18+
'{
19+
new ParserForMethods[B](
20+
MethodMains[B]($mainDatas, () => $base)
21+
)
22+
}
23+
}
24+
25+
def parserForClass[B](using Quotes, Type[B]): Expr[ParserForClass[B]] = {
26+
import quotes.reflect._
27+
val typeReprOfB = TypeRepr.of[B]
28+
val companionModule = typeReprOfB match {
29+
case TypeRef(a,b) => TermRef(a,b)
30+
}
31+
val typeSymbolOfB = typeReprOfB.typeSymbol
32+
val companionModuleType = typeSymbolOfB.companionModule.tree.asInstanceOf[ValDef].tpt.tpe.asType
33+
val companionModuleExpr = Ident(companionModule).asExpr
34+
val mainAnnotationInstance = typeSymbolOfB.getAnnotation(mainAnnotation).getOrElse {
35+
report.throwError(
36+
s"cannot find @main annotation on ${companionModule.name}",
37+
typeSymbolOfB.pos.get
38+
)
39+
}
40+
val annotatedMethod = TypeRepr.of[B].typeSymbol.companionModule.memberMethod("apply").head
41+
companionModuleType match
42+
case '[bCompanion] =>
43+
val mainData = createMainData[B, Any](annotatedMethod, mainAnnotationInstance)
44+
'{
45+
new ParserForClass[B](
46+
ClassMains[B](${ mainData }, () => ${ Ident(companionModule).asExpr })
47+
)
48+
}
49+
}
50+
51+
def createMainData[T: Type, B: Type](using Quotes)(method: quotes.reflect.Symbol, annotation: quotes.reflect.Term): Expr[MainData[T, B]] = {
52+
import quotes.reflect.*
53+
val params = method.paramSymss.headOption.getOrElse(report.throwError("Multiple parameter lists not supported"))
54+
val defaultParams = getDefaultParams(method)
55+
val argSigs = Expr.ofList(params.map { param =>
56+
val paramTree = param.tree.asInstanceOf[ValDef]
57+
val paramTpe = paramTree.tpt.tpe
58+
val arg = param.getAnnotation(argAnnotation).map(_.asExprOf[mainargs.arg]).getOrElse('{ new mainargs.arg() })
59+
val paramType = paramTpe.asType
60+
paramType match
61+
case '[t] =>
62+
val defaultParam: Expr[Option[B => t]] = defaultParams.get(param) match {
63+
case Some('{ $v: `t`}) => '{ Some(((_: B) => $v)) }
64+
case None => '{ None }
65+
}
66+
val argReader = Expr.summon[mainargs.ArgReader[t]].getOrElse {
67+
report.throwError(
68+
s"No mainargs.ArgReader found for parameter ${param.name}",
69+
param.pos.get
70+
)
71+
}
72+
'{ (ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ argReader })).asInstanceOf[ArgSig[Any, B]] }
73+
})
74+
75+
val invokeRaw: Expr[(B, Seq[Any]) => T] = {
76+
def callOf(args: Expr[Seq[Any]]) = call(method, '{ Seq( ${ args }) }).asExprOf[T]
77+
'{ ((b: B, params: Seq[Any]) => ${ callOf('{ params }) }) }
78+
}
79+
'{ MainData.create[T, B](${ Expr(method.name) }, ${ annotation.asExprOf[mainargs.main] }, ${ argSigs }, ${ invokeRaw }) }
80+
}
81+
82+
/** Call a method given by its symbol.
83+
*
84+
* E.g.
85+
*
86+
* assuming:
87+
*
88+
* def foo(x: Int, y: String)(z: Int)
89+
*
90+
* val argss: List[List[Any]] = ???
91+
*
92+
* then:
93+
*
94+
* call(<symbol of foo>, '{argss})
95+
*
96+
* will expand to:
97+
*
98+
* foo(argss(0)(0), argss(0)(1))(argss(1)(0))
99+
*
100+
*/
101+
private def call(using Quotes)(
102+
method: quotes.reflect.Symbol,
103+
argss: Expr[Seq[Seq[Any]]]
104+
): Expr[_] = {
105+
// Copy pasted from Cask.
106+
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L106
107+
import quotes.reflect._
108+
val paramss = method.paramSymss
109+
110+
if (paramss.isEmpty) {
111+
report.throwError("At least one parameter list must be declared.", method.pos.get)
112+
}
113+
114+
val fct = Ref(method)
115+
116+
val accesses: List[List[Term]] = for (i <- paramss.indices.toList) yield {
117+
for (j <- paramss(i).indices.toList) yield {
118+
val tpe = paramss(i)(j).tree.asInstanceOf[ValDef].tpt.tpe
119+
tpe.asType match
120+
case '[t] => '{ $argss(${Expr(i)})(${Expr(j)}).asInstanceOf[t] }.asTerm
121+
}
122+
}
123+
124+
val base = Apply(fct, accesses.head)
125+
val application: Apply = accesses.tail.foldLeft(base)((lhs, args) => Apply(lhs, args))
126+
val expr = application.asExpr
127+
expr
128+
}
129+
130+
131+
/** Lookup default values for a method's parameters. */
132+
private def getDefaultParams(using Quotes)(method: quotes.reflect.Symbol): Map[quotes.reflect.Symbol, Expr[Any]] = {
133+
// Copy pasted from Cask.
134+
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L38
135+
import quotes.reflect._
136+
137+
val params = method.paramSymss.flatten
138+
val defaults = collection.mutable.Map.empty[Symbol, Expr[Any]]
139+
140+
val Name = (method.name + """\$default\$(\d+)""").r
141+
142+
val idents = method.owner.tree.asInstanceOf[ClassDef].body
143+
idents.foreach{
144+
case deff @ DefDef(Name(idx), _, _, _) =>
145+
val expr = Ref(deff.symbol).asExpr
146+
defaults += (params(idx.toInt - 1) -> expr)
147+
case _ =>
148+
}
149+
150+
defaults.toMap
151+
}
152+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package mainargs
2+
3+
import scala.language.experimental.macros
4+
5+
private [mainargs] trait ParserForClassCompanionVersionSpecific {
6+
inline def apply[T]: ParserForClass[T] = ${ Macros.parserForClass[T] }
7+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package mainargs
2+
3+
private [mainargs] trait ParserForMethodsCompanionVersionSpecific {
4+
inline def apply[B](base: B): ParserForMethods[B] = ${ Macros.parserForMethods[B]('base) }
5+
}

mainargs/src-3/acyclic.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package acyclic
2+
3+
def skipped = ???

mainargs/src/Parser.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package mainargs
2+
3+
import acyclic.skipped
4+
25
import scala.language.experimental.macros
36
import java.io.PrintStream
4-
object ParserForMethods{
5-
def apply[B](base: B): ParserForMethods[B] = macro Macros.parserForMethods[B]
6-
}
7+
8+
object ParserForMethods extends ParserForMethodsCompanionVersionSpecific
79
class ParserForMethods[B](val mains: MethodMains[B]){
810
def helpText(totalWidth: Int = 100,
911
docsOnNewLine: Boolean = false,
@@ -102,9 +104,7 @@ class ParserForMethods[B](val mains: MethodMains[B]){
102104
}
103105
}
104106

105-
object ParserForClass{
106-
def apply[T]: ParserForClass[T] = macro Macros.parserForClass[T]
107-
}
107+
object ParserForClass extends ParserForClassCompanionVersionSpecific
108108
class ParserForClass[T](val mains: ClassMains[T]) extends SubParser[T]{
109109
def helpText(totalWidth: Int = 100,
110110
docsOnNewLine: Boolean = false,

0 commit comments

Comments
 (0)