diff --git a/ember/app/authenticators/cookie.js b/ember/app/authenticators/cookie.js deleted file mode 100644 index cb64ed0adc..0000000000 --- a/ember/app/authenticators/cookie.js +++ /dev/null @@ -1,21 +0,0 @@ -import Ember from 'ember'; -import Base from 'ember-simple-auth/authenticators/base'; - -export default Base.extend({ - ajax: Ember.inject.service(), - - authenticate(email, password) { - return this.get('ajax').request('/api/session', { method: 'PUT', json: { email, password } }) - .then(() => this.get('ajax').request('/api/settings/')) - .then(settings => ({ settings })); - }, - - restore(/* data */) { - return this.get('ajax').request('/api/settings/') - .then(settings => ({ settings })); - }, - - invalidate(/* data */) { - return this.get('ajax').request('/api/session', { method: 'DELETE' }); - }, -}); diff --git a/ember/app/authenticators/oauth2.js b/ember/app/authenticators/oauth2.js new file mode 100644 index 0000000000..4df7833ecd --- /dev/null +++ b/ember/app/authenticators/oauth2.js @@ -0,0 +1,32 @@ +import Ember from 'ember'; +import OAuth2PasswordGrant from 'ember-simple-auth/authenticators/oauth2-password-grant'; + +export default OAuth2PasswordGrant.extend({ + ajax: Ember.inject.service(), + + clientId: 'skylines.aero', + serverTokenEndpoint: '/api/oauth/token', + serverTokenRevocationEndpoint: '/api/oauth/revoke', + + authenticate() { + return this._super(...arguments) + .then(data => this._addSettings(data)); + }, + + restore() { + return this._super(...arguments) + .then(data => this._addSettings(data)); + }, + + _addSettings(data) { + let headers = {}; + if (data.access_token) { + headers['Authorization'] = `Bearer ${data.access_token}`; + } + + return this.get('ajax').request('/api/settings', { headers }).then(settings => { + data.settings = settings; + return data; + }); + }, +}); diff --git a/ember/app/components/login-form.js b/ember/app/components/login-form.js index e2a5be80be..aca6452d3f 100644 --- a/ember/app/components/login-form.js +++ b/ember/app/components/login-form.js @@ -13,7 +13,7 @@ export default Ember.Component.extend({ let { email, password } = this.getProperties('email', 'password'); try { - yield this.get('session').authenticate('authenticator:cookie', email, password); + yield this.get('session').authenticate('authenticator:oauth2', email, password); } catch (error) { this.set('error', error); } diff --git a/ember/app/components/upload-flight-form.js b/ember/app/components/upload-flight-form.js index e35a636ac7..5800bcf58f 100644 --- a/ember/app/components/upload-flight-form.js +++ b/ember/app/components/upload-flight-form.js @@ -44,9 +44,6 @@ export default Ember.Component.extend(Validations, { let data = new FormData(form); try { - let csrfToken = yield this.get('ajax').request('/api/flights/upload/csrf').then(it => it.token); - data.append('csrfToken', csrfToken); - let json = yield this.get('ajax').request('/api/flights/upload/', { method: 'POST', data, contentType: false, processData: false }); this.getWithDefault('onUpload', Ember.K)(json); diff --git a/ember/app/services/ajax.js b/ember/app/services/ajax.js index 16f3dbed72..848885a0e6 100644 --- a/ember/app/services/ajax.js +++ b/ember/app/services/ajax.js @@ -1,6 +1,18 @@ +import Ember from 'ember'; import AjaxService from 'ember-ajax/services/ajax'; export default AjaxService.extend({ + session: Ember.inject.service(), + + headers: Ember.computed('session.data.authenticated.access_token', function() { + let headers = {}; + let authToken = this.get('session.data.authenticated.access_token'); + if (authToken) { + headers['Authorization'] = `Bearer ${authToken}`; + } + return headers; + }), + options(url, options = {}) { if (options.json) { options.contentType = 'application/json'; diff --git a/setup.py b/setup.py index 859f9a73f2..cd7a5f297e 100755 --- a/setup.py +++ b/setup.py @@ -18,7 +18,6 @@ 'flask==0.10.1', 'werkzeug==0.9.6', 'Flask-Babel==0.9', - 'Flask-Login==0.2.9', 'Flask-Cache==0.12', 'Flask-Migrate==1.2.0', 'Flask-Script==0.6.7', diff --git a/skylines/app.py b/skylines/app.py index 69add401e5..f9d9e520f2 100644 --- a/skylines/app.py +++ b/skylines/app.py @@ -30,11 +30,6 @@ def add_cache(self): from skylines.frontend.cache import cache cache.init_app(self) - def add_login_manager(self): - """ Create and attach Login extension """ - from skylines.frontend.login import login_manager - login_manager.init_app(self) - def add_logging_handlers(self): if self.debug: return @@ -87,11 +82,14 @@ def create_http_app(*args, **kw): def create_frontend_app(*args, **kw): + from skylines.frontend.oauth import oauth + app = create_http_app('skylines.frontend', *args, **kw) app.add_cache() app.load_egm96() - app.add_login_manager() + + oauth.init_app(app) import skylines.frontend.views skylines.frontend.views.register(app) diff --git a/skylines/frontend/login.py b/skylines/frontend/login.py deleted file mode 100644 index 8c39f5da43..0000000000 --- a/skylines/frontend/login.py +++ /dev/null @@ -1,23 +0,0 @@ -import base64 - -from flask.ext.login import LoginManager - -from skylines.model import User - -login_manager = LoginManager() - - -@login_manager.user_loader -def load_user(user_id): - return User.get(user_id) - - -@login_manager.header_loader -def load_user_from_header(header_val): - try: - header_val = header_val.replace('Basic ', '', 1) - header_val = base64.b64decode(header_val) - email, password = header_val.split(':', 1) - return User.by_credentials(email, password) - except: - return None diff --git a/skylines/frontend/oauth.py b/skylines/frontend/oauth.py new file mode 100644 index 0000000000..73663e833c --- /dev/null +++ b/skylines/frontend/oauth.py @@ -0,0 +1,170 @@ +import time +from functools import wraps + +from itsdangerous import JSONWebSignatureSerializer +from flask import Blueprint, request, abort, jsonify, current_app + +from flask.ext.oauthlib.provider import OAuth2Provider +from flask_oauthlib.provider import OAuth2RequestValidator +from flask_oauthlib.provider.oauth2 import log +from flask_oauthlib.utils import decode_base64 +from oauthlib.common import to_unicode +from oauthlib.oauth2.rfc6749.tokens import random_token_generator + +from skylines.database import db +from skylines.model import User, Client, RefreshToken, AccessToken + + +class CustomProvider(OAuth2Provider): + def __init__(self, *args, **kwargs): + super(CustomProvider, self).__init__(*args, **kwargs) + self.blueprint = Blueprint('oauth', __name__) + + def init_app(self, app): + super(CustomProvider, self).init_app(app) + app.config.setdefault('OAUTH2_PROVIDER_TOKEN_GENERATOR', self.generate_token) + app.config.setdefault('OAUTH2_PROVIDER_REFRESH_TOKEN_GENERATOR', random_token_generator) + + app.jws = JSONWebSignatureSerializer(app.config.get('SECRET_KEY')) + + app.register_blueprint(self.blueprint) + + def generate_token(self, request): + token = { + 'user': request.user.id, + 'exp': int(time.time() + request.expires_in), + } + + if request.scopes is not None: + token['scope'] = ' '.join(request.scopes) + + return current_app.jws.dumps(token) + + def verify_request(self, scopes): + if request.authorization: + from skylines.model import User + + user = User.by_credentials( + request.authorization.username, + request.authorization.password, + ) + + request.user_id = user.id if user else None + return (user is not None), None + + else: + valid, req = super(CustomProvider, self).verify_request(scopes) + + request.user_id = req.access_token.user_id if valid else None + + return valid, req + + def required(self, *args, **kwargs): + return self.require_oauth(*args, **kwargs) + + def optional(self, *scopes): + """Enhance resource with specified scopes.""" + def wrapper(f): + @wraps(f) + def decorated(*args, **kwargs): + for func in self._before_request_funcs: + func() + + if hasattr(request, 'oauth') and request.oauth: + return f(*args, **kwargs) + + valid, req = self.verify_request(scopes) + + for func in self._after_request_funcs: + valid, req = func(valid, req) + + if not valid and (not req or 'Authorization' in req.headers or req.access_token): + if self._invalid_response: + return self._invalid_response(req) + return abort(401) + request.oauth = req + return f(*args, **kwargs) + return decorated + return wrapper + + +class CustomRequestValidator(OAuth2RequestValidator): + def __init__(self): + super(CustomRequestValidator, self).__init__( + clientgetter=lambda client_id: Client(), + tokengetter=self.tokengetter, + grantgetter=None, + usergetter=User.by_credentials, + tokensetter=self.tokensetter, + ) + + @staticmethod + def tokengetter(access_token=None, refresh_token=None): + """ Retrieve a token record using submitted access token or refresh token. """ + if access_token: + return AccessToken.from_jwt(access_token) + + elif refresh_token: + return RefreshToken.query(refresh_token=refresh_token).first() + + @staticmethod + def tokensetter(token, request, *args, **kwargs): + """ Save a new token to the database. + + :param token: Token dictionary containing access and refresh tokens, plus token type. + :param request: Request dictionary containing information about the client and user. + """ + + if request.grant_type != 'refresh_token': + tok = RefreshToken( + refresh_token=token['refresh_token'], + user_id=request.user.id, + ) + db.session.add(tok) + db.session.commit() + + def rotate_refresh_token(self, request): + return False + + def authenticate_client(self, request, *args, **kwargs): + + auth = request.headers.get('Authorization', None) + if auth: + try: + _, s = auth.split(' ') + client_id, client_secret = decode_base64(s).split(':') + client_id = to_unicode(client_id, 'utf-8') + except Exception as e: + log.debug('Authenticate client failed with exception: %r', e) + return False + else: + client_id = request.client_id + + client = self._clientgetter(client_id) + if not client: + log.debug('Authenticate client failed, client not found.') + return False + + return self.authenticate_client_id(client_id, request) + + +oauth = CustomProvider() +oauth._validator = CustomRequestValidator() + + +@oauth.blueprint.route('/api/oauth/token', methods=['POST']) +@oauth.token_handler +def access_token(*args, **kwargs): + return None + + +@oauth.blueprint.route('/api/oauth/revoke', methods=['POST']) +@oauth.revoke_handler +def revoke_token(): + pass + + +@oauth.invalid_response +def invalid_require_oauth(req): + message = req.error_message if req else 'Unauthorized' + return jsonify(error='invalid_token', message=message), 401 diff --git a/skylines/frontend/views/__init__.py b/skylines/frontend/views/__init__.py index a5f522f2ac..78ff6308bf 100644 --- a/skylines/frontend/views/__init__.py +++ b/skylines/frontend/views/__init__.py @@ -1,6 +1,5 @@ from .errors import register as register_error_handlers from .i18n import register as register_i18n -from .login import register as register_login from .about import about_blueprint from .airport import airport_blueprint @@ -30,7 +29,6 @@ def register(app): register_error_handlers(app) register_i18n(app) - register_login(app) app.register_blueprint(assets_blueprint) app.register_blueprint(files_blueprint) diff --git a/skylines/frontend/views/club.py b/skylines/frontend/views/club.py index ae2b448f3b..9ed0c29f81 100644 --- a/skylines/frontend/views/club.py +++ b/skylines/frontend/views/club.py @@ -1,19 +1,23 @@ -from flask import Blueprint, g, request, jsonify +from flask import Blueprint, request, jsonify from skylines.database import db +from skylines.frontend.oauth import oauth from skylines.lib.dbutil import get_requested_record -from skylines.model import Club +from skylines.model import Club, User from skylines.schemas import ClubSchema, ValidationError club_blueprint = Blueprint('club', 'skylines') @club_blueprint.route('/clubs/', strict_slashes=False) +@oauth.optional() def read(club_id): + current_user = User.get(request.user_id) if request.user_id else None + club = get_requested_record(Club, club_id) json = ClubSchema().dump(club).data - json['isWritable'] = club.is_writable(g.current_user) + json['isWritable'] = club.is_writable(current_user) return jsonify(**json) diff --git a/skylines/frontend/views/flight.py b/skylines/frontend/views/flight.py index 524952af63..ece140add5 100644 --- a/skylines/frontend/views/flight.py +++ b/skylines/frontend/views/flight.py @@ -10,6 +10,7 @@ from datetime import timedelta from skylines.frontend.cache import cache +from skylines.frontend.oauth import oauth from skylines.database import db from skylines.lib import files from skylines.lib.dbutil import get_requested_record @@ -128,13 +129,13 @@ def _get_contest_traces(flight): def mark_flight_notifications_read(flight): - if not g.current_user: + if not request.user_id: return def add_flight_filter(query): return query.filter(Event.flight_id == flight.id) - Notification.mark_all_read(g.current_user, filter_func=add_flight_filter) + Notification.mark_all_read(User.get(request.user_id), filter_func=add_flight_filter) db.session.commit() @@ -151,10 +152,12 @@ class NearFlightSchema(Schema): @flight_blueprint.route('/flights/', strict_slashes=False) +@oauth.optional() def read(flight_id): flight = get_requested_record(Flight, flight_id, joinedload=[Flight.igc_file]) - if not flight.is_viewable(g.current_user): + current_user = User.get(request.user_id) if request.user_id else None + if not flight.is_viewable(current_user): return jsonify(), 404 _reanalyse_if_needed(flight) @@ -226,10 +229,12 @@ def read(flight_id): @flight_blueprint.route('/flights//json') +@oauth.optional() def json(flight_id): flight = get_requested_record(Flight, flight_id, joinedload=[Flight.igc_file]) - if not flight.is_viewable(g.current_user): + current_user = User.get(request.user_id) if request.user_id else None + if not flight.is_viewable(current_user): return jsonify(), 404 # Return HTTP Status code 304 if an upstream or browser cache already @@ -329,10 +334,12 @@ def _get_near_flights(flight, location, time, max_distance=1000): @flight_blueprint.route('/flights//near') +@oauth.optional() def near(flight_id): flight = get_requested_record(Flight, flight_id, joinedload=[Flight.igc_file]) - if not flight.is_viewable(g.current_user): + current_user = User.get(request.user_id) if request.user_id else None + if not flight.is_viewable(current_user): return jsonify(), 404 try: @@ -360,10 +367,12 @@ def add_flight_path(flight): @flight_blueprint.route('/flights/', methods=['POST'], strict_slashes=False) +@oauth.required() def update(flight_id): flight = get_requested_record(Flight, flight_id) - if not flight.is_writable(g.current_user): + current_user = User.get(request.user_id) + if not flight.is_writable(current_user): return jsonify(), 403 json = request.get_json() @@ -385,7 +394,7 @@ def update(flight_id): pilot_club_id = User.get(pilot_id).club_id - if pilot_club_id != g.current_user.club_id or (pilot_club_id is None and pilot_id != g.current_user.id): + if pilot_club_id != current_user.club_id or (pilot_club_id is None and pilot_id != current_user.id): return jsonify(error='pilot-disallowed'), 422 if flight.pilot_id != pilot_id: @@ -411,8 +420,8 @@ def update(flight_id): co_pilot_club_id = User.get(co_pilot_id).club_id - if co_pilot_club_id != g.current_user.club_id \ - or (co_pilot_club_id is None and co_pilot_id != g.current_user.id): + if co_pilot_club_id != current_user.club_id \ + or (co_pilot_club_id is None and co_pilot_id != current_user.id): return jsonify(error='co-pilot-disallowed'), 422 flight.co_pilot_id = co_pilot_id @@ -458,10 +467,12 @@ def update(flight_id): @flight_blueprint.route('/flights/', methods=('DELETE',), strict_slashes=False) +@oauth.required() def delete(flight_id): flight = get_requested_record(Flight, flight_id, joinedload=[Flight.igc_file]) - if not flight.is_writable(g.current_user): + current_user = User.get(request.user_id) + if not flight.is_writable(current_user): abort(403) files.delete_file(flight.igc_file.filename) @@ -473,10 +484,12 @@ def delete(flight_id): @flight_blueprint.route('/flights//comments', methods=('POST',)) +@oauth.required() def add_comment(flight_id): flight = get_requested_record(Flight, flight_id) - if not g.current_user: + current_user = User.get(request.user_id) + if not current_user: return jsonify(), 403 json = request.get_json() @@ -489,7 +502,7 @@ def add_comment(flight_id): return jsonify(error='validation-failed', fields=e.messages), 422 comment = FlightComment() - comment.user = g.current_user + comment.user = current_user comment.flight = flight comment.text = data['text'] diff --git a/skylines/frontend/views/flights.py b/skylines/frontend/views/flights.py index 0cbb086cec..03234f0816 100644 --- a/skylines/frontend/views/flights.py +++ b/skylines/frontend/views/flights.py @@ -1,6 +1,6 @@ from datetime import datetime -from flask import Blueprint, current_app, g, jsonify +from flask import Blueprint, current_app, jsonify, request from sqlalchemy import func from sqlalchemy.sql.expression import or_, and_ @@ -8,6 +8,7 @@ from sqlalchemy.orm.util import aliased from skylines.database import db +from skylines.frontend.oauth import oauth from skylines.lib.table_tools import Pager, Sorter from skylines.lib.dbutil import get_requested_record from skylines.model import ( @@ -21,13 +22,13 @@ def mark_flight_notifications_read(pilot): - if not g.current_user: + if not request.user_id: return def add_flight_filter(query): return query.filter(Event.actor_id == pilot.id) - Notification.mark_all_read(g.current_user, filter_func=add_flight_filter) + Notification.mark_all_read(User.get(request.user_id), filter_func=add_flight_filter) db.session.commit() @@ -42,8 +43,10 @@ def _create_list(date=None, pilot=None, club=None, airport=None, .query(FlightComment.flight_id, func.count('*').label('count')) \ .group_by(FlightComment.flight_id).subquery() + current_user = User.get(request.user_id) if request.user_id else None + flights = db.session.query(Flight, subq.c.count) \ - .filter(Flight.is_listable(g.current_user)) \ + .filter(Flight.is_listable(current_user)) \ .join(Flight.igc_file) \ .options(contains_eager(Flight.igc_file)) \ .join(owner_alias, IGCFile.owner) \ @@ -128,11 +131,13 @@ def _create_list(date=None, pilot=None, club=None, airport=None, @flights_blueprint.route('/flights/all') +@oauth.optional() def all(): return _create_list(default_sorting_column='date', default_sorting_order='desc') @flights_blueprint.route('/flights/date/') +@oauth.optional() def date(date, latest=False): try: if isinstance(date, (str, unicode)): @@ -148,12 +153,15 @@ def date(date, latest=False): @flights_blueprint.route('/flights/latest') +@oauth.optional() def latest(): + current_user = User.get(request.user_id) if request.user_id else None + query = db.session \ .query(func.max(Flight.date_local).label('date')) \ .filter(Flight.takeoff_time < datetime.utcnow()) \ .join(Flight.igc_file) \ - .filter(Flight.is_listable(g.current_user)) + .filter(Flight.is_listable(current_user)) date_ = query.one().date if not date_: @@ -163,6 +171,7 @@ def latest(): @flights_blueprint.route('/flights/pilot/') +@oauth.optional() def pilot(id): pilot = get_requested_record(User, id) @@ -172,6 +181,7 @@ def pilot(id): @flights_blueprint.route('/flights/club/') +@oauth.optional() def club(id): club = get_requested_record(Club, id) @@ -179,6 +189,7 @@ def club(id): @flights_blueprint.route('/flights/airport/') +@oauth.optional() def airport(id): airport = get_requested_record(Airport, id) @@ -186,17 +197,15 @@ def airport(id): @flights_blueprint.route('/flights/unassigned') +@oauth.required() def unassigned(): - if not g.current_user: - return jsonify(), 400 - - f = and_(Flight.pilot_id is None, - IGCFile.owner == g.current_user) + f = and_(Flight.pilot_id is None, IGCFile.owner_id == request.user_id) return _create_list(filter=f, default_sorting_column='date', default_sorting_order='desc') @flights_blueprint.route('/flights/list/') +@oauth.optional() def list(ids): if not ids: return jsonify(), 400 diff --git a/skylines/frontend/views/login.py b/skylines/frontend/views/login.py deleted file mode 100644 index 3a306edbba..0000000000 --- a/skylines/frontend/views/login.py +++ /dev/null @@ -1,46 +0,0 @@ -from flask import request, g, jsonify -from flask.ext.login import login_user, logout_user, current_user - -from skylines.model import User -from skylines.schemas import CurrentUserSchema, ValidationError - - -def register(app): - """ Register the /login and /logout routes on the given app """ - - @app.before_request - def inject_current_user(): - """ - Inject a current_user variable into the global object. current_user is - either None or points to the User that is currently logged in. - """ - - if current_user.is_anonymous(): - g.current_user = None - else: - g.current_user = current_user - - @app.route('/api/session', methods=('PUT',)) - @app.route('/session', methods=('PUT',)) - def create_session(): - json = request.get_json() - if json is None: - return jsonify(error='invalid-request'), 400 - - try: - data = CurrentUserSchema(only=('email', 'password')).load(json).data - except ValidationError, e: - return jsonify(error='validation-failed', fields=e.messages), 422 - - user = User.by_credentials(data['email_address'], data['password']) - if not user or not login_user(user, remember=True): - logout_user() - return jsonify(error='wrong-credentials'), 403 - - return jsonify() - - @app.route('/api/session', methods=('DELETE',)) - @app.route('/session', methods=('DELETE',)) - def delete_session(): - logout_user() - return jsonify() diff --git a/skylines/frontend/views/notifications.py b/skylines/frontend/views/notifications.py index 74c37a1600..2cb9a91683 100644 --- a/skylines/frontend/views/notifications.py +++ b/skylines/frontend/views/notifications.py @@ -1,10 +1,10 @@ -from flask import Blueprint, request, g, jsonify +from flask import Blueprint, request, jsonify from sqlalchemy.orm import subqueryload, contains_eager from sqlalchemy.sql.expression import or_ from skylines.database import db -from skylines.model.event import Event, Notification, Flight -from skylines.lib.decorators import login_required +from skylines.frontend.oauth import oauth +from skylines.model import Event, Notification, Flight, User TYPES = { Event.Type.FLIGHT_COMMENT: 'flight-comment', @@ -30,9 +30,9 @@ def _filter_query(query, args): @notifications_blueprint.route('/notifications', strict_slashes=False) -@login_required("You have to login to read notifications.") +@oauth.required() def list(): - query = Notification.query(recipient=g.current_user) \ + query = Notification.query(recipient_id=request.user_id) \ .join('event') \ .options(contains_eager('event')) \ .options(subqueryload('event.actor')) \ @@ -60,12 +60,13 @@ def get_event(notification): @notifications_blueprint.route('/notifications/clear', methods=('POST',)) -@login_required("You have to login to clear notifications.") +@oauth.required() def clear(): def filter_func(query): return _filter_query(query, request.args) - Notification.mark_all_read(g.current_user, filter_func=filter_func) + current_user = User.get(request.user_id) + Notification.mark_all_read(current_user, filter_func=filter_func) db.session.commit() diff --git a/skylines/frontend/views/settings.py b/skylines/frontend/views/settings.py index ad75d7a97e..6c011c3198 100644 --- a/skylines/frontend/views/settings.py +++ b/skylines/frontend/views/settings.py @@ -1,7 +1,7 @@ -from flask import Blueprint, request, g, jsonify -from flask.ext.login import login_required +from flask import Blueprint, request, jsonify from sqlalchemy.sql.expression import and_, or_ +from skylines.frontend.oauth import oauth from skylines.database import db from skylines.model import User, Club, Flight, IGCFile from skylines.model.event import ( @@ -13,15 +13,23 @@ @settings_blueprint.route('/settings', strict_slashes=False) -@login_required +@oauth.required() def read(): + current_user = User.get(request.user_id) + if not current_user: + return jsonify(error='invalid-token'), 401 + schema = CurrentUserSchema(exclude=('id')) - return jsonify(**schema.dump(g.current_user).data) + return jsonify(**schema.dump(current_user).data) @settings_blueprint.route('/settings', methods=['POST'], strict_slashes=False) -@login_required +@oauth.required() def update(): + current_user = User.get(request.user_id) + if not current_user: + return jsonify(error='invalid-token'), 401 + json = request.get_json() if json is None: return jsonify(error='invalid-request'), 400 @@ -34,61 +42,61 @@ def update(): if 'email_address' in data: email = data.get('email_address') - if email != g.current_user.email_address and User.exists(email_address=email): + if email != current_user.email_address and User.exists(email_address=email): return jsonify(error='email-exists-already'), 422 - g.current_user.email_address = email + current_user.email_address = email if 'first_name' in data: - g.current_user.first_name = data.get('first_name') + current_user.first_name = data.get('first_name') if 'last_name' in data: - g.current_user.last_name = data.get('last_name') + current_user.last_name = data.get('last_name') if 'distance_unit' in data: - g.current_user.distance_unit = data.get('distance_unit') + current_user.distance_unit = data.get('distance_unit') if 'speed_unit' in data: - g.current_user.speed_unit = data.get('speed_unit') + current_user.speed_unit = data.get('speed_unit') if 'lift_unit' in data: - g.current_user.lift_unit = data.get('lift_unit') + current_user.lift_unit = data.get('lift_unit') if 'altitude_unit' in data: - g.current_user.altitude_unit = data.get('altitude_unit') + current_user.altitude_unit = data.get('altitude_unit') if 'tracking_callsign' in data: - g.current_user.tracking_callsign = data.get('tracking_callsign') + current_user.tracking_callsign = data.get('tracking_callsign') if 'tracking_delay' in data: - g.current_user.tracking_delay = data.get('tracking_delay') + current_user.tracking_delay = data.get('tracking_delay') if 'password' in data: if 'currentPassword' not in data: return jsonify(error='current-password-missing'), 422 - if not g.current_user.validate_password(data['currentPassword']): + if not current_user.validate_password(data['currentPassword']): return jsonify(error='wrong-password'), 403 - g.current_user.password = data['password'] - g.current_user.recover_key = None + current_user.password = data['password'] + current_user.recover_key = None - if 'club_id' in data and data['club_id'] != g.current_user.club_id: + if 'club_id' in data and data['club_id'] != current_user.club_id: club_id = data['club_id'] if club_id is not None and not Club.exists(id=club_id): return jsonify(error='unknown-club'), 422 - g.current_user.club_id = club_id + current_user.club_id = club_id - create_club_join_event(club_id, g.current_user) + create_club_join_event(club_id, current_user) # assign the user's new club to all of his flights that have # no club yet flights = Flight.query().join(IGCFile) flights = flights.filter(and_(Flight.club_id == None, - or_(Flight.pilot_id == g.current_user.id, - IGCFile.owner_id == g.current_user.id))) + or_(Flight.pilot_id == current_user.id, + IGCFile.owner_id == current_user.id))) for flight in flights: flight.club_id = club_id @@ -98,27 +106,39 @@ def update(): @settings_blueprint.route('/settings/password/check', methods=['POST']) -@login_required +@oauth.required() def check_current_password(): + current_user = User.get(request.user_id) + if not current_user: + return jsonify(error='invalid-token'), 401 + json = request.get_json() if not json: return jsonify(error='invalid-request'), 400 - return jsonify(result=g.current_user.validate_password(json.get('password', ''))) + return jsonify(result=current_user.validate_password(json.get('password', ''))) @settings_blueprint.route('/settings/tracking/key', methods=['POST']) -@login_required +@oauth.required() def tracking_generate_key(): - g.current_user.generate_tracking_key() + current_user = User.get(request.user_id) + if not current_user: + return jsonify(error='invalid-token'), 401 + + current_user.generate_tracking_key() db.session.commit() - return jsonify(key=g.current_user.tracking_key_hex) + return jsonify(key=current_user.tracking_key_hex) @settings_blueprint.route('/settings/club', methods=['PUT']) -@login_required +@oauth.required() def create_club(): + current_user = User.get(request.user_id) + if not current_user: + return jsonify(error='invalid-token'), 401 + json = request.get_json() if json is None: return jsonify(error='invalid-request'), 400 @@ -133,15 +153,15 @@ def create_club(): # create the new club club = Club(**data) - club.owner_id = g.current_user.id + club.owner_id = current_user.id db.session.add(club) db.session.flush() # assign the user to the new club - g.current_user.club = club + current_user.club = club # create the "user joined club" event - create_club_join_event(club.id, g.current_user) + create_club_join_event(club.id, current_user) db.session.commit() diff --git a/skylines/frontend/views/tracking.py b/skylines/frontend/views/tracking.py index a1774eeee5..6f3d1f4b03 100644 --- a/skylines/frontend/views/tracking.py +++ b/skylines/frontend/views/tracking.py @@ -1,6 +1,7 @@ -from flask import Blueprint, jsonify, g +from flask import Blueprint, jsonify, request from skylines.frontend.cache import cache +from skylines.frontend.oauth import oauth from skylines.lib.decorators import jsonp from skylines.model import TrackingFix, Airport, Follower from skylines.schemas import TrackingFixSchema, AirportSchema @@ -10,6 +11,7 @@ @tracking_blueprint.route('/tracking', strict_slashes=False) @tracking_blueprint.route('/live', strict_slashes=False) +@oauth.optional() def index(): fix_schema = TrackingFixSchema(only=('time', 'location', 'altitude', 'elevation', 'pilot')) airport_schema = AirportSchema(only=('id', 'name', 'countryCode')) @@ -34,8 +36,8 @@ def get_nearest_airport(track): tracks.append(track) - if g.current_user: - followers = [f.destination_id for f in Follower.query(source=g.current_user)] + if request.user_id: + followers = [f.destination_id for f in Follower.query(source_id=request.user_id)] else: followers = [] diff --git a/skylines/frontend/views/upload.py b/skylines/frontend/views/upload.py index 1e672c6101..e9a21a957a 100644 --- a/skylines/frontend/views/upload.py +++ b/skylines/frontend/views/upload.py @@ -7,16 +7,15 @@ from collections import namedtuple -from flask import Blueprint, request, g, current_app, abort, make_response, jsonify -from flask_wtf.csrf import generate_csrf, validate_csrf +from flask import Blueprint, request, current_app, abort, make_response, jsonify from redis.exceptions import ConnectionError from sqlalchemy.sql.expression import func from skylines.frontend.cache import cache +from skylines.frontend.oauth import oauth from skylines.database import db from skylines.lib import files from skylines.lib.util import pressure_alt_to_qnh_alt -from skylines.lib.decorators import login_required from skylines.lib.md5 import file_md5 from skylines.lib.sql import query_to_sql from skylines.lib.xcsoar_ import flight_path, analyse_flight @@ -148,25 +147,13 @@ def _encode_flight_path(fp, qnh): igc_start_time=fp[0].datetime, igc_end_time=fp[-1].datetime) -@upload_blueprint.route('/flights/upload/csrf') -@login_required("You have to login to upload flights.") -def csrf(): - if not g.current_user: - return jsonify(), 403 - - return jsonify(token=generate_csrf()) - - @upload_blueprint.route('/flights/upload', methods=('POST',), strict_slashes=False) +@oauth.required() def index_post(): - if not g.current_user: - return jsonify(error='authentication-required'), 403 + current_user = User.get(request.user_id) form = request.form - if not validate_csrf(form.get('csrfToken')): - return jsonify(error='invalid-csrf-token'), 403 - if form.get('pilotId') == u'': form = form.copy() form.pop('pilotId') @@ -176,13 +163,11 @@ def index_post(): except ValidationError, e: return jsonify(error='validation-failed', fields=e.messages), 422 - user = g.current_user - pilot_id = data.get('pilot_id') pilot = pilot_id and User.get(pilot_id) pilot_id = pilot and pilot.id - club_id = (pilot and pilot.club_id) or user.club_id + club_id = (pilot and pilot.club_id) or current_user.club_id results = [] @@ -204,7 +189,7 @@ def index_post(): continue igc_file = IGCFile() - igc_file.owner = user + igc_file.owner = current_user igc_file.filename = filename igc_file.md5 = md5 igc_file.update_igc_headers() @@ -269,7 +254,7 @@ def index_post(): db.session.flush() # Store data in cache for image creation - cache_key = hashlib.sha1(str(flight.id) + '_' + str(user.id)).hexdigest() + cache_key = hashlib.sha1(str(flight.id) + '_' + str(current_user.id)).hexdigest() cache.set('upload_airspace_infringements_' + cache_key, infringements, timeout=15 * 60) cache.set('upload_airspace_flight_path_' + cache_key, fp, timeout=15 * 60) @@ -287,12 +272,12 @@ def index_post(): results = UploadResultSchema().dump(results, many=True).data club_members = [] - if g.current_user.club_id: + if current_user.club_id: member_schema = UserSchema(only=('id', 'name')) - club_members = User.query(club_id=g.current_user.club_id) \ + club_members = User.query(club_id=current_user.club_id) \ .order_by(func.lower(User.name)) \ - .filter(User.id != g.current_user.id) + .filter(User.id != current_user.id) club_members = member_schema.dump(club_members.all(), many=True).data @@ -307,8 +292,10 @@ def index_post(): @upload_blueprint.route('/flights/upload/verify', methods=('POST',)) -@login_required('You have to login to upload flights.') +@oauth.required() def verify(): + current_user = User.get(request.user_id) + json = request.get_json() if json is None: return jsonify(error='invalid-request'), 400 @@ -333,7 +320,7 @@ def verify(): for d in data: flight = flights.get(d.pop('id')) - if not flight or not flight.is_writable(g.current_user): + if not flight or not flight.is_writable(current_user): return jsonify(error='unknown-flight'), 422 if 'pilot_id' in d and d['pilot_id'] is not None and d['pilot_id'] not in users: diff --git a/skylines/frontend/views/user.py b/skylines/frontend/views/user.py index 9a2077c6a3..37443d5db9 100644 --- a/skylines/frontend/views/user.py +++ b/skylines/frontend/views/user.py @@ -1,12 +1,12 @@ from datetime import date, timedelta -from flask import Blueprint, g, request, jsonify -from flask.ext.login import login_required +from flask import Blueprint, request, jsonify from sqlalchemy import func, and_ from sqlalchemy.orm import contains_eager, subqueryload from skylines.database import db +from skylines.frontend.oauth import oauth from skylines.lib.dbutil import get_requested_record from skylines.model import ( User, Flight, Follower, Location, Notification, Event @@ -78,25 +78,27 @@ def _get_takeoff_locations(user): def mark_user_notifications_read(user): - if not g.current_user: + if not request.user_id: return def add_user_filter(query): return query.filter(Event.actor_id == user.id) - Notification.mark_all_read(g.current_user, filter_func=add_user_filter) + Notification.mark_all_read(User.get(request.user_id), filter_func=add_user_filter) db.session.commit() @user_blueprint.route('/users/', strict_slashes=False) +@oauth.optional() def read(user_id): user = get_requested_record(User, user_id) - user_schema = CurrentUserSchema() if user == g.current_user else UserSchema() + user_schema = CurrentUserSchema() if user_id == request.user_id else UserSchema() user_json = user_schema.dump(user).data - if g.current_user: - user_json['followed'] = g.current_user.follows(user) + if request.user_id: + current_user = User.get(request.user_id) + user_json['followed'] = current_user.follows(user) if 'extended' in request.args: user_json['distanceFlights'] = _distance_flights(user) @@ -109,6 +111,7 @@ def read(user_id): @user_blueprint.route('/users//followers') +@oauth.optional() def followers(user_id): user = get_requested_record(User, user_id) @@ -128,6 +131,7 @@ def followers(user_id): @user_blueprint.route('/users//following') +@oauth.optional() def following(user_id): user = get_requested_record(User, user_id) @@ -154,11 +158,11 @@ def add_current_user_follows(followers): following the pilot """ - if not g.current_user: + if not request.user_id: return # Query list of people that the current user is following - query = Follower.query(source=g.current_user) + query = Follower.query(source_id=request.user_id) current_user_follows = [follower.destination_id for follower in query] for follower in followers: @@ -166,19 +170,21 @@ def add_current_user_follows(followers): @user_blueprint.route('/users//follow') -@login_required +@oauth.required() def follow(user_id): user = get_requested_record(User, user_id) - Follower.follow(g.current_user, user) - create_follower_notification(user, g.current_user) + current_user = User.get(request.user_id) + Follower.follow(current_user, user) + create_follower_notification(user, current_user) db.session.commit() return jsonify() @user_blueprint.route('/users//unfollow') -@login_required +@oauth.required() def unfollow(user_id): user = get_requested_record(User, user_id) - Follower.unfollow(g.current_user, user) + current_user = User.get(request.user_id) + Follower.unfollow(current_user, user) db.session.commit() return jsonify() diff --git a/skylines/frontend/views/users.py b/skylines/frontend/views/users.py index a6671ce20c..42a650c1fa 100644 --- a/skylines/frontend/views/users.py +++ b/skylines/frontend/views/users.py @@ -2,13 +2,14 @@ from email.utils import formatdate import smtplib -from flask import Blueprint, request, current_app, g, jsonify +from flask import Blueprint, request, current_app, jsonify from werkzeug.exceptions import ServiceUnavailable from sqlalchemy import func from sqlalchemy.orm import joinedload from skylines.database import db +from skylines.frontend.oauth import oauth from skylines.model import User from skylines.model.event import create_new_user_event from skylines.schemas import UserSchema, CurrentUserSchema, ValidationError @@ -137,7 +138,10 @@ def recover_step2_post(json): @users_blueprint.route('/users/check-email', methods=['POST']) +@oauth.optional() def check_email(): + current_user = User.get(request.user_id) if request.user_id else None + json = request.get_json() if not json: return jsonify(error='invalid-request'), 400 @@ -145,7 +149,7 @@ def check_email(): email = json.get('email', '') result = 'available' - if g.current_user and email == g.current_user.email_address: + if current_user and email == current_user.email_address: result = 'self' elif User.exists(email_address=email): result = 'unavailable' diff --git a/skylines/lib/decorators.py b/skylines/lib/decorators.py index 13e90c29a1..1ea62668d3 100644 --- a/skylines/lib/decorators.py +++ b/skylines/lib/decorators.py @@ -1,7 +1,6 @@ from functools import wraps -from flask import current_app, request, jsonify -from flask.ext.login import current_user +from flask import current_app, request def jsonp(func): @@ -17,17 +16,3 @@ def decorated_function(*args, **kwargs): else: return func(*args, **kwargs) return decorated_function - - -class login_required: - def __init__(self, msg=None): - self.msg = msg - - def __call__(self, fn): - @wraps(fn) - def decorated_view(*args, **kwargs): - if not current_user.is_authenticated(): - return jsonify(), 401 - - return fn(*args, **kwargs) - return decorated_view