inth-oauth2/src/client.rs

246 lines
7.4 KiB
Rust

use std::io::Read;
use chrono::{UTC, Duration};
use hyper::{self, header, mime};
use rustc_serialize::json;
use url::{Url, form_urlencoded};
use super::Token;
use super::error::{Error, Result, OAuth2Error, OAuth2ErrorCode};
/// OAuth 2.0 client.
///
/// Performs HTTP requests using the provided `hyper::Client`.
///
/// See [RFC6749 section 4.1](http://tools.ietf.org/html/rfc6749#section-4.1).
pub struct Client {
http_client: hyper::Client,
auth_uri: String,
token_uri: String,
client_id: String,
client_secret: String,
redirect_uri: Option<String>,
}
#[derive(RustcDecodable)]
struct TokenResponse {
access_token: String,
token_type: String,
expires_in: Option<i64>,
refresh_token: Option<String>,
scope: Option<String>,
}
impl Into<Token> for TokenResponse {
fn into(self) -> Token {
Token {
access_token: self.access_token,
token_type: self.token_type,
expires: self.expires_in.map(|s| UTC::now() + Duration::seconds(s)),
refresh_token: self.refresh_token,
scope: self.scope,
}
}
}
#[derive(RustcDecodable)]
struct ErrorResponse {
error: String,
error_description: Option<String>,
error_uri: Option<String>,
}
impl Into<OAuth2Error> for ErrorResponse {
fn into(self) -> OAuth2Error {
let code = match &self.error[..] {
"invalid_request" => OAuth2ErrorCode::InvalidRequest,
"invalid_client" => OAuth2ErrorCode::InvalidClient,
"invalid_grant" => OAuth2ErrorCode::InvalidGrant,
"unauthorized_client" => OAuth2ErrorCode::UnauthorizedClient,
"unsupported_grant_type" => OAuth2ErrorCode::UnsupportedGrantType,
"invalid_scope" => OAuth2ErrorCode::InvalidScope,
_ => OAuth2ErrorCode::Unrecognized(self.error),
};
OAuth2Error {
code: code,
description: self.error_description,
uri: self.error_uri,
}
}
}
macro_rules! site_constructors {
(
$(
#[$attr:meta]
$ident:ident => ($auth_uri:expr, $token_uri:expr)
),*
) => {
$(
#[$attr]
pub fn $ident<S>(
http_client: hyper::Client,
client_id: S,
client_secret: S,
redirect_uri: Option<S>
) -> Self where S: Into<String> {
Client {
http_client: http_client,
auth_uri: String::from($auth_uri),
token_uri: String::from($token_uri),
client_id: client_id.into(),
client_secret: client_secret.into(),
redirect_uri: redirect_uri.map(Into::into),
}
}
)*
}
}
impl Client {
/// Creates an OAuth 2.0 client.
pub fn new<S>(
http_client: hyper::Client,
auth_uri: S,
token_uri: S,
client_id: S,
client_secret: S,
redirect_uri: Option<S>
) -> Self where S: Into<String> {
Client {
http_client: http_client,
auth_uri: auth_uri.into(),
token_uri: token_uri.into(),
client_id: client_id.into(),
client_secret: client_secret.into(),
redirect_uri: redirect_uri.map(Into::into),
}
}
site_constructors!{
#[doc = "Creates a Google OAuth 2.0 client.\n\nSee [Using OAuth 2.0 to Access Google APIs](https://developers.google.com/identity/protocols/OAuth2)."]
google => (
"https://accounts.google.com/o/oauth2/auth",
"https://accounts.google.com/o/oauth2/token"
),
#[doc = "Creates a GitHub OAuth 2.0 client.\n\nSee [OAuth, GitHub API](https://developer.github.com/v3/oauth/)."]
github => (
"https://github.com/login/oauth/authorize",
"https://github.com/login/oauth/access_token"
)
}
/// Constructs an authorization request URI.
///
/// See [RFC6749 section 4.1.1](http://tools.ietf.org/html/rfc6749#section-4.1.1).
pub fn auth_uri(&self, scope: Option<&str>, state: Option<&str>) -> Result<String> {
let mut uri = try!(Url::parse(&self.auth_uri));
let mut query_pairs = vec![
("response_type", "code"),
("client_id", &self.client_id),
];
if let Some(ref redirect_uri) = self.redirect_uri {
query_pairs.push(("redirect_uri", redirect_uri));
}
if let Some(scope) = scope {
query_pairs.push(("scope", scope));
}
if let Some(state) = state {
query_pairs.push(("state", state));
}
uri.set_query_from_pairs(query_pairs.iter());
Ok(uri.serialize())
}
fn auth_header(&self) -> header::Authorization<header::Basic> {
header::Authorization(
header::Basic {
username: self.client_id.clone(),
password: Some(self.client_secret.clone()),
}
)
}
fn accept_header(&self) -> header::Accept {
header::Accept(vec![
header::qitem(
mime::Mime(
mime::TopLevel::Application,
mime::SubLevel::Json,
vec![]
)
),
])
}
fn token_post(&self, body_pairs: Vec<(&str, &str)>) -> Result<Token> {
let post_body = form_urlencoded::serialize(body_pairs);
let request = self.http_client.post(&self.token_uri)
.header(self.auth_header())
.header(self.accept_header())
.header(header::ContentType::form_url_encoded())
.body(&post_body);
let mut response = try!(request.send());
let mut body = String::new();
try!(response.read_to_string(&mut body));
let token = json::decode::<TokenResponse>(&body);
if let Ok(token) = token {
return Ok(token.into());
}
let error: ErrorResponse = try!(json::decode(&body));
Err(Error::OAuth2(error.into()))
}
/// Requests an access token using an authorization code.
///
/// See [RFC6749 section 4.1.3](http://tools.ietf.org/html/rfc6749#section-4.1.3).
pub fn request_token(&self, code: &str) -> Result<Token> {
let mut body_pairs = vec![
("grant_type", "authorization_code"),
("code", code),
];
if let Some(ref redirect_uri) = self.redirect_uri {
body_pairs.push(("redirect_uri", redirect_uri));
}
self.token_post(body_pairs)
}
/// Refreshes an access token.
///
/// See [RFC6749 section 6](http://tools.ietf.org/html/rfc6749#section-6).
///
/// # Panics
///
/// Panics if `token` does not contain a `refresh_token`.
pub fn refresh_token(&self, token: &Token, scope: Option<&str>) -> Result<Token> {
let refresh_token = token.refresh_token.as_ref().unwrap();
let mut body_pairs = vec![
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
];
if let Some(scope) = scope {
body_pairs.push(("scope", scope));
}
let mut result = self.token_post(body_pairs);
if let Ok(ref mut token) = result {
if token.refresh_token.is_none() {
token.refresh_token = Some(refresh_token.clone());
}
}
result
}
}