Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse each file only once #200

Merged
merged 7 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions avrohugger-core/src/main/scala/Generator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import avrohugger.input.parsers.{ FileInputParser, StringInputParser }
import avrohugger.matchers.TypeMatcher
import avrohugger.types.AvroScalaTypes
import avrohugger.stores.{ ClassStore, SchemaStore }
import org.apache.avro.Schema.Parser
import org.apache.avro.{ Protocol, Schema }

import java.io.File

// Unable to overload this class' methods because outDir uses a default value
Expand Down Expand Up @@ -76,6 +78,7 @@ case class Generator(format: SourceFormat,
classStore,
schemaStore,
fileParser,
schemaParser,
typeMatcher,
classLoader,
restrictedFieldNumber,
Expand All @@ -92,6 +95,7 @@ case class Generator(format: SourceFormat,
classStore,
schemaStore,
fileParser,
schemaParser,
typeMatcher,
classLoader,
restrictedFieldNumber,
Expand Down Expand Up @@ -128,6 +132,7 @@ case class Generator(format: SourceFormat,
classStore,
schemaStore,
fileParser,
schemaParser,
typeMatcher,
classLoader,
restrictedFieldNumber,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ object SpecificImporter extends Importer {

val switchAnnotSymbol = RootClass.newClass("scala.annotation.switch")
val switchImport = IMPORT(switchAnnotSymbol)
val topLevelSchemas =
getTopLevelSchemas(schemaOrProtocol, schemaStore, typeMatcher)
val topLevelSchemas = getTopLevelSchemas(schemaOrProtocol, schemaStore, typeMatcher)
val recordSchemas = getRecordSchemas(topLevelSchemas)
val enumSchemas = getEnumSchemas(topLevelSchemas)
val userDefinedDeps = getUserDefinedImports(recordSchemas ++ enumSchemas, currentNamespace, typeMatcher)
Expand All @@ -42,7 +41,6 @@ object SpecificImporter extends Importer {
else libraryDeps ++ userDefinedDeps
}
case Right(protocol) => {
val types = protocol.getTypes().asScala.toList
val messages = protocol.getMessages.asScala.toMap
if (messages.isEmpty) switchImport :: libraryDeps ::: userDefinedDeps // for ADT
else List.empty // for RPC
Expand Down
7 changes: 5 additions & 2 deletions avrohugger-core/src/main/scala/generators/FileGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import avrohugger.input.{ DependencyInspector, NestedSchemaExtractor }
import avrohugger.input.parsers.{ FileInputParser, StringInputParser }
import avrohugger.matchers.TypeMatcher
import avrohugger.stores.{ ClassStore, SchemaStore }
import org.apache.avro.Schema.Parser
import org.apache.avro.{ Protocol, Schema }

import java.io.File
Expand Down Expand Up @@ -72,11 +73,12 @@ private[avrohugger] class FileGenerator {
classStore: ClassStore,
schemaStore: SchemaStore,
fileParser: FileInputParser,
schemaParser: Parser,
typeMatcher: TypeMatcher,
classLoader: ClassLoader,
restrictedFields: Boolean,
targetScalaPartialVersion: String): Unit = {
distinctSchemaOrProtocol(fileParser.getSchemaOrProtocols(inFile, format, classStore, classLoader))
distinctSchemaOrProtocol(fileParser.getSchemaOrProtocols(inFile, format, classStore, classLoader, schemaParser))
.foreach {
case Left(schema) =>
schemaToFile(schema, outDir, format, classStore, schemaStore, typeMatcher, restrictedFields, targetScalaPartialVersion)
Expand All @@ -92,11 +94,12 @@ private[avrohugger] class FileGenerator {
classStore: ClassStore,
schemaStore: SchemaStore,
fileParser: FileInputParser,
schemaParser: Parser,
typeMatcher: TypeMatcher,
classLoader: ClassLoader,
restrictedFields: Boolean,
targetScalaPartialVersion: String): Unit = {
distinctSchemaOrProtocol(inFiles.flatMap(fileParser.getSchemaOrProtocols(_, format, classStore, classLoader)))
distinctSchemaOrProtocol(inFiles.flatMap(fileParser.getSchemaOrProtocols(_, format, classStore, classLoader, schemaParser)))
.foreach {
case Left(schema) =>
schemaToFile(schema, outDir, format, classStore, schemaStore, typeMatcher, restrictedFields, targetScalaPartialVersion)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import avrohugger.input.NestedSchemaExtractor
import avrohugger.input.parsers.{ FileInputParser, StringInputParser }
import avrohugger.matchers.TypeMatcher
import avrohugger.stores.{ ClassStore, SchemaStore }
import org.apache.avro.Schema.Parser

import java.io.{ File, FileNotFoundException, IOException }

import org.apache.avro.{ Protocol, Schema }

// Unable to overload this class' methods because outDir uses a default value
Expand Down Expand Up @@ -91,12 +91,13 @@ private[avrohugger] class StringGenerator {
classStore: ClassStore,
schemaStore: SchemaStore,
fileParser: FileInputParser,
schemaParser: Parser,
typeMatcher: TypeMatcher,
classLoader: ClassLoader,
restrictedFields: Boolean,
targetScalaPartialVersion: String): List[String] = {
try {
val schemaOrProtocols: List[Either[Schema, Protocol]] = fileParser.getSchemaOrProtocols(inFile, format, classStore, classLoader)
val schemaOrProtocols: List[Either[Schema, Protocol]] = fileParser.getSchemaOrProtocols(inFile, format, classStore, classLoader, schemaParser)
schemaOrProtocols.flatMap {
case Left(schema) =>
schemaToStrings(schema, format, classStore, schemaStore, typeMatcher, restrictedFields, targetScalaPartialVersion)
Expand Down
128 changes: 64 additions & 64 deletions avrohugger-core/src/main/scala/input/parsers/FileInputParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ package parsers

import format.abstractions.SourceFormat
import stores.ClassStore
import org.apache.avro.{Protocol, Schema}
import org.apache.avro.{ Protocol, Schema }
import org.apache.avro.Schema.Parser
import org.apache.avro.Schema.Type.{ENUM, FIXED, RECORD, UNION}
import org.apache.avro.Schema.Type.{ ENUM, FIXED, RECORD, UNION }
import org.apache.avro.compiler.idl.Idl
import org.apache.avro.generic.{GenericDatumReader, GenericRecord}
import org.apache.avro.generic.{ GenericDatumReader, GenericRecord }
import org.apache.avro.file.DataFileReader
import org.apache.avro.SchemaParseException

Expand All @@ -17,23 +17,25 @@ import scala.jdk.CollectionConverters._
import scala.util.Try

class FileInputParser {

val schemaParser = new Parser()

var processedFiles: Set[String] = Set.empty
var processedSchemas: Set[Schema] = Set.empty

def getSchemaOrProtocols(
infile: File,
format: SourceFormat,
classStore: ClassStore,
classLoader: ClassLoader,
parser: Parser = schemaParser): List[Either[Schema, Protocol]] = {
parser: Parser): List[Either[Schema, Protocol]] = {
def unUnion(schema: Schema) = {
schema.getType match {
case UNION => schema.getTypes().asScala.toList
case RECORD => List(schema)
case ENUM => List(schema)
case FIXED => List(schema)
case _ => sys.error("""Neither a record, enum nor a union of either.
|Nothing to map to a definition.""".trim.stripMargin)
case _ => sys.error(
"""Neither a record, enum nor a union of either.
|Nothing to map to a definition.""".trim.stripMargin)
}
}

Expand All @@ -57,79 +59,77 @@ class FileInputParser {
}
}
}

def mightBeRecoverable(e: SchemaParseException): Boolean = {
val msg = e.getMessage
msg.contains("Undefined name:") || msg.contains("is not a defined name")
msg.contains("Undefined name:") || msg.contains("is not a defined name")
}

def tryParse(inFile: File, parser: Schema.Parser): List[Schema] = {
val tempParser = new Parser()
val parsed = Try(tempParser.parse(inFile)).map(schema => {
copySchemas(tempParser, parser)
schema
}).recoverWith { case e: SchemaParseException if mightBeRecoverable(e) =>
}).recoverWith { case e: SchemaParseException if mightBeRecoverable(e) =>
Try(parser.parse(inFile))
}
unUnion(parsed.get)// throw the avro parse exception if Failure
unUnion(parsed.get) // throw the avro parse exception if Failure
}

val schemaOrProtocols: List[Either[Schema, Protocol]] = {

infile.getName.split("\\.").last match {
case "avro" =>
val gdr = new GenericDatumReader[GenericRecord]
val dfr = new DataFileReader(infile, gdr)
val schemas = unUnion(dfr.getSchema)
schemas.map(Left(_))
case "avsc" =>
val schemas = tryParse(infile, parser)
schemas.map(Left(_))
case "avpr" =>
val protocol = Protocol.parse(infile)
List(Right(protocol))
case "avdl" =>
val idlParser = new Idl(infile, classLoader)
val protocol = idlParser.CompilationUnit()
/**
* IDLs may refer to types imported from another file. When converted
* to protocols, the imported types that share the IDL's namespace
* cannot be distinguished from types defined within the IDL, yet
* should not be generated as subtypes of the IDL's ADT and should
* instead be generated in its own namespace. So, strip the protocol
* of all imported types and generate them separately.
*/
val importedFiles = IdlImportParser.getImportedFiles(infile, classLoader)
val importedSchemaOrProtocols = importedFiles.flatMap(file => {
val res = infile.getName.split("\\.").last match {
case "avro" =>
val gdr = new GenericDatumReader[GenericRecord]
val dfr = new DataFileReader(infile, gdr)
val schemas = unUnion(dfr.getSchema)
schemas.map(Left(_))
case "avsc" =>
val schemas = tryParse(infile, parser)
schemas.map(Left(_))
case "avpr" =>
val protocol = Protocol.parse(infile)
List(Right(protocol))
case "avdl" =>
val idlParser = new Idl(infile, classLoader)
val protocol = idlParser.CompilationUnit()
/**
* IDLs may refer to types imported from another file. When converted
* to protocols, the imported types that share the IDL's namespace
* cannot be distinguished from types defined within the IDL, yet
* should not be generated as subtypes of the IDL's ADT and should
* instead be generated in its own namespace. So, strip the protocol
* of all imported types and generate them separately.
*/
val importedFiles = IdlImportParser.getImportedFiles(infile, classLoader)
val importedSchemaOrProtocols = importedFiles.flatMap { file =>
if (!processedFiles.contains(file.getCanonicalPath)) {
processedFiles += file.getCanonicalPath
val importParser = new Parser() // else attempts to redefine schemas
getSchemaOrProtocols(file, format, classStore, classLoader, importParser)
})
Comment on lines -103 to -106
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here the code potentially revisits the same file multiple times

def stripImports(
protocol: Protocol,
importedSchemaOrProtocols: List[Either[Schema, Protocol]]) = {
val imported = importedSchemaOrProtocols.flatMap(avroDef => {
avroDef match {
case Left(importedSchema) => List(importedSchema)
case Right(importedProtocol) => importedProtocol.getTypes().asScala
}
})
val types = protocol.getTypes().asScala.toList
val localTypes = imported.foldLeft(types)((remaining, imported) => {
remaining.filterNot(remainingType => remainingType == imported)
})
protocol.setTypes(localTypes.asJava)
protocol
} else {
List()
}
val localProtocol = stripImports(protocol, importedSchemaOrProtocols)
// reverse to dependent classes are generated first
(Right(localProtocol) +: importedSchemaOrProtocols).reverse
case _ =>
throw new Exception("""File must end in ".avpr" for protocol files,
|".avsc" for plain text json files, ".avdl" for IDL files, or .avro
}

def stripImports(protocol: Protocol, imported: Set[Schema]) = {
val types = protocol.getTypes().asScala.toList
val localTypes = types.filterNot(imported.contains)
protocol.setTypes(localTypes.asJava)
protocol
}

val localProtocol = stripImports(protocol, processedSchemas)
// reverse to dependent classes are generated first
(Right(localProtocol) +: importedSchemaOrProtocols).reverse
case _ =>
throw new Exception(
"""File must end in ".avpr" for protocol files,
|".avsc" for plain text json files, ".avdl" for IDL files, or .avro
|for binary.""".trim.stripMargin)
}
}

schemaOrProtocols
res.foreach {
case Left(importedSchema) => processedSchemas += importedSchema
case Right(importedProtocol) => processedSchemas ++= importedProtocol.getTypes().asScala
}
res
}
}
28 changes: 13 additions & 15 deletions avrohugger-core/src/main/scala/input/parsers/IdlImportParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ package avrohugger
package input
package parsers

import org.apache.avro.{ Protocol, Schema }

import java.io.File
import scala.util.matching.Regex.Match

object IdlImportParser {

def stripComments(fileContents: String): String = {
val multiLinePattern = """/\*.*\*/""".r
val singleLinePattern = """//.*$""".r
Expand All @@ -29,13 +27,14 @@ object IdlImportParser {
// if file is empty, try again, it was there when we read idl
if (fileContents.isEmpty && (count < maxTries)) readFile(infile)
else fileContents
} catch {// if file is not found, try again, it was there when we read idl
} catch { // if file is not found, try again, it was there when we read idl
case e: java.io.FileNotFoundException => {
if (count < maxTries) readFile(infile)
else sys.error("File to found: " + infile)
}
}
}

val path = infile.getParent + "/"
val contents = readFile(infile)
val avdlPattern = """import[ \t]+idl[ \t]+"([^"]*\.avdl)"[ \t]*;""".r
Expand All @@ -45,24 +44,23 @@ object IdlImportParser {
val protocolMatches = avprPattern.findAllIn(contents).matchData.toList
val schemaMatches = avscPattern.findAllIn(contents).matchData.toList
val importMatches = idlMatches ::: protocolMatches ::: schemaMatches

val (localImports, nonLocalMatches): (List[File], List[Match]) =
importMatches.foldLeft((List.empty[File], List.empty[Match])){
case ((ai,am), m) =>
importMatches.foldLeft((List.empty[File], List.empty[Match])) {
case ((ai, am), m) =>
val f = new File(path + m.group(1))
if (f.exists) (ai:+f, am)
else (ai, am:+m)
if (f.exists) (ai :+ f, am)
else (ai, am :+ m)
}

val classpathImports: List[File] = nonLocalMatches.map(m =>{

Option(classLoader.getResource(m.group(1))).map(resource =>{

val classpathImports: List[File] = nonLocalMatches.flatMap { m =>
Option(classLoader.getResource(m.group(1))).map(resource => {
new File(resource.getFile)
})
}).flatMap(_.toList).filter(file => file.exists)
}.filter(_.exists)

val importedFiles = classpathImports ++ localImports
importedFiles
}

}
2 changes: 1 addition & 1 deletion avrohugger-core/src/test/scala/util/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ object Util {
val rightValue: String = right.value
val rightAsList = if (rightValue.contains("\r\n")) rightValue.split("\r\n") else rightValue.split("\n")

val res = leftAsList === rightAsList
val res = leftAsList.mkString("\n") === rightAsList.mkString("\n")
result(res.isSuccess, res.message, res.message, right)
}
}
Expand Down
Loading