29 #ifndef _chemistry_qc_scf_cadf_treemat_h
30 #define _chemistry_qc_scf_cadf_treemat_h
33 #include <type_traits>
36 #include <boost/shared_ptr.hpp>
38 #include <Eigen/Dense>
40 #include <util/misc/assert.h>
41 #include <util/misc/scexception.h>
43 #include "treemat_fwd.h"
46 namespace sc {
namespace cadf {
48 template<
typename NormContainer,
typename Index,
typename NormValue>
52 typedef NormContainer norm_container_t;
53 typedef Index index_t;
54 typedef NormValue norm_value_t;
58 norm_container_t norms_;
61 std::vector<boost::shared_ptr<TreeBlock>> children_;
62 int atom_block_index_ = -1;
68 template<
typename Derived>
70 const Eigen::MatrixBase<Derived>& norms,
71 index_t begin_idx, index_t end_idx
73 begin_index_(begin_idx),
77 template<
typename Iterator>
78 static boost::shared_ptr<TreeBlock>
79 merge_blocks(Iterator begin, Iterator end, Index end_index,
int atom_block_index = -1)
81 boost::shared_ptr<TreeBlock> rv(boost::make_shared<TreeBlock>());
82 const Index norms_size = (*begin)->norms_.size();
83 rv->norms_.resize(norms_size);
84 rv->norms_ = NormContainer::Zero(norms_size);
85 rv->begin_index_ = (*begin)->begin_index();
86 rv->end_index_ = end_index;
87 rv->atom_block_index_ = atom_block_index;
91 if(std::is_base_of<std::random_access_iterator_tag,
92 typename std::iterator_traits<Iterator>::iterator_category>::value
95 typename std::iterator_traits<Iterator>::difference_type dist = end - begin;
96 rv->children_.reserve(dist);
99 for(
auto it = begin; it != end; ++it) {
100 rv->norms_.array() += (*it)->norms_.array().square();
101 rv->children_.push_back(*it);
103 rv->norms_ = rv->norms_.cwiseSqrt();
107 index_t begin_index()
const {
return begin_index_; }
109 index_t end_index()
const {
return end_index_; }
111 bool is_atom_block()
const {
return atom_block_index_ != -1; }
113 int atom_block_index()
const {
return atom_block_index_; }
115 template <
typename NormIndex>
117 norm(
const NormIndex idx)
const
122 size_t n_children()
const {
return children_.size(); }
123 bool is_leaf()
const {
return n_children() == 0; }
125 const std::vector<boost::shared_ptr<TreeBlock>>& children()
const {
return children_; }
129 template<
typename BlockType>
134 typedef BlockType block_t;
135 typedef typename block_t::index_t index_t;
136 typedef boost::shared_ptr<block_t> block_ptr_t;
137 typedef std::deque<block_ptr_t> block_container_t;
142 block_container_t blocks_;
146 const block_container_t& blocks()
const {
return blocks_; }
148 typename block_container_t::const_iterator
149 breadth_first_begin()
const {
150 return blocks_.cbegin();
153 typename block_container_t::const_iterator
154 breadth_first_end()
const {
155 return blocks_.cend();
158 const block_t root()
const {
return blocks_.front(); }
160 template<
typename Derived>
162 const Eigen::MatrixBase<Derived>& m,
164 std::vector<int> block_requirements = { SameAngularMomentum|SameCenter, SameCenter },
169 std::deque<block_ptr_t> curr_blocks;
170 std::deque<block_ptr_t> new_blocks;
173 auto blk = boost::make_shared<block_t>(m.row(ish), ish.bfoff, ish.bfoff+ish.nbf);
174 new_blocks.push_front(blk);
177 for(
auto req : block_requirements) {
179 while(!new_blocks.empty()) {
180 curr_blocks.push_front(new_blocks.front());
181 blocks_.push_front(new_blocks.front());
182 new_blocks.pop_front();
185 auto blk_iter = curr_blocks.cbegin();
186 for(
const auto&& iblk :
shell_block_range(basis, 0, 0, NoLastIndex, req, NoMaximumBlockSize)) {
187 auto blk_start = blk_iter;
188 index_t end_index = (*blk_start)->end_index();
189 while(blk_iter != curr_blocks.cend()
190 and (*blk_iter)->begin_index() < iblk.bfoff + iblk.nbf
192 end_index = (*blk_iter)->end_index();
195 new_blocks.push_front(block_t::merge_blocks(blk_start, blk_iter, end_index,
196 req==SameCenter ? iblk.center : -1
201 while(new_blocks.size() > 1) {
204 while(!new_blocks.empty()) {
205 curr_blocks.push_front(new_blocks.front());
206 blocks_.push_front(new_blocks.front());
207 new_blocks.pop_front();
210 auto blk_iter = curr_blocks.cbegin();
211 while(blk_iter != curr_blocks.cend()) {
212 auto blk_start = blk_iter;
213 index_t end_index = (*blk_start)->end_index();
214 for(
int ichild = 0; ichild < max_children and blk_iter != curr_blocks.end(); ++ichild) {
215 end_index = (*blk_iter)->end_index();
218 new_blocks.push_front(block_t::merge_blocks(blk_start, blk_iter, end_index));
223 blocks_.push_front(new_blocks.front());
232 typename LeftBlockPtr=
typename TreeMatrix<>::block_ptr_t,
233 typename RightBlockPtr=
typename TreeMatrix<>::block_ptr_t
242 LeftBlockPtr left_block_;
243 RightBlockPtr right_block_;
247 template<
typename LeftIndex,
typename RightIndex>
249 const LeftBlockPtr& left, LeftIndex left_index,
250 const RightBlockPtr& right, RightIndex right_index
251 ) : norm_(left->norm(left_index) * right->norm(right_index)),
252 begin_index_(left->begin_index()),
253 end_index_(left->end_index()),
254 left_block_(left), right_block_(right)
256 MPQC_ASSERT(left->begin_index() == right->begin_index());
257 MPQC_ASSERT(left->end_index() == right->end_index());
258 MPQC_ASSERT(left->n_children() == right->n_children());
261 const Index size()
const {
return end_index_ - begin_index_; }
262 const Index begin_index()
const {
return begin_index_; }
263 const Index end_index()
const {
return end_index_; }
264 const double norm()
const {
return norm_; }
267 return norm_ / size() <
other.norm_ /
other.size();
270 template<
typename LeftIndex,
typename RightIndex>
273 std::queue<ProductBlock>& queue,
274 const LeftIndex left_index,
275 const RightIndex right_index
278 auto left_iter = left_block_->children().cbegin();
279 auto right_iter = right_block_->children().cbegin();
280 for(; left_iter != left_block_->children().end(); ++left_iter, ++right_iter) {
281 queue.emplace(*left_iter, left_index, *right_iter, right_index);
285 bool is_leaf()
const {
286 return left_block_->is_leaf();
289 bool is_atom_block()
const {
290 return left_block_->is_atom_block();
293 int atom_block_index()
const {
294 return left_block_->atom_block_index();
305 bool operator()(
const T& a,
const T& b)
const {
306 return a.begin_index() < b.begin_index();
312 template<
typename LeftTreeType,
typename RightTreeType,
typename LeftIndex,
typename RightIndex>
313 inline std::vector<std::pair<
314 typename LeftTreeType::index_t,
315 typename RightTreeType::index_t
317 relevant_product_ranges(
318 const LeftTreeType& left, LeftIndex left_index,
319 const RightTreeType& right, RightIndex right_index,
320 double thresh,
int exclude_center = -1,
324 bool guarantee_min_error=
false
328 typename LeftTreeType::index_t,
329 typename LeftTreeType::block_ptr_t,
330 typename RightTreeType::block_ptr_t
332 typedef std::vector<std::pair<
333 typename LeftTreeType::index_t,
334 typename RightTreeType::index_t
337 std::priority_queue<product_block_t> discarded_blocks;
338 std::set<product_block_t, detail::begin_index_less<product_block_t>> sig_blocks;
339 std::queue<product_block_t> working_blocks;
340 double curr_error_squared = 0.0;
342 auto left_iter = left.breadth_first_begin();
343 auto right_iter = right.breadth_first_begin();
345 working_blocks.emplace(*left_iter, left_index, *right_iter, right_index);
347 while(!working_blocks.empty()) {
349 const auto& curr_block = working_blocks.front();
351 if(exclude_center != -1 and curr_block.is_atom_block() and curr_block.atom_block_index() == exclude_center) {
354 else if(curr_block.norm() < thresh) {
355 if(guarantee_min_error) {
356 discarded_blocks.push(curr_block);
358 curr_error_squared += curr_block.norm() * curr_block.norm();
361 if(curr_block.is_leaf()) {
362 sig_blocks.insert(curr_block);
365 curr_block.enqueue_children(working_blocks, left_index, right_index);
368 working_blocks.pop();
371 if(guarantee_min_error) {
372 throw FeatureNotImplemented(
"guarantee_min_error", __FILE__, __LINE__);
377 for(
const auto& sig_block : sig_blocks) {
378 if(rv.empty() or rv.back().second != sig_block.begin_index()) {
379 rv.emplace_back(sig_block.begin_index(), sig_block.end_index());
382 rv.back().second = sig_block.end_index();