Skip to content

Commit

Permalink
Merge pull request #899 from Privado-Inc/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
khemrajrathore authored Jan 2, 2024
2 parents 750569d + 798b3bb commit af74fe9
Show file tree
Hide file tree
Showing 7 changed files with 424 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import io.joern.x2cpg.Ast
import io.joern.x2cpg.passes.frontend.{CallAlias, LocalKey, LocalVar, SymbolTable, XImportsPass}
import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes}
import io.shiftleft.passes.ConcurrentWriterCpgPass
import io.joern.rubysrc2cpg.deprecated.utils.PackageTable
import io.joern.x2cpg.Ast.storeInDiffGraph
import io.joern.x2cpg.Imports.createImportNodeAndLink
import io.shiftleft.semanticcpg.language.importresolver.{ResolvedMethod, ResolvedTypeDecl}
Expand All @@ -37,77 +36,24 @@ import io.shiftleft.semanticcpg.language.*

import scala.collection.Seq

class GlobalImportPass(cpg: Cpg, packageTable: PackageTable, globalSymbolTable: SymbolTable[LocalKey])
extends PrivadoSimpleCpgPass(cpg) {

/*
lazy val modules: Set[String] =
(packageTable.moduleMapping.keys.l ++ packageTable.typeDeclMapping.keys.l).toSet.filter(_.nonEmpty)
def generateParts(): Array[File] = {
cpg.file.name(".*[.]rb").toArray
}
def runOnPart(builder: DiffGraphBuilder, fileNode: File): Unit = {
modules.foreach { moduleKey =>
if (!fileNode.name.equals(moduleKey)) {
val callNode = NewCall().name(moduleKey)
val importNode = createImportNodeAndLink(moduleKey, "", Some(callNode), builder)
builder.addEdge(fileNode, callNode, EdgeTypes.AST)
val callAst = Ast(callNode).withChild(Ast(importNode))
storeInDiffGraph(callAst, builder)
}
}
}
*/
class GlobalImportPass(cpg: Cpg, globalSymbolTable: SymbolTable[LocalKey]) extends PrivadoSimpleCpgPass(cpg) {

override def run(builder: DiffGraphBuilder): Unit = {

/*
val resolvedModulesExternal = packageTable.moduleMapping.values.flatMap(moduleMappings =>
moduleMappings.map(module => ResolvedTypeDecl(module.fullName))
)
val resolvedTypeDeclExternal = packageTable.typeDeclMapping.values.flatMap(typeDeclMappings =>
typeDeclMappings.flatMap(typeDeclModel =>
Seq(
ResolvedMethod(s"${typeDeclModel.fullName}.new", "new"),
ResolvedMethod(s"${typeDeclModel.fullName}.${typeDeclModel.name}", typeDeclModel.name),
ResolvedTypeDecl(typeDeclModel.fullName)
)
)
)
*/

val resolvedTypeDeclInternal = cpg.typeDecl
.flatMap(typeDecl =>
Seq(
ResolvedTypeDecl(typeDecl.fullName),
ResolvedMethod(s"${typeDecl.fullName}.new", "new"),
ResolvedMethod(s"${typeDecl.fullName}.${typeDecl.name}", typeDecl.name)
)
)

val resolvedModuleInternal = cpg.namespaceBlock
.whereNot(_.nameExact("<global>"))
.flatMap(module => Seq(ResolvedTypeDecl(module.fullName)))

// Expose methods which are directly present in a file, without any module, TypeDecl
val resolvedMethodInternal = cpg.method
.where(_.nameExact(":program"))
.astChildren
.astChildren
.isMethod
.flatMap(method => Seq(ResolvedMethod(method.fullName, method.name)))

// (resolvedModulesExternal ++ resolvedTypeDeclExternal ++
(resolvedTypeDeclInternal ++ resolvedModuleInternal ++ resolvedMethodInternal).toSet
.foreach {
case ResolvedMethod(fullName, alias, receiver, _) =>
globalSymbolTable.append(CallAlias(alias, receiver), fullName)
case ResolvedTypeDecl(fullName, _) =>
globalSymbolTable.append(LocalVar(fullName.split("\\.").lastOption.getOrElse(fullName)), fullName)
.filterNot(_.astParent.isNamespaceBlock)
.flatMap(typeDecl => Seq(ResolvedTypeDecl(typeDecl.fullName)))
.l

val resolvedModuleInternal = cpg.typeDecl
.filter(_.astParent.isNamespaceBlock)
.filter(_.astChildren.size != 1)
.flatMap(typeDecl => Seq(ResolvedTypeDecl(typeDecl.fullName)))
.l

(resolvedTypeDeclInternal ++ resolvedModuleInternal).toSet
.foreach { case ResolvedTypeDecl(fullName, _) =>
globalSymbolTable.append(LocalVar(fullName.split("\\.").lastOption.getOrElse(fullName)), fullName)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,39 @@ import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt
import io.shiftleft.semanticcpg.language.*
import io.joern.x2cpg.Defines.{ConstructorMethodName, DynamicCallUnknownFullName}
import io.joern.x2cpg.Defines as XDefines
import io.joern.x2cpg.passes.frontend.SBKey.getClass
import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, Operators, PropertyNames}
import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess}
import overflowdb.BatchedUpdate.DiffGraphBuilder
import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt
import org.slf4j.{Logger, LoggerFactory}

import scala.annotation.tailrec
import scala.collection.{Seq, mutable}
object SBKeyPrivado {
protected val logger: Logger = LoggerFactory.getLogger(getClass)
def fromNodeToLocalKey(node: AstNode): Option[LocalKey] = {
Option(node match {
case n: Identifier => LocalVar(n.name)
case n: Local => LocalVar(n.name)
case n: Call =>
CallAlias(
n.name,
n.argument.collectFirst {
case x: Identifier if x.argumentIndex == 0 => x.name
case c: Call if c.argumentIndex == 0 && c.name == "<operator>.scopeResolution" =>
c.code.stripPrefix("::").trim
}
)
case n: Method => CallAlias(n.name, Option("this"))
case n: MethodRef => CallAlias(n.code)
case n: FieldIdentifier => LocalVar(n.canonicalName)
case n: MethodParameterIn => LocalVar(n.name)
case _ => logger.debug(s"Local node of type ${node.label} is not supported in the type recovery pass."); null
})
}

}

class PrivadoRubyTypeRecoveryPassGenerator(
cpg: Cpg,
Expand Down Expand Up @@ -51,16 +77,18 @@ private class RecoverForRubyFile(

import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt

override protected val symbolTable = new SymbolTable[LocalKey](SBKeyPrivado.fromNodeToLocalKey)

/** A heuristic method to determine if a call is a constructor or not.
*/
override protected def isConstructor(c: Call): Boolean = {
isConstructor(c.name) && c.code.charAt(0).isUpper
isConstructor(c.name)
}

/** A heuristic method to determine if a call name is a constructor or not.
*/
override protected def isConstructor(name: String): Boolean =
!name.isBlank && name.equals("new")
!name.isBlank && (name.equals("new") || name.equals("<init>"))

override def visitImport(i: Import): Unit = for {
resolvedImport <- i.call.tag
Expand All @@ -73,17 +101,6 @@ private class RecoverForRubyFile(
case _ => super.visitImport(i)
}
}
override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = {

def isMatching(cName: String, code: String) = {
val cNameList = cName.split(":program").last.split("\\.").filterNot(_.isEmpty)
val codeList = code.split("\\(").head.split("[:.]").filterNot(_.isEmpty)
cNameList sameElements codeList
}

val constructorPaths = symbolTable.get(c).filter(isMatching(_, c.code)).map(_.stripSuffix(s"${pathSep}new"))
associateTypes(i, constructorPaths)
}

/*
override def methodReturnValues(methodFullNames: Seq[String]): Set[String] = {
Expand All @@ -103,26 +120,14 @@ private class RecoverForRubyFile(
override def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = {
if (c.name.startsWith("<operator>")) {
visitIdentifierAssignedToOperator(i, c, c.name)
} else if ((symbolTable.contains(c) || globalSymbolTable.contains(c)) && isConstructor(c)) {
visitIdentifierAssignedToConstructor(i, c)
} else if (symbolTable.contains(c) || globalSymbolTable.contains(c)) {
visitIdentifierAssignedToCallRetVal(i, c)
} else if (c.argument.headOption.exists(arg => symbolTable.contains(arg) || globalSymbolTable.contains(arg))) {
setCallMethodFullNameFromBase(c)
} else if (isCallHeadArgumentAScopeResolutionAndIsLastArgumentInTable(c)) {
setCallMethodFullNameFromBaseScopeResolution(c)
// Repeat this method now that the call has a type
visitIdentifierAssignedToCall(i, c)
} else if (
c.argument.headOption
.exists(_.isCall) && c.argument.head
.asInstanceOf[Call]
.name
.equals("<operator>.scopeResolution") && c.argument.head
.asInstanceOf[Call]
.argument
.lastOption
.exists(arg => symbolTable.contains(arg) || globalSymbolTable.contains(arg))
) {
setCallMethodFullNameFromBaseScopeResolution(c)
} else if (c.argument.headOption.exists(arg => symbolTable.contains(arg) || globalSymbolTable.contains(arg))) {
setCallMethodFullNameFromBase(c)
// Repeat this method now that the call has a type
visitIdentifierAssignedToCall(i, c)
} else {
Expand All @@ -131,17 +136,56 @@ private class RecoverForRubyFile(
}
}

/** Return true if `methodFullName` after the `:program` part matches the callCode excluding arguments
* @param methodFullName
* @param callCode
* @return
*/
def isCallParentScopeResolutionMatching(methodFullName: String, callCode: String) = {
try {
val cNameList = methodFullName.split(":program").last.split("\\.").filterNot(_.isEmpty)
val codeList = callCode.split("\\(").head.split("[:.]").filterNot(_.isEmpty).dropRight(1).toList
cNameList sameElements codeList
} catch {
case e: Exception => false
}

}

/** Return true if the passed node is `foo` and it is called as `Pay::Braintree.Billable.foo()`,
*
* Here we check whether the head argument is `ScopeResolution` (which is true above) and is `Billable` present in
* symbol table which accessor as `Pay.Braintree`
*
* If the above conditions hold true, return true else false
* @param c
* @return
*/
def isCallHeadArgumentAScopeResolutionAndIsLastArgumentInTable(c: Call): Boolean = c.argument.headOption
.exists(_.isCall) && c.argument.head
.asInstanceOf[Call]
.name
.equals("<operator>.scopeResolution") && c.argument.head
.asInstanceOf[Call]
.argument
.lastOption
.exists(arg =>
symbolTable.get(arg).union(globalSymbolTable.get(arg)).exists(isCallParentScopeResolutionMatching(_, c.code))
)

protected def setCallMethodFullNameFromBaseScopeResolution(c: Call): Set[String] = {
val recTypes = c.argument.headOption
.map {
case x: Call if x.name.equals("<operator>.scopeResolution") =>
x.argument.lastOption
.map(i => symbolTable.get(i).union(globalSymbolTable.get(i)))
.map(i =>
symbolTable.get(i).union(globalSymbolTable.get(i)).filter(isCallParentScopeResolutionMatching(_, c.code))
)
.getOrElse(Set.empty[String])
.map(_.concat(s"$pathSep${c.name}"))
}
.getOrElse(Set.empty[String])
val callTypes = recTypes.map(_.concat(s"$pathSep${c.name}"))
symbolTable.append(c, callTypes)
symbolTable.append(c, recTypes)
}

private def debugLocation(n: AstNode): String = {
Expand Down Expand Up @@ -254,8 +298,14 @@ private class RecoverForRubyFile(
.get(LocalVar(getFieldName(c.asInstanceOf[FieldAccess])))
.union(globalSymbolTable.get(LocalVar(getFieldName(c.asInstanceOf[FieldAccess]))))
case _ if symbolTable.contains(c) => methodReturnValues(symbolTable.get(c).toSeq)
case _ if globalSymbolTable.contains(c) => globalSymbolTable.get(c)
case Operators.indexAccess => getIndexAccessTypes(c)
case _ if globalSymbolTable.contains(c) => methodReturnValues(globalSymbolTable.get(c).toSeq)
case _ if c.argument.headOption.exists(arg => symbolTable.contains(arg) || globalSymbolTable.contains(arg)) =>
setCallMethodFullNameFromBase(c)
methodReturnValues(symbolTable.get(c).toSeq)
case _ if isCallHeadArgumentAScopeResolutionAndIsLastArgumentInTable(c) =>
setCallMethodFullNameFromBaseScopeResolution(c)
methodReturnValues(symbolTable.get(c).toSeq)
case Operators.indexAccess => getIndexAccessTypes(c)
case n =>
logger.debug(s"Unknown RHS call type '$n' @ ${debugLocation(c)}")
Set.empty[String]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,7 @@ class RubyImportResolverPass(cpg: Cpg, packageTableInfo: PackageTable) extends X
val importNodesFromTypeDecl = packageTableInfo
.getTypeDecl(expEntity)
.flatMap { typeDeclModel =>
Seq(
ResolvedMethod(s"${typeDeclModel.fullName}.new", "new"),
ResolvedMethod(s"${typeDeclModel.fullName}.${typeDeclModel.name}", typeDeclModel.name),
ResolvedTypeDecl(typeDeclModel.fullName)
)
Seq(ResolvedTypeDecl(typeDeclModel.fullName))
}
.distinct

Expand All @@ -91,13 +87,7 @@ class RubyImportResolverPass(cpg: Cpg, packageTableInfo: PackageTable) extends X
} else {
val resolvedTypeDecls = cpg.typeDecl
.where(_.file.name(s"${Pattern.quote(expResolvedPath)}\\.?.*"))
.flatMap(typeDecl =>
Seq(
ResolvedTypeDecl(typeDecl.fullName),
ResolvedMethod(s"${typeDecl.fullName}.new", "new"),
ResolvedMethod(s"${typeDecl.fullName}.${typeDecl.name}", typeDecl.name)
)
)
.flatMap(typeDecl => Seq(ResolvedTypeDecl(typeDecl.fullName)))
.toSet

val resolvedModules = cpg.namespaceBlock
Expand All @@ -106,16 +96,7 @@ class RubyImportResolverPass(cpg: Cpg, packageTableInfo: PackageTable) extends X
.flatMap(module => Seq(ResolvedTypeDecl(module.fullName)))
.toSet

// Expose methods which are directly present in a file, without any module, TypeDecl
val resolvedMethods = cpg.method
.where(_.file.name(s"${Pattern.quote(expResolvedPath)}\\.?.*"))
.where(_.nameExact(":program"))
.astChildren
.astChildren
.isMethod
.flatMap(method => Seq(ResolvedMethod(method.fullName, method.name)))
.toSet
resolvedTypeDecls ++ resolvedModules ++ resolvedMethods
resolvedTypeDecls ++ resolvedModules
}
}

Expand Down
Loading

0 comments on commit af74fe9

Please sign in to comment.