Package jcifs.http

Source Code of jcifs.http.NtlmHttpURLConnection$CacheStream

/* jcifs smb client library in Java
* Copyright (C) 2002  "Michael B. Allen" <jcifs at samba dot org>
*                   "Eric Glass" <jcifs at samba dot org>
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

package jcifs.http;

import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.io.IOException;
import java.io.OutputStream;

import java.net.Authenticator;
import java.net.HttpURLConnection;
import java.net.PasswordAuthentication;
import java.net.ProtocolException;
import java.net.URL;
import java.net.URLDecoder;

import java.security.Permission;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import jcifs.Config;

import jcifs.ntlmssp.NtlmFlags;
import jcifs.ntlmssp.NtlmMessage;
import jcifs.ntlmssp.Type1Message;
import jcifs.ntlmssp.Type2Message;
import jcifs.ntlmssp.Type3Message;

import jcifs.util.Base64;

/**
* Wraps an <code>HttpURLConnection</code> to provide NTLM authentication
* services.
*
* Please read <a href="../../../httpclient.html">Using jCIFS NTLM Authentication for HTTP Connections</a>.
*/
public class NtlmHttpURLConnection extends HttpURLConnection {

    private static final int MAX_REDIRECTS =
            Integer.parseInt(System.getProperty("http.maxRedirects", "20"));

    private static final int LM_COMPATIBILITY =
            Config.getInt("jcifs.smb.lmCompatibility", 0);

    private static final String DEFAULT_DOMAIN;

    private HttpURLConnection connection;

    private Map requestProperties;

    private Map headerFields;

    private ByteArrayOutputStream cachedOutput;

    private String authProperty;

    private String authMethod;

    private boolean handshakeComplete;

    static {
        String domain = System.getProperty("http.auth.ntlm.domain");
        if (domain == null) domain = Type3Message.getDefaultDomain();
        DEFAULT_DOMAIN = domain;
    }

    public NtlmHttpURLConnection(HttpURLConnection connection) {
        super(connection.getURL());
        this.connection = connection;
        requestProperties = new HashMap();
    }

    public void connect() throws IOException {
        if (connected) return;
        connection.connect();
        connected = true;
    }

    private void handshake() throws IOException {
        if (handshakeComplete) return;
        doHandshake();
        handshakeComplete = true;
    }

    public URL getURL() {
        return connection.getURL();
    }

    public int getContentLength() {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getContentLength();
    }

    public String getContentType() {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getContentType();
    }

    public String getContentEncoding() {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getContentEncoding();
    }

    public long getExpiration() {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getExpiration();
    }

    public long getDate() {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getDate();
    }

    public long getLastModified() {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getLastModified();
    }

    public String getHeaderField(String header) {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getHeaderField(header);
    }

    private Map getHeaderFields0() {
        if (headerFields != null) return headerFields;
        Map map = new HashMap();
        String key = connection.getHeaderFieldKey(0);
        String value = connection.getHeaderField(0);
        for (int i = 1; key != null || value != null; i++) {
            List values = (List) map.get(key);
            if (values == null) {
                values = new ArrayList();
                map.put(key, values);
            }
            values.add(value);
            key = connection.getHeaderFieldKey(i);
            value = connection.getHeaderField(i);
        }
        Iterator entries = map.entrySet().iterator();
        while (entries.hasNext()) {
            Map.Entry entry = (Map.Entry) entries.next();
            entry.setValue(Collections.unmodifiableList((List)
                    entry.getValue()));
        }
        return (headerFields = Collections.unmodifiableMap(map));
    }

    public Map getHeaderFields() {
        if (headerFields != null) return headerFields;
        try {
            handshake();
        } catch (IOException ex) { }
        return getHeaderFields0();
    }

    public int getHeaderFieldInt(String header, int def) {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getHeaderFieldInt(header, def);
    }

    public long getHeaderFieldDate(String header, long def) {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getHeaderFieldDate(header, def);
    }

    public String getHeaderFieldKey(int index) {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getHeaderFieldKey(index);
    }

