Skip to content

Commit

Permalink
Support interactive mode for taint analysis (#110)
Browse files Browse the repository at this point in the history
Co-authored-by: Teng Zhang <[email protected]>
Co-authored-by: Tian Tan <[email protected]>
  • Loading branch information
3 people authored May 30, 2024
1 parent 03fe7ef commit c8e4361
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- Taint analysis
- Support specifying IndexRef (e.g., `index: "0[*]"` and `index: "0.f"`) in call sources and parameter sources.
- Support specifying IndexRef in sinks.
- Support interactive mode, allowing users to modify the taint configuration file and re-run taint analysis without needing to re-run the whole program analysis.

### Breaking Changes
- API changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ public PointerFlowEdge addEdge(PointerFlowEdge edge) {
return null;
}

@Override
public void removeEdgesIf(Predicate<PointerFlowEdge> filter) {
outEdges.removeIf(filter);
}

@Override
public Set<PointerFlowEdge> getOutEdges() {
return Collections.unmodifiableSet(new ArraySet<>(outEdges, true));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ public interface Pointer extends Indexable {
*/
PointerFlowEdge addEdge(PointerFlowEdge edge);

/**
* Removes out edges of this pointer if they satisfy the filter.
* <p>
* <strong>Note:</strong> This method should not be called outside of
* {@link pascal.taie.analysis.pta.plugin.Plugin#onPhaseFinish()},
* otherwise it may break the monotonicity of pointer analysis.
* </p>
*/
void removeEdgesIf(Predicate<PointerFlowEdge> filter);

/**
* @return out edges of this pointer in pointer flow graph.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;

/**
* Composite plugin which allows multiple independent plugins
Expand Down Expand Up @@ -89,6 +90,13 @@ private void addPlugin(Plugin plugin, List<Plugin> plugins,
}
}

public void clearPlugins() {
Stream.of(allPlugins,
onNewPointsToSetPlugins, onNewCallEdgePlugins, onNewMethodPlugins,
onNewStmtPlugins, onNewCSMethodPlugins, onUnresolvedCallPlugins
).forEach(List::clear);
}

@Override
public void setSolver(Solver solver) {
allPlugins.forEach(p -> p.setSolver(solver));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import pascal.taie.analysis.graph.callgraph.CallKind;
import pascal.taie.analysis.graph.callgraph.Edge;
import pascal.taie.analysis.pta.PointerAnalysisResult;
import pascal.taie.analysis.pta.core.cs.element.ArrayIndex;
import pascal.taie.analysis.pta.core.cs.element.CSObj;
import pascal.taie.analysis.pta.core.cs.element.InstanceField;
import pascal.taie.analysis.pta.core.cs.element.Pointer;
import pascal.taie.analysis.pta.core.heap.Obj;
import pascal.taie.analysis.pta.plugin.util.InvokeUtils;
import pascal.taie.ir.exp.Var;
Expand Down Expand Up @@ -59,7 +63,7 @@ Set<TaintFlow> collectTaintFlows() {
// TODO: handle other call edges
.filter(e -> e.getKind() != CallKind.OTHER)
.map(Edge::getCallSite)
.map(sinkCall -> collectTaintFlows(result, sinkCall, sink))
.map(sinkCall -> collectTaintFlows(sinkCall, sink))
.forEach(taintFlows::addAll);
}
if (callSiteMode) {
Expand All @@ -73,7 +77,7 @@ Set<TaintFlow> collectTaintFlows() {
JMethod callee = callSite.getMethodRef().resolveNullable();
if (callee != null) {
for (Sink sink : sinkMap.get(callee)) {
taintFlows.addAll(collectTaintFlows(result, callSite, sink));
taintFlows.addAll(collectTaintFlows(callSite, sink));
}
}
});
Expand All @@ -82,15 +86,31 @@ Set<TaintFlow> collectTaintFlows() {
}

private Set<TaintFlow> collectTaintFlows(
PointerAnalysisResult result, Invoke sinkCall, Sink sink) {
Invoke sinkCall, Sink sink) {
IndexRef indexRef = sink.indexRef();
Var arg = InvokeUtils.getVar(sinkCall, indexRef.index());
SinkPoint sinkPoint = new SinkPoint(sinkCall, indexRef);
// obtain objects to check for different IndexRef.Kind
Set<Obj> objs = switch (indexRef.kind()) {
case VAR -> result.getPointsToSet(arg);
case ARRAY -> result.getPointsToSet(arg, (Var) null);
case FIELD -> result.getPointsToSet(arg, indexRef.field());
case VAR -> csManager.getCSVarsOf(arg)
.stream()
.flatMap(Pointer::objects)
.map(CSObj::getObject)
.collect(Collectors.toUnmodifiableSet());
case ARRAY -> csManager.getCSVarsOf(arg)
.stream()
.flatMap(Pointer::objects)
.map(csManager::getArrayIndex)
.flatMap(ArrayIndex::objects)
.map(CSObj::getObject)
.collect(Collectors.toUnmodifiableSet());
case FIELD -> csManager.getCSVarsOf(arg)
.stream()
.flatMap(Pointer::objects)
.map(o -> csManager.getInstanceField(o, indexRef.field()))
.flatMap(InstanceField::objects)
.map(CSObj::getObject)
.collect(Collectors.toUnmodifiableSet());
};
return objs.stream()
.filter(manager::isTaint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,60 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import pascal.taie.World;
import pascal.taie.analysis.graph.callgraph.CallGraph;
import pascal.taie.analysis.pta.core.cs.context.Context;
import pascal.taie.analysis.pta.core.cs.element.CSCallSite;
import pascal.taie.analysis.pta.core.cs.element.CSManager;
import pascal.taie.analysis.pta.core.cs.element.CSMethod;
import pascal.taie.analysis.pta.core.cs.element.CSVar;
import pascal.taie.analysis.pta.core.solver.Solver;
import pascal.taie.analysis.pta.plugin.CompositePlugin;
import pascal.taie.analysis.pta.pts.PointsToSet;
import pascal.taie.ir.IR;
import pascal.taie.language.classes.JMethod;
import pascal.taie.util.Timer;

import javax.annotation.Nullable;
import java.io.File;
import java.util.Set;

/**
* Taint Analysis composites plugins {@link SourceHandler}, {@link TransferHandler}
* and {@link SanitizerHandler} to handle the logic associated with {@link Source},
* {@link TaintTransfer}, and {@link Sanitizer} respectively.
* The analysis finally gathers taint flows from {@link Sink} through {@link SinkHandler}
* and generates reports.
*/
public class TaintAnalysis extends CompositePlugin {

private static final Logger logger = LogManager.getLogger(TaintAnalysis.class);

private static final String TAINT_FLOW_GRAPH_FILE = "taint-flow-graph.dot";

private Solver solver;

private boolean isInteractive;

/**
* Indicates whether the taint analysis result has been reported.
* It is used to ensures that {@link #reportTaintFlows()} executes only once
* during a taint analysis.
*/
private boolean isReported;

private HandlerContext context;

@Override
public void setSolver(Solver solver) {
this.solver = solver;
isInteractive = solver.getOptions().getBoolean("taint-interactive-mode");
initialize();
}

private void initialize() {
isReported = false;
// reset composited plugins
clearPlugins();
TaintManager manager = new TaintManager(solver.getHeapModel());
TaintConfig config = TaintConfig.loadConfig(
solver.getOptions().getString("taint-config"),
Expand All @@ -52,15 +89,102 @@ public void setSolver(Solver solver) {
addPlugin(new SourceHandler(context),
new TransferHandler(context),
new SanitizerHandler(context));
// clear all taint objects and taint edges
CSManager csManager = solver.getCSManager();
csManager.pointers().forEach(p -> {
PointsToSet pts = p.getPointsToSet();
if (pts != null) {
pts.removeIf(csObj -> manager.isTaint(csObj.getObject()));
}
p.removeEdgesIf(TaintTransferEdge.class::isInstance);
});
// trigger the creation of taint objects
CallGraph<CSCallSite, CSMethod> cg = solver.getCallGraph();
if (cg != null) {
boolean handleStmt = context.config().callSiteMode()
|| context.config().sources().stream().anyMatch(FieldSource.class::isInstance);
cg.reachableMethods().forEach(csMethod -> {
JMethod method = csMethod.getMethod();
Context ctxt = csMethod.getContext();
IR ir = csMethod.getMethod().getIR();
if (handleStmt) {
ir.forEach(stmt -> onNewStmt(stmt, method));
}
this.onNewCSMethod(csMethod);
csMethod.getEdges().forEach(this::onNewCallEdge);
ir.getParams().forEach(param -> {
CSVar csParam = csManager.getCSVar(ctxt, param);
onNewPointsToSet(csParam, csParam.getPointsToSet());
});
});
}
}

@Override
public void onPhaseFinish() {
if (isInteractive) {
while (true) {
reportTaintFlows();
System.out.println("Taint Analysis is in interactive mode,"
+ " you can modify the taint configuration and run the analysis again.\n"
+ "Enter 'r' to run, 'e' to exit: ");
String input = readLineFromConsole();
if (input == null) {
break;
}
input = input.strip();
System.out.println("You have entered: '" + input + "'");
if ("r".equals(input)) {
initialize();
if (!context.manager().getTaintObjs().isEmpty()) {
break;
}
} else if ("e".equals(input)) {
isInteractive = false;
break;
}
}
}
}

/**
* A utility method for reading one line from the console using {@code System.in}.
* This method does not use buffering to ensure it does not read more than necessary.
* <br>
*
* @return one line line read from the console,
* or {@code null} if no line is available
*/
@Nullable
private static String readLineFromConsole() {
StringBuilder sb = new StringBuilder();
try {
int c;
while ((c = System.in.read()) != -1) {
if (c == '\r' || c == '\n') {
return sb.toString();
}
sb.append((char) c);
}
} catch (Exception e) {
logger.error("Error reading from console", e);
}
return sb.isEmpty() ? null : sb.toString();
}

@Override
public void onFinish() {
reportTaintFlows();
}

private void reportTaintFlows() {
if (isReported) {
return;
}
isReported = true;
Set<TaintFlow> taintFlows = new SinkHandler(context).collectTaintFlows();
logger.info("Detected {} taint flow(s):", taintFlows.size());
taintFlows.forEach(logger::info);
Solver solver = context.solver();
solver.getResult().storeResult(getClass().getName(), taintFlows);
TaintManager manager = context.manager();
Timer.runAndCount(() -> new TFGDumper().dump(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import java.util.Collections;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Stream;

/**
Expand Down Expand Up @@ -58,6 +59,11 @@ public boolean addAll(PointsToSet pts) {
}
}

@Override
public void removeIf(Predicate<CSObj> filter) {
set.removeIf(filter);
}

@Override
public boolean contains(CSObj obj) {
return set.contains(obj);
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/pascal/taie/analysis/pta/pts/PointsToSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import java.util.Iterator;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Stream;

/**
Expand Down Expand Up @@ -57,6 +58,16 @@ public interface PointsToSet extends Iterable<CSObj>, Copyable<PointsToSet> {
*/
PointsToSet addAllDiff(PointsToSet pts);

/**
* Removes objects from this set if they satisfy the filter.
* <p>
* <strong>Note:</strong> This method should not be called outside of
* {@link pascal.taie.analysis.pta.plugin.Plugin#onPhaseFinish()},
* otherwise it may break the monotonicity of pointer analysis.
* </p>
*/
void removeIf(Predicate<CSObj> filter);

/**
* @return true if this set contains given object, otherwise false.
*/
Expand Down
1 change: 1 addition & 0 deletions src/main/resources/tai-e-analyses.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
reflection-log: null # path to reflection log, required when reflection option is log
taint-config: null # path to config file of taint analysis,
# when this file is given, taint analysis will be enabled
taint-interactive-mode: false # whether enable interactive mode for taint analysis
plugins: [ ] # | [ pluginClass, ... ]
time-limit: -1 # set time limit (in seconds) for pointer analysis,
# -1 means no time limit
Expand Down
22 changes: 22 additions & 0 deletions src/test/java/pascal/taie/analysis/pta/TaintTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
import pascal.taie.analysis.Tests;
import pascal.taie.util.MultiStringsSource;

import java.io.ByteArrayInputStream;
import java.io.InputStream;

public class TaintTest {

static final String DIR = "taint";
Expand Down Expand Up @@ -67,7 +70,26 @@ public class TaintTest {
@MultiStringsSource({"CallSiteMode",
TAINT_CONFIG_PREFIX + "taint-config-call-site-model.yml"})
void test(String mainClass, String... opts) {
testInNonInteractiveMode(mainClass, opts);
testInInteractiveMode(mainClass, opts);
}

private void testInNonInteractiveMode(String mainClass, String... opts) {
Tests.testPTA(DIR, mainClass, opts);
}

private void testInInteractiveMode(String mainClass, String... opts) {
InputStream originalSystemIn = System.in;
try {
String simulatedInput = "r\ne\n";
System.setIn(new ByteArrayInputStream(simulatedInput.getBytes()));
String[] newOpts = new String[opts.length + 1];
System.arraycopy(opts, 0, newOpts, 0, opts.length);
newOpts[opts.length] = "taint-interactive-mode:true";
Tests.testPTA(DIR, mainClass, newOpts);
} finally {
System.setIn(originalSystemIn);
}
}

}

0 comments on commit c8e4361

Please sign in to comment.