package org.mortbay.jetty;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.net.Socket;
import java.security.MessageDigest;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import junit.framework.TestCase;
import org.mortbay.io.ByteArrayBuffer;
import org.mortbay.jetty.client.ContentExchange;
import org.mortbay.jetty.client.HttpClient;
import org.mortbay.jetty.client.security.Realm;
import org.mortbay.jetty.client.security.SimpleRealmResolver;
import org.mortbay.jetty.handler.DefaultHandler;
import org.mortbay.jetty.handler.HandlerCollection;
import org.mortbay.jetty.nio.SelectChannelConnector;
import org.mortbay.jetty.security.Constraint;
import org.mortbay.jetty.security.ConstraintMapping;
import org.mortbay.jetty.security.DigestAuthenticator;
import org.mortbay.jetty.security.HashUserRealm;
import org.mortbay.jetty.security.SecurityHandler;
import org.mortbay.jetty.security.UserRealm;
import org.mortbay.jetty.servlet.Context;
import org.mortbay.log.Log;
import org.mortbay.util.IO;
import org.mortbay.util.StringUtil;
import org.mortbay.util.TypeUtil;
public class DigestPostTest extends TestCase
{
private static final String NC = "00000001";
public final static String __message =
"0123456789 0123456789 0123456789 0123456789 0123456789 0123456789 0123456789 0123456789 \n"+
"9876543210 9876543210 9876543210 9876543210 9876543210 9876543210 9876543210 9876543210 \n"+
"1234567890 1234567890 1234567890 1234567890 1234567890 1234567890 1234567890 1234567890 \n"+
"0987654321 0987654321 0987654321 0987654321 0987654321 0987654321 0987654321 0987654321 \n"+
"abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstuvwxyz \n"+
"ABCDEFGHIJKLMNOPQRSTUVWXYZ ABCDEFGHIJKLMNOPQRSTUVWXYZ ABCDEFGHIJKLMNOPQRSTUVWXYZ \n"+
"Now is the time for all good men to come to the aid of the party.\n"+
"How now brown cow.\n"+
"The quick brown fox jumped over the lazy dog.\n";
public volatile static String _received = null;
private Server _server;
public void setUp()
{
try
{
_server = new Server();
_server.setConnectors(new Connector[]
{ new SelectChannelConnector() });
Context context = new Context(Context.SECURITY);
context.setContextPath("/test");
context.addServlet(PostServlet.class,"/");
context.setAllowNullPathInfo(true);
HashUserRealm realm = new HashUserRealm("test");
realm.put("testuser","password");
realm.addUserToRole("testuser","test");
_server.setUserRealms(new UserRealm[]{realm});
SecurityHandler security=context.getSecurityHandler();
security.setAuthenticator(new DigestAuthenticator());
security.setUserRealm(realm);
Constraint constraint = new Constraint("SecureTest","test");
constraint.setAuthenticate(true);
ConstraintMapping mapping = new ConstraintMapping();
mapping.setConstraint(constraint);
mapping.setPathSpec("/*");
security.setConstraintMappings(new ConstraintMapping[]{mapping});
HandlerCollection handlers = new HandlerCollection();
handlers.setHandlers(new Handler[]
{ context, new DefaultHandler() });
_server.setHandler(handlers);
_server.start();
}
catch (final Exception e)
{
e.printStackTrace();
}
}
/* (non-Javadoc)
* @see junit.framework.TestCase#tearDown()
*/
protected void tearDown() throws Exception
{
_server.stop();
}
public void testServerDirectlyHTTP10() throws Exception
{
Socket socket = new Socket("127.0.0.1",_server.getConnectors()[0].getLocalPort());
byte[] bytes = __message.getBytes("UTF-8");
_received=null;
socket.getOutputStream().write(
("POST /test HTTP/1.0\r\n"+
"Host: 127.0.0.1:"+_server.getConnectors()[0].getLocalPort()+"\r\n"+
"Content-Length: "+bytes.length+"\r\n"+
"\r\n").getBytes("UTF-8"));
socket.getOutputStream().write(bytes);
socket.getOutputStream().flush();
String result = IO.toString(socket.getInputStream());
assertTrue(result.startsWith("HTTP/1.1 401 Unauthorized"));
assertEquals(null,_received);
int n=result.indexOf("nonce=");
String nonce=result.substring(n+7,result.indexOf('"',n+7));
MessageDigest md = MessageDigest.getInstance("MD5");
byte[] b= md.digest(String.valueOf(System.currentTimeMillis()).getBytes(StringUtil.__ISO_8859_1));
String cnonce=encode(b);
String digest="Digest username=\"testuser\" realm=\"test\" nonce=\""+nonce+"\" uri=\"/test\" algorithm=MD5 response=\""+
newResponse("POST","/test",cnonce,"testuser","test","password",nonce,"auth")+
"\" qop=auth nc="+NC+" cnonce=\""+cnonce+"\"";
socket = new Socket("127.0.0.1",_server.getConnectors()[0].getLocalPort());
_received=null;
socket.getOutputStream().write(
("POST /test HTTP/1.0\r\n"+
"Host: 127.0.0.1:"+_server.getConnectors()[0].getLocalPort()+"\r\n"+
"Content-Length: "+bytes.length+"\r\n"+
"Authorization: "+digest+"\r\n"+
"\r\n").getBytes("UTF-8"));
socket.getOutputStream().write(bytes);
socket.getOutputStream().flush();
result = IO.toString(socket.getInputStream());
assertTrue(result.startsWith("HTTP/1.1 200 OK"));
assertEquals(__message,_received);
}
public void testServerDirectlyHTTP11() throws Exception
{
Socket socket = new Socket("127.0.0.1",_server.getConnectors()[0].getLocalPort());
byte[] bytes = __message.getBytes("UTF-8");
_received=null;
socket.getOutputStream().write(
("POST /test HTTP/1.1\r\n"+
"Host: 127.0.0.1:"+_server.getConnectors()[0].getLocalPort()+"\r\n"+
"Content-Length: "+bytes.length+"\r\n"+
"\r\n").getBytes("UTF-8"));
socket.getOutputStream().write(bytes);
socket.getOutputStream().flush();
Thread.sleep(100);
byte[] buf=new byte[4096];
int len=socket.getInputStream().read(buf);
String result=new String(buf,0,len,"UTF-8");
assertTrue(result.startsWith("HTTP/1.1 401 Unauthorized"));
assertEquals(null,_received);
int n=result.indexOf("nonce=");
String nonce=result.substring(n+7,result.indexOf('"',n+7));
MessageDigest md = MessageDigest.getInstance("MD5");
byte[] b= md.digest(String.valueOf(System.currentTimeMillis()).getBytes(StringUtil.__ISO_8859_1));
String cnonce=encode(b);
String digest="Digest username=\"testuser\" realm=\"test\" nonce=\""+nonce+"\" uri=\"/test\" algorithm=MD5 response=\""+
newResponse("POST","/test",cnonce,"testuser","test","password",nonce,"auth")+
"\" qop=auth nc="+NC+" cnonce=\""+cnonce+"\"";
_received=null;
socket.getOutputStream().write(
("POST /test HTTP/1.0\r\n"+
"Host: 127.0.0.1:"+_server.getConnectors()[0].getLocalPort()+"\r\n"+
"Content-Length: "+bytes.length+"\r\n"+
"Authorization: "+digest+"\r\n"+
"\r\n").getBytes("UTF-8"));
socket.getOutputStream().write(bytes);
socket.getOutputStream().flush();
result = IO.toString(socket.getInputStream());
assertTrue(result.startsWith("HTTP/1.1 200 OK"));
assertEquals(__message,_received);
}
public void testServerDirectlyHTTP11Chunked() throws Exception
{
// Log.getLog().setDebugEnabled(true);
Socket socket = new Socket("127.0.0.1",_server.getConnectors()[0].getLocalPort());
byte[] bytes = __message.getBytes("UTF-8");
_received=null;
socket.getOutputStream().write(
("POST /test HTTP/1.1\r\n"+
"Host: 127.0.0.1:"+_server.getConnectors()[0].getLocalPort()+"\r\n"+
"Content-Type: text/plain\r\n"+
"Transfer-Encoding: chunked\r\n"+
"\r\n").getBytes("UTF-8"));
socket.getOutputStream().flush();
Thread.sleep(10);
socket.getOutputStream().write("10\r\n".getBytes("UTF-8"));
socket.getOutputStream().write(bytes,0,16);
socket.getOutputStream().flush();
Thread.sleep(10);
socket.getOutputStream().write("\r\n20\r\n".getBytes("UTF-8"));
socket.getOutputStream().write(bytes,16,32);
socket.getOutputStream().flush();
Thread.sleep(10);
socket.getOutputStream().write(("\r\n"+Integer.toHexString(bytes.length-48)+"\r\n").getBytes("UTF-8"));
socket.getOutputStream().write(bytes,48,bytes.length-48);
socket.getOutputStream().write("\r\n0\r\n".getBytes("UTF-8"));
socket.getOutputStream().flush();
Thread.sleep(100);
byte[] buf=new byte[4096];
int len=socket.getInputStream().read(buf);
String result=new String(buf,0,len,"UTF-8");
assertTrue(result.startsWith("HTTP/1.1 401 Unauthorized"));
assertEquals(null,_received);
int n=result.indexOf("nonce=");
String nonce=result.substring(n+7,result.indexOf('"',n+7));
MessageDigest md = MessageDigest.getInstance("MD5");
byte[] b= md.digest(String.valueOf(System.currentTimeMillis()).getBytes(StringUtil.__ISO_8859_1));
String cnonce=encode(b);
String digest="Digest username=\"testuser\" realm=\"test\" nonce=\""+nonce+"\" uri=\"/test\" algorithm=MD5 response=\""+
newResponse("POST","/test",cnonce,"testuser","test","password",nonce,"auth")+
"\" qop=auth nc="+NC+" cnonce=\""+cnonce+"\"";
_received=null;
socket.getOutputStream().write(
("POST /test HTTP/1.1\r\n"+
"Host: 127.0.0.1:"+_server.getConnectors()[0].getLocalPort()+"\r\n"+
"Content-Type: text/plain\r\n"+
"Transfer-Encoding: Chunked\r\n"+
"Authorization: "+digest+"\r\n"+
"\r\n").getBytes("UTF-8"));
socket.getOutputStream().flush();
Thread.sleep(10);
socket.getOutputStream().write("10\r\n".getBytes("UTF-8"));
socket.getOutputStream().write(bytes,0,16);
socket.getOutputStream().flush();
Thread.sleep(10);
socket.getOutputStream().write("\r\n20\r\n".getBytes("UTF-8"));
socket.getOutputStream().write(bytes,16,32);
socket.getOutputStream().flush();
Thread.sleep(10);
socket.getOutputStream().write(("\r\n"+Integer.toHexString(bytes.length-48)+"\r\n").getBytes("UTF-8"));
socket.getOutputStream().write(bytes,48,bytes.length-48);
socket.getOutputStream().write("\r\n0\r\n".getBytes("UTF-8"));
socket.getOutputStream().flush();
Thread.sleep(100);
buf=new byte[4096];
len=socket.getInputStream().read(buf);
result=new String(buf,0,len,"UTF-8");
assertTrue(result.startsWith("HTTP/1.1 200 OK"));
assertEquals(__message,_received);
socket.close();
}
public void testServerWithHttpClientStringContent() throws Exception
{
HttpClient client = new HttpClient();
client.setConnectorType(HttpClient.CONNECTOR_SELECT_CHANNEL);
client.setRealmResolver(new SimpleRealmResolver(new TestRealm()));
client.start();
String srvUrl = "http://127.0.0.1:" + _server.getConnectors()[0].getLocalPort() + "/test";
ContentExchange ex = new ContentExchange();
ex.setMethod(HttpMethods.POST);
ex.setURL(srvUrl);
ex.setRequestContent(new ByteArrayBuffer(__message,"UTF-8"));
_received=null;
client.send(ex);
ex.waitForDone();
assertEquals(__message,_received);
assertEquals(200,ex.getResponseStatus());
}
public void testServerWithHttpClientStreamContent() throws Exception
{
HttpClient client = new HttpClient();
client.setConnectorType(HttpClient.CONNECTOR_SELECT_CHANNEL);
client.setRealmResolver(new SimpleRealmResolver(new TestRealm()));
client.start();
String srvUrl = "http://127.0.0.1:" + _server.getConnectors()[0].getLocalPort() + "/test";
ContentExchange ex = new ContentExchange();
ex.setMethod(HttpMethods.POST);
ex.setURL(srvUrl);
ex.setRequestContentSource(new BufferedInputStream(new FileInputStream("src/test/resources/message.txt")));
_received=null;
client.send(ex);
ex.waitForDone();
String sent = IO.toString(new FileInputStream("src/test/resources/message.txt"));
assertEquals(sent,_received);
assertEquals(200,ex.getResponseStatus());
}
public static class TestRealm implements Realm
{
public String getPrincipal()
{
return "testuser";
}
public String getId()
{
return "test";
}
public String getCredentials()
{
return "password";
}
}
public static class PostServlet extends HttpServlet
{
public void doPost(HttpServletRequest request, HttpServletResponse response)
throws IOException
{
String received = IO.toString(request.getInputStream());
_received = received;
// System.out.println("received: "+received.length()+":\n"+received);
response.setStatus(200);
response.getWriter().println("Received "+received.length()+" bytes");
}
}
protected String newResponse(String method, String uri, String cnonce, String principal, String realm, String credentials, String nonce, String qop)
throws Exception
{
MessageDigest md = MessageDigest.getInstance("MD5");
// calc A1 digest
md.update(principal.getBytes(StringUtil.__ISO_8859_1));
md.update((byte)':');
md.update(realm.getBytes(StringUtil.__ISO_8859_1));
md.update((byte)':');
md.update(credentials.getBytes(StringUtil.__ISO_8859_1));
byte[] ha1 = md.digest();
// calc A2 digest
md.reset();
md.update(method.getBytes(StringUtil.__ISO_8859_1));
md.update((byte)':');
md.update(uri.getBytes(StringUtil.__ISO_8859_1));
byte[] ha2=md.digest();
md.update(TypeUtil.toString(ha1,16).getBytes(StringUtil.__ISO_8859_1));
md.update((byte)':');
md.update(nonce.getBytes(StringUtil.__ISO_8859_1));
md.update((byte)':');
md.update(NC.getBytes(StringUtil.__ISO_8859_1));
md.update((byte)':');
md.update(cnonce.getBytes(StringUtil.__ISO_8859_1));
md.update((byte)':');
md.update(qop.getBytes(StringUtil.__ISO_8859_1));
md.update((byte)':');
md.update(TypeUtil.toString(ha2,16).getBytes(StringUtil.__ISO_8859_1));
byte[] digest=md.digest();
// check digest
return encode(digest);
}
private static String encode(byte[] data)
{
StringBuffer buffer = new StringBuffer();
for (int i=0; i<data.length; i++)
{
buffer.append(Integer.toHexString((data[i] & 0xf0) >>> 4));
buffer.append(Integer.toHexString(data[i] & 0x0f));
}
return buffer.toString();
}
}