fft.c
上传用户:jsljixie
上传日期:2013-08-15
资源大小:827k
文件大小:10k
源码类别:

并行计算

开发平台:

Visual C++

  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <math.h>
  4. #include "mpi.h"
  5. #define MAX_N 50
  6. #define PI    3.1415926535897932
  7. #define EPS   10E-8
  8. #define V_TAG 99
  9. #define P_TAG 100
  10. #define Q_TAG 101
  11. #define R_TAG 102
  12. #define S_TAG 103
  13. #define S_TAG2 104
  14. typedef enum {FALSE, TRUE}
  15. BOOL;
  16. typedef struct
  17. {
  18.     double r;
  19.     double i;
  20. } complex_t;
  21. complex_t p[MAX_N], q[MAX_N], s[2*MAX_N], r[2*MAX_N];
  22. complex_t w[2*MAX_N];
  23. int variableNum;
  24. double transTime = 0, totalTime = 0, beginTime;
  25. MPI_Status status;
  26. void comp_add(complex_t* result, const complex_t* c1, const complex_t* c2)
  27. {
  28.     result->r = c1->r + c2->r;
  29.     result->i = c1->i + c2->i;
  30. }
  31. void comp_multiply(complex_t* result, const complex_t* c1, const complex_t* c2)
  32. {
  33.     result->r = c1->r * c2->r - c1->i * c2->i;
  34.     result->i = c1->r * c2->i + c2->r * c1->i;
  35. }
  36. /*
  37.  * Function:    shuffle
  38.  * Description: 移动f中从beginPos到endPos位置的元素,使之按位置奇偶
  39.  *              重新排列。举例说明:假设数组f,beginPos=2, endPos=5
  40.  *              则shuffle函数的运行结果为f[2..5]重新排列,排列后各个
  41.  *              位置对应的原f的元素为: f[2],f[4],f[3],f[5]
  42.  * Parameters:  f为被操作数组首地址
  43.  *              beginPos, endPos为操作的下标范围
  44.  */
  45. void shuffle(complex_t* f, int beginPos, int endPos)
  46. {
  47.     int i;
  48.     complex_t temp[2*MAX_N];
  49.     for(i = beginPos; i <= endPos; i ++)
  50.     {
  51.         temp[i] = f[i];
  52.     }
  53.     int j = beginPos;
  54.     for(i = beginPos; i <= endPos; i +=2)
  55.     {
  56.         f[j] = temp[i];
  57.         j++;
  58.     }
  59.     for(i = beginPos +1; i <= endPos; i += 2)
  60.     {
  61.         f[j] = temp[i];
  62.         j++;
  63.     }
  64. }
  65. /*
  66.  * Function: evaluate
  67.  * Description: 对复数序列f进行FFT或者IFFT(由x决定),结果序列为y,
  68.  *  产生leftPos 到 rightPos之间的结果元素
  69.  * Parameters: f : 原始序列数组首地址
  70.  *  beginPos : 原始序列在数组f中的第一个下标
  71.  *  endPos : 原始序列在数组f中的最后一个下标
  72.  *  x : 存放单位根的数组,其元素为w,w^2,w^3...
  73.  *  y : 输出序列
  74.  *  leftPos : 所负责计算输出的y的片断的起始下标
  75.  *  rightPos : 所负责计算输出的y的片断的终止下标
  76.  *  totalLength : y的长度
  77.  */
  78. void evaluate(complex_t* f, int beginPos, int endPos,
  79. const complex_t* x, complex_t* y,
  80. int leftPos, int rightPos, int totalLength)
  81. {
  82.     int i;
  83.     if ((beginPos > endPos)||(leftPos > rightPos))
  84.     {
  85.         printf("Error in use Polynomial!n");
  86.         exit(-1);
  87.     }
  88.     else if(beginPos == endPos)
  89.     {
  90.         for(i = leftPos; i <= rightPos; i ++)
  91.         {
  92.             y[i] = f[beginPos];
  93.         }
  94.     }
  95.     else if(beginPos + 1 == endPos)
  96.     {
  97.         for(i = leftPos; i <= rightPos; i ++)
  98.         {
  99.             complex_t temp;
  100.             comp_multiply(&temp, &f[endPos], &x[i]);
  101.             comp_add(&y[i], &f[beginPos], &temp);
  102.         }
  103.     }
  104.     else
  105.     {
  106.         complex_t tempX[2*MAX_N],tempY1[2*MAX_N], tempY2[2*MAX_N];
  107.         int midPos = (beginPos + endPos)/2;
  108.         shuffle(f, beginPos, endPos);
  109.         for(i = leftPos; i <= rightPos; i ++)
  110.         {
  111.             comp_multiply(&tempX[i], &x[i], &x[i]);
  112.         }
  113.         evaluate(f, beginPos, midPos, tempX, tempY1,
  114.             leftPos, rightPos, totalLength);
  115.         evaluate(f, midPos+1, endPos, tempX, tempY2,
  116.             leftPos, rightPos, totalLength);
  117.         for(i = leftPos; i <= rightPos; i ++)
  118.         {
  119.             complex_t temp;
  120.             comp_multiply(&temp, &x[i], &tempY2[i]);
  121.             comp_add(&y[i], &tempY1[i], &temp);
  122.         }
  123.     }
  124. }
  125. /*
  126.  * Function:    print
  127.  * Description: 打印数组元素的实部
  128.  * Parameters:  f为待打印数组的首地址
  129.  *              fLength为数组的长度
  130.  */
  131. void print(const complex_t* f, int fLength)
  132. {
  133.     BOOL isPrint = FALSE;
  134.     int i;
  135.     /* f[0] */
  136.     if (abs(f[0].r) > EPS)
  137.     {
  138.         printf("%f", f[0].r);
  139.         isPrint = TRUE;
  140.     }
  141.     for(i = 1; i < fLength; i ++)
  142.     {
  143.         if (f[i].r > EPS)
  144.         {
  145.             if (isPrint)
  146.                 printf(" + ");
  147.             else
  148.                 isPrint = TRUE;
  149.             printf("%ft^%d", f[i].r, i);
  150.         }
  151.         else if (f[i].r < - EPS)
  152.         {
  153.             if(isPrint)
  154.                 printf(" - ");
  155.             else
  156.                 isPrint = TRUE;
  157.             printf("%ft^%d", -f[i].r, i);
  158.         }
  159.     }
  160.     if (isPrint == FALSE)
  161.         printf("0");
  162.     printf("n");
  163. }
  164. /*
  165.  * Function:    myprint
  166.  * Description: 完整打印复数数组元素,包括实部和虚部
  167.  * Parameters:  f为待打印数组的首地址
  168.  *              fLength为数组的长度
  169.  */
  170. void myprint(const complex_t* f, int fLength)
  171. {
  172.     int i;
  173.     for(i=0;i<fLength;i++)
  174.     {
  175.         printf("%f+%fi , ", f[i].r, f[i].i);
  176.     }
  177.     printf("n");
  178. }
  179. /*
  180.  * Function:   addTransTime
  181.  * Description:累计发送数据所耗费的时间
  182.  * Parameters: toAdd为累加的时间
  183.  */
  184. void addTransTime(double toAdd)
  185. {
  186.     transTime += toAdd;
  187. }
  188. /*
  189.  * Function:    readFromFile
  190.  * Description: 从dataIn.txt读取数据
  191.  */
  192. BOOL readFromFile()
  193. {
  194.     int i;
  195.     FILE* fin = fopen("dataIn.txt", "r");
  196.     if (fin == NULL)
  197.     {
  198.         printf("Cannot find input data filen"
  199.             "Please create a file "dataIn.txt"n"
  200.             "2n"
  201.             "1.0  2n"
  202.             "2.0  -1n"
  203.             );
  204.         return(FALSE);
  205.     }
  206.     fscanf(fin, "%dn", &variableNum);
  207.     if ((variableNum < 1)||(variableNum > MAX_N))
  208.     {
  209.         printf("variableNum out of range!n");
  210.         return(FALSE);
  211.     }
  212.     for(i = 0; i < variableNum; i ++)
  213.     {
  214.         fscanf(fin, "%lf", &p[i].r);
  215.         p[i].i = 0.0;
  216.     }
  217.     for(i = 0; i < variableNum; i ++)
  218.     {
  219.         fscanf(fin, "%lf", &q[i].r);
  220.         q[i].i = 0.0;
  221.     }
  222.     fclose(fin);
  223.     printf("Read from data file "dataIn.txt"n");
  224.     printf("p(t) = ");
  225.     print(p, variableNum);
  226.     printf("q(t) = ");
  227.     print(q, variableNum);
  228.     return(TRUE);
  229. }
  230. /*
  231.  * Function:    sendOrigData
  232.  * Description: 把原始数据发送给其它进程
  233.  * Parameters:  size为集群中进程的数目
  234.  */
  235. void sendOrigData(int size)
  236. {
  237.     int i;
  238.     for(i = 1; i < size; i ++)
  239.     {
  240.         MPI_Send(&variableNum, 1, MPI_INT, i, V_TAG, MPI_COMM_WORLD);
  241.         MPI_Send(p, variableNum * 2, MPI_DOUBLE, i, P_TAG, MPI_COMM_WORLD);
  242.         MPI_Send(q, variableNum * 2, MPI_DOUBLE, i, Q_TAG, MPI_COMM_WORLD);
  243.     }
  244. }
  245. /*
  246.  * Function:    recvOrigData
  247.  * Description: 接受原始数据
  248.  */
  249. void recvOrigData()
  250. {
  251.     MPI_Recv(&variableNum, 1, MPI_INT, 0, V_TAG, MPI_COMM_WORLD, &status);
  252.     MPI_Recv(p, variableNum * 2, MPI_DOUBLE, 0, P_TAG, MPI_COMM_WORLD, &status);
  253.     MPI_Recv(q, variableNum * 2, MPI_DOUBLE, 0, Q_TAG,MPI_COMM_WORLD, &status);
  254. }
  255. int main(int argc, char *argv[])
  256. {
  257.     int rank,size, i;
  258.     MPI_Init(&argc,&argv);
  259.     MPI_Comm_rank(MPI_COMM_WORLD,&rank);
  260.     MPI_Comm_size(MPI_COMM_WORLD,&size);
  261.     if(rank == 0)
  262.     {
  263.         /* 0# 进程从文件dataIn.txt读入多项式p,q的阶数和系数序列 */
  264.         if(!readFromFile())
  265.             exit(-1);
  266.         /* 进程数目太多,造成每个进程平均分配不到一个元素,异常退出 */
  267.         if(size > 2*variableNum)
  268.         {
  269.             printf("Too many Processors , reduce your -np valuen");
  270.             MPI_Abort(MPI_COMM_WORLD, 1);
  271.         }
  272.         beginTime = MPI_Wtime();
  273.         /* 0#进程把多项式的阶数、p、q发送给其它进程 */
  274.         sendOrigData(size);
  275.         /* 累计传输时间 */
  276.         addTransTime(MPI_Wtime() - beginTime);
  277.     }
  278.     else                                          /* 其它进程接收进程0发送来的数据,包括variableNum、数组p和q */
  279.     {
  280.         recvOrigData();
  281.     }
  282.     /* 初始化数组w,用于进行傅立叶变换 */
  283.     int wLength = 2*variableNum;
  284.     for(i = 0; i < wLength; i ++)
  285.     {
  286.         w[i].r = cos(i*2*PI/wLength);
  287.         w[i].i = sin(i*2*PI/wLength);
  288.     }
  289.     /* 划分各个进程的工作范围 startPos ~ stopPos */
  290.     int everageLength = wLength / size;
  291.     int moreLength = wLength % size;
  292.     int startPos = moreLength + rank * everageLength;
  293.     int stopPos  = startPos + everageLength - 1;
  294.     if(rank == 0)
  295.     {
  296.         startPos = 0;
  297.         stopPos  = moreLength+everageLength - 1;
  298.     }
  299.     /* 对p作FFT,输出序列为s,每个进程仅负责计算出序列中 */
  300.     /* 位置为startPos 到 stopPos的元素 */
  301.     evaluate(p, 0, variableNum - 1, w, s, startPos, stopPos, wLength);
  302.     /* 对q作FFT,输出序列为r,每个进程仅负责计算出序列中 */
  303.     /* 位置为startPos 到 stopPos的元素 */
  304.     evaluate(q, 0, variableNum - 1, w, r, startPos, stopPos, wLength);
  305.     /* s和r作点积,结果保存在s中,同样,每个进程只计算自己范围内的部分 */
  306.     for(i = startPos; i <= stopPos ; i ++)
  307.     {
  308.         complex_t temp;
  309.         comp_multiply(&temp, &s[i], &r[i]);
  310.         s[i] = temp;
  311.         s[i].r /= wLength * 1.0;
  312.         s[i].i /= wLength * 1.0;
  313.     }
  314.     /* 各个进程都把s中自己负责计算出来的部分发送给进程0,并从进程0接收汇总的s */
  315.     if (rank > 0)
  316.     {
  317.         MPI_Send(s + startPos, everageLength * 2, MPI_DOUBLE, 0, S_TAG, MPI_COMM_WORLD);
  318.         MPI_Recv(s, wLength * 2, MPI_DOUBLE, 0, S_TAG2, MPI_COMM_WORLD, &status);
  319.     }
  320.     else
  321.     {
  322.         /* 进程0接收s片断,向其余进程发送完整的s */
  323.         double tempTime = MPI_Wtime();
  324.         for(i = 1; i < size; i ++)
  325.         {
  326.             MPI_Recv(s + moreLength + i * everageLength, everageLength * 2,
  327.                 MPI_DOUBLE, i, S_TAG, MPI_COMM_WORLD,&status);
  328.         }
  329.         for(i = 1; i < size; i ++)
  330.         {
  331.             MPI_Send(s, wLength * 2,
  332.                 MPI_DOUBLE, i,
  333.                 S_TAG2, MPI_COMM_WORLD);
  334.         }
  335.         addTransTime(MPI_Wtime() - tempTime);
  336.     }
  337.     /* swap(w[i],w[(wLength-i)%wLength]) */
  338.     /* 重新设置w,用于作逆傅立叶变换 */
  339.     complex_t temp;
  340.     for(i = 1; i < wLength/2; i ++)
  341.     {
  342.         temp = w[i];
  343.         w[i] = w[wLength - i];
  344.         w[wLength - i] = temp;
  345.     }
  346.     /* 各个进程对s作逆FFT,输出到r的相应部分 */
  347.     evaluate(s, 0, wLength - 1, w, r, startPos, stopPos, wLength);
  348.     /* 各进程把自己负责的部分的r的片断发送到进程0 */
  349.     if (rank > 0)
  350.     {
  351.         MPI_Send(r + startPos, everageLength * 2, MPI_DOUBLE,
  352.             0,R_TAG, MPI_COMM_WORLD);
  353.     }
  354.     else
  355.     {
  356.         /* 进程0接收各个片断得到完整的r,此时r就是两多项式p,q相乘的结果多项式了 */
  357.         double tempTime = MPI_Wtime();
  358.         for(i = 1; i < size; i ++)
  359.         {
  360.             MPI_Recv((r+moreLength+i*everageLength), everageLength * 2,
  361.                 MPI_DOUBLE, i, R_TAG, MPI_COMM_WORLD, &status);
  362.         }
  363.         totalTime = MPI_Wtime();
  364.         addTransTime(totalTime - tempTime);
  365.         totalTime -= beginTime;
  366.         /* 输出结果信息以及时间统计信息 */
  367.         printf("nAfter FFT r(t)=p(t)q(t)n");
  368.         printf("r(t) = ");
  369.         print(r, wLength - 1);
  370.         printf("nUse prossor size = %dn", size);
  371.         printf("Total running time = %f(s)n", totalTime);
  372.         printf("Distribute data time = %f(s)n", transTime);
  373.         printf("Parallel compute time = %f(s)n", totalTime - transTime);
  374.     }
  375.     MPI_Finalize();
  376. }