Skip to content

Commit

Permalink
added exception handling for methodReturn query (#903)
Browse files Browse the repository at this point in the history
* refactor test cases

* add exception handling fro methodReturn query
  • Loading branch information
khemrajrathore authored Jan 3, 2024
1 parent 09d684f commit daf5245
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt
import org.slf4j.{Logger, LoggerFactory}

import scala.annotation.tailrec
import scala.collection.{Seq, mutable}
import scala.collection.{mutable}
import scala.util.Try
object SBKeyPrivado {
protected val logger: Logger = LoggerFactory.getLogger(getClass)
def fromNodeToLocalKey(node: AstNode): Option[LocalKey] = {
Expand Down Expand Up @@ -102,20 +103,27 @@ private class RecoverForRubyFile(
}
}

/*
override def methodReturnValues(methodFullNames: Seq[String]): Set[String] = {
// Check if we have a corresponding member to resolve type
val memberTypes = methodFullNames.flatMap { fullName =>
val memberName = fullName.split("\\.").lastOption
if (memberName.isDefined) {
val typeDeclFullName = fullName.stripSuffix(s".${memberName.get}")
cpg.typeDecl.fullName(typeDeclFullName).member.nameExact(memberName.get).typeFullName.l
} else
List.empty
} else List.empty
}.toSet
if (memberTypes.nonEmpty) memberTypes else super.methodReturnValues(methodFullNames)
if (memberTypes.nonEmpty) memberTypes
else {
val rs = cpg.method
.fullNameExact(methodFullNames.toList: _*)
.flatMap(m => Try(m.methodReturn).toOption)
.flatMap(mr => mr.typeFullName +: mr.dynamicTypeHintFullName)
.filterNot(_.equals("ANY"))
.toSet
if (rs.isEmpty) methodFullNames.map(_.concat(s"$pathSep${XTypeRecovery.DummyReturnType}")).toSet
else rs
}
}
*/

override def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = {
if (c.name.startsWith("<operator>")) {
Expand Down Expand Up @@ -221,7 +229,12 @@ private class RecoverForRubyFile(
.map {
case x: Call if x.typeFullName != "ANY" => Set(x.typeFullName)
case x: Call =>
cpg.method.fullNameExact(c.methodFullName).methodReturn.typeFullNameNot("ANY").typeFullName.toSet match {
cpg.method
.fullNameExact(c.methodFullName)
.flatMap(m => Try(m.methodReturn).toOption)
.typeFullNameNot("ANY")
.typeFullName
.toSet match {
case xs if xs.nonEmpty => xs
case _ =>
symbolTable
Expand Down Expand Up @@ -421,8 +434,10 @@ private class RecoverForRubyFile(
override def visitReturns(ret: Return): Unit = {
val m = ret.method
val existingTypes = mutable.HashSet.from(
(m.methodReturn.typeFullName +: m.methodReturn.dynamicTypeHintFullName)
.filterNot(_ == "ANY")
Try(
(m.methodReturn.typeFullName +: m.methodReturn.dynamicTypeHintFullName)
.filterNot(_ == "ANY")
).toOption.getOrElse(List())
)

@tailrec
Expand Down Expand Up @@ -467,7 +482,7 @@ private class RecoverForRubyFile(

val returnTypes = extractTypes(ret.argumentOut.l)
existingTypes.addAll(returnTypes)
builder.setNodeProperty(ret.method.methodReturn, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, existingTypes)
Try(builder.setNodeProperty(ret.method.methodReturn, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, existingTypes))
}

override def setTypeInformation(): Unit = {
Expand All @@ -490,12 +505,12 @@ private class RecoverForRubyFile(
val typs =
if (state.config.enabledDummyTypes) symbolTable.get(x).toSeq
else symbolTable.get(x).filterNot(XTypeRecovery.isDummyType).toSeq
storeCallTypeInfoPrivado(x, typs)
storeCallTypeInfo(x, typs)
case x: Call if globalSymbolTable.contains(x) =>
val typs =
if (state.config.enabledDummyTypes) globalSymbolTable.get(x).toSeq
else globalSymbolTable.get(x).filterNot(XTypeRecovery.isDummyType).toSeq
storeCallTypeInfoPrivado(x, typs)
storeCallTypeInfo(x, typs)
case x: Identifier
if (symbolTable
.contains(CallAlias(x.name)) || globalSymbolTable.contains(CallAlias(x.name))) && x.inCall.nonEmpty =>
Expand All @@ -509,11 +524,11 @@ private class RecoverForRubyFile(
val typs =
if (state.config.enabledDummyTypes) symbolTable.get(x).toSeq
else symbolTable.get(x).filterNot(XTypeRecovery.isDummyType).toSeq
storeCallTypeInfoPrivado(x, typs)
storeCallTypeInfo(x, typs)
case _ =>
}
// Set types in an atomic way
newTypesForMembers.foreach { case (m, ts) => storeDefaultTypeInfoPrivado(m, ts.toSeq) }
newTypesForMembers.foreach { case (m, ts) => storeDefaultTypeInfo(m, ts.toSeq) }
}

override protected def postSetTypeInformation(): Unit = {
Expand All @@ -522,7 +537,7 @@ private class RecoverForRubyFile(
cu.ast.isCall
.nameExact("perform_async")
.foreach(c =>
storeCallTypeInfoPrivado(
storeCallTypeInfo(
c,
symbolTable.get(c).flatMap(fullName => List(fullName, fullName.stripSuffix("_async"))).toSeq
)
Expand Down Expand Up @@ -611,78 +626,19 @@ private class RecoverForRubyFile(
case Some(ts) => Some(ts ++ types)
case None => Some(types.toSet)
}
case i: Identifier => storeIdentifierTypeInfoPrivado(i, types)
case l: Local => storeLocalTypeInfoPrivado(l, types)
case c: Call if !c.name.startsWith("<operator>") => storeCallTypeInfoPrivado(c, types)
case i: Identifier => storeIdentifierTypeInfo(i, types)
case l: Local => storeLocalTypeInfo(l, types)
case c: Call if !c.name.startsWith("<operator>") => storeCallTypeInfo(c, types)
case _: Call =>
case n => setTypesPrivado(n, types)
case n => setTypes(n, types)
}
}
}

def storeIdentifierTypeInfoPrivado(i: Identifier, types: Seq[String]): Unit =
storeDefaultTypeInfoPrivado(i, types)

/** Allows one to modify the types assigned to nodes otherwise.
*/
def storeDefaultTypeInfoPrivado(n: StoredNode, types: Seq[String]): Unit =
if (types.toSet != n.getKnownTypes) {
setTypesPrivado(n, (n.property(PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq.empty) ++ types).distinct)
}

def setTypesPrivado(n: StoredNode, types: Seq[String]): Unit =
if (types.size == 1) builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, types.head)
else builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, types)

/** Allows one to modify the types assigned to locals.
*/
def storeLocalTypeInfoPrivado(l: Local, types: Seq[String]): Unit = {
storeDefaultTypeInfoPrivado(
l,
if (state.config.enabledDummyTypes) types else types.filterNot(XTypeRecovery.isDummyType)
)
}

def storeCallTypeInfoPrivado(c: Call, types: Seq[String]): Unit =
if (types.nonEmpty) {
builder.setNodeProperty(
c,
PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME,
(c.dynamicTypeHintFullName ++ types).distinct
)
}
private def persistMemberType(i: Identifier, types: Set[String]): Unit = {
getLocalMember(i) match {
case Some(m) => storeNodeTypeInfo(m, types.toSeq)
case None =>
}
}

private def integrateMethodRef(funcPtr: Expression, m: Method, mRef: NewMethodRef, inCall: AstNode) = {
builder.addNode(mRef)
builder.addEdge(mRef, m, EdgeTypes.REF)
builder.addEdge(inCall, mRef, EdgeTypes.AST)
builder.addEdge(funcPtr.method, mRef, EdgeTypes.CONTAINS)
inCall match {
case x: Call =>
builder.addEdge(x, mRef, EdgeTypes.ARGUMENT)
mRef.argumentIndex(x.argumentOut.size + 1)
case x =>
mRef.argumentIndex(x.astChildren.size + 1)
}
addedNodes.add(s"${funcPtr.id()}${NodeTypes.METHOD_REF}$pathSep${mRef.methodFullName}")
}
private def createMethodRef(
baseName: Option[String],
funcName: String,
methodFullName: String,
lineNo: Option[Integer],
columnNo: Option[Integer]
): NewMethodRef =
NewMethodRef()
.code(s"${baseName.map(_.appended(pathSep)).getOrElse("")}$funcName")
.methodFullName(methodFullName)
.lineNumber(lineNo)
.columnNumber(columnNo)

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package ai.privado.languageEngine.ruby.passes

import ai.privado.cache.RuleCache
import org.scalatest.BeforeAndAfterAll
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import io.shiftleft.semanticcpg.language.*
import ai.privado.utility.Utilities.resolver
import ai.privado.languageEngine.ruby.RubyTestBase.*
import ai.privado.languageEngine.ruby.passes.download.DownloadDependenciesPass
import io.joern.rubysrc2cpg.RubySrc2Cpg
import io.joern.rubysrc2cpg.deprecated.utils.PackageTable
class MethodFullNameForExternalNodesTest extends AnyWordSpec with Matchers with BeforeAndAfterAll {

"method fullname for external nodes accessed via scopeResolution" should {

"be resolved" in {
val (cpg, config) = code(
List(
SourceCodeModel(
"""
|class MyClass
| def foo
| zendesk_user = ZendeskAPI::User.create_or_update!(
| external_id: user.strong_id,
| email: user.email,
| name: user.name,
| user_fields: user_fields(user),
| )
| end
|end
|""".stripMargin,
"demo.rb"
),
SourceCodeModel(
"""
|source 'https://rubygems.org'
|gem 'zendesk_api'
|""".stripMargin,
"Gemfile"
)
)
)

val packageTable =
new DownloadDependenciesPass(new PackageTable(), config.inputPath, RuleCache()).createAndApply()
new RubyExternalTypesPass(cpg, packageTable).createAndApply()

cpg.call("create_or_update!").dynamicTypeHintFullName.l shouldBe List(
"zendesk_api::program.ZendeskAPI.User.create_or_update!"
)

}
}
}

0 comments on commit daf5245

Please sign in to comment.