/*
 * Test program that illustrates how to annotate a smart pointer
 * implementation.  In a multithreaded program the following is relevant when
 * working with smart pointers:
 * - whether or not the objects pointed at are shared over threads.
 * - whether or not the methods of the objects pointed at are thread-safe.
 * - whether or not the smart pointer objects are shared over threads.
 * - whether or not the smart pointer object itself is thread-safe.
 *
 * Most smart pointer implemenations are not thread-safe
 * (e.g. boost::shared_ptr<>, tr1::shared_ptr<> and the smart_ptr<>
 * implementation below). This means that it is not safe to modify a shared
 * pointer object that is shared over threads without proper synchronization.
 *
 * Even for non-thread-safe smart pointers it is possible to have different
 * threads access the same object via smart pointers without triggering data
 * races on the smart pointer objects.
 *
 * A smart pointer implementation guarantees that the destructor of the object
 * pointed at is invoked after the last smart pointer that points to that
 * object has been destroyed or reset. Data race detection tools cannot detect
 * this ordering without explicit annotation for smart pointers that track
 * references without invoking synchronization operations recognized by data
 * race detection tools.
 */


#include <cassert>     // assert()
#include <climits>     // PTHREAD_STACK_MIN
#include <iostream>    // std::cerr
#include <stdlib.h>    // atoi()
#ifdef _WIN32
#include <process.h>   // _beginthreadex()
#include <windows.h>   // CRITICAL_SECTION
#else
#include <pthread.h>   // pthread_mutex_t
#endif
#include "unified_annotations.h"


static bool s_enable_annotations;


#ifdef _WIN32

class AtomicInt32
{
public:
  AtomicInt32(const int value = 0) : m_value(value) { }
  ~AtomicInt32() { }
  LONG operator++() { return InterlockedIncrement(&m_value); }
  LONG operator--() { return InterlockedDecrement(&m_value); }

private:
  volatile LONG m_value;
};

class Mutex
{
public:
  Mutex() : m_mutex()
  { InitializeCriticalSection(&m_mutex); }
  ~Mutex()
  { DeleteCriticalSection(&m_mutex); }
  void Lock()
  { EnterCriticalSection(&m_mutex); }
  void Unlock()
  { LeaveCriticalSection(&m_mutex); }

private:
  CRITICAL_SECTION m_mutex;
};

class Thread
{
public:
  Thread() : m_thread(INVALID_HANDLE_VALUE) { }
  ~Thread() { }
  void Create(void* (*pf)(void*), void* arg)
  {
    WrapperArgs* wrapper_arg_p = new WrapperArgs(pf, arg);
    m_thread = reinterpret_cast<HANDLE>(_beginthreadex(NULL, 0, wrapper,
						       wrapper_arg_p, 0, NULL));
  }
  void Join()
  { WaitForSingleObject(m_thread, INFINITE); }

private:
  struct WrapperArgs
  {
    WrapperArgs(void* (*pf)(void*), void* arg) : m_pf(pf), m_arg(arg) { }

    void* (*m_pf)(void*);
    void* m_arg;
  };
  static unsigned int __stdcall wrapper(void* arg)
  {
    WrapperArgs* wrapper_arg_p = reinterpret_cast<WrapperArgs*>(arg);
    WrapperArgs wa = *wrapper_arg_p;
    delete wrapper_arg_p;
    return reinterpret_cast<unsigned>((wa.m_pf)(wa.m_arg));
  }
  HANDLE m_thread;
};

#else // _WIN32

class AtomicInt32
{
public:
  AtomicInt32(const int value = 0) : m_value(value) { }
  ~AtomicInt32() { }
  int operator++() { return __sync_add_and_fetch(&m_value, 1); }
  int operator--() { return __sync_sub_and_fetch(&m_value, 1); }
private:
  volatile int m_value;
};

class Mutex
{
public:
  Mutex() : m_mutex()
  { pthread_mutex_init(&m_mutex, NULL); }
  ~Mutex()
  { pthread_mutex_destroy(&m_mutex); }
  void Lock()
  { pthread_mutex_lock(&m_mutex); }
  void Unlock()
  { pthread_mutex_unlock(&m_mutex); }

private:
  pthread_mutex_t m_mutex;
};

class Thread
{
public:
  Thread() : m_tid() { }
  ~Thread() { }
  void Create(void* (*pf)(void*), void* arg)
  {
    pthread_attr_t attr;
    pthread_attr_init(&attr);
    pthread_attr_setstacksize(&attr, PTHREAD_STACK_MIN + 4096);
    pthread_create(&m_tid, &attr, pf, arg);
    pthread_attr_destroy(&attr);
  }
  void Join()
  { pthread_join(m_tid, NULL); }
private:
  pthread_t m_tid;
};

