#include "tradArrayDistG500GraphR.h"

#ifdef __XGRAPH_DIST__

#include "external_generators/g500_rmat/common.h"
#include "external_generators/generator/make_graph.h"
#include "external_generators/generator/utils.h"
//#include "external_generators/generator/utils.c"

/*
   static int compare_doubles(const void* a, const void* b) {
   double aa = *(const double*)a;
   double bb = *(const double*)b;
   return (aa < bb) ? -1 : (aa == bb) ? 0 : 1;
   }

   enum {s_minimum, s_firstquartile, s_median, s_thirdquartile, s_maximum, s_mean, s_std, s_LAST};
   static void get_statistics(const double x[], int n, double r[s_LAST]) {
   double temp;
   int i;
// Compute mean. 
temp = 0;
for (i = 0; i < n; ++i) temp += x[i];
temp /= n;
r[s_mean] = temp;
// Compute std. dev. 
temp = 0;
for (i = 0; i < n; ++i) temp += (x[i] - r[s_mean]) * (x[i] - r[s_mean]);
temp /= n - 1;
r[s_std] = sqrt(temp);
// Sort x.
double* xx = (double*)xmalloc(n * sizeof(double));
memcpy(xx, x, n * sizeof(double));
qsort(xx, n, sizeof(double), compare_doubles);
// Get order statistics. 
r[s_minimum] = xx[0];
r[s_firstquartile] = (xx[(n - 1) / 4] + xx[n / 4]) * .5;
r[s_median] = (xx[(n - 1) / 2] + xx[n / 2]) * .5;
r[s_thirdquartile] = (xx[n - 1 - (n - 1) / 4] + xx[n - 1 - n / 4]) * .5;
r[s_maximum] = xx[n - 1];

free(xx);
}
 */

//int _g500_rank;
//int _g500_size;

static oned_csr_graph g;

void make_graph_data_structure(const tuple_graph* const tg) {
  convert_graph_to_oned_csr(tg, &g);
}

void free_graph_data_structure(void) {
  free_oned_csr_graph(&g);
}

