Skip to content

Commit

Permalink
Add WebSocket connection backend
Browse files Browse the repository at this point in the history
  • Loading branch information
njbooher committed Sep 9, 2024
1 parent 26166e0 commit e5b73f9
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 35 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ gevent-eventemitter~=2.1
cachetools>=3.0.0
enum34==1.1.2; python_version < '3.4'
win-inet-pton; python_version == '2.7' and sys_platform == 'win32'
wsproto~=1.2.0
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
'protobuf~=3.0; python_version >= "3"',
'protobuf<3.18.0; python_version < "3"',
'gevent-eventemitter~=2.1',
'wsproto~=1.2.0',
],
}

Expand Down
4 changes: 2 additions & 2 deletions steam/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class SteamClient(CMClient, BuiltinBase):
login_key = None #: can be used for subsequent logins (no 2FA code will be required)
chat_mode = 2 #: chat mode (0=old chat, 2=new chat)

def __init__(self):
CMClient.__init__(self)
def __init__(self, protocol=CMClient.PROTOCOL_TCP):
CMClient.__init__(self, protocol=protocol)

# register listners
self.on(self.EVENT_DISCONNECTED, self._handle_disconnect)
Expand Down
51 changes: 46 additions & 5 deletions steam/core/cm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from steam.enums import EResult, EUniverse
from steam.enums.emsg import EMsg
from steam.core import crypto
from steam.core.connection import TCPConnection
from steam.core.connection import TCPConnection, WebsocketConnection
from steam.core.msg import Msg, MsgProto
from eventemitter import EventEmitter
from steam.utils import ip4_from_int
Expand Down Expand Up @@ -59,6 +59,7 @@ class CMClient(EventEmitter):

PROTOCOL_TCP = 0 #: TCP protocol enum
PROTOCOL_UDP = 1 #: UDP protocol enum
PROTOCOL_WEBSOCKET = 2 #: WEBSOCKET protocol enum
verbose_debug = False #: print message connects in debug

auto_discovery = True #: enables automatic CM discovery
Expand All @@ -83,10 +84,12 @@ class CMClient(EventEmitter):
def __init__(self, protocol=PROTOCOL_TCP):
self.cm_servers = CMServerList()

if protocol == CMClient.PROTOCOL_TCP:
if protocol == CMClient.PROTOCOL_WEBSOCKET:
self.connection = WebsocketConnection()
elif protocol == CMClient.PROTOCOL_TCP:
self.connection = TCPConnection()
else:
raise ValueError("Only TCP is supported")
raise ValueError("Only Websocket and TCP are supported")

self.on(EMsg.ChannelEncryptRequest, self.__handle_encrypt_request),
self.on(EMsg.Multi, self.__handle_multi),
Expand Down Expand Up @@ -132,8 +135,11 @@ def connect(self, retry=0, delay=0):
self._connecting = False
return False

if not self.cm_servers.bootstrap_from_webapi():
self.cm_servers.bootstrap_from_dns()
if isinstance(self.connection, WebsocketConnection):
self.cm_servers.bootstrap_from_webapi_websocket()
elif isinstance(self.connection, TCPConnection):
if not self.cm_servers.bootstrap_from_webapi():
self.cm_servers.bootstrap_from_dns()

for i, server_addr in enumerate(cycle(self.cm_servers), start=next(i)-1):
if retry and i >= retry:
Expand All @@ -154,6 +160,12 @@ def connect(self, retry=0, delay=0):
self.current_server_addr = server_addr
self.connected = True
self.emit(self.EVENT_CONNECTED)

# WebsocketConnection secures itself
if isinstance(self.connection, WebsocketConnection):
self.channel_secured = True
self.emit(self.EVENT_CHANNEL_SECURED)

self._recv_loop = gevent.spawn(self._recv_messages)
self._connecting = False
return True
Expand Down Expand Up @@ -509,7 +521,36 @@ def str_to_tuple(serveraddr):
self.merge_list(map(str_to_tuple, serverlist))

return True

def bootstrap_from_webapi_websocket(self):
"""
Fetches CM server list from WebAPI and replaces the current one
:return: booststrap success
:rtype: :class:`bool`
"""
self._LOG.debug("Attempting bootstrap via WebAPI for websocket")

from steam import webapi
try:
resp = webapi.get('ISteamDirectory', 'GetCMListForConnect', 1, params={'cmtype': 'websockets',
'http_timeout': 3})
except Exception as exp:
self._LOG.error("WebAPI boostrap failed: %s" % str(exp))
return False

serverlist = resp['response']['serverlist']
self._LOG.debug("Received %d servers from WebAPI" % len(serverlist))

def str_to_tuple(serverinfo):
ip, port = serverinfo['endpoint'].split(':')
return str(ip), int(port)

self.clear()
self.merge_list(map(str_to_tuple, serverlist))

return True

def __iter__(self):
def cm_server_iter():
if not self.list:
Expand Down
171 changes: 143 additions & 28 deletions steam/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
from gevent import event
from gevent.select import select as gselect

import ssl
import certifi

from wsproto import WSConnection, events as wsevents
from wsproto.connection import ConnectionType, ConnectionState

logger = logging.getLogger("Connection")


class Connection(object):
MAGIC = b'VT01'
FMT = '<I4s'
FMT_SIZE = struct.calcsize(FMT)

def __init__(self):
self.socket = None
Expand Down Expand Up @@ -48,8 +51,9 @@ def connect(self, server_addr):
self._reader = gevent.spawn(self._reader_loop)
self._writer = gevent.spawn(self._writer_loop)

logger.debug("Connected.")
self.event_connected.set()
# how this gets set is implementation dependent
self.event_connected.wait(timeout=10)

return True

def disconnect(self):
Expand Down Expand Up @@ -80,11 +84,46 @@ def __iter__(self):

