Skip to content

Commit

Permalink
Refine condition to send WebSocket binary messages
Browse files Browse the repository at this point in the history
The following two refinements have been added:
1) SockJS doesn't support binary messages so don't even try
2) don't bother if payload.length == 0

Issue: SPR-12475
  • Loading branch information
rstoyanchev committed Dec 29, 2014
1 parent bc075c7 commit 51367de
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
Expand All @@ -52,7 +51,6 @@
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.support.MessageHeaderInitializer;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
Expand Down Expand Up @@ -356,8 +354,13 @@ else if (StompCommand.CONNECTED.equals(command)) {
}
}
try {
byte[] bytes = this.stompEncoder.encode(stompAccessor.getMessageHeaders(), (byte[]) message.getPayload());
if (MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(stompAccessor.getContentType())) {
byte[] payload = (byte[]) message.getPayload();
byte[] bytes = this.stompEncoder.encode(stompAccessor.getMessageHeaders(), payload);

boolean useBinary = (payload.length > 0 && !(session instanceof SockJsSession) &&
MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(stompAccessor.getContentType()));

if (useBinary) {
session.sendMessage(new BinaryMessage(bytes));
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@

package org.springframework.web.socket.messaging;

import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -28,7 +42,6 @@
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.messaging.Message;
Expand All @@ -50,16 +63,14 @@
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.MimeTypeUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.handler.TestWebSocketSession;
import org.springframework.web.socket.sockjs.transport.SockJsSession;

import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

/**
* Test fixture for {@link StompSubProtocolHandler} tests.
*
Expand Down Expand Up @@ -267,6 +278,38 @@ public void handleMessageToClientUserDestination() {
assertFalse(((String) textMessage.getPayload()).contains(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION));
}

// SPR-12475

@Test
public void handleMessageToClientBinaryWebSocketMessage() {

StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
headers.setMessageId("mess0");
headers.setSubscriptionId("sub0");
headers.setContentType(MimeTypeUtils.APPLICATION_OCTET_STREAM);
headers.setDestination("/queue/foo");

// Non-empty payload

byte[] payload = new byte[1];
Message<byte[]> message = MessageBuilder.createMessage(payload, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message);

assertEquals(1, this.session.getSentMessages().size());
WebSocketMessage<?> webSocketMessage = this.session.getSentMessages().get(0);
assertTrue(webSocketMessage instanceof BinaryMessage);

// Empty payload

payload = EMPTY_PAYLOAD;
message = MessageBuilder.createMessage(payload, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message);

assertEquals(2, this.session.getSentMessages().size());
webSocketMessage = this.session.getSentMessages().get(1);
assertTrue(webSocketMessage instanceof TextMessage);
}

@Test
public void handleMessageFromClient() {

Expand Down

0 comments on commit 51367de

Please sign in to comment.