1
0
Fork 0
mirror of https://github.com/cosmo-sims/monofonIC.git synced 2024-09-19 17:03:45 +02:00

added convolution class

This commit is contained in:
Oliver Hahn 2019-05-10 04:48:35 +02:00
parent cecd3a2dd5
commit 2f54498fc2
4 changed files with 138 additions and 13 deletions

View file

@ -14,6 +14,122 @@ enum space_t
rspace_id
};
template< typename data_t >
class Grid_FFT;
template <typename data_t>
void unpad(const Grid_FFT<data_t> &fp, Grid_FFT<data_t> &f);
template <typename data_t>
void pad_insert(const Grid_FFT<data_t> &f, Grid_FFT<data_t> &fp);
template< typename data_t >
class OrszagConvolver
{
protected:
Grid_FFT<data_t> *f1p_, *f2p_;
std::array<size_t,3> np_;
std::array<real_t,3> length_;
ccomplex_t *crecvbuf_;
real_t *recvbuf_;
ptrdiff_t *offsets_;
ptrdiff_t *offsetsp_;
ptrdiff_t *sizes_;
ptrdiff_t *sizesp_;
private:
int get_task( ptrdiff_t index, const ptrdiff_t *offsets, const ptrdiff_t *sizes, const int ntasks ) const
{
int itask = 0;
while( itask < ntasks-1 && offsets[itask+1] <= index ) ++itask;
return itask;
}
// void pad_insert( const Grid_FFT<data_t> & f, Grid_FFT<data_t> & fp );
// void unpad( const Grid_FFT<data_t> & fp, Grid_FFT< data_t > & f );
public:
OrszagConvolver( const std::array<size_t, 3> &N, const std::array<real_t, 3> &L )
: np_({3*N[0]/2,3*N[1]/2,3*N[2]/2}), length_(L)
{
//... create temporaries
f1p_ = new Grid_FFT<data_t>(np_, length_, kspace_id);
f2p_ = new Grid_FFT<data_t>(np_, length_, kspace_id);
#if defined(WITH_MPI)
size_t maxslicesz = f1p_->sizes_[1] * f1p_->sizes_[3] * 2;
crecvbuf_ = new ccomplex_t[maxslicesz / 2];
recvbuf_ = reinterpret_cast<real_t *>(&crecvbuf_[0]);
int ntasks(MPI_Get_size());
offsets_ = new ptrdiff_t[ntasks];
offsetsp_ = new ptrdiff_t[ntasks];
sizes_ = new ptrdiff_t[ntasks];
sizesp_ = new ptrdiff_t[ntasks];
size_t tsize = f.size(0), tsizep = f1p_->size(0);
MPI_Allgather(&f.local_1_start_, 1, MPI_LONG_LONG, &offsets_[0], 1,
MPI_LONG_LONG, MPI_COMM_WORLD);
MPI_Allgather(&f1p_->local_1_start_, 1, MPI_LONG_LONG, &offsetsp_[0], 1,
MPI_LONG_LONG, MPI_COMM_WORLD);
MPI_Allgather(&tsize, 1, MPI_LONG_LONG, &sizes_[0], 1, MPI_LONG_LONG,
MPI_COMM_WORLD);
MPI_Allgather(&tsizep, 1, MPI_LONG_LONG, &sizesp_[0], 1, MPI_LONG_LONG,
MPI_COMM_WORLD);
#endif
}
~OrszagConvolver()
{
delete f1p_;
delete f2p_;
#ifdef WITH_MPI
delete[] crecvbuf_;
delete[] offsets_;
delete[] offsetsp_;
delete[] sizes_;
delete[] sizesp_;
#endif
}
//... inplace interface
void convolve2( const Grid_FFT<data_t> & f1, const Grid_FFT<data_t> & f2, Grid_FFT<data_t> & res )
{
//... copy data 1
f1.to_kspace();
f1p_->to_kspace(false);
pad_insert(f1, *f1p_);
//... copy data 2
f2.to_kspace();
f2p_->to_kspace(false);
pad_insert(f2, *f2p_);
//... convolve
f1p_->to_rspace();
f2p_->to_rspace();
for (size_t i = 0; i < f1p_->ntot_; ++i){
(*f2p_)[i] *= (*f1p_)[i];
}
f2p_->to_kspace();
//... copy data back
res.to_kspace(false);
unpad(*f2p_, res);
}
//... inplace interface
void convolve3( const Grid_FFT<data_t> & f1, const Grid_FFT<data_t> & f2, const Grid_FFT<data_t> & f3, Grid_FFT<data_t> & res )
{
convolve2( f1, f2, res );
convolve2( res, f3, res );
}
};
template <typename data_t>
class Grid_FFT
{
@ -408,8 +524,3 @@ public:
}
};
template <typename data_t>
void unpad(const Grid_FFT<data_t> &fp, Grid_FFT<data_t> &f);
template <typename data_t>
void pad_insert(const Grid_FFT<data_t> &f, Grid_FFT<data_t> &fp);

