Source code for tornadio2.session

# -*- coding: utf-8 -*-
#
# Copyright: (c) 2011 by the Serge S. Koval, see AUTHORS for more details.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

"""
    tornadio2.session
    ~~~~~~~~~~~~~~~~~

    Active TornadIO2 connection session.
"""

import urlparse
import logging


logger = logging.getLogger('tornadio2.session')

from tornado.web import HTTPError

from tornadio2 import sessioncontainer, proto, periodic, stats


[docs]class ConnectionInfo(object): """Connection information object. Will be passed to the ``on_open`` handler of your connection class. Has few properties: `ip` Caller IP address `cookies` Collection of cookies `arguments` Collection of the query string arguments """ def __init__(self, ip, arguments, cookies): self.ip = ip self.cookies = cookies self.arguments = arguments
[docs] def get_argument(self, name): """Return single argument by name""" val = self.arguments.get(name) if val: return val[0] return None
[docs]class Session(sessioncontainer.SessionBase): """Socket.IO session implementation. Session has some publicly accessible properties: `server` Server association. Server contains io_loop instance, settings, etc. `remote_ip` Remote IP `is_closed` Check if session is closed or not. """
[docs] def __init__(self, conn, server, request, expiry=None): """Session constructor. `conn` Default connection class `server` Associated server `handler` Request handler that created new session `expiry` Session expiry """ # Initialize session super(Session, self).__init__(None, expiry) self.server = server self.send_queue = [] self.handler = None # Stats server.stats.session_opened() self.remote_ip = request.remote_ip # Create connection instance self.conn = conn(self) # Call on_open. self.info = ConnectionInfo(request.remote_ip, request.arguments, request.cookies) # If everything is fine - continue self.send_message(proto.connect()) # Heartbeat related stuff self._heartbeat_timer = None self._heartbeat_interval = self.server.settings['heartbeat_interval'] * 1000 self._missed_heartbeats = 0 # Endpoints self.endpoints = dict() result = self.conn.on_open(self.info) if result is not None and not result: raise HTTPError(401) # Session callbacks
[docs] def on_delete(self, forced): """Session expiration callback `forced` If session item explicitly deleted, forced will be set to True. If item expired, will be set to False. """ # Do not remove connection if it was not forced and there's running connection if not forced and self.handler is not None and not self.is_closed: self.promote() else: self.close() # Add session
[docs] def set_handler(self, handler): """Set active handler for the session `handler` Associate active Tornado handler with the session """ # Check if session already has associated handler if self.handler is not None: return False # If IP address don't match - refuse connection if self.server.settings['verify_remote_ip'] and handler.request.remote_ip != self.remote_ip: logger.error('Attempted to attach to session %s (%s) from different IP (%s)' % ( self.session_id, self.remote_ip, handler.request.remote_ip )) return False # Associate handler and promote self.handler = handler self.promote() # Stats self.server.stats.connection_opened() return True
[docs] def remove_handler(self, handler): """Remove active handler from the session `handler` Handler to remove """ # Attempt to remove another handler if self.handler != handler: raise Exception('Attempted to remove invalid handler') self.handler = None self.promote() self.server.stats.connection_closed()
[docs] def send_message(self, pack): """Send socket.io encoded message `pack` Encoded socket.io message """ logger.debug('<<< ' + pack) # TODO: Possible optimization if there's on-going connection - there's no # need to queue messages? self.send_queue.append(pack) self.flush()
[docs] def flush(self): """Flush message queue if there's an active connection running""" if self.handler is None: return if not self.send_queue: return self.handler.send_messages(self.send_queue) self.send_queue = [] # If session was closed, detach connection if self.is_closed and self.handler is not None: self.handler.session_closed() # Close connection with all endpoints or just one endpoint
[docs] def close(self, endpoint=None): """Close session or endpoint connection. `endpoint` If endpoint is passed, will close open endpoint connection. Otherwise will close whole socket. """ if endpoint is None: if not self.conn.is_closed: # Close child connections for k in self.endpoints.keys(): self.disconnect_endpoint(k) # Close parent connections try: self.conn.on_close() finally: self.conn.is_closed = True # Stats self.server.stats.session_closed() # Stop heartbeats self.stop_heartbeat() # Send disconnection message self.send_message(proto.disconnect()) # Notify transport that session was closed if self.handler is not None: self.handler.session_closed() else: # Disconnect endpoint self.disconnect_endpoint(endpoint)
@property def is_closed(self): """Check if session was closed""" return self.conn.is_closed # Heartbeats
[docs] def reset_heartbeat(self): """Reset hearbeat timer""" self.stop_heartbeat() self._heartbeat_timer = periodic.Callback(self._heartbeat, self._heartbeat_interval, self.server.io_loop) self._heartbeat_timer.start()
[docs] def stop_heartbeat(self): """Stop active heartbeat""" if self._heartbeat_timer is not None: self._heartbeat_timer.stop() self._heartbeat_timer = None
[docs] def delay_heartbeat(self): """Delay active heartbeat""" if self._heartbeat_timer is not None: self._heartbeat_timer.delay()
[docs] def _heartbeat(self): """Heartbeat callback""" self.send_message(proto.heartbeat()) self._missed_heartbeats += 1 # TODO: Configurable if self._missed_heartbeats > 2: self.close() # Endpoints
[docs] def connect_endpoint(self, url): """Connect endpoint from URL. `url` socket.io endpoint URL. """ urldata = urlparse.urlparse(url) endpoint = urldata.path conn = self.endpoints.get(endpoint, None) if conn is None: conn_class = self.conn.get_endpoint(endpoint) if conn_class is None: logger.error('There is no handler for endpoint %s' % endpoint) return conn = conn_class(self, endpoint) self.endpoints[endpoint] = conn self.send_message(proto.connect(endpoint)) if conn.on_open(self.info) == False: self.disconnect_endpoint(endpoint)
[docs] def disconnect_endpoint(self, endpoint): """Disconnect endpoint `endpoint` endpoint name """ if endpoint not in self.endpoints: logger.error('Invalid endpoint for disconnect %s' % endpoint) return conn = self.endpoints[endpoint] del self.endpoints[endpoint] conn.on_close() self.send_message(proto.disconnect(endpoint))
def get_connection(self, endpoint): """Get connection object. `endpoint` Endpoint name. If set to None, will return default connection object. """ if endpoint: return self.endpoints.get(endpoint) else: return self.conn # Message handler
[docs] def raw_message(self, msg): """Socket.IO message handler. `msg` Raw socket.io message to handle """ try: logger.debug('>>> ' + msg) parts = msg.split(':', 3) if len(parts) == 3: msg_type, msg_id, msg_endpoint = parts msg_data = None else: msg_type, msg_id, msg_endpoint, msg_data = parts # Packets that don't require valid endpoint if msg_type == proto.DISCONNECT: if not msg_endpoint: self.close() else: self.disconnect_endpoint(msg_endpoint) return elif msg_type == proto.CONNECT: if msg_endpoint: self.connect_endpoint(msg_endpoint) else: # TODO: Disconnect? logger.error('Invalid connect without endpoint') return # All other packets need endpoints conn = self.get_connection(msg_endpoint) if conn is None: logger.error('Invalid endpoint: %s' % msg_endpoint) return if msg_type == proto.HEARTBEAT: self._missed_heartbeats = 0 elif msg_type == proto.MESSAGE: # Handle text message conn.on_message(msg_data) if msg_id: self.send_message(proto.ack(msg_endpoint, msg_id)) elif msg_type == proto.JSON: # Handle json message conn.on_message(proto.json_load(msg_data)) if msg_id: self.send_message(proto.ack(msg_endpoint, msg_id)) elif msg_type == proto.EVENT: # Javascript event event = proto.json_load(msg_data) # TODO: Verify if args = event.get('args', []) won't be slower. args = event.get('args') if args is None: args = [] ack_response = None # It is kind of magic - if there's only one parameter # and it is dict, unpack dictionary. Otherwise, pass # in args if len(args) == 1 and isinstance(args[0], dict): # Fix for the http://bugs.python.org/issue4978 for older Python versions str_args = dict((str(x), y) for x, y in args[0].iteritems()) ack_response = conn.on_event(event['name'], kwargs=str_args) else: ack_response = conn.on_event(event['name'], args=args) if msg_id: if msg_id.endswith('+'): msg_id = msg_id[:-1] self.send_message(proto.ack(msg_endpoint, msg_id, ack_response)) elif msg_type == proto.ACK: # Handle ACK ack_data = msg_data.split('+', 2) data = None if len(ack_data) > 1: data = proto.json_load(ack_data[1]) conn.deque_ack(int(ack_data[0]), data) elif msg_type == proto.ERROR: # TODO: Pass it to handler? logger.error('Incoming error: %s' % msg_data) elif msg_type == proto.NOOP: pass except Exception, ex: logger.exception(ex) # TODO: Add global exception callback? raise