Modern Arm Assembly Language Programming: Covers Armv8-A 32-bit, 64-bit, and SIMD

by Daniel Kusswurm
2021.07.28: updated by
Up

Chapter 16: Armv8-64 Advanced SIMD Programming

Vector and Matrix Operations

Matrix-Vector Multiplication

SIMD 命令を使用して、4x4 行列と4次元ベクトルの乗算算を行うプログラム

ch16_03

NeonのSIMD命令を使って実装した関数 void MatVecMulF32_(Vec4x1F32* b, float m[4][4], Vec4x1F32* a, int n) の第2引数である行列 m は転置行列を渡すことを前提にしている。

ch16_03/mat.h
#pragma once

struct Vec4x1F32 {
  float W, X, Y,Z;
};
ch16_03/main.cpp
#include <iostream>
#include <iomanip>
#include <random>
#include <cmath>
#include "mat.h"
#include "AlignedMem.h"

using namespace std;

#define EPS		1.0e-12f
#define ARM_ALIGN	16

extern bool MatVecMulF32_(Vec4x1F32* b, float m[4][4], Vec4x1F32* a, int n);


bool VecCompare(const Vec4x1F32* p, const Vec4x1F32* q)
{
  return (fabs(p->W - q->W) < EPS) && (fabs(p->X - q->X) < EPS)
    && (fabs(p->Y - q->Y) < EPS) && (fabs(p->Z - q->Z) < EPS);
}
void InitVecArray(Vec4x1F32* a, size_t n)
{
    uniform_int_distribution<> ui_dist {1, 500};
    mt19937 rng {187};
    for (size_t i = 0; i < n; i++)
    {
        a[i].W = (float)ui_dist(rng);
        a[i].X = (float)ui_dist(rng);
        a[i].Y = (float)ui_dist(rng);
        a[i].Z = (float)ui_dist(rng);
    }
    if (n >= 4)
    {
        // Known values for test purposes
        a[0].W =  5; a[0].X =  6; a[0].Y =  7; a[0].Z =  8;
        a[1].W = 15; a[1].X = 16; a[1].Y = 17; a[1].Z = 18;
        a[2].W = 25; a[2].X = 26; a[2].Y = 27; a[2].Z = 28;
        a[3].W = 35; a[3].X = 36; a[3].Y = 37; a[3].Z = 38;
    }
}
bool MatVecMulF32Cpp(Vec4x1F32* b, float m[4][4], Vec4x1F32* a, size_t n)
{
    if (n == 0 || (n % 4) != 0)
        return false;
    if (!AlignedMem::IsAligned(a, ARM_ALIGN) || !AlignedMem::IsAligned(b, ARM_ALIGN))
        return false;
    for (size_t i = 0; i < n; i++)
    {
        b[i].W =  m[0][0] * a[i].W + m[0][1] * a[i].X;
        b[i].W += m[0][2] * a[i].Y + m[0][3] * a[i].Z;
        b[i].X =  m[1][0] * a[i].W + m[1][1] * a[i].X;
        b[i].X += m[1][2] * a[i].Y + m[1][3] * a[i].Z;
        b[i].Y =  m[2][0] * a[i].W + m[2][1] * a[i].X;
        b[i].Y += m[2][2] * a[i].Y + m[2][3] * a[i].Z;
        b[i].Z =  m[3][0] * a[i].W + m[3][1] * a[i].X;
        b[i].Z += m[3][2] * a[i].Y + m[3][3] * a[i].Z;
    }
    return true;
}
void MatVecMulF32(void)
{
    const char nl = '\n';
    const size_t num_vec = 8;
    alignas(ARM_ALIGN) float m[4][4]
    {
       10.0, 11.0, 12.0, 13.0,
       20.0, 21.0, 22.0, 23.0,
       30.0, 31.0, 32.0, 33.0,
       40.0, 41.0, 42.0, 43.0
    };
    AlignedArray<Vec4x1F32> a_aa(num_vec, ARM_ALIGN);
    AlignedArray<Vec4x1F32> b1_aa(num_vec, ARM_ALIGN);
    AlignedArray<Vec4x1F32> b2_aa(num_vec, ARM_ALIGN);
    Vec4x1F32* a = a_aa.Data();
    Vec4x1F32* b1 = b1_aa.Data();
    Vec4x1F32* b2 = b2_aa.Data();
    InitVecArray(a, num_vec);
    bool rc1 = MatVecMulF32Cpp(b1, m, a, num_vec);
    bool rc2 = MatVecMulF32_(b2, m, a, num_vec);
    cout << "Results for MatVecMulF32\n";
    if (!rc1 || !rc2)
    {
        cout << "Invalid return code\n";
        cout << "  rc1 = " << boolalpha << rc1 << nl;
        cout << "  rc2 = " << boolalpha << rc2 << nl;
        return;
    }
    const unsigned int w = 8;
    cout << fixed << setprecision(1);
    for (size_t i = 0; i < num_vec; i++)
    {
        cout << "Test case #" << i << '\n';
        cout << "b1: ";
        cout << "  " << setw(w) << b1[i].W << ' ';
        cout << "  " << setw(w) << b1[i].X << ' ';
        cout << "  " << setw(w) << b1[i].Y << ' ';
        cout << "  " << setw(w) << b1[i].Z << nl;
        cout << "b2: ";
        cout << "  " << setw(w) << b2[i].W << ' ';
        cout << "  " << setw(w) << b2[i].X << ' ';
        cout << "  " << setw(w) << b2[i].Y << ' ';
        cout << "  " << setw(w) << b2[i].Z << nl;
        if (!VecCompare(&b1[i], &b2[i]))
        {
            cout << "Error - vector compare failed\n";
            return;
        }
    }
}
int main()
{
    MatVecMulF32();
    // MatVecMulF32_BM();
    return 0;
}
ch16_03/neon.cpp
#include "mat.h"

