/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.messaging;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Scope;
import org.springframework.context.annotation.ScopedProxyMode;
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.messaging.simp.annotation.SubscribeMapping;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.stereotype.Controller;
import org.springframework.web.socket.AbstractWebSocketIntegrationTests;
import org.springframework.web.socket.JettyWebSocketTestServer;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.TomcatWebSocketTestServer;
import org.springframework.web.socket.UndertowTestServer;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.jetty.JettyWebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.config.annotation.DelegatingWebSocketMessageBrokerConfiguration;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import static org.junit.Assert.*;
import static org.springframework.web.socket.messaging.StompTextMessageBuilder.*;
/**
* Integration tests with annotated message-handling methods.
*
* @author Rossen Stoyanchev
*/
@RunWith(Parameterized.class)
public class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@Parameters
public static Iterable<Object[]> arguments() {
return Arrays.asList(new Object[][] {
{new JettyWebSocketTestServer(), new JettyWebSocketClient()},
{new TomcatWebSocketTestServer(), new StandardWebSocketClient()},
{new UndertowTestServer(), new StandardWebSocketClient()}
});
}
@Override
protected Class<?>[] getAnnotatedConfigClasses() {
return new Class<?>[] { TestMessageBrokerConfiguration.class, TestMessageBrokerConfigurer.class };
}
@Test
public void sendMessageToController() throws Exception {
TextMessage message = create(StompCommand.SEND).headers("destination:/app/simple").build();
WebSocketSession session = doHandshake(new TestClientWebSocketHandler(0, message), "/ws").get();
SimpleController controller = this.wac.getBean(SimpleController.class);
try {
assertTrue(controller.latch.await(10, TimeUnit.SECONDS));
}
finally {
session.close();
}
}
@Test
public void sendMessageToControllerAndReceiveReplyViaTopic() throws Exception {
TextMessage message1 = create(StompCommand.SUBSCRIBE)
.headers("id:subs1", "destination:/topic/increment").build();
TextMessage message2 = create(StompCommand.SEND)
.headers("destination:/app/increment").body("5").build();
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, message1, message2);
WebSocketSession session = doHandshake(clientHandler, "/ws").get();
try {
assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS));
}
finally {
session.close();
}
}
// SPR-10930
@Test
public void sendMessageToBrokerAndReceiveReplyViaTopic() throws Exception {
TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", "destination:/topic/foo").build();
TextMessage m2 = create(StompCommand.SEND).headers("destination:/topic/foo").body("5").build();
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m1, m2);
WebSocketSession session = doHandshake(clientHandler, "/ws").get();
try {
assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS));
String payload = clientHandler.actual.get(0).getPayload();
assertTrue("Expected STOMP MESSAGE, got " + payload, payload.startsWith("MESSAGE\n"));
}
finally {
session.close();
}
}
// SPR-11648
@Test
public void sendSubscribeToControllerAndReceiveReply() throws Exception {
String destHeader = "destination:/app/number";
TextMessage message = create(StompCommand.SUBSCRIBE).headers("id:subs1", destHeader).build();
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, message);
WebSocketSession session = doHandshake(clientHandler, "/ws").get();
try {
assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS));
String payload = clientHandler.actual.get(0).getPayload();
assertTrue("Expected STOMP destination=/app/number, got " + payload, payload.contains(destHeader));
assertTrue("Expected STOMP Payload=42, got " + payload, payload.contains("42"));
}
finally {
session.close();
}
}
@Test
public void handleExceptionAndSendToUser() throws Exception {
String destHeader = "destination:/user/queue/error";
TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destHeader).build();
TextMessage m2 = create(StompCommand.SEND).headers("destination:/app/exception").build();
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m1, m2);
WebSocketSession session = doHandshake(clientHandler, "/ws").get();
try {
assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS));
String payload = clientHandler.actual.get(0).getPayload();
assertTrue(payload.startsWith("MESSAGE\n"));
assertTrue(payload.contains("destination:/user/queue/error\n"));
assertTrue(payload.endsWith("Got error: Bad input\0"));
}
finally {
session.close();
}
}
@Test
public void webSocketScope() throws Exception {
TextMessage message1 = create(StompCommand.SUBSCRIBE)
.headers("id:subs1", "destination:/topic/scopedBeanValue").build();
TextMessage message2 = create(StompCommand.SEND)
.headers("destination:/app/scopedBeanValue").build();
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, message1, message2);
WebSocketSession session = doHandshake(clientHandler, "/ws").get();
try {
assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS));
String payload = clientHandler.actual.get(0).getPayload();
assertTrue(payload.startsWith("MESSAGE\n"));
assertTrue(payload.contains("destination:/topic/scopedBeanValue\n"));
assertTrue(payload.endsWith("55\0"));
}
finally {
session.close();
}
}
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Controller
private @interface IntegrationTestController {
}
@IntegrationTestController
static class SimpleController {
private CountDownLatch latch = new CountDownLatch(1);
@MessageMapping(value="/simple")
public void handle() {
this.latch.countDown();
}
@MessageMapping(value="/exception")
public void handleWithError() {
throw new IllegalArgumentException("Bad input");
}
@MessageExceptionHandler
@SendToUser("/queue/error")
public String handleException(IllegalArgumentException ex) {
return "Got error: " + ex.getMessage();
}
}
@IntegrationTestController
static class IncrementController {
@MessageMapping(value="/increment")
public int handle(int i) {
return i + 1;
}
@SubscribeMapping("/number")
public int number() {
return 42;
}
}
@IntegrationTestController
static class ScopedBeanController {
private final ScopedBean scopedBean;
@Autowired
public ScopedBeanController(ScopedBean scopedBean) {
this.scopedBean = scopedBean;
}
@MessageMapping(value="/scopedBeanValue")
public String getValue() {
return this.scopedBean.getValue();
}
}
static interface ScopedBean {
String getValue();
}
static class ScopedBeanImpl implements ScopedBean {
private final String value;
public ScopedBeanImpl(String value) {
this.value = value;
}
@Override
public String getValue() {
return this.value;
}
}
private static class TestClientWebSocketHandler extends TextWebSocketHandler {
private final TextMessage[] messagesToSend;
private final int expected;
private final List<TextMessage> actual = new CopyOnWriteArrayList<>();
private final CountDownLatch latch;
public TestClientWebSocketHandler(int expectedNumberOfMessages, TextMessage... messagesToSend) {
this.messagesToSend = messagesToSend;
this.expected = expectedNumberOfMessages;
this.latch = new CountDownLatch(this.expected);
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
for (TextMessage message : this.messagesToSend) {
session.sendMessage(message);
}
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
this.actual.add(message);
this.latch.countDown();
}
}
@Configuration
@ComponentScan(
basePackageClasses=StompWebSocketIntegrationTests.class,
useDefaultFilters=false,
includeFilters=@ComponentScan.Filter(IntegrationTestController.class))
static class TestMessageBrokerConfigurer extends AbstractWebSocketMessageBrokerConfigurer {
@Autowired
private HandshakeHandler handshakeHandler; // can't rely on classpath for server detection
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/ws").setHandshakeHandler(this.handshakeHandler);
}
@Override
public void configureMessageBroker(MessageBrokerRegistry configurer) {
configurer.setApplicationDestinationPrefixes("/app");
configurer.enableSimpleBroker("/topic", "/queue");
}
@Bean
@Scope(value="websocket", proxyMode=ScopedProxyMode.INTERFACES)
public ScopedBean scopedBean() {
return new ScopedBeanImpl("55");
}
}
@Configuration
static class TestMessageBrokerConfiguration extends DelegatingWebSocketMessageBrokerConfiguration {
@Override
@Bean
public AbstractSubscribableChannel clientInboundChannel() {
return new ExecutorSubscribableChannel(); // synchronous
}
@Override
@Bean
public AbstractSubscribableChannel clientOutboundChannel() {
return new ExecutorSubscribableChannel(); // synchronous
}
}
}