Package nallar.tickthreading.patcher

Source Code of nallar.tickthreading.patcher.Patcher$PatchDescriptor

package nallar.tickthreading.patcher;

import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import com.google.common.io.Files;
import javassist.CannotCompileException;
import javassist.ClassLoaderPool;
import javassist.ClassPool;
import javassist.CtBehavior;
import javassist.CtClass;
import javassist.CtConstructor;
import javassist.CtMethod;
import javassist.NotFoundException;
import nallar.log.PatchLog;
import nallar.tickthreading.patcher.mappings.ClassDescription;
import nallar.tickthreading.patcher.mappings.FieldDescription;
import nallar.tickthreading.patcher.mappings.MCPMappings;
import nallar.tickthreading.patcher.mappings.Mappings;
import nallar.tickthreading.patcher.mappings.MethodDescription;
import nallar.tickthreading.util.CollectionsUtil;
import nallar.tickthreading.util.DomUtil;
import nallar.unsafe.UnsafeUtil;
import net.minecraft.launchwrapper.LaunchClassLoader;
import org.w3c.dom.Document;
import org.w3c.dom.Element;

import java.io.*;
import java.lang.annotation.*;
import java.lang.reflect.*;
import java.util.*;

public class Patcher {
  private final ClassPool classPool;
  private final ClassPool preSrgClassPool;
  private final Mappings mappings;
  private final Mappings preSrgMappings;
  private static final boolean DEBUG_PATCHED_OUTPUT = System.getProperty("patcher.debug", "false").equals("true");
  private Object patchClassInstance;
  private Object preSrgPatchClassInstance;
  private Map<String, PatchMethodDescriptor> patchMethods = new HashMap<String, PatchMethodDescriptor>();
  private Map<String, PatchGroup> classToPatchGroup = new HashMap<String, PatchGroup>();

  public Patcher(InputStream config, Class<?> patchesClass) {
    for (Method method : patchesClass.getDeclaredMethods()) {
      for (Annotation annotation : method.getDeclaredAnnotations()) {
        if (annotation instanceof Patch) {
          PatchMethodDescriptor patchMethodDescriptor = new PatchMethodDescriptor(method, (Patch) annotation);
          patchMethods.put(patchMethodDescriptor.name, patchMethodDescriptor);
        }
      }
    }
    classPool = new ClassLoaderPool(false);
    preSrgClassPool = new ClassLoaderPool(true);
    try {
      mappings = new MCPMappings(true);
      preSrgMappings = new MCPMappings(false);
    } catch (IOException e) {
      throw new RuntimeException("Failed to load mappings", e);
    }
    try {
      patchClassInstance = patchesClass.getDeclaredConstructors()[0].newInstance(classPool, mappings);
      preSrgPatchClassInstance = patchesClass.getDeclaredConstructors()[0].newInstance(preSrgClassPool, preSrgMappings);
    } catch (Exception e) {
      PatchLog.severe("Failed to instantiate patch class", e);
    }
    try {
      readPatchesFromXmlDocument(DomUtil.readDocumentFromInputStream(config));
    } catch (Throwable t) {
      throw UnsafeUtil.throwIgnoreChecked(t);
    }
  }

  private String env = null;

  private String getEnv() {
    String env = this.env;
    if (env != null) {
      return env;
    }
    try {
      if (LaunchClassLoader.instance.getClassBytes("za.co.mcportcentral.MCPCConfig") == null) {
        throw new IOException();
      }
      env = "mcpc";
    } catch (IOException e) {
      env = "forge";
    }
    return this.env = env;
  }

  private void readPatchesFromXmlDocument(Document document) {
    List<Element> patchGroupElements = DomUtil.elementList(document.getDocumentElement().getChildNodes());
    for (Element patchGroupElement : patchGroupElements) {
      new PatchGroup(patchGroupElement);
    }
  }

  public synchronized byte[] preSrgTransformation(String name, String transformedName, byte[] originalBytes) {
    PatchGroup patchGroup = getPatchGroup(name);
    if (patchGroup != null && patchGroup.preSrg) {
      return patchGroup.getClassBytes(name, originalBytes);
    }
    return originalBytes;
  }

