diff --git a/modules/websocket/wsl_server.cpp b/modules/websocket/wsl_server.cpp index 1e140a716f8..0d09a4d74e3 100644 --- a/modules/websocket/wsl_server.cpp +++ b/modules/websocket/wsl_server.cpp @@ -42,7 +42,7 @@ WSLServer::PendingPeer::PendingPeer() { memset(req_buf, 0, sizeof(req_buf)); } -bool WSLServer::PendingPeer::_parse_request(String &r_key) { +bool WSLServer::PendingPeer::_parse_request(const PoolStringArray p_protocols) { Vector psa = String((char *)req_buf).split("\r\n"); int len = psa.size(); if (len < 4) { @@ -87,11 +87,29 @@ bool WSLServer::PendingPeer::_parse_request(String &r_key) { _WLS_CHECK_EX("connection"); #undef _WLS_CHECK_EX #undef _WLS_CHECK - r_key = headers["sec-websocket-key"]; + key = headers["sec-websocket-key"]; + if (headers.has("sec-websocket-protocol")) { + Vector protos = headers["sec-websocket-protocol"].split(","); + for (int i = 0; i < protos.size(); i++) { + // Check if we have the given protocol + for (int j = 0; j < p_protocols.size(); j++) { + if (protos[i] != p_protocols[j]) + continue; + protocol = protos[i]; + break; + } + // Found a protocol + if (protocol != "") + break; + } + if (protocol == "") // Invalid protocol(s) requested + return false; + } else if (p_protocols.size() > 0) // No protocol requested, but we need one + return false; return true; } -Error WSLServer::PendingPeer::do_handshake() { +Error WSLServer::PendingPeer::do_handshake(PoolStringArray p_protocols) { if (OS::get_singleton()->get_ticks_msec() - time > WSL_SERVER_TIMEOUT) return ERR_TIMEOUT; if (!has_request) { @@ -111,13 +129,15 @@ Error WSLServer::PendingPeer::do_handshake() { int l = req_pos; if (l > 3 && r[l] == '\n' && r[l - 1] == '\r' && r[l - 2] == '\n' && r[l - 3] == '\r') { r[l - 3] = '\0'; - if (!_parse_request(key)) { + if (!_parse_request(p_protocols)) { return FAILED; } String s = "HTTP/1.1 101 Switching Protocols\r\n"; s += "Upgrade: websocket\r\n"; s += "Connection: Upgrade\r\n"; s += "Sec-WebSocket-Accept: " + WSLPeer::compute_key_response(key) + "\r\n"; + if (protocol != "") + s += "Sec-WebSocket-Protocol: " + protocol + "\r\n"; s += "\r\n"; response = s.utf8(); has_request = true; @@ -143,6 +163,7 @@ Error WSLServer::listen(int p_port, PoolVector p_protocols, bool gd_mp_a ERR_FAIL_COND_V(is_listening(), ERR_ALREADY_IN_USE); _is_multiplayer = gd_mp_api; + _protocols = p_protocols; _server->listen(p_port); return OK; @@ -167,7 +188,7 @@ void WSLServer::poll() { List > remove_peers; for (List >::Element *E = _pending.front(); E; E = E->next()) { Ref ppeer = E->get(); - Error err = ppeer->do_handshake(); + Error err = ppeer->do_handshake(_protocols); if (err == ERR_BUSY) { continue; } else if (err != OK) { @@ -188,7 +209,7 @@ void WSLServer::poll() { _peer_map[id] = ws_peer; remove_peers.push_back(ppeer); - _on_connect(id, ""); + _on_connect(id, ppeer->protocol); } for (List >::Element *E = remove_peers.front(); E; E = E->next()) { _pending.erase(E->get()); diff --git a/modules/websocket/wsl_server.h b/modules/websocket/wsl_server.h index b0520bd731c..2ceb9410733 100644 --- a/modules/websocket/wsl_server.h +++ b/modules/websocket/wsl_server.h @@ -49,7 +49,7 @@ private: class PendingPeer : public Reference { private: - bool _parse_request(String &r_key); + bool _parse_request(const PoolStringArray p_protocols); public: Ref connection; @@ -58,13 +58,14 @@ private: uint8_t req_buf[WSL_MAX_HEADER_SIZE]; int req_pos; String key; + String protocol; bool has_request; CharString response; int response_sent; PendingPeer(); - Error do_handshake(); + Error do_handshake(const PoolStringArray p_protocols); }; int _in_buf_size; @@ -74,6 +75,7 @@ private: List > _pending; Ref _server; + PoolStringArray _protocols; public: Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets);