C++ Program to Implement Strassen’s Algorithm

This is a C++ Program to implement Strassen’s algorithm for matrix multiplication. In the mathematical discipline of linear algebra, the Strassen algorithm, named after Volker Strassen, is an algorithm used for matrix multiplication. It is faster than the standard matrix multiplication algorithm and is useful in practice for large matrices, but would be slower than the fastest known algorithms for extremely large matrices.

Here is source code of the C++ Program to Implement Strassen’s Algorithm. The C++ program is successfully compiled and run on a Linux system. The program output is also shown below.

  1. #include <assert.h>
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <time.h>
  5.  
  6. #define M 2
  7. #define N (1<<M)
  8.  
  9. typedef double datatype;
  10. #define DATATYPE_FORMAT "%4.2g"
  11. typedef datatype mat[N][N]; // mat[2**M,2**M]  for divide and conquer mult.
  12. typedef struct
  13. {
  14.         int ra, rb, ca, cb;
  15. } corners; // for tracking rows and columns.
  16. // A[ra..rb][ca..cb] .. the 4 corners of a matrix.
  17.  
  18. // set A[a] = I
  19. void identity(mat A, corners a)
  20. {
  21.     int i, j;
  22.     for (i = a.ra; i < a.rb; i++)
  23.         for (j = a.ca; j < a.cb; j++)
  24.             A[i][j] = (datatype) (i == j);
  25. }
  26.  
  27. // set A[a] = k
  28. void set(mat A, corners a, datatype k)
  29. {
  30.     int i, j;
  31.     for (i = a.ra; i < a.rb; i++)
  32.         for (j = a.ca; j < a.cb; j++)
  33.             A[i][j] = k;
  34. }
  35.  
  36. // set A[a] = [random(l..h)].
  37. void randk(mat A, corners a, double l, double h)
  38. {
  39.     int i, j;
  40.     for (i = a.ra; i < a.rb; i++)
  41.         for (j = a.ca; j < a.cb; j++)
  42.             A[i][j] = (datatype) (l + (h - l) * (rand() / (double) RAND_MAX));
  43. }
  44.  
  45. // Print A[a]
  46. void print(mat A, corners a, char *name)
  47. {
  48.     int i, j;
  49.     printf("%s = {\n", name);
  50.     for (i = a.ra; i < a.rb; i++)
  51.     {
  52.         for (j = a.ca; j < a.cb; j++)
  53.             printf(DATATYPE_FORMAT ", ", A[i][j]);
  54.         printf("\n");
  55.     }
  56.     printf("}\n");
  57. }
  58.  
  59. // C[c] = A[a] + B[b]
  60. void add(mat A, mat B, mat C, corners a, corners b, corners c)
  61. {
  62.     int rd = a.rb - a.ra;
  63.     int cd = a.cb - a.ca;
  64.     int i, j;
  65.     for (i = 0; i < rd; i++)
  66.     {
  67.         for (j = 0; j < cd; j++)
  68.         {
  69.             C[i + c.ra][j + c.ca] = A[i + a.ra][j + a.ca] + B[i + b.ra][j
  70.                     + b.ca];
  71.         }
  72.     }
  73. }
  74.  
  75. // C[c] = A[a] - B[b]
  76. void sub(mat A, mat B, mat C, corners a, corners b, corners c)
  77. {
  78.     int rd = a.rb - a.ra;
  79.     int cd = a.cb - a.ca;
  80.     int i, j;
  81.     for (i = 0; i < rd; i++)
  82.     {
  83.         for (j = 0; j < cd; j++)
  84.         {
  85.             C[i + c.ra][j + c.ca] = A[i + a.ra][j + a.ca] - B[i + b.ra][j
  86.                     + b.ca];
  87.         }
  88.     }
  89. }
  90.  
  91. // Return 1/4 of the matrix: top/bottom , left/right.
  92. void find_corner(corners a, int i, int j, corners *b)
  93. {
  94.     int rm = a.ra + (a.rb - a.ra) / 2;
  95.     int cm = a.ca + (a.cb - a.ca) / 2;
  96.     *b = a;
  97.     if (i == 0)
  98.         b->rb = rm; // top rows
  99.     else
  100.         b->ra = rm; // bot rows
  101.     if (j == 0)
  102.         b->cb = cm; // left cols
  103.     else
  104.         b->ca = cm; // right cols
  105. }
  106.  
  107. // Multiply: A[a] * B[b] => C[c], recursively.
  108. void mul(mat A, mat B, mat C, corners a, corners b, corners c)
  109. {
  110.     corners aii[2][2], bii[2][2], cii[2][2], p;
  111.     mat P[7], S, T;
  112.     int i, j, m, n, k;
  113.  
  114.     // Check: A[m n] * B[n k] = C[m k]
  115.     m = a.rb - a.ra;
  116.     assert(m==(c.rb-c.ra));
  117.     n = a.cb - a.ca;
  118.     assert(n==(b.rb-b.ra));
  119.     k = b.cb - b.ca;
  120.     assert(k==(c.cb-c.ca));
  121.     assert(m>0);
  122.  
  123.     if (n == 1)
  124.     {
  125.         C[c.ra][c.ca] += A[a.ra][a.ca] * B[b.ra][b.ca];
  126.         return;
  127.     }
  128.  
  129.     // Create the 12 smaller matrix indexes:
  130.     //  A00 A01   B00 B01   C00 C01
  131.     //  A10 A11   B10 B11   C10 C11
  132.     for (i = 0; i < 2; i++)
  133.     {
  134.         for (j = 0; j < 2; j++)
  135.         {
  136.             find_corner(a, i, j, &aii[i][j]);
  137.             find_corner(b, i, j, &bii[i][j]);
  138.             find_corner(c, i, j, &cii[i][j]);
  139.         }
  140.     }
  141.  
  142.     p.ra = p.ca = 0;
  143.     p.rb = p.cb = m / 2;
  144.  
  145. #define LEN(A) (sizeof(A)/sizeof(A[0]))
  146.     for (i = 0; i < LEN(P); i++)
  147.         set(P[i], p, 0);
  148.  
  149. #define ST0 set(S,p,0); set(T,p,0)
  150.  
  151.     // (A00 + A11) * (B00+B11) = S * T = P0
  152.     ST0;
  153.     add(A, A, S, aii[0][0], aii[1][1], p);
  154.     add(B, B, T, bii[0][0], bii[1][1], p);
  155.     mul(S, T, P[0], p, p, p);
  156.  
  157.     // (A10 + A11) * B00 = S * B00 = P1
  158.     ST0;
  159.     add(A, A, S, aii[1][0], aii[1][1], p);
  160.     mul(S, B, P[1], p, bii[0][0], p);
  161.  
  162.     // A00 * (B01 - B11) = A00 * T = P2
  163.     ST0;
  164.     sub(B, B, T, bii[0][1], bii[1][1], p);
  165.     mul(A, T, P[2], aii[0][0], p, p);
  166.  
  167.     // A11 * (B10 - B00) = A11 * T = P3
  168.     ST0;
  169.     sub(B, B, T, bii[1][0], bii[0][0], p);
  170.     mul(A, T, P[3], aii[1][1], p, p);
  171.  
  172.     // (A00 + A01) * B11 = S * B11 = P4
  173.     ST0;
  174.     add(A, A, S, aii[0][0], aii[0][1], p);
  175.     mul(S, B, P[4], p, bii[1][1], p);
  176.  
  177.     // (A10 - A00) * (B00 + B01) = S * T = P5
  178.     ST0;
  179.     sub(A, A, S, aii[1][0], aii[0][0], p);
  180.     add(B, B, T, bii[0][0], bii[0][1], p);
  181.     mul(S, T, P[5], p, p, p);
  182.  
  183.     // (A01 - A11) * (B10 + B11) = S * T = P6
  184.     ST0;
  185.     sub(A, A, S, aii[0][1], aii[1][1], p);
  186.     add(B, B, T, bii[1][0], bii[1][1], p);
  187.     mul(S, T, P[6], p, p, p);
  188.  
  189.     // P0 + P3 - P4 + P6 = S - P4 + P6 = T + P6 = C00
  190.     add(P[0], P[3], S, p, p, p);
  191.     sub(S, P[4], T, p, p, p);
  192.     add(T, P[6], C, p, p, cii[0][0]);
  193.  
  194.     // P2 + P4 = C01
  195.     add(P[2], P[4], C, p, p, cii[0][1]);
  196.  
  197.     // P1 + P3 = C10
  198.     add(P[1], P[3], C, p, p, cii[1][0]);
  199.  
  200.     // P0 + P2 - P1 + P5 = S - P1 + P5 = T + P5 = C11
  201.     add(P[0], P[2], S, p, p, p);
  202.     sub(S, P[1], T, p, p, p);
  203.     add(T, P[5], C, p, p, cii[1][1]);
  204.  
  205. }
  206. int main()
  207. {
  208.     mat A, B, C;
  209.     corners ai = { 0, N, 0, N };
  210.     corners bi = { 0, N, 0, N };
  211.     corners ci = { 0, N, 0, N };
  212.     srand(time(0));
  213.     // identity(A,bi); identity(B,bi);
  214.     // set(A,ai,2); set(B,bi,2);
  215.     randk(A, ai, 0, 2);
  216.     randk(B, bi, 0, 2);
  217.     print(A, ai, "A");
  218.     print(B, bi, "B");
  219.     set(C, ci, 0);
  220.     // add(A,B,C, ai, bi, ci);
  221.     mul(A, B, C, ai, bi, ci);
  222.     print(C, ci, "C");
  223.     return 0;
  224. }