  public synchronized byte[] postSrgTransformation(String name, String transformedName, byte[] originalBytes) {
    PatchGroup patchGroup = getPatchGroup(transformedName);
    if (patchGroup != null && !patchGroup.preSrg) {
      return patchGroup.getClassBytes(transformedName, originalBytes);
    }
    return originalBytes;
  }

  public boolean shouldPostSrgTransform(String transformedName) {
    PatchGroup patchGroup = getPatchGroup(transformedName);
    return patchGroup != null && !patchGroup.preSrg;
  }

  private PatchGroup getPatchGroup(String name) {
    return classToPatchGroup.get(name);
  }

  private static final Splitter idSplitter = Splitter.on("  ").trimResults().omitEmptyStrings();

  private class PatchGroup {
    public final String name;
    public final boolean preSrg;
    public final boolean onDemand;
    public final ClassPool classPool;
    public final Mappings mappings;
    private final Map<String, ClassPatchDescriptor> patches;
    private final Map<String, byte[]> patchedBytes = new HashMap<String, byte[]>();
    private final List<ClassPatchDescriptor> classPatchDescriptors = new ArrayList<ClassPatchDescriptor>();
    private boolean ranPatches = false;

    private PatchGroup(Element element) {
      Map<String, String> attributes = DomUtil.getAttributes(element);
      name = element.getTagName();
      preSrg = attributes.containsKey("preSrg");
      if (preSrg) {
        classPool = preSrgClassPool;
        mappings = preSrgMappings;
      } else {
        classPool = Patcher.this.classPool;
        mappings = Patcher.this.mappings;
      }
      obfuscateAttributesAndTextContent(element);
      onDemand = attributes.containsKey("onDemand");
      patches = onDemand ? new HashMap<String, ClassPatchDescriptor>() : null;

      for (Element classElement : DomUtil.elementList(element.getChildNodes())) {
        ClassPatchDescriptor classPatchDescriptor;
        try {
          classPatchDescriptor = new ClassPatchDescriptor(classElement);
        } catch (Throwable t) {
          throw new RuntimeException("Failed to create class patch for " + classElement.getAttribute("id"), t);
        }
        PatchGroup other = classToPatchGroup.get(classPatchDescriptor.name);
        if (other != null && other.preSrg != this.preSrg) {
          PatchLog.severe("Adding class " + classPatchDescriptor.name + " in patch group " + name + " to patch group with different preSrg setting " + other.name);
        }
        (other == null ? this : other).addClassPatchDescriptor(classPatchDescriptor);
      }
    }

    private void addClassPatchDescriptor(ClassPatchDescriptor classPatchDescriptor) {
      classToPatchGroup.put(classPatchDescriptor.name, this);
      classPatchDescriptors.add(classPatchDescriptor);
      if (onDemand && patches.put(classPatchDescriptor.name, classPatchDescriptor) != null) {
        throw new Error("Duplicate class patch for " + classPatchDescriptor.name + ", but onDemand is set.");
      }
    }

    private void obfuscateAttributesAndTextContent(Element root) {
      for (Element classElement : DomUtil.elementList(root.getChildNodes())) {
        String env = classElement.getAttribute("env");
        if (env != null && !env.isEmpty()) {
          if (!env.equals(getEnv())) {
            root.removeChild(classElement);
          }
        }
      }
      for (Element element : DomUtil.elementList(root.getChildNodes())) {
        if (!DomUtil.elementList(element.getChildNodes()).isEmpty()) {
          obfuscateAttributesAndTextContent(element);
        } else if (element.getTextContent() != null && !element.getTextContent().isEmpty()) {
          element.setTextContent(mappings.obfuscate(element.getTextContent()));
        }
        Map<String, String> attributes = DomUtil.getAttributes(element);
        for (Map.Entry<String, String> attributeEntry : attributes.entrySet()) {
          element.setAttribute(attributeEntry.getKey(), mappings.obfuscate(attributeEntry.getValue()));
        }
      }
      for (Element classElement : DomUtil.elementList(root.getChildNodes())) {
        String id = classElement.getAttribute("id");
        ArrayList<String> list = Lists.newArrayList(idSplitter.split(id));
        if (list.size() > 1) {
          for (String className : list) {
            Element newClassElement = (Element) classElement.cloneNode(true);
            newClassElement.setAttribute("id", className.trim());
            classElement.getParentNode().insertBefore(newClassElement, classElement);
          }
          classElement.getParentNode().removeChild(classElement);
        }
      }
    }

