Skip to content

Commit 1634ba0

Browse files
authored
Merge pull request #48 from devreal/fix-han-gather
COLL HAN: add support for MPI_IN_PLACE to MPI_Gather
2 parents 37f8a93 + e3be005 commit 1634ba0

File tree

2 files changed

+88
-27
lines changed

2 files changed

+88
-27
lines changed

ompi/mca/coll/han/coll_han.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ struct mca_coll_han_gather_args_s {
127127
int root_low_rank;
128128
int w_rank;
129129
bool noop;
130+
bool is_mapbycore;
130131
};
131132
typedef struct mca_coll_han_gather_args_s mca_coll_han_gather_args_t;
132133

ompi/mca/coll/han/coll_han_gather.c

Lines changed: 87 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ mca_coll_han_set_gather_args(mca_coll_han_gather_args_t * args,
3636
int root_low_rank,
3737
struct ompi_communicator_t *up_comm,
3838
struct ompi_communicator_t *low_comm,
39-
int w_rank, bool noop, ompi_request_t * req)
39+
int w_rank, bool noop, bool is_mapbycore, ompi_request_t * req)
4040
{
4141
args->cur_task = cur_task;
4242
args->sbuf = sbuf;
@@ -53,6 +53,7 @@ mca_coll_han_set_gather_args(mca_coll_han_gather_args_t * args,
5353
args->low_comm = low_comm;
5454
args->w_rank = w_rank;
5555
args->noop = noop;
56+
args->is_mapbycore = is_mapbycore;
5657
args->req = req;
5758
}
5859

@@ -70,7 +71,6 @@ mca_coll_han_gather_intra(const void *sbuf, int scount,
7071
int root_low_rank, root_up_rank; /* root ranks for both sub-communicators */
7172
char *reorder_buf = NULL, *reorder_rbuf = NULL;
7273
int i, err, *vranks, low_rank, low_size, *topo;
73-
ptrdiff_t rsize, rgap = 0, rextent;
7474
ompi_request_t *temp_request = NULL;
7575

7676
/* Create the subcommunicators */
@@ -100,6 +100,7 @@ mca_coll_han_gather_intra(const void *sbuf, int scount,
100100
comm, comm->c_coll->coll_gather_module);
101101
}
102102

103+
ompi_datatype_t *dtype = (w_rank == root) ? rdtype : sdtype;
103104
w_rank = ompi_comm_rank(comm);
104105
w_size = ompi_comm_size(comm);
105106
/* Set up request */
@@ -128,7 +129,6 @@ mca_coll_han_gather_intra(const void *sbuf, int scount,
128129
"[%d]: Han Gather root %d root_low_rank %d root_up_rank %d\n",
129130
w_rank, root, root_low_rank, root_up_rank));
130131

131-
ompi_datatype_type_extent(rdtype, &rextent);
132132

133133
/* Allocate reorder buffers */
134134
if (w_rank == root) {
@@ -142,12 +142,25 @@ mca_coll_han_gather_intra(const void *sbuf, int scount,
142142

143143
} else {
144144
/* Need a buffer to store unordered final result */
145+
ptrdiff_t rsize, rgap;
145146
rsize = opal_datatype_span(&rdtype->super,
146147
(int64_t)rcount * w_size,
147148
&rgap);
148149
reorder_buf = (char *)malloc(rsize); //TODO:free
149150
/* rgap is the size of unused space at the start of the datatype */
150151
reorder_rbuf = reorder_buf - rgap;
152+
153+
if (MPI_IN_PLACE == sbuf) {
154+
ptrdiff_t rextent;
155+
ompi_datatype_type_extent(rdtype, &rextent);
156+
ptrdiff_t block_size = rextent * (ptrdiff_t)rcount;
157+
ptrdiff_t src_shift = block_size * w_rank;
158+
ptrdiff_t dest_shift = block_size * w_rank;
159+
ompi_datatype_copy_content_same_ddt(dtype,
160+
(ptrdiff_t)rcount,
161+
(char *)rbuf + dest_shift,
162+
reorder_rbuf + src_shift);
163+
}
151164
}
152165
}
153166

