From daf52457c580678ce8c4bd86b2aaf1dfb4634bea Mon Sep 17 00:00:00 2001 From: KhemrajSingh Rathore Date: Wed, 3 Jan 2024 12:30:23 +0530 Subject: [PATCH] added exception handling for methodReturn query (#903) * refactor test cases * add exception handling fro methodReturn query --- ...PrivadoRubyTypeRecoveryPassGenerator.scala | 112 ++++++------------ .../MethodFullNameForExternalNodesTest.scala | 55 +++++++++ 2 files changed, 89 insertions(+), 78 deletions(-) create mode 100644 src/test/scala/ai/privado/languageEngine/ruby/passes/MethodFullNameForExternalNodesTest.scala diff --git a/src/main/scala/ai/privado/languageEngine/ruby/passes/PrivadoRubyTypeRecoveryPassGenerator.scala b/src/main/scala/ai/privado/languageEngine/ruby/passes/PrivadoRubyTypeRecoveryPassGenerator.scala index fa9e47ec0..a18f143e3 100644 --- a/src/main/scala/ai/privado/languageEngine/ruby/passes/PrivadoRubyTypeRecoveryPassGenerator.scala +++ b/src/main/scala/ai/privado/languageEngine/ruby/passes/PrivadoRubyTypeRecoveryPassGenerator.scala @@ -15,7 +15,8 @@ import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt import org.slf4j.{Logger, LoggerFactory} import scala.annotation.tailrec -import scala.collection.{Seq, mutable} +import scala.collection.{mutable} +import scala.util.Try object SBKeyPrivado { protected val logger: Logger = LoggerFactory.getLogger(getClass) def fromNodeToLocalKey(node: AstNode): Option[LocalKey] = { @@ -102,7 +103,6 @@ private class RecoverForRubyFile( } } - /* override def methodReturnValues(methodFullNames: Seq[String]): Set[String] = { // Check if we have a corresponding member to resolve type val memberTypes = methodFullNames.flatMap { fullName => @@ -110,12 +110,20 @@ private class RecoverForRubyFile( if (memberName.isDefined) { val typeDeclFullName = fullName.stripSuffix(s".${memberName.get}") cpg.typeDecl.fullName(typeDeclFullName).member.nameExact(memberName.get).typeFullName.l - } else - List.empty + } else List.empty }.toSet - if (memberTypes.nonEmpty) memberTypes else super.methodReturnValues(methodFullNames) + if (memberTypes.nonEmpty) memberTypes + else { + val rs = cpg.method + .fullNameExact(methodFullNames.toList: _*) + .flatMap(m => Try(m.methodReturn).toOption) + .flatMap(mr => mr.typeFullName +: mr.dynamicTypeHintFullName) + .filterNot(_.equals("ANY")) + .toSet + if (rs.isEmpty) methodFullNames.map(_.concat(s"$pathSep${XTypeRecovery.DummyReturnType}")).toSet + else rs + } } - */ override def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = { if (c.name.startsWith("")) { @@ -221,7 +229,12 @@ private class RecoverForRubyFile( .map { case x: Call if x.typeFullName != "ANY" => Set(x.typeFullName) case x: Call => - cpg.method.fullNameExact(c.methodFullName).methodReturn.typeFullNameNot("ANY").typeFullName.toSet match { + cpg.method + .fullNameExact(c.methodFullName) + .flatMap(m => Try(m.methodReturn).toOption) + .typeFullNameNot("ANY") + .typeFullName + .toSet match { case xs if xs.nonEmpty => xs case _ => symbolTable @@ -421,8 +434,10 @@ private class RecoverForRubyFile( override def visitReturns(ret: Return): Unit = { val m = ret.method val existingTypes = mutable.HashSet.from( - (m.methodReturn.typeFullName +: m.methodReturn.dynamicTypeHintFullName) - .filterNot(_ == "ANY") + Try( + (m.methodReturn.typeFullName +: m.methodReturn.dynamicTypeHintFullName) + .filterNot(_ == "ANY") + ).toOption.getOrElse(List()) ) @tailrec @@ -467,7 +482,7 @@ private class RecoverForRubyFile( val returnTypes = extractTypes(ret.argumentOut.l) existingTypes.addAll(returnTypes) - builder.setNodeProperty(ret.method.methodReturn, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, existingTypes) + Try(builder.setNodeProperty(ret.method.methodReturn, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, existingTypes)) } override def setTypeInformation(): Unit = { @@ -490,12 +505,12 @@ private class RecoverForRubyFile( val typs = if (state.config.enabledDummyTypes) symbolTable.get(x).toSeq else symbolTable.get(x).filterNot(XTypeRecovery.isDummyType).toSeq - storeCallTypeInfoPrivado(x, typs) + storeCallTypeInfo(x, typs) case x: Call if globalSymbolTable.contains(x) => val typs = if (state.config.enabledDummyTypes) globalSymbolTable.get(x).toSeq else globalSymbolTable.get(x).filterNot(XTypeRecovery.isDummyType).toSeq - storeCallTypeInfoPrivado(x, typs) + storeCallTypeInfo(x, typs) case x: Identifier if (symbolTable .contains(CallAlias(x.name)) || globalSymbolTable.contains(CallAlias(x.name))) && x.inCall.nonEmpty => @@ -509,11 +524,11 @@ private class RecoverForRubyFile( val typs = if (state.config.enabledDummyTypes) symbolTable.get(x).toSeq else symbolTable.get(x).filterNot(XTypeRecovery.isDummyType).toSeq - storeCallTypeInfoPrivado(x, typs) + storeCallTypeInfo(x, typs) case _ => } // Set types in an atomic way - newTypesForMembers.foreach { case (m, ts) => storeDefaultTypeInfoPrivado(m, ts.toSeq) } + newTypesForMembers.foreach { case (m, ts) => storeDefaultTypeInfo(m, ts.toSeq) } } override protected def postSetTypeInformation(): Unit = { @@ -522,7 +537,7 @@ private class RecoverForRubyFile( cu.ast.isCall .nameExact("perform_async") .foreach(c => - storeCallTypeInfoPrivado( + storeCallTypeInfo( c, symbolTable.get(c).flatMap(fullName => List(fullName, fullName.stripSuffix("_async"))).toSeq ) @@ -611,78 +626,19 @@ private class RecoverForRubyFile( case Some(ts) => Some(ts ++ types) case None => Some(types.toSet) } - case i: Identifier => storeIdentifierTypeInfoPrivado(i, types) - case l: Local => storeLocalTypeInfoPrivado(l, types) - case c: Call if !c.name.startsWith("") => storeCallTypeInfoPrivado(c, types) + case i: Identifier => storeIdentifierTypeInfo(i, types) + case l: Local => storeLocalTypeInfo(l, types) + case c: Call if !c.name.startsWith("") => storeCallTypeInfo(c, types) case _: Call => - case n => setTypesPrivado(n, types) + case n => setTypes(n, types) } } } - def storeIdentifierTypeInfoPrivado(i: Identifier, types: Seq[String]): Unit = - storeDefaultTypeInfoPrivado(i, types) - - /** Allows one to modify the types assigned to nodes otherwise. - */ - def storeDefaultTypeInfoPrivado(n: StoredNode, types: Seq[String]): Unit = - if (types.toSet != n.getKnownTypes) { - setTypesPrivado(n, (n.property(PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq.empty) ++ types).distinct) - } - - def setTypesPrivado(n: StoredNode, types: Seq[String]): Unit = - if (types.size == 1) builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, types.head) - else builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, types) - - /** Allows one to modify the types assigned to locals. - */ - def storeLocalTypeInfoPrivado(l: Local, types: Seq[String]): Unit = { - storeDefaultTypeInfoPrivado( - l, - if (state.config.enabledDummyTypes) types else types.filterNot(XTypeRecovery.isDummyType) - ) - } - - def storeCallTypeInfoPrivado(c: Call, types: Seq[String]): Unit = - if (types.nonEmpty) { - builder.setNodeProperty( - c, - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - (c.dynamicTypeHintFullName ++ types).distinct - ) - } private def persistMemberType(i: Identifier, types: Set[String]): Unit = { getLocalMember(i) match { case Some(m) => storeNodeTypeInfo(m, types.toSeq) case None => } } - - private def integrateMethodRef(funcPtr: Expression, m: Method, mRef: NewMethodRef, inCall: AstNode) = { - builder.addNode(mRef) - builder.addEdge(mRef, m, EdgeTypes.REF) - builder.addEdge(inCall, mRef, EdgeTypes.AST) - builder.addEdge(funcPtr.method, mRef, EdgeTypes.CONTAINS) - inCall match { - case x: Call => - builder.addEdge(x, mRef, EdgeTypes.ARGUMENT) - mRef.argumentIndex(x.argumentOut.size + 1) - case x => - mRef.argumentIndex(x.astChildren.size + 1) - } - addedNodes.add(s"${funcPtr.id()}${NodeTypes.METHOD_REF}$pathSep${mRef.methodFullName}") - } - private def createMethodRef( - baseName: Option[String], - funcName: String, - methodFullName: String, - lineNo: Option[Integer], - columnNo: Option[Integer] - ): NewMethodRef = - NewMethodRef() - .code(s"${baseName.map(_.appended(pathSep)).getOrElse("")}$funcName") - .methodFullName(methodFullName) - .lineNumber(lineNo) - .columnNumber(columnNo) - } diff --git a/src/test/scala/ai/privado/languageEngine/ruby/passes/MethodFullNameForExternalNodesTest.scala b/src/test/scala/ai/privado/languageEngine/ruby/passes/MethodFullNameForExternalNodesTest.scala new file mode 100644 index 000000000..a732c38bf --- /dev/null +++ b/src/test/scala/ai/privado/languageEngine/ruby/passes/MethodFullNameForExternalNodesTest.scala @@ -0,0 +1,55 @@ +package ai.privado.languageEngine.ruby.passes + +import ai.privado.cache.RuleCache +import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import io.shiftleft.semanticcpg.language.* +import ai.privado.utility.Utilities.resolver +import ai.privado.languageEngine.ruby.RubyTestBase.* +import ai.privado.languageEngine.ruby.passes.download.DownloadDependenciesPass +import io.joern.rubysrc2cpg.RubySrc2Cpg +import io.joern.rubysrc2cpg.deprecated.utils.PackageTable +class MethodFullNameForExternalNodesTest extends AnyWordSpec with Matchers with BeforeAndAfterAll { + + "method fullname for external nodes accessed via scopeResolution" should { + + "be resolved" in { + val (cpg, config) = code( + List( + SourceCodeModel( + """ + |class MyClass + | def foo + | zendesk_user = ZendeskAPI::User.create_or_update!( + | external_id: user.strong_id, + | email: user.email, + | name: user.name, + | user_fields: user_fields(user), + | ) + | end + |end + |""".stripMargin, + "demo.rb" + ), + SourceCodeModel( + """ + |source 'https://rubygems.org' + |gem 'zendesk_api' + |""".stripMargin, + "Gemfile" + ) + ) + ) + + val packageTable = + new DownloadDependenciesPass(new PackageTable(), config.inputPath, RuleCache()).createAndApply() + new RubyExternalTypesPass(cpg, packageTable).createAndApply() + + cpg.call("create_or_update!").dynamicTypeHintFullName.l shouldBe List( + "zendesk_api::program.ZendeskAPI.User.create_or_update!" + ) + + } + } +}