#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <float.h>
#include <math.h>
#include <assert.h>
#include <mpi.h>
#include "smut.h"
#include "stripesmut.h"
#include "timer.h"
#include "auction.h"
#include "stripe_auction.h"
#if !defined(HAVE_RESTRICT)
#define restrict
#endif
MPI_Comm comm;
int rank, csize, am_root;
#if !defined(HAVE_MPI_ALLOC_MEM)
static int
MPI_Alloc_mem (MPI_Aint size, MPI_Info info, void *baseptr)
{
void **out = (void**)baseptr;
*out = malloc (size);
if (*out) return MPI_SUCCESS;
return MPI_ERR_NO_MEM;
}
static int
MPI_Free_mem (void *baseptr)
{
free(baseptr);
return MPI_SUCCESS;
}
#if !defined(MPI_INFO_NULL)
#define MPI_INFO_NULL 0
#endif
#endif
struct cbuf_entry {
double price;
int i;
};
struct comm_state_t {
int csize, rank;
int phase;
int gnc;
struct cbuf_entry *fullbuf;
struct cbuf_entry *fullbuf_send;
int g_n_unmatched;
int *g2l;
};
int
comm_get_changes (struct comm_state_t *state, MPI_Comm comm,
int *n_changes_out, const struct cbuf_entry **clist)
{
const int phase = ++state->phase;
struct cbuf_entry *tmp;
if (phase >= state->csize) {
*n_changes_out = 0;
*clist = NULL;
return 0;
}
*n_changes_out = state->gnc;
*clist = state->fullbuf;
return 1;
}
void
free_comm_state (struct comm_state_t *state, MPI_Comm comm)
{
if (state->g2l) free (state->g2l);
state->g2l = NULL;
MPI_Free_mem (state->fullbuf); state->fullbuf = NULL;
MPI_Free_mem (state->fullbuf_send); state->fullbuf_send = NULL;
}
void
init_comm_state (struct comm_state_t *state, const struct gcsr_t *gA,
MPI_Comm comm)
{
int rank, csize;
const int gnr = gA->nr, gnc = gA->nc;
int lnc = gA->local.nc, max_lnc;
int j;
int *g2l = NULL;
MPI_Comm_size (comm, &csize); state->csize = csize;
MPI_Comm_rank (comm, &rank); state->rank = rank;
state->gnc = gnc;
MPI_Alloc_mem (gnc * sizeof(struct cbuf_entry), MPI_INFO_NULL,
&state->fullbuf);
MPI_Alloc_mem (gnc * sizeof(struct cbuf_entry), MPI_INFO_NULL,
&state->fullbuf_send);
for (j = 0; j < gnc; ++j) {
state->fullbuf[j].price = 0.0;
state->fullbuf[j].i = -1;
state->fullbuf_send[j].price = -1000.0*(rank+1);
state->fullbuf_send[j].i = 1000*gnr + (rank+1);
}
if (gA->colmap.l2g) {
const int lnc = gA->local.nc;
const int * restrict l2g = gA->colmap.l2g;
g2l = malloc (gnc * sizeof(int));
memset (g2l, -1, gnc * sizeof(int));
for (j = 0; j < lnc; ++j) g2l[l2g[j]] = j;
}
state->g2l = g2l;
}
int
comm_g_n_unmatched (struct comm_state_t *state, int n_unmatched, int n_changes,
MPI_Comm comm)
{
int tmp, out;
tmp = n_unmatched + n_changes;
MPI_Allreduce (&tmp, &out, 1, MPI_INT, MPI_SUM, comm);
return out;
}
void
comm_give_changes (struct comm_state_t *state, const struct gcsr_t *gA,
const int infeasible,
const int n_changes,
const int *changed_col_list_in,
const int *col_changed_flag_in,
const int *lmatch_in, const double *lprice_in,
MPI_Comm comm)
{
struct cbuf_entry * restrict fullbuf = state->fullbuf;
struct cbuf_entry * restrict fullbuf_send = state->fullbuf_send;
const int * restrict changed_col_list = changed_col_list_in;
const int * restrict col_changed_flag = col_changed_flag_in;
const int * restrict lmatch = lmatch_in;
const double * restrict lprice = lprice_in;
const int gnr = gA->nr;
const int gnc = gA->nc;
const int * restrict l2g = gA->colmap.l2g;
const int lindstart = gA->rowmap.lindstart;
/*const int lnind = gA->rowmap.lnind;*/
int j, k;
/*fprintf (stderr, "%d: lmatch %p\n", rank, lmatch);*/
if (infeasible) { /* infeasible */
for (j = 0; j < gnc; ++j) {
fullbuf_send[j].price = -9999.0;
fullbuf_send[j].i = -1;
}
}
else {
for (k = 0; k < n_changes; ++k) {
int j, gj;
j = changed_col_list[k];
gj = (l2g? l2g[j] : j);
if (lmatch[j] >= 0)
fullbuf_send[j].i = lmatch[j] + lindstart;
else
fullbuf_send[j].i = gnr;
fullbuf_send[j].price = lprice[j];
}
}
state->phase = 0;
/*fprintf (stderr, "%d: here1 %d\n", rank, gnc);*/
MPI_Allreduce (fullbuf_send, fullbuf, gnc, MPI_DOUBLE_INT, MPI_MAXLOC, comm);
/*fprintf (stderr, "%d: here2 %d\n", rank, gnc);*/
if (!infeasible) { /* infeasible */
for (k = 0; k < n_changes; ++k) {
int j, gj;
j = changed_col_list[k];
gj = (l2g? l2g[j] : j);
fullbuf_send[j].i = 1000*gnr + rank;
fullbuf_send[j].price = -10.0;
}
}
}
void
merge_cbuf_into_vars (const int n_changes,
const struct cbuf_entry *cbuf_in,
const int *g2l_in,
const struct gcsr_t* gA,
int *lmatch_in, double *lprice_in,
int *n_unmatched_out, int *unmatched_list)
{
const struct cbuf_entry * restrict pricebuf = cbuf_in;
int n_unmatched = *n_unmatched_out;
int * restrict unmatched = unmatched_list;
int * restrict lmatch = lmatch_in;
double * restrict lprice = lprice_in;
const int gnr = gA->nr;
const int gnc = gA->nc;
const int lnc = gA->local.nc;
const int * restrict l2g = gA->colmap.l2g;
const int * restrict g2l = g2l_in;
const int lindstart = gA->rowmap.lindstart;
const int lnind = gA->rowmap.lnind;
int j;
/*fprintf (stderr, "%d: here3\n", rank);*/
for (j = 0; j < gnc; ++j) {
int lj;
int matchi;
lj = (g2l? g2l[j] : j);
if (!lnc) lj = -1;
if (lj < 0) continue;
matchi = pricebuf[j].i;
if (matchi < 0) {
*n_unmatched_out = -2*gnr;
return;
}
matchi -= lindstart; /* translate to local */
if (matchi >= lnind || matchi < 0) /* map all external rows to -1 */
matchi = -1;
if (pricebuf[j].price > lprice[lj]) {
assert(matchi < 0 || matchi >= lnind);
lprice[lj] = pricebuf[j].price;
if (lmatch[lj] >= 0) {
/*
if (n_unmatched < 0 || n_unmatched > gnr)
fprintf(stderr, "%d: bogus n_unmatched %d\n", rank, n_unmatched);
if (lj < 0 || lj >= lnc > gnr)
fprintf(stderr, "%d: bogus lj %d (%d)\n", rank, lj, lnc);
*/
unmatched[n_unmatched] = lmatch[lj];
lmatch[lj] = -1;
++n_unmatched;
}
}
else if (pricebuf[j].price == lprice[lj]) {
/*
if (!(matchi == lmatch[lj] || matchi < 0)) {
fprintf (stderr, "%d: j = %d, lj = %d, matchi = %d, lmatch = %d, lnind = %d\n",
rank, j, lj, matchi, lmatch[lj], lnind);
fprintf (stderr, "%d: sent matchi = %d (trans %d)\n",
rank, comm_state.pricebuf_send[j].i, comm_state.pricebuf_send[j].i - lindstart);
}
*/
assert (matchi == lmatch[lj] || matchi < 0);
if (lprice[lj] == HUGE_VAL && matchi != lmatch[lj]) {
n_unmatched = -2*gnr;
/*
fprintf (stderr,
"%d: tried to break inf price, col %d between %d(l %d) and %d (%d)\n",
rank,
j, lmatch[lj] + lindstart, lmatch[lj], comm_state.pricebuf[j].i, matchi);
*/
break;
}
if (matchi < 0 && lmatch[lj] != -1) {
/*
if (n_unmatched < 0 || n_unmatched > gnr)
fprintf(stderr, "%d: bogus n_unmatched %d\n", rank, n_unmatched);
if (lj < 0 || lj >= lnc > gnr)
fprintf(stderr, "%d: bogus lj %d (%d)\n", rank, lj, lnc);
*/
unmatched[n_unmatched] = lmatch[lj];
lmatch[lj] = -1;
++n_unmatched;
}
}
#if !defined(NDEBUG)
else {
assert(lmatch[lj] < 0 || matchi != lmatch[lj]);
assert(matchi < 0);
}
#endif
}
*n_unmatched_out = n_unmatched;
}
int*
comm_state_g2l (struct comm_state_t* state)
{
return state->g2l;
}
int
main (int argc, char **argv)
{
FILE *f;
int err;
double mu_min;
int relgap = 0;
struct spcsr_t A, Acopy;
struct gcsr_t gA;
int lnr, lnc, lnent;
int gnent;
double *R = NULL, *C = NULL;
double *expint = NULL, *lexpint = NULL, *expint_check = NULL;
int *rdisps = NULL, *rcounts = NULL;
int *match = NULL;
double *price = NULL;
int *lmatch = NULL;
double *lprice = NULL;
int errflg, i, j;
double primal, dual;
struct Timer timer;
struct comm_state_t comm_state;
MPI_Init (&argc, &argv);
comm = MPI_COMM_WORLD;
MPI_Comm_size (comm, &csize);
MPI_Comm_rank (comm, &rank);
am_root = (0 == rank);
assert(sizeof(int) == 4);
spcsr_init_clear (&A);
if (am_root) {
if (argc <= 1) {
printf("No files\n");
return -1;
}
f = fopen(argv[1], "rb");
if (!f) {
perror("Error opening file: ");
return -1;
}
err = spcsr_load_binfile (f, &A);
if (err) {
printf ("Error reading: %d\n", err);
return -1;
}
fclose(f);
R = malloc (A.nr * sizeof(double));
C = malloc (A.nc * sizeof(double));
gnent = A.nent;
printf ("nr = %d nc = %d nent = %d\n", A.nr, A.nc, A.nent);
if (getenv ("PRESCALE")) {
spcsr_lascaling (&A, R, C);
spcsr_apply_scaling (&A, R, C);
}
else
spcsr_lascale (&A, R, C);
spcsr_copy (&A, &Acopy);
expint_check = malloc (Acopy.nent * sizeof(double));
if (getenv ("TOINT"))
auction_toexpint (&A, expint);
else
auction_toexp (&A, expint);
auction_shift (&Acopy, expint_check);
}
gcsr_take_csr_root (&A, &gA, MPI_COMM_WORLD);
gcsr_redist_ents (&gA, 0, MPI_COMM_WORLD);
{
int p;
for (p = 0; p < csize; ++p) {
MPI_Barrier(MPI_COMM_WORLD);
if (p == rank) {
printf("++++ %d: gA(%d, %d), lA(%d, %d; %d) rows [%d, %d)\n",
rank, gA.nr, gA.nc, gA.local.nr, gA.local.nc, gA.local.nent,
gA.rowmap.lindstart, gA.rowmap.lindstart + gA.rowmap.lnind);
printf("++++ %d: local rowoff[1] = %d\n", rank,
(gA.local.rowoff? gA.local.rowoff[1] : -1));
}
MPI_Barrier(MPI_COMM_WORLD);
}
}
lnr = gA.local.nr;
lnc = gA.local.nc;
lnent = gA.local.nent;
/*
fprintf (stderr, "%d: (lnr, lnc; lnent) = (%d, %d; %d)\n",
rank, lnr, lnc, lnent);
fprintf (stderr, "%d: (gnr, gnc) = (%d, %d)\n",
rank, gA.nr, gA.nc);
*/
lexpint = malloc (lnent * sizeof(double));
auction_toexpint (&gA.local, lexpint);
stripe_auction_shift (&gA, lexpint, MPI_COMM_WORLD);
if (am_root) {
int k;
int die = 0;
for (k = 0; k < gA.local.nent; ++k) {
if (gA.local.entry[k] != Acopy.entry[k]) {
fprintf (stderr, "gA.local.entry[%d] = %g Acopy.entry[%d] = %g\n",
k, gA.local.entry[k], k, Acopy.entry[k]);
die = 1;
}
if (gA.local.colind[k] != Acopy.colind[k]) {
fprintf (stderr, "gA.local.colind[%d] = %d Acopy.colind[%d] = %d\n",
k, gA.local.colind[k], k, Acopy.colind[k]);
die = 1;
}
if (lexpint[k] != expint_check[k]) {
fprintf (stderr, "lexpint[%d] = %g expint_check[%d] = %g\n",
k, lexpint[k], k, expint_check[k]);
die = 1;
}
if (die) MPI_Abort (comm, -9999);
}
}
mu_min = 1.0 / gA.nr;
if (argc > 2)
mu_min = strtod(argv[2], NULL);
if (argc > 3)
relgap = 1;
lmatch = malloc (lnc * sizeof (int));
lprice = malloc (lnc * sizeof (double));
MPI_Barrier(MPI_COMM_WORLD);
initialize_timer (&timer);
start_timer (&timer);
#if 0
errflg = stripe_auction_simple (&gA, lexpint, lmatch, lprice, mu_min,
&comm_state, MPI_COMM_WORLD);
#else
errflg = stripe_auction_scaling (&gA, lexpint, lmatch, lprice,
mu_min, relgap,
&comm_state, MPI_COMM_WORLD);
#endif
stop_timer(&timer);
MPI_Barrier(MPI_COMM_WORLD);
if (am_root) {
price = malloc (gA.nc * sizeof(double));
match = malloc (gA.nc * sizeof(int));
expint = malloc (Acopy.nent * sizeof(double));
rcounts = malloc (csize * sizeof(int));
rdisps = malloc (csize * sizeof(int));
}
stripe_auction_vars_l2g (&gA, lmatch, lprice, match, price, comm);
if (1 == csize) {
int j;
for (j = 0; j < gA.nc; ++j) {
assert(lmatch[j] == match[j]);
if (lprice[j] != price[j]) {
fprintf(stderr, "prices at %d don't match: %g %g\n",
j, lprice[j], price[j]);
}
assert(lprice[j] == price[j]);
}
}
/* now get expint */
MPI_Gather (&gA.local.nent, 1, MPI_INT,
rcounts, 1, MPI_INT,
0, comm);
if (am_root) {
int p;
rdisps[0] = 0;
for (p = 1; p < csize; ++p)
rdisps[p] = rdisps[p-1] + rcounts[p-1];
}
MPI_Gatherv (lexpint, gA.local.nent, MPI_DOUBLE,
expint, rcounts, rdisps, MPI_DOUBLE,
0, comm);
auction_eval_primal_mdual (gA.local, lexpint, lmatch, lprice,
&primal, &dual);
{
double pd_in[2];
double pd[2];
pd_in[0] = primal; pd_in[1] = dual;
MPI_Allreduce (&pd_in[0], &pd[0], 2, MPI_DOUBLE, MPI_SUM, comm);
if (am_root) {
printf ("Reduced primal: %g dual %g\n", pd[0], pd[1]);
}
}
gcsr_redist_root (&gA, comm);
#if 1
if (am_root) {
int k;
for (k = 0; k < Acopy.nent; ++k) {
if (expint[k] != expint_check[k]) {
fprintf (stderr, "expint[%d] = %g expint_check[%d] = %g le = %g\n",
k, expint[k], k, expint_check[k], lexpint[0]);
}
assert(expint[k] == expint_check[k]);
assert(gA.local.entry[k] == Acopy.entry[k]);
assert(gA.local.colind[k] == Acopy.colind[k]);
}
}
#endif
if (am_root) {
auction_eval_primal_mdual (Acopy, expint, match, price,
&primal, &dual);
printf ("Primal: %20g\nDual: %20g\nTime: %20g\nmu_min: %20g\n",
primal, dual, timer_duration(timer), mu_min);
auction_eval_primal_mdual (gA.local, expint_check, match, price,
&primal, &dual);
printf ("Primal: %20g\nDual: %20g\n", primal, dual);
#if 1
{
int i, j;
int *invmatch;
invmatch = malloc(Acopy.nr * sizeof(int));
memset (invmatch, -1, Acopy.nr * sizeof(int));
for (j = 0; j < Acopy.nc; ++j) {
if (match[j] < 0) printf("column %d is unmatched\n", j);
if (match[j] >= Acopy.nr) printf("column %d's match out of range (%d/%d)\n",
j, match[j], Acopy.nr);
invmatch[ match[j] ] = j;
}
for (i = 0; i < Acopy.nr; ++i) {
int k;
int found_match = 0;
if (invmatch[i] < 0)
printf("row %d isn't matched\n", i);
for (k = Acopy.rowoff[i]; k < Acopy.rowoff[i+1]; ++k) {
const int j = Acopy.colind[k];
if (i == match[j]) ++found_match;
}
if (!found_match) printf ("row %d is unmatched\n", i);
if (found_match > 1) printf ("row %d is multiply matched\n", i);
}
free(invmatch);
}
#endif
}
{
int gcomm = -1;
extern int comm_ngives;
MPI_Reduce (&comm_ngives, &gcomm, 1, MPI_INT, MPI_MAX, 0, comm);
if (am_root) printf ("Max number of gives: %d\n", gcomm);
}
{
int gscanned = -1;
extern int nent_scanned;
MPI_Reduce (&nent_scanned, &gscanned, 1, MPI_INT, MPI_MAX, 0, comm);
if (am_root) printf ("nent-scanned: %d %g\n", gscanned, (double)gscanned/(double)gnent);
}
/* ------------ */
gcsr_free (&gA, comm);
spcsr_free (&A);
if (am_root) {
free (rdisps);
free (rcounts);
free (expint_check);
free (expint);
free (price);
free (match);
free (R);
free (C);
}
MPI_Finalize ();
return 0;
}