Skip to content

Matrix multiplication optimization #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions bench.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/// Author: asakhar
/// Description:
/// I basically changed the order of iteration
/// in matrix multiplication to make it more
/// cache friendly. Here are the benchmarks and test.
///
/// GCC with -O3 gave me 7x improvement (~0.021s -> ~0.003s)
/// MSVC with /O2 has a little bit less difference but never the less:
/// 2x improved (~0.021s -> ~0.01s)

#define NN_IMPLEMENTATION
#include "nn.h"
#include <time.h>
#include <stdio.h>

#define WARM_UP_TIME 3
#define ITERS 500

void mat_dot_old(Mat dst, Mat a, Mat b);

typedef void (*DotFunc)(Mat, Mat, Mat);
void bench(Mat dst, Mat a, Mat b, DotFunc func, char const *name);
void test_against(Mat a, Mat b, DotFunc reference, DotFunc to_test);

int main(void)
{
// setup
size_t R = 300;
size_t K = 200;
size_t C = 400;
Mat a = mat_alloc(R, K);
Mat b = mat_alloc(K, C);
Mat dst = mat_alloc(R, C);
mat_rand(a, 0, 1);
mat_rand(b, 0, 1);

// actual benches
bench(dst, a, b, mat_dot_old, "old");
bench(dst, a, b, mat_dot, "new");

// testing
test_against(a, b, mat_dot_old, mat_dot);
}

void bench(Mat dst, Mat a, Mat b, DotFunc func, char const *name)
{
double start = (double)clock() / CLOCKS_PER_SEC;
double end = start;
printf("Warming up for %d seconds...\n", WARM_UP_TIME);
while (end-start < WARM_UP_TIME)
{
func(dst, a, b);
end = (double)clock() / CLOCKS_PER_SEC;
}
printf("Running bench %s...\n", name);
double total_time = 0;
for (size_t i = 0; i < ITERS; ++i)
{
start = (double)clock() / CLOCKS_PER_SEC;
func(dst, a, b);
end = (double)clock() / CLOCKS_PER_SEC;
total_time += end-start;
}
printf("%s solution took: %fs to process in average among %d iterations\n", name, total_time/(double)ITERS, ITERS);
}

void mat_dot_old(Mat dst, Mat a, Mat b)
{
NN_ASSERT(a.cols == b.rows);
size_t n = a.cols;
NN_ASSERT(dst.rows == a.rows);
NN_ASSERT(dst.cols == b.cols);

for (size_t i = 0; i < dst.rows; ++i) {
for (size_t j = 0; j < dst.cols; ++j) {
MAT_AT(dst, i, j) = 0;
for (size_t k = 0; k < n; ++k) {
MAT_AT(dst, i, j) += MAT_AT(a, i, k) * MAT_AT(b, k, j);
}
}
}
}

void test_against(Mat a, Mat b, DotFunc reference, DotFunc to_test) {
Mat reference_res = mat_alloc(a.rows, b.cols);
Mat test_res = mat_alloc(a.rows, b.cols);
reference(reference_res, a, b);
to_test(test_res, a, b);
size_t total = reference_res.rows * reference_res.cols;
for(size_t i = 0; i < total; ++i) {
if(reference_res.es[i] != test_res.es[i]) {
fputs("Matrices did not match", stderr);
return;
}
}
puts("Matrices are equal");
}
8 changes: 5 additions & 3 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ set -xe
CFLAGS="-O3 -Wall -Wextra -I./thirdparty/ `pkg-config --cflags raylib`"
LIBS="-lm `pkg-config --libs raylib` -lglfw -ldl -lpthread"

clang $CFLAGS -o adder adder.c $LIBS
clang $CFLAGS -o xor xor.c $LIBS
clang $CFLAGS -o img2nn img2nn.c $LIBS
clang $CFLAGS -o adder_gen adder_gen.c $LIBS
clang $CFLAGS `pkg-config --cflags raylib` -o xor xor.c $LIBS `pkg-config --libs raylib` -lglfw -ldl -lpthread
clang $CFLAGS `pkg-config --cflags raylib` -o gym gym.c $LIBS `pkg-config --libs raylib` -lglfw -ldl -lpthread
clang $CFLAGS `pkg-config --cflags raylib` -o img2nn img2nn.c $LIBS `pkg-config --libs raylib` -lglfw -ldl -lpthread
clang $CFLAGS -o bench bench.c $LIBS
6 changes: 3 additions & 3 deletions nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ void mat_dot(Mat dst, Mat a, Mat b)
size_t n = a.cols;
NN_ASSERT(dst.rows == a.rows);
NN_ASSERT(dst.cols == b.cols);
mat_fill(dst, 0);

for (size_t i = 0; i < dst.rows; ++i) {
for (size_t j = 0; j < dst.cols; ++j) {
MAT_AT(dst, i, j) = 0;
for (size_t k = 0; k < n; ++k) {
for (size_t k = 0; k < n; ++k) {
for (size_t j = 0; j < dst.cols; ++j) {
MAT_AT(dst, i, j) += MAT_AT(a, i, k) * MAT_AT(b, k, j);
}
}
Expand Down