    private void saveByteCode(byte[] bytes, String name) {
      if (DEBUG_PATCHED_OUTPUT) {
        name = name.replace('.', '/') + ".class";
        File file = new File("./TTpatched/" + name);
        file.getParentFile().mkdirs();
        try {
          Files.write(bytes, file);
        } catch (IOException e) {
          PatchLog.severe("Failed to save patched bytes for " + name, e);
        }
      }
    }

    public byte[] getClassBytes(String name, byte[] originalBytes) {
      byte[] bytes = patchedBytes.get(name);
      if (bytes != null) {
        return bytes;
      }
      if (onDemand) {
        try {
          bytes = patches.get(name).runPatches().toBytecode();
          patchedBytes.put(name, bytes);
          saveByteCode(bytes, name);
          PatchLog.flush();
          return bytes;
        } catch (Throwable t) {
          PatchLog.severe("Failed to patch " + name + " in patch group " + name + '.', t);
          return originalBytes;
        }
      }
      runPatchesIfNeeded();
      bytes = patchedBytes.get(name);
      if (bytes == null) {
        PatchLog.severe("Got no patched bytes for " + name);
        return originalBytes;
      }
      return bytes;
    }

    private void runPatchesIfNeeded() {
      if (ranPatches) {
        return;
      }
      ranPatches = true;
      Set<CtClass> patchedClasses = new HashSet<CtClass>();
      for (ClassPatchDescriptor classPatchDescriptor : classPatchDescriptors) {
        try {
          try {
            patchedClasses.add(classPatchDescriptor.runPatches());
          } catch (NotFoundException e) {
            if (e.getMessage().contains(classPatchDescriptor.name)) {
              PatchLog.warning("Skipping patch for " + classPatchDescriptor.name + ", not found.");
            } else {
              throw e;
            }
          }
        } catch (Throwable t) {
          PatchLog.severe("Failed to patch " + classPatchDescriptor.name + " in patch group " + name + '.', t);
        }
      }
      for (CtClass ctClass : patchedClasses) {
        String className = ctClass.getName();
        if (!ctClass.isModified()) {
          PatchLog.severe("Failed to get patched bytes for " + className + " as it was never modified in patch group " + name + '.');
          continue;
        }
        try {
          byte[] byteCode =  ctClass.toBytecode();
          patchedBytes.put(className, byteCode);
          saveByteCode(byteCode, className);
        } catch (Throwable t) {
          PatchLog.severe("Failed to get patched bytes for " + className + " in patch group " + name + '.', t);
        }
      }
      PatchLog.flush();
    }

    private class ClassPatchDescriptor {
      private final Map<String, String> attributes;
      public final String name;
      public final List<PatchDescriptor> patches = new ArrayList<PatchDescriptor>();