    public String getHeaderField(int index) {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getHeaderField(index);
    }

    public Object getContent() throws IOException {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getContent();
    }

    public Object getContent(Class[] classes) throws IOException {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getContent(classes);
    }

    public Permission getPermission() throws IOException {
        return connection.getPermission();
    }

    public InputStream getInputStream() throws IOException {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getInputStream();
    }

    public OutputStream getOutputStream() throws IOException {
        try {
            connect();
        } catch (IOException ex) { }
        OutputStream output = connection.getOutputStream();
        cachedOutput = new ByteArrayOutputStream();
        return new CacheStream(output, cachedOutput);
    }

    public String toString() {
        return connection.toString();
    }

    public void setDoInput(boolean doInput) {
        connection.setDoInput(doInput);
        this.doInput = doInput;
    }

    public boolean getDoInput() {
        return connection.getDoInput();
    }

    public void setDoOutput(boolean doOutput) {
        connection.setDoOutput(doOutput);
        this.doOutput = doOutput;
    }

    public boolean getDoOutput() {
        return connection.getDoOutput();
    }

    public void setAllowUserInteraction(boolean allowUserInteraction) {
        connection.setAllowUserInteraction(allowUserInteraction);
        this.allowUserInteraction = allowUserInteraction;
    }

    public boolean getAllowUserInteraction() {
        return connection.getAllowUserInteraction();
    }

    public void setUseCaches(boolean useCaches) {
        connection.setUseCaches(useCaches);
        this.useCaches = useCaches;
    }

    public boolean getUseCaches() {
        return connection.getUseCaches();
    }

    public void setIfModifiedSince(long ifModifiedSince) {
        connection.setIfModifiedSince(ifModifiedSince);
        this.ifModifiedSince = ifModifiedSince;
    }

    public long getIfModifiedSince() {
        return connection.getIfModifiedSince();
    }

    public boolean getDefaultUseCaches() {
        return connection.getDefaultUseCaches();
    }

    public void setDefaultUseCaches(boolean defaultUseCaches) {
        connection.setDefaultUseCaches(defaultUseCaches);
    }

    public void setRequestProperty(String key, String value) {
        if (key == null) throw new NullPointerException();
        List values = new ArrayList();
        values.add(value);
        boolean found = false;
        Iterator entries = requestProperties.entrySet().iterator();
        while (entries.hasNext()) {
            Map.Entry entry = (Map.Entry) entries.next();
            if (key.equalsIgnoreCase((String) entry.getKey())) {
                entry.setValue(values);
                found = true;
                break;
            }
        }
        if (!found) requestProperties.put(key, values);
        connection.setRequestProperty(key, value);
    }

    public void addRequestProperty(String key, String value) {
        if (key == null) throw new NullPointerException();
        List values = null;
        Iterator entries = requestProperties.entrySet().iterator();
        while (entries.hasNext()) {
            Map.Entry entry = (Map.Entry) entries.next();
            if (key.equalsIgnoreCase((String) entry.getKey())) {
                values = (List) entry.getValue();
                values.add(value);
                break;
            }
        }
        if (values == null) {
            values = new ArrayList();
            values.add(value);
            requestProperties.put(key, values);
        }
        // 1.3-compatible.
        StringBuffer buffer = new StringBuffer();
        Iterator propertyValues = values.iterator();
        while (propertyValues.hasNext()) {
            buffer.append(propertyValues.next());
            if (propertyValues.hasNext()) {
                buffer.append(", ");
            }
        }
        connection.setRequestProperty(key, buffer.toString());
    }

    public String getRequestProperty(String key) {
        return connection.getRequestProperty(key);
    }

    public Map getRequestProperties() {
        Map map = new HashMap();
        Iterator entries = requestProperties.entrySet().iterator();
        while (entries.hasNext()) {
            Map.Entry entry = (Map.Entry) entries.next();
            map.put(entry.getKey(),
                    Collections.unmodifiableList((List) entry.getValue()));
        }
        return Collections.unmodifiableMap(map);
    }

