Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 29 additions & 40 deletions rts/System/Matrix44f.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,9 @@ CMatrix44f& CMatrix44f::Translate(const float x, const float y, const float z)
return *this;
}



__FORCE_ALIGN_STACK__
static inline void MatrixMatrixMultiplySSE(const CMatrix44f& m1, const CMatrix44f& m2, CMatrix44f* mout)
{
//alignof guarantees 16 byte alignment required by SSE2
const __m128 m1c1 = _mm_load_ps(&m1.md[0][0]);
const __m128 m1c2 = _mm_load_ps(&m1.md[1][0]);
const __m128 m1c3 = _mm_load_ps(&m1.md[2][0]);
Expand All @@ -342,43 +339,35 @@ static inline void MatrixMatrixMultiplySSE(const CMatrix44f& m1, const CMatrix44
assert(m2.m[7] == 0.0f);
// assert(m2.m[11] == 0.0f); in case of a gluPerspective it's -1

const __m128 m2i0 = _mm_load1_ps(&m2.m[0]);
const __m128 m2i1 = _mm_load1_ps(&m2.m[1]);
const __m128 m2i2 = _mm_load1_ps(&m2.m[2]);
//const __m128 m2i3 = _mm_load1_ps(&m2.m[3]);
const __m128 m2i4 = _mm_load1_ps(&m2.m[4]);
const __m128 m2i5 = _mm_load1_ps(&m2.m[5]);
const __m128 m2i6 = _mm_load1_ps(&m2.m[6]);
//const __m128 m2i7 = _mm_load1_ps(&m2.m[7]);
const __m128 m2i8 = _mm_load1_ps(&m2.m[8]);
const __m128 m2i9 = _mm_load1_ps(&m2.m[9]);
const __m128 m2i10 = _mm_load1_ps(&m2.m[10]);
const __m128 m2i11 = _mm_load1_ps(&m2.m[11]);
const __m128 m2i12 = _mm_load1_ps(&m2.m[12]);
const __m128 m2i13 = _mm_load1_ps(&m2.m[13]);
const __m128 m2i14 = _mm_load1_ps(&m2.m[14]);
const __m128 m2i15 = _mm_load1_ps(&m2.m[15]);

__m128 moutc1, moutc2, moutc3, moutc4;
moutc1 = _mm_mul_ps(m1c1, m2i0);
moutc2 = _mm_mul_ps(m1c1, m2i4);
moutc3 = _mm_mul_ps(m1c1, m2i8);
moutc4 = _mm_mul_ps(m1c1, m2i12);

moutc1 = _mm_add_ps(moutc1, _mm_mul_ps(m1c2, m2i1));
moutc2 = _mm_add_ps(moutc2, _mm_mul_ps(m1c2, m2i5));
moutc3 = _mm_add_ps(moutc3, _mm_mul_ps(m1c2, m2i9));
moutc4 = _mm_add_ps(moutc4, _mm_mul_ps(m1c2, m2i13));

moutc1 = _mm_add_ps(moutc1, _mm_mul_ps(m1c3, m2i2));
moutc2 = _mm_add_ps(moutc2, _mm_mul_ps(m1c3, m2i6));
moutc3 = _mm_add_ps(moutc3, _mm_mul_ps(m1c3, m2i10));
moutc4 = _mm_add_ps(moutc4, _mm_mul_ps(m1c3, m2i14));

//moutc1 = _mm_add_ps(moutc1, _mm_mul_ps(m1c4, _mm_load1_ps(&m2.m[3])));
//moutc2 = _mm_add_ps(moutc2, _mm_mul_ps(m1c4, _mm_load1_ps(&m2.m[7])));
moutc3 = _mm_add_ps(moutc3, _mm_mul_ps(m1c4, m2i11));
moutc4 = _mm_add_ps(moutc4, _mm_mul_ps(m1c4, m2i15));
// Load each column of m2 as a full vector, then use _mm_shuffle_ps to broadcast
// each element — avoids 12 separate scalar loads with _mm_load1_ps
const __m128 m2c0 = _mm_load_ps(&m2.m[0]); // [m00, m10, m20, m30=0]
const __m128 m2c1 = _mm_load_ps(&m2.m[4]); // [m01, m11, m21, m31=0]
const __m128 m2c2 = _mm_load_ps(&m2.m[8]); // [m02, m12, m22, m32]
const __m128 m2c3 = _mm_load_ps(&m2.m[12]); // [m03, m13, m23, m33]

#define SPLAT(v, i) _mm_shuffle_ps(v, v, _MM_SHUFFLE(i,i,i,i))

__m128 moutc1 = _mm_mul_ps(m1c1, SPLAT(m2c0, 0));
__m128 moutc2 = _mm_mul_ps(m1c1, SPLAT(m2c1, 0));
__m128 moutc3 = _mm_mul_ps(m1c1, SPLAT(m2c2, 0));
__m128 moutc4 = _mm_mul_ps(m1c1, SPLAT(m2c3, 0));

moutc1 = _mm_add_ps(moutc1, _mm_mul_ps(m1c2, SPLAT(m2c0, 1)));
moutc2 = _mm_add_ps(moutc2, _mm_mul_ps(m1c2, SPLAT(m2c1, 1)));
moutc3 = _mm_add_ps(moutc3, _mm_mul_ps(m1c2, SPLAT(m2c2, 1)));
moutc4 = _mm_add_ps(moutc4, _mm_mul_ps(m1c2, SPLAT(m2c3, 1)));

moutc1 = _mm_add_ps(moutc1, _mm_mul_ps(m1c3, SPLAT(m2c0, 2)));
moutc2 = _mm_add_ps(moutc2, _mm_mul_ps(m1c3, SPLAT(m2c1, 2)));
moutc3 = _mm_add_ps(moutc3, _mm_mul_ps(m1c3, SPLAT(m2c2, 2)));
moutc4 = _mm_add_ps(moutc4, _mm_mul_ps(m1c3, SPLAT(m2c3, 2)));

// m2.m[3] and m2.m[7] are zero — skip those terms
moutc3 = _mm_add_ps(moutc3, _mm_mul_ps(m1c4, SPLAT(m2c2, 3)));
moutc4 = _mm_add_ps(moutc4, _mm_mul_ps(m1c4, SPLAT(m2c3, 3)));

#undef SPLAT

_mm_store_ps(&mout->md[0][0], moutc1);
_mm_store_ps(&mout->md[1][0], moutc2);
Expand Down
86 changes: 84 additions & 2 deletions test/engine/System/testMatrix44f.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
/* This file is part of the Spring engine (GPL v2 or later), see LICENSE.html */

#include <random>

#include "System/simd_compat.h"
#include "System/Matrix44f.h"
#include "System/float4.h"
Expand Down Expand Up @@ -99,7 +101,7 @@ static const int testRuns = 40000000;
}


_noinline static void MatrixMatrixMultiply(CMatrix44f* m1, const CMatrix44f& m2)
_noinline static void MatrixMatrixMultiplySSEOld(CMatrix44f* m1, const CMatrix44f& m2)
{
assert(long(&m1->m[0]) % 16 == 0); // 16byte aligned

Expand Down Expand Up @@ -206,11 +208,12 @@ _noinline static int TestMMSSE()
ScopedOnceTimer timer("Matrix-Matrix-Mult: sse");
CMatrix44f m1(m_);
for (int i = 0; i < testRuns; ++i) {
MatrixMatrixMultiply(&m1, m);
MatrixMatrixMultiplySSEOld(&m1, m);
}
return spring::LiteHash(&m1, sizeof(CMatrix44f), 0);
}


_noinline static int TestSpring()
{
ScopedOnceTimer timer("Matrix-Vector-Mult: spring");
Expand Down Expand Up @@ -287,3 +290,82 @@ TEST_CASE("Matrix44MatrixMultiply")
}
}
}

