From c088b96f0d321d1df2a99850d12e473c8a226e17 Mon Sep 17 00:00:00 2001 From: Curtis McEnroe Date: Mon, 30 Nov 2015 21:55:07 -0500 Subject: [PATCH] Split Token into AccessToken, RefreshToken, TokenPair --- examples/github.rs | 4 +-- examples/google.rs | 6 ++-- src/client.rs | 58 ++++++++++++++++++------------------ src/lib.rs | 40 +++++++++++++------------ src/token.rs | 73 ++++++++++++++++++++++++++++++---------------- 5 files changed, 104 insertions(+), 77 deletions(-) diff --git a/examples/github.rs b/examples/github.rs index 0580696..0ef6382 100644 --- a/examples/github.rs +++ b/examples/github.rs @@ -18,7 +18,7 @@ fn main() { let mut code = String::new(); io::stdin().read_line(&mut code).unwrap(); - let token = client.request_token(code.trim()).unwrap(); + let token_pair = client.request_token(code.trim()).unwrap(); - println!("{:?}", token); + println!("{:?}", token_pair); } diff --git a/examples/google.rs b/examples/google.rs index a2338df..98b6dfd 100644 --- a/examples/google.rs +++ b/examples/google.rs @@ -21,11 +21,11 @@ fn main() { let mut code = String::new(); io::stdin().read_line(&mut code).unwrap(); - let token = client.request_token(code.trim()).unwrap(); + let token_pair = client.request_token(code.trim()).unwrap(); - println!("{:?}", token); + println!("{:?}", token_pair); - let refreshed = client.refresh_token(&token, None).unwrap(); + let refreshed = client.refresh_token(token_pair.refresh.unwrap(), None).unwrap(); println!("{:?}", refreshed); } diff --git a/src/client.rs b/src/client.rs index ce04620..8f9dc61 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,7 +5,7 @@ use hyper::{self, header, mime}; use rustc_serialize::json; use url::{Url, form_urlencoded}; -use super::Token; +use super::{TokenPair, AccessToken, RefreshToken}; use super::error::{Error, Result, OAuth2Error, OAuth2ErrorCode}; /// OAuth 2.0 client. @@ -33,14 +33,16 @@ struct TokenResponse { scope: Option, } -impl Into for TokenResponse { - fn into(self) -> Token { - Token { - access_token: self.access_token, - token_type: self.token_type, - expires: self.expires_in.map(|s| UTC::now() + Duration::seconds(s)), - refresh_token: self.refresh_token, - scope: self.scope, +impl Into for TokenResponse { + fn into(self) -> TokenPair { + TokenPair { + access: AccessToken { + token: self.access_token, + token_type: self.token_type, + expires: self.expires_in.map(|s| UTC::now() + Duration::seconds(s)), + scope: self.scope, + }, + refresh: self.refresh_token.map(|t| RefreshToken { token: t }), } } } @@ -179,7 +181,7 @@ impl Client { ]) } - fn token_post(&self, body_pairs: Vec<(&str, &str)>) -> Result { + fn token_post(&self, body_pairs: Vec<(&str, &str)>) -> Result { let post_body = form_urlencoded::serialize(body_pairs); let request = self.http_client.post(&self.token_uri) .header(self.auth_header()) @@ -203,7 +205,7 @@ impl Client { /// Requests an access token using an authorization code. /// /// See [RFC6749 section 4.1.3](http://tools.ietf.org/html/rfc6749#section-4.1.3). - pub fn request_token(&self, code: &str) -> Result { + pub fn request_token(&self, code: &str) -> Result { let mut body_pairs = vec![ ("grant_type", "authorization_code"), ("code", code), @@ -216,27 +218,25 @@ impl Client { /// Refreshes an access token. /// + /// The returned `TokenPair` will always have a `refresh`. + /// /// See [RFC6749 section 6](http://tools.ietf.org/html/rfc6749#section-6). - /// - /// # Panics - /// - /// Panics if `token` does not contain a `refresh_token`. - pub fn refresh_token(&self, token: &Token, scope: Option<&str>) -> Result { - let refresh_token = token.refresh_token.as_ref().unwrap(); + pub fn refresh_token(&self, refresh: RefreshToken, scope: Option<&str>) -> Result { + let mut result = { + let mut body_pairs = vec![ + ("grant_type", "refresh_token"), + ("refresh_token", &refresh.token), + ]; + if let Some(scope) = scope { + body_pairs.push(("scope", scope)); + } - let mut body_pairs = vec![ - ("grant_type", "refresh_token"), - ("refresh_token", refresh_token), - ]; - if let Some(scope) = scope { - body_pairs.push(("scope", scope)); - } + self.token_post(body_pairs) + }; - let mut result = self.token_post(body_pairs); - - if let Ok(ref mut token) = result { - if token.refresh_token.is_none() { - token.refresh_token = Some(refresh_token.clone()); + if let Ok(ref mut pair) = result { + if pair.refresh.is_none() { + pair.refresh = Some(refresh); } } diff --git a/src/lib.rs b/src/lib.rs index 580857c..da4eecd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,8 +70,8 @@ //! # use inth_oauth2::Client as OAuth2; //! # let auth = OAuth2::google(Default::default(), "", "", None); //! # let code = String::new(); -//! let token = auth.request_token(&code).unwrap(); -//! println!("{}", token.access_token); +//! let token_pair = auth.request_token(&code).unwrap(); +//! println!("{}", token_pair.access.token); //! ``` //! //! ## Refreshing an access token @@ -81,36 +81,40 @@ //! ```no_run //! # use inth_oauth2::Client as OAuth2; //! # let auth = OAuth2::google(Default::default(), "", "", None); -//! # let mut token = auth.request_token("").unwrap(); -//! if token.expired() { -//! token = auth.refresh_token(&token, None).unwrap(); +//! # let mut token_pair = auth.request_token("").unwrap(); +//! if token_pair.expired() { +//! if let Some(refresh) = token_pair.refresh { +//! token_pair = auth.refresh_token(refresh, None).unwrap(); +//! } //! } //! ``` //! //! ## Persisting tokens //! -//! `Token` implements `Encodable` and `Decodable` from `rustc_serialize`, so can be persisted in -//! JSON. +//! `TokenPair` implements `Encodable` and `Decodable` from `rustc_serialize`, so can be persisted +//! as JSON. //! //! ``` //! # extern crate inth_oauth2; //! # extern crate rustc_serialize; //! # extern crate chrono; -//! use inth_oauth2::Token; +//! use inth_oauth2::{TokenPair, AccessToken, RefreshToken}; //! use rustc_serialize::json; //! # use chrono::{UTC, Timelike}; //! # fn main() { -//! # let token = Token { -//! # access_token: String::from("AAAAAAAA"), -//! # token_type: String::from("bearer"), -//! # expires: Some(UTC::now().with_nanosecond(0).unwrap()), -//! # refresh_token: Some(String::from("BBBBBBB")), -//! # scope: None, +//! # let token_pair = TokenPair { +//! # access: AccessToken { +//! # token: String::from("AAAAAAAA"), +//! # token_type: String::from("bearer"), +//! # expires: Some(UTC::now().with_nanosecond(0).unwrap()), +//! # scope: None, +//! # }, +//! # refresh: Some(RefreshToken { token: String::from("BBBBBBBB") }), //! # }; //! -//! let json = json::encode(&token).unwrap(); -//! let decoded: Token = json::decode(&json).unwrap(); -//! assert_eq!(token, decoded); +//! let json = json::encode(&token_pair).unwrap(); +//! let decoded: TokenPair = json::decode(&json).unwrap(); +//! assert_eq!(token_pair, decoded); //! # } //! ``` @@ -122,7 +126,7 @@ extern crate url; pub use client::Client; pub mod client; -pub use token::Token; +pub use token::{TokenPair, AccessToken, RefreshToken}; pub mod token; pub use error::{Error, Result}; diff --git a/src/token.rs b/src/token.rs index 0cc2635..ee755ab 100644 --- a/src/token.rs +++ b/src/token.rs @@ -1,13 +1,24 @@ +use std::ops::Deref; + use chrono::{DateTime, UTC, TimeZone}; use rustc_serialize::{Encodable, Encoder, Decodable, Decoder}; +/// OAuth 2.0 access token and refresh token pair. +#[derive(Debug, Clone, PartialEq, Eq, RustcEncodable, RustcDecodable)] +pub struct TokenPair { + /// The access token. + pub access: AccessToken, + /// The refresh token. + pub refresh: Option, +} + /// OAuth 2.0 access token. /// /// See [RFC6749 section 5](http://tools.ietf.org/html/rfc6749#section-5). #[derive(Debug, Clone, PartialEq, Eq)] -pub struct Token { +pub struct AccessToken { /// The access token issued by the authorization server. - pub access_token: String, + pub token: String, /// The type of the token issued. /// @@ -17,59 +28,71 @@ pub struct Token { /// The expiry time of the access token. pub expires: Option>, - /// The refresh token, which can be used to obtain new access tokens. - pub refresh_token: Option, - /// The scope of the access token. pub scope: Option, } -impl Token { +/// OAuth 2.0 refresh token. +/// +/// See [RFC6749 section 1.5](http://tools.ietf.org/html/rfc6749#section-1.5). +#[derive(Debug, Clone, PartialEq, Eq, RustcEncodable, RustcDecodable)] +pub struct RefreshToken { + /// The refresh token issued by the authorization server. + pub token: String, +} + +impl AccessToken { /// Returns true if token is expired. pub fn expired(&self) -> bool { self.expires.map_or(false, |dt| dt < UTC::now()) } } +impl Deref for TokenPair { + type Target = AccessToken; + + fn deref<'a>(&'a self) -> &'a AccessToken { + &self.access + } +} + #[derive(RustcEncodable, RustcDecodable)] -struct SerializableToken { - access_token: String, +struct SerializableAccessToken { + token: String, token_type: String, expires: Option, - refresh_token: Option, scope: Option, } -impl SerializableToken { - fn from_token(token: &Token) -> Self { - SerializableToken { - access_token: token.access_token.clone(), - token_type: token.token_type.clone(), - expires: token.expires.as_ref().map(DateTime::timestamp), - refresh_token: token.refresh_token.clone(), - scope: token.scope.clone(), +impl SerializableAccessToken { + fn from_access_token(access: &AccessToken) -> Self { + SerializableAccessToken { + token: access.token.clone(), + token_type: access.token_type.clone(), + expires: access.expires.as_ref().map(DateTime::timestamp), + scope: access.scope.clone(), } } - fn into_token(self) -> Token { - Token { - access_token: self.access_token, + fn into_access_token(self) -> AccessToken { + AccessToken { + token: self.token, token_type: self.token_type, expires: self.expires.map(|t| UTC.timestamp(t, 0)), - refresh_token: self.refresh_token, scope: self.scope, } } } -impl Encodable for Token { +impl Encodable for AccessToken { fn encode(&self, s: &mut S) -> Result<(), S::Error> { - SerializableToken::from_token(self).encode(s) + SerializableAccessToken::from_access_token(self).encode(s) } } -impl Decodable for Token { +impl Decodable for AccessToken { fn decode(d: &mut D) -> Result { - SerializableToken::decode(d).map(SerializableToken::into_token) + SerializableAccessToken::decode(d) + .map(SerializableAccessToken::into_access_token) } }