Skip to content

Commit

Permalink
Forward dialect to ScalaMtags to avoid parsing fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
tgodzik committed May 15, 2021
1 parent 71d89fa commit 7ee694e
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 34 deletions.
8 changes: 6 additions & 2 deletions metals-bench/src/main/scala/bench/MetalsBench.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ class MetalsBench {
@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def mtagsScalaIndex(): Unit = {
scalaDependencySources.inputs.foreach { input => Mtags.index(input) }
scalaDependencySources.inputs.foreach { input =>
Mtags.index(input, dialects.Scala213)
}
}

@Benchmark
Expand Down Expand Up @@ -138,7 +140,9 @@ class MetalsBench {
@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def mtagsJavaParse(): Unit = {
javaDependencySources.inputs.foreach { input => Mtags.index(input) }
javaDependencySources.inputs.foreach { input =>
Mtags.index(input, dialects.Scala213)
}
}

@Benchmark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,11 @@ class DestinationProvider(
// Read text file from disk instead of editor buffers because the file
// on disk is more likely to parse.
lazy val parsed =
mtags.index(symbolDefinition.path.toLanguage, defnRevisedInput)
mtags.index(
symbolDefinition.path.toLanguage,
defnRevisedInput,
symbolDefinition.dialect
)

if (symbolDefinition.path.isAmmoniteScript || parsed.occurrences.isEmpty) {
// Fall back to SemanticDB on disk, if any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2211,7 +2211,7 @@ class MetalsLanguageServer(

// Since the `symbols` here are toplevel symbols,
// we cannot use `symbols` for expiring the cache for all symbols in the source.
symbolDocs.expireSymbolDefinition(sourceToIndex0)
symbolDocs.expireSymbolDefinition(sourceToIndex0, dialect)
}
} catch {
case NonFatal(e) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import java.util.logging.Logger
import scala.collection.concurrent.TrieMap
import scala.util.control.NonFatal

import scala.meta.Dialect
import scala.meta.inputs.Input
import scala.meta.internal.mtags.GlobalSymbolIndex
import scala.meta.internal.mtags.MtagsEnrichments._
Expand Down Expand Up @@ -54,10 +55,10 @@ class Docstrings(index: GlobalSymbolIndex) {
*
* @param path the absolute path for the source file to update.
*/
def expireSymbolDefinition(path: AbsolutePath): Unit = {
def expireSymbolDefinition(path: AbsolutePath, dialect: Dialect): Unit = {
path.toLanguage match {
case Language.SCALA =>
new Deindexer(path.toInput).indexRoot()
new Deindexer(path.toInput, dialect).indexRoot()
case _ =>
}
}
Expand Down Expand Up @@ -86,14 +87,15 @@ class Docstrings(index: GlobalSymbolIndex) {
.foreach(defn.path.toInput)(cacheSymbol)
case Language.SCALA =>
ScaladocIndexer
.foreach(defn.path.toInput)(cacheSymbol)
.foreach(defn.path.toInput, defn.dialect)(cacheSymbol)
case _ =>
}
}

private class Deindexer(
input: Input.VirtualFile
) extends ScalaMtags(input) {
input: Input.VirtualFile,
dialect: Dialect
) extends ScalaMtags(input, dialect) {
override def visitOccurrence(
occ: SymbolOccurrence,
sinfo: SymbolInformation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package scala.meta.internal.metals
import scala.collection.mutable

import scala.meta._
import scala.meta.dialects.Scala213
import scala.meta.internal.docstrings._
import scala.meta.internal.jdk.CollectionConverters._
import scala.meta.internal.mtags.ScalaMtags
Expand All @@ -18,8 +19,9 @@ import scala.meta.pc.SymbolDocumentation
*/
class ScaladocIndexer(
input: Input.VirtualFile,
fn: SymbolDocumentation => Unit
) extends ScalaMtags(input) {
fn: SymbolDocumentation => Unit,
dialect: Dialect
) extends ScalaMtags(input, dialect) {
val defines: mutable.Map[String, String] = mutable.Map.empty[String, String]
override def visitOccurrence(
occ: SymbolOccurrence,
Expand Down Expand Up @@ -132,9 +134,10 @@ object ScaladocIndexer {
* @param fn callback function for calculated SymbolDocumentation
*/
def foreach(
input: Input.VirtualFile
input: Input.VirtualFile,
dialect: Dialect
)(fn: SymbolDocumentation => Unit): Unit = {
new ScaladocIndexer(input, fn).indexRoot()
new ScaladocIndexer(input, fn, dialect).indexRoot()
}

/**
Expand Down
12 changes: 8 additions & 4 deletions mtags/src/main/scala/scala/meta/internal/mtags/Mtags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,17 @@ final class Mtags {
}
}

def index(language: Language, input: Input.VirtualFile): TextDocument = {
def index(
language: Language,
input: Input.VirtualFile,
dialect: Dialect
): TextDocument = {
addLines(language, input.text)
val result =
if (language.isJava) {
JavaMtags.index(input).index()
} else if (language.isScala) {
ScalaMtags.index(input).index()
ScalaMtags.index(input, dialect).index()
} else {
TextDocument()
}
Expand All @@ -66,8 +70,8 @@ final class Mtags {
}
}
object Mtags {
def index(input: Input.VirtualFile): TextDocument = {
new Mtags().index(input.toLanguage, input)
def index(input: Input.VirtualFile, dialect: Dialect): TextDocument = {
new Mtags().index(input.toLanguage, input, dialect)
}

def toplevels(document: TextDocument): List[String] = {
Expand Down
18 changes: 5 additions & 13 deletions mtags/src/main/scala/scala/meta/internal/mtags/ScalaMtags.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package scala.meta.internal.mtags

import scala.meta._
import scala.meta.dialects.Scala213
import scala.meta.dialects.Scala3
import scala.meta.inputs.Input
import scala.meta.inputs.Position
import scala.meta.internal.mtags.MtagsEnrichments._
Expand All @@ -14,22 +12,16 @@ import scala.meta.internal.trees._
import scala.meta.transversers.SimpleTraverser

object ScalaMtags {
def index(input: Input.VirtualFile): MtagsIndexer = {
new ScalaMtags(input)
def index(input: Input.VirtualFile, dialect: Dialect): MtagsIndexer = {
new ScalaMtags(input, dialect)
}
}
class ScalaMtags(val input: Input.VirtualFile)
class ScalaMtags(val input: Input.VirtualFile, dialect: Dialect)
extends SimpleTraverser
with MtagsIndexer {

// This needs to be further improved with tests
// https://github.com/scalameta/metals/issues/2493
private val root: Parsed[Source] = {
Scala213(input).parse[Source] match {
case r @ Parsed.Success(_) => r
case _ => Scala3(input).parse[Source]
}
}
private val root: Parsed[Source] =
dialect(input).parse[Source]
def source: Source = root.get
override def language: Language = Language.SCALA
override def indexRoot(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class SymbolIndexBucket(
val language = file.toLanguage
val toIndexSource0 = toIndexSource(file).getOrElse(file)
val input = toIndexSource0.toInput
val document = mtags.index(language, input)
val document = mtags.index(language, input, dialect)
s.TextDocuments(List(document))
case _ =>
s.TextDocuments(Nil)
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/src/test/scala/tests/MtagsSuite.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package tests

import scala.meta.dialects
import scala.meta.internal.inputs._
import scala.meta.internal.mtags.Mtags
import scala.meta.internal.mtags.MtagsEnrichments._
Expand Down Expand Up @@ -30,7 +31,7 @@ class MtagsSuite extends DirectoryExpectSuite("mtags") {
file,
{ () =>
val input = file.input
val mtags = Mtags.index(input)
val mtags = Mtags.index(input, dialects.Scala213)
val obtained = Semanticdbs.printTextDocument(mtags)
val unknownSymbols = mtags.occurrences.collect {
case occ if symtab.info(occ.symbol).isEmpty =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ScalaToplevelLibrarySuite extends BaseSuite {
test(entry.toNIO.getFileName.toString) {
forAllFilesInJar(entry) { file =>
val input = file.toInput
val scalaMtags = Mtags.toplevels(Mtags.index(input))
val scalaMtags = Mtags.toplevels(Mtags.index(input, dialects.Scala213))
val scalaToplevelMtags = Mtags.toplevels(input)

val obtained = scalaToplevelMtags.mkString("\n")
Expand All @@ -61,7 +61,7 @@ class ScalaToplevelLibrarySuite extends BaseSuite {
forAllFilesInJar(entry) { file =>
if (!scala3ExclusionList.contains(file.toString)) {
val input = file.toInput
val scalaMtags = Mtags.toplevels(Mtags.index(input))
val scalaMtags = Mtags.toplevels(Mtags.index(input, dialects.Scala3))
val scalaToplevelMtags = Mtags.toplevels(input, dialects.Scala3)
val obtained = scalaToplevelMtags.mkString("\n")
val expected = scalaMtags.mkString("\n")
Expand Down

0 comments on commit 7ee694e

Please sign in to comment.