/********************************************************************** Extract.h Used for extracting elements of Blitz++ vector using an irregularly spaced range. *********************************************************************** */ #ifndef BLITZ_EXTRACT_H #define BLITZ_EXTRACT_H #include #include using std::cerr; using std::endl; #include "blitz/array.h" using namespace blitz; namespace bul { // Extraction using integer index template Array extract( const Array& index, const Array& A); template Array extract( const Array& index, const int dim, const Array& A); // Extraction using Boolean index template Array extract(const Array& index, const Array& A); template Array extract( const Array& index, const int dim, const Array& A); // Function definitions ************************************************* template Array extract( const Array& index, const Array& A) { // check to ensure that index of integers is consistent with A if (any(index < A.lbound(firstDim)) || any(index > A.ubound(firstDim))) { cerr << "Array and its index are incompatible" << endl; assert(all(index >= A.lbound(firstDim)) && all(index <= A.ubound(firstDim)) ); } // extract indexed elements of A Array retVal(index.size()); for (int i=index.lbound(firstDim); i<=index.ubound(firstDim); i++) { retVal(i) = A(index(i)); } return retVal.copy(); } template Array extract( const Array& index, const int dim, const Array& A) { // check to ensure that dim refers to rows or columns of A if (dim < firstDim || dim > secondDim) { cerr << "Dimension " << dim << " is out of bounds" << endl; assert(dim >= firstDim && dim <=secondDim); } // check to ensure that index of integers is consistent with A if (any(index < A.lbound(dim)) || any(index > A.ubound(dim))) { cerr << "Array and its index are incompatible" << endl; assert(all(index >= A.lbound(dim)) && all(index <= A.ubound(dim)) ); } // extract indexed elements of A int i,j; Range all = Range::all(); Array retVal; if (dim==firstDim) { retVal.resize(index.size(), A.columns()); for (i=index.lbound(firstDim); i<=index.ubound(firstDim); i++) { retVal(i,all) = A(index(i), all); } } else { retVal.resize(A.rows(), index.size()); for (j=index.lbound(firstDim); j<=index.ubound(firstDim); j++) { retVal(all,j) = A(all, index(j)); } } return retVal.copy(); } template Array extract(const Array& index, const Array& A) { // check whether A and the boolean index are of equal length if (A.size() != index.size()) { cerr << "Array and its boolean index are not of equal length" << endl; assert(A.size() == index.size()); } // extract indexed elements of A Array retVal(sum(index)); int sizeOfretVal = index.lbound(firstDim); for (int i=index.lbound(firstDim); i<=index.ubound(firstDim); i++) { if (index(i)) { retVal(sizeOfretVal) = A(i); sizeOfretVal++; } } return retVal.copy(); } template Array extract( const Array& index, const int dim, const Array& A) { // check to ensure that dim refers to rows or columns of A if (dim < firstDim || dim > secondDim) { cerr << "Dimension " << dim << " is out of bounds" << endl; assert(dim >= firstDim && dim <=secondDim); } // check whether A and the boolean index are of equal length if (A.extent(dim) != index.size()) { cerr << "Array and its boolean index are not of equal length" << endl; assert(A.extent(dim) == index.size()); } // extract indexed elements of A int i,j, sizeOfretVal; Range all = Range::all(); Array retVal; if (dim==firstDim) { retVal.resize(sum(index), A.columns()); sizeOfretVal = index.lbound(firstDim); for (i=index.lbound(firstDim); i<=index.ubound(firstDim); i++) { if (index(i)) { retVal(sizeOfretVal, all) = A(i, all); sizeOfretVal++; } } } else { retVal.resize(A.rows(), sum(index)); sizeOfretVal = index.lbound(firstDim); for (j=index.lbound(firstDim); j<=index.ubound(firstDim); j++) { if (index(j)) { retVal(all, sizeOfretVal) = A(all, j); sizeOfretVal++; } } } return retVal.copy(); } } /* namespace bul */ #endif // BLITZ_EXTRACT_H