/*
* Copyright (C) 2014 Zhang,Yuexiang (xfeep)
* All rights reserved.
*/
package nginx.clojure.wave;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;
import nginx.clojure.SuspendableConstructorUtilStack;
import nginx.clojure.asm.Label;
import nginx.clojure.asm.MethodVisitor;
import nginx.clojure.asm.Opcodes;
import nginx.clojure.asm.Type;
import nginx.clojure.asm.tree.AbstractInsnNode;
import nginx.clojure.asm.tree.InsnList;
import nginx.clojure.asm.tree.LabelNode;
import nginx.clojure.asm.tree.LocalVariableNode;
import nginx.clojure.asm.tree.MethodInsnNode;
import nginx.clojure.asm.tree.MethodNode;
import nginx.clojure.asm.tree.TryCatchBlockNode;
import nginx.clojure.asm.tree.analysis.Analyzer;
import nginx.clojure.asm.tree.analysis.AnalyzerException;
import nginx.clojure.asm.tree.analysis.BasicValue;
import nginx.clojure.asm.tree.analysis.Frame;
import nginx.clojure.asm.tree.analysis.Value;
import nginx.clojure.wave.MethodDatabase.ClassEntry;
public class InstrumentConstructorMethod {
private static final String CSTACK_NAME = Type.getInternalName(SuspendableConstructorUtilStack.class);
private static final String CSTACK_PUSH_DOUBLE_VALUE_DESC = "(DL"+CSTACK_NAME+";I)V";
private static final String CSTACK_PUSH_LONG_VALUE_DESC = "(JL"+CSTACK_NAME+";I)V";
private static final String CSTACK_PUSH_FLOAT_VALUE_DESC = "(FL"+CSTACK_NAME+";I)V";
private static final String CSTACK_PUSH_INT_VALUE_DESC = "(IL"+CSTACK_NAME+";I)V";
private static final String CSTACK_PUSH_OBJECT_VALUE_DESC = "(Ljava/lang/Object;L"+CSTACK_NAME+";I)V";
private final MethodDatabase db;
private final String className;
private final MethodNode mn;
private final Frame[] frames;
private final int lvarCStack;
private final int firstLocal;
public InstrumentConstructorMethod(MethodDatabase db, String className, MethodNode mn) throws AnalyzerException {
this.db = db;
this.className = className;
this.mn = mn;
try {
Analyzer a = MethodDatabaseUtil.buildAnalyzer(db);
this.frames = a.analyze(className, mn);
this.lvarCStack = mn.maxLocals+2;
this.firstLocal = ((mn.access & Opcodes.ACC_STATIC) == Opcodes.ACC_STATIC) ? 0 : 1;
} catch (UnsupportedOperationException ex) {
throw new AnalyzerException(null, ex.getMessage(), ex);
}
}
public static String getMD5(String desc) {
try {
byte[] dg = MessageDigest.getInstance("MD5").digest(desc.getBytes());
return new BigInteger(1,dg).toString(16);
} catch (NoSuchAlgorithmException ex) {
throw new RuntimeException(ex.getMessage(), ex);
}
}
public void trySplitIntoTwoNewMethods(InstrumentClass cv) throws AnalyzerException {
int numIns = mn.instructions.size();
int splitPos = -1;
MethodInsnNode invokedInitInsn = null;
for(int i = 0 ; i < numIns ; i++) {
AbstractInsnNode insn = mn.instructions.get(i);
if (db.meetTraceTargetClassMethod(className, mn.name)) {
db.debug(InstrumentClass.insnToString(insn));
}
// System.out.println(insnToString(insn));
if (insn instanceof MethodInsnNode) {
MethodInsnNode minsn = (MethodInsnNode) insn;
Frame mif = frames[i];
if (minsn.name.equals("<init>") && mif.getLocal(0) == mif.getStack(mif.getStackSize()-1-TypeAnalyzer.getNumArguments(minsn.desc)) ) {
splitPos = i+1;
invokedInitInsn = minsn;
break;
}
}
}
if (splitPos == -1) {
splitPos = 0;
}
Frame f = frames[splitPos];
FrameInfo fi = new FrameInfo(f, firstLocal, splitPos, mn.instructions, db);
emitShrinkedInitMethod(cv, splitPos, f, fi, invokedInitInsn);
emitInitHelpMethod(cv, numIns, splitPos, f, fi, invokedInitInsn);
}
public void emitInitHelpMethod(InstrumentClass cv, int numIns, int splitPos, Frame f, FrameInfo fi, MethodInsnNode invokedInitInsn) throws AnalyzerException {
mn.instructions.resetLabels();
List<String> exps = new ArrayList<String>(mn.exceptions);
if (!exps.contains(CheckInstrumentationVisitor.EXCEPTION_NAME)) {
exps.add(CheckInstrumentationVisitor.EXCEPTION_NAME);
}
String[] expss = MethodDatabase.toStringArray(exps);
MethodNode mv = new MethodNode(Opcodes.ACC_PUBLIC, buildInitHelpMethodName(mn.desc), "()V", null, expss);
Label invokedInitInsnStart = null;
Label invokedInitInsnEnd = null;
Label invokedInitInsnCatchAll = null;
boolean needWaveInvokedInitInsn = invokedInitInsn != null
&& db.checkMethodSuspendType(invokedInitInsn.owner, ClassEntry.key(invokedInitInsn.name, invokedInitInsn.desc), false) == MethodDatabase.SUSPEND_NORMAL;
if (needWaveInvokedInitInsn) {
invokedInitInsnStart = new Label();
invokedInitInsnEnd = new Label();
invokedInitInsnCatchAll = new Label();
mv.visitTryCatchBlock(invokedInitInsnStart, invokedInitInsnEnd, invokedInitInsnCatchAll, null);
}
for(TryCatchBlockNode tcb : mn.tryCatchBlocks) {
tcb.accept(mv);
}
mv.visitMethodInsn(Opcodes.INVOKESTATIC, CSTACK_NAME, "getStack", "()L"+CSTACK_NAME+";");
mv.visitVarInsn(Opcodes.ASTORE, lvarCStack);
if (needWaveInvokedInitInsn) {
mv.visitLabel(invokedInitInsnStart);
mv.visitVarInsn(Opcodes.ALOAD, 0);
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, invokedInitInsn.owner, "inch_" + getMD5(invokedInitInsn.desc), "()V");
mv.visitLabel(invokedInitInsnEnd);
}
for(int i=firstLocal ; i<f.getLocals() ; i++) {
BasicValue v = (BasicValue) f.getLocal(i);
if(!isNullType(v)) {
int slotIdx = fi.localSlotIndices[i];
assert slotIdx >= 0 && slotIdx < fi.numSlots;
emitRestoreValue(mv, v, lvarCStack, slotIdx);
mv.visitVarInsn(v.getType().getOpcode(Opcodes.ISTORE), i);
} else if(v != BasicValue.UNINITIALIZED_VALUE) {
mv.visitInsn(Opcodes.ACONST_NULL);
mv.visitVarInsn(Opcodes.ASTORE, i);
}
}
for(int i=0 ; i<f.getStackSize() ; i++) {
BasicValue v = (BasicValue) f.getStack(i);
if(!isOmitted(v)) {
if(!isNullType(v)) {
int slotIdx = fi.stackSlotIndices[i];
assert slotIdx >= 0 && slotIdx < fi.numSlots;
emitRestoreValue(mv, v, lvarCStack, slotIdx);
} else {
mv.visitInsn(Opcodes.ACONST_NULL);
}
}
}
mv.visitVarInsn(Opcodes.ALOAD,lvarCStack);
emitConst(mv, fi.numSlots);
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, CSTACK_NAME, "release", "(I)V");
int maxStack = mn.maxStack;
for (int i = splitPos; i < numIns; i++) {
AbstractInsnNode insn = mn.instructions.get(i);
if (insn instanceof MethodInsnNode) {
MethodInsnNode misn = (MethodInsnNode) insn;
String name = misn.name;
if (name.charAt(0) == '<' && name.charAt(1) == 'i' && db != null && db.checkMethodSuspendType(misn.owner, ClassEntry.key(name, misn.desc), false, false) == MethodDatabase.SUSPEND_NORMAL) {
mv.visitInsn(Opcodes.ACONST_NULL);
mv.visitMethodInsn(misn.getOpcode(), misn.owner, name, InstrumentConstructorMethod.buildShrinkedInitMethodDesc(misn.desc));
mv.visitInsn(Opcodes.DUP);
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, misn.owner, InstrumentConstructorMethod.buildInitHelpMethodName(misn.desc), "()V");
maxStack = mn.maxStack + 1;
continue;
}
}
insn.accept(mv);
}
if (needWaveInvokedInitInsn) {
mv.visitLabel(invokedInitInsnCatchAll);
mv.visitVarInsn(Opcodes.ALOAD,lvarCStack);
emitConst(mv, fi.numSlots);
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, CSTACK_NAME, "release", "(I)V");
mv.visitInsn(Opcodes.ATHROW);
}
if(mn.localVariables != null && !mn.localVariables.isEmpty()) {
for (int i = 0; i < f.getLocals(); i++) {
mn.localVariables.get(i).accept(mv);
}
}
mv.visitMaxs(maxStack+3, mn.maxLocals+3);
mv.visitEnd();
InstrumentMethod im = new InstrumentMethod(db, className, mv);
if(im.collectCodeBlocks()) {
im.accept(cv.makeOutMV(mv));
}else {
throw new RuntimeException("no suspendable method in constructor : " + mn.name + mn.desc);
}
}
public void emitShrinkedInitMethod(InstrumentClass cv, int splitPos, Frame f, FrameInfo fi, MethodInsnNode invokedInitInsn)
throws InternalError {
String desc = buildShrinkedInitMethodDesc(mn.desc);
String[] exps = MethodDatabase.toStringArray(mn.exceptions);
MethodVisitor cmv = cv.makeOutMV(mn.access, "<init>", desc, null, exps);
cmv.visitCode();
for (int i = 0; i < splitPos -1; i++) {
mn.instructions.get(i).accept(cmv);
}
if (invokedInitInsn != null) {
if (db.checkMethodSuspendType(invokedInitInsn.owner, ClassEntry.key(invokedInitInsn.name, invokedInitInsn.desc), false, false) == MethodDatabase.SUSPEND_NORMAL) {
Type[] tps = Type.getArgumentTypes(invokedInitInsn.desc);
Type[] ntps = new Type[tps.length + 1];
System.arraycopy(tps, 0, ntps, 0, tps.length);
ntps[tps.length] = Type.getType(CheckInstrumentationVisitor.EXCEPTION_DESC);
cmv.visitInsn(Opcodes.ACONST_NULL);
cmv.visitMethodInsn(invokedInitInsn.getOpcode(), invokedInitInsn.owner,invokedInitInsn.name, Type.getMethodDescriptor(Type.VOID_TYPE, ntps));
}else {
invokedInitInsn.accept(cmv);
}
}
cmv.visitMethodInsn(Opcodes.INVOKESTATIC, CSTACK_NAME, "getStack", "()L"+CSTACK_NAME+";");
cmv.visitVarInsn(Opcodes.ASTORE, lvarCStack);
cmv.visitVarInsn(Opcodes.ALOAD,lvarCStack);
emitConst(cmv, fi.numSlots);
cmv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, CSTACK_NAME, "incRefsAndReserveSpace", "(I)V");
for (int i = f.getStackSize(); i-->0 ;) {
BasicValue v = (BasicValue) f.getStack(i);
if(!isOmitted(v)) {
if(!isNullType(v)) {
int slotIdx = fi.stackSlotIndices[i];
assert slotIdx >= 0 && slotIdx < fi.numSlots;
emitStoreValue(cmv, v, lvarCStack, slotIdx);
} else {
db.trace("NULL stack entry: type=%s size=%d", v.getType(), v.getSize());
cmv.visitInsn(Opcodes.POP);
}
}
}
for(int i=firstLocal; i < f.getLocals() ; i++) {
BasicValue v = (BasicValue) f.getLocal(i);
if(!isNullType(v)) {
cmv.visitVarInsn(v.getType().getOpcode(Opcodes.ILOAD), i);
int slotIdx = fi.localSlotIndices[i];
assert slotIdx >= 0 && slotIdx < fi.numSlots;
emitStoreValue(cmv, v, lvarCStack, slotIdx);
}
}
if(mn.localVariables != null) {
for(LocalVariableNode var : mn.localVariables) {
if (invokedInitInsn != null) {
if (mn.instructions.indexOf(var.start) <= splitPos) {
var.accept(cmv);
}
}else if (var.start.getPrevious() == null) {
var.accept(cmv);
}
}
}
cmv.visitInsn(Opcodes.RETURN);
cmv.visitMaxs(mn.maxStack+3, mn.maxLocals+3);
cmv.visitEnd();
}
public static String buildShrinkedInitMethodDesc(String desc) {
Type[] argsTypes = Type.getArgumentTypes(desc);
Type[] fargsTypes = new Type[argsTypes.length + 1];
System.arraycopy(argsTypes, 0, fargsTypes, 0, argsTypes.length);
fargsTypes[argsTypes.length] = Type.getType(CheckInstrumentationVisitor.EXCEPTION_DESC);
desc = Type.getMethodDescriptor(Type.VOID_TYPE, fargsTypes);
return desc;
}
public static String buildInitHelpMethodName(String desc) {
return "inch_" + getMD5(desc);
}
public void accept(InstrumentClass cv) throws AnalyzerException {
trySplitIntoTwoNewMethods(cv);
}
private static void emitConst(MethodVisitor mv, int value) {
if(value >= -1 && value <= 5) {
mv.visitInsn(Opcodes.ICONST_0 + value);
} else if((byte)value == value) {
mv.visitIntInsn(Opcodes.BIPUSH, value);
} else if((short)value == value) {
mv.visitIntInsn(Opcodes.SIPUSH, value);
} else {
mv.visitLdcInsn(value);
}
}
private void emitStoreValue(MethodVisitor mv, BasicValue v, int lvarStack, int idx) throws InternalError, IndexOutOfBoundsException {
String desc;
switch(v.getType().getSort()) {
case Type.OBJECT:
case Type.ARRAY:
desc = CSTACK_PUSH_OBJECT_VALUE_DESC;
break;
case Type.BOOLEAN:
case Type.BYTE:
case Type.SHORT:
case Type.CHAR:
case Type.INT:
desc = CSTACK_PUSH_INT_VALUE_DESC;
break;
case Type.FLOAT:
desc = CSTACK_PUSH_FLOAT_VALUE_DESC;
break;
case Type.LONG:
desc = CSTACK_PUSH_LONG_VALUE_DESC;
break;
case Type.DOUBLE:
desc = CSTACK_PUSH_DOUBLE_VALUE_DESC;
break;
default:
throw new InternalError("Unexpected type: " + v.getType());
}
mv.visitVarInsn(Opcodes.ALOAD, lvarStack);
emitConst(mv, idx);
mv.visitMethodInsn(Opcodes.INVOKESTATIC, CSTACK_NAME, "push", desc);
}
private void emitRestoreValue(MethodVisitor mv, BasicValue v, int lvarStack, int idx) {
mv.visitVarInsn(Opcodes.ALOAD, lvarStack);
emitConst(mv, idx);
switch(v.getType().getSort()) {
case Type.OBJECT:
String internalName = v.getType().getInternalName();
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, CSTACK_NAME, "getObject", "(I)Ljava/lang/Object;");
if(!internalName.equals("java/lang/Object")) { // don't cast to Object ;)
mv.visitTypeInsn(Opcodes.CHECKCAST, internalName);
}
break;
case Type.ARRAY:
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, CSTACK_NAME, "getObject", "(I)Ljava/lang/Object;");
mv.visitTypeInsn(Opcodes.CHECKCAST, v.getType().getDescriptor());
break;
case Type.BYTE:
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, CSTACK_NAME, "getInt", "(I)I");
mv.visitInsn(Opcodes.I2B);
break;
case Type.SHORT:
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, CSTACK_NAME, "getInt", "(I)I");
mv.visitInsn(Opcodes.I2S);
break;
case Type.CHAR:
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, CSTACK_NAME, "getInt", "(I)I");
mv.visitInsn(Opcodes.I2C);
break;
case Type.BOOLEAN:
case Type.INT:
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, CSTACK_NAME, "getInt", "(I)I");
break;
case Type.FLOAT:
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, CSTACK_NAME, "getFloat", "(I)F");
break;
case Type.LONG:
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, CSTACK_NAME, "getLong", "(I)J");
break;
case Type.DOUBLE:
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, CSTACK_NAME, "getDouble", "(I)D");
break;
default:
throw new InternalError("Unexpected type: " + v.getType());
}
}
static boolean isNullType(BasicValue v) {
return (v == BasicValue.UNINITIALIZED_VALUE) ||
(v.isReference() && v.getType().getInternalName().equals("null"));
}
static boolean isOmitted(BasicValue v) {
if(v instanceof NewValue) {
return ((NewValue)v).omitted;
}
return false;
}
static boolean isNewValue(Value v, boolean dupped) {
if(v instanceof NewValue) {
return ((NewValue)v).isDupped == dupped;
}
return false;
}
static class BlockLabelNode extends LabelNode {
final int idx;
BlockLabelNode(int idx) {
this.idx = idx;
}
}
static class FrameInfo {
static final FrameInfo FIRST = new FrameInfo(null, 0, 0, null, null);
final int endInstruction;
final int numSlots;
final int numObjSlots;
final int[] localSlotIndices;
final int[] stackSlotIndices;
BlockLabelNode lBefore;
BlockLabelNode lAfter;
FrameInfo(Frame f, int firstLocal, int endInstruction, InsnList insnList, MethodDatabase db) {
this.endInstruction = endInstruction;
int idxObj = 0;
int idxPrim = 0;
if(f != null) {
stackSlotIndices = new int[f.getStackSize()];
for(int i=0 ; i<f.getStackSize() ; i++) {
BasicValue v = (BasicValue)f.getStack(i);
if(v instanceof NewValue) {
NewValue newValue = (NewValue)v;
if(db.isDebug()) {
db.trace("Omit value from stack idx %d at instruction %d with type %s generated by %s",
i, endInstruction, v, newValue.formatInsn());
}
if(!newValue.omitted) {
newValue.omitted = true;
if(db.isDebug()) {
// need to log index before replacing instruction
db.trace("Omitting instruction %d: %s", insnList.indexOf(newValue.insn), newValue.formatInsn());
}
insnList.set(newValue.insn, new OmittedInstruction(newValue.insn));
}
stackSlotIndices[i] = -666; // an invalid index ;)
} else if(!isNullType(v)) {
if(v.isReference()) {
stackSlotIndices[i] = idxObj++;
} else {
stackSlotIndices[i] = idxPrim++;
}
} else {
stackSlotIndices[i] = -666; // an invalid index ;)
}
}
localSlotIndices = new int[f.getLocals()];
for(int i=firstLocal ; i<f.getLocals() ; i++) {
BasicValue v = (BasicValue)f.getLocal(i);
if(!isNullType(v)) {
if(v.isReference()) {
localSlotIndices[i] = idxObj++;
} else {
localSlotIndices[i] = idxPrim++;
}
} else {
localSlotIndices[i] = -666; // an invalid index ;)
}
}
} else {
stackSlotIndices = null;
localSlotIndices = null;
}
numSlots = Math.max(idxPrim, idxObj);
numObjSlots = idxObj;
}
public LabelNode createBeforeLabel() {
if(lBefore == null) {
lBefore = new BlockLabelNode(endInstruction);
}
return lBefore;
}
public LabelNode createAfterLabel() {
if(lAfter == null) {
lAfter = new BlockLabelNode(endInstruction);
}
return lAfter;
}
}
}