#include "tradArrayDistGraphR.h"

#ifdef __XGRAPH_DIST__

TradArrayDistGraphR::TradArrayDistGraphR(v_id n, v_id m, v_id* offsets, v_id* adj, int loading_rank, MPI_Comm comm) : GraphR(n,m) {
  configureParameters(comm, n, m);
  loadAndDistributeGraphFromMemory(offsets, adj, loading_rank);
}

void TradArrayDistGraphR::loadAndDistributeGraphFromMemory(v_id* offsets, v_id* adj, int loading_rank) {
  // TODO: now this thing is really inefficient, but I also plan to use it for small
  // graphs; this code enables reusing the ready functions in graphIO that load graphs.

  // Distribute the degrees so that each proc can allocate enough memory:
  m_local_ = 0;

  if(rank_ == loading_rank) { 
    v_id* m_for_each_proc = new v_id[num_procs_]();
    for(v_id v = 0; v < n_; ++v) {
      v_id v_deg = offsets[v+1] - offsets[v];
      int v_owner = getVertexOwningRank(v);
      m_for_each_proc[v_owner] += v_deg;
    }

    MPI_Scatter(&m_for_each_proc[0], 1, MPI_INT64_T, &m_local_, 1, MPI_INT64_T, loading_rank, comm_);

    delete [] m_for_each_proc;
  }
  else {
    MPI_Scatter(NULL, 0, MPI_INT64_T, &m_local_, 1, MPI_INT64_T, loading_rank, comm_);
  }

#ifdef __XGRAPH_DEBUGGING__
  v_id debug_m = 0;
  MPI_Allreduce(&m_local_, &debug_m, 1, MPI_INT64_T, MPI_SUM, comm_);
  assert(debug_m == 2*m_);
#endif

  adj_local_ = new v_id[m_local_]();
  offsets_local_ = new v_id[n_local_+1]();

#ifdef __XGRAPH_DEBUGGING__
  for(v_id v = 0; v < n_local_ + 1; ++v) {
    offsets_local_[v] = -1;
  }

  for(v_id e = 0; e < m_local_; ++e) {
    adj_local_[e] = -1;
  }
#endif

  offsets_local_[0] = 0; 
  v_id processed_local_vertex = 0;
  v_id current_offset = 0;

  if(rank_ == loading_rank) { 
    for(v_id v = 0; v < n_; ++v) {
      v_id v_deg = offsets[v+1] - offsets[v];
      v_id v_offset = offsets[v];
      int owning_rank = getVertexOwningRank(v);
      if(owning_rank == loading_rank) {
        v_id v_local_id = getVertexLocalID(v);
        assert(v_local_id == processed_local_vertex++);
        offsets_local_[v_local_id + 1] = offsets_local_[v_local_id] + v_deg;
      }
    }

    for(v_id v = 0; v < n_; ++v) {
      v_id v_deg = offsets[v+1] - offsets[v];
      v_id v_offset = offsets[v];
      int owning_rank = getVertexOwningRank(v);
      if(owning_rank != loading_rank) {
        v_id* send_data = new v_id[v_deg + 2]();
        send_data[0] = v;
        send_data[1] = v_deg;
        memcpy(&send_data[2], &adj[v_offset], sizeof(v_id) * v_deg);
        MPI_Send(&send_data[0], v_deg + 2, MPI_INT64_T, getVertexOwningRank(v), 111, comm_);
        delete [] send_data;
      }
      else {
        v_id v_local_id = getVertexLocalID(v);
        memcpy(&adj_local_[ offsets_local_[v_local_id] ], &adj[v_offset], sizeof(v_id) * v_deg);
      }
    }
  }
  else {
    v_id prev_v_global = 0;
    v_id prev_v_local = 0;
    bool first_msg = true;
    for(v_id i = 0; i < n_local_; ++i) {
      MPI_Status status;
      MPI_Probe(loading_rank, 111, comm_, &status);

      int msg_size = -1;
      MPI_Get_count(&status, MPI_INT64_T, &msg_size);
      assert(msg_size >= 0);

      v_id* recv_data = new v_id[msg_size]();
      MPI_Recv(&recv_data[0], msg_size, MPI_INT64_T, loading_rank, 111, comm_, &status);
      v_id v_global = recv_data[0];
      v_id v_local = getVertexLocalID(v_global);
      assert(getVertexOwningRank(v_global) == rank_);
      if(!first_msg) {
        assert(prev_v_global < v_global);
        assert(prev_v_local == (v_local - 1));
      }
      v_id v_deg = recv_data[1];
      assert(v_deg == (msg_size - 2));

      memcpy(&adj_local_[current_offset], &recv_data[2], sizeof(v_id) * v_deg);
      current_offset += v_deg;
      //offsets_local_[processed_local_vertex+1] = offsets_local_[processed_local_vertex] + v_deg;
      offsets_local_[v_local+1] = offsets_local_[v_local] + v_deg;
      delete [] recv_data;

      //++processed_local_vertex;
      prev_v_global = v_global;
      prev_v_local = v_local;
    }
  }

  MPI_Barrier(comm_);

#ifdef __XGRAPH_DEBUGGING__
  for(v_id v = 0; v < n_local_ + 1; ++v) {
    assert(offsets_local_[v] >= 0);
    assert(offsets_local_[v] <= m_local_);
    if(v > 0) {
      assert(offsets_local_[v] >= offsets_local_[v-1]);
    }
  }

  for(v_id e = 0; e < m_local_ + 1; ++e) {
    assert(adj_local_[e] >= 0);
  }

  // Soft checking, maybe do some full backward data distribution? TODO
  v_id global_ids_sum = 0;
  for(v_id v = 0; v < n_local_; ++v) {
    v_id v_global_id = getVertexGlobalID(v, rank_);
    global_ids_sum += v_global_id;

    v_id v_deg = offsets_local_[v+1] - offsets_local_[v];
    for(v_id e = 0; e < v_deg; ++e) {
      global_ids_sum += adj_local_[ offsets_local_[v] + e ];
    }
  }

  if(rank_ == loading_rank) {
    v_id* received_sums = new v_id[num_procs_]();
    v_id* computed_sums = new v_id[num_procs_]();

    MPI_Gather(&global_ids_sum, 1, MPI_INT64_T, &received_sums[0], 1, MPI_INT64_T, loading_rank, comm_);

    for(v_id v = 0; v < n_; ++v) {
      v_id v_owner = getVertexOwningRank(v);
      computed_sums[v_owner] += v;
    
      v_id v_deg = offsets[v+1] - offsets[v];
      for(v_id e = 0; e < v_deg; ++e) {
        computed_sums[v_owner] += adj[ offsets[v] + e ];
      }
    }

    for(int proc = 0; proc < num_procs_; ++proc) {
      assert(received_sums[proc] == computed_sums[proc]);
    }

    delete [] received_sums;
    delete [] computed_sums;
  }
  else {
    MPI_Gather(&global_ids_sum, 1, MPI_INT64_T, NULL, 0, MPI_INT64_T, loading_rank, comm_);
  }

  debug_m = 0;
  MPI_Allreduce(&m_local_, &debug_m, 1, MPI_INT64_T, MPI_SUM, comm_);
  assert(debug_m == 2*m_);


  for(v_id v = 0; v < n_local_; ++v) {
    v_id v_global = getVertexGlobalID(v, rank_);
    v_id v_owner = getVertexOwningRank(v_global);
    assert(v_owner == rank_);
  }
#endif
}