      private ClassPatchDescriptor(Element element) {
        attributes = DomUtil.getAttributes(element);
        ClassDescription deobfuscatedClass = new ClassDescription(attributes.get("id"));
        ClassDescription obfuscatedClass = mappings.map(deobfuscatedClass);
        name = obfuscatedClass == null ? deobfuscatedClass.name : obfuscatedClass.name;
        for (Element patchElement : DomUtil.elementList(element.getChildNodes())) {
          PatchDescriptor patchDescriptor = new PatchDescriptor(patchElement);
          patches.add(patchDescriptor);
          List<MethodDescription> methodDescriptionList = MethodDescription.fromListString(deobfuscatedClass.name, patchDescriptor.getMethods());
          if (!patchDescriptor.getMethods().isEmpty()) {
            patchDescriptor.set("deobf", methodDescriptionList.get(0).getShortName());
            //noinspection unchecked
            patchDescriptor.setMethods(MethodDescription.toListString((List<MethodDescription>) mappings.map(methodDescriptionList)));
          }
          String field = patchDescriptor.get("field"), prefix = "";
          if (field != null && !field.isEmpty()) {
            if (field.startsWith("this.")) {
              field = field.substring("this.".length());
              prefix = "this.";
            }
            String after = "", type = name;
            if (field.indexOf('.') != -1) {
              after = field.substring(field.indexOf('.'));
              field = field.substring(0, field.indexOf('.'));
              if (!field.isEmpty() && (field.charAt(0) == '$') && prefix.isEmpty()) {
                ArrayList<String> parameterList = new ArrayList<String>();
                for (MethodDescription methodDescriptionOriginal : methodDescriptionList) {
                  MethodDescription methodDescription = mappings.rmap(mappings.map(methodDescriptionOriginal));
                  methodDescription = methodDescription == null ? methodDescriptionOriginal : methodDescription;
                  int i = 0;
                  for (String parameter : methodDescription.getParameterList()) {
                    if (parameterList.size() <= i) {
                      parameterList.add(parameter);
                    } else if (!parameterList.get(i).equals(parameter)) {
                      parameterList.set(i, null);
                    }
                    i++;
                  }
                }
                int parameterIndex = Integer.valueOf(field.substring(1)) - 1;
                if (parameterIndex >= parameterList.size()) {
                  if (!parameterList.isEmpty()) {
                    PatchLog.severe("Can not obfuscate parameter field " + patchDescriptor.get("field") + ", index: " + parameterIndex + " but parameter list is: " + CollectionsUtil.join(parameterList));
                  }
                  break;
                }
                type = parameterList.get(parameterIndex);
                if (type == null) {
                  PatchLog.severe("Can not obfuscate parameter field " + patchDescriptor.get("field") + " automatically as this parameter does not have a single type across the methods used in this patch.");
                  break;
                }
                prefix = field + '.';
                field = after.substring(1);
                after = "";
              }
            }
            FieldDescription obfuscatedField = mappings.map(new FieldDescription(type, field));
            if (obfuscatedField != null) {
              patchDescriptor.set("field", prefix + obfuscatedField.name + after);
            }
          }
        }
      }

      public CtClass runPatches() throws NotFoundException {
        CtClass ctClass = classPool.get(name);
        for (PatchDescriptor patchDescriptor : patches) {
          PatchMethodDescriptor patchMethodDescriptor = patchMethods.get(patchDescriptor.getPatch());
          Object result = patchMethodDescriptor.run(patchDescriptor, ctClass, preSrg ? preSrgPatchClassInstance : patchClassInstance);
          if (result instanceof CtClass) {
            ctClass = (CtClass) result;
          }
        }
        return ctClass;
      }
    }
  }

  private static class PatchDescriptor {
    private final Map<String, String> attributes;
    private String methods;
    private final String patch;

    PatchDescriptor(Element element) {
      attributes = DomUtil.getAttributes(element);
      methods = element.getTextContent().trim();
      patch = element.getTagName();
    }

    public String set(String name, String value) {
      return attributes.put(name, value);
    }

    public String get(String name) {
      return attributes.get(name);
    }

    public Map<String, String> getAttributes() {
      return attributes;
    }

    public String getMethods() {
      return methods;
    }

    public String getPatch() {
      return patch;
    }

    public void setMethods(String methods) {
      this.methods = methods;
    }
  }

  public static class PatchMethodDescriptor {
    public final String name;
    public final List<String> requiredAttributes;
    public final Method patchMethod;
    public final boolean isClassPatch;
    public final boolean emptyConstructor;

    public PatchMethodDescriptor(Method method, Patch patch) {
      String name = patch.name();
      if (Arrays.asList(method.getParameterTypes()).contains(Map.class)) {
        this.requiredAttributes = CollectionsUtil.split(patch.requiredAttributes());
      } else {
        this.requiredAttributes = null;
      }
      if (name == null || name.isEmpty()) {
        name = method.getName();
      }
      this.name = name;
      emptyConstructor = patch.emptyConstructor();
      isClassPatch = method.getParameterTypes()[0].equals(CtClass.class);
      patchMethod = method;
    }