//TradArrayDistG500GraphR::TradArrayDistG500GraphR(v_id n, v_id m, v_id* offsets, v_id* adj, int loading_rank, MPI_Comm comm) : GraphR(n,m)
TradArrayDistG500GraphR::TradArrayDistG500GraphR(v_id SCALE, v_id edgefactor, int loading_rank, MPI_Comm comm) : GraphR( (int64_t)(1) << SCALE, (int64_t)(edgefactor) << SCALE ) {
  v_id n = (int64_t)(1) << SCALE;
  v_id m = (int64_t)(edgefactor) << SCALE;


  if(rank_ == 0) { cout << "X..." << endl;} 

  configureParameters(comm, n, m);
  if(rank_ == 0) { cout << "Y..." << endl;} 
  setup_globals(rank_, num_procs_);

  if(rank_ == 0) { cout << "Start to create the RMAT graph..." << endl;} 

  uint64_t seed1 = 2, seed2 = 3;

  const char* filename = getenv("TMPFILE");
  /* If filename is NULL, store data in memory */

  tuple_graph tg;
  tg.nglobaledges = (int64_t)(edgefactor) << SCALE;
  int64_t nglobalverts = (int64_t)(1) << SCALE;

  tg.data_in_file = (filename != NULL);

  if (tg.data_in_file) {
    MPI_File_set_errhandler(MPI_FILE_NULL, MPI_ERRORS_ARE_FATAL);
    MPI_File_open(MPI_COMM_WORLD, (char*)filename, MPI_MODE_RDWR | MPI_MODE_CREATE | MPI_MODE_EXCL | MPI_MODE_DELETE_ON_CLOSE | MPI_MODE_UNIQUE_OPEN, MPI_INFO_NULL, &tg.edgefile);
    MPI_File_set_size(tg.edgefile, tg.nglobaledges * sizeof(packed_edge));
    MPI_File_set_view(tg.edgefile, 0, packed_edge_mpi_type, packed_edge_mpi_type, "native", MPI_INFO_NULL);
    MPI_File_set_atomicity(tg.edgefile, 0);
  }

  /* Make the raw graph edges. */
  /* Get roots for BFS runs, plus maximum vertex with non-zero degree (used by
   * validator). */
  /*int num_bfs_roots = 64;
    int64_t* bfs_roots = (int64_t*)xmalloc(num_bfs_roots * sizeof(int64_t));*/
  //int64_t max_used_vertex = 0;

  double make_graph_start = MPI_Wtime();
  {
    /* Spread the two 64-bit numbers into five nonzero values in the correct
     * range. */
    uint_fast32_t seed[5];
    make_mrg_seed(seed1, seed2, seed);

    /* As the graph is being generated, also keep a bitmap of vertices with
     * incident edges.  We keep a grid of processes, each row of which has a
     * separate copy of the bitmap (distributed among the processes in the
     * row), and then do an allreduce at the end.  This scheme is used to avoid
     * non-local communication and reading the file separately just to find BFS
     * roots. */
    MPI_Offset nchunks_in_file = (tg.nglobaledges + FILE_CHUNKSIZE - 1) / FILE_CHUNKSIZE;
    int64_t bitmap_size_in_bytes = int64_min(BITMAPSIZE, (nglobalverts + CHAR_BIT - 1) / CHAR_BIT);
    if (bitmap_size_in_bytes * _g500_size * CHAR_BIT < nglobalverts) {
      bitmap_size_in_bytes = (nglobalverts + _g500_size * CHAR_BIT - 1) / (_g500_size * CHAR_BIT);
    }
    int ranks_per_row = ((nglobalverts + CHAR_BIT - 1) / CHAR_BIT + bitmap_size_in_bytes - 1) / bitmap_size_in_bytes;
    int nrows = _g500_size / ranks_per_row;
    int my_row = -1, my_col = -1;
    unsigned char* __restrict__ has_edge = NULL;
    MPI_Comm cart_comm;
    {
      int dims[2] = {_g500_size / ranks_per_row, ranks_per_row};
      int periods[2] = {0, 0};
      MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 1, &cart_comm);
    }
    int in_generating_rectangle = 0;
    if (cart_comm != MPI_COMM_NULL) {
      in_generating_rectangle = 1;
      {
        int dims[2], periods[2], coords[2];
        MPI_Cart_get(cart_comm, 2, dims, periods, coords);
        my_row = coords[0];
        my_col = coords[1];
      }
      MPI_Comm this_col;
      MPI_Comm_split(cart_comm, my_col, my_row, &this_col);
      MPI_Comm_free(&cart_comm);
      has_edge = (unsigned char*)xMPI_Alloc_mem(bitmap_size_in_bytes);
      memset(has_edge, 0, bitmap_size_in_bytes);
      /* Every rank in a given row creates the same vertices (for updating the
       * bitmap); only one writes them to the file (or final memory buffer). */
      packed_edge* buf = (packed_edge*)xmalloc(FILE_CHUNKSIZE * sizeof(packed_edge));
      MPI_Offset block_limit = (nchunks_in_file + nrows - 1) / nrows;
      /* fprintf(stderr, "%d: nchunks_in_file = %" PRId64 ", block_limit = %" PRId64 " in grid of %d rows, %d cols\n", rank, (int64_t)nchunks_in_file, (int64_t)block_limit, nrows, ranks_per_row); */
      if (tg.data_in_file) {
        tg.edgememory_size = 0;
        tg.edgememory = NULL;
      } else {
        int my_pos = my_row + my_col * nrows;
        int last_pos = (tg.nglobaledges % ((int64_t)FILE_CHUNKSIZE * nrows * ranks_per_row) != 0) ?
          (tg.nglobaledges / FILE_CHUNKSIZE) % (nrows * ranks_per_row) :
          -1;
        int64_t edges_left = tg.nglobaledges % FILE_CHUNKSIZE;
        int64_t nedges = FILE_CHUNKSIZE * (tg.nglobaledges / ((int64_t)FILE_CHUNKSIZE * nrows * ranks_per_row)) +
          FILE_CHUNKSIZE * (my_pos < (tg.nglobaledges / FILE_CHUNKSIZE) % (nrows * ranks_per_row)) +
          (my_pos == last_pos ? edges_left : 0);
        /* fprintf(stderr, "%d: nedges = %" PRId64 " of %" PRId64 "\n", rank, (int64_t)nedges, (int64_t)tg.nglobaledges); */
        tg.edgememory_size = nedges;
        tg.edgememory = (packed_edge*)xmalloc(nedges * sizeof(packed_edge));
      }
      MPI_Offset block_idx;

      for (block_idx = 0; block_idx < block_limit; ++block_idx) {
        /* fprintf(stderr, "%d: On block %d of %d\n", rank, (int)block_idx, (int)block_limit); */
        MPI_Offset start_edge_index = int64_min(FILE_CHUNKSIZE * (block_idx * nrows + my_row), tg.nglobaledges);
        MPI_Offset edge_count = int64_min(tg.nglobaledges - start_edge_index, FILE_CHUNKSIZE);
        packed_edge* actual_buf = (!tg.data_in_file && block_idx % ranks_per_row == my_col) ?
          tg.edgememory + FILE_CHUNKSIZE * (block_idx / ranks_per_row) :
          buf;
        /* fprintf(stderr, "%d: My range is [%" PRId64 ", %" PRId64 ") %swriting into index %" PRId64 "\n", rank, (int64_t)start_edge_index, (int64_t)(start_edge_index + edge_count), (my_col == (block_idx % ranks_per_row)) ? "" : "not ", (int64_t)(FILE_CHUNKSIZE * (block_idx / ranks_per_row))); */
        if (!tg.data_in_file && block_idx % ranks_per_row == my_col) {
          assert (FILE_CHUNKSIZE * (block_idx / ranks_per_row) + edge_count <= tg.edgememory_size);
        }
        generate_kronecker_range(seed, SCALE, start_edge_index, start_edge_index + edge_count, actual_buf);
        if (tg.data_in_file && my_col == (block_idx % ranks_per_row)) { /* Try to spread writes among ranks */
          MPI_File_write_at(tg.edgefile, start_edge_index, actual_buf, edge_count, packed_edge_mpi_type, MPI_STATUS_IGNORE);
        }
        ptrdiff_t i;
#ifdef _OPENMP
#pragma omp parallel for
#endif
        for (i = 0; i < edge_count; ++i) {
          int64_t src = get_v0_from_edge(&actual_buf[i]);
          int64_t tgt = get_v1_from_edge(&actual_buf[i]);
          if (src == tgt) continue;
          if (src / bitmap_size_in_bytes / CHAR_BIT == my_col) {
#ifdef _OPENMP
#pragma omp atomic
#endif
            has_edge[(src / CHAR_BIT) % bitmap_size_in_bytes] |= (1 << (src % CHAR_BIT));
          }
          if (tgt / bitmap_size_in_bytes / CHAR_BIT == my_col) {
#ifdef _OPENMP
#pragma omp atomic
#endif
            has_edge[(tgt / CHAR_BIT) % bitmap_size_in_bytes] |= (1 << (tgt % CHAR_BIT));
          }
        }
      }
      free(buf);
#if 0
      /* The allreduce for each root acts like we did this: */
      MPI_Allreduce(MPI_IN_PLACE, has_edge, bitmap_size_in_bytes, MPI_UNSIGNED_CHAR, MPI_BOR, this_col);
#endif
      MPI_Comm_free(&this_col);
    } else {
      tg.edgememory = NULL;
      tg.edgememory_size = 0;
    }
    MPI_Allreduce(&tg.edgememory_size, &tg.max_edgememory_size, 1, MPI_INT64_T, MPI_MAX, MPI_COMM_WORLD);
    if (in_generating_rectangle) {
      MPI_Free_mem(has_edge);
    }
    if (tg.data_in_file) {
      MPI_File_sync(tg.edgefile);
    }
  } 

  double make_graph_stop = MPI_Wtime();
  double make_graph_time = make_graph_stop - make_graph_start;
  if (_g500_rank == 0) { /* Not an official part of the results */
    fprintf(stderr, "graph_generation:               %f s\n", make_graph_time);
  }

  /* Make user's graph data structure. */
  double data_struct_start = MPI_Wtime();
  make_graph_data_structure(&tg);
  double data_struct_stop = MPI_Wtime();
  double data_struct_time = data_struct_stop - data_struct_start;
  if (_g500_rank == 0) { /* Not an official part of the results */
    fprintf(stderr, "construction_time:              %f s\n", data_struct_time);
  }

  //loadAndDistributeGraphFromMemory(offsets, adj, loading_rank);
  loadAndDistributeGraphFromMemory(loading_rank);

  //if (tg.data_in_file) {
  //  MPI_File_close(&tg.edgefile);
  //} else {
  free(tg.edgememory); tg.edgememory = NULL;
  //}


}

