diff --git a/nuxeo-platform-web-common/pom.xml b/nuxeo-platform-web-common/pom.xml index 59d47fafa..c909f3b06 100644 --- a/nuxeo-platform-web-common/pom.xml +++ b/nuxeo-platform-web-common/pom.xml @@ -112,7 +112,8 @@ org.mockito mockito-all + test - \ No newline at end of file + diff --git a/nuxeo-platform-web-common/src/test/java/org/nuxeo/ecm/platform/ui/web/download/TestDownloadServlet.java b/nuxeo-platform-web-common/src/test/java/org/nuxeo/ecm/platform/ui/web/download/TestDownloadServlet.java index d7e2efedc..4a1d48450 100644 --- a/nuxeo-platform-web-common/src/test/java/org/nuxeo/ecm/platform/ui/web/download/TestDownloadServlet.java +++ b/nuxeo-platform-web-common/src/test/java/org/nuxeo/ecm/platform/ui/web/download/TestDownloadServlet.java @@ -1,15 +1,29 @@ package org.nuxeo.ecm.platform.ui.web.download; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.io.OutputStream; +import java.io.PrintWriter; +import java.lang.reflect.Method; -import org.junit.Test; -import static org.junit.Assert.assertEquals; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import org.junit.Test; +import org.nuxeo.ecm.core.api.Blob; import org.nuxeo.ecm.core.api.ClientException; +import org.nuxeo.ecm.core.api.impl.blob.InputStreamBlob; +import org.nuxeo.ecm.core.storage.sql.Binary; +import org.nuxeo.ecm.core.storage.sql.coremodel.SQLBlob; import org.nuxeo.ecm.platform.ui.web.download.DownloadServlet.ByteRange; +import org.nuxeo.ecm.platform.web.common.requestcontroller.filter.BufferingServletOutputStream; public class TestDownloadServlet { @@ -70,4 +84,64 @@ public void testWriteStream() throws Exception { DownloadServlet.writeStream(in, out, range); assertEquals("world", out.toString()); } + + @Test + public void testETagHeaderNone() throws Exception { + doTestETagHeader(null); + } + + @Test + public void testETagHeaderNotMatched() throws Exception { + doTestETagHeader(Boolean.FALSE); + } + + @Test + public void testETagHeaderMatched() throws Exception { + doTestETagHeader(Boolean.TRUE); + } + + private void doTestETagHeader(Boolean match) throws Exception { + HttpServletRequest req = mock(HttpServletRequest.class); + HttpServletResponse resp = mock(HttpServletResponse.class); + Binary binary = mock(Binary.class); + String s = "Hello, world!"; + final byte[] bytes = s.getBytes(); + InputStream in = new ByteArrayInputStream(bytes); + String digest = "12345"; + String bogusDigest = "67890"; + SQLBlob blob = new SQLBlob(binary, "myFile.txt", "text/plain", + "UTF-8", digest, bytes.length); + when(binary.getStream()).thenReturn(in); + when(binary.getDigest()).thenReturn(digest); + String ifNoneMatchHeader = null; + if (match != null) { + ifNoneMatchHeader = (match) ? getETag(digest) : getETag(bogusDigest); + } + when(req.getHeader("If-None-Match")).thenReturn(ifNoneMatchHeader); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + BufferingServletOutputStream sos = new BufferingServletOutputStream(baos); + PrintWriter printWriter = new PrintWriter(sos); + when(resp.getOutputStream()).thenReturn(sos); + when(resp.getWriter()).thenReturn(printWriter); + + DownloadServlet servlet = new DownloadServlet(); + Method m = servlet.getClass().getDeclaredMethod("downloadBlob", new Class[] + {HttpServletRequest.class, HttpServletResponse.class, + Blob.class, String.class}); + m.setAccessible(true); + m.invoke(servlet, new Object[] {req, resp, blob, (String) null}); + + verify(req, atLeast(1)).getHeader("If-None-Match"); + if (Boolean.TRUE.equals(match)) { + assertEquals(0, baos.toByteArray().length); + verify(resp).sendError(HttpServletResponse.SC_NOT_MODIFIED); + } else { + assertEquals(s, baos.toString()); + verify(resp).setHeader("ETag", getETag(digest)); + } + } + + private String getETag(String digest) { + return digest; + } }