@@ -158,7 +171,7 @@ mca_coll_han_gather_intra(const void *sbuf, int scount,
158171
mca_coll_han_gather_args_t *lg_args = malloc(sizeof(mca_coll_han_gather_args_t));
159172
mca_coll_han_set_gather_args(lg_args, lg, (char *) sbuf, NULL, scount, sdtype, reorder_rbuf,
160173
rcount, rdtype, root, root_up_rank, root_low_rank, up_comm,
161-
low_comm, w_rank, low_rank != root_low_rank, temp_request);
174+
low_comm, w_rank, low_rank != root_low_rank, han_module->is_mapbycore, temp_request);
162175
/* Init lg task */
163176
init_task(lg, mca_coll_han_gather_lg_task, (void *) (lg_args));
164177
/* Issure lg task */
@@ -176,6 +189,8 @@ mca_coll_han_gather_intra(const void *sbuf, int scount,
176189
*/
177190
/* reorder rbuf based on rank */
178191
if (w_rank == root && !han_module->is_mapbycore) {
192+
ptrdiff_t rextent;
193+
ompi_datatype_type_extent(rdtype, &rextent);
179194
for (i=0; i<w_size; i++) {
180195
OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output,
181196
"[%d]: Han Gather copy from %d to %d\n",
@@ -202,6 +217,15 @@ int mca_coll_han_gather_lg_task(void *task_args)
202217
mca_coll_han_gather_args_t *t = (mca_coll_han_gather_args_t *) task_args;
203218
OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, "[%d] Han Gather: lg\n",
204219
t->w_rank));
220+
ompi_datatype_t *dtype;
221+
size_t count;
222+
if (t->w_rank == t->root) {
223+
dtype = t->rdtype;
224+
count = t->rcount;
225+
} else {
226+
dtype = t->sdtype;
227+
count = t->scount;
228+
}
205229

206230
/* If the process is one of the node leader */
207231
char *tmp_buf = NULL;
@@ -210,21 +234,35 @@ int mca_coll_han_gather_lg_task(void *task_args)
210234
/* if the process is one of the node leader, allocate the intermediary
211235
* buffer to gather on the low sub communicator */
212236
int low_size = ompi_comm_size(t->low_comm);
237+
int low_rank = ompi_comm_rank(t->low_comm);
213238
ptrdiff_t rsize, rgap = 0;
214-
rsize = opal_datatype_span(&t->rdtype->super,
215-
(int64_t)t->rcount * low_size,
239+
rsize = opal_datatype_span(&dtype->super,
240+
count * low_size,
216241
&rgap);
217242
tmp_buf = (char *) malloc(rsize);
218243
tmp_rbuf = tmp_buf - rgap;
244+
if (t->w_rank == t->root) {
245+
if (t->is_mapbycore && MPI_IN_PLACE == t->sbuf) {
246+
ptrdiff_t rextent;
247+
ompi_datatype_type_extent(dtype, &rextent);
248+
ptrdiff_t block_size = rextent * (ptrdiff_t)count;
249+
ptrdiff_t src_shift = block_size * t->w_rank;
250+
ptrdiff_t dest_shift = block_size * low_rank;
251+
ompi_datatype_copy_content_same_ddt(dtype,
252+
(ptrdiff_t)count,
253+
tmp_rbuf + dest_shift,
254+
(char *)t->rbuf + src_shift);
255+
}
256+
}
219257
}
220258

221259
/* Low level (usually intra-node or shared memory) node gather */
222260
t->low_comm->c_coll->coll_gather((char *)t->sbuf,
223-
t->scount,
224-
t->sdtype,
261+
count,
262+
dtype,
225263
tmp_rbuf,
226-
t->rcount,
227-
t->rdtype,
264+
count,
265+
dtype,
228266
t->root_low_rank,
229267
t->low_comm,
230268
t->low_comm->c_coll->coll_gather_module);
@@ -253,14 +291,25 @@ int mca_coll_han_gather_ug_task(void *task_args)
253291
OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output,
254292
"[%d] Han Gather: ug noop\n", t->w_rank));
255293
} else {
294+
ompi_datatype_t *dtype;
295+
size_t count;
296+
if (t->w_rank == t->root) {
297+
dtype = t->rdtype;
298+
count = t->rcount;
299+
} else {
300+
dtype = t->sdtype;
301+
count = t->scount;
302+
}
303+
304+
256305
int low_size = ompi_comm_size(t->low_comm);
257306
/* inter node gather */
258307
t->up_comm->c_coll->coll_gather((char *)t->sbuf,
259-
t->scount*low_size,
260-
t->sdtype,
308+
count*low_size,
309+
dtype,
261310
(char *)t->rbuf,
262-
t->rcount*low_size,
263-
t->rdtype,
311+
count*low_size,
312+
dtype,
264313
t->root_up_rank,
265314
t->up_comm,
266315
t->up_comm->c_coll->coll_gather_module);
@@ -320,6 +369,17 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount,
320369

321370
ompi_communicator_t *low_comm = han_module->sub_comm[INTRA_NODE];
322371
ompi_communicator_t *up_comm = han_module->sub_comm[INTER_NODE];
372+
ompi_datatype_t *dtype;
373+
size_t count;
374+
375+
if (w_rank == root) {
376+
dtype = rdtype;
377+
count = rcount;
378+
} else {
379+
dtype = sdtype;
380+
count = scount;
381+
}
382+
323383

324384
/* Get the 'virtual ranks' mapping corresponding to the communicators */
325385
int *vranks = han_module->cached_vranks;
@@ -359,32 +419,32 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount,
359419
char *tmp_buf_start = NULL; // start of the data
360420
if (low_rank == root_low_rank) {
361421
ptrdiff_t rsize, rgap = 0;
362-
rsize = opal_datatype_span(&rdtype->super,
363-
(int64_t)rcount * low_size,
422+
rsize = opal_datatype_span(&dtype->super,
423+
count * low_size,
364424
&rgap);
365425
tmp_buf = (char *) malloc(rsize);
366426
tmp_buf_start = tmp_buf - rgap;
367427
}
368428

369429
/* 1. low gather on nodes leaders */
370430
low_comm->c_coll->coll_gather((char *)sbuf,
371-
scount,
372-
sdtype,
431+
count,
432+
dtype,
373433
tmp_buf_start,
374-
rcount,
375-
rdtype,
434+
count,
435+
dtype,
376436
root_low_rank,
377437
low_comm,
378438
low_comm->c_coll->coll_gather_module);
379439

380440
/* 2. upper gather (inter-node) between node leaders */
381441
if (low_rank == root_low_rank) {
382442
up_comm->c_coll->coll_gather((char *)tmp_buf_start,
383-
scount*low_size,
384-
sdtype,
443+
count*low_size,
444+
dtype,
385445
(char *)reorder_buf_start,
386-
rcount*low_size,
387-
rdtype,
446+
count*low_size,
447+
dtype,
388448
root_up_rank,
389449
up_comm,
390450
up_comm->c_coll->coll_gather_module);
@@ -425,15 +485,15 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount,
425485
void
426486
ompi_coll_han_reorder_gather(const void *sbuf,
427487
void *rbuf, int rcount,
428-
struct ompi_datatype_t *rdtype,
488+
struct ompi_datatype_t *dtype,
429489
struct ompi_communicator_t *comm,
430490
int * topo)
431491
{
432492
int i, topolevel = 2; // always 2 levels in topo
433493
int w_rank = ompi_comm_rank(comm);
434494
int w_size = ompi_comm_size(comm);
435495
ptrdiff_t rextent;
436-
ompi_datatype_type_extent(rdtype, &rextent);
496+
ompi_datatype_type_extent(dtype, &rextent);
437497
for ( i = 0; i < w_size; i++ ) {
438498
OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output,
439499
"[%d]: Future reorder from %d to %d\n",
@@ -443,7 +503,7 @@ ompi_coll_han_reorder_gather(const void *sbuf,
443503
ptrdiff_t block_size = rextent * (ptrdiff_t)rcount;
444504
ptrdiff_t src_shift = block_size * i;
445505
ptrdiff_t dest_shift = block_size * (ptrdiff_t)topo[i * topolevel + 1];
446-
ompi_datatype_copy_content_same_ddt(rdtype,
506+
ompi_datatype_copy_content_same_ddt(dtype,
447507
(ptrdiff_t)rcount,
448508
(char *)rbuf + dest_shift,
449509
(char *)sbuf + src_shift);

0 commit comments

Comments
 (0)