View file

@ -4,8 +4,17 @@ template< typename T >
class vec3{
private:
std::array<T,3> data_;
T &x,&y,&z;
public:
vec3()
: x(data_[0]),y(data_[1]),z(data_[2]){}
vec3( const vec3<T> &v)
: data_(v.data_), x(data_[0]),y(data_[1]),z(data_[2]){}
vec3( vec3<T> &&v)
: data_(std::move(v.data_)), x(data_[0]),y(data_[1]),z(data_[2]){}
T &operator[](size_t i){ return data_[i];}
const T &operator[](size_t i) const { return data_[i]; }

View file

@ -55,7 +55,7 @@ void Grid_FFT<data_t>::Setup(void)
}
else
{
csoca::elog.Print("invalid data type in field3d<data_t>::setup_fft_interface\n");
csoca::elog.Print("invalid data type in Grid_FFT<data_t>::setup_fft_interface\n");
}
fft_norm_fac_ = 1.f / sqrtf((float)((size_t)n_[0] * (size_t)n_[1] * (size_t)n_[2]));
@ -135,7 +135,7 @@ void Grid_FFT<data_t>::Setup(void)
}
else
{
csoca::elog.Print("unknown data type in field3d<data_t>::setup_fft_interface\n");
csoca::elog.Print("unknown data type in Grid_FFT<data_t>::setup_fft_interface\n");
abort();
}
@ -214,7 +214,7 @@ void Grid_FFT<data_t>::FourierTransformForward(bool do_transform)
this->ApplyNorm();
wtime = get_wtime() - wtime;
csoca::ilog.Print("[FFT] Completed field3d::to_kspace (%lux%lux%lu), took %f s", sizes_[0], sizes_[1], sizes_[2], wtime);
csoca::ilog.Print("[FFT] Completed Grid_FFT::to_kspace (%lux%lux%lu), took %f s", sizes_[0], sizes_[1], sizes_[2], wtime);
}
sizes_[0] = local_1_size_;
@ -242,7 +242,7 @@ void Grid_FFT<data_t>::FourierTransformBackward(bool do_transform)
this->ApplyNorm();
wtime = get_wtime() - wtime;
csoca::ilog.Print("[FFT] Completed field3d::to_rspace (%dx%dx%d), took %f s\n", sizes_[0], sizes_[1], sizes_[2], wtime);
csoca::ilog.Print("[FFT] Completed Grid_FFT::to_rspace (%dx%dx%d), took %f s\n", sizes_[0], sizes_[1], sizes_[2], wtime);
}
sizes_[0] = local_0_size_;
sizes_[1] = n_[1];
@ -991,6 +991,9 @@ void unpad(const Grid_FFT<data_t> &fp, Grid_FFT<data_t> &f)
#endif /// end of ifdef/ifndef USE_MPI //////////////////////////////////////////////////////////////
}
/********************************************************************************************/
template class Grid_FFT<real_t>;
template class Grid_FFT<ccomplex_t>;

View file

@ -154,6 +154,8 @@ int main( int argc, char** argv )
Grid_FFT<real_t> phi3a({ngrid, ngrid, ngrid}, {boxlen, boxlen, boxlen});
Grid_FFT<real_t> phi3b({ngrid, ngrid, ngrid}, {boxlen, boxlen, boxlen});
OrszagConvolver<real_t> Conv({ngrid, ngrid, ngrid}, {boxlen, boxlen, boxlen});
phi.FillRandomReal(6519);
//======================================================================
@ -224,9 +226,9 @@ int main( int argc, char** argv )
{
size_t idx = phi2.get_idx(i, j, k);
phi2.relem(idx) = ((phi_xx.relem(idx)*phi_yy.relem(idx)-phi_xy.relem(idx)*phi_xy.relem(idx))
+(phi_xx.relem(idx)*phi_zz.relem(idx)-phi_xz.relem(idx)*phi_xz.relem(idx))
+(phi_yy.relem(idx)*phi_zz.relem(idx)-phi_yz.relem(idx)*phi_yz.relem(idx)));
phi2.relem(idx) = phi_xx.relem(idx)*phi_yy.relem(idx)-phi_xy.relem(idx)*phi_xy.relem(idx)
+phi_xx.relem(idx)*phi_zz.relem(idx)-phi_xz.relem(idx)*phi_xz.relem(idx)
+phi_yy.relem(idx)*phi_zz.relem(idx)-phi_yz.relem(idx)*phi_yz.relem(idx);
}
}
}