#include <iostream>
#include <TNL/Functional.h>
#include <TNL/Containers/Vector.h>
#include <TNL/Algorithms/Segments/CSR.h>
#include <TNL/Algorithms/SegmentsReductionKernels/DefaultKernel.h>
#include <TNL/Devices/Host.h>
#include <TNL/Devices/Cuda.h>
template< typename Device >
void
SegmentsExample()
{
using SegmentsReductionKernel =
const int size( 5 );
SegmentsType segments{ 1, 2, 3, 4, 5 };
auto data_view = data.getView();
segments.forElements( 0,
size,
{
if( localIdx <= segmentIdx )
data_view[ globalIdx ] = segmentIdx;
} );
auto sums_view = sums.getView();
auto fetch_full = [ = ]
__cuda_callable__(
int segmentIdx,
int localIdx,
int globalIdx ) ->
double
{
if( localIdx <= segmentIdx )
return data_view[ globalIdx ];
else
return 0.0;
};
{
return data_view[ globalIdx ];
};
{
sums_view[ globalIdx ] = value;
};
SegmentsReductionKernel kernel;
kernel.init( segments );
kernel.reduceAllSegments( segments, fetch_full,
TNL::Plus{}, keep );
kernel.reduceAllSegments( segments, fetch_brief,
TNL::Plus{}, keep );
}
int
main( int argc, char* argv[] )
{
SegmentsExample< TNL::Devices::Host >();
#ifdef __CUDACC__
SegmentsExample< TNL::Devices::Cuda >();
#endif
return EXIT_SUCCESS;
}
Data structure for CSR segments format.
Definition CSR.h:27
Array is responsible for memory management, access to array elements, and general array operations.
Definition Array.h:64
Vector extends Array with algebraic operations.
Definition Vector.h:36
Definition DefaultKernel.h:21
Function object implementing x + y.
Definition Functional.h:17