#endif // !defined(_WIN32)


template<class T>
class smart_ptr
{
public:
  typedef AtomicInt32 counter_t;

  template <typename Q> friend class smart_ptr;

  explicit smart_ptr()
    : m_ptr(NULL), m_count_ptr(NULL)
  { }

  explicit smart_ptr(T* const pT)
    : m_ptr(NULL), m_count_ptr(NULL)
  {
    set(pT, pT ? new counter_t(0) : NULL);
  }

  template <typename Q>
  explicit smart_ptr(Q* const q)
    : m_ptr(NULL), m_count_ptr(NULL)
  {
    set(q, q ? new counter_t(0) : NULL);
  }

  ~smart_ptr()
  {
    set(NULL, NULL);
  }

  smart_ptr(const smart_ptr<T>& sp)
    : m_ptr(NULL), m_count_ptr(NULL)
  {
    set(sp.m_ptr, sp.m_count_ptr);
  }

  template <typename Q>
  smart_ptr(const smart_ptr<Q>& sp)
    : m_ptr(NULL), m_count_ptr(NULL)
  {
    set(sp.m_ptr, sp.m_count_ptr);
  }

  smart_ptr& operator=(const smart_ptr<T>& sp)
  {
    set(sp.m_ptr, sp.m_count_ptr);
    return *this;
  }

  smart_ptr& operator=(T* const p)
  {
    set(p, p ? new counter_t(0) : NULL);
    return *this;
  }

  template <typename Q>
  smart_ptr& operator=(Q* const q)
  {
    set(q, q ? new counter_t(0) : NULL);
    return *this;
  }

  T* operator->() const
  {
    assert(m_ptr);
    return m_ptr;
  }

  T& operator*() const
  {
    assert(m_ptr);
    return *m_ptr;
  }

private:
  void set(T* const pT, counter_t* const count_ptr)
  {
    if (m_ptr != pT)
    {
      if (m_count_ptr)
      {
	if (s_enable_annotations)
	  U_ANNOTATE_HAPPENS_BEFORE(m_count_ptr);
	if (--(*m_count_ptr) == 0)
	{
	  if (s_enable_annotations)
	    U_ANNOTATE_HAPPENS_AFTER(m_count_ptr);
	  delete m_ptr;
	  m_ptr = NULL;
	  delete m_count_ptr;
	  m_count_ptr = NULL;
	}
      }
      m_ptr = pT;
      m_count_ptr = count_ptr;
      if (count_ptr)
	++(*m_count_ptr);
    }
  }

  T*         m_ptr;
  counter_t* m_count_ptr;
};

class counter
{
public:
  counter()
    : m_mutex(), m_count()
  { }
  ~counter()
  {
    // Data race detection tools that do not recognize the
    // ANNOTATE_HAPPENS_BEFORE() / ANNOTATE_HAPPENS_AFTER() annotations in the
    // smart_ptr<> implementation will report that the assignment below
    // triggers a data race.
    m_count = -1;
  }
  int get() const
  {
    int result;
    m_mutex.Lock();
    result = m_count;
    m_mutex.Unlock();
    return result;
  }
  int post_increment()
  {
    int result;
    m_mutex.Lock();
    result = m_count++;
    m_mutex.Unlock();
    return result;
  }

private:
  mutable Mutex m_mutex;
  int           m_count;
};

static void* thread_func(void* arg)
{
  smart_ptr<counter>* pp = reinterpret_cast<smart_ptr<counter>*>(arg);
  (*pp)->post_increment();
  *pp = NULL;
  delete pp;
  return NULL;
}

int main(int argc, char** argv)
{
  const int nthreads = std::max(argc > 1 ? atoi(argv[1]) : 1, 1);
  const int iterations = std::max(argc > 2 ? atoi(argv[2]) : 1, 1);
  s_enable_annotations = argc > 3 ? !!atoi(argv[3]) : true;

  for (int j = 0; j < iterations; ++j)
  {
    Thread T[nthreads];

    smart_ptr<counter> p(new counter);
    p->post_increment();
    for (int i = 0; i < nthreads; ++i)
      T[i].Create(thread_func, new smart_ptr<counter>(p));
    {
      // Avoid that counter.m_mutex introduces a false ordering on the
      // counter.m_count accesses.
      const timespec delay = { 0, 100 * 1000 * 1000 };
      nanosleep(&delay, 0);
    }
    p = NULL;
    for (int i = 0; i < nthreads; ++i)
      T[i].Join();
  }
  std::cerr << "Done.\n";
  return 0;
}