/* jcifs smb client library in Java
* Copyright (C) 2002 "Michael B. Allen" <jcifs at samba dot org>
* "Eric Glass" <jcifs at samba dot org>
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/
package jcifs.http;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.Authenticator;
import java.net.HttpURLConnection;
import java.net.PasswordAuthentication;
import java.net.ProtocolException;
import java.net.URL;
import java.net.URLDecoder;
import java.security.Permission;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import jcifs.Config;
import jcifs.ntlmssp.NtlmFlags;
import jcifs.ntlmssp.NtlmMessage;
import jcifs.ntlmssp.Type1Message;
import jcifs.ntlmssp.Type2Message;
import jcifs.ntlmssp.Type3Message;
import jcifs.util.Base64;
/**
* Wraps an <code>HttpURLConnection</code> to provide NTLM authentication
* services.
*
* Please read <a href="../../../httpclient.html">Using jCIFS NTLM Authentication for HTTP Connections</a>.
*/
public class NtlmHttpURLConnection extends HttpURLConnection {
private static final int MAX_REDIRECTS =
Integer.parseInt(System.getProperty("http.maxRedirects", "20"));
private static final int LM_COMPATIBILITY =
Config.getInt("jcifs.smb.lmCompatibility", 0);
private static final String DEFAULT_DOMAIN;
private HttpURLConnection connection;
private Map requestProperties;
private Map headerFields;
private ByteArrayOutputStream cachedOutput;
private String authProperty;
private String authMethod;
private boolean handshakeComplete;
static {
String domain = System.getProperty("http.auth.ntlm.domain");
if (domain == null) domain = Type3Message.getDefaultDomain();
DEFAULT_DOMAIN = domain;
}
public NtlmHttpURLConnection(HttpURLConnection connection) {
super(connection.getURL());
this.connection = connection;
requestProperties = new HashMap();
}
public void connect() throws IOException {
if (connected) return;
connection.connect();
connected = true;
}
private void handshake() throws IOException {
if (handshakeComplete) return;
doHandshake();
handshakeComplete = true;
}
public URL getURL() {
return connection.getURL();
}
public int getContentLength() {
try {
handshake();
} catch (IOException ex) { }
return connection.getContentLength();
}
public String getContentType() {
try {
handshake();
} catch (IOException ex) { }
return connection.getContentType();
}
public String getContentEncoding() {
try {
handshake();
} catch (IOException ex) { }
return connection.getContentEncoding();
}
public long getExpiration() {
try {
handshake();
} catch (IOException ex) { }
return connection.getExpiration();
}
public long getDate() {
try {
handshake();
} catch (IOException ex) { }
return connection.getDate();
}
public long getLastModified() {
try {
handshake();
} catch (IOException ex) { }
return connection.getLastModified();
}
public String getHeaderField(String header) {
try {
handshake();
} catch (IOException ex) { }
return connection.getHeaderField(header);
}
private Map getHeaderFields0() {
if (headerFields != null) return headerFields;
Map map = new HashMap();
String key = connection.getHeaderFieldKey(0);
String value = connection.getHeaderField(0);
for (int i = 1; key != null || value != null; i++) {
List values = (List) map.get(key);
if (values == null) {
values = new ArrayList();
map.put(key, values);
}
values.add(value);
key = connection.getHeaderFieldKey(i);
value = connection.getHeaderField(i);
}
Iterator entries = map.entrySet().iterator();
while (entries.hasNext()) {
Map.Entry entry = (Map.Entry) entries.next();
entry.setValue(Collections.unmodifiableList((List)
entry.getValue()));
}
return (headerFields = Collections.unmodifiableMap(map));
}
public Map getHeaderFields() {
if (headerFields != null) return headerFields;
try {
handshake();
} catch (IOException ex) { }
return getHeaderFields0();
}
public int getHeaderFieldInt(String header, int def) {
try {
handshake();
} catch (IOException ex) { }
return connection.getHeaderFieldInt(header, def);
}
public long getHeaderFieldDate(String header, long def) {
try {
handshake();
} catch (IOException ex) { }
return connection.getHeaderFieldDate(header, def);
}
public String getHeaderFieldKey(int index) {
try {
handshake();
} catch (IOException ex) { }
return connection.getHeaderFieldKey(index);
}
public String getHeaderField(int index) {
try {
handshake();
} catch (IOException ex) { }
return connection.getHeaderField(index);
}
public Object getContent() throws IOException {
try {
handshake();
} catch (IOException ex) { }
return connection.getContent();
}
public Object getContent(Class[] classes) throws IOException {
try {
handshake();
} catch (IOException ex) { }
return connection.getContent(classes);
}
public Permission getPermission() throws IOException {
return connection.getPermission();
}
public InputStream getInputStream() throws IOException {
try {
handshake();
} catch (IOException ex) { }
return connection.getInputStream();
}
public OutputStream getOutputStream() throws IOException {
try {
connect();
} catch (IOException ex) { }
OutputStream output = connection.getOutputStream();
cachedOutput = new ByteArrayOutputStream();
return new CacheStream(output, cachedOutput);
}
public String toString() {
return connection.toString();
}
public void setDoInput(boolean doInput) {
connection.setDoInput(doInput);
this.doInput = doInput;
}
public boolean getDoInput() {
return connection.getDoInput();
}
public void setDoOutput(boolean doOutput) {
connection.setDoOutput(doOutput);
this.doOutput = doOutput;
}
public boolean getDoOutput() {
return connection.getDoOutput();
}
public void setAllowUserInteraction(boolean allowUserInteraction) {
connection.setAllowUserInteraction(allowUserInteraction);
this.allowUserInteraction = allowUserInteraction;
}
public boolean getAllowUserInteraction() {
return connection.getAllowUserInteraction();
}
public void setUseCaches(boolean useCaches) {
connection.setUseCaches(useCaches);
this.useCaches = useCaches;
}
public boolean getUseCaches() {
return connection.getUseCaches();
}
public void setIfModifiedSince(long ifModifiedSince) {
connection.setIfModifiedSince(ifModifiedSince);
this.ifModifiedSince = ifModifiedSince;
}
public long getIfModifiedSince() {
return connection.getIfModifiedSince();
}
public boolean getDefaultUseCaches() {
return connection.getDefaultUseCaches();
}
public void setDefaultUseCaches(boolean defaultUseCaches) {
connection.setDefaultUseCaches(defaultUseCaches);
}
public void setRequestProperty(String key, String value) {
if (key == null) throw new NullPointerException();
List values = new ArrayList();
values.add(value);
boolean found = false;
Iterator entries = requestProperties.entrySet().iterator();
while (entries.hasNext()) {
Map.Entry entry = (Map.Entry) entries.next();
if (key.equalsIgnoreCase((String) entry.getKey())) {
entry.setValue(values);
found = true;
break;
}
}
if (!found) requestProperties.put(key, values);
connection.setRequestProperty(key, value);
}
public void addRequestProperty(String key, String value) {
if (key == null) throw new NullPointerException();
List values = null;
Iterator entries = requestProperties.entrySet().iterator();
while (entries.hasNext()) {
Map.Entry entry = (Map.Entry) entries.next();
if (key.equalsIgnoreCase((String) entry.getKey())) {
values = (List) entry.getValue();
values.add(value);
break;
}
}
if (values == null) {
values = new ArrayList();
values.add(value);
requestProperties.put(key, values);
}
// 1.3-compatible.
StringBuffer buffer = new StringBuffer();
Iterator propertyValues = values.iterator();
while (propertyValues.hasNext()) {
buffer.append(propertyValues.next());
if (propertyValues.hasNext()) {
buffer.append(", ");
}
}
connection.setRequestProperty(key, buffer.toString());
}
public String getRequestProperty(String key) {
return connection.getRequestProperty(key);
}
public Map getRequestProperties() {
Map map = new HashMap();
Iterator entries = requestProperties.entrySet().iterator();
while (entries.hasNext()) {
Map.Entry entry = (Map.Entry) entries.next();
map.put(entry.getKey(),
Collections.unmodifiableList((List) entry.getValue()));
}
return Collections.unmodifiableMap(map);
}
public void setInstanceFollowRedirects(boolean instanceFollowRedirects) {
connection.setInstanceFollowRedirects(instanceFollowRedirects);
}
public boolean getInstanceFollowRedirects() {
return connection.getInstanceFollowRedirects();
}
public void setRequestMethod(String requestMethod)
throws ProtocolException {
connection.setRequestMethod(requestMethod);
this.method = requestMethod;
}
public String getRequestMethod() {
return connection.getRequestMethod();
}
public int getResponseCode() throws IOException {
try {
handshake();
} catch (IOException ex) { }
return connection.getResponseCode();
}
public String getResponseMessage() throws IOException {
try {
handshake();
} catch (IOException ex) { }
return connection.getResponseMessage();
}
public void disconnect() {
connection.disconnect();
handshakeComplete = false;
connected = false;
}
public boolean usingProxy() {
return connection.usingProxy();
}
public InputStream getErrorStream() {
try {
handshake();
} catch (IOException ex) { }
return connection.getErrorStream();
}
private int parseResponseCode() throws IOException {
try {
String response = connection.getHeaderField(0);
int index = response.indexOf(' ');
while (response.charAt(index) == ' ') index++;
return Integer.parseInt(response.substring(index, index + 3));
} catch (Exception ex) {
throw new IOException(ex.getMessage());
}
}
private void doHandshake() throws IOException {
connect();
try {
int response = parseResponseCode();
if (response != HTTP_UNAUTHORIZED && response != HTTP_PROXY_AUTH) {
return;
}
Type1Message type1 = (Type1Message) attemptNegotiation(response);
if (type1 == null) return; // no NTLM
int attempt = 0;
while (attempt < MAX_REDIRECTS) {
connection.setRequestProperty(authProperty, authMethod + ' ' +
Base64.encode(type1.toByteArray()));
connection.connect(); // send type 1
response = parseResponseCode();
if (response != HTTP_UNAUTHORIZED &&
response != HTTP_PROXY_AUTH) {
return;
}
Type3Message type3 = (Type3Message)
attemptNegotiation(response);
if (type3 == null) return;
connection.setRequestProperty(authProperty, authMethod + ' ' +
Base64.encode(type3.toByteArray()));
connection.connect(); // send type 3
if (cachedOutput != null && doOutput) {
OutputStream output = connection.getOutputStream();
cachedOutput.writeTo(output);
output.flush();
}
response = parseResponseCode();
if (response != HTTP_UNAUTHORIZED &&
response != HTTP_PROXY_AUTH) {
return;
}
attempt++;
if (allowUserInteraction && attempt < MAX_REDIRECTS) {
reconnect();
} else {
break;
}
}
throw new IOException("Unable to negotiate NTLM authentication.");
} finally {
cachedOutput = null;
}
}
private NtlmMessage attemptNegotiation(int response) throws IOException {
authProperty = null;
authMethod = null;
InputStream errorStream = connection.getErrorStream();
if (errorStream != null && errorStream.available() != 0) {
int count;
byte[] buf = new byte[1024];
while ((count = errorStream.read(buf, 0, 1024)) != -1);
}
String authHeader;
if (response == HTTP_UNAUTHORIZED) {
authHeader = "WWW-Authenticate";
authProperty = "Authorization";
} else {
authHeader = "Proxy-Authenticate";
authProperty = "Proxy-Authorization";
}
String authorization = null;
List methods = (List) getHeaderFields0().get(authHeader);
if (methods == null) return null;
Iterator iterator = methods.iterator();
while (iterator.hasNext()) {
String currentAuthMethod = (String) iterator.next();
if (currentAuthMethod.startsWith("NTLM")) {
if (currentAuthMethod.length() == 4) {
authMethod = "NTLM";
break;
}
if (currentAuthMethod.indexOf(' ') != 4) continue;
authMethod = "NTLM";
authorization = currentAuthMethod.substring(5).trim();
break;
} else if (currentAuthMethod.startsWith("Negotiate")) {
if (currentAuthMethod.length() == 9) {
authMethod = "Negotiate";
break;
}
if (currentAuthMethod.indexOf(' ') != 9) continue;
authMethod = "Negotiate";
authorization = currentAuthMethod.substring(10).trim();
break;
}
}
if (authMethod == null) return null;
NtlmMessage message = (authorization != null) ?
new Type2Message(Base64.decode(authorization)) : null;
reconnect();
if (message == null) {
message = new Type1Message();
if (LM_COMPATIBILITY > 2) {
message.setFlag(NtlmFlags.NTLMSSP_REQUEST_TARGET, true);
}
} else {
String domain = DEFAULT_DOMAIN;
String user = Type3Message.getDefaultUser();
String password = Type3Message.getDefaultPassword();
String userInfo = url.getUserInfo();
if (userInfo != null) {
userInfo = URLDecoder.decode(userInfo);
int index = userInfo.indexOf(':');
user = (index != -1) ? userInfo.substring(0, index) : userInfo;
if (index != -1) password = userInfo.substring(index + 1);
index = user.indexOf('\\');
if (index == -1) index = user.indexOf('/');
domain = (index != -1) ? user.substring(0, index) : domain;
user = (index != -1) ? user.substring(index + 1) : user;
}
if (user == null) {
if (!allowUserInteraction) return null;
try {
URL url = getURL();
String protocol = url.getProtocol();
int port = url.getPort();
if (port == -1) {
port = "https".equalsIgnoreCase(protocol) ? 443 : 80;
}
PasswordAuthentication auth =
Authenticator.requestPasswordAuthentication(null,
port, protocol, "", authMethod);
if (auth == null) return null;
user = auth.getUserName();
password = new String(auth.getPassword());
} catch (Exception ex) { }
}
Type2Message type2 = (Type2Message) message;
message = new Type3Message(type2, password, domain, user,
Type3Message.getDefaultWorkstation(), 0);
}
return message;
}
private void reconnect() throws IOException {
connection = (HttpURLConnection) connection.getURL().openConnection();
connection.setRequestMethod(method);
headerFields = null;
Iterator properties = requestProperties.entrySet().iterator();
while (properties.hasNext()) {
Map.Entry property = (Map.Entry) properties.next();
String key = (String) property.getKey();
StringBuffer value = new StringBuffer();
Iterator values = ((List) property.getValue()).iterator();
while (values.hasNext()) {
value.append(values.next());
if (values.hasNext()) value.append(", ");
}
connection.setRequestProperty(key, value.toString());
}
connection.setAllowUserInteraction(allowUserInteraction);
connection.setDoInput(doInput);
connection.setDoOutput(doOutput);
connection.setIfModifiedSince(ifModifiedSince);
connection.setUseCaches(useCaches);
}
private static class CacheStream extends OutputStream {
private final OutputStream stream;
private final OutputStream collector;
public CacheStream(OutputStream stream, OutputStream collector) {
this.stream = stream;
this.collector = collector;
}
public void close() throws IOException {
stream.close();
collector.close();
}
public void flush() throws IOException {
stream.flush();
collector.flush();
}
public void write(byte[] b) throws IOException {
stream.write(b);
collector.write(b);
}
public void write(byte[] b, int off, int len) throws IOException {
stream.write(b, off, len);
collector.write(b, off, len);
}
public void write(int b) throws IOException {
stream.write(b);
collector.write(b);
}
}
}