    public void setInstanceFollowRedirects(boolean instanceFollowRedirects) {
        connection.setInstanceFollowRedirects(instanceFollowRedirects);
    }

    public boolean getInstanceFollowRedirects() {
        return connection.getInstanceFollowRedirects();
    }

    public void setRequestMethod(String requestMethod)
            throws ProtocolException {
        connection.setRequestMethod(requestMethod);
        this.method = requestMethod;
    }

    public String getRequestMethod() {
        return connection.getRequestMethod();
    }

    public int getResponseCode() throws IOException {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getResponseCode();
    }

    public String getResponseMessage() throws IOException {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getResponseMessage();
    }

    public void disconnect() {
        connection.disconnect();
        handshakeComplete = false;
        connected = false;
    }

    public boolean usingProxy() {
        return connection.usingProxy();
    }

    public InputStream getErrorStream() {
        try {
            handshake();
        } catch (IOException ex) { }
        return connection.getErrorStream();
    }

    private int parseResponseCode() throws IOException {
        try {
            String response = connection.getHeaderField(0);
            int index = response.indexOf(' ');
            while (response.charAt(index) == ' ') index++;
            return Integer.parseInt(response.substring(index, index + 3));
        } catch (Exception ex) {
            throw new IOException(ex.getMessage());
        }
    }

    private void doHandshake() throws IOException {
        connect();
        try {
            int response = parseResponseCode();
            if (response != HTTP_UNAUTHORIZED && response != HTTP_PROXY_AUTH) {
                return;
            }
            Type1Message type1 = (Type1Message) attemptNegotiation(response);
            if (type1 == null) return; // no NTLM
            int attempt = 0;
            while (attempt < MAX_REDIRECTS) {
                connection.setRequestProperty(authProperty, authMethod + ' ' +
                        Base64.encode(type1.toByteArray()));
                connection.connect(); // send type 1
                response = parseResponseCode();
                if (response != HTTP_UNAUTHORIZED &&
                        response != HTTP_PROXY_AUTH) {
                    return;
                }
                Type3Message type3 = (Type3Message)
                        attemptNegotiation(response);
                if (type3 == null) return;
                connection.setRequestProperty(authProperty, authMethod + ' ' +
                        Base64.encode(type3.toByteArray()));
                connection.connect(); // send type 3
                if (cachedOutput != null && doOutput) {
                    OutputStream output = connection.getOutputStream();
                    cachedOutput.writeTo(output);
                    output.flush();
                }
                response = parseResponseCode();
                if (response != HTTP_UNAUTHORIZED &&
                        response != HTTP_PROXY_AUTH) {
                    return;
                }
                attempt++;
                if (allowUserInteraction && attempt < MAX_REDIRECTS) {
                    reconnect();
                } else {
                    break;
                }
            }
            throw new IOException("Unable to negotiate NTLM authentication.");
        } finally {
            cachedOutput = null;
        }
    }

