MPQC  3.0.0-alpha
contract.h
1 
2 /*
3  * Copyright 2009 Sandia Corporation. Under the terms of Contract
4  * DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government
5  * retains certain rights in this software.
6  *
7  * This file is a part of the MPQC LMP2 library.
8  *
9  * The MPQC LMP2 library is free software: you can redistribute it
10  * and/or modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation, either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * This program is distributed in the hope that it will be useful, but
15  * WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with this program. If not, see
21  * <http://www.gnu.org/licenses/>.
22  *
23  */
24 
25 #ifndef _chemistry_qc_lmp2_contract_h
26 #define _chemistry_qc_lmp2_contract_h
27 
28 #include <chemistry/qc/lmp2/dgemminfo.h>
29 
30 #include <stdexcept>
31 
32 #include <math/scmat/blas.h>
33 
34 #define USE_BOUNDS_IN_CONTRACT_UNION 1
35 
36 namespace sc {
37 
38 namespace sma2 {
39 
46 template <int N>
47 inline bool
48 need_repack(const Array<N> *A, const IndexList &row_indices, const IndexList &col_indices,
49  const IndexList &fixed, const BlockInfo<N> *fixedvals)
50 {
51  // All that matters are indices that have block sizes greater than 1.
52  // If those indices are ordered correctly, then a repack is not needed.
53 
54  std::vector<int> indices;
55 
56  for (int i=0; i<row_indices.n(); i++) {
57  int index = row_indices.i(i);
58  if (A->index(index).max_block_size() > 1) {
59  indices.push_back(index);
60  }
61  }
62 
63  for (int i=0; i<col_indices.n(); i++) {
64  int index = col_indices.i(i);
65  if (A->index(index).max_block_size() > 1) {
66  indices.push_back(index);
67  }
68  }
69 
70  for (int i=1; i<indices.size(); i++) {
71  if (indices[i] < indices[i-1]) return true;
72  }
73 
74  return false;
75 }
76 
83 template <int N>
84 inline double
85 repack_cost(const Array<N> *A,
86  const IndexList &row_indices, const IndexList &col_indices,
87  const IndexList &fixed, const BlockInfo<N> *fixedvals)
88 {
89  // This gives a very rough estimate of the relative repack costs
90  // nindex cost
91  // 0 0
92  // 1 1
93  // 2 3
94  // 3 7
95  // 4 15
96  double cost = A->n_element_allocated();
97  for (int i=0; i<fixed.n(); i++) {
98  cost = cost/A->index(fixed.i(i)).nblock();
99  }
100  return cost;
101 }
102 
104 template <int NC, int NA, int NB>
106  double cost_;
107  bool need_repack_A_, transpose_A_;
108  bool need_repack_B_, transpose_B_;
109  bool need_repack_C_, transpose_C_;
110  IndexList extA_, intA_, fixA_;
111  const BlockInfo<NA> *fixvalA_;
112  IndexList extB_, intB_, fixB_;
113  const BlockInfo<NB> *fixvalB_;
114  IndexList extCA_, extCB_, fixC_;
115  const BlockInfo<NC> *fixvalC_;
116  const Array<NA> *A_;
117  const Array<NB> *B_;
118  const Array<NC> *C_;
119  int n_C_repack_;
120 
121  void init() {
122  cost_ = 0.0;
123  need_repack_A_ = false;
124  transpose_A_ = false;
125  if (need_repack(A_, extA_, intA_, fixA_, fixvalA_)) {
126  if (need_repack(A_, intA_, extA_, fixA_, fixvalA_)) {
127  need_repack_A_ = true;
128  cost_ += 2.0 * repack_cost(A_,extA_,intA_,fixA_,fixvalA_);
129  }
130  else {
131  transpose_A_ = true;
132  }
133  }
134 
135  need_repack_B_ = false;
136  transpose_B_ = false;
137  if (need_repack(B_, intB_, extB_, fixB_, fixvalB_)) {
138  if (need_repack(B_, extB_, intB_, fixB_, fixvalB_)) {
139  need_repack_B_ = true;
140  cost_ += 2.0 * repack_cost(B_,intB_,extB_,fixB_,fixvalB_);
141  }
142  else {
143  transpose_B_ = true;
144  }
145  }
146 
147  need_repack_C_ = false;
148  transpose_C_ = false;
149  if (need_repack(C_, extCA_, extCB_, fixC_, fixvalC_)) {
150  if (need_repack(C_, extCB_, extCA_, fixC_, fixvalC_)) {
151  need_repack_C_ = true;
152  cost_ += n_C_repack_
153  * repack_cost(C_,extCA_,extCB_,fixC_,fixvalC_);
154  }
155  else {
156  transpose_C_ = true;
157  }
158  }
159  }
160  void reorder_indices(IndexList &i1,IndexList &i2) {
161  std::map<int,int> index_map;
162  for (int i=0; i<i1.n(); i++) {
163  index_map[i1.i(i)] = i2.i(i);
164  }
165  int i=0;
166  for (std::map<int,int>::iterator
167  iter=index_map.begin();
168  iter!=index_map.end(); i++,iter++) {
169  i1.i(i) = iter->first;
170  i2.i(i) = iter->second;
171  }
172  }
173  public:
176  const IndexList &extA, const IndexList &intA,
177  const IndexList &fixA, const BlockInfo<NA> *fixvalA,
178  const Array<NB> *B,
179  const IndexList &intB, const IndexList &extB,
180  const IndexList &fixB, const BlockInfo<NB> *fixvalB,
181  const Array<NC> *C,
182  const IndexList &extCA, const IndexList &extCB,
183  const IndexList &fixC, const BlockInfo<NC> *fixvalC,
184  int n_C_repack
185  ):
186  cost_(0.0),
187  extA_(extA), intA_(intA), fixA_(fixA), fixvalA_(fixvalA),
188  extB_(extB), intB_(intB), fixB_(fixB), fixvalB_(fixvalB),
189  extCA_(extCA), extCB_(extCB), fixC_(fixC), fixvalC_(fixvalC),
190  A_(A), B_(B), C_(C),
191  n_C_repack_(n_C_repack) {
192  init();
193  }
196  void driver(int i) {
197  if (i==0) {
198  reorder_indices(extCA_,extA_);
199  reorder_indices(extCB_,extB_);
200  }
201  else if (i==1) {
202  reorder_indices(extA_,extCA_);
203  reorder_indices(intA_,intB_);
204  }
205  else {
206  reorder_indices(extB_,extCB_);
207  reorder_indices(intB_,intA_);
208  }
209  init();
210  }
212  bool need_repack_A() const { return need_repack_A_; }
214  bool transpose_A() const { return transpose_A_; }
216  bool need_repack_B() const { return need_repack_B_; }
218  bool transpose_B() const { return transpose_B_; }
220  bool need_repack_C() const { return need_repack_C_; }
222  bool transpose_C() const { return transpose_C_; }
224  double cost() const { return cost_; }
225 
228  void assign_indices(IndexList &extA, IndexList &intA,
229  IndexList &intB, IndexList &extB,
230  IndexList &extCA,IndexList &extCB) {
231  extA = extA_;
232  intA = intA_;
233  intB = intB_;
234  extB = extB_;
235  extCA = extCA_;
236  extCB = extCB_;
237  }
238 };
239 
247 template <int N>
248 inline void
249 repack(Array<N> &A, const IndexList &row_indices, const IndexList &col_indices,
250  const IndexList &fixed, const BlockInfo<N> &fixedvals,
251  bool reverse = false)
252 {
253  double *tmp_data = 0;
254  int n_tmp_data = 0;
255 
256  const typename Array<N>::blockmap_t &amap = A.blockmap();
257 
258  typename Array<N>::blockmap_t::const_iterator begin, end;
259  if (fixed.n() == 0) {
260  begin = amap.begin();
261  end = amap.end();
262  }
263  else {
264  BlockInfo<N> sbi;
265 
266  sbi.zero();
267  sbi.assign_blocks(fixed, fixedvals);
268 #ifdef USE_BOUND
269  sbi.set_bound(DBL_MAX);
270 #endif
271  begin = amap.lower_bound(sbi);
272 
273  for (int i=0; i<N; i++) sbi.block(i) = A.index(i).nblock();
274  sbi.assign_blocks(fixed, fixedvals);
275 #ifdef USE_BOUND
276  sbi.set_bound(0.0);
277 #endif
278  end = amap.upper_bound(sbi);
279  }
280 
281  for (typename Array<N>::blockmap_t::const_iterator aiter = begin;
282  aiter != end;
283  aiter++) {
284  const BlockInfo<N> &bi = aiter->first;
285  double *data = aiter->second;
286  int ndata = bi.size(A.indices());
287  if (n_tmp_data < ndata) {
288  delete[] tmp_data;
289  tmp_data = new double[ndata];
290  n_tmp_data = ndata;
291  }
292  int nrow = bi.subset_size(A.indices(), row_indices);
293  int ncol = bi.subset_size(A.indices(), col_indices);
294  if (ndata != nrow * ncol) {
295  throw std::length_error("sma::repack: ntotal != nrow * ncol");
296  }
297  memcpy(tmp_data,data,sizeof(double)*ndata);
298  BlockIter<N> a_iter(A.indices(), bi);
299  for (a_iter.start(); a_iter.ready(); a_iter++) {
300  int row_index = a_iter.subset_offset(row_indices);
301  int col_index = a_iter.subset_offset(col_indices);
302  int index = a_iter.offset();
303  int repacked_index = row_index*ncol + col_index;
304  if (!reverse) {
305  data[repacked_index] = tmp_data[index];
306  }
307  else {
308  data[index] = tmp_data[repacked_index];
309  }
310  }
311  }
312 
313  delete[] tmp_data;
314 }
315 
344 template <int NC, int NA, int NB>
345 inline void
346 contract(
347  Array<NC> &C, const IndexList &c_extCA, const IndexList &c_extCB,
348  const IndexList &fixextCA, const IndexList &fixextCB,
349  const IndexList &fixC, const BlockInfo<NC> &fixvalC,
350  Array<NA> &A, const IndexList &c_extA, const IndexList &fixextA,
351  const IndexList &c_intA,
352  const IndexList &fixA, const BlockInfo<NA> &fixvalA,
353  bool clear_A_after_use,
354  Array<NB> &B, const IndexList &c_extB, const IndexList &fixextB,
355  const IndexList &c_intB,
356  const IndexList &fixB, const BlockInfo<NB> &fixvalB,
357  bool clear_B_after_use,
358  double ABfactor,
359  bool C_is_zero_on_entry = false,
360  sc::Ref<sc::RegionTimer> timer = 0)
361 {
362  // Some of the arguments are copied to local variables. This
363  // is to allow modification of those arguments locally to permit
364  // optimization of the cost of repacking the arrays.
365  IndexList extCA(c_extCA), extCB(c_extCB);
366  IndexList extA(c_extA), extB(c_extB);
367  IndexList intA(c_intA), intB(c_intB);
368 
369  if (C.n_element_allocated() == 0
370  || A.n_element_allocated() == 0
371  || B.n_element_allocated() == 0) {
372  return;
373  }
374 
375  for (int i=0; i<fixC.n(); i++)
376  if (fixC.i(i) >= fixC.n())
377  throw std::invalid_argument("contract: C's fixed indices must be first");
378 
379  // Consistency checks
380  if (extA.n() + extB.n() != NC - fixC.n() + fixextCA.n() + fixextCB.n()) {
381  throw std::invalid_argument("contract: Number of externals on A + B != C");
382  }
383  if (extCA.n() + extCB.n() != NC - fixC.n() + fixextCA.n() + fixextCB.n()) {
384  throw std::invalid_argument("contract: Number of indices on C inconsistent");
385  }
386  if (intA.n() != intB.n()) {
387  throw std::invalid_argument(
388  "contract: Number of internals on A and B inconsistent");
389  }
390  if (intA.n() + extA.n() != NA - fixA.n() + fixextA.n()) {
391  throw std::invalid_argument("contract: Number of indices on A inconsistent");
392  }
393  if (intB.n() + extB.n() != NB - fixB.n() + fixextB.n()) {
394  throw std::invalid_argument("contract: Number of indices on B inconsistent");
395  }
396  for (int i=0; i<extA.n(); i++) {
397  if (A.index(extA.i(i)) != C.index(extCA.i(i))) {
398  throw std::invalid_argument("contract: Range conflict between A and C");
399  }
400  }
401  for (int i=0; i<extB.n(); i++) {
402  if (B.index(extB.i(i)) != C.index(extCB.i(i))) {
403  throw std::invalid_argument("contract: Range conflict between B and C");
404  }
405  }
406  for (int i=0; i<intA.n(); i++) {
407  if (A.index(intA.i(i)) != B.index(intB.i(i))) {
408  throw std::invalid_argument("contract: Range conflict between A and B");
409  }
410  }
411 
412  RepackScheme<NC,NA,NB> repack_scheme(&A, extA, intA, fixA, &fixvalA,
413  &B, intB, extB, fixB, &fixvalB,
414  &C, extCA,extCB,fixC, &fixvalC,
415  C_is_zero_on_entry?1:2);
416 
417 // std::cout << "trying to repack: original cost = "
418 // << repack_scheme.cost()
419 // << ", rpk A = " << repack_scheme.need_repack_A()
420 // << " (" << A.n_element_allocated() << ")"
421 // << ", rpk B = " << repack_scheme.need_repack_B()
422 // << " (" << B.n_element_allocated() << ")"
423 // << ", rpk C = " << repack_scheme.need_repack_C()
424 // << " (" << C.n_element_allocated() << ")"
425 // << std::endl;
426 
427  if (repack_scheme.cost() != 0.0) {
428  RepackScheme<NC,NA,NB> tmp_repack_scheme(repack_scheme);
429  for (int driver1 = 0; driver1 < 3; driver1++) {
430  for (int driver2 = 0; driver2 < 3; driver2++) {
431  if (driver1 == driver2) continue;
432  tmp_repack_scheme.driver(driver1);
433  tmp_repack_scheme.driver(driver2);
434 
435 // std::cout << " repack scheme: "
436 // << " d1 = " << driver1
437 // << " d2 = " << driver2
438 // << ", cost = "
439 // << tmp_repack_scheme.cost()
440 // << ", rpk A = "
441 // << tmp_repack_scheme.need_repack_A()
442 // << " (" << A.n_element_allocated() << ")"
443 // << ", rpk B = "
444 // << tmp_repack_scheme.need_repack_B()
445 // << " (" << B.n_element_allocated() << ")"
446 // << ", rpk C = "
447 // << tmp_repack_scheme.need_repack_C()
448 // << " (" << C.n_element_allocated() << ")"
449 // << std::endl;
450 
451  if (tmp_repack_scheme.cost() < repack_scheme.cost()) {
452  repack_scheme = tmp_repack_scheme;
453  repack_scheme.assign_indices(extA, intA,
454  intB, extB,
455  extCA,extCB);
456  //std::cout << "found a more efficient scheme" << std::endl;
457  }
458  if (repack_scheme.cost() == 0.0) break;
459  }
460  if (repack_scheme.cost() == 0.0) break;
461  }
462  }
463 
464  // Remap the blocks of A so that it is sorted by its external indices
465  // and any fixed indices.
466  if (timer) timer->enter("remap A");
467  IndexList cmpAlist(extA,fixA);
468  IndexListLess<NA> cmpA(cmpAlist);
469  typename Array<NA>::cached_blockmap_t remappedAbm_local(cmpA);
470  typename Array<NA>::cached_blockmap_t *remappedAbm_ptr;
471  if (fixA.n() == 0 && A.use_blockmap_cache()) {
472  remappedAbm_ptr = &A.blockmap_cache_entry(cmpAlist);
473  }
474  else {
475  remap(remappedAbm_local, A, fixA, fixvalA);
476  remappedAbm_ptr = &remappedAbm_local;
477  }
478  typename Array<NA>::cached_blockmap_t &remappedAbm=*remappedAbm_ptr;
479  if (timer) timer->exit();
480 // std::cout << "A:" << std::endl << A;
481 // std::cout << "remappedA:" << std::endl << remappedA;
482 
483  // Repack the data of A, B, and C so DGEMM can be used
484  // note: if need_repack is true then the matrix has not been transposed
485  if (timer) timer->enter("repack1");
486  if (repack_scheme.need_repack_A()) repack(A, extA, intA, fixA, fixvalA);
487  if (repack_scheme.need_repack_B()) repack(B, intB, extB, fixB, fixvalB);
488  if (repack_scheme.need_repack_C() && !C_is_zero_on_entry) {
489  repack(C, extCA, extCB, fixC, fixvalC);
490  }
491  if (timer) timer->exit();
492 
493 // std::cout << "tA:" << transpose_A
494 // << " tB:" << transpose_B
495 // << " rA:" << need_repack_A
496 // << " rB:" << need_repack_B
497 // << " rC:" << need_repack_C
498 // << std::endl;
499 
500 #if 0
501  // Fixed indices imply that a loop over those indices is been done
502  // external to this routine. If one array doesn't have all of the fixed
503  // indices that other arrays have, then repacking that array will result
504  // in extra work. The code below detects this case.
505  std::set<int> fixed_all, fixed_A, fixed_B, fixed_C;
506  for (int i=0; i<fixA.n(); i++) {
507  fixed_all.insert(fixvalA.block(i));
508  fixed_A.insert(fixvalA.block(i));
509  }
510  for (int i=0; i<fixB.n(); i++) {
511  fixed_all.insert(fixvalB.block(i));
512  fixed_B.insert(fixvalB.block(i));
513  }
514  for (int i=0; i<fixC.n(); i++) {
515  fixed_all.insert(fixvalC.block(i));
516  fixed_C.insert(fixvalC.block(i));
517  }
518  if (repack_scheme.need_repack_A()
519  && fixed_A.size() > 0
520  && fixed_A != fixed_all) {
521  std::cout << "PERFORMANCE WARNING: contract needed to repack A"
522  << " but B and/or C have different fixed indices"
523  << std::endl;
524  throw std::runtime_error("contract: performance exception");
525  }
526  if (repack_scheme.need_repack_B()
527  && fixed_B.size() > 0
528  && fixed_B != fixed_all) {
529  std::cout << "PERFORMANCE WARNING: contract needed to repack B"
530  << " but A and/or C have different fixed indices"
531  << std::endl;
532  throw std::runtime_error("contract: performance exception");
533  }
534  if (repack_scheme.need_repack_C()
535  && fixed_C.size() > 0
536  && fixed_C != fixed_all) {
537  std::cout << "PERFORMANCE WARNING: contract needed to repack C"
538  << " but A and/or B have different fixed indices"
539  << std::endl;
540  throw std::runtime_error("contract: performance exception");
541  }
542 #endif
543 
544  const typename Array<NB>::blockmap_t &
545  Bbm = B.blockmap();
546 #ifdef USE_HASH
547  const typename Array<NB>::blockhash_t &
548  Bbh = B.blockhash();
549 #endif
550  const typename Array<NC>::blockmap_t &
551  Cbm = C.blockmap();
552 
553  BlockInfo<NA> Abi;
554  Abi.assign_blocks(fixA,fixvalA);
555 
556  BlockInfo<NB> Bbi;
557  Bbi.zero();
558  Bbi.assign_blocks(fixB,fixvalB);
559 
560 #ifndef USE_HASH
561  typename Array<NB>::blockmap_t::const_iterator B_fixed_hint
562  = Bbm.lower_bound(Bbi);
563 #endif
564 
565  typename Array<NC>::blockmap_t::const_iterator C_begin, C_end;
566  BlockInfo<NC> Cbi_lb;
567  BlockInfo<NC> Cbi_ub;
568  for (int i=0; i<NC; i++) {
569  Cbi_lb.block(i) = 0;
570  Cbi_ub.block(i) = C.index(i).nblock();
571  }
572  Cbi_lb.assign_blocks(fixC,fixvalC);
573  Cbi_ub.assign_blocks(fixC,fixvalC);
574  C_begin = Cbm.lower_bound(Cbi_lb);
575  C_end = Cbm.upper_bound(Cbi_ub);
576 
577  if (timer) timer->enter("C loop");
578  for (typename Array<NC>::blockmap_t::const_iterator
579  Citer = C_begin;
580  Citer != C_end;
581  Citer++) {
582  const BlockInfo<NC> &Cbi = Citer->first;
583  double *Cdata = Citer->second;
584  Abi.assign_blocks(extA, Cbi, extCA);
585  std::pair<
586  typename Array<NA>::cached_blockmap_t::const_iterator,
587  typename Array<NA>::cached_blockmap_t::const_iterator >
588  rangeA;
589 #ifdef USE_BOUND
590  // cannot use equal range on remappedA because bound is used to sort
591  // rangeA = remappedAbm.equal_range(Abi);
592  Abi.set_bound(DBL_MAX);
593  rangeA.first = remappedAbm.lower_bound(Abi);
594  Abi.set_bound(0.0);
595  rangeA.second = remappedAbm.upper_bound(Abi);
596 #else
597  rangeA = remappedAbm.equal_range(Abi);
598 #endif
599  typename Array<NA>::cached_blockmap_t::const_iterator
600  firstA = rangeA.first,
601  fenceA = rangeA.second;
602  Bbi.assign_blocks(extB, Cbi, extCB);
603  blasint n_extB = Cbi.subset_size(C.indices(), extCB);
604  blasint n_extA = Cbi.subset_size(C.indices(), extCA);
605 #ifdef USE_HASH
606  typename Array<NB>::blockhash_t::const_iterator Biter;
607 #else
608  typename Array<NB>::blockmap_t::const_iterator Biter = Bbm.begin();
609 #endif
610  if (timer) timer->enter("A loop");
611  for (typename Array<NA>::cached_blockmap_t::const_iterator
612  Aiter = firstA;
613  Aiter != fenceA;
614  Aiter++) {
615  const BlockInfo<NA> &Abi = Aiter->first;
616  double *Adata = Aiter->second;
617  Bbi.assign_blocks(intB, Abi, intA);
618 #ifdef USE_HASH
619  Biter = Bbh.find(Bbi);
620  if (Biter == Bbh.end()) continue;
621 #else
622  if (fixB.n() > 0) {
623 #if USE_STL_MULTIMAP
624  Biter = Bbm.find(Bbi);
625 #else
626  Biter = Bbm.find(B_fixed_hint, Bbi);
627 #endif
628  }
629  else {
630  //blindly using a hint here makes this a bit slower
631  //Biter = Bbm.find(Biter, Bbi);
632  Biter = Bbm.find(Bbi);
633  }
634  if (Biter == Bbm.end()) continue;
635 #endif
636  double *Bdata = Biter->second;
637  blasint n_int = Abi.subset_size(A.indices(), intA);
638 
639  double one = 1.0;
640  if (timer) timer->enter("dgemm");
641 
642  double t0 = cpu_walltime();
643 
644  if (n_extA == 1 && n_int == 1) {
645  double tmp = ABfactor * Adata[0];
646  for (int i=0; i<n_extB; i++) {
647  Cdata[i] += tmp*Bdata[i];
648  }
649  }
650  else if (n_extA == 1 && n_extB == 1) {
651  double tmp = 0.0;
652  for (int i=0; i<n_int; i++) {
653  tmp += Adata[i]*Bdata[i];
654  }
655  Cdata[0] += ABfactor*tmp;
656  }
657  else if (n_int == 1 && n_extB == 1) {
658  double tmp = ABfactor*Bdata[0];
659  for (int i=0; i<n_extA; i++) {
660  Cdata[i] += Adata[i]*tmp;
661  }
662  }
663  else if (n_int == 1) {
664  if (repack_scheme.transpose_C()) {
665  for (int i=0,ij=0; i<n_extB; i++) {
666  for (int j=0; j<n_extA; j++,ij++) {
667  Cdata[ij] += ABfactor*Adata[j]*Bdata[i];
668  }
669  }
670  }
671  else {
672  for (int i=0,ij=0; i<n_extA; i++) {
673  for (int j=0; j<n_extB; j++,ij++) {
674  Cdata[ij] += ABfactor*Adata[i]*Bdata[j];
675  }
676  }
677  }
678  }
679  else if (n_extA == 1) {
680  if (repack_scheme.transpose_B()) {
681  for (int i=0,ij=0; i<n_extB; i++) {
682  double tmp = 0.0;
683  for (int j=0; j<n_int; j++,ij++) {
684  tmp += Adata[j]*Bdata[ij];
685  }
686  Cdata[i] += tmp * ABfactor;
687  }
688  }
689  else {
690  for (int i=0; i<n_extB; i++) {
691  double tmp = 0.0;
692  for (int j=0,ij=i; j<n_int; j++,ij+=n_extB) {
693  tmp += Adata[j]*Bdata[ij];
694  }
695  Cdata[i] += tmp * ABfactor;
696  }
697  }
698  }
699  else if (n_extB == 1) {
700  if (repack_scheme.transpose_A()) {
701  for (int i=0; i<n_extA; i++) {
702  double tmp = 0.0;
703  for (int j=0,ij=i; j<n_int; j++,ij+=n_extA) {
704  tmp += Bdata[j]*Adata[ij];
705  }
706  Cdata[i] += tmp * ABfactor;
707  }
708  }
709  else {
710  for (int i=0,ij=0; i<n_extA; i++) {
711  double tmp = 0.0;
712  for (int j=0; j<n_int; j++,ij++) {
713  tmp += Bdata[j]*Adata[ij];
714  }
715  Cdata[i] += tmp * ABfactor;
716  }
717  }
718  }
719  else if (repack_scheme.transpose_C()) {
720  const char *tA = "T";
721  blasint lda = n_int;
722  if (repack_scheme.transpose_A()) { tA = "N"; lda = n_extA; }
723 
724  const char *tB = "T";
725  blasint ldb = n_extB;
726  if (repack_scheme.transpose_B()) { tB = "N"; ldb = n_int; }
727 
728  blasint ldc = n_extA;
729 
730 // std::cout << " tA: " << tA
731 // << " tB: " << tB
732 // << " nr: " << n_extA
733 // << " nc: " << n_extB
734 // << " nl: " << n_int
735 // << " lda: " << lda
736 // << " ldb: " << ldb
737 // << " ldc: " << ldc
738 // << std::endl;
739 
740  F77_DGEMM(tA, tB, &n_extA, &n_extB, &n_int,
741  &ABfactor,Adata,&lda,Bdata,&ldb,
742  &one,Cdata,&ldc);
743  }
744  else {
745  const char *tA = "N";
746  blasint lda = n_int;
747  if (repack_scheme.transpose_A()) { tA = "T"; lda = n_extA; }
748 
749  const char *tB = "N";
750  blasint ldb = n_extB;
751  if (repack_scheme.transpose_B()) { tB = "T"; ldb = n_int; }
752 
753  blasint ldc = n_extB;
754 
755  F77_DGEMM(tB, tA, &n_extB, &n_extA, &n_int,
756  &ABfactor,Bdata,&ldb,Adata,&lda,
757  &one,Cdata,&ldc);
758  }
759 #ifdef USE_COUNT_DGEMM
760  count_dgemm(n_extA, n_int, n_extB,
761  cpu_walltime()-t0);
762 #endif
763  if (timer) timer->exit();
764  }
765  if (timer) timer->exit();
766  }
767  if (timer) timer->exit();
768 
769  // Repack the data of A, B, and C to the orginal data layout
770  if (timer) timer->enter("repack2");
771  if (clear_A_after_use) A.clear();
772  else {
773  if (repack_scheme.need_repack_A()) {
774  repack(A, extA, intA, fixA, fixvalA, true);
775  }
776  }
777 
778  if (clear_B_after_use) B.clear();
779  else {
780  if (repack_scheme.need_repack_B()) {
781  repack(B, intB, extB, fixB, fixvalB, true);
782  }
783  }
784 
785  if (repack_scheme.need_repack_C()) {
786  repack(C, extCA, extCB, fixC, fixvalC, true);
787  }
788  if (timer) timer->exit();
789 
790  if (timer) timer->enter("bounds");
791  C.compute_bounds();
792  if (timer) timer->exit();
793 }
794 
796 template <int N>
797 inline double
798 scalar_contract(
799  Array<N> &c,
800  Array<N> &a, const IndexList &alist)
801 {
802  // Consistency checks
803  if (alist.n() != N) {
804  throw std::invalid_argument(
805  "sma::scalar_contract: # of indices inconsistent");
806  }
807  for (int i=0; i<N; i++) {
808  if (c.index(i) != a.index(alist.i(i)))
809  throw std::invalid_argument(
810  "sma::scalar_contract: indices don't agree");
811  }
812 
813  bool same_index_order = alist.is_identity();
814 
815  double r = 0.0;
816  const typename Array<N>::blockmap_t &amap = a.blockmap();
817  const typename Array<N>::blockmap_t &cmap = c.blockmap();
818  IndexList clist = alist.reverse_mapping();
819  bool use_hint;
820  if (clist.i(0) == 0) use_hint = true;
821  else use_hint = false;
822  typename Array<N>::blockmap_t::const_iterator citer = cmap.begin();
823  for (typename Array<N>::blockmap_t::const_iterator aiter = amap.begin();
824  aiter != amap.end();
825  aiter++) {
826  BlockInfo<N> cbi(aiter->first,clist);
827 #if USE_STL_MULTIMAP
828  citer = cmap.find(cbi);
829 #else
830  if (use_hint) citer = cmap.find(citer,cbi);
831  else citer = cmap.find(cbi);
832 #endif
833  if (citer == cmap.end()) continue;
834  double *cdata = citer->second;
835  double *adata = aiter->second;
836  if (same_index_order) {
837  int sz = c.block_size(cbi);
838  for (int i=0; i<sz; i++) r += cdata[i] * adata[i];
839  }
840  else {
841  BlockIter<N> cbiter(c.indices(),cbi);
842  int coff = 0;
843  for (cbiter.start(); cbiter.ready(); cbiter++,coff++) {
844  r += cdata[coff] * adata[cbiter.subset_offset(alist)];
845  }
846  }
847  }
848 
849  return r;
850 }
851 
858 template <int NC, int NA, int NB>
859 inline void
860 contract_union(
861  Array<NC> &C, const IndexList &extCA, const IndexList &extCB,
862  const IndexList &fixC, const BlockInfo<NC> &fixvalC,
863  Array<NA> &A, const IndexList &extA, const IndexList &intA,
864  const IndexList &fixA, const BlockInfo<NA> &fixvalA,
865  Array<NB> &B, const IndexList &extB, const IndexList &intB,
866  const IndexList &fixB, const BlockInfo<NB> &fixvalB)
867 {
868  // Consistency checks
869  if (extA.n() + extB.n() != NC - fixC.n()) {
870  std::cerr << "NA = " << NA << std::endl;
871  std::cerr << "intA = " << intA << std::endl;
872  std::cerr << "extA = " << extA << std::endl;
873  std::cerr << "fixA = " << fixA << std::endl;
874  std::cerr << "NB = " << NB << std::endl;
875  std::cerr << "intB = " << intB << std::endl;
876  std::cerr << "extB = " << extB << std::endl;
877  std::cerr << "fixB = " << fixB << std::endl;
878  std::cerr << "NC = " << NC << std::endl;
879  std::cerr << "extCA = " << extCA << std::endl;
880  std::cerr << "extCB = " << extCB << std::endl;
881  std::cerr << "fixC = " << fixC << std::endl;
882  throw std::invalid_argument("contract_union: Number of externals on A + B != C");
883  }
884  if (extCA.n() + extCB.n() != NC - fixC.n()) {
885  throw std::invalid_argument("contract_union: Number of indices on C inconsistent");
886  }
887  if (intA.n() != intB.n()) {
888  throw std::invalid_argument(
889  "contract_union: Number of internals on A and B inconsistent");
890  }
891  if (intA.n() + extA.n() != NA - fixA.n()) {
892  std::cerr << "NA = " << NA << std::endl;
893  std::cerr << "extA.n() = " << extA.n() << std::endl;
894  std::cerr << "fixA.n() = " << fixA.n() << std::endl;
895  std::cerr << "intA.n() = " << intA.n() << " (";
896  for (int i=0; i<intA.n(); i++) {
897  std::cerr << " " << intA.i(i);
898  }
899  std::cerr << ")" << std::endl;
900  throw std::invalid_argument("contract_union: Number of indices on A inconsistent");
901  }
902  if (intB.n() + extB.n() != NB - fixB.n()) {
903  throw std::invalid_argument("contract_union: Number of indices on B inconsistent");
904  }
905  for (int i=0; i<extB.n(); i++) {
906  if (B.index(extB.i(i)) != C.index(extCB.i(i))) {
907  throw std::invalid_argument("contract_union: Range conflict between B and C");
908  }
909  }
910  for (int i=0; i<intA.n(); i++) {
911  if (A.index(intA.i(i)) != B.index(intB.i(i))) {
912  throw std::invalid_argument("contract_union: Range conflict between A and B");
913  }
914  }
915 
916 
917  // Remap the blocks of A so that it is sorted by its internal indices and
918  // any fixed indices. The fixed indices appear first (are most
919  // significant wrt the ordering) so we can get the iterator bounds for
920  // relevant internal indices more easily. Data is not moved.
921  IndexList cmpAlist(fixA, intA);
922  IndexListLess<NA> cmpA(cmpAlist);
923  typename Array<NA>::cached_blockmap_t remappedAbm_local(cmpA);
924  typename Array<NA>::cached_blockmap_t *remappedAbm_ptr;
925  if (fixA.n() == 0 && A.use_blockmap_cache()) {
926  remappedAbm_ptr = &A.blockmap_cache_entry(cmpAlist);
927  }
928  else {
929  remap(remappedAbm_local, A, fixA, fixvalA);
930  remappedAbm_ptr = &remappedAbm_local;
931  }
932  typename Array<NA>::cached_blockmap_t &remappedAbm=*remappedAbm_ptr;
933 
934  // Remap the blocks of B so that it is sorted by its internal indices
935  // and any fixed indices. Data is not moved.
936  IndexList cmpBlist(intB, fixB);
937  IndexListLess<NB> cmpB(cmpBlist);
938  typename Array<NB>::cached_blockmap_t remappedBbm_local(cmpB);
939  typename Array<NB>::cached_blockmap_t *remappedBbm_ptr;
940  if (fixB.n() == 0 && B.use_blockmap_cache()) {
941  remappedBbm_ptr = &B.blockmap_cache_entry(cmpBlist);
942  }
943  else {
944  remap(remappedBbm_local, B, fixB, fixvalB);
945  remappedBbm_ptr = &remappedBbm_local;
946  }
947  typename Array<NB>::cached_blockmap_t &remappedBbm=*remappedBbm_ptr;
948 
949 
950 // std::cout << "beginning loops" << std::endl;
951 // std::cout << "extA = " << extA << std::endl;
952 // std::cout << "extB = " << extB << std::endl;
953 // std::cout << "extCA = " << extCA << std::endl;
954 // std::cout << "extCB = " << extCB << std::endl;
955 
956  BlockInfo<NA> ablockinfo;
957  for (int i=0; i<NA; i++) ablockinfo.block(i) = 0;
958 #ifdef USE_BOUND
959  ablockinfo.set_bound(DBL_MAX);
960 #endif
961  ablockinfo.assign_blocks(fixA, fixvalA);
962  typename Array<NA>::cached_blockmap_t::const_iterator abegin;
963  abegin = remappedAbm.lower_bound(ablockinfo);
964 
965  BlockInfo<NB> bblockinfo;
966  bblockinfo.assign_blocks(fixB, fixvalB);
967 
968  BlockInfo<NC> cblockinfo;
969  cblockinfo.assign_blocks(fixC, fixvalC);
970 
971  while (abegin != remappedAbm.end()) {
972  ablockinfo = abegin->first;
973 
974  // if there are fixed indices, then abegin might be beyond
975  // the fixed indices that we are interested in
976  if (!ablockinfo.equiv_blocks(fixA, fixvalA)) break;
977 
978 #ifdef USE_BOUND
979 #if 0 && USE_BOUNDS_IN_CONTRACT_UNION
980  if (B.bound() < DBL_EPSILON) {
981  ablockinfo.set_bound(C.tolerance()/DBL_EPSILON);
982  }
983  else {
984  ablockinfo.set_bound(C.tolerance()/B.bound());
985  }
986 #else
987  ablockinfo.set_bound(0.0);
988 #endif
989 #endif
990  typename Array<NA>::cached_blockmap_t::const_iterator
991  afence = remappedAbm.upper_bound(ablockinfo);
992  bblockinfo.assign_blocks(intB,ablockinfo,intA);
993  std::pair<typename Array<NB>::cached_blockmap_t::const_iterator,
994  typename Array<NB>::cached_blockmap_t::const_iterator>
995  brange;
996  // cannot use equal_range on remappedB since bounds are used to sort
997  // brange = remappedBbm.equal_range(bblockinfo);
998 #ifdef USE_BOUND
999  bblockinfo.set_bound(DBL_MAX);
1000 #endif
1001  brange.first = remappedBbm.lower_bound(bblockinfo);
1002 #ifdef USE_BOUND
1003 #if 0 && USE_BOUNDS_IN_CONTRACT_UNION
1004  if (A.bound() < DBL_EPSILON) {
1005  bblockinfo.set_bound(C.tolerance()/DBL_EPSILON);
1006  }
1007  else {
1008  bblockinfo.set_bound(C.tolerance()/A.bound());
1009  }
1010 #else
1011  bblockinfo.set_bound(0.0);
1012 #endif
1013 #endif
1014  brange.second = remappedBbm.upper_bound(bblockinfo);
1015  typename Array<NB>::cached_blockmap_t::const_iterator
1016  bbegin = brange.first,
1017  bfence = brange.second;
1018 // std::cout << " in internal loop" << std::endl;
1019  for (typename Array<NA>::cached_blockmap_t::const_iterator
1020  aiter = abegin;
1021  aiter != afence;
1022  aiter++) {
1023 #ifdef USE_BOUND
1024  double a_block_bound = aiter->first.bound();
1025 #endif
1026  cblockinfo.assign_blocks(extCA,aiter->first,extA);
1027 // std::cout << " A blocks: " << aiter->first << std::endl;
1028  for (typename Array<NB>::cached_blockmap_t::const_iterator
1029  biter = bbegin;
1030  biter != bfence;
1031  biter++) {
1032 #ifdef USE_BOUND
1033 #if 0 && USE_BOUNDS_IN_CONTRACT_UNION
1034  if (a_block_bound * biter->first.bound() < C.tolerance()) {
1035  continue;
1036  }
1037 #endif
1038 #endif
1039  cblockinfo.assign_blocks(extCB,biter->first,extB);
1040 // std::cout << " B blocks: " << biter->first
1041 // << " adding " << cblockinfo << std::endl;
1042  C.add_unallocated_block(cblockinfo);
1043  }
1044  }
1045 #ifdef USE_BOUND
1046  ablockinfo.set_bound(0.0);
1047 #endif
1048  abegin = remappedAbm.upper_bound(ablockinfo);
1049  }
1050 }
1051 
1052 }
1053 
1054 }
1055 
1056 #endif
sc::sma2::RepackScheme::need_repack_B
bool need_repack_B() const
Returns true if B needs repacked in the current scheme.
Definition: contract.h:216
sc::sma2::RepackScheme::assign_indices
void assign_indices(IndexList &extA, IndexList &intA, IndexList &intB, IndexList &extB, IndexList &extCA, IndexList &extCB)
Assign the contraction indices.
Definition: contract.h:228
sc::Ref< sc::RegionTimer >
sc::sma2::RepackScheme::transpose_C
bool transpose_C() const
Returns true if C needs transposed in the current scheme.
Definition: contract.h:222
sc::sma2::RepackScheme::need_repack_A
bool need_repack_A() const
Returns true if A needs repacked in the current scheme.
Definition: contract.h:212
sc::sma2::RepackScheme::driver
void driver(int i)
Set the array that determines the index ordering to C (i==0), A (i==1), or B (i==2) This will update ...
Definition: contract.h:196
sc::sma2::BlockInfo< NA >
sc::sma2::Array< NA >
sc::sma2::RepackScheme::cost
double cost() const
Returns the cost of the current scheme.
Definition: contract.h:224
sc::sma2::RepackScheme
Determine the cost of repacking arrays for a contraction.
Definition: contract.h:105
sc::sma2::IndexList
An IndexList is a vector of indices.
Definition: sma.h:160
sc::sma2::RepackScheme::need_repack_C
bool need_repack_C() const
Returns true if C needs repacked in the current scheme.
Definition: contract.h:220
sc::sma2::RepackScheme::RepackScheme
RepackScheme(const Array< NA > *A, const IndexList &extA, const IndexList &intA, const IndexList &fixA, const BlockInfo< NA > *fixvalA, const Array< NB > *B, const IndexList &intB, const IndexList &extB, const IndexList &fixB, const BlockInfo< NB > *fixvalB, const Array< NC > *C, const IndexList &extCA, const IndexList &extCB, const IndexList &fixC, const BlockInfo< NC > *fixvalC, int n_C_repack)
Create the RepackScheme for a given contraction.
Definition: contract.h:175
sc::sma2::RepackScheme::transpose_B
bool transpose_B() const
Returns true if B needs transposed in the current scheme.
Definition: contract.h:218
sc::sma2::RepackScheme::transpose_A
bool transpose_A() const
Returns true if A needs transposed in the current scheme.
Definition: contract.h:214
sc::count_dgemm
void count_dgemm(int n, int l, int m, double t)
Records information about the time take to perform a DGEMM operation.
sc
Contains all MPQC code up to version 3.
Definition: mpqcin.h:14

Generated at Sun Jan 26 2020 23:23:58 for MPQC 3.0.0-alpha using the documentation package Doxygen 1.8.16.