package net.djpowell.liverepl.agent;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.lang.instrument.Instrumentation;
import java.lang.reflect.Method;
import java.net.*;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Logger;
import net.djpowell.liverepl.client.Main;
import net.djpowell.liverepl.discovery.ClassLoaderInfo;
import net.djpowell.liverepl.discovery.Discovery;
public class Agent {
private static final Discovery discovery = new Discovery();
public static interface ConnectNotifyingTask {
void run(ServerSocket server, AtomicBoolean connected);
}
private static Thread startKillerThread(final int connectTimeout, final AtomicBoolean connected, final ServerSocket server) {
Thread killer = new Thread(new Runnable() {
public void run() {
try {
Thread.sleep(connectTimeout);
} catch (InterruptedException e) {
// ignore
}
if (!connected.get()) {
// TRC.fine("Client connect timeout: terminating server");
try {
server.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
}, "Killer Thread");
killer.start();
return killer;
}
public static void runAfterConnect(int port, int connectTimeout, String threadName, final ConnectNotifyingTask task) throws Exception {
final ServerSocket serverSocket = new ServerSocket(port, 0, Main.LOCALHOST);
final AtomicBoolean connected = new AtomicBoolean(false);
Thread taskThread = new Thread(new Runnable() {
public void run() {
task.run(serverSocket, connected);
}
}, threadName);
startKillerThread(connectTimeout, connected, serverSocket);
taskThread.start();
}
private static void printClassLoaderInfo(int port) {
try {
runAfterConnect(port, 5000, "ClassLoaderInfoThread", new ConnectNotifyingTask() {
public void run(ServerSocket server, AtomicBoolean connected) {
try {
Socket socket = server.accept();
connected.set(true);
try {
PrintStream out = new PrintStream(socket.getOutputStream());
try {
discovery.dumpList(out);
} finally {
out.close();
}
} finally {
socket.close();
}
} catch (RuntimeException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
} catch (RuntimeException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static ClassLoader pushClassLoader(List<URL> urls, String clId) {
TRC.fine("Creating new classloader with: " + urls);
ClassLoader old = Thread.currentThread().getContextClassLoader();
TRC.fine("Old classloader: " + old);
ClassLoaderInfo cli = discovery.findClassLoader(clId);
if (cli == null) {
throw new RuntimeException("Unknown class loader: " + clId);
}
ClassLoader cl = cli.getClassLoader();
URLClassLoader withClojure = new URLClassLoader(urls.toArray(new URL[urls.size()]), cl); // TODO
Thread.currentThread().setContextClassLoader(withClojure);
return old;
}
private static void popClassLoader(ClassLoader old) {
TRC.fine("Restoring old context classloader");
Thread.currentThread().setContextClassLoader(old);
}
private static boolean isClojureLoaded() {
try {
ClassLoader cl = Thread.currentThread().getContextClassLoader();
cl.loadClass("clojure.lang.RT");
return true;
} catch (ClassNotFoundException e) {
return false;
}
}
public static void agentmain(String agentArgs, Instrumentation inst) {
TRC.fine("Started Attach agent");
StringTokenizer stok = new StringTokenizer(agentArgs, "\n");
if (stok.countTokens() != 4) {
throw new RuntimeException("Invalid parameters: " + agentArgs);
}
int port = Integer.parseInt(stok.nextToken());
TRC.fine("Port: " + port);
String clojurePath = stok.nextToken();
String serverPath = stok.nextToken();
String classLoaderId = stok.nextToken();
if ("L".equals(classLoaderId)) {
printClassLoaderInfo(port);
return;
}
boolean clojureLoaded = isClojureLoaded();
TRC.fine("Clojure is " + (clojureLoaded ? "" : "not ") + "loaded");
List<URL> urls;
if (clojureLoaded) {
urls = getJarUrls(serverPath);
} else {
urls = getJarUrls(clojurePath, serverPath);
}
ClassLoader old = pushClassLoader(urls, classLoaderId);
try {
if (!clojureLoaded) { // if clojure wasn't loaded before, print current status
TRC.fine("Clojure is " + (isClojureLoaded() ? "" : "not ") + "loaded");
}
startRepl(port, inst);
} finally {
popClassLoader(old);
}
}
private static List<URL> getJarUrls(String... paths) {
List<URL> urls = new ArrayList<URL>();
try {
for (String path : paths) {
URL url = new File(path).toURI().toURL();
urls.add(url);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return urls;
}
private static void startRepl(int port, Instrumentation inst) {
// avoids making load-time references to Clojure classes from the system classloader
try {
ClassLoader cl = Thread.currentThread().getContextClassLoader();
Class<?> repl = Class.forName("net.djpowell.liverepl.server.Repl", true, cl);
Method method = repl.getMethod("main", InetAddress.class, Integer.TYPE, Instrumentation.class);
method.invoke(null, Main.LOCALHOST, port, inst);
} catch (RuntimeException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static final Logger TRC = Logger.getLogger(Agent.class.getName());
}