diff options
| -rw-r--r-- | stratified_sampling.hpp | 40 |
1 files changed, 25 insertions, 15 deletions
diff --git a/stratified_sampling.hpp b/stratified_sampling.hpp index e25ac63..21bfea8 100644 --- a/stratified_sampling.hpp +++ b/stratified_sampling.hpp @@ -45,13 +45,14 @@ private: gsl_rng *gen; }; -template <typename Gen> +template <typename L> struct stratified_sampling { - stratified_sampling(vector<double> p, vector<Gen> gen) - :p(p), gen(gen), mean(p.size(), 0), sigma2(p.size(), 0){}; + stratified_sampling(vector<double> p, vector<L> X) + :p(p), X(X), mean(p.size(), 0), sigma2(p.size(), 0), I(p.size()){}; void update(int N); void draw(); vector<double> get_mean(); + vector<double> get_var(); //double estimator(); private: vector<double> p; @@ -59,14 +60,13 @@ private: vector<int> cumM; vector<double> mean; vector<double> sigma2; - vector<Gen> gen; - + vector<L> X; + int I; }; //actualisation du nombre de tirages à faire par strates -template <typename Gen> -void stratified_sampling<Gen>::update(int Nk) { - int I = p.size(); +template <typename L> +void stratified_sampling<L>::update(int Nk) { bool first_step = M.empty(); //reinitialistation du vecteur M du nombre de tirages par strates if (first_step) { @@ -99,26 +99,30 @@ void stratified_sampling<Gen>::update(int Nk) { std::cout<<m[i]<<std::endl; } } + std::cout<<"M[O] avant mis à jour"<<M[0]<<endl; M[0]+=floor(m[0]); + std::cout<<"M[0] après"<<M[0]<<endl; double current = m[0]; + int compteur = M[0]; for (int i=1; i<I; i++){ M[i] += floor(current+m[i]) - floor(current); current += m[i]; + compteur += M[i]; cout<<M[i]<<"\t"; } cout<<endl; + cout<<compteur<<endl; } -template <typename Gen> -void stratified_sampling<Gen>::draw() { - int I = p.size(); +template <typename L> +void stratified_sampling<L>::draw() { double m, s, oldmean; for(int i=0;i<I;i++){ m=0; s=0; for(int j=0;j<M[i];j++){ - m=m+gen[i](); - s=s+gen[i].current()*gen[i].current(); + m=m+X[i](); + s=s+X[i].current()*X[i].current(); } oldmean=mean[i]; mean[i]=(mean[i]*cumM[i]+m)/(cumM[i]+M[i]); @@ -126,7 +130,13 @@ void stratified_sampling<Gen>::draw() { } }; -template <typename Gen> -vector<double> stratified_sampling<Gen>::get_mean() { +template <typename L> +vector<double> stratified_sampling<L>::get_mean() { return mean; }; + +template <typename L> +vector<double> stratified_sampling<L>::get_var() { + return sigma2; +}; + |