    private NtlmMessage attemptNegotiation(int response) throws IOException {
        authProperty = null;
        authMethod = null;
        InputStream errorStream = connection.getErrorStream();
        if (errorStream != null && errorStream.available() != 0) {
            int count;
            byte[] buf = new byte[1024];
            while ((count = errorStream.read(buf, 0, 1024)) != -1);
        }
        String authHeader;
        if (response == HTTP_UNAUTHORIZED) {
            authHeader = "WWW-Authenticate";
            authProperty = "Authorization";
        } else {
            authHeader = "Proxy-Authenticate";
            authProperty = "Proxy-Authorization";
        }
        String authorization = null;
        List methods = (List) getHeaderFields0().get(authHeader);
        if (methods == null) return null;
        Iterator iterator = methods.iterator();
        while (iterator.hasNext()) {
            String currentAuthMethod = (String) iterator.next();
            if (currentAuthMethod.startsWith("NTLM")) {
                if (currentAuthMethod.length() == 4) {
                    authMethod = "NTLM";
                    break;
                }
                if (currentAuthMethod.indexOf(' ') != 4) continue;
                authMethod = "NTLM";
                authorization = currentAuthMethod.substring(5).trim();
                break;
            } else if (currentAuthMethod.startsWith("Negotiate")) {
                if (currentAuthMethod.length() == 9) {
                    authMethod = "Negotiate";
                    break;
                }
                if (currentAuthMethod.indexOf(' ') != 9) continue;
                authMethod = "Negotiate";
                authorization = currentAuthMethod.substring(10).trim();
                break;
            }
        }
        if (authMethod == null) return null;
        NtlmMessage message = (authorization != null) ?
                new Type2Message(Base64.decode(authorization)) : null;
        reconnect();
        if (message == null) {
            message = new Type1Message();
            if (LM_COMPATIBILITY > 2) {
                message.setFlag(NtlmFlags.NTLMSSP_REQUEST_TARGET, true);
            }
        } else {
            String domain = DEFAULT_DOMAIN;
            String user = Type3Message.getDefaultUser();
            String password = Type3Message.getDefaultPassword();
            String userInfo = url.getUserInfo();
            if (userInfo != null) {
                userInfo = URLDecoder.decode(userInfo);
                int index = userInfo.indexOf(':');
                user = (index != -1) ? userInfo.substring(0, index) : userInfo;
                if (index != -1) password = userInfo.substring(index + 1);
                index = user.indexOf('\\');
                if (index == -1) index = user.indexOf('/');
                domain = (index != -1) ? user.substring(0, index) : domain;
                user = (index != -1) ? user.substring(index + 1) : user;
            }
            if (user == null) {
                if (!allowUserInteraction) return null;
                try {
                    URL url = getURL();
                    String protocol = url.getProtocol();
                    int port = url.getPort();
                    if (port == -1) {
                        port = "https".equalsIgnoreCase(protocol) ? 443 : 80;
                    }
                    PasswordAuthentication auth =
                            Authenticator.requestPasswordAuthentication(null,
                                    port, protocol, "", authMethod);
                    if (auth == null) return null;
                    user = auth.getUserName();
                    password = new String(auth.getPassword());
                } catch (Exception ex) { }
            }
            Type2Message type2 = (Type2Message) message;
            message = new Type3Message(type2, password, domain, user,
                    Type3Message.getDefaultWorkstation(), 0);
        }
        return message;
    }

    private void reconnect() throws IOException {
        connection = (HttpURLConnection) connection.getURL().openConnection();
        connection.setRequestMethod(method);
        headerFields = null;
        Iterator properties = requestProperties.entrySet().iterator();
        while (properties.hasNext()) {
            Map.Entry property = (Map.Entry) properties.next();
            String key = (String) property.getKey();
            StringBuffer value = new StringBuffer();
            Iterator values = ((List) property.getValue()).iterator();
            while (values.hasNext()) {
                value.append(values.next());
                if (values.hasNext()) value.append(", ");
            }
            connection.setRequestProperty(key, value.toString());
        }
        connection.setAllowUserInteraction(allowUserInteraction);
        connection.setDoInput(doInput);
        connection.setDoOutput(doOutput);
        connection.setIfModifiedSince(ifModifiedSince);
        connection.setUseCaches(useCaches);
    }

    private static class CacheStream extends OutputStream {

        private final OutputStream stream;

        private final OutputStream collector;

        public CacheStream(OutputStream stream, OutputStream collector) {
            this.stream = stream;
            this.collector = collector;
        }

        public void close() throws IOException {
            stream.close();
            collector.close();
        }

        public void flush() throws IOException {
            stream.flush();
            collector.flush();
        }

        public void write(byte[] b) throws IOException {
            stream.write(b);
            collector.write(b);
        }

        public void write(byte[] b, int off, int len) throws IOException {
            stream.write(b, off, len);
            collector.write(b, off, len);
        }

        public void write(int b) throws IOException {
            stream.write(b);
            collector.write(b);
        }

    }

}
TOP

Related Classes of jcifs.http.NtlmHttpURLConnection$CacheStream

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.