cmp.cpp

Go to the documentation of this file.
00001 
00005 /*
00006  * Author: Steven Ludtke, 04/10/2003 (sludtke@bcm.edu)
00007  * Copyright (c) 2000-2006 Baylor College of Medicine
00008  *
00009  * This software is issued under a joint BSD/GNU license. You may use the
00010  * source code in this file under either license. However, note that the
00011  * complete EMAN2 and SPARX software packages have some GPL dependencies,
00012  * so you are responsible for compliance with the licenses of these packages
00013  * if you opt to use BSD licensing. The warranty disclaimer below holds
00014  * in either instance.
00015  *
00016  * This complete copyright notice must be included in any revised version of the
00017  * source code. Additional authorship citations may be added, but existing
00018  * author citations must be preserved.
00019  *
00020  * This program is free software; you can redistribute it and/or modify
00021  * it under the terms of the GNU General Public License as published by
00022  * the Free Software Foundation; either version 2 of the License, or
00023  * (at your option) any later version.
00024  *
00025  * This program is distributed in the hope that it will be useful,
00026  * but WITHOUT ANY WARRANTY; without even the implied warranty of
00027  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
00028  * GNU General Public License for more details.
00029  *
00030  * You should have received a copy of the GNU General Public License
00031  * along with this program; if not, write to the Free Software
00032  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
00033  *
00034  * */
00035 
00036 #include "cmp.h"
00037 #include "emdata.h"
00038 #include "ctf.h"
00039 
00040 using namespace EMAN;
00041 
00042 template <> Factory < Cmp >::Factory()
00043 {
00044         force_add(&CccCmp::NEW);
00045         force_add(&SqEuclideanCmp::NEW);
00046         force_add(&DotCmp::NEW);
00047         force_add(&TomoDotCmp::NEW);
00048         force_add(&QuadMinDotCmp::NEW);
00049         force_add(&OptVarianceCmp::NEW);
00050         force_add(&PhaseCmp::NEW);
00051         force_add(&FRCCmp::NEW);
00052 }
00053 
00054 void Cmp::validate_input_args(const EMData * image, const EMData *with) const
00055 {
00056         if (!image) {
00057                 throw NullPointerException("compared image");
00058         }
00059         if (!with) {
00060                 throw NullPointerException("compare-with image");
00061         }
00062 
00063         if (!EMUtil::is_same_size(image, with)) {
00064                 throw ImageFormatException( "images not same size");
00065         }
00066 
00067         float *d1 = image->get_data();
00068         if (!d1) {
00069                 throw NullPointerException("image contains no data");
00070         }
00071 
00072         float *d2 = with->get_data();
00073         if (!d2) {
00074                 throw NullPointerException("compare-with image data");
00075         }
00076 }
00077 
00078 //  It would be good to add code for complex images!  PAP
00079 float CccCmp::cmp(EMData * image, EMData *with) const
00080 {
00081         ENTERFUNC;
00082         if (image->is_complex() || with->is_complex())
00083                 throw ImageFormatException( "Complex images not supported by CMP::CccCmp");
00084         validate_input_args(image, with);
00085 
00086         const float *const d1 = image->get_const_data();
00087         const float *const d2 = with->get_const_data();
00088 
00089         float negative = (float)params.set_default("negative", 1);
00090         if (negative) negative=-1.0; else negative=1.0;
00091 
00092         double avg1 = 0.0, var1 = 0.0, avg2 = 0.0, var2 = 0.0, ccc = 0.0;
00093         long n = 0;
00094         size_t totsize = image->get_xsize()*image->get_ysize()*image->get_zsize();
00095 
00096         bool has_mask = false;
00097         EMData* mask = 0;
00098         if (params.has_key("mask")) {
00099                 mask = params["mask"];
00100                 if(mask!=0) {has_mask=true;}
00101         }
00102 
00103         if (has_mask) {
00104                 const float *const dm = mask->get_const_data();
00105                 for (size_t i = 0; i < totsize; ++i) {
00106                         if (dm[i] > 0.5) {
00107                                 avg1 += double(d1[i]);
00108                                 var1 += d1[i]*double(d1[i]);
00109                                 avg2 += double(d2[i]);
00110                                 var2 += d2[i]*double(d2[i]);
00111                                 ccc += d1[i]*double(d2[i]);
00112                                 n++;
00113                         }
00114                 }
00115         } else {
00116                 for (size_t i = 0; i < totsize; ++i) {
00117                         avg1 += double(d1[i]);
00118                         var1 += d1[i]*double(d1[i]);
00119                         avg2 += double(d2[i]);
00120                         var2 += d2[i]*double(d2[i]);
00121                         ccc += d1[i]*double(d2[i]);
00122                 }
00123                 n = totsize;
00124         }
00125 
00126         avg1 /= double(n);
00127         var1 = var1/double(n) - avg1*avg1;
00128         avg2 /= double(n);
00129         var2 = var2/double(n) - avg2*avg2;
00130         ccc = ccc/double(n) - avg1*avg2;
00131         ccc /= sqrt(var1*var2);
00132         ccc *= negative;
00133         return static_cast<float>(ccc);
00134         EXITFUNC;
00135 }
00136 
00137 
00138 
00139 float SqEuclideanCmp::cmp(EMData * image, EMData *with) const
00140 {
00141         ENTERFUNC;
00142         validate_input_args(image, with);
00143 
00144         const float *const y_data = with->get_const_data();
00145         const float *const x_data = image->get_const_data();
00146         double result = 0.;
00147         float n = 0.0f;
00148         if(image->is_complex() && with->is_complex()) {
00149         // Implemented by PAP  01/09/06 - please do not change.  If in doubts, write/call me.
00150                 int nx  = with->get_xsize();
00151                 int ny  = with->get_ysize();
00152                 int nz  = with->get_zsize();
00153                 nx = (nx - 2 + with->is_fftodd()); // nx is the real-space size of the input image
00154                 int lsd2 = (nx + 2 - nx%2) ; // Extended x-dimension of the complex image
00155 
00156                 int ixb = 2*((nx+1)%2);
00157                 int iyb = ny%2;
00158                 //
00159                 if(nz == 1) {
00160                 //  it looks like it could work in 3D, but it is not, really.
00161                 for ( int iz = 0; iz <= nz-1; iz++) {
00162                         double part = 0.;
00163                         for ( int iy = 0; iy <= ny-1; iy++) {
00164                                 for ( int ix = 2; ix <= lsd2 - 1 - ixb; ix++) {
00165                                                 size_t ii = ix + (iy  + iz * ny)* lsd2;
00166                                                 part += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00167                                 }
00168                         }
00169                         for ( int iy = 1; iy <= ny/2-1 + iyb; iy++) {
00170                                 size_t ii = (iy  + iz * ny)* lsd2;
00171                                 part += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00172                                 part += (x_data[ii+1] - y_data[ii+1])*double(x_data[ii+1] - y_data[ii+1]);
00173                         }
00174                         if(nx%2 == 0) {
00175                                 for ( int iy = 1; iy <= ny/2-1 + iyb; iy++) {
00176                                         size_t ii = lsd2 - 2 + (iy  + iz * ny)* lsd2;
00177                                         part += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00178                                         part += (x_data[ii+1] - y_data[ii+1])*double(x_data[ii+1] - y_data[ii+1]);
00179                                 }
00180 
00181                         }
00182                         part *= 2;
00183                         part += (x_data[0] - y_data[0])*double(x_data[0] - y_data[0]);
00184                         if(ny%2 == 0) {
00185                                 int ii = (ny/2  + iz * ny)* lsd2;
00186                                 part += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00187                         }
00188                         if(nx%2 == 0) {
00189                                 int ii = lsd2 - 2 + (0  + iz * ny)* lsd2;
00190                                 part += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00191                                 if(ny%2 == 0) {
00192                                         int ii = lsd2 - 2 +(ny/2  + iz * ny)* lsd2;
00193                                         part += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00194                                 }
00195                         }
00196                         result += part;
00197                 }
00198                 n = (float)nx*(float)ny*(float)nz*(float)nx*(float)ny*(float)nz;
00199 
00200                 }else{ //This 3D code is incorrect, but it is the best I can do now 01/09/06 PAP
00201                 int ky, kz;
00202                 int ny2 = ny/2; int nz2 = nz/2;
00203                 for ( int iz = 0; iz <= nz-1; iz++) {
00204                         if(iz>nz2) kz=iz-nz; else kz=iz;
00205                         for ( int iy = 0; iy <= ny-1; iy++) {
00206                                 if(iy>ny2) ky=iy-ny; else ky=iy;
00207                                 for ( int ix = 0; ix <= lsd2-1; ix++) {
00208                                 // Skip Friedel related values
00209                                 if(ix>0 || (kz>=0 && (ky>=0 || kz!=0))) {
00210                                                 size_t ii = ix + (iy  + iz * ny)* lsd2;
00211                                                 result += (x_data[ii] - y_data[ii])*double(x_data[ii] - y_data[ii]);
00212                                         }
00213                                 }
00214                         }
00215                 }
00216                 n = ((float)nx*(float)ny*(float)nz*(float)nx*(float)ny*(float)nz)/2.0f;
00217                 }
00218         } else {
00219                 size_t totsize = image->get_xsize()*image->get_ysize()*image->get_zsize();
00220                 if (params.has_key("mask")) {
00221                   EMData* mask;
00222                   mask = params["mask"];
00223                   const float *const dm = mask->get_const_data();
00224                   for (size_t i = 0; i < totsize; i++) {
00225                            if (dm[i] > 0.5) {
00226                                 double temp = x_data[i]- y_data[i];
00227                                 result += temp*temp;
00228                                 n++;
00229                            }
00230                   }
00231                 } else {
00232                   for (size_t i = 0; i < totsize; i++) {
00233                                 double temp = x_data[i]- y_data[i];
00234                                 result += temp*temp;
00235                    }
00236                    n = (float)totsize;
00237                 }
00238         }
00239         result/=n;
00240 
00241         EXITFUNC;
00242         return static_cast<float>(result);
00243 }
00244 
00245 
00246 // Even though this uses doubles, it might be wise to recode it row-wise
00247 // to avoid numerical errors on large images
00248 float DotCmp::cmp(EMData* image, EMData* with) const
00249 {
00250         ENTERFUNC;
00251         validate_input_args(image, with);
00252 
00253         const float *const x_data = image->get_const_data();
00254         const float *const y_data = with->get_const_data();
00255 
00256         int normalize = params.set_default("normalize", 0);
00257         float negative = (float)params.set_default("negative", 1);
00258 
00259         if (negative) negative=-1.0; else negative=1.0;
00260         double result = 0.;
00261         long n = 0;
00262         if(image->is_complex() && with->is_complex()) {
00263         // Implemented by PAP  01/09/06 - please do not change.  If in doubts, write/call me.
00264                 int nx  = with->get_xsize();
00265                 int ny  = with->get_ysize();
00266                 int nz  = with->get_zsize();
00267                 nx = (nx - 2 + with->is_fftodd()); // nx is the real-space size of the input image
00268                 int lsd2 = (nx + 2 - nx%2) ; // Extended x-dimension of the complex image
00269 
00270                 int ixb = 2*((nx+1)%2);
00271                 int iyb = ny%2;
00272                 //
00273                 if(nz == 1) {
00274                 //  it looks like it could work in 3D, but does not
00275                 for ( int iz = 0; iz <= nz-1; ++iz) {
00276                         double part = 0.;
00277                         for ( int iy = 0; iy <= ny-1; ++iy) {
00278                                 for ( int ix = 2; ix <= lsd2 - 1 - ixb; ++ix) {
00279                                         size_t ii = ix + (iy  + iz * ny)* lsd2;
00280                                         part += x_data[ii] * double(y_data[ii]);
00281                                 }
00282                         }
00283                         for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00284                                 size_t ii = (iy  + iz * ny)* lsd2;
00285                                 part += x_data[ii] * double(y_data[ii]);
00286                                 part += x_data[ii+1] * double(y_data[ii+1]);
00287                         }
00288                         if(nx%2 == 0) {
00289                                 for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00290                                         size_t ii = lsd2 - 2 + (iy  + iz * ny)* lsd2;
00291                                         part += x_data[ii] * double(y_data[ii]);
00292                                         part += x_data[ii+1] * double(y_data[ii+1]);
00293                                 }
00294 
00295                         }
00296                         part *= 2;
00297                         part += x_data[0] * double(y_data[0]);
00298                         if(ny%2 == 0) {
00299                                 size_t ii = (ny/2  + iz * ny)* lsd2;
00300                                 part += x_data[ii] * double(y_data[ii]);
00301                         }
00302                         if(nx%2 == 0) {
00303                                 size_t ii = lsd2 - 2 + (0  + iz * ny)* lsd2;
00304                                 part += x_data[ii] * double(y_data[ii]);
00305                                 if(ny%2 == 0) {
00306                                         int ii = lsd2 - 2 +(ny/2  + iz * ny)* lsd2;
00307                                         part += x_data[ii] * double(y_data[ii]);
00308                                 }
00309                         }
00310                         result += part;
00311                 }
00312                 if( normalize ) {
00313                 //  it looks like it could work in 3D, but does not
00314                 double square_sum1 = 0., square_sum2 = 0.;
00315                 for ( int iz = 0; iz <= nz-1; ++iz) {
00316                         for ( int iy = 0; iy <= ny-1; ++iy) {
00317                                 for ( int ix = 2; ix <= lsd2 - 1 - ixb; ++ix) {
00318                                         size_t ii = ix + (iy  + iz * ny)* lsd2;
00319                                         square_sum1 += x_data[ii] * double(x_data[ii]);
00320                                         square_sum2 += y_data[ii] * double(y_data[ii]);
00321                                 }
00322                         }
00323                         for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00324                                 size_t ii = (iy  + iz * ny)* lsd2;
00325                                 square_sum1 += x_data[ii] * double(x_data[ii]);
00326                                 square_sum1 += x_data[ii+1] * double(x_data[ii+1]);
00327                                 square_sum2 += y_data[ii] * double(y_data[ii]);
00328                                 square_sum2 += y_data[ii+1] * double(y_data[ii+1]);
00329                         }
00330                         if(nx%2 == 0) {
00331                                 for ( int iy = 1; iy <= ny/2-1 + iyb; ++iy) {
00332                                         size_t ii = lsd2 - 2 + (iy  + iz * ny)* lsd2;
00333                                         square_sum1 += x_data[ii] * double(x_data[ii]);
00334                                         square_sum1 += x_data[ii+1] * double(x_data[ii+1]);
00335                                         square_sum2 += y_data[ii] * double(y_data[ii]);
00336                                         square_sum2 += y_data[ii+1] * double(y_data[ii+1]);
00337                                 }
00338 
00339                         }
00340                         square_sum1 *= 2;
00341                         square_sum1 += x_data[0] * double(x_data[0]);
00342                         square_sum2 *= 2;
00343                         square_sum2 += y_data[0] * double(y_data[0]);
00344                         if(ny%2 == 0) {
00345                                 int ii = (ny/2  + iz * ny)* lsd2;
00346                                 square_sum1 += x_data[ii] * double(x_data[ii]);
00347                                 square_sum2 += y_data[ii] * double(y_data[ii]);
00348                         }
00349                         if(nx%2 == 0) {
00350                                 int ii = lsd2 - 2 + (0  + iz * ny)* lsd2;
00351                                 square_sum1 += x_data[ii] * double(x_data[ii]);
00352                                 square_sum2 += y_data[ii] * double(y_data[ii]);
00353                                 if(ny%2 == 0) {
00354                                         int ii = lsd2 - 2 +(ny/2  + iz * ny)* lsd2;
00355                                         square_sum1 += x_data[ii] * double(x_data[ii]);
00356                                         square_sum2 += y_data[ii] * double(y_data[ii]);
00357                                 }
00358                         }
00359                 }
00360                 result /= sqrt(square_sum1*square_sum2);
00361                 } else  result /= ((float)nx*(float)ny*(float)nz*(float)nx*(float)ny*(float)nz);
00362 
00363                 } else { //This 3D code is incorrect, but it is the best I can do now 01/09/06 PAP
00364                 int ky, kz;
00365                 int ny2 = ny/2; int nz2 = nz/2;
00366                 for ( int iz = 0; iz <= nz-1; ++iz) {
00367                         if(iz>nz2) kz=iz-nz; else kz=iz;
00368                         for ( int iy = 0; iy <= ny-1; ++iy) {
00369                                 if(iy>ny2) ky=iy-ny; else ky=iy;
00370                                 for ( int ix = 0; ix <= lsd2-1; ++ix) {
00371                                         // Skip Friedel related values
00372                                         if(ix>0 || (kz>=0 && (ky>=0 || kz!=0))) {
00373                                                 size_t ii = ix + (iy  + iz * ny)* lsd2;
00374                                                 result += x_data[ii] * double(y_data[ii]);
00375                                         }
00376                                 }
00377                         }
00378                 }
00379                 if( normalize ) {
00380                 //  still incorrect
00381                 double square_sum1 = 0., square_sum2 = 0.;
00382                 int ky, kz;
00383                 int ny2 = ny/2; int nz2 = nz/2;
00384                 for ( int iz = 0; iz <= nz-1; ++iz) {
00385                         if(iz>nz2) kz=iz-nz; else kz=iz;
00386                         for ( int iy = 0; iy <= ny-1; ++iy) {
00387                                 if(iy>ny2) ky=iy-ny; else ky=iy;
00388                                 for ( int ix = 0; ix <= lsd2-1; ++ix) {
00389                                         // Skip Friedel related values
00390                                         if(ix>0 || (kz>=0 && (ky>=0 || kz!=0))) {
00391                                                 size_t ii = ix + (iy  + iz * ny)* lsd2;
00392                                                 square_sum1 += x_data[ii] * double(x_data[ii]);
00393                                                 square_sum2 += y_data[ii] * double(y_data[ii]);
00394                                         }
00395                                 }
00396                         }
00397                 }
00398                 result /= sqrt(square_sum1*square_sum2);
00399                 } else result /= ((float)nx*(float)ny*(float)nz*(float)nx*(float)ny*(float)nz/2);
00400                 }
00401         } else {
00402                 size_t totsize = image->get_xsize() * image->get_ysize() * image->get_zsize();
00403 
00404                 double square_sum1 = 0., square_sum2 = 0.;
00405 
00406                 if (params.has_key("mask")) {
00407                         EMData* mask;
00408                         mask = params["mask"];
00409                         const float *const dm = mask->get_const_data();
00410                         if (normalize) {
00411                                 for (size_t i = 0; i < totsize; i++) {
00412                                         if (dm[i] > 0.5) {
00413                                                 square_sum1 += x_data[i]*double(x_data[i]);
00414                                                 square_sum2 += y_data[i]*double(y_data[i]);
00415                                                 result += x_data[i]*double(y_data[i]);
00416                                         }
00417                                 }
00418                         } else {
00419                                 for (size_t i = 0; i < totsize; i++) {
00420                                         if (dm[i] > 0.5) {
00421                                                 result += x_data[i]*double(y_data[i]);
00422                                                 n++;
00423                                         }
00424                                 }
00425                         }
00426                 } else {
00427                         // this little bit of manual loop unrolling makes the dot product as fast as sqeuclidean with -O2
00428                         for (size_t i=0; i<totsize; i++) result+=x_data[i]*y_data[i];
00429 
00430                         if (normalize) {
00431                                 square_sum1 = image->get_attr("square_sum");
00432                                 square_sum2 = with->get_attr("square_sum");
00433                         } else n = totsize;
00434                 }
00435                 if (normalize) result /= (sqrt(square_sum1*square_sum2)); else result /= n;
00436         }
00437 
00438 
00439         EXITFUNC;
00440         return (float) (negative*result);
00441 }
00442 
00443 
00444 float TomoDotCmp::cmp(EMData * image, EMData *with) const
00445 {
00446         ENTERFUNC;
00447         float threshold = params.set_default("threshold",0.f);
00448         if (threshold < 0.0f) throw InvalidParameterException("The threshold parameter must be greater than or equal to zero");
00449 
00450         if ( threshold > 0) {
00451                 EMData* ccf = params.set_default("ccf",(EMData*) NULL);
00452                 bool ccf_ownership = false;
00453                 if (!ccf) {
00454                         ccf = image->calc_ccf(with);
00455                         ccf_ownership = true;
00456                 }
00457                 bool norm = params.set_default("norm",false);
00458                 if (norm) ccf->process_inplace("normalize");
00459                 int tx = params.set_default("tx",0); int ty = params.set_default("ty",0); int tz = params.set_default("tz",0);
00460                 float best_score = ccf->get_value_at_wrap(tx,ty,tz)/static_cast<float>(image->get_size());
00461                 EMData* ccf_fft = ccf->do_fft();// so cuda works, or else we could do an fft_inplace - though honestly doing an fft inplace is less efficient anyhow
00462                 if (ccf_ownership) delete ccf; ccf = 0;
00463                 ccf_fft->process_inplace("threshold.binary.fourier",Dict("value",threshold));
00464                 float map_sum =  ccf_fft->get_attr("mean");
00465                 if (map_sum == 0.0f) throw UnexpectedBehaviorException("The number of voxels in the Fourier image with an amplitude above your threshold is zero. Please adjust your parameters");
00466                 best_score /= map_sum;
00467                 delete ccf_fft; ccf_fft = 0;
00468                 return -best_score;
00469         } else {
00470                 return -image->dot(with);
00471         }
00472 
00473 
00474 }
00475 
00476 // Even though this uses doubles, it might be wise to recode it row-wise
00477 // to avoid numerical errors on large images
00478 float QuadMinDotCmp::cmp(EMData * image, EMData *with) const
00479 {
00480         ENTERFUNC;
00481         validate_input_args(image, with);
00482 
00483         if (image->get_zsize()!=1) throw InvalidValueException(0, "QuadMinDotCmp supports 2D only");
00484 
00485         int nx=image->get_xsize();
00486         int ny=image->get_ysize();
00487 
00488         int normalize = params.set_default("normalize", 0);
00489         float negative = (float)params.set_default("negative", 1);
00490 
00491         if (negative) negative=-1.0; else negative=1.0;
00492 
00493         double result[4] = { 0,0,0,0 }, sq1[4] = { 0,0,0,0 }, sq2[4] = { 0,0,0,0 } ;
00494 
00495         vector<int> image_saved_offsets = image->get_array_offsets();
00496         vector<int> with_saved_offsets = with->get_array_offsets();
00497         image->set_array_offsets(-nx/2,-ny/2);
00498         with->set_array_offsets(-nx/2,-ny/2);
00499         int i,x,y;
00500         for (y=-ny/2; y<ny/2; y++) {
00501                 for (x=-nx/2; x<nx/2; x++) {
00502                         int quad=(x<0?0:1) + (y<0?0:2);
00503                         result[quad]+=(*image)(x,y)*(*with)(x,y);
00504                         if (normalize) {
00505                                 sq1[quad]+=(*image)(x,y)*(*image)(x,y);
00506                                 sq2[quad]+=(*with)(x,y)*(*with)(x,y);
00507                         }
00508                 }
00509         }
00510         image->set_array_offsets(image_saved_offsets);
00511         with->set_array_offsets(with_saved_offsets);
00512 
00513         if (normalize) {
00514                 for (i=0; i<4; i++) result[i]/=sqrt(sq1[i]*sq2[i]);
00515         } else {
00516                 for (i=0; i<4; i++) result[i]/=nx*ny/4;
00517         }
00518 
00519         float worst=static_cast<float>(result[0]);
00520         for (i=1; i<4; i++) if (static_cast<float>(result[i])<worst) worst=static_cast<float>(result[i]);
00521 
00522         EXITFUNC;
00523         return (float) (negative*worst);
00524 }
00525 
00526 float OptVarianceCmp::cmp(EMData * image, EMData *with) const
00527 {
00528         ENTERFUNC;
00529         validate_input_args(image, with);
00530 
00531         int keepzero = params.set_default("keepzero", 1);
00532         int invert = params.set_default("invert",0);
00533         int matchfilt = params.set_default("matchfilt",1);
00534         int matchamp = params.set_default("matchamp",0);
00535         int radweight = params.set_default("radweight",0);
00536         int dbug = params.set_default("debug",0);
00537 
00538         size_t size = image->get_xsize() * image->get_ysize() * image->get_zsize();
00539 
00540 
00541         EMData *with2=NULL;
00542         if (matchfilt) {
00543                 EMData *a = image->do_fft();
00544                 EMData *b = with->do_fft();
00545 
00546                 vector <float> rfa=a->calc_radial_dist(a->get_ysize()/2,0.0f,1.0f,1);
00547                 vector <float> rfb=b->calc_radial_dist(b->get_ysize()/2,0.0f,1.0f,1);
00548 
00549                 float avg=0;
00550                 for (size_t i=0; i<a->get_ysize()/2.0f; i++) {
00551                         rfa[i]=(rfb[i]==0?0.0f:(rfa[i]/rfb[i]));
00552                         avg+=rfa[i];
00553                 }
00554 
00555                 avg/=a->get_ysize()/2.0f;
00556                 for (size_t i=0; i<a->get_ysize()/2.0f; i++) {
00557                         if (rfa[i]>avg*10.0) rfa[i]=10.0;                       // If some particular location has a small but non-zero value, we don't want to overcorrect it
00558                 }
00559                 rfa[0]=0.0;
00560 
00561                 if (dbug) b->write_image("a.hdf",-1);
00562 
00563                 b->apply_radial_func(0.0f,1.0f/a->get_ysize(),rfa);
00564                 with2=b->do_ift();
00565 
00566                 if (dbug) b->write_image("a.hdf",-1);
00567                 if (dbug) a->write_image("a.hdf",-1);
00568 
00569 /*              if (dbug) {
00570                         FILE *out=fopen("a.txt","w");
00571                         for (int i=0; i<a->get_ysize()/2.0; i++) fprintf(out,"%d\t%f\n",i,rfa[i]);
00572                         fclose(out);
00573 
00574                         out=fopen("b.txt","w");
00575                         for (int i=0; i<a->get_ysize()/2.0; i++) fprintf(out,"%d\t%f\n",i,rfb[i]);
00576                         fclose(out);
00577                 }*/
00578 
00579 
00580                 delete a;
00581                 delete b;
00582 
00583                 if (dbug) {
00584                         with2->write_image("a.hdf",-1);
00585                         image->write_image("a.hdf",-1);
00586                 }
00587 
00588 //              with2->process_inplace("matchfilt",Dict("to",this));
00589 //              x_data = with2->get_data();
00590         }
00591 
00592         // This applies the individual Fourier amplitudes from 'image' and
00593         // applies them to 'with'
00594         if (matchamp) {
00595                 EMData *a = image->do_fft();
00596                 EMData *b = with->do_fft();
00597                 size_t size2 = a->get_xsize() * a->get_ysize() * a->get_zsize();
00598 
00599                 a->ri2ap();
00600                 b->ri2ap();
00601 
00602                 const float *const ad=a->get_const_data();
00603                 float * bd=b->get_data();
00604 
00605                 for (size_t i=0; i<size2; i+=2) bd[i]=ad[i];
00606                 b->update();
00607 
00608                 b->ap2ri();
00609                 with2=b->do_ift();
00610 //with2->write_image("a.hdf",-1);
00611                 delete a;
00612                 delete b;
00613         }
00614 
00615         const float * x_data;
00616         if (with2) x_data=with2->get_const_data();
00617         else x_data = with->get_const_data();
00618         const float *const y_data = image->get_const_data();
00619 
00620         size_t nx = image->get_xsize();
00621         float m = 0;
00622         float b = 0;
00623 
00624         // This will write the x vs y file used to calculate the density
00625         // optimization. This behavior may change in the future
00626         if (dbug) {
00627                 FILE *out=fopen("dbug.optvar.txt","w");
00628                 if (out) {
00629                         for (size_t i=0; i<size; i++) {
00630                                 if ( !keepzero || (x_data[i] && y_data[i])) fprintf(out,"%g\t%g\n",x_data[i],y_data[i]);
00631                         }
00632                         fclose(out);
00633                 }
00634         }
00635 
00636 
00637         Util::calc_least_square_fit(size, x_data, y_data, &m, &b, keepzero);
00638         if (m == 0) {
00639                 m = FLT_MIN;
00640         }
00641         b = -b / m;
00642         m = 1.0f / m;
00643 
00644         // While negative slopes are really not a valid comparison in most cases, we
00645         // still want to detect these instances, so this if is removed
00646 /*      if (m < 0) {
00647                 b = 0;
00648                 m = 1000.0;
00649         }*/
00650 
00651         double  result = 0;
00652         int count = 0;
00653 
00654         if (radweight) {
00655                 if (image->get_zsize()!=1) throw ImageDimensionException("radweight option is 2D only");
00656                 if (keepzero) {
00657                         for (size_t i = 0,y=0; i < size; y++) {
00658                                 for (size_t x=0; x<nx; i++,x++) {
00659                                         if (y_data[i] && x_data[i]) {
00660 #ifdef  _WIN32
00661                                                 if (invert) result += Util::square(x_data[i] - (y_data[i]-b)/m)*(_hypot((float)x,(float)y)+nx/4.0);
00662                                                 else result += Util::square((x_data[i] * m) + b - y_data[i])*(_hypot((float)x,(float)y)+nx/4.0);
00663 #else
00664                                                 if (invert) result += Util::square(x_data[i] - (y_data[i]-b)/m)*(hypot((float)x,(float)y)+nx/4.0);
00665                                                 else result += Util::square((x_data[i] * m) + b - y_data[i])*(hypot((float)x,(float)y)+nx/4.0);
00666 #endif
00667                                                 count++;
00668                                         }
00669                                 }
00670                         }
00671                         result/=count;
00672                 }
00673                 else {
00674                         for (size_t i = 0,y=0; i < size; y++) {
00675                                 for (size_t x=0; x<nx; i++,x++) {
00676 #ifdef  _WIN32
00677                                         if (invert) result += Util::square(x_data[i] - (y_data[i]-b)/m)*(_hypot((float)x,(float)y)+nx/4.0);
00678                                         else result += Util::square((x_data[i] * m) + b - y_data[i])*(_hypot((float)x,(float)y)+nx/4.0);
00679 #else
00680                                         if (invert) result += Util::square(x_data[i] - (y_data[i]-b)/m)*(hypot((float)x,(float)y)+nx/4.0);
00681                                         else result += Util::square((x_data[i] * m) + b - y_data[i])*(hypot((float)x,(float)y)+nx/4.0);
00682 #endif
00683                                 }
00684                         }
00685                         result = result / size;
00686                 }
00687         }
00688         else {
00689                 if (keepzero) {
00690                         for (size_t i = 0; i < size; i++) {
00691                                 if (y_data[i] && x_data[i]) {
00692                                         if (invert) result += Util::square(x_data[i] - (y_data[i]-b)/m);
00693                                         else result += Util::square((x_data[i] * m) + b - y_data[i]);
00694                                         count++;
00695                                 }
00696                         }
00697                         result/=count;
00698                 }
00699                 else {
00700                         for (size_t i = 0; i < size; i++) {
00701                                 if (invert) result += Util::square(x_data[i] - (y_data[i]-b)/m);
00702                                 else result += Util::square((x_data[i] * m) + b - y_data[i]);
00703                         }
00704                         result = result / size;
00705                 }
00706         }
00707         scale = m;
00708         shift = b;
00709 
00710         image->set_attr("ovcmp_m",m);
00711         image->set_attr("ovcmp_b",b);
00712         if (with2) delete with2;
00713         EXITFUNC;
00714 
00715 #if 0
00716         return (1 - result);
00717 #endif
00718 
00719         return static_cast<float>(result);
00720 }
00721 
00722 float PhaseCmp::cmp(EMData * image, EMData *with) const
00723 {
00724         ENTERFUNC;
00725 
00726 #ifdef EMAN2_USING_CUDA
00727         if (image->gpu_operation_preferred()) {
00728 //              cout << "Cuda cmp" << endl;
00729                 EXITFUNC;
00730                 return cuda_cmp(image,with);
00731         }
00732 #endif
00733         validate_input_args(image, with);
00734 
00735         static float *dfsnr = 0;
00736         static int nsnr = 0;
00737 
00738 //      if (image->get_zsize() > 1) {
00739 //              throw ImageDimensionException("2D only");
00740 //      }
00741 
00742         //int nx = image->get_xsize();
00743         int ny = image->get_ysize();
00744         //int nz = image->get_zsize();
00745 
00746         int np = (int) ceil(Ctf::CTFOS * sqrt(2.0f) * ny / 2) + 2;
00747 
00748         if (nsnr != np) {
00749                 nsnr = np;
00750                 dfsnr = (float *) realloc(dfsnr, np * sizeof(float));
00751 
00752                 //float w = Util::square(nx / 8.0f); // <- Not used currently
00753 
00754                 for (int i = 0; i < np; i++) {
00755 //                      float x2 = Util::square(i / (float) Ctf::CTFOS);
00756 //                      dfsnr[i] = (1.0f - exp(-x2 / 4.0f)) * exp(-x2 / w);
00757                         float x2 = 10.0f*i/np;
00758                         dfsnr[i] = x2 * exp(-x2);
00759                 }
00760 
00761 //              Util::save_data(0, 1.0f / Ctf::CTFOS, dfsnr, np, "filt.txt");
00762         }
00763 
00764         EMData *image_fft = image->do_fft();
00765         image_fft->ri2ap();
00766         EMData *with_fft = with->do_fft();
00767         with_fft->ri2ap();
00768 
00769         const float *const image_fft_data = image_fft->get_const_data();
00770         const float *const with_fft_data = with_fft->get_const_data();
00771         double sum = 0;
00772         double norm = FLT_MIN;
00773         size_t i = 0;
00774 
00775         for (int z = 0; z < image_fft->get_zsize(); ++z){
00776                 for (int y = 0; y < image_fft->get_ysize(); ++y) {
00777                         for (int x = 0; x < image_fft->get_xsize(); x += 2) {
00778                                 int r;
00779 //                              if ( nz == 1 ) {
00780                                         if (y<ny/2) r = Util::round(Util::hypot_fast(x / 2, y) * Ctf::CTFOS);
00781                                         else r = Util::round(Util::hypot_fast(x / 2, y-ny) * Ctf::CTFOS);
00782 
00783                                 float a = dfsnr[r] * with_fft_data[i];
00784 //                              cout << a << " " << Util::angle_sub_2pi(image_fft_data[i + 1], with_fft_data[i + 1]) << " " <<image_fft_data[i + 1] << " " << with_fft_data[i + 1] << endl;
00785                                 sum += Util::angle_sub_2pi(image_fft_data[i + 1], with_fft_data[i + 1]) * a;
00786                                 norm += a;
00787                                 i += 2;
00788                         }
00789                 }
00790         }
00791         EXITFUNC;
00792 
00793         if( image_fft )
00794         {
00795                 delete image_fft;
00796                 image_fft = 0;
00797         }
00798         if( with_fft )
00799         {
00800                 delete with_fft;
00801                 with_fft = 0;
00802         }
00803 #if 0
00804         return (1.0f - sum / norm);
00805 #endif
00806         return (float)(sum / norm);
00807 }
00808 
00809 #ifdef EMAN2_USING_CUDA
00810 #include "cuda/cuda_cmp.h"
00811 float PhaseCmp::cuda_cmp(EMData * image, EMData *with) const
00812 {
00813         ENTERFUNC;
00814         validate_input_args(image, with);
00815 
00816         typedef vector<EMData*> EMDatas;
00817         static EMDatas hist_pyramid;
00818         static EMDatas norm_pyramid;
00819         static EMData weighting;
00820         static int image_size = 0;
00821 
00822         int size;
00823         EMData::CudaDataLock imagelock(image);
00824         EMData::CudaDataLock withlock(with);
00825 
00826         if (image->is_complex()) {
00827                 size = image->get_xsize();
00828         } else {
00829                 int nx = image->get_xsize()+2;
00830                 nx -= nx%2;
00831                 size = nx*image->get_ysize()*image->get_zsize();
00832         }
00833         if (size != image_size) {
00834                 for(unsigned int i =0; i < hist_pyramid.size(); ++i) {
00835                         delete hist_pyramid[i];
00836                         delete norm_pyramid[i];
00837                 }
00838                 hist_pyramid.clear();
00839                 norm_pyramid.clear();
00840                 int s = size;
00841                 if (s < 1) throw UnexpectedBehaviorException("The image is 0 size");
00842                 int p2 = 1;
00843                 while ( s != 1 ) {
00844                         s /= 2;
00845                         p2 *= 2;
00846                 }
00847                 if ( p2 != size ) {
00848                         p2 *= 2;
00849                         s = p2;
00850                 }
00851                 if (s != 1) s /= 2;
00852                 while (true) {
00853                         EMData* h = new EMData();
00854                         h->set_size_cuda(s); h->to_value(0.0);
00855                         hist_pyramid.push_back(h);
00856                         EMData* n = new EMData();
00857                         n->set_size_cuda(s); n->to_value(0.0);
00858                         norm_pyramid.push_back(n);
00859                         if ( s == 1) break;
00860                         s /= 2;
00861                 }
00862                 int nx = image->get_xsize()+2;
00863                 nx -= nx%2; // for Fourier stuff
00864                 int ny = image->get_ysize();
00865                 int nz = image->get_zsize();
00866                 weighting.set_size_cuda(nx,ny,nz);
00867                 // Size of weighting need only be half this, but does that translate into faster code?
00868                 weighting.set_size_cuda(nx/2,ny,nz);
00869                 float np = (int) ceil(Ctf::CTFOS * sqrt(2.0f) * ny / 2) + 2;
00870                 EMDataForCuda tmp = weighting.get_data_struct_for_cuda();
00871                 calc_phase_weights_cuda(&tmp,np);
00872                 //weighting.write_image("phase_wieghts.hdf");
00873                 image_size = size;
00874         }
00875 
00876         EMDataForCuda hist[hist_pyramid.size()];
00877         EMDataForCuda norm[hist_pyramid.size()];
00878 
00879         EMDataForCuda wt = weighting.get_data_struct_for_cuda();
00880         EMData::CudaDataLock lock1(&weighting);
00881         for(unsigned int i = 0; i < hist_pyramid.size(); ++i ) {
00882                 hist[i] = hist_pyramid[i]->get_data_struct_for_cuda();
00883                 hist_pyramid[i]->cuda_lock();
00884                 norm[i] = norm_pyramid[i]->get_data_struct_for_cuda();
00885                 norm_pyramid[i]->cuda_lock();
00886         }
00887 
00888         EMData *image_fft = image->do_fft_cuda();
00889         EMDataForCuda left = image_fft->get_data_struct_for_cuda();
00890         EMData::CudaDataLock lock2(image_fft);
00891         EMData *with_fft = with->do_fft_cuda();
00892         EMDataForCuda right = with_fft->get_data_struct_for_cuda();
00893         EMData::CudaDataLock lock3(image_fft);
00894 
00895         mean_phase_error_cuda(&left,&right,&wt,hist,norm,hist_pyramid.size());
00896         float result;
00897         float* gpu_result = hist_pyramid[hist_pyramid.size()-1]->get_cuda_data();
00898         cudaError_t error = cudaMemcpy(&result,gpu_result,sizeof(float),cudaMemcpyDeviceToHost);
00899         if ( error != cudaSuccess) throw UnexpectedBehaviorException( "CudaMemcpy (host to device) in the phase comparator failed:" + string(cudaGetErrorString(error)));
00900 
00901         delete image_fft; image_fft=0;
00902         delete with_fft; with_fft=0;
00903 
00904         for(unsigned int i = 0; i < hist_pyramid.size(); ++i ) {
00905 //              hist_pyramid[i]->write_image("hist.hdf",-1); // debug
00906 //              norm_pyramid[i]->write_image("norm.hdf",-1); // debug
00907                 hist_pyramid[i]->cuda_unlock();
00908                 norm_pyramid[i]->cuda_unlock();
00909         }
00910 
00911         EXITFUNC;
00912         return result;
00913 
00914 }
00915 
00916 #endif // EMAN2_USING_CUDA
00917 
00918 
00919 float FRCCmp::cmp(EMData * image, EMData * with) const
00920 {
00921         ENTERFUNC;
00922         validate_input_args(image, with);
00923 
00924         int snrweight = params.set_default("snrweight", 0);
00925         int ampweight = params.set_default("ampweight", 0);
00926         int sweight = params.set_default("sweight", 1);
00927         int nweight = params.set_default("nweight", 0);
00928         int zeromask = params.set_default("zeromask",0);
00929 
00930         if (zeromask) {
00931                 image=image->copy();
00932                 with=with->copy();
00933                 
00934                 int sz=image->get_xsize()*image->get_ysize()*image->get_zsize();
00935                 float *d1=image->get_data();
00936                 float *d2=with->get_data();
00937                 
00938                 for (int i=0; i<sz; i++) {
00939                         if (d1[i]==0.0 || d2[i]==0.0) { d1[i]=0.0; d2[i]=0.0; }
00940                 }
00941                 
00942                 image->update();
00943                 with->update();
00944                 image->do_fft_inplace();
00945                 with->do_fft_inplace();
00946                 image->set_attr("free_me",1); 
00947                 with->set_attr("free_me",1); 
00948         }
00949 
00950 
00951         if (!image->is_complex()) {
00952                 image=image->do_fft(); 
00953                 image->set_attr("free_me",1); 
00954         }
00955         if (!with->is_complex()) { 
00956                 with=with->do_fft(); 
00957                 with->set_attr("free_me",1); 
00958         }
00959 
00960         static vector < float >default_snr;
00961 
00962 //      if (image->get_zsize() > 1) {
00963 //              throw ImageDimensionException("2D only");
00964 //      }
00965 
00966 //      int nx = image->get_xsize();
00967         int ny = image->get_ysize();
00968         int ny2=ny/2+1;
00969 
00970         vector < float >fsc;
00971 
00972                 
00973 
00974         fsc = image->calc_fourier_shell_correlation(with,1);
00975 
00976         // The fast hypot here was supposed to speed things up. Little effect
00977 //      if (image->get_zsize()>1) fsc = image->calc_fourier_shell_correlation(with,1);
00978 //      else {
00979 //              double *sxy = (double *)malloc(ny2*sizeof(double)*4);
00980 //              double *sxx = sxy+ny2;
00981 //              double *syy = sxy+2*ny2;
00982 //              double *norm= sxy+3*ny2;
00983 //
00984 //              float *df1=image->get_data();
00985 //              float *df2=with->get_data();
00986 //              int nx2=image->get_xsize();
00987 //
00988 //              for (int y=-ny/2; y<ny/2; y++) {
00989 //                      for (int x=0; x<nx2/2; x++) {
00990 //                              if (x==0 && y<0) continue;      // skip Friedel pair
00991 //                              short r=Util::hypot_fast_int(x,y);
00992 //                              if (r>ny2-1) continue;
00993 //                              int l=x*2+(y<0?ny+y:y)*nx2;
00994 //                              sxy[r]+=df1[l]*df2[l]+df1[l+1]*df2[l+1];
00995 //                              sxx[r]+=df1[l]*df1[l];
00996 //                              syy[r]+=df2[l]*df2[l];
00997 //                              norm[r]+=1.0;
00998 //                      }
00999 //              }
01000 //              fsc.resize(ny2*3);
01001 //              for (int r=0; r<ny2; r++) {
01002 //                      fsc[r]=r*0.5/ny2;
01003 //                      fsc[ny2+r]=sxy[r]/(sqrt(sxx[r])*sqrt(syy[r]));
01004 //                      fsc[ny2*2+r]=norm[r];
01005 //              }
01006 //              free(sxy);
01007 //      }
01008 
01009         vector<float> snr;
01010         if (snrweight) {
01011                 Ctf *ctf = NULL;
01012                 if (!image->has_attr("ctf")) {
01013                         if (!with->has_attr("ctf")) throw InvalidCallException("SNR weight with no CTF parameters");
01014                         ctf=with->get_attr("ctf");
01015                 }
01016                 else ctf=image->get_attr("ctf");
01017 
01018                 float ds=1.0f/(ctf->apix*ny);
01019                 snr=ctf->compute_1d(ny,ds,Ctf::CTF_SNR);
01020                 if(ctf) {delete ctf; ctf=0;}
01021         }
01022 
01023         vector<float> amp;
01024         if (ampweight) amp=image->calc_radial_dist(ny/2,0,1,0);
01025 
01026         double sum=0.0, norm=0.0;
01027 
01028         for (int i=0; i<ny/2; i++) {
01029                 double weight=1.0;
01030                 if (sweight) weight*=fsc[(ny2)*2+i];
01031                 if (ampweight) weight*=amp[i];
01032                 if (snrweight) weight*=snr[i];
01033                 sum+=weight*fsc[ny2+i];
01034                 norm+=weight;
01035 //              printf("%d\t%f\t%f\n",i,weight,fsc[ny/2+1+i]);
01036         }
01037 
01038         // This performs a weighting that tries to normalize FRC by correcting from the number of particles represented by the average
01039         sum/=norm;
01040         if (nweight && with->get_attr_default("ptcl_repr",0) && sum>=0 && sum<1.0) {
01041                 sum=sum/(1.0-sum);                                                      // convert to SNR
01042                 sum/=(float)with->get_attr_default("ptcl_repr",0);      // divide by ptcl represented
01043                 sum=sum/(1.0+sum);                                                      // convert back to correlation
01044         }
01045 
01046         if (image->has_attr("free_me")) delete image;
01047         if (with->has_attr("free_me")) delete with;
01048 
01049         EXITFUNC;
01050 
01051 
01052         //.Note the negative! This is because EMAN2 follows the convention that
01053         // smaller return values from comparitors indicate higher similarity -
01054         // this enables comparitors to be used in a generic fashion.
01055         return (float)-sum;
01056 }
01057 
01058 void EMAN::dump_cmps()
01059 {
01060         dump_factory < Cmp > ();
01061 }
01062 
01063 map<string, vector<string> > EMAN::dump_cmps_list()
01064 {
01065         return dump_factory_list < Cmp > ();
01066 }
01067 
01068 /* vim: set ts=4 noet: */

Generated on Sat Nov 21 02:19:14 2009 for EMAN2 by  doxygen 1.5.6