diff --git a/avrohugger-core/src/main/scala/Generator.scala b/avrohugger-core/src/main/scala/Generator.scala index 403ab922..46e6bff6 100644 --- a/avrohugger-core/src/main/scala/Generator.scala +++ b/avrohugger-core/src/main/scala/Generator.scala @@ -7,7 +7,9 @@ import avrohugger.input.parsers.{ FileInputParser, StringInputParser } import avrohugger.matchers.TypeMatcher import avrohugger.types.AvroScalaTypes import avrohugger.stores.{ ClassStore, SchemaStore } +import org.apache.avro.Schema.Parser import org.apache.avro.{ Protocol, Schema } + import java.io.File // Unable to overload this class' methods because outDir uses a default value @@ -76,6 +78,7 @@ case class Generator(format: SourceFormat, classStore, schemaStore, fileParser, + schemaParser, typeMatcher, classLoader, restrictedFieldNumber, @@ -92,6 +95,7 @@ case class Generator(format: SourceFormat, classStore, schemaStore, fileParser, + schemaParser, typeMatcher, classLoader, restrictedFieldNumber, @@ -128,6 +132,7 @@ case class Generator(format: SourceFormat, classStore, schemaStore, fileParser, + schemaParser, typeMatcher, classLoader, restrictedFieldNumber, diff --git a/avrohugger-core/src/main/scala/format/specific/SpecificImporter.scala b/avrohugger-core/src/main/scala/format/specific/SpecificImporter.scala index 7207cb50..d057004c 100644 --- a/avrohugger-core/src/main/scala/format/specific/SpecificImporter.scala +++ b/avrohugger-core/src/main/scala/format/specific/SpecificImporter.scala @@ -25,8 +25,7 @@ object SpecificImporter extends Importer { val switchAnnotSymbol = RootClass.newClass("scala.annotation.switch") val switchImport = IMPORT(switchAnnotSymbol) - val topLevelSchemas = - getTopLevelSchemas(schemaOrProtocol, schemaStore, typeMatcher) + val topLevelSchemas = getTopLevelSchemas(schemaOrProtocol, schemaStore, typeMatcher) val recordSchemas = getRecordSchemas(topLevelSchemas) val enumSchemas = getEnumSchemas(topLevelSchemas) val userDefinedDeps = getUserDefinedImports(recordSchemas ++ enumSchemas, currentNamespace, typeMatcher) @@ -42,7 +41,6 @@ object SpecificImporter extends Importer { else libraryDeps ++ userDefinedDeps } case Right(protocol) => { - val types = protocol.getTypes().asScala.toList val messages = protocol.getMessages.asScala.toMap if (messages.isEmpty) switchImport :: libraryDeps ::: userDefinedDeps // for ADT else List.empty // for RPC diff --git a/avrohugger-core/src/main/scala/generators/FileGenerator.scala b/avrohugger-core/src/main/scala/generators/FileGenerator.scala index 60848661..c6974890 100644 --- a/avrohugger-core/src/main/scala/generators/FileGenerator.scala +++ b/avrohugger-core/src/main/scala/generators/FileGenerator.scala @@ -6,6 +6,7 @@ import avrohugger.input.{ DependencyInspector, NestedSchemaExtractor } import avrohugger.input.parsers.{ FileInputParser, StringInputParser } import avrohugger.matchers.TypeMatcher import avrohugger.stores.{ ClassStore, SchemaStore } +import org.apache.avro.Schema.Parser import org.apache.avro.{ Protocol, Schema } import java.io.File @@ -72,11 +73,12 @@ private[avrohugger] class FileGenerator { classStore: ClassStore, schemaStore: SchemaStore, fileParser: FileInputParser, + schemaParser: Parser, typeMatcher: TypeMatcher, classLoader: ClassLoader, restrictedFields: Boolean, targetScalaPartialVersion: String): Unit = { - distinctSchemaOrProtocol(fileParser.getSchemaOrProtocols(inFile, format, classStore, classLoader)) + distinctSchemaOrProtocol(fileParser.getSchemaOrProtocols(inFile, format, classStore, classLoader, schemaParser)) .foreach { case Left(schema) => schemaToFile(schema, outDir, format, classStore, schemaStore, typeMatcher, restrictedFields, targetScalaPartialVersion) @@ -92,11 +94,12 @@ private[avrohugger] class FileGenerator { classStore: ClassStore, schemaStore: SchemaStore, fileParser: FileInputParser, + schemaParser: Parser, typeMatcher: TypeMatcher, classLoader: ClassLoader, restrictedFields: Boolean, targetScalaPartialVersion: String): Unit = { - distinctSchemaOrProtocol(inFiles.flatMap(fileParser.getSchemaOrProtocols(_, format, classStore, classLoader))) + distinctSchemaOrProtocol(inFiles.flatMap(fileParser.getSchemaOrProtocols(_, format, classStore, classLoader, schemaParser))) .foreach { case Left(schema) => schemaToFile(schema, outDir, format, classStore, schemaStore, typeMatcher, restrictedFields, targetScalaPartialVersion) diff --git a/avrohugger-core/src/main/scala/generators/StringGenerator.scala b/avrohugger-core/src/main/scala/generators/StringGenerator.scala index 797cfd6c..443a1e16 100644 --- a/avrohugger-core/src/main/scala/generators/StringGenerator.scala +++ b/avrohugger-core/src/main/scala/generators/StringGenerator.scala @@ -7,9 +7,9 @@ import avrohugger.input.NestedSchemaExtractor import avrohugger.input.parsers.{ FileInputParser, StringInputParser } import avrohugger.matchers.TypeMatcher import avrohugger.stores.{ ClassStore, SchemaStore } +import org.apache.avro.Schema.Parser import java.io.{ File, FileNotFoundException, IOException } - import org.apache.avro.{ Protocol, Schema } // Unable to overload this class' methods because outDir uses a default value @@ -91,12 +91,13 @@ private[avrohugger] class StringGenerator { classStore: ClassStore, schemaStore: SchemaStore, fileParser: FileInputParser, + schemaParser: Parser, typeMatcher: TypeMatcher, classLoader: ClassLoader, restrictedFields: Boolean, targetScalaPartialVersion: String): List[String] = { try { - val schemaOrProtocols: List[Either[Schema, Protocol]] = fileParser.getSchemaOrProtocols(inFile, format, classStore, classLoader) + val schemaOrProtocols: List[Either[Schema, Protocol]] = fileParser.getSchemaOrProtocols(inFile, format, classStore, classLoader, schemaParser) schemaOrProtocols.flatMap { case Left(schema) => schemaToStrings(schema, format, classStore, schemaStore, typeMatcher, restrictedFields, targetScalaPartialVersion) diff --git a/avrohugger-core/src/main/scala/input/NestedSchemaExtractor.scala b/avrohugger-core/src/main/scala/input/NestedSchemaExtractor.scala index f932c5f5..50f1a4a3 100644 --- a/avrohugger-core/src/main/scala/input/NestedSchemaExtractor.scala +++ b/avrohugger-core/src/main/scala/input/NestedSchemaExtractor.scala @@ -6,7 +6,7 @@ import stores.SchemaStore import types.EnumAsScalaString import org.apache.avro.Schema -import org.apache.avro.Schema.Type.{ARRAY, ENUM, FIXED, MAP, RECORD, UNION} +import org.apache.avro.Schema.Type.{ ARRAY, ENUM, FIXED, MAP, RECORD, UNION } import scala.jdk.CollectionConverters._ @@ -16,53 +16,60 @@ object NestedSchemaExtractor { schema: Schema, schemaStore: SchemaStore, typeMatcher: TypeMatcher): List[Schema] = { - def extract( - schema: Schema, - fieldPath: List[String] = List.empty): List[Schema] = { + var visitedSchemas = Set.empty[String] - schema.getType match { - case RECORD => - val fields: List[Schema.Field] = schema.getFields().asScala.toList - val fieldSchemas: List[Schema] = fields.map(field => field.schema) - def flattenSchema(fieldSchema: Schema): List[Schema] = { - fieldSchema.getType match { - case ARRAY => flattenSchema(fieldSchema.getElementType) - case MAP => flattenSchema(fieldSchema.getValueType) - case RECORD => { - // if the field schema is one that has already been stored, use that one - if (schemaStore.schemas.contains(fieldSchema.getFullName)) List() - // if we've already seen this schema (recursive schemas) don't traverse further - else if (fieldPath.contains(fieldSchema.getFullName)) List() - else fieldSchema :: extract(fieldSchema, fieldSchema.getFullName :: fieldPath) - } - case UNION => fieldSchema.getTypes().asScala.toList.flatMap(x => flattenSchema(x)) - case ENUM => { - // if the field schema is one that has already been stored, use that one - if (schemaStore.schemas.contains(fieldSchema.getFullName)) List() - else List(fieldSchema) - } - case FIXED => { - // if the field schema is one that has already been stored, use that one - if (schemaStore.schemas.contains(fieldSchema.getFullName)) List() - else List(fieldSchema) + def extract(schema: Schema): List[Schema] = { + if (visitedSchemas.contains(schema.getFullName)) + List() + else { + visitedSchemas += schema.getFullName + schema.getType match { + case RECORD => + val fields: List[Schema.Field] = schema.getFields().asScala.toList + val fieldSchemas: List[Schema] = fields.map(field => field.schema) + + def flattenSchema(fieldSchema: Schema): List[Schema] = { + fieldSchema.getType match { + case ARRAY => flattenSchema(fieldSchema.getElementType) + case MAP => flattenSchema(fieldSchema.getValueType) + case RECORD => { + // if the field schema is one that has already been stored, use that one + if (schemaStore.schemas.contains(fieldSchema.getFullName)) List() + // if we've already seen this schema (recursive schemas) don't traverse further + else fieldSchema :: extract(fieldSchema) + + } + case UNION => fieldSchema.getTypes().asScala.toList.flatMap(x => flattenSchema(x)) + case ENUM => { + // if the field schema is one that has already been stored, use that one + if (schemaStore.schemas.contains(fieldSchema.getFullName)) List() + else List(fieldSchema) + } + case FIXED => { + // if the field schema is one that has already been stored, use that one + if (schemaStore.schemas.contains(fieldSchema.getFullName)) List() + else List(fieldSchema) + } + case _ => List(fieldSchema) } - case _ => List(fieldSchema) } - } - val flatSchemas = fieldSchemas.flatMap(fieldSchema => flattenSchema(fieldSchema)) - def topLevelTypes(schema: Schema) = { - if (typeMatcher.avroScalaTypes.`enum` == EnumAsScalaString) (schema.getType == RECORD | schema.getType == FIXED) - else (schema.getType == RECORD | schema.getType == ENUM | schema.getType == FIXED) - } - val nestedTopLevelSchemas = flatSchemas.filter(topLevelTypes) - nestedTopLevelSchemas - case ENUM => List(schema) - case FIXED => List(schema) - case _ => Nil - } + + fieldSchemas + .flatMap(flattenSchema) + .filter { schema => + if (typeMatcher.avroScalaTypes.`enum` == EnumAsScalaString) + schema.getType == RECORD | schema.getType == FIXED + else + schema.getType == RECORD | schema.getType == ENUM | schema.getType == FIXED + } + case ENUM => List(schema) + case FIXED => List(schema) + case _ => Nil + } + } } - schema::extract(schema) + schema :: extract(schema) } } diff --git a/avrohugger-core/src/main/scala/input/parsers/FileInputParser.scala b/avrohugger-core/src/main/scala/input/parsers/FileInputParser.scala index 163079d7..38602de7 100644 --- a/avrohugger-core/src/main/scala/input/parsers/FileInputParser.scala +++ b/avrohugger-core/src/main/scala/input/parsers/FileInputParser.scala @@ -4,11 +4,11 @@ package parsers import format.abstractions.SourceFormat import stores.ClassStore -import org.apache.avro.{Protocol, Schema} +import org.apache.avro.{ Protocol, Schema } import org.apache.avro.Schema.Parser -import org.apache.avro.Schema.Type.{ENUM, FIXED, RECORD, UNION} +import org.apache.avro.Schema.Type.{ ENUM, FIXED, RECORD, UNION } import org.apache.avro.compiler.idl.Idl -import org.apache.avro.generic.{GenericDatumReader, GenericRecord} +import org.apache.avro.generic.{ GenericDatumReader, GenericRecord } import org.apache.avro.file.DataFileReader import org.apache.avro.SchemaParseException @@ -17,23 +17,25 @@ import scala.jdk.CollectionConverters._ import scala.util.Try class FileInputParser { - - val schemaParser = new Parser() + + var processedFiles: Set[String] = Set.empty + var processedSchemas: Set[Schema] = Set.empty def getSchemaOrProtocols( infile: File, format: SourceFormat, classStore: ClassStore, classLoader: ClassLoader, - parser: Parser = schemaParser): List[Either[Schema, Protocol]] = { + parser: Parser): List[Either[Schema, Protocol]] = { def unUnion(schema: Schema) = { schema.getType match { case UNION => schema.getTypes().asScala.toList case RECORD => List(schema) case ENUM => List(schema) case FIXED => List(schema) - case _ => sys.error("""Neither a record, enum nor a union of either. - |Nothing to map to a definition.""".trim.stripMargin) + case _ => sys.error( + """Neither a record, enum nor a union of either. + |Nothing to map to a definition.""".trim.stripMargin) } } @@ -57,10 +59,10 @@ class FileInputParser { } } } - + def mightBeRecoverable(e: SchemaParseException): Boolean = { val msg = e.getMessage - msg.contains("Undefined name:") || msg.contains("is not a defined name") + msg.contains("Undefined name:") || msg.contains("is not a defined name") } def tryParse(inFile: File, parser: Schema.Parser): List[Schema] = { @@ -68,68 +70,66 @@ class FileInputParser { val parsed = Try(tempParser.parse(inFile)).map(schema => { copySchemas(tempParser, parser) schema - }).recoverWith { case e: SchemaParseException if mightBeRecoverable(e) => + }).recoverWith { case e: SchemaParseException if mightBeRecoverable(e) => Try(parser.parse(inFile)) } - unUnion(parsed.get)// throw the avro parse exception if Failure + unUnion(parsed.get) // throw the avro parse exception if Failure } - - val schemaOrProtocols: List[Either[Schema, Protocol]] = { - infile.getName.split("\\.").last match { - case "avro" => - val gdr = new GenericDatumReader[GenericRecord] - val dfr = new DataFileReader(infile, gdr) - val schemas = unUnion(dfr.getSchema) - schemas.map(Left(_)) - case "avsc" => - val schemas = tryParse(infile, parser) - schemas.map(Left(_)) - case "avpr" => - val protocol = Protocol.parse(infile) - List(Right(protocol)) - case "avdl" => - val idlParser = new Idl(infile, classLoader) - val protocol = idlParser.CompilationUnit() - /** - * IDLs may refer to types imported from another file. When converted - * to protocols, the imported types that share the IDL's namespace - * cannot be distinguished from types defined within the IDL, yet - * should not be generated as subtypes of the IDL's ADT and should - * instead be generated in its own namespace. So, strip the protocol - * of all imported types and generate them separately. - */ - val importedFiles = IdlImportParser.getImportedFiles(infile, classLoader) - val importedSchemaOrProtocols = importedFiles.flatMap(file => { + val res = infile.getName.split("\\.").last match { + case "avro" => + val gdr = new GenericDatumReader[GenericRecord] + val dfr = new DataFileReader(infile, gdr) + val schemas = unUnion(dfr.getSchema) + schemas.map(Left(_)) + case "avsc" => + val schemas = tryParse(infile, parser) + schemas.map(Left(_)) + case "avpr" => + val protocol = Protocol.parse(infile) + List(Right(protocol)) + case "avdl" => + val idlParser = new Idl(infile, classLoader) + val protocol = idlParser.CompilationUnit() + /** + * IDLs may refer to types imported from another file. When converted + * to protocols, the imported types that share the IDL's namespace + * cannot be distinguished from types defined within the IDL, yet + * should not be generated as subtypes of the IDL's ADT and should + * instead be generated in its own namespace. So, strip the protocol + * of all imported types and generate them separately. + */ + val importedFiles = IdlImportParser.getImportedFiles(infile, classLoader) + val importedSchemaOrProtocols = importedFiles.flatMap { file => + if (!processedFiles.contains(file.getCanonicalPath)) { + processedFiles += file.getCanonicalPath val importParser = new Parser() // else attempts to redefine schemas getSchemaOrProtocols(file, format, classStore, classLoader, importParser) - }) - def stripImports( - protocol: Protocol, - importedSchemaOrProtocols: List[Either[Schema, Protocol]]) = { - val imported = importedSchemaOrProtocols.flatMap(avroDef => { - avroDef match { - case Left(importedSchema) => List(importedSchema) - case Right(importedProtocol) => importedProtocol.getTypes().asScala - } - }) - val types = protocol.getTypes().asScala.toList - val localTypes = imported.foldLeft(types)((remaining, imported) => { - remaining.filterNot(remainingType => remainingType == imported) - }) - protocol.setTypes(localTypes.asJava) - protocol + } else { + List() } - val localProtocol = stripImports(protocol, importedSchemaOrProtocols) - // reverse to dependent classes are generated first - (Right(localProtocol) +: importedSchemaOrProtocols).reverse - case _ => - throw new Exception("""File must end in ".avpr" for protocol files, - |".avsc" for plain text json files, ".avdl" for IDL files, or .avro + } + + def stripImports(protocol: Protocol, imported: Set[Schema]) = { + val types = protocol.getTypes().asScala.toList + val localTypes = types.filterNot(imported.contains) + protocol.setTypes(localTypes.asJava) + protocol + } + + val localProtocol = stripImports(protocol, processedSchemas) + // reverse to dependent classes are generated first + (Right(localProtocol) +: importedSchemaOrProtocols).reverse + case _ => + throw new Exception( + """File must end in ".avpr" for protocol files, + |".avsc" for plain text json files, ".avdl" for IDL files, or .avro |for binary.""".trim.stripMargin) - } } - - schemaOrProtocols + res.foreach { + case Left(importedSchema) => processedSchemas += importedSchema + case Right(importedProtocol) => processedSchemas ++= importedProtocol.getTypes().asScala + } + res } } \ No newline at end of file diff --git a/avrohugger-core/src/main/scala/input/parsers/IdlImportParser.scala b/avrohugger-core/src/main/scala/input/parsers/IdlImportParser.scala index 0508a31e..61289dc0 100644 --- a/avrohugger-core/src/main/scala/input/parsers/IdlImportParser.scala +++ b/avrohugger-core/src/main/scala/input/parsers/IdlImportParser.scala @@ -2,13 +2,11 @@ package avrohugger package input package parsers -import org.apache.avro.{ Protocol, Schema } - import java.io.File import scala.util.matching.Regex.Match object IdlImportParser { - + def stripComments(fileContents: String): String = { val multiLinePattern = """/\*.*\*/""".r val singleLinePattern = """//.*$""".r @@ -29,13 +27,14 @@ object IdlImportParser { // if file is empty, try again, it was there when we read idl if (fileContents.isEmpty && (count < maxTries)) readFile(infile) else fileContents - } catch {// if file is not found, try again, it was there when we read idl + } catch { // if file is not found, try again, it was there when we read idl case e: java.io.FileNotFoundException => { if (count < maxTries) readFile(infile) else sys.error("File to found: " + infile) } } } + val path = infile.getParent + "/" val contents = readFile(infile) val avdlPattern = """import[ \t]+idl[ \t]+"([^"]*\.avdl)"[ \t]*;""".r @@ -45,24 +44,23 @@ object IdlImportParser { val protocolMatches = avprPattern.findAllIn(contents).matchData.toList val schemaMatches = avscPattern.findAllIn(contents).matchData.toList val importMatches = idlMatches ::: protocolMatches ::: schemaMatches - + val (localImports, nonLocalMatches): (List[File], List[Match]) = - importMatches.foldLeft((List.empty[File], List.empty[Match])){ - case ((ai,am), m) => + importMatches.foldLeft((List.empty[File], List.empty[Match])) { + case ((ai, am), m) => val f = new File(path + m.group(1)) - if (f.exists) (ai:+f, am) - else (ai, am:+m) + if (f.exists) (ai :+ f, am) + else (ai, am :+ m) } - - val classpathImports: List[File] = nonLocalMatches.map(m =>{ - - Option(classLoader.getResource(m.group(1))).map(resource =>{ + + val classpathImports: List[File] = nonLocalMatches.flatMap { m => + Option(classLoader.getResource(m.group(1))).map(resource => { new File(resource.getFile) }) - }).flatMap(_.toList).filter(file => file.exists) + }.filter(_.exists) val importedFiles = classpathImports ++ localImports importedFiles } - + } diff --git a/avrohugger-core/src/test/scala/util/Util.scala b/avrohugger-core/src/test/scala/util/Util.scala index 8f627f6d..d6bc2903 100644 --- a/avrohugger-core/src/test/scala/util/Util.scala +++ b/avrohugger-core/src/test/scala/util/Util.scala @@ -37,7 +37,7 @@ object Util { val rightValue: String = right.value val rightAsList = if (rightValue.contains("\r\n")) rightValue.split("\r\n") else rightValue.split("\n") - val res = leftAsList === rightAsList + val res = leftAsList.mkString("\n") === rightAsList.mkString("\n") result(res.isSuccess, res.message, res.message, right) } }