diff --git a/Libraries/WebSocket/RCTSRWebSocket.h b/Libraries/WebSocket/RCTSRWebSocket.h index 0c209c21c..127530434 100644 --- a/Libraries/WebSocket/RCTSRWebSocket.h +++ b/Libraries/WebSocket/RCTSRWebSocket.h @@ -60,10 +60,12 @@ extern NSString *const RCTSRHTTPResponseErrorKey; @property (nonatomic, readonly, copy) NSString *protocol; // Protocols should be an array of strings that turn into Sec-WebSocket-Protocol. -- (instancetype)initWithURLRequest:(NSURLRequest *)request protocols:(NSArray *)protocols NS_DESIGNATED_INITIALIZER; +// options can contain a custom "origin" NSString +- (instancetype)initWithURLRequest:(NSURLRequest *)request protocols:(NSArray *)protocols options:(NSDictionary *)options NS_DESIGNATED_INITIALIZER; - (instancetype)initWithURLRequest:(NSURLRequest *)request; // Some helper constructors. +- (instancetype)initWithURL:(NSURL *)url protocols:(NSArray *)protocols options:(NSDictionary *)options; - (instancetype)initWithURL:(NSURL *)url protocols:(NSArray *)protocols; - (instancetype)initWithURL:(NSURL *)url; diff --git a/Libraries/WebSocket/RCTSRWebSocket.m b/Libraries/WebSocket/RCTSRWebSocket.m index 5f46bfd75..3d8782c87 100644 --- a/Libraries/WebSocket/RCTSRWebSocket.m +++ b/Libraries/WebSocket/RCTSRWebSocket.m @@ -234,6 +234,7 @@ typedef void (^data_callback)(RCTSRWebSocket *webSocket, NSData *data); __strong RCTSRWebSocket *_selfRetain; NSArray *_requestedProtocols; + NSDictionary *_requestedOptions; RCTSRIOConsumerPool *_consumerPool; } @@ -244,7 +245,7 @@ static __strong NSData *CRLFCRLF; CRLFCRLF = [[NSData alloc] initWithBytes:"\r\n\r\n" length:4]; } -- (instancetype)initWithURLRequest:(NSURLRequest *)request protocols:(NSArray *)protocols; +- (instancetype)initWithURLRequest:(NSURLRequest *)request protocols:(NSArray *)protocols options:(NSDictionary *)options { RCTAssertParam(request); @@ -253,6 +254,7 @@ static __strong NSData *CRLFCRLF; _urlRequest = request; _requestedProtocols = [protocols copy]; + _requestedOptions = [options copy]; [self _RCTSR_commonInit]; } @@ -263,18 +265,23 @@ RCT_NOT_IMPLEMENTED(- (instancetype)init) - (instancetype)initWithURLRequest:(NSURLRequest *)request; { - return [self initWithURLRequest:request protocols:nil]; + return [self initWithURLRequest:request protocols:nil options: nil]; } - (instancetype)initWithURL:(NSURL *)URL; { - return [self initWithURL:URL protocols:nil]; + return [self initWithURL:URL protocols:nil options:nil]; } - (instancetype)initWithURL:(NSURL *)URL protocols:(NSArray *)protocols; +{ + return [self initWithURL:URL protocols:protocols options:nil]; +} + +- (instancetype)initWithURL:(NSURL *)URL protocols:(NSArray *)protocols options:(NSDictionary *)options { NSURLRequest *request = URL ? [NSURLRequest requestWithURL:URL] : nil; - return [self initWithURLRequest:request protocols:protocols]; + return [self initWithURLRequest:request protocols:protocols options:options]; } - (void)_RCTSR_commonInit; @@ -465,12 +472,12 @@ RCT_NOT_IMPLEMENTED(- (instancetype)init) CFHTTPMessageSetHeaderFieldValue(request, CFSTR("Sec-WebSocket-Key"), (__bridge CFStringRef)_secKey); CFHTTPMessageSetHeaderFieldValue(request, CFSTR("Sec-WebSocket-Version"), (__bridge CFStringRef)[NSString stringWithFormat:@"%ld", (long)_webSocketVersion]); - CFHTTPMessageSetHeaderFieldValue(request, CFSTR("Origin"), (__bridge CFStringRef)_url.RCTSR_origin); - if (_requestedProtocols) { CFHTTPMessageSetHeaderFieldValue(request, CFSTR("Sec-WebSocket-Protocol"), (__bridge CFStringRef)[_requestedProtocols componentsJoinedByString:@", "]); } + CFHTTPMessageSetHeaderFieldValue(request, CFSTR("Origin"), (__bridge CFStringRef)(_requestedOptions[@"origin"] ?: _url.RCTSR_origin)); + [_urlRequest.allHTTPHeaderFields enumerateKeysAndObjectsUsingBlock:^(id key, id obj, BOOL *stop) { CFHTTPMessageSetHeaderFieldValue(request, (__bridge CFStringRef)key, (__bridge CFStringRef)obj); }]; diff --git a/Libraries/WebSocket/RCTWebSocketModule.m b/Libraries/WebSocket/RCTWebSocketModule.m index e16f3067c..6e057f78b 100644 --- a/Libraries/WebSocket/RCTWebSocketModule.m +++ b/Libraries/WebSocket/RCTWebSocketModule.m @@ -44,9 +44,9 @@ RCT_EXPORT_MODULE() } } -RCT_EXPORT_METHOD(connect:(NSURL *)URL socketID:(nonnull NSNumber *)socketID) +RCT_EXPORT_METHOD(connect:(NSURL *)URL protocols:(NSArray *)protocols options:(NSDictionary *)options socketID:(nonnull NSNumber *)socketID) { - RCTSRWebSocket *webSocket = [[RCTSRWebSocket alloc] initWithURL:URL]; + RCTSRWebSocket *webSocket = [[RCTSRWebSocket alloc] initWithURL:URL protocols:protocols options:options]; webSocket.delegate = self; webSocket.reactTag = socketID; if (!_sockets) { diff --git a/Libraries/WebSocket/WebSocket.js b/Libraries/WebSocket/WebSocket.js index 515a6ad5c..b9eb255cc 100644 --- a/Libraries/WebSocket/WebSocket.js +++ b/Libraries/WebSocket/WebSocket.js @@ -26,15 +26,16 @@ var CLOSE_NORMAL = 1000; * Browser-compatible WebSockets implementation. * * See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket + * See https://github.com/websockets/ws */ class WebSocket extends WebSocketBase { _socketId: number; _subs: any; - connectToSocketImpl(url: string): void { + connectToSocketImpl(url: string, protocols: ?Array, options: ?{origin?: string}): void { this._socketId = WebSocketId++; - RCTWebSocketModule.connect(url, this._socketId); + RCTWebSocketModule.connect(url, protocols, options, this._socketId); this._registerEvents(this._socketId); } diff --git a/Libraries/WebSocket/WebSocketBase.js b/Libraries/WebSocket/WebSocketBase.js index ef52352d4..931b29749 100644 --- a/Libraries/WebSocket/WebSocketBase.js +++ b/Libraries/WebSocket/WebSocketBase.js @@ -33,19 +33,23 @@ class WebSocketBase extends EventTarget { readyState: number; url: ?string; - constructor(url: string, protocols: ?any) { + constructor(url: string, protocols: ?string | ?Array, options: ?{origin?: string}) { super(); this.CONNECTING = 0; this.OPEN = 1; this.CLOSING = 2; this.CLOSED = 3; - if (!protocols) { - protocols = []; + if (typeof protocols === 'string') { + protocols = [protocols]; + } + + if (!Array.isArray(protocols)) { + protocols = null; } this.readyState = this.CONNECTING; - this.connectToSocketImpl(url); + this.connectToSocketImpl(url, protocols, options); } close(): void { diff --git a/ReactAndroid/src/main/java/com/facebook/react/modules/websocket/WebSocketModule.java b/ReactAndroid/src/main/java/com/facebook/react/modules/websocket/WebSocketModule.java index dfa41638c..61e1057bf 100644 --- a/ReactAndroid/src/main/java/com/facebook/react/modules/websocket/WebSocketModule.java +++ b/ReactAndroid/src/main/java/com/facebook/react/modules/websocket/WebSocketModule.java @@ -10,6 +10,7 @@ package com.facebook.react.modules.websocket; import java.io.IOException; +import javax.annotation.Nullable; import com.facebook.common.logging.FLog; import com.facebook.react.bridge.Arguments; @@ -17,6 +18,10 @@ import com.facebook.react.bridge.ReactApplicationContext; import com.facebook.react.bridge.ReactContext; import com.facebook.react.bridge.ReactContextBaseJavaModule; import com.facebook.react.bridge.ReactMethod; +import com.facebook.react.bridge.ReadableArray; +import com.facebook.react.bridge.ReadableMap; +import com.facebook.react.bridge.ReadableMapKeySetIterator; +import com.facebook.react.bridge.ReadableType; import com.facebook.react.bridge.WritableMap; import com.facebook.react.common.ReactConstants; import com.facebook.react.modules.core.DeviceEventManagerModule; @@ -57,7 +62,8 @@ public class WebSocketModule extends ReactContextBaseJavaModule { } @ReactMethod - public void connect(final String url, final int id) { + public void connect(final String url, @Nullable final ReadableArray protocols, @Nullable final ReadableMap options, final int id) { + // ignoring protocols, since OKHttp overrides them. OkHttpClient client = new OkHttpClient(); client.setConnectTimeout(10, TimeUnit.SECONDS); @@ -65,12 +71,21 @@ public class WebSocketModule extends ReactContextBaseJavaModule { // Disable timeouts for read client.setReadTimeout(0, TimeUnit.MINUTES); - Request request = new Request.Builder() + Request.Builder builder = new Request.Builder() .tag(id) - .url(url) - .build(); + .url(url); - WebSocketCall.create(client, request).enqueue(new WebSocketListener() { + if (options != null && options.hasKey("origin")) { + if (ReadableType.String.equals(options.getType("origin"))) { + builder.addHeader("Origin", options.getString("origin")); + } else { + FLog.w( + ReactConstants.TAG, + "Ignoring: requested origin, value not a string"); + } + } + + WebSocketCall.create(client, builder.build()).enqueue(new WebSocketListener() { @Override public void onOpen(WebSocket webSocket, Response response) {