//void TradArrayDistG500GraphR::loadAndDistributeGraphFromMemory(v_id* offsets, v_id* adj, int loading_rank)
void TradArrayDistG500GraphR::loadAndDistributeGraphFromMemory(int loading_rank) {
  // TODO: now this thing is really inefficient, but I also plan to use it for small
  // graphs; this code enables reusing the ready functions in graphIO that load graphs.

  // Distribute the degrees so that each proc can allocate enough memory:
  m_local_ = g.nlocaledges;
  n_local_ = g.nlocalverts;
  assert(n_ == g.nlocalverts * num_procs_);

#ifdef __XGRAPH_DEBUGGING__
  v_id debug_m = 0;
  MPI_Allreduce(&m_local_, &debug_m, 1, MPI_INT64_T, MPI_SUM, comm_);
  //if(rank_ == 0) cout << "debug_m: " << debug_m << ", m_local_: " << m_local_ << ", "
  cout << rank_  << "] " << "n_local_: " << n_local_ << ", debug_m: " << debug_m << ", m_local_: " << m_local_ << ", 2*m_: " << 2*m_ << endl; 




  //assert(debug_m == 2*m_);
#endif

  adj_local_ = new v_id[m_local_]();
  offsets_local_ = new v_id[n_local_+1]();

#ifdef __XGRAPH_DEBUGGING__
  for(v_id v = 0; v < n_local_ + 1; ++v) {
    offsets_local_[v] = -1;
  }

  for(v_id e = 0; e < m_local_; ++e) {
    adj_local_[e] = -1;
  }
#endif

  offsets_local_[0] = 0; 
  v_id processed_local_vertex = 0;
  v_id current_offset = 0;

  cout << rank_ << "] " << "g.nlocalverts: " << g.nlocalverts << ", n_local_: " << n_local_ << endl;
  assert(g.nlocalverts == n_local_); 

  //g.rowstarts[VERTEX_LOCAL(v)] ... g.rowstarts[VERTEX_LOCAL(v) + 1] - 1

  v_id offset = 0;
  for(v_id v = 0; v < n_local_; ++v) {
    v_id v_global = VERTEX_TO_GLOBAL(rank_, v);
    v_id v_local_g500 = VERTEX_LOCAL(v_global);

    assert(v == v_local_g500);

    v_id owner = VERTEX_OWNER(v_global);
    assert(rank_ == owner);
    assert(_g500_rank == owner);

    offsets_local_[v] = offset;
    //assert(offset == g.rowstarts[VERTEX_LOCAL(v)]);
    assert(offset == g.rowstarts[v]);

    v_id v_deg = g.rowstarts[v + 1] - g.rowstarts[v];
    //v_id v_deg = (g.rowstarts[VERTEX_LOCAL(v) + 1] - 1) - g.rowstarts[VERTEX_LOCAL(v)];

    //cout << rank_ << "] " << "v: " << v << ", rows + 1: " << g.rowstarts[v+1] << ", rows: " << g.rowstarts[v] << endl;

    assert(v_deg >= 0);

    for(v_id x = 0; x < v_deg; ++x) {
      v_id next_neighbor = g.column[ offset + x ];
      adj_local_[offset + x] = next_neighbor;
    }
    offset += v_deg;
  }
  offsets_local_[n_local_] = offset;

  MPI_Barrier(comm_);

#ifdef __XGRAPH_DEBUGGING__
  for(v_id v = 0; v < n_local_ + 1; ++v) {
    assert(offsets_local_[v] >= 0);
    assert(offsets_local_[v] <= m_local_);
    if(v > 0) {
      assert(offsets_local_[v] >= offsets_local_[v-1]);
    }
  }

  for(v_id e = 0; e < m_local_ + 1; ++e) {
    assert(adj_local_[e] >= 0);
  }

  //TODO: more debugging and checks!

#endif
  debug_m = 0;
  MPI_Allreduce(&m_local_, &debug_m, 1, MPI_INT64_T, MPI_SUM, comm_);
  //assert(debug_m == 2*m_);

  // TODO: make sure about that.
  m_ = debug_m / 2; 


  for(v_id v = 0; v < n_local_; ++v) {
    v_id v_global = getVertexGlobalID(rank_, v);
    v_id v_owner = getVertexOwningRank(v_global);
    assert(v_owner == rank_);
  }
}