Output:

$ g++ StrassenMulitplication.cpp
$ a.out
 
A = {
 1.2, 0.83, 0.39, 0.41, 
 1.8,  1.9, 0.49, 0.23, 
0.38, 0.72,  1.8,  1.9, 
0.13,  1.8, 0.48, 0.82, 
}
B = {
 1.2,  1.6,  1.4,  1.6, 
0.27, 0.63,  0.3, 0.79, 
0.58,  1.2,  1.1, 0.07, 
   2,  1.9, 0.47, 0.47, 
}
C = {
 2.7,  3.7,  2.6,  2.9, 
 3.4,    5,  3.7,  4.5, 
 5.3,  6.7,  3.6,  2.2, 
 2.5,  3.5,  1.6,  2.1, 
}

Sanfoundry Global Education & Learning Series – 1000 C++ Programs.

advertisement

Here’s the list of Best Books in C++ Programming, Data Structures and Algorithms.

advertisement
Subscribe to our Newsletters (Subject-wise). Participate in the Sanfoundry Certification to get free Certificate of Merit. Join our social networks below and stay updated with latest contests, videos, internships and jobs!

Youtube | Telegram | LinkedIn | Instagram | Facebook | Twitter | Pinterest
Manish Bhojasia - Founder & CTO at Sanfoundry
I’m Manish - Founder and CTO at Sanfoundry. I’ve been working in tech for over 25 years, with deep focus on Linux kernel, SAN technologies, Advanced C, Full Stack and Scalable website designs.

You can connect with me on LinkedIn, watch my Youtube Masterclasses, or join my Telegram tech discussions.

If you’re in your 20s–40s and exploring new directions in your career, I also offer mentoring. Learn more here.