#![warn(missing_docs)]
#[macro_use]
extern crate lazy_static;
mod thread_id;
mod unreachable;
use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
use std::sync::Mutex;
use std::marker::PhantomData;
use std::cell::UnsafeCell;
use std::fmt;
use std::iter::Chain;
use std::option::IntoIter as OptionIter;
use std::panic::UnwindSafe;
use unreachable::{UncheckedOptionExt, UncheckedResultExt};
pub struct ThreadLocal<T: ?Sized + Send> {
table: AtomicPtr<Table<T>>,
lock: Mutex<usize>,
marker: PhantomData<T>,
}
struct Table<T: ?Sized + Send> {
entries: Box<[TableEntry<T>]>,
hash_bits: usize,
prev: Option<Box<Table<T>>>,
}
struct TableEntry<T: ?Sized + Send> {
owner: AtomicUsize,
data: UnsafeCell<Option<Box<T>>>,
}
unsafe impl<T: ?Sized + Send> Sync for ThreadLocal<T> {}
impl<T: ?Sized + Send> Default for ThreadLocal<T> {
fn default() -> ThreadLocal<T> {
ThreadLocal::new()
}
}
impl<T: ?Sized + Send> Drop for ThreadLocal<T> {
fn drop(&mut self) {
unsafe {
Box::from_raw(self.table.load(Ordering::Relaxed));
}
}
}
impl<T: ?Sized + Send> Clone for TableEntry<T> {
fn clone(&self) -> TableEntry<T> {
TableEntry {
owner: AtomicUsize::new(0),
data: UnsafeCell::new(None),
}
}
}
#[cfg(target_pointer_width = "32")]
#[inline]
fn hash(id: usize, bits: usize) -> usize {
id.wrapping_mul(0x9E3779B9) >> (32 - bits)
}
#[cfg(target_pointer_width = "64")]
#[inline]
fn hash(id: usize, bits: usize) -> usize {
id.wrapping_mul(0x9E37_79B9_7F4A_7C15) >> (64 - bits)
}
impl<T: ?Sized + Send> ThreadLocal<T> {
pub fn new() -> ThreadLocal<T> {
let entry = TableEntry {
owner: AtomicUsize::new(0),
data: UnsafeCell::new(None),
};
let table = Table {
entries: vec![entry; 2].into_boxed_slice(),
hash_bits: 1,
prev: None,
};
ThreadLocal {
table: AtomicPtr::new(Box::into_raw(Box::new(table))),
lock: Mutex::new(0),
marker: PhantomData,
}
}
pub fn get(&self) -> Option<&T> {
let id = thread_id::get();
self.get_fast(id)
}
pub fn get_or<F>(&self, create: F) -> &T
where
F: FnOnce() -> Box<T>,
{
unsafe {
self.get_or_try(|| Ok::<Box<T>, ()>(create()))
.unchecked_unwrap_ok()
}
}
pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E>
where
F: FnOnce() -> Result<Box<T>, E>,
{
let id = thread_id::get();
match self.get_fast(id) {
Some(x) => Ok(x),
None => Ok(self.insert(id, try!(create()), true)),
}
}
fn lookup(id: usize, table: &Table<T>) -> Option<&UnsafeCell<Option<Box<T>>>> {
for entry in table.entries.iter().cycle().skip(hash(id, table.hash_bits)) {
let owner = entry.owner.load(Ordering::Relaxed);
if owner == id {
return Some(&entry.data);
}
if owner == 0 {
return None;
}
}
unreachable!();
}
fn get_fast(&self, id: usize) -> Option<&T> {
let table = unsafe { &*self.table.load(Ordering::Relaxed) };
match Self::lookup(id, table) {
Some(x) => unsafe { Some((*x.get()).as_ref().unchecked_unwrap()) },
None => self.get_slow(id, table),
}
}
#[cold]
fn get_slow(&self, id: usize, table_top: &Table<T>) -> Option<&T> {
let mut current = &table_top.prev;
while let Some(ref table) = *current {
if let Some(x) = Self::lookup(id, table) {
let data = unsafe { (*x.get()).take().unchecked_unwrap() };
return Some(self.insert(id, data, false));
}
current = &table.prev;
}
None
}
#[cold]
fn insert(&self, id: usize, data: Box<T>, new: bool) -> &T {
let mut count = self.lock.lock().unwrap();
if new {
*count += 1;
}
let table_raw = self.table.load(Ordering::Relaxed);
let table = unsafe { &*table_raw };
let table = if *count > table.entries.len() * 3 / 4 {
let entry = TableEntry {
owner: AtomicUsize::new(0),
data: UnsafeCell::new(None),
};
let new_table = Box::into_raw(Box::new(Table {
entries: vec![entry; table.entries.len() * 2].into_boxed_slice(),
hash_bits: table.hash_bits + 1,
prev: unsafe { Some(Box::from_raw(table_raw)) },
}));
self.table.store(new_table, Ordering::Release);
unsafe { &*new_table }
} else {
table
};
for entry in table.entries.iter().cycle().skip(hash(id, table.hash_bits)) {
let owner = entry.owner.load(Ordering::Relaxed);
if owner == 0 {
unsafe {
entry.owner.store(id, Ordering::Relaxed);
*entry.data.get() = Some(data);
return (*entry.data.get()).as_ref().unchecked_unwrap();
}
}
if owner == id {
unsafe {
return (*entry.data.get()).as_ref().unchecked_unwrap();
}
}
}
unreachable!();
}
pub fn iter_mut(&mut self) -> IterMut<T> {
let raw = RawIter {
remaining: *self.lock.lock().unwrap(),
index: 0,
table: self.table.load(Ordering::Relaxed),
};
IterMut {
raw: raw,
marker: PhantomData,
}
}
pub fn clear(&mut self) {
*self = ThreadLocal::new();
}
}
impl<T: ?Sized + Send> IntoIterator for ThreadLocal<T> {
type Item = Box<T>;
type IntoIter = IntoIter<T>;
fn into_iter(self) -> IntoIter<T> {
let raw = RawIter {
remaining: *self.lock.lock().unwrap(),
index: 0,
table: self.table.load(Ordering::Relaxed),
};
IntoIter {
raw: raw,
_thread_local: self,
}
}
}
impl<'a, T: ?Sized + Send + 'a> IntoIterator for &'a mut ThreadLocal<T> {
type Item = &'a mut Box<T>;
type IntoIter = IterMut<'a, T>;
fn into_iter(self) -> IterMut<'a, T> {
self.iter_mut()
}
}
impl<T: Send + Default> ThreadLocal<T> {
pub fn get_default(&self) -> &T {
self.get_or(|| Box::new(T::default()))
}
}
impl<T: ?Sized + Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get())
}
}
impl<T: ?Sized + Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}
struct RawIter<T: ?Sized + Send> {
remaining: usize,
index: usize,
table: *const Table<T>,
}
impl<T: ?Sized + Send> RawIter<T> {
fn next(&mut self) -> Option<*mut Option<Box<T>>> {
if self.remaining == 0 {
return None;
}
loop {
let entries = unsafe { &(*self.table).entries[..] };
while self.index < entries.len() {
let val = entries[self.index].data.get();
self.index += 1;
if unsafe { (*val).is_some() } {
self.remaining -= 1;
return Some(val);
}
}
self.index = 0;
self.table = unsafe { &**(*self.table).prev.as_ref().unchecked_unwrap() };
}
}
}
pub struct IterMut<'a, T: ?Sized + Send + 'a> {
raw: RawIter<T>,
marker: PhantomData<&'a mut ThreadLocal<T>>,
}
impl<'a, T: ?Sized + Send + 'a> Iterator for IterMut<'a, T> {
type Item = &'a mut Box<T>;
fn next(&mut self) -> Option<&'a mut Box<T>> {
self.raw.next().map(|x| unsafe {
(*x).as_mut().unchecked_unwrap()
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.raw.remaining, Some(self.raw.remaining))
}
}
impl<'a, T: ?Sized + Send + 'a> ExactSizeIterator for IterMut<'a, T> {}
pub struct IntoIter<T: ?Sized + Send> {
raw: RawIter<T>,
_thread_local: ThreadLocal<T>,
}
impl<T: ?Sized + Send> Iterator for IntoIter<T> {
type Item = Box<T>;
fn next(&mut self) -> Option<Box<T>> {
self.raw.next().map(
|x| unsafe { (*x).take().unchecked_unwrap() },
)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.raw.remaining, Some(self.raw.remaining))
}
}
impl<T: ?Sized + Send> ExactSizeIterator for IntoIter<T> {}
pub struct CachedThreadLocal<T: ?Sized + Send> {
owner: AtomicUsize,
local: UnsafeCell<Option<Box<T>>>,
global: ThreadLocal<T>,
}
unsafe impl<T: ?Sized + Send> Sync for CachedThreadLocal<T> {}
impl<T: ?Sized + Send> Default for CachedThreadLocal<T> {
fn default() -> CachedThreadLocal<T> {
CachedThreadLocal::new()
}
}
impl<T: ?Sized + Send> CachedThreadLocal<T> {
pub fn new() -> CachedThreadLocal<T> {
CachedThreadLocal {
owner: AtomicUsize::new(0),
local: UnsafeCell::new(None),
global: ThreadLocal::new(),
}
}
pub fn get(&self) -> Option<&T> {
let id = thread_id::get();
let owner = self.owner.load(Ordering::Relaxed);
if owner == id {
return unsafe { Some((*self.local.get()).as_ref().unchecked_unwrap()) };
}
if owner == 0 {
return None;
}
self.global.get_fast(id)
}
#[inline(always)]
pub fn get_or<F>(&self, create: F) -> &T
where
F: FnOnce() -> Box<T>,
{
unsafe {
self.get_or_try(|| Ok::<Box<T>, ()>(create()))
.unchecked_unwrap_ok()
}
}
pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E>
where
F: FnOnce() -> Result<Box<T>, E>,
{
let id = thread_id::get();
let owner = self.owner.load(Ordering::Relaxed);
if owner == id {
return Ok(unsafe { (*self.local.get()).as_ref().unchecked_unwrap() });
}
self.get_or_try_slow(id, owner, create)
}
#[cold]
#[inline(never)]
fn get_or_try_slow<F, E>(&self, id: usize, owner: usize, create: F) -> Result<&T, E>
where
F: FnOnce() -> Result<Box<T>, E>,
{
if owner == 0 && self.owner.compare_and_swap(0, id, Ordering::Relaxed) == 0 {
unsafe {
(*self.local.get()) = Some(try!(create()));
return Ok((*self.local.get()).as_ref().unchecked_unwrap());
}
}
match self.global.get_fast(id) {
Some(x) => Ok(x),
None => Ok(self.global.insert(id, try!(create()), true)),
}
}
pub fn iter_mut(&mut self) -> CachedIterMut<T> {
unsafe {
(*self.local.get()).as_mut().into_iter().chain(
self.global
.iter_mut(),
)
}
}
pub fn clear(&mut self) {
*self = CachedThreadLocal::new();
}
}
impl<T: ?Sized + Send> IntoIterator for CachedThreadLocal<T> {
type Item = Box<T>;
type IntoIter = CachedIntoIter<T>;
fn into_iter(self) -> CachedIntoIter<T> {
unsafe {
(*self.local.get()).take().into_iter().chain(
self.global
.into_iter(),
)
}
}
}
impl<'a, T: ?Sized + Send + 'a> IntoIterator for &'a mut CachedThreadLocal<T> {
type Item = &'a mut Box<T>;
type IntoIter = CachedIterMut<'a, T>;
fn into_iter(self) -> CachedIterMut<'a, T> {
self.iter_mut()
}
}
impl<T: Send + Default> CachedThreadLocal<T> {
pub fn get_default(&self) -> &T {
self.get_or(|| Box::new(T::default()))
}
}
impl<T: ?Sized + Send + fmt::Debug> fmt::Debug for CachedThreadLocal<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get())
}
}
pub type CachedIterMut<'a, T> = Chain<OptionIter<&'a mut Box<T>>, IterMut<'a, T>>;
pub type CachedIntoIter<T> = Chain<OptionIter<Box<T>>, IntoIter<T>>;
impl<T: ?Sized + Send + UnwindSafe> UnwindSafe for CachedThreadLocal<T> {}
#[cfg(test)]
mod tests {
use std::cell::RefCell;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::Relaxed;
use std::thread;
use super::{ThreadLocal, CachedThreadLocal};
fn make_create() -> Arc<Fn() -> Box<usize> + Send + Sync> {
let count = AtomicUsize::new(0);
Arc::new(move || Box::new(count.fetch_add(1, Relaxed)))
}
#[test]
fn same_thread() {
let create = make_create();
let mut tls = ThreadLocal::new();
assert_eq!(None, tls.get());
assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls));
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls));
tls.clear();
assert_eq!(None, tls.get());
}
#[test]
fn same_thread_cached() {
let create = make_create();
let mut tls = CachedThreadLocal::new();
assert_eq!(None, tls.get());
assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls));
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls));
tls.clear();
assert_eq!(None, tls.get());
}
#[test]
fn different_thread() {
let create = make_create();
let tls = Arc::new(ThreadLocal::new());
assert_eq!(None, tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
let tls2 = tls.clone();
let create2 = create.clone();
thread::spawn(move || {
assert_eq!(None, tls2.get());
assert_eq!(1, *tls2.get_or(|| create2()));
assert_eq!(Some(&1), tls2.get());
}).join()
.unwrap();
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
}
#[test]
fn different_thread_cached() {
let create = make_create();
let tls = Arc::new(CachedThreadLocal::new());
assert_eq!(None, tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
let tls2 = tls.clone();
let create2 = create.clone();
thread::spawn(move || {
assert_eq!(None, tls2.get());
assert_eq!(1, *tls2.get_or(|| create2()));
assert_eq!(Some(&1), tls2.get());
}).join()
.unwrap();
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
}
#[test]
fn iter() {
let tls = Arc::new(ThreadLocal::new());
tls.get_or(|| Box::new(1));
let tls2 = tls.clone();
thread::spawn(move || {
tls2.get_or(|| Box::new(2));
let tls3 = tls2.clone();
thread::spawn(move || { tls3.get_or(|| Box::new(3)); })
.join()
.unwrap();
}).join()
.unwrap();
let mut tls = Arc::try_unwrap(tls).unwrap();
let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>();
v.sort();
assert_eq!(vec![1, 2, 3], v);
let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>();
v.sort();
assert_eq!(vec![1, 2, 3], v);
}
#[test]
fn iter_cached() {
let tls = Arc::new(CachedThreadLocal::new());
tls.get_or(|| Box::new(1));
let tls2 = tls.clone();
thread::spawn(move || {
tls2.get_or(|| Box::new(2));
let tls3 = tls2.clone();
thread::spawn(move || { tls3.get_or(|| Box::new(3)); })
.join()
.unwrap();
}).join()
.unwrap();
let mut tls = Arc::try_unwrap(tls).unwrap();
let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>();
v.sort();
assert_eq!(vec![1, 2, 3], v);
let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>();
v.sort();
assert_eq!(vec![1, 2, 3], v);
}
#[test]
fn is_sync() {
fn foo<T: Sync>() {}
foo::<ThreadLocal<String>>();
foo::<ThreadLocal<RefCell<String>>>();
foo::<CachedThreadLocal<String>>();
foo::<CachedThreadLocal<RefCell<String>>>();
}
}