Skip to content

Commit f1755de

Browse files
committed
fix: Dispatch to the given servlet when using getNamedDispatcher
1 parent be15cfd commit f1755de

File tree

6 files changed

+111
-18
lines changed

6 files changed

+111
-18
lines changed

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyRequestDispatcher.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public void forward(ServletRequest servletRequest, ServletResponse servletRespon
8080
}
8181

8282
if (isNamedDispatcher) {
83-
lambdaContainerHandler.doFilter((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, getServlet((HttpServletRequest)servletRequest));
83+
lambdaContainerHandler.doFilter((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, getServlet(dispatchTo));
8484
return;
8585
}
8686

@@ -148,4 +148,9 @@ void setRequestPath(ServletRequest req, final String destinationPath) {
148148
private Servlet getServlet(HttpServletRequest req) {
149149
return ((AwsServletContext)lambdaContainerHandler.getServletContext()).getServletForPath(req.getPathInfo());
150150
}
151+
152+
private Servlet getServlet(String servletName) throws ServletException {
153+
return ((AwsServletContext)lambdaContainerHandler.getServletContext()).getServlet(servletName);
154+
}
155+
151156
}

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ public RequestDispatcher getRequestDispatcher(String s) {
171171

172172
@Override
173173
public RequestDispatcher getNamedDispatcher(String s) {
174-
throw new UnsupportedOperationException();
174+
return new AwsProxyRequestDispatcher(s, true, containerHandler);
175175
}
176176

177177

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servl
132132
chainHolder.addFilter(new FilterHolder(new ServletExecutionFilter(servletRegistration), servletContext));
133133
}
134134

135-
putFilterChainCache(type, targetPath, chainHolder);
135+
putFilterChainCache(type, targetPath, servlet, chainHolder);
136136
// update total filter size
137137
if (filtersSize != registrations.size()) {
138138
filtersSize = registrations.size();
@@ -151,13 +151,16 @@ FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servl
151151
* initialized with the cached list of {@link FilterHolder} objects
152152
* @param type The dispatcher type for the incoming request
153153
* @param targetPath The request path - this is extracted with the <code>getPath</code> method of the request object
154-
* @param servlet Servlet to put at the end of the chain (optional).
154+
* @param servlet The final servlet in the filter chain (if any)
155155
* @return A populated FilterChainHolder
156156
*/
157157
private FilterChainHolder getFilterChainCache(final DispatcherType type, final String targetPath, Servlet servlet) {
158158
TargetCacheKey key = new TargetCacheKey();
159159
key.setDispatcherType(type);
160160
key.setTargetPath(targetPath);
161+
if (servlet != null) {
162+
key.setServletName(servlet.getServletConfig().getServletName());
163+
}
161164

162165
if (!filterCache.containsKey(key)) {
163166
return null;
@@ -174,12 +177,16 @@ private FilterChainHolder getFilterChainCache(final DispatcherType type, final S
174177
* method to retry this.
175178
* @param type DispatcherType from the incoming request
176179
* @param targetPath The target path in the API
177-
* @param holder The FilterChainHolder object to save in the cache
180+
* @param servlet The final servlet in the filter chain (if any)
181+
* @param holder The FilterChainHolder object to save in the cache
178182
*/
179-
private void putFilterChainCache(final DispatcherType type, final String targetPath, final FilterChainHolder holder) {
183+
private void putFilterChainCache(final DispatcherType type, final String targetPath, Servlet servlet, final FilterChainHolder holder) {
180184
TargetCacheKey key = new TargetCacheKey();
181185
key.setDispatcherType(type);
182186
key.setTargetPath(targetPath);
187+
if (servlet != null) {
188+
key.setServletName(servlet.getServletConfig().getServletName());
189+
}
183190

184191
// we couldn't compute the hash code because either the target path or dispatcher type were null
185192
if (key.hashCode() == -1) {
@@ -256,6 +263,7 @@ protected static class TargetCacheKey {
256263

257264
private String targetPath;
258265
private DispatcherType dispatcherType;
266+
private String servletName;
259267

260268

261269
//-------------------------------------------------------------
@@ -295,10 +303,15 @@ public int hashCode() {
295303
}
296304
hashString += ":" + hashDispatcher;
297305

306+
if (servletName != null) {
307+
hashString += ":" + servletName;
308+
}
309+
298310
return hashString.hashCode();
299311
}
300312

301313

314+
302315
@Override
303316
public boolean equals(Object key) {
304317
if (key == null) {
@@ -324,6 +337,11 @@ void setTargetPath(String targetPath) {
324337
void setDispatcherType(DispatcherType dispatcherType) {
325338
this.dispatcherType = dispatcherType;
326339
}
340+
public void setServletName(String servletName) {
341+
this.servletName = servletName;
342+
}
343+
344+
327345
}
328346

329347
@SuppressFBWarnings("URF_UNREAD_FIELD")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.amazonaws.serverless.proxy.internal.testutils;
2+
3+
import java.io.IOException;
4+
5+
import javax.servlet.ServletException;
6+
import javax.servlet.http.HttpServlet;
7+
import javax.servlet.http.HttpServletRequest;
8+
import javax.servlet.http.HttpServletResponse;
9+
10+
public class MockServlet extends HttpServlet {
11+
12+
private int serviceCalls = 0;
13+
14+
@Override
15+
protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
16+
super.service(req, resp);
17+
serviceCalls++;
18+
}
19+
20+
public int getServiceCalls() {
21+
return serviceCalls;
22+
}
23+
}

aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;
44
import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext;
5+
import com.amazonaws.serverless.proxy.internal.testutils.MockServlet;
56
import com.amazonaws.services.lambda.runtime.Context;
67
import org.junit.jupiter.api.BeforeAll;
78
import org.junit.jupiter.api.Test;
@@ -17,6 +18,8 @@
1718
import static org.junit.jupiter.api.Assertions.*;
1819

1920
public class AwsFilterChainManagerTest {
21+
private static final String SERVLET1_NAME = "Servlet 1";
22+
private static final String SERVLET2_NAME = "Servlet 2";
2023
private static final String REQUEST_CUSTOM_ATTRIBUTE_NAME = "X-Custom-Attribute";
2124
private static final String REQUEST_CUSTOM_ATTRIBUTE_VALUE = "CustomAttrValue";
2225

@@ -36,6 +39,10 @@ public static void setUp() {
3639
reg2.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/second/*");
3740
FilterRegistration.Dynamic reg3 = servletContext.addFilter("Filter3", new MockFilter());
3841
reg3.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/third/fourth/*");
42+
ServletRegistration.Dynamic firstServlet = servletContext.addServlet(SERVLET1_NAME, new MockServlet());
43+
firstServlet.addMapping("/first/*");
44+
ServletRegistration.Dynamic secondServlet = servletContext.addServlet(SERVLET2_NAME, new MockServlet());
45+
secondServlet.addMapping("/second/*");
3946

4047
chainManager = new AwsFilterChainManager((AwsServletContext) servletContext);
4148
}
@@ -88,6 +95,22 @@ void cacheKey_compare_differentDispatcher() {
8895
assertNotEquals(cacheKey, secondCacheKey);
8996
}
9097

98+
@Test
99+
void cacheKey_compare_differentServlet() {
100+
FilterChainManager.TargetCacheKey cacheKey = new FilterChainManager.TargetCacheKey();
101+
cacheKey.setDispatcherType(DispatcherType.REQUEST);
102+
cacheKey.setTargetPath("/first/path");
103+
cacheKey.setServletName("Dispatcher servlet");
104+
105+
FilterChainManager.TargetCacheKey secondCacheKey = new FilterChainManager.TargetCacheKey();
106+
secondCacheKey.setDispatcherType(DispatcherType.REQUEST);
107+
secondCacheKey.setTargetPath("/first/path");
108+
cacheKey.setServletName("Real servlet");
109+
110+
assertNotEquals(cacheKey.hashCode(), secondCacheKey.hashCode());
111+
assertNotEquals(cacheKey, secondCacheKey);
112+
}
113+
91114
@Test
92115
void cacheKey_compare_additionalChars() {
93116
FilterChainManager.TargetCacheKey cacheKey = new FilterChainManager.TargetCacheKey();
@@ -154,7 +177,7 @@ void filterChain_matchMultipleTimes_expectSameMatch() {
154177
}
155178

156179
@Test
157-
void filerChain_executeMultipleFilters_expectRunEachTime() {
180+
void filterChain_executeMultipleFilters_expectRunEachTime() {
158181
AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(
159182
new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null
160183
);
@@ -204,6 +227,34 @@ void filerChain_executeMultipleFilters_expectRunEachTime() {
204227
assertEquals(REQUEST_CUSTOM_ATTRIBUTE_VALUE, req2.getAttribute(REQUEST_CUSTOM_ATTRIBUTE_NAME));
205228
}
206229

230+
@Test
231+
void filterChain_multipleServlets_callsCorrectServlet() throws IOException, ServletException {
232+
MockServlet servlet1 = (MockServlet) servletContext.getServlet(SERVLET1_NAME);
233+
ServletConfig servlet1Config = ((AwsServletRegistration) servletContext.getServletRegistration(SERVLET1_NAME)).getServletConfig();
234+
servlet1.init(servlet1Config);
235+
236+
MockServlet servlet2 = (MockServlet) servletContext.getServlet(SERVLET2_NAME);
237+
ServletConfig servlet2Config = ((AwsServletRegistration) servletContext.getServletRegistration(SERVLET2_NAME)).getServletConfig();
238+
servlet2.init(servlet2Config);
239+
240+
AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(
241+
new AwsProxyRequestBuilder("/", "GET").build(), lambdaContext, null
242+
);
243+
AwsHttpServletResponse resp = new AwsHttpServletResponse(req, new CountDownLatch(1));
244+
245+
FilterChainHolder servlet1filterChain = chainManager.getFilterChain(req, servlet1);
246+
servlet1filterChain.doFilter(req, resp);
247+
248+
assertEquals(1, servlet1.getServiceCalls());
249+
assertEquals(0, servlet2.getServiceCalls());
250+
251+
FilterChainHolder servlet2filterChain = chainManager.getFilterChain(req, servlet2);
252+
servlet2filterChain.doFilter(req, resp);
253+
254+
assertEquals(1, servlet1.getServiceCalls());
255+
assertEquals(1, servlet2.getServiceCalls());
256+
}
257+
207258
@Test
208259
void filterChain_getFilterChain_multipleFilters() {
209260
AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(

aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContextTest.java

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@
1313
import javax.servlet.http.HttpServletResponse;
1414

1515
import java.io.File;
16-
import java.io.FileNotFoundException;
1716
import java.io.IOException;
18-
import java.io.PrintWriter;
19-
import java.io.UnsupportedEncodingException;
20-
import java.nio.file.Files;
21-
import java.nio.file.Paths;
2217
import java.util.concurrent.CountDownLatch;
2318

2419
import static org.junit.jupiter.api.Assertions.*;
@@ -190,12 +185,7 @@ void unsupportedOperations_expectExceptions() {
190185
} catch (UnsupportedOperationException e) {
191186
exCount++;
192187
}
193-
try {
194-
STATIC_CTX.getNamedDispatcher("1");
195-
} catch (UnsupportedOperationException e) {
196-
exCount++;
197-
}
198-
assertEquals(2, exCount);
188+
assertEquals(1, exCount);
199189

200190
assertNull(STATIC_CTX.getServletRegistration("1"));
201191
}
@@ -232,6 +222,12 @@ void addServlet_callsDefaultConstructor() throws ServletException {
232222
assertEquals("", ((TestServlet)ctx.getServlet("srv1")).getId());
233223
}
234224

225+
@Test
226+
void getNamedDispatcher_returnsDispatcher() {
227+
AwsServletContext ctx = new AwsServletContext(null);
228+
assertNotNull(ctx.getNamedDispatcher("/hello"));
229+
}
230+
235231
public static class TestServlet implements Servlet {
236232
private String id;
237233

0 commit comments

Comments
 (0)