-
Notifications
You must be signed in to change notification settings - Fork 0
/
TestScript.java
316 lines (252 loc) · 13.1 KB
/
TestScript.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
//TODO write a description for this script
//@author
//@category _NEW_
//@keybinding
//@menupath
//@toolbar
import ghidra.app.script.GhidraScript;
import ghidra.app.util.PseudoDisassembler;
import ghidra.pcode.emulate.BreakTable;
import ghidra.pcode.emulate.BreakTableCallBack;
import ghidra.pcode.emulate.Emulate;
import ghidra.pcode.emulate.InstructionDecodeException;
import ghidra.pcode.error.LowlevelError;
import ghidra.pcode.memstate.MemoryState;
import ghidra.pcode.pcoderaw.PcodeOpRaw;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
import ghidra.program.model.listing.*;
import ghidra.program.model.pcode.*;
import ghidra.program.model.address.*;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.HashSet;
import java.util.TreeSet;
import ghidra.app.emulator.*;
import ghidra.app.plugin.processors.sleigh.SleighLanguage;
public class TestScript extends GhidraScript {
private static final int MAX_INSTR_LENGTH = 8;
private static final HashSet<String> DIT_GP_MNEMONICS = new HashSet<>();
static {
DIT_GP_MNEMONICS.addAll(Arrays.asList("ADC", "ADCS", "ADD", "ADDS", "AND", "ANDS", "ASR", "ASRV", "BFC", "BFI", "BFM", "BFXIL", "BIC", "BICS", "CCMN", "CCMP", "CFINV", "CINC", "CINV", "CLS", "CLZ", "CMN", "CMP", "CNEG", "CSEL", "CSET", "CSETM", "CSINC", "CSINV", "CSNEG", "EON", "EOR", "EXTR", "LSL", "LSLV", "LSR", "LSRV", "MADD", "MNEG", "MOV", "MOVK", "MOVN", "MOVZ", "MSUB", "MUL", "MVN", "NEG", "NEGS", "NGC", "NGCS", "NOP", "ORN", "ORR", "RBIT", "RET", "REV", "REV16", "REV32", "REV64", "RMIF", "ROR", "RORV", "SBC", "SBCS", "SBFIZ", "SBFM", "SBFX", "SETF8", "SETF16", "SMADDL", "SMNEGL", "SMSUBL", "SMULH", "SMULL", "SUB", "SUBS", "SXTB", "SXTH", "SXTW", "TST", "UBFIZ", "UBFM", "UBFX", "UMADDL", "UMNEGL", "UMSUBL", "UMULH", "UMULL", "UXTB", "UXTH"));
}
private static final HashSet<String> LDST_MNEMONICS = new HashSet<>();
static {
LDST_MNEMONICS.addAll(Arrays.asList("LD64B", "LDADD", "LDADDA", "LDADDAL", "LDADDL", "LDADDB", "LDADDAB", "LDADDALB", "LDADDLB", "LDADDH", "LDADDAH", "LDADDALH", "LDADDLH", "LDAPR", "LDAPRB", "LDAPRH", "LDAPUR", "LDAPURB", "LDAPURH", "LDAPURSB", "LDAPURSH", "LDAPURSW", "LDAR", "LDARB", "LDARH", "LDAXP", "LDAXR", "LDAXRB", "LDAXRH", "LDCLR", "LDCLRA", "LDCLRAL", "LDCLRL", "LDCLRB", "LDCLRAB", "LDCLRALB", "LDCLRLB", "LDCLRH", "LDCLRAH", "LDCLRALH", "LDCLRLH", "LDEOR", "LDEORA", "LDEORAL", "LDEORL", "LDEORB", "LDEORAB", "LDEORALB", "LDEORLB", "LDEORH", "LDEORAH", "LDEORALH", "LDEORLH", "LDLAR", "LDLARB", "LDLARH", "LDNP", "LDP", "LDPSW", "LDR", "LDRAA", "LDRAB", "LDRB", "LDRH", "LDRSB", "LDRSH", "LDRSW", "LDSET", "LDSETA", "LDSETAL", "LDSETL", "LDSETB", "LDSETAB", "LDSETALB", "LDSETLB", "LDSETH", "LDSETAH", "LDSETALH", "LDSETLH", "LDSMAX", "LDSMAXA", "LDSMAXAL", "LDSMAXL", "LDSMAXB", "LDSMAXAB", "LDSMAXALB", "LDSMAXLB", "LDSMAXH", "LDSMAXAH", "LDSMAXALH", "LDSMAXLH", "LDSMIN", "LDSMINA", "LDSMINAL", "LDSMINL", "LDSMINB", "LDSMINAB", "LDSMINALB", "LDSMINLB", "LDSMINH", "LDSMINAH", "LDSMINALH", "LDSMINLH", "LDTR", "LDTRB", "LDTRH", "LDTRSB", "LDTRSH", "LDTRSW", "LDUMAX", "LDUMAXA", "LDUMAXAL", "LDUMAXL", "LDUMAXB", "LDUMAXAB", "LDUMAXALB", "LDUMAXLB", "LDUMAXH", "LDUMAXAH", "LDUMAXALH", "LDUMAXLH", "LDUMIN", "LDUMINA", "LDUMINAL", "LDUMINL", "LDUMINB", "LDUMINAB", "LDUMINALB", "LDUMINLB", "LDUMINH", "LDUMINAH", "LDUMINALH", "LDUMINLH", "LDUR", "LDURB", "LDURH", "LDURSB", "LDURSH", "LDURSW", "LDXP", "LDXR", "LDXRB", "LDXRH", "ST64B", "ST64BV", "ST64BV0", "STADD", "STADDL", "STADDB", "STADDLB", "STADDH", "STADDLH", "STCLR", "STCLRL", "STCLRB", "STCLRLB", "STCLRH", "STCLRLH", "STEOR", "STEORL", "STEORB", "STEORLB", "STEORH", "STEORLH", "STLLR", "STLLRB", "STLLRH", "STLR", "STLRB", "STLRH", "STLUR", "STLURB", "STLURH", "STLXP", "STLXR", "STLXRB", "STLXRH", "STNP", "STP", "STR", "STRB", "STRH", "STSET", "STSETL", "STSETB", "STSETLB", "STSETH", "STSETLH", "STSMAX", "STSMAXL", "STSMAXB", "STSMAXLB", "STSMAXH", "STSMAXLH", "STSMIN", "STSMINL", "STSMINB", "STSMINLB", "STSMINH", "STSMINLH", "STTR", "STTRB", "STTRH", "STUMAX", "STUMAXL", "STUMAXB", "STUMAXLB", "STUMAXH", "STUMAXLH", "STUMIN", "STUMINL", "STUMINB", "STUMINLB", "STUMINH", "STUMINLH", "STUR", "STURB", "STURH", "STXP", "STXR", "STXRB", "STXRH"));
}
private enum InstrType {
DIT, //this instruction is free to read any tainted value
LDST, //this address of the destination cannot be a tainted value but the value being loaded/stored can be tainted
NONDIT, //this instruction may not read any tainted value
}
class MemTaintTracker extends MemoryAccessFilter {
//represents a single byte which was accessed, one optimization would be to represent ranges rather than bytes
public class AccessLocation implements Comparable<AccessLocation> {
AddressSpace space;
long offset;
public AccessLocation(AddressSpace spc, long off) {
this.space = spc;
this.offset = off;
}
@Override
public int compareTo(TestScript.MemTaintTracker.AccessLocation o) {
int cmpSpc = this.space.compareTo(o.space);
if(cmpSpc != 0) return cmpSpc;
int cmpOff = Long.compare(this.offset, o.offset);
if(cmpOff != 0) return cmpOff;
return 0;
}
@Override
public String toString() {
return space.toString() + "(" + Long.toString(offset, 16) + ")";
}
}
TreeSet<AccessLocation> taintSet = new TreeSet<>(); //may contain secret-dependent values
boolean tainted = false; //whether the current executing instruction has been tainted yet
InstrType currInstrType = InstrType.NONDIT; //whether this instruction is a DIT instruction
boolean ditViolationDetected = false;
boolean raiseDitViolations = true;
public void beginInstr(InstrType instrType) {
this.currInstrType = instrType;
this.tainted = false;
// A DIT instruction cannot raise DIT violations
// And LD/ST instructions cannot raise DIT violations outside of the target address being tainted
this.raiseDitViolations = instrType == InstrType.NONDIT;
}
//marks the entirety of the accessed range as tainted
public void taint(AddressSpace spc, long off, int size) {
for(int i = 0; i < size; i++) {
taintSet.add(new AccessLocation(spc, off + i));
}
}
//marks the entirety of the accessed range as untainted
public void untaint(AddressSpace spc, long off, int size) {
for(int i = 0; i < size; i++) {
taintSet.remove(new AccessLocation(spc, off + i));
}
}
//returns true if any part of this read has been tainted
public boolean checkTaint(AddressSpace spc, long off, int size) {
for(int i = 0; i < size; i++) {
if(taintSet.contains(new AccessLocation(spc, off + i))) {
return true;
}
}
return false;
}
//allows an instruction to be temporarily treated as DIT or non-DIT
public void setRaiseDitViolations(boolean rdv) {
this.raiseDitViolations = rdv;
}
@Override
protected void processRead(AddressSpace spc, long off, int size, byte[] values) {
boolean readTainted = checkTaint(spc, off, size);
if(readTainted) {
if(!this.tainted) {
this.tainted = true;
}
if(this.raiseDitViolations && this.currInstrType != InstrType.DIT) {
println("DIT VIOLATION DETECTED");
ditViolationDetected = true;
}
}
}
@Override
protected void processWrite(AddressSpace spc, long off, int size, byte[] values) {
if(this.tainted) {
this.taint(spc, off, size);
}else {
this.untaint(spc, off, size);
}
}
}
class DitEmulator extends Emulate {
AddressFactory addrFactory;
MemoryState memstate;
MemTaintTracker taintTrack;
public DitEmulator(SleighLanguage lang, MemoryState s, BreakTable b, MemTaintTracker tt) {
super(lang, s, b);
addrFactory = lang.getAddressFactory();
memstate = s;
taintTrack = tt;
}
@Override
public void executeInstruction(boolean stopAtBreakpoint, TaskMonitor monitor1) throws CancelledException, LowlevelError, InstructionDecodeException {
Address pcAddr = this.getExecuteAddress();
byte[] instrBytes = new byte[MAX_INSTR_LENGTH];
memstate.getChunk(instrBytes, pcAddr.getAddressSpace(), pcAddr.getOffset(), MAX_INSTR_LENGTH, false);
Instruction instr;
try {
instr = new PseudoDisassembler(currentProgram).disassemble(this.getExecuteAddress(), instrBytes);
} catch (Exception e) {
throw new InstructionDecodeException("Failed to disassemble instruction to determine DIT status", this.getExecuteAddress());
}
println(instr.getMnemonicString());
InstrType instrType = InstrType.NONDIT;
if(DIT_GP_MNEMONICS.contains(instr.getMnemonicString().toUpperCase())) {
instrType = InstrType.DIT;
}else if(LDST_MNEMONICS.contains(instr.getMnemonicString().toUpperCase())) {
instrType = InstrType.LDST;
println("here");
}
taintTrack.beginInstr(instrType);
super.executeInstruction(stopAtBreakpoint, monitor1);
}
//a direct copy of overloaded functions, except the addresses are explicitly checked for DIT violations
@Override
public void executeLoad(PcodeOpRaw op) {
//we must ensure that the target address is not secret-dependent
if(taintTrack.currInstrType == InstrType.LDST) {
taintTrack.setRaiseDitViolations(true);
}
AddressSpace space =
addrFactory.getAddressSpace((int) op.getInput(0).getAddress().getOffset()); // Space to read from
long offset = memstate.getValue(op.getInput(1)); // Offset to read from
long byteOffset =
space.truncateAddressableWordOffset(offset) * space.getAddressableUnitSize();
if(taintTrack.currInstrType == InstrType.LDST) {
taintTrack.setRaiseDitViolations(false);
}
Varnode outvar = op.getOutput();
if (outvar.getSize() > 8) {
BigInteger res =
memstate.getBigInteger(space, byteOffset, op.getOutput().getSize(), false);
memstate.setValue(outvar, res);
}
else {
long res = memstate.getValue(space, byteOffset, op.getOutput().getSize());
memstate.setValue(op.getOutput(), res);
}
}
@Override
public void executeStore(PcodeOpRaw op) {
//we must ensure that the target address is not secret-dependent
if(taintTrack.currInstrType == InstrType.LDST) {
taintTrack.setRaiseDitViolations(true);
}
AddressSpace space =
addrFactory.getAddressSpace((int) op.getInput(0).getAddress().getOffset()); // Space to store in
long offset = memstate.getValue(op.getInput(1)); // Offset to store at
long byteOffset =
space.truncateAddressableWordOffset(offset) * space.getAddressableUnitSize();
if(taintTrack.currInstrType == InstrType.LDST) {
taintTrack.setRaiseDitViolations(false);
}
taintTrack.setRaiseDitViolations(false);
Varnode storedVar = op.getInput(2); // Value being stored
if (storedVar.getSize() > 8) {
BigInteger val = memstate.getBigInteger(storedVar, false);
memstate.setValue(space, byteOffset, op.getInput(2).getSize(), val);
}
else {
long val = memstate.getValue(storedVar);
memstate.setValue(space, byteOffset, op.getInput(2).getSize(), val);
}
}
}
public void run() throws Exception {
if(currentProgram == null) {
printerr("Please open a program you would like to verify.");
return;
}
if (!"AARCH64:LE:64:v8A".equals(currentProgram.getLanguageID().toString())) {
printerr("Sorry, this PoC script is currently only designed to run for AArch64.");
return;
}
if(currentAddress == null) {
printerr("Please place your cursor at the beginning of a function you would like to verify.");
return;
}
EmulatorHelper emuHelper = new EmulatorHelper(currentProgram);
SleighLanguage language = (SleighLanguage)currentProgram.getLanguage();
MemTaintTracker taintTrack = new MemTaintTracker();
Emulator emu = emuHelper.getEmulator();
emuHelper.writeRegister(emuHelper.getPCRegister(), currentAddress.getOffset());
emuHelper.writeRegister(emuHelper.getStackPointerRegister(), 0x000000002FFF0000);
emuHelper.writeRegister("x0", 42);
emuHelper.writeRegister("x1", 50);
Address taintedRegAddr = emuHelper.getLanguage().getRegister("x0").getAddress();
taintTrack.taint(taintedRegAddr.getAddressSpace(), taintedRegAddr.getOffset(), 8);
taintTrack.setFilterOnExecutionOnly(false);
emu.addMemoryAccessFilter(taintTrack);
MemoryState memState = emuHelper.getEmulator().getFilteredMemState();
println(Long.toString(memState.getValue(emuHelper.getPCRegister())));
DitEmulator ditEmu = new DitEmulator(language, memState, new BreakTableCallBack(language), taintTrack);
ditEmu.setExecuteAddress(currentAddress);
println(taintTrack.taintSet.toString());
while(true) {
if(ditEmu.getExecuteAddress().getOffset() == 0) {
println("Successfully finished emulating function with no DIT violations detected.");
break;
}
ditEmu.executeInstruction(false, monitor);
println(taintTrack.taintSet.toString());
if(taintTrack.ditViolationDetected) {
break;
}
}
printf("value of r0 after function call: %x\n", emuHelper.readRegister("x0"));
println(taintTrack.taintSet.toString());
emuHelper.dispose();
}
}