package sparklr.common;
import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import org.junit.rules.MethodRule;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.Statement;
import org.springframework.boot.test.TestRestTemplate;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.security.oauth2.client.test.RestTemplateHolder;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriTemplate;
/**
* <p>
* A rule that provides HTTP connectivity to test cases on the assumption that the server is available when test methods
* fire.
* </p>
*
* @author Dave Syer
*
*/
public class HttpTestUtils implements MethodRule, RestTemplateHolder {
private static int DEFAULT_PORT = 8080;
private static String DEFAULT_HOST = "localhost";
private int port;
private String hostName = DEFAULT_HOST;
private RestOperations client;
private String prefix = "";
/**
* @return a new rule that sets up default host and port etc.
*/
public static HttpTestUtils standard() {
return new HttpTestUtils();
}
private HttpTestUtils() {
setPort(DEFAULT_PORT);
}
/**
* @param prefix
*/
public void setPrefix(String prefix) {
if (!StringUtils.hasText(prefix)) {
prefix = "";
} else while (prefix.endsWith("/")) {
prefix = prefix.substring(0, prefix.lastIndexOf("/"));
}
this.prefix = prefix;
}
/**
* @param port the port to set
*/
public HttpTestUtils setPort(int port) {
this.port = port;
if (client == null) {
client = createRestTemplate();
}
return this;
}
/**
* @param hostName the hostName to set
*/
public HttpTestUtils setHostName(String hostName) {
this.hostName = hostName;
return this;
}
public Statement apply(final Statement base, FrameworkMethod method, Object target) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
base.evaluate();
}
};
}
public String getBaseUrl() {
return "http://" + hostName + ":" + port + prefix;
}
public String getUrl(String path) {
if (path.startsWith("http")) {
return path;
}
if (!path.startsWith("/")) {
path = "/" + path;
}
return "http://" + hostName + ":" + port + prefix + path;
}
public ResponseEntity<String> postForString(String path, MultiValueMap<String, String> formData) {
HttpHeaders headers = new HttpHeaders();
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON));
return client.exchange(getUrl(path), HttpMethod.POST, new HttpEntity<MultiValueMap<String, String>>(formData,
headers), String.class);
}
public ResponseEntity<String> postForString(String path, HttpHeaders headers, MultiValueMap<String, String> formData) {
HttpHeaders actualHeaders = new HttpHeaders();
actualHeaders.putAll(headers);
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON));
return client.exchange(getUrl(path), HttpMethod.POST, new HttpEntity<MultiValueMap<String, String>>(formData,
actualHeaders), String.class);
}
@SuppressWarnings("rawtypes")
public ResponseEntity<Map> postForMap(String path, MultiValueMap<String, String> formData) {
return postForMap(path, new HttpHeaders(), formData);
}
@SuppressWarnings("rawtypes")
public ResponseEntity<Map> postForMap(String path, HttpHeaders headers, MultiValueMap<String, String> formData) {
if (headers.getContentType() == null) {
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
}
return client.exchange(getUrl(path), HttpMethod.POST, new HttpEntity<MultiValueMap<String, String>>(formData,
headers), Map.class);
}
public ResponseEntity<Void> postForStatus(String path, MultiValueMap<String, String> formData) {
return postForStatus(this.client, path, formData);
}
public ResponseEntity<Void> postForStatus(String path, HttpHeaders headers, MultiValueMap<String, String> formData) {
return postForStatus(this.client, path, headers, formData);
}
private ResponseEntity<Void> postForStatus(RestOperations client, String path,
MultiValueMap<String, String> formData) {
return postForStatus(client, path, new HttpHeaders(), formData);
}
private ResponseEntity<Void> postForStatus(RestOperations client, String path, HttpHeaders headers,
MultiValueMap<String, String> formData) {
HttpHeaders actualHeaders = new HttpHeaders();
actualHeaders.putAll(headers);
actualHeaders.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
return client.exchange(getUrl(path), HttpMethod.POST, new HttpEntity<MultiValueMap<String, String>>(formData,
actualHeaders), (Class<Void>) null);
}
public ResponseEntity<Void> postForRedirect(String path, HttpHeaders headers, MultiValueMap<String, String> params) {
ResponseEntity<Void> exchange = postForStatus(path, headers, params);
if (exchange.getStatusCode() != HttpStatus.FOUND) {
throw new IllegalStateException("Expected 302 but server returned status code " + exchange.getStatusCode());
}
if (exchange.getHeaders().containsKey("Set-Cookie")) {
String cookie = exchange.getHeaders().getFirst("Set-Cookie");
headers.set("Cookie", cookie);
}
String location = exchange.getHeaders().getLocation().toString();
return client.exchange(location, HttpMethod.GET, new HttpEntity<Void>(null, headers), (Class<Void>) null);
}
public ResponseEntity<String> getForString(String path) {
return getForString(path, new HttpHeaders());
}
public ResponseEntity<String> getForString(String path, final HttpHeaders headers) {
return client.exchange(getUrl(path), HttpMethod.GET, new HttpEntity<Void>((Void) null, headers), String.class);
}
public ResponseEntity<String> getForString(String path, final HttpHeaders headers, Map<String, String> uriVariables) {
return client.exchange(getUrl(path), HttpMethod.GET, new HttpEntity<Void>((Void) null, headers), String.class,
uriVariables);
}
public ResponseEntity<Void> getForResponse(String path, final HttpHeaders headers, Map<String, String> uriVariables) {
HttpEntity<Void> request = new HttpEntity<Void>(null, headers);
return client.exchange(getUrl(path), HttpMethod.GET, request, (Class<Void>) null, uriVariables);
}
public ResponseEntity<Void> getForResponse(String path, HttpHeaders headers) {
return getForResponse(path, headers, Collections.<String, String> emptyMap());
}
public HttpStatus getStatusCode(String path, final HttpHeaders headers) {
ResponseEntity<Void> response = getForResponse(path, headers);
return response.getStatusCode();
}
public HttpStatus getStatusCode(String path) {
return getStatusCode(getUrl(path), null);
}
public void setRestTemplate(RestOperations restTemplate) {
client = restTemplate;
}
public RestOperations getRestTemplate() {
return client;
}
public RestOperations createRestTemplate() {
RestTemplate client = new TestRestTemplate();
return client;
}
public UriBuilder buildUri(String url) {
return UriBuilder.fromUri(url.startsWith("http:") ? url : getUrl(url));
}
public static class UriBuilder {
private final String url;
private Map<String, String> params = new LinkedHashMap<String, String>();
public UriBuilder(String url) {
this.url = url;
}
public static UriBuilder fromUri(String url) {
return new UriBuilder(url);
}
public UriBuilder queryParam(String key, String value) {
params.put(key, value);
return this;
}
public String pattern() {
StringBuilder builder = new StringBuilder();
// try {
builder.append(url.replace(" ", "+"));
if (!params.isEmpty()) {
builder.append("?");
boolean first = true;
for (String key : params.keySet()) {
if (!first) {
builder.append("&");
}
else {
first = false;
}
String value = params.get(key);
if (value.contains("=")) {
value = value.replace("=", "%3D");
}
builder.append(key + "={" + key + "}");
}
}
return builder.toString();
}
public Map<String, String> params() {
return params;
}
public URI build() {
return new UriTemplate(pattern()).expand(params);
}
}
}