void TradArrayDistG500GraphR::configureParameters(MPI_Comm comm, v_id n, v_id m) {
  comm_ = comm;

  MPI_Comm_size(comm, &num_procs_);
  MPI_Comm_rank(comm, &rank_);

  n_ = n;
  m_ = m;

#ifdef NUM_OF_PROCS_IS_A_POWER_OF_TWO 
  if((num_procs_ & (num_procs_ - 1)) != 0) {
    exitMsg("The number of processes (%d) should be a power of two!\n" << num_procs_);
  }

  for (log_of_num_procs_ = 0; log_of_num_procs_ < num_procs_; ++log_of_num_procs_) {
    if ((1 << log_of_num_procs_) == num_procs_) break;
  }
  assert (log_of_num_procs_ < num_procs_);
#endif

  assert(num_procs_ > 0);
  assert(rank_ >= 0);

  n_local_ = 0;
  m_local_ = 0;

  n_local_ = 0;
  for(v_id v = 0; v < n_; ++v) {
    if(rank_ == getVertexOwningRank(v)) {
      ++n_local_;
    }
  }

#ifdef __XGRAPH_DEBUGGING__
  v_id debug_n = 0;
  MPI_Allreduce(&n_local_, &debug_n, 1, MPI_INT64_T, MPI_SUM, comm_);
  assert(debug_n == n);

  v_id* n_locals = new v_id[num_procs_]();
  for(v_id v = 0; v < n_; ++v) {
    ++n_locals[getVertexOwningRank(v)];
  }

  v_id n_sum = 0;
  for(int p = 0; p < num_procs_; ++p) {
    n_sum += n_locals[p];
  }

  assert(n_sum == n);

  delete [] n_locals;
#endif
}

TradArrayDistG500GraphR::~TradArrayDistG500GraphR() {
  free_graph_data_structure();

  cleanup_globals();

  delete [] offsets_local_;
  delete [] adj_local_;
}


#endif
