package play.libs;
import java.net.URI;
import java.net.URLEncoder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang.StringUtils;
import org.w3c.dom.Document;
import play.Logger;
import play.exceptions.PlayException;
import play.libs.WS.HttpResponse;
import play.mvc.Router;
import play.mvc.Http.Request;
import play.mvc.Scope.Params;
import play.mvc.results.Redirect;
public class OpenID {
private OpenID(String id) {
this.id = id;
this.returnAction = this.realmAction = Request.current().action;
}
// ~~~ API
String id;
String returnAction;
String realmAction;
List<String> sregRequired = new ArrayList<String>();
List<String> sregOptional = new ArrayList<String>();
Map<String, String> axRequired = new HashMap<String, String>();
Map<String, String> axOptional = new HashMap<String, String>();
public OpenID returnTo(String action) {
this.returnAction = action;
return this;
}
public OpenID forRealm(String action) {
this.realmAction = action;
return this;
}
public OpenID required(String alias) {
this.sregRequired.add(alias);
return this;
}
public OpenID required(String alias, String schema) {
this.axRequired.put(alias, schema);
return this;
}
public OpenID optional(String alias) {
this.sregOptional.add(alias);
return this;
}
public OpenID optional(String alias, String schema) {
this.axOptional.put(alias, schema);
return this;
}
public boolean verify() {
try {
// Normalize
String claimedId = normalize(id);
String server = null;
String delegate = null;
// Discover
HttpResponse response = WS.url(claimedId).get();
// Try HTML (I know it's bad)
String html = response.getString();
server = discoverServer(html);
if (server == null) {
// Try YADIS
Document xrds = null;
if (response.getContentType().contains("application/xrds+xml")) {
xrds = response.getXml();
} else if (response.getHeader("X-XRDS-Location") != null) {
xrds = WS.url(response.getHeader("X-XRDS-Location")).get().getXml();
} else {
return false;
}
// Ok we have the XRDS file
server = XPath.selectText("//Type[text()='http://specs.openid.net/auth/2.0/server']/following-sibling::URI/text()", xrds);
claimedId = XPath.selectText("//Type[text()='http://specs.openid.net/auth/2.0/signon']/following-sibling::LocalID/text()", xrds);
if (claimedId == null) {
claimedId = "http://specs.openid.net/auth/2.0/identifier_select";
} else {
server = XPath.selectText("//Type[text()='http://specs.openid.net/auth/2.0/signon']/following-sibling::URI/text()", xrds);
}
if (server == null) {
return false;
}
} else {
// Delegate
Matcher openid2Localid = Pattern.compile("<link[^>]+openid2[.]local_id[^>]+>", Pattern.CASE_INSENSITIVE).matcher(html);
Matcher openidDelegate = Pattern.compile("<link[^>]+openid[.]delegate[^>]+>", Pattern.CASE_INSENSITIVE).matcher(html);
if (openid2Localid.find()) {
delegate = extractHref(openid2Localid.group());
} else if (openidDelegate.find()) {
delegate = extractHref(openidDelegate.group());
}
}
// Redirect
String url = server;
if (!server.contains("?")) {
url += "?";
}
if (!url.endsWith("?") && !url.endsWith("&")) {
url += "&";
}
url += "openid.ns=" + URLEncoder.encode("http://specs.openid.net/auth/2.0", "UTF-8");
url += "&openid.mode=checkid_setup";
url += "&openid.claimed_id=" + URLEncoder.encode(claimedId, "utf8");
url += "&openid.identity=" + URLEncoder.encode(delegate == null ? claimedId : delegate, "utf8");
if (returnAction != null && returnAction.startsWith("http://")) {
url += "&openid.return_to=" + URLEncoder.encode(returnAction, "utf8");
} else {
url += "&openid.return_to=" + URLEncoder.encode(Request.current().getBase() + Router.reverse(returnAction), "utf8");
}
if (realmAction != null && realmAction.startsWith("http://")) {
url += "&openid.realm=" + URLEncoder.encode(realmAction, "utf8");
} else {
url += "&openid.realm=" + URLEncoder.encode(Request.current().getBase() + Router.reverse(realmAction), "utf8");
}
if (!sregOptional.isEmpty() || !sregRequired.isEmpty()) {
url += "&openid.ns.sreg=" + URLEncoder.encode("http://openid.net/extensions/sreg/1.1", "UTF-8");
}
String sregO = "";
for (String a : sregOptional) {
sregO += URLEncoder.encode(a, "UTF-8") + ",";
}
if (!StringUtils.isEmpty(sregO)) {
url += "&openid.sreg.optional=" + sregO.substring(0, sregO.length() - 1);
}
String sregR = "";
for (String a : sregRequired) {
sregR += URLEncoder.encode(a, "UTF-8") + ",";
}
if (!StringUtils.isEmpty(sregR)) {
url += "&openid.sreg.required=" + sregR.substring(0, sregR.length() - 1);
}
if (!axRequired.isEmpty() || !axOptional.isEmpty()) {
url += "&openid.ns.ax=http%3A%2F%2Fopenid.net%2Fsrv%2Fax%2F1.0";
url += "&openid.ax.mode=fetch_request";
for (String a : axOptional.keySet()) {
url += "&openid.ax.type." + a + "=" + URLEncoder.encode(axOptional.get(a), "UTF-8");
}
for (String a : axRequired.keySet()) {
url += "&openid.ax.type." + a + "=" + URLEncoder.encode(axRequired.get(a), "UTF-8");
}
if (!axRequired.isEmpty()) {
String r = "";
for (String a : axRequired.keySet()) {
r += "," + a;
}
r = r.substring(1);
url += "&openid.ax.required=" + r;
}
if (!axOptional.isEmpty()) {
String r = "";
for (String a : axOptional.keySet()) {
r += "," + a;
}
r = r.substring(1);
url += "&openid.ax.if_available=" + r;
}
}
// Debug
Logger.trace("Send request %s", url);
throw new Redirect(url);
} catch (Redirect e) {
throw e;
} catch (PlayException e) {
throw e;
} catch (Exception e) {
return false;
}
}
// ~~~~ Main API
public static OpenID id(String id) {
return new OpenID(id);
}
/**
* Normalize the given openid as a standard openid
*/
public static String normalize(String openID) {
openID = openID.trim();
if (!openID.startsWith("http://") && !openID.startsWith("https://")) {
openID = "http://" + openID;
}
try {
URI url = new URI(openID);
String frag = url.getRawFragment();
if (frag != null && frag.length() > 0) {
openID = openID.replace("#" + frag, "");
}
if (url.getPath().equals("")) {
openID += "/";
}
openID = new URI(openID).toString();
} catch (Exception e) {
throw new RuntimeException(openID + " is not a valid URL");
}
return openID;
}
/**
* Is the current request an authentication response from the OP ?
*/
public static boolean isAuthenticationResponse() {
return Params.current().get("openid.mode") != null;
}
/**
* Retrieve the verified OpenID
* @return A UserInfo object
*/
public static UserInfo getVerifiedID() {
try {
String mode = Params.current().get("openid.mode");
// Check authentication
if (mode != null && mode.equals("id_res")) {
// id
String id = Params.current().get("openid.claimed_id");
if (id == null) {
id = Params.current().get("openid.identity");
}
id = normalize(id);
// server
String server = Params.current().get("openid.op_endpoint");
if (server == null) {
server = discoverServer(id);
}
String fields = Request.current().querystring.replace("openid.mode=id_res", "openid.mode=check_authentication");
WS.HttpResponse response = WS.url(server).mimeType("application/x-www-form-urlencoded").body(fields).post();
if (response.getStatus() == 200 && response.getString().contains("is_valid:true")) {
UserInfo userInfo = new UserInfo();
userInfo.id = id;
Pattern patternAX = Pattern.compile("^openid[.].+[.]value[.]([^.]+)([.]\\d+)?$");
Pattern patternSREG = Pattern.compile("^openid[.]sreg[.]([^.]+)$");
for (String p : Params.current().allSimple().keySet()) {
Matcher m = patternAX.matcher(p);
if (m.matches()) {
String alias = m.group(1);
userInfo.extensions.put(alias, Params.current().get(p));
}
m = patternSREG.matcher(p);
if (m.matches()) {
String alias = m.group(1);
userInfo.extensions.put(alias, Params.current().get(p));
}
}
return userInfo;
} else {
return null;
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return null;
}
// ~~~~ Utils
static String extractHref(String link) {
Matcher m = Pattern.compile("href=\"([^\"]*)\"").matcher(link);
if (m.find()) {
return m.group(1).trim();
}
m = Pattern.compile("href=\'([^\']*)\'").matcher(link);
if (m.find()) {
return m.group(1).trim();
}
return null;
}
public static String discoverServer(String openid) {
if (openid.startsWith("http")) {
openid = WS.url(openid).get().getString();
}
Matcher openid2Provider = Pattern.compile("<link[^>]+openid2[.]provider[^>]+>", Pattern.CASE_INSENSITIVE).matcher(openid);
Matcher openidServer = Pattern.compile("<link[^>]+openid[.]server[^>]+>", Pattern.CASE_INSENSITIVE).matcher(openid);
String server = null;
if (openid2Provider.find()) {
server = extractHref(openid2Provider.group());
} else if (openidServer.find()) {
server = extractHref(openidServer.group());
}
return server;
}
// ~~~~ Result class
public static class UserInfo {
/**
* OpenID
*/
public String id;
/**
* Extensions values
*/
public Map<String, String> extensions = new HashMap<String, String>();
@Override
public String toString() {
return id + " " + extensions;
}
}
}