# -*- coding: utf-8 -*- """ flask_oauth ~~~~~~~~~~~ Implements basic OAuth support for Flask. :copyright: (c) 2010 by Armin Ronacher. :license: BSD, see LICENSE for more details. """ import httplib2 from functools import wraps from urlparse import urljoin from flask import request, session, json, redirect, Response from werkzeug import url_decode, url_encode, url_quote, \ parse_options_header, Headers import oauth2 _etree = None def get_etree(): """Return an elementtree implementation. Prefers lxml""" global _etree if _etree is None: try: from lxml import etree as _etree except ImportError: try: from xml.etree import cElementTree as _etree except ImportError: try: from xml.etree import ElementTree as _etree except ImportError: raise TypeError('lxml or etree not found') return _etree def parse_response(resp, content, strict=False): ct, options = parse_options_header(resp['content-type']) if ct in ('application/json', 'text/javascript'): return json.loads(content) elif ct in ('application/xml', 'text/xml'): # technically, text/xml is ascii based but because many # implementations get that wrong and utf-8 is a superset # of utf-8 anyways, so there is not much harm in assuming # utf-8 here charset = options.get('charset', 'utf-8') return get_etree().fromstring(content.decode(charset)) elif ct != 'application/x-www-form-urlencoded': if strict: return content charset = options.get('charset', 'utf-8') return url_decode(content, charset=charset).to_dict() def add_query(url, args): if not args: return url return url + ('?' in url and '&' or '?') + url_encode(args) def encode_request_data(data, format): if format is None: return data, None elif format == 'json': return json.dumps(data or {}), 'application/json' elif format == 'urlencoded': return url_encode(data or {}), 'application/x-www-form-urlencoded' raise TypeError('Unknown format %r' % format) class OAuthResponse(object): """Contains the response sent back from an OAuth protected remote application. """ def __init__(self, resp, content): #: a :class:`~werkzeug.Headers` object with the response headers #: the application sent. self.headers = Headers(resp) #: the raw, unencoded content from the server self.raw_data = content #: the parsed content from the server self.data = parse_response(resp, content, strict=True) @property def status(self): """The status code of the response.""" return self.headers.get('status', type=int) class OAuthClient(oauth2.Client): def request_new_token(self, uri, callback=None, params={}): if callback is not None: params['oauth_callback'] = callback req = oauth2.Request.from_consumer_and_token( self.consumer, token=self.token, http_method='POST', http_url=uri, parameters=params, is_form_encoded=True) req.sign_request(self.method, self.consumer, self.token) body = req.to_postdata() headers = { 'Content-Type': 'application/x-www-form-urlencoded', 'Content-Length': str(len(body)) } return httplib2.Http.request(self, uri, method='POST', body=body, headers=headers) class OAuthException(RuntimeError): """Raised if authorization fails for some reason.""" message = None type = None def __init__(self, message, type=None, data=None): #: A helpful error message for debugging self.message = message #: A unique type for this exception if available. self.type = type #: If available, the parsed data from the remote API that can be #: used to pointpoint the error. self.data = data def __str__(self): return self.message.encode('utf-8') def __unicode__(self): return self.message class OAuth(object): """Registry for remote applications. In the future this will also be the central class for OAuth provider functionality. """ def __init__(self): self.remote_apps = {} def remote_app(self, name, register=True, **kwargs): """Registers a new remote applicaton. If `param` register is set to `False` the application is not registered in the :attr:`remote_apps` dictionary. The keyword arguments are forwarded to the :class:`OAuthRemoteApp` consturctor. """ app = OAuthRemoteApp(self, name, **kwargs) if register: assert name not in self.remote_apps, \ 'application already registered' self.remote_apps[name] = app return app class OAuthRemoteApp(object): """Represents a remote application. :param oauth: the associated :class:`OAuth` object. :param name: then name of the remote application :param request_token_url: the URL for requesting new tokens :param access_token_url: the URL for token exchange :param authorize_url: the URL for authorization :param consumer_key: the application specific consumer key :param consumer_secret: the application specific consumer secret :param request_token_params: an optional dictionary of parameters to forward to the request token URL or authorize URL depending on oauth version. :param access_token_params: an option diction of parameters to forward to the access token URL :param access_token_method: the HTTP method that should be used for the access_token_url. Defaults to ``'GET'``. """ def __init__(self, oauth, name, base_url, request_token_url, access_token_url, authorize_url, consumer_key, consumer_secret, request_token_params=None, access_token_params=None, access_token_method='GET'): self.oauth = oauth #: the `base_url` all URLs are joined with. self.base_url = base_url self.name = name self.request_token_url = request_token_url self.access_token_url = access_token_url self.authorize_url = authorize_url self.consumer_key = consumer_key self.consumer_secret = consumer_secret self.tokengetter_func = None self.request_token_params = request_token_params or {} self.access_token_params = access_token_params or {} self.access_token_method = access_token_method self._consumer = oauth2.Consumer(self.consumer_key, self.consumer_secret) self._client = OAuthClient(self._consumer) def status_okay(self, resp): """Given request data, checks if the status is okay.""" try: return int(resp['status']) in (200, 201) except ValueError: return False def get(self, *args, **kwargs): """Sends a ``GET`` request. Accepts the same parameters as :meth:`request`. """ kwargs['method'] = 'GET' return self.request(*args, **kwargs) def post(self, *args, **kwargs): """Sends a ``POST`` request. Accepts the same parameters as :meth:`request`. """ kwargs['method'] = 'POST' return self.request(*args, **kwargs) def put(self, *args, **kwargs): """Sends a ``PUT`` request. Accepts the same parameters as :meth:`request`. """ kwargs['method'] = 'PUT' return self.request(*args, **kwargs) def delete(self, *args, **kwargs): """Sends a ``DELETE`` request. Accepts the same parameters as :meth:`request`. """ kwargs['method'] = 'DELETE' return self.request(*args, **kwargs) def make_client(self, token=None): """Creates a new `oauth2` Client object with the token attached. Usually you don't have to do that but use the :meth:`request` method instead. """ return oauth2.Client(self._consumer, self.get_request_token(token)) def request(self, url, data="", headers=None, format='urlencoded', method='GET', content_type=None, token=None): """Sends a request to the remote server with OAuth tokens attached. The `url` is joined with :attr:`base_url` if the URL is relative. .. versionadded:: 0.12 added the `token` parameter. :param url: where to send the request to :param data: the data to be sent to the server. If the request method is ``GET`` the data is appended to the URL as query parameters, otherwise encoded to `format` if the format is given. If a `content_type` is provided instead, the data must be a string encoded for the given content type and used as request body. :param headers: an optional dictionary of headers. :param format: the format for the `data`. Can be `urlencoded` for URL encoded data or `json` for JSON. :param method: the HTTP request method to use. :param content_type: an optional content type. If a content type is provided, the data is passed as it and the `format` parameter is ignored. :param token: an optional token to pass to tokengetter. Use this if you want to support sending requests using multiple tokens. If you set this to anything not None, `tokengetter_func` will receive the given token as an argument, in which case the tokengetter should return the `(token, secret)` tuple for the given token. :return: an :class:`OAuthResponse` object. """ headers = dict(headers or {}) client = self.make_client(token) url = self.expand_url(url) if method == 'GET': assert format == 'urlencoded' if data: url = add_query(url, data) data = "" else: if content_type is None: data, content_type = encode_request_data(data, format) if content_type is not None: headers['Content-Type'] = content_type return OAuthResponse(*client.request(url, method=method, body=data or '', headers=headers)) def expand_url(self, url): return urljoin(self.base_url, url) def generate_request_token(self, callback=None): if callback is not None: callback = urljoin(request.url, callback) resp, content = self._client.request_new_token( self.expand_url(self.request_token_url), callback, self.request_token_params) if not self.status_okay(resp): raise OAuthException('Failed to generate request token', type='token_generation_failed') data = parse_response(resp, content) if data is None: raise OAuthException('Invalid token response from ' + self.name, type='token_generation_failed') tup = (data['oauth_token'], data['oauth_token_secret']) session[self.name + '_oauthtok'] = tup return tup def get_request_token(self, token=None): assert self.tokengetter_func is not None, 'missing tokengetter function' # Don't pass the token if the token is None to support old # tokengetter functions. rv = self.tokengetter_func(*(token and (token,) or ())) if rv is None: rv = session.get(self.name + '_oauthtok') if rv is None: raise OAuthException('No token available', type='token_missing') return oauth2.Token(*rv) def free_request_token(self): session.pop(self.name + '_oauthtok', None) session.pop(self.name + '_oauthredir', None) def authorize(self, callback=None): """Returns a redirect response to the remote authorization URL with the signed callback given. The callback must be `None` in which case the application will most likely switch to PIN based authentication or use a remotely stored callback URL. Alternatively it's an URL on the system that has to be decorated as :meth:`authorized_handler`. """ if self.request_token_url: token = self.generate_request_token(callback)[0] url = '%s?oauth_token=%s' % (self.expand_url(self.authorize_url), url_quote(token)) else: assert callback is not None, 'Callback is required OAuth2' # This is for things like facebook's oauth. Since we need the # callback for the access_token_url we need to keep it in the # session. params = dict(self.request_token_params) params['redirect_uri'] = callback params['client_id'] = self.consumer_key session[self.name + '_oauthredir'] = callback url = add_query(self.expand_url(self.authorize_url), params) return redirect(url) def tokengetter(self, f): """Registers a function as tokengetter. The tokengetter has to return a tuple of ``(token, secret)`` with the user's token and token secret. If the data is unavailable, the function must return `None`. If the `token` parameter is passed to the request function it's forwarded to the tokengetter function:: @oauth.tokengetter def get_token(token='user'): if token == 'user': return find_the_user_token() elif token == 'app': return find_the_app_token() raise RuntimeError('invalid token') """ self.tokengetter_func = f return f def handle_oauth1_response(self): """Handles an oauth1 authorization response. The return value of this method is forwarded as first argument to the handling view function. """ client = self.make_client() resp, content = client.request('%s?oauth_verifier=%s' % ( self.expand_url(self.access_token_url), request.args['oauth_verifier'] ), self.access_token_method) data = parse_response(resp, content) if not self.status_okay(resp): raise OAuthException('Invalid response from ' + self.name, type='invalid_response', data=data) return data def handle_oauth2_response(self): """Handles an oauth2 authorization response. The return value of this method is forwarded as first argument to the handling view function. """ remote_args = { 'code': request.args.get('code'), 'client_id': self.consumer_key, 'client_secret': self.consumer_secret, 'redirect_uri': session.get(self.name + '_oauthredir') } remote_args.update(self.access_token_params) if self.access_token_method == 'POST': resp, content = self._client.request(self.expand_url(self.access_token_url), self.access_token_method, url_encode(remote_args)) elif self.access_token_method == 'GET': url = add_query(self.expand_url(self.access_token_url), remote_args) resp, content = self._client.request(url, self.access_token_method) else: raise OAuthException('Unsupported access_token_method: ' + self.access_token_method) data = parse_response(resp, content) if not self.status_okay(resp): raise OAuthException('Invalid response from ' + self.name, type='invalid_response', data=data) return data def handle_unknown_response(self): """Called if an unknown response came back from the server. This usually indicates a denied response. The default implementation just returns `None`. """ return None def authorized_handler(self, f): """Injects additional authorization functionality into the function. The function will be passed the response object as first argument if the request was allowed, or `None` if access was denied. When the authorized handler is called, the temporary issued tokens are already destroyed. """ @wraps(f) def decorated(*args, **kwargs): if 'oauth_verifier' in request.args: data = self.handle_oauth1_response() elif 'code' in request.args: data = self.handle_oauth2_response() else: data = self.handle_unknown_response() self.free_request_token() return f(*((data,) + args), **kwargs) return decorated