    public Object run(PatchDescriptor patchDescriptor, CtClass ctClass, Object patchClassInstance) {
      String methods = patchDescriptor.getMethods();
      Map<String, String> attributes = patchDescriptor.getAttributes();
      Map<String, String> attributesClean = new HashMap<String, String>(attributes);
      attributesClean.remove("code");
      PatchLog.fine("Patching " + ctClass.getName() + " with " + this.name + '(' + CollectionsUtil.joinMap(attributesClean) + ')' + (methods.isEmpty() ? "" : " {" + methods + '}'));
      if (requiredAttributes != null && !attributes.keySet().containsAll(requiredAttributes)) {
        PatchLog.severe("Missing required attributes " + requiredAttributes.toString() + " when patching " + ctClass.getName());
        return null;
      }
      if ("^all^".equals(methods)) {
        patchDescriptor.set("silent", "true");
        List<CtBehavior> ctBehaviors = new ArrayList<CtBehavior>();
        Collections.addAll(ctBehaviors, ctClass.getDeclaredMethods());
        Collections.addAll(ctBehaviors, ctClass.getDeclaredConstructors());
        CtBehavior initializer = ctClass.getClassInitializer();
        if (initializer != null) {
          ctBehaviors.add(initializer);
        }
        for (CtBehavior ctBehavior : ctBehaviors) {
          run(ctBehavior, attributes, patchClassInstance);
        }
      } else if (isClassPatch || (!emptyConstructor && methods.isEmpty())) {
        return run(ctClass, attributes, patchClassInstance);
      } else if (methods.isEmpty()) {
        for (CtConstructor ctConstructor : ctClass.getDeclaredConstructors()) {
          run(ctConstructor, attributes, patchClassInstance);
        }
      } else if ("^static^".equals(methods)) {
        CtConstructor ctBehavior = ctClass.getClassInitializer();
        if (ctBehavior == null) {
          PatchLog.severe("No static initializer found patching " + ctClass.getName() + " with " + toString());
        } else {
          run(ctBehavior, attributes, patchClassInstance);
        }
      } else {
        List<MethodDescription> methodDescriptions = MethodDescription.fromListString(ctClass.getName(), methods);
        for (MethodDescription methodDescription : methodDescriptions) {
          CtMethod ctMethod;
          try {
            ctMethod = methodDescription.inClass(ctClass);
          } catch (Throwable t) {
            if (!attributes.containsKey("allowMissing")) {
              PatchLog.warning("", t);
            }
            continue;
          }
          run(ctMethod, attributes, patchClassInstance);
        }
      }
      return null;
    }

    private Object run(CtClass ctClass, Map<String, String> attributes, Object patchClassInstance) {
      try {
        if (requiredAttributes == null) {
          return patchMethod.invoke(patchClassInstance, ctClass);
        } else {
          return patchMethod.invoke(patchClassInstance, ctClass, attributes);
        }
      } catch (Throwable t) {
        if (t instanceof InvocationTargetException) {
          t = t.getCause();
        }
        if (t instanceof CannotCompileException && attributes.containsKey("code")) {
          PatchLog.severe("Code: " + attributes.get("code"));
        }
        PatchLog.severe("Error patching " + ctClass.getName() + " with " + toString(), t);
        return null;
      }
    }

    private Object run(CtBehavior ctBehavior, Map<String, String> attributes, Object patchClassInstance) {
      try {
        if (requiredAttributes == null) {
          return patchMethod.invoke(patchClassInstance, ctBehavior);
        } else {
          return patchMethod.invoke(patchClassInstance, ctBehavior, attributes);
        }
      } catch (Throwable t) {
        if (t instanceof InvocationTargetException) {
          t = t.getCause();
        }
        if (t instanceof CannotCompileException && attributes.containsKey("code")) {
          PatchLog.severe("Code: " + attributes.get("code"));
        }
        PatchLog.severe("Error patching " + ctBehavior.getName() + " in " + ctBehavior.getDeclaringClass().getName() + " with " + toString(), t);
        return null;
      }
    }

    @Override
    public String toString() {
      return name;
    }
  }
}
TOP

Related Classes of nallar.tickthreading.patcher.Patcher$PatchDescriptor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.