void MatVecMulF32_(Vec4x1F32* b, float m[4][4], Vec4x1F32* a, int n) { // m must be transposed
  __asm volatile("\n\
	// b = M a                                                                                   \n\
	// v0: matrix M column 0                                                                     \n\
	// v1: matrix M column 1                                                                     \n\
	// v2: matrix M column 2                                                                     \n\
	// v3: matrix M column 3                                                                     \n\
	// [x0]...[x0+12]: vector b                                                                  \n\
	// [x2]...[x2+12]: vector a                                                                  \n\
	.macro Mat4x4MulVec                                                                          \n\
	ld1	{v4.4s}, [x2], 16            // v4 = [x2] ; x2 += 16                                 \n\
	fmul	v5.4s, v0.4s, v4.s[0]        // v5 = M[:,0] a.w                                      \n\
	fmla	v5.4s, v1.4s, v4.s[1]        // v5 += M[:,1] a.x                                     \n\
	fmla	v5.4s, v2.4s, v4.s[2]        // v5 += M[:,1] a.y                                     \n\
	fmla	v5.4s, v3.4s, v4.s[3]        // v5 += M[:,1] a.z                                     \n\
	st1	{v5.4s}, [x0], 16            // [x0] = v5; x0 += 16                                  \n\
	.endm	                                                                                     \n\
		                                                                                     \n\
		                                                                                     \n\
	cbz	x3, LInvalidArg              // if n == 0 goto InvalidArg                            \n\
	tst	x3, 0x3                      // if (n & 3) != 0                                      \n\
	b.ne	LInvalidArg                  //   goto InvalidArg                                    \n\
	tst	x2, 0xf                      // if (n & f) != 0                                      \n\
	b.ne	LInvalidArg                  //   goto InvalidArg                                    \n\
		                                                                                     \n\
	ld4	{v0.4s-v3.4s}, [x1]          // transpose M                                          \n\
LLoop1:		                                                                                     \n\
	Mat4x4MulVec                                                                                 \n\
	Mat4x4MulVec                                                                                 \n\
	Mat4x4MulVec                                                                                 \n\
	Mat4x4MulVec                                                                                 \n\
	subs	x3, x3, 4                   // if ((n -= 4) != 0)                                    \n\
	b.ne	LLoop1                      //   goto Loop1                                          \n\
		                                                                                     \n\
	mov	w0, 1                       // return code: success                                  \n\
	b	LEXIT                                                                                \n\
		                                                                                     \n\
LInvalidArg:	                                                                                     \n\
	mov	w0, 1                       // return code: error                                    \n\
LEXIT:		                                                                                     \n\
"
		 :
		 :
		 : "x0", "x1", "x2", "x3", "v0", "v1", "v2", "v3", "v4", "v5"
		 );
}
ch16_03/main.cpp の実行例
arm64@manet ch16_03 % g++ -I.. -I. -std=c++11 -O -S neon.cpp
arm64@manet ch16_03 % g++ -I.. -I. -std=c++11 -O main.cpp neon.cpp -o a.out
arm64@manet ch16_03 % ./a.out
Results for MatVecMulF32
Test case #0
b1:      304.0      564.0      824.0     1084.0
b2:      304.0      564.0      824.0     1084.0
Test case #1
b1:      764.0     1424.0     2084.0     2744.0
b2:      764.0     1424.0     2084.0     2744.0
Test case #2
b1:     1224.0     2284.0     3344.0     4404.0
b2:     1224.0     2284.0     3344.0     4404.0
Test case #3
b1:     1684.0     3144.0     4604.0     6064.0
b2:     1684.0     3144.0     4604.0     6064.0
Test case #4
b1:    13208.0    24608.0    36008.0    47408.0
b2:    13208.0    24608.0    36008.0    47408.0
Test case #5
b1:    10205.0    19025.0    27845.0    36665.0
b2:    10205.0    19025.0    27845.0    36665.0
Test case #6
b1:     9361.0    17371.0    25381.0    33391.0
b2:     9361.0    17371.0    25381.0    33391.0
Test case #7
b1:     4649.0     9029.0    13409.0    17789.0
b2:     4649.0     9029.0    13409.0    17789.0


http://nw.tsuda.ac.jp/