feat: realtime support (#2)

* feat(wip): initial realtime impl

* chore: move sample app to samples folder

* feat: realtime + pencilkit sample

* chore: update readme
This commit is contained in:
Daniel Rochetti
2023-11-27 22:56:40 -08:00
committed by GitHub
parent b56e8d2baa
commit 85d298c698
30 changed files with 864 additions and 475 deletions

View File

@@ -1,3 +1,4 @@
import Dispatch
import Foundation
public struct EmptyInput: Encodable {
@@ -18,7 +19,7 @@ public extension Client {
}
func run<Output: Decodable>(
_ id: String,
_ app: String,
input: (some Encodable) = EmptyInput.empty,
options: RunOptions = DefaultRunOptions
) async throws -> Output {
@@ -27,25 +28,25 @@ public extension Client {
? try JSONSerialization.jsonObject(with: inputData!) as? [String: Any]
: nil
let data = try await sendRequest(id, input: inputData, queryParams: queryParams, options: options)
let url = buildUrl(fromId: app, path: options.path)
let data = try await sendRequest(url, input: inputData, queryParams: queryParams, options: options)
return try decoder.decode(Output.self, from: data)
}
func subscribe<Output: Decodable>(
_ id: String,
to app: String,
input: (some Encodable) = EmptyInput.empty,
pollInterval: FalTimeInterval = .seconds(1),
timeout: FalTimeInterval = .minutes(3),
pollInterval: DispatchTimeInterval = .seconds(1),
timeout: DispatchTimeInterval = .minutes(3),
includeLogs: Bool = false,
options _: RunOptions = DefaultRunOptions,
onQueueUpdate: OnQueueUpdate? = nil
) async throws -> Output {
let requestId = try await queue.submit(id, input: input)
let requestId = try await queue.submit(app, input: input)
let start = Int(Date().timeIntervalSince1970 * 1000)
var elapsed = 0
var isCompleted = false
while elapsed < timeout.milliseconds {
let update = try await queue.status(id, of: requestId, includeLogs: includeLogs)
let update = try await queue.status(app, of: requestId, includeLogs: includeLogs)
if let onQueueUpdateCallback = onQueueUpdate {
onQueueUpdateCallback(update)
}
@@ -59,6 +60,6 @@ public extension Client {
if !isCompleted {
throw FalError.queueTimeout
}
return try await queue.response(id, of: requestId)
return try await queue.response(app, of: requestId)
}
}

View File

@@ -1,8 +1,7 @@
import Foundation
extension Client {
func sendRequest(_ id: String, input: Data?, queryParams: [String: Any]? = nil, options: RequestOptions) async throws -> Data {
let urlString = buildUrl(fromId: id, path: options.path)
func sendRequest(_ urlString: String, input: Data?, queryParams: [String: Any]? = nil, options: RunOptions) async throws -> Data {
guard var url = URL(string: urlString) else {
throw FalError.invalidUrl(url: urlString)
}
@@ -49,6 +48,6 @@ extension Client {
var userAgent: String {
let osVersion = ProcessInfo.processInfo.operatingSystemVersionString
return "fal.ai/swift-client 0.0.1 - \(osVersion)"
return "fal.ai/swift-client 0.1.0 - \(osVersion)"
}
}

View File

@@ -1,26 +1,27 @@
import Dispatch
import Foundation
enum HttpMethod: String {
public enum HttpMethod: String {
case get
case post
case put
case delete
}
protocol RequestOptions {
public protocol RequestOptions {
var httpMethod: HttpMethod { get }
var path: String { get }
}
public struct RunOptions: RequestOptions {
let path: String
let httpMethod: HttpMethod
public let path: String
public let httpMethod: HttpMethod
static func withMethod(_ method: HttpMethod) -> Self {
static func withMethod(_ method: HttpMethod) -> RunOptions {
RunOptions(path: "", httpMethod: method)
}
static func route(_ path: String, withMethod method: HttpMethod = .post) -> Self {
static func route(_ path: String, withMethod method: HttpMethod = .post) -> RunOptions {
RunOptions(path: path, httpMethod: method)
}
}
@@ -34,39 +35,38 @@ public protocol Client {
var queue: Queue { get }
var realtime: Realtime { get }
func run(_ id: String, input: [String: Any]?, options: RunOptions) async throws -> [String: Any]
func subscribe(
_ id: String,
to app: String,
input: [String: Any]?,
pollInterval: FalTimeInterval,
timeout: FalTimeInterval,
pollInterval: DispatchTimeInterval,
timeout: DispatchTimeInterval,
includeLogs: Bool,
options: RunOptions,
onQueueUpdate: OnQueueUpdate?
) async throws -> [String: Any]
}
public extension Client {
func run(_ id: String, input: [String: Any]? = nil, options: RunOptions = DefaultRunOptions) async throws -> [String: Any] {
return try await run(id, input: input, options: options)
func run(_ app: String, input: [String: Any]? = nil, options: RunOptions = DefaultRunOptions) async throws -> [String: Any] {
return try await run(app, input: input, options: options)
}
func subscribe(
_ id: String,
to app: String,
input: [String: Any]? = nil,
pollInterval: FalTimeInterval = .seconds(1),
timeout: FalTimeInterval = .minutes(3),
pollInterval: DispatchTimeInterval = .seconds(1),
timeout: DispatchTimeInterval = .minutes(3),
includeLogs: Bool = false,
options: RunOptions = DefaultRunOptions,
onQueueUpdate: OnQueueUpdate? = nil
) async throws -> [String: Any] {
return try await subscribe(id,
return try await subscribe(to: app,
input: input,
pollInterval: pollInterval,
timeout: timeout,
includeLogs: includeLogs,
options: options,
onQueueUpdate: onQueueUpdate)
}
}

View File

@@ -1,3 +1,4 @@
import Dispatch
import Foundation
func buildUrl(fromId id: String, path: String? = nil) -> String {
@@ -28,10 +29,13 @@ public struct FalClient: Client {
public var queue: Queue { QueueClient(client: self) }
public func run(_ id: String, input: [String: Any]?, options: RunOptions) async throws -> [String: Any] {
public var realtime: Realtime { RealtimeClient(client: self) }
public func run(_ app: String, input: [String: Any]?, options: RunOptions) async throws -> [String: Any] {
let inputData = input != nil ? try JSONSerialization.data(withJSONObject: input as Any) : nil
let queryParams = options.httpMethod == .get ? input : nil
let data = try await sendRequest(id, input: inputData, queryParams: queryParams, options: options)
let url = buildUrl(fromId: app, path: options.path)
let data = try await sendRequest(url, input: inputData, queryParams: queryParams, options: options)
guard let result = try JSONSerialization.jsonObject(with: data) as? [String: Any] else {
throw FalError.invalidResultFormat
}
@@ -39,20 +43,19 @@ public struct FalClient: Client {
}
public func subscribe(
_ id: String,
to app: String,
input: [String: Any]?,
pollInterval: FalTimeInterval,
timeout: FalTimeInterval,
pollInterval: DispatchTimeInterval,
timeout: DispatchTimeInterval,
includeLogs: Bool,
options _: RunOptions,
onQueueUpdate: OnQueueUpdate?
) async throws -> [String: Any] {
let requestId = try await queue.submit(id, input: input)
let requestId = try await queue.submit(app, input: input)
let start = Int(Date().timeIntervalSince1970 * 1000)
var elapsed = 0
var isCompleted = false
while elapsed < timeout.milliseconds {
let update = try await queue.status(id, of: requestId, includeLogs: includeLogs)
let update = try await queue.status(app, of: requestId, includeLogs: includeLogs)
if let onQueueUpdateCallback = onQueueUpdate {
onQueueUpdateCallback(update)
}
@@ -66,12 +69,16 @@ public struct FalClient: Client {
if !isCompleted {
throw FalError.queueTimeout
}
return try await queue.response(id, of: requestId)
return try await queue.response(app, of: requestId)
}
}
public extension FalClient {
static func withProxy(_ url: String) -> FalClient {
static func withProxy(_ url: String) -> Client {
return FalClient(config: ClientConfig(requestProxy: url))
}
static func withCredentials(_ credentials: ClientCredentials) -> Client {
return FalClient(config: ClientConfig(credentials: credentials))
}
}

View File

@@ -1,19 +0,0 @@
public enum FalTimeInterval {
case milliseconds(Int)
case seconds(Int)
case minutes(Int)
case hours(Int)
var milliseconds: Int {
switch self {
case let .milliseconds(value):
return value
case let .seconds(value):
return value * 1000
case let .minutes(value):
return value * 60 * 1000
case let .hours(value):
return value * 60 * 60 * 1000
}
}
}

View File

@@ -0,0 +1,30 @@
import Dispatch
import Foundation
class CodableRealtimeConnection<Input: Encodable>: RealtimeConnection<Input> {
override public func send(_ data: Input) throws {
let json = try JSONEncoder().encode(data)
try sendReference(json)
}
}
public extension Realtime {
func connect<Input: Encodable, Output: Decodable>(
to app: String,
connectionKey: String,
throttleInterval: DispatchTimeInterval,
onResult completion: @escaping (Result<Output, Error>) -> Void
) throws -> RealtimeConnection<Input> {
return handleConnection(
to: app, connectionKey: connectionKey, throttleInterval: throttleInterval,
resultConverter: { data in
let result = try JSONDecoder().decode(Output.self, from: data)
return result
},
connectionFactory: { send, close in
CodableRealtimeConnection(send, close)
},
onResult: completion
)
}
}

View File

@@ -0,0 +1,365 @@
import Dispatch
import Foundation
func throttle<T>(_ function: @escaping (T) -> Void, throttleInterval: DispatchTimeInterval) -> ((T) -> Void) {
var lastExecution = DispatchTime.now()
let throttledFunction: ((T) -> Void) = { input in
if DispatchTime.now() > lastExecution + throttleInterval {
lastExecution = DispatchTime.now()
function(input)
}
}
return throttledFunction
}
public enum FalRealtimeError: Error {
case connectionError
case unauthorized
case invalidResult
}
public class RealtimeConnection<Input> {
var sendReference: SendFunction
var closeReference: CloseFunction
init(_ send: @escaping SendFunction, _ close: @escaping CloseFunction) {
sendReference = send
closeReference = close
}
public func close() {
closeReference()
}
public func send(_: Input) throws {
preconditionFailure("This method must be overridden to handle \(Input.self)")
}
}
typealias SendFunction = (Data) throws -> Void
typealias CloseFunction = () -> Void
class UntypedRealtimeConnection: RealtimeConnection<[String: Any]> {
override public func send(_ data: [String: Any]) throws {
let json = try JSONSerialization.data(withJSONObject: data)
try sendReference(json)
}
}
func buildRealtimeUrl(forApp app: String, host: String, token: String? = nil) -> URL {
var components = URLComponents()
components.scheme = "wss"
components.host = "\(app).\(host)"
components.path = "/ws"
if let token = token {
components.queryItems = [URLQueryItem(name: "fal_jwt_token", value: token)]
}
// swiftlint:disable:next force_unwrapping
return components.url!
}
typealias RefreshTokenFunction = (String, (Result<String, Error>) -> Void) -> Void
private let TokenExpirationInterval: DispatchTimeInterval = .minutes(1)
class WebSocketConnection: NSObject, URLSessionWebSocketDelegate {
let app: String
let client: Client
let onMessage: (Data) -> Void
let onError: (Error) -> Void
private let queue = DispatchQueue(label: "ai.fal.WebSocketConnection.\(UUID().uuidString)")
private let session = URLSession(configuration: .default)
private var enqueuedMessages: [Data] = []
private var task: URLSessionWebSocketTask?
private var token: String?
private var isConnecting = false
private var isRefreshingToken = false
init(
app: String,
client: Client,
onMessage: @escaping (Data) -> Void,
onError: @escaping (Error) -> Void
) {
self.app = app
self.client = client
self.onMessage = onMessage
self.onError = onError
}
func connect() {
if task == nil && !isConnecting && !isRefreshingToken {
isConnecting = true
if token == nil && !isRefreshingToken {
isRefreshingToken = true
refreshToken(app) { result in
switch result {
case let .success(token):
self.token = token
self.isRefreshingToken = false
self.isConnecting = false
// Very simple token expiration handling for now
// Create the deadline 90% of the way through the token's lifetime
let tokenExpirationDeadline: DispatchTime = .now() + TokenExpirationInterval - .seconds(20)
DispatchQueue.main.asyncAfter(deadline: tokenExpirationDeadline) {
self.token = nil
}
self.connect()
case let .failure(error):
self.isConnecting = false
self.isRefreshingToken = false
self.onError(error)
}
}
return
}
// TODO: get host from config
let url = buildRealtimeUrl(forApp: app, host: "gateway.alpha.fal.ai", token: token)
let webSocketTask = session.webSocketTask(with: url)
webSocketTask.delegate = self
task = webSocketTask
// connect and keep the task reference
task?.resume()
isConnecting = false
receiveMessage()
}
}
func refreshToken(_ app: String, completion: @escaping (Result<String, Error>) -> Void) {
Task {
// TODO: improve app alias resolution
let appAlias = app.split(separator: "-").dropFirst().joined(separator: "-")
let url = "https://rest.alpha.fal.ai/tokens/"
let body = try? JSONSerialization.data(withJSONObject: [
"allowed_apps": [appAlias],
"token_expiration": 300,
])
do {
let response = try await self.client.sendRequest(
url,
input: body,
options: .withMethod(.post)
)
if let token = String(data: response, encoding: .utf8) {
completion(.success(token.replacingOccurrences(of: "\"", with: "")))
} else {
completion(.failure(FalRealtimeError.unauthorized))
}
} catch {
completion(.failure(error))
}
}
}
func receiveMessage() {
task?.receive { [weak self] incomingMessage in
switch incomingMessage {
case let .success(message):
do {
let data = try message.data()
guard let parsedMessage = try JSONSerialization.jsonObject(with: data) as? [String: Any] else {
self?.onError(FalRealtimeError.invalidResult)
return
}
if isSuccessResult(parsedMessage) {
self?.onMessage(data)
}
// if (parsedMessage["status"] as? String != "error") {
// self?.task?.cancel()
// }
} catch {
self?.onError(error)
}
case let .failure(error):
self?.onError(error)
}
self?.receiveMessage()
}
}
func send(_ data: Data) throws {
if let task = task {
guard let message = String(data: data, encoding: .utf8) else {
return
}
task.send(.string(message)) { [weak self] error in
if let error = error {
self?.onError(error)
}
}
} else {
enqueuedMessages.append(data)
queue.sync {
if !isConnecting {
connect()
}
}
}
}
func close() {
task?.cancel(with: .normalClosure, reason: "Programmatically closed".data(using: .utf8))
}
func urlSession(
_: URLSession,
webSocketTask _: URLSessionWebSocketTask,
didOpenWithProtocol _: String?
) {
if let lastMessage = enqueuedMessages.last {
do {
try send(lastMessage)
} catch {
onError(error)
}
}
enqueuedMessages.removeAll()
}
func urlSession(
_: URLSession,
webSocketTask _: URLSessionWebSocketTask,
didCloseWith _: URLSessionWebSocketTask.CloseCode,
reason _: Data?
) {
task = nil
}
}
var connectionPool: [String: WebSocketConnection] = [:]
public protocol Realtime {
var client: Client { get }
func connect(
to app: String,
connectionKey: String,
throttleInterval: DispatchTimeInterval,
onResult completion: @escaping (Result<[String: Any], Error>) -> Void
) throws -> RealtimeConnection<[String: Any]>
}
func isSuccessResult(_ message: [String: Any]) -> Bool {
return message["status"] as? String != "error" && message["type"] as? String != "x-fal-message"
}
extension URLSessionWebSocketTask.Message {
func data() throws -> Data {
switch self {
case let .data(data):
return data
case let .string(string):
guard let data = string.data(using: .utf8) else {
throw FalRealtimeError.invalidResult
}
return data
@unknown default:
preconditionFailure("Unknown URLSessionWebSocketTask.Message case")
}
}
}
public struct RealtimeClient: Realtime {
// TODO in the future make this non-public
// External APIs should not use it
public let client: Client
init(client: Client) {
self.client = client
}
public func connect(
to app: String,
connectionKey: String,
throttleInterval: DispatchTimeInterval,
onResult completion: @escaping (Result<[String: Any], Error>) -> Void
) throws -> RealtimeConnection<[String: Any]> {
return handleConnection(
to: app,
connectionKey: connectionKey,
throttleInterval: throttleInterval,
resultConverter: { data in
guard let result = try JSONSerialization.jsonObject(with: data) as? [String: Any] else {
throw FalRealtimeError.invalidResult
}
return result
},
connectionFactory: { send, close in
UntypedRealtimeConnection(send, close)
},
onResult: completion
)
}
}
extension Realtime {
internal func handleConnection<InputType, ResultType>(
to app: String,
connectionKey: String,
throttleInterval: DispatchTimeInterval,
resultConverter convertToResultType: @escaping (Data) throws -> ResultType,
connectionFactory createRealtimeConnection: @escaping (@escaping SendFunction, @escaping CloseFunction) -> RealtimeConnection<InputType>,
onResult completion: @escaping (Result<ResultType, Error>) -> Void
) -> RealtimeConnection<InputType> {
let key = "\(app):\(connectionKey)"
let ws = connectionPool[key] ?? WebSocketConnection(
app: app,
client: self.client,
onMessage: { data in
do {
let result = try convertToResultType(data)
completion(.success(result))
} catch {
completion(.failure(error))
}
},
onError: { error in
completion(.failure(error))
}
)
if connectionPool[key] == nil {
connectionPool[key] = ws
}
let sendData = { (data: Data) in
do {
try ws.send(data)
} catch {
completion(.failure(error))
}
}
let send: SendFunction = throttleInterval.milliseconds > 0 ? throttle(sendData, throttleInterval: throttleInterval) : sendData
let close: CloseFunction = {
ws.close()
}
return createRealtimeConnection(send, close)
}
}
public extension Realtime {
func connect(
to app: String,
connectionKey: String = UUID().uuidString,
throttleInterval: DispatchTimeInterval = .milliseconds(64),
onResult completion: @escaping (Result<[String: Any], Error>) -> Void
) throws -> RealtimeConnection<[String: Any]> {
return try connect(
to: app,
connectionKey: connectionKey,
throttleInterval: throttleInterval,
onResult: completion
)
}
}

View File

@@ -0,0 +1,24 @@
import Dispatch
extension DispatchTimeInterval {
public static func minutes(_ value: Int) -> DispatchTimeInterval {
return .seconds(value * 60)
}
var milliseconds: Int {
switch self {
case let .milliseconds(value):
return value
case let .seconds(value):
return value * 1000
case let .microseconds(value):
return value / 1000
case let .nanoseconds(value):
return value / 1_000_000
case .never:
return 0
@unknown default:
return 0
}
}
}