def put_message(self, message):
self.send_queue.put(message)

def _new_socket(self):
raise TypeError("{}: _new_socket is unimplemented".format(self.__class__.__name__))

def _connect(self, server_addr):
raise TypeError("{}: _connect is unimplemented".format(self.__class__.__name__))

def _reader_loop(self):
raise TypeError("{}: _reader_loop is unimplemented".format(self.__class__.__name__))

def _writer_loop(self):
raise TypeError("{}: _writer_loop is unimplemented".format(self.__class__.__name__))

class TCPConnection(Connection):

MAGIC = b'VT01'
FMT = '<I4s'
FMT_SIZE = struct.calcsize(FMT)

def _new_socket(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

def _connect(self, server_addr):
self.socket.connect(server_addr)
logger.debug("Connected.")
self.event_connected.set()

def _read_data(self):
try:
return self.socket.recv(16384)
except socket.error:
return ''

def _write_data(self, data):
self.socket.sendall(data)

def _writer_loop(self):
while True:
message = self.send_queue.get()
packet = struct.pack(Connection.FMT, len(message), Connection.MAGIC) + message
packet = struct.pack(TCPConnection.FMT, len(message), TCPConnection.MAGIC) + message
try:
self._write_data(packet)
except:
Expand All @@ -108,13 +147,13 @@ def _reader_loop(self):
self._read_packets()

def _read_packets(self):
header_size = Connection.FMT_SIZE
header_size = TCPConnection.FMT_SIZE
buf = self._readbuf

while len(buf) > header_size:
message_length, magic = struct.unpack_from(Connection.FMT, buf)
message_length, magic = struct.unpack_from(TCPConnection.FMT, buf)

if magic != Connection.MAGIC:
if magic != TCPConnection.MAGIC:
logger.debug("invalid magic, got %s" % repr(magic))
self.disconnect()
return
Expand All @@ -131,33 +170,109 @@ def _read_packets(self):

self._readbuf = buf

class WebsocketConnection(Connection):

def __init__(self):
super(WebsocketConnection, self).__init__()
self.ws = WSConnection(ConnectionType.CLIENT)
self.ssl_ctx = ssl.create_default_context(cafile=certifi.where())
self.event_wsdisconnected = event.Event()

class TCPConnection(Connection):
def _new_socket(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.raw_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

def _connect(self, server_addr):
self.socket.connect(server_addr)

def _read_data(self):
try:
return self.socket.recv(16384)
except socket.error:
return ''
host, port = server_addr

for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
try:
# tcp socket
_, _, _, _, sa = res
self.raw_socket.connect(sa)
self.socket = self.ssl_ctx.wrap_socket(self.raw_socket, server_hostname=host)
# websocket
ws_host = ':'.join(map(str,server_addr))
ws_send = self.ws.send(wsevents.Request(host=ws_host, target="/cmsocket/"))
self.socket.sendall(ws_send)
return
except socket.error:
if self.socket is not None:
self.socket.close()

def _writer_loop(self):
while True:
message = self.send_queue.get()
try:
logger.debug("sending message of length {}".format(len(message)))
self.socket.sendall(self.ws.send(wsevents.Message(data=message)))
except:
logger.debug("Connection error (writer).")
self.disconnect()
return

def _reader_loop(self):
while True:
rlist, _, _ = gselect([self.socket], [], [])

def _write_data(self, data):
self.socket.sendall(data)
if self.socket in rlist:

try:
data = self.socket.recv(16384)
except socket.error:
data = ''

if not data:
logger.debug("Connection error (reader).")
# A receive of zero bytes indicates the TCP socket has been closed. We
# need to pass None to wsproto to update its internal state.
logger.debug("Received 0 bytes (connection closed)")
self.ws.receive_data(None)
# now disconnect
self.disconnect()
return

logger.debug("Received {} bytes".format(len(data)))
self.ws.receive_data(data)
self._handle_events()

def _handle_events(self):
for event in self.ws.events():
if isinstance(event, wsevents.AcceptConnection):
logger.debug("WebSocket negotiation complete. Connected.")
self.event_connected.set()
elif isinstance(event, wsevents.RejectConnection):
logger.debug("WebSocket connection was rejected. That's probably not good.")
elif isinstance(event, wsevents.TextMessage):
logger.debug("Received websocket text message of length: {}".format(len(event.data)))
elif isinstance(event, wsevents.BytesMessage):
logger.debug("Received websocket bytes message of length: {}".format(len(event.data)))
self.recv_queue.put(event.data)
elif isinstance(event, wsevents.Pong):
logger.debug("Received pong: {}".format(repr(event.payload)))
elif isinstance(event, wsevents.CloseConnection):
logger.debug('Connection closed: code={} reason={}'.format(
event.code, event.reason
))
if self.ws.state == ConnectionState.REMOTE_CLOSING:
self.socket.send(self.ws.send(event.response()))
self.event_wsdisconnected.set()
else:
raise TypeError("Do not know how to handle event: {}".format((event)))

def disconnect(self):
self.event_wsdisconnected.clear()

# WebSocket closing handshake
if self.ws.state == ConnectionState.OPEN:
logger.debug("Disconnect called. Sending CloseConnection message.")
self.socket.sendall(self.ws.send(wsevents.CloseConnection(code=1000, reason="sample reason")))
self.socket.shutdown(socket.SHUT_WR)
# wait for notification from _reader_loop that the closing response was received
self.event_wsdisconnected.wait()

super(WebsocketConnection, self).disconnect()

class UDPConnection(Connection):
def _new_socket(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

def _connect(self, server_addr):
pass

def _read_data(self):
pass

def _write_data(self, data):
pass

0 comments on commit e5b73f9

Please sign in to comment.