diff --git a/ReactAndroid/src/main/java/com/facebook/react/common/network/OkHttpCallUtil.java b/ReactAndroid/src/main/java/com/facebook/react/common/network/OkHttpCallUtil.java index 7fa7d967a..4da3a5e9b 100644 --- a/ReactAndroid/src/main/java/com/facebook/react/common/network/OkHttpCallUtil.java +++ b/ReactAndroid/src/main/java/com/facebook/react/common/network/OkHttpCallUtil.java @@ -32,8 +32,4 @@ public class OkHttpCallUtil { } } } - - public static void cancelAll(OkHttpClient client) { - client.dispatcher().cancelAll(); - } } diff --git a/ReactAndroid/src/main/java/com/facebook/react/modules/network/NetworkingModule.java b/ReactAndroid/src/main/java/com/facebook/react/modules/network/NetworkingModule.java index 68e6e93b0..1410f085e 100644 --- a/ReactAndroid/src/main/java/com/facebook/react/modules/network/NetworkingModule.java +++ b/ReactAndroid/src/main/java/com/facebook/react/modules/network/NetworkingModule.java @@ -9,6 +9,17 @@ package com.facebook.react.modules.network; +import javax.annotation.Nullable; + +import java.io.IOException; +import java.io.InputStream; +import java.io.Reader; +import java.net.SocketTimeoutException; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; + import com.facebook.react.bridge.Arguments; import com.facebook.react.bridge.ExecutorToken; import com.facebook.react.bridge.GuardedAsyncTask; @@ -22,15 +33,6 @@ import com.facebook.react.bridge.WritableMap; import com.facebook.react.common.network.OkHttpCallUtil; import com.facebook.react.modules.core.DeviceEventManagerModule; -import java.io.IOException; -import java.io.InputStream; -import java.io.Reader; -import java.net.SocketTimeoutException; -import java.util.List; -import java.util.concurrent.TimeUnit; - -import javax.annotation.Nullable; - import okhttp3.Call; import okhttp3.Callback; import okhttp3.Headers; @@ -61,6 +63,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule { private final ForwardingCookieHandler mCookieHandler; private final @Nullable String mDefaultUserAgent; private final CookieJarContainer mCookieJarContainer; + private final Set mRequestIds; private boolean mShuttingDown; /* package */ NetworkingModule( @@ -69,7 +72,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule { OkHttpClient client, @Nullable List networkInterceptorCreators) { super(reactContext); - + if (networkInterceptorCreators != null) { OkHttpClient.Builder clientBuilder = client.newBuilder(); for (NetworkInterceptorCreator networkInterceptorCreator : networkInterceptorCreators) { @@ -83,6 +86,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule { mCookieJarContainer = (CookieJarContainer) mClient.cookieJar(); mShuttingDown = false; mDefaultUserAgent = defaultUserAgent; + mRequestIds = new HashSet<>(); } /** @@ -138,7 +142,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule { @Override public void onCatalystInstanceDestroy() { mShuttingDown = true; - OkHttpCallUtil.cancelAll(mClient); + cancelAllRequests(); mCookieHandler.destroy(); mCookieJarContainer.removeCookieJar(); @@ -241,6 +245,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule { requestBuilder.method(method, RequestBodyUtil.getEmptyBody(method)); } + addRequest(requestId); client.newCall(requestBuilder.build()).enqueue( new Callback() { @Override @@ -248,6 +253,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule { if (mShuttingDown) { return; } + removeRequest(requestId); onRequestError(executorToken, requestId, e.getMessage(), e); } @@ -256,7 +262,7 @@ public final class NetworkingModule extends ReactContextBaseJavaModule { if (mShuttingDown) { return; } - + removeRequest(requestId); // Before we touch the body send headers to JS onResponseReceived(executorToken, requestId, response); @@ -335,6 +341,21 @@ public final class NetworkingModule extends ReactContextBaseJavaModule { getEventEmitter(ExecutorToken).emit("didReceiveNetworkResponse", args); } + private synchronized void addRequest(int requestId) { + mRequestIds.add(requestId); + } + + private synchronized void removeRequest(int requestId) { + mRequestIds.remove(requestId); + } + + private synchronized void cancelAllRequests() { + for (Integer requestId : mRequestIds) { + cancelRequest(requestId); + } + mRequestIds.clear(); + } + private static WritableMap translateHeaders(Headers headers) { WritableMap responseHeaders = Arguments.createMap(); for (int i = 0; i < headers.size(); i++) { @@ -353,6 +374,11 @@ public final class NetworkingModule extends ReactContextBaseJavaModule { @ReactMethod public void abortRequest(ExecutorToken executorToken, final int requestId) { + cancelRequest(requestId); + removeRequest(requestId); + } + + private void cancelRequest(final int requestId) { // We have to use AsyncTask since this might trigger a NetworkOnMainThreadException, this is an // open issue on OkHttp: https://github.com/square/okhttp/issues/869 new GuardedAsyncTask(getReactApplicationContext()) { diff --git a/ReactAndroid/src/test/java/com/facebook/react/modules/network/NetworkingModuleTest.java b/ReactAndroid/src/test/java/com/facebook/react/modules/network/NetworkingModuleTest.java index b80561183..ec04b5cbd 100644 --- a/ReactAndroid/src/test/java/com/facebook/react/modules/network/NetworkingModuleTest.java +++ b/ReactAndroid/src/test/java/com/facebook/react/modules/network/NetworkingModuleTest.java @@ -15,12 +15,13 @@ import java.util.List; import com.facebook.react.bridge.Arguments; import com.facebook.react.bridge.ExecutorToken; -import com.facebook.react.bridge.ReactApplicationContext; -import com.facebook.react.bridge.ReactContext; import com.facebook.react.bridge.JavaOnlyArray; import com.facebook.react.bridge.JavaOnlyMap; +import com.facebook.react.bridge.ReactApplicationContext; +import com.facebook.react.bridge.ReactContext; import com.facebook.react.bridge.WritableArray; import com.facebook.react.bridge.WritableMap; +import com.facebook.react.common.network.OkHttpCallUtil; import com.facebook.react.modules.core.DeviceEventManagerModule.RCTDeviceEventEmitter; import okhttp3.Call; @@ -31,7 +32,6 @@ import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.RequestBody; import okio.Buffer; - import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -40,8 +40,8 @@ import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.powermock.api.mockito.PowerMockito; -import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.rule.PowerMockRule; import org.robolectric.RobolectricTestRunner; @@ -63,7 +63,8 @@ import static org.mockito.Mockito.when; MultipartBody.class, MultipartBody.Builder.class, NetworkingModule.class, - OkHttpClient.class}) + OkHttpClient.class, + OkHttpCallUtil.class}) @RunWith(RobolectricTestRunner.class) @PowerMockIgnore({"org.mockito.*", "org.robolectric.*", "android.*"}) public class NetworkingModuleTest { @@ -476,4 +477,105 @@ public class NetworkingModuleTest { assertThat(bodyRequestBody.get(1).contentType()).isEqualTo(MediaType.parse("image/jpg")); assertThat(bodyRequestBody.get(1).contentLength()).isEqualTo("imageUri".getBytes().length); } + + @Test + public void testCancelAllCallsOnCatalystInstanceDestroy() throws Exception { + PowerMockito.mockStatic(OkHttpCallUtil.class); + OkHttpClient httpClient = mock(OkHttpClient.class); + final int requests = 3; + final Call[] calls = new Call[requests]; + for (int idx = 0; idx < requests; idx++) { + calls[idx] = mock(Call.class); + } + + when(httpClient.cookieJar()).thenReturn(mock(CookieJarContainer.class)); + when(httpClient.newCall(any(Request.class))).thenAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + Request request = (Request) invocation.getArguments()[0]; + return calls[(Integer) request.tag() - 1]; + } + }); + NetworkingModule networkingModule = new NetworkingModule(null, "", httpClient); + networkingModule.initialize(); + + for (int idx = 0; idx < requests; idx++) { + networkingModule.sendRequest( + mock(ExecutorToken.class), + "GET", + "http://somedomain/foo", + idx + 1, + JavaOnlyArray.of(), + null, + true, + 0); + } + verify(httpClient, times(3)).newCall(any(Request.class)); + + networkingModule.onCatalystInstanceDestroy(); + PowerMockito.verifyStatic(times(3)); + ArgumentCaptor clientArguments = ArgumentCaptor.forClass(OkHttpClient.class); + ArgumentCaptor requestIdArguments = ArgumentCaptor.forClass(Integer.class); + OkHttpCallUtil.cancelTag(clientArguments.capture(), requestIdArguments.capture()); + + assertThat(requestIdArguments.getAllValues().size()).isEqualTo(requests); + for (int idx = 0; idx < requests; idx++) { + assertThat(requestIdArguments.getAllValues().contains(idx + 1)).isTrue(); + } + } + + @Test + public void testCancelSomeCallsOnCatalystInstanceDestroy() throws Exception { + PowerMockito.mockStatic(OkHttpCallUtil.class); + OkHttpClient httpClient = mock(OkHttpClient.class); + final int requests = 3; + final Call[] calls = new Call[requests]; + for (int idx = 0; idx < requests; idx++) { + calls[idx] = mock(Call.class); + } + + when(httpClient.cookieJar()).thenReturn(mock(CookieJarContainer.class)); + when(httpClient.newCall(any(Request.class))).thenAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + Request request = (Request) invocation.getArguments()[0]; + return calls[(Integer) request.tag() - 1]; + } + }); + NetworkingModule networkingModule = new NetworkingModule(null, "", httpClient); + + for (int idx = 0; idx < requests; idx++) { + networkingModule.sendRequest( + mock(ExecutorToken.class), + "GET", + "http://somedomain/foo", + idx + 1, + JavaOnlyArray.of(), + null, + true, + 0); + } + verify(httpClient, times(3)).newCall(any(Request.class)); + + networkingModule.abortRequest(mock(ExecutorToken.class), requests); + PowerMockito.verifyStatic(times(1)); + ArgumentCaptor clientArguments = ArgumentCaptor.forClass(OkHttpClient.class); + ArgumentCaptor requestIdArguments = ArgumentCaptor.forClass(Integer.class); + OkHttpCallUtil.cancelTag(clientArguments.capture(), requestIdArguments.capture()); + assertThat(requestIdArguments.getAllValues().size()).isEqualTo(1); + assertThat(requestIdArguments.getAllValues().get(0)).isEqualTo(requests); + + // verifyStatic actually does not clear all calls so far, so we have to check for all of them. + // If `cancelTag` would've been called again for the aborted call, we would have had + // `requests + 1` calls. + networkingModule.onCatalystInstanceDestroy(); + PowerMockito.verifyStatic(times(requests)); + clientArguments = ArgumentCaptor.forClass(OkHttpClient.class); + requestIdArguments = ArgumentCaptor.forClass(Integer.class); + OkHttpCallUtil.cancelTag(clientArguments.capture(), requestIdArguments.capture()); + assertThat(requestIdArguments.getAllValues().size()).isEqualTo(requests); + for (int idx = 0; idx < requests; idx++) { + assertThat(requestIdArguments.getAllValues().contains(idx + 1)).isTrue(); + } + } }