void TradArrayDistGraphR::configureParameters(MPI_Comm comm, v_id n, v_id m) {
  comm_ = comm;

  MPI_Comm_size(comm, &num_procs_);
  MPI_Comm_rank(comm, &rank_);

  n_ = n;
  m_ = m;

#ifdef NUM_OF_PROCS_IS_A_POWER_OF_TWO 
  if((num_procs_ & (num_procs_ - 1)) != 0) {
    exitMsg("The number of processes (%d) should be a power of two!\n" << num_procs_);
  }

  for (log_of_num_procs_ = 0; log_of_num_procs_ < num_procs_; ++log_of_num_procs_) {
    if ((1 << log_of_num_procs_) == num_procs_) break;
  }
  assert (log_of_num_procs_ < num_procs_);
#endif

  assert(num_procs_ > 0);
  assert(rank_ >= 0);

  n_local_ = 0;
  m_local_ = 0;

  n_local_ = 0;
  for(v_id v = 0; v < n_; ++v) {
    if(rank_ == getVertexOwningRank(v)) {
      ++n_local_;
    }
  }

#ifdef __XGRAPH_DEBUGGING__
  v_id debug_n = 0;
  MPI_Allreduce(&n_local_, &debug_n, 1, MPI_INT64_T, MPI_SUM, comm_);
  assert(debug_n == n);

  v_id* n_locals = new v_id[num_procs_]();
  for(v_id v = 0; v < n_; ++v) {
    ++n_locals[getVertexOwningRank(v)];
  }

  v_id n_sum = 0;
  for(int p = 0; p < num_procs_; ++p) {
    n_sum += n_locals[p];
  }

  assert(n_sum == n);

  delete [] n_locals;
#endif
}

TradArrayDistGraphR::~TradArrayDistGraphR() {
  delete [] offsets_local_;
  delete [] adj_local_;
}


#endif