TEST_CASE("Matrix44MatrixMultiplySSE")
{
for (int i = 0; i < 16; ++i) {
if ((i != 7) && (i != 3)) {
m[i] = float(i + 1) / 31.3125f;
m_[i] = float(i + 1) / 31.3125f;
} else {
m[i] = 0.0f;
m_[i] = 0.0f;
}
}

ScopedOnceTimer timer("Matrix-Matrix-Mult: sse");
CMatrix44f m1(m_);
for (int i = 0; i < testRuns; ++i) {
MatrixMatrixMultiplySSEOld(&m1, m);
}
spring::LiteHash(&m1, sizeof(CMatrix44f), 0);

spring_clock::PopTickRate();
}

TEST_CASE("Matrix44MatrixMultiplySSEOldVsSSENew")
{
const int numTests = 100000;
bool allMatch = true;
std::mt19937 rng(12345);
std::uniform_real_distribution<float> dist(-10.0f, 10.0f);

for (int t = 0; t < numTests; ++t) {
for (int i = 0; i < 16; ++i) {
m[i] = dist(rng);
}

for (int i = 0; i < 16; ++i) {
if (i == 3 || i == 7) {
m_[i] = 0.0f;
} else {
m_[i] = dist(rng);
}
}

assert(m_[3] == 0.0f);
assert(m_[7] == 0.0f);

CMatrix44f resultNew = m * m_;
MatrixMatrixMultiplySSEOld(&m, m_);

if (!(m == resultNew)) {
allMatch = false;
break;
}
}

CHECK(allMatch == true);
}

TEST_CASE("Matrix44MatrixMultiplySSE_Opt")
{
for (int i = 0; i < 16; ++i) {
if ((i != 7) && (i != 3)) {
m[i] = float(i + 1) / 31.3125f;
m_[i] = float(i + 1) / 31.3125f;
} else {
m[i] = 0.0f;
m_[i] = 0.0f;
}
}

ScopedOnceTimer timer("Matrix-Matrix-Mult: sse_new");
CMatrix44f m1(m_);
for (int i = 0; i < testRuns; ++i) {
m1 = m1 * m_;
}
spring::LiteHash(&m1, sizeof(CMatrix44f), 0);

spring_clock::PopTickRate();
}
Loading