-
Notifications
You must be signed in to change notification settings - Fork 115
Expand file tree
/
Copy pathsolution_onemath_usm_gemm.cpp
More file actions
161 lines (128 loc) · 4.94 KB
/
Copy pathsolution_onemath_usm_gemm.cpp
File metadata and controls
161 lines (128 loc) · 4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
/*
SYCL Academy (c)
SYCL Academy is licensed under a Creative Commons
Attribution-ShareAlike 4.0 International License.
You should have received a copy of the license along with this
work. If not, see <http://creativecommons.org/licenses/by-sa/4.0/>.
Quick Reference
~~~~~~~~~~~~~~~~~~~~
oneMath execution model:
https://oneapi-spec.uxlfoundation.org/specifications/oneapi/latest/elements/onemath/source/architecture/architecture
oneMath GEMM API:
https://oneapi-spec.uxlfoundation.org/specifications/oneapi/latest/elements/onemath/source/domains/blas/gemm
*/
#include <iostream>
#include <limits>
#include <oneapi/math.hpp>
#include <random>
#include <sycl/sycl.hpp>
// Matrix size constants
constexpr size_t SIZE = 4800; // Must be a multiple of 8.
constexpr size_t M = SIZE / 8;
constexpr size_t N = SIZE / 4;
constexpr size_t P = SIZE / 2;
using T = double;
//////////////////////////////////////////////////////////////////////////////////////////
bool ValueSame(T a, T b) { return std::fabs(a - b) < 1.0e-08; }
int VerifyResult(T* c_A, T* c_B) {
bool MismatchFound = false;
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < P; j++) {
if (!ValueSame(c_A[i * P + j], c_B[i * P + j])) {
std::cout << "fail - The result is incorrect for element: [" << i
<< ", " << j << "], expected: " << c_A[i * P + j]
<< " , but got: " << c_B[i * P + j] << std::endl;
MismatchFound = true;
}
}
}
if (!MismatchFound) {
std::cout << "SUCCESS - The results are correct!" << std::endl;
return 0;
} else {
std::cout << "FAIL - The results mis-match!" << std::endl;
return -1;
}
}
//////////////////////////////////////////////////////////////////////////////////////////
void print_device_info(sycl::queue& Q) {
std::string sycl_dev_name, sycl_dev_version, sycl_driver;
sycl_dev_name = Q.get_device().get_info<sycl::info::device::name>();
sycl_driver = Q.get_device().get_info<sycl::info::device::driver_version>();
sycl_dev_version = Q.get_device().get_info<sycl::info::device::version>();
std::cout << "Running on " << sycl_dev_name.c_str()
<< ", version: " << sycl_dev_version.c_str()
<< ", driver version: " << sycl_driver.c_str() << std::endl;
}
//////////////////////////////////////////////////////////////////////////////////////////
int main() {
std::random_device
rd; // Will be used to obtain a seed for the random number engine
std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with rd()
std::uniform_real_distribution<> dis(1.0, 2.0);
// matrix data sizes
int m = M;
int n = P;
int k = N;
// leading dimensions of data
int ldA = k;
int ldB = n;
int ldC = n;
// set scalar fp values
T alpha = 1.0;
T beta = 0.0;
// Allocate memory on host
std::vector<T> A(M * N);
std::vector<T> B(N * P);
std::vector<T> C_host(M * P);
std::cout << "Problem size: c(" << M << "," << P << ") = a(" << M << "," << N
<< ") * b(" << N << "," << P << ")" << std::endl;
// A(M, N)
for (size_t i = 0; i < M; i++)
for (size_t j = 0; j < N; j++) A[i * N + j] = dis(gen);
// B(N, P)
for (size_t i = 0; i < N; i++)
for (size_t j = 0; j < P; j++) B[i * P + j] = dis(gen);
// Resultant matrix: C_serial = A*B
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < P; j++) {
for (size_t d = 0; d < N; d++) {
C_host[i * P + j] += A[i * N + d] * B[d * P + j];
}
}
}
// Create a SYCL queue
sycl::queue Q;
// Prints some basic info related to the hardware
print_device_info(Q);
// Allocate memory on device, (using sycl::malloc_device APIs)
T* a = sycl::malloc_device<T>((M * N), Q);
T* b = sycl::malloc_device<T>((N * P), Q);
T* c = sycl::malloc_device<T>((M * P), Q);
sycl::event eventCopyA = Q.memcpy(a, A.data(), sizeof(T) * M * N);
sycl::event eventCopyB = Q.memcpy(b, B.data(), sizeof(T) * N * P);
// Use oneMath GEMM USM API
oneapi::math::transpose transA = oneapi::math::transpose::nontrans;
oneapi::math::transpose transB = oneapi::math::transpose::nontrans;
// Pass the synchronisation events to ensure GEMM starts after inputs are
// fully copied to the device
sycl::event eventGEMM = oneapi::math::blas::column_major::gemm(
Q, transA, transB, n, m, k, alpha, b, ldB, a, ldA, beta, c, ldC,
{eventCopyA, eventCopyB}); // row-major
// Copy the results from device to host for verification
std::vector<T> C_device(M * P);
// Pass the synchronisation event for the copy to wait until GEMM is finished
sycl::event eventCopyC =
Q.memcpy(C_device.data(), c, sizeof(T) * M * P, eventGEMM);
// Wait for the copy to finish
eventCopyC.wait();
// Verify results from oneMath
int result = 0;
std::cout << "Verify results between oneMath & serial: ";
result = VerifyResult(C_device.data(), C_host.data());
// Free memory from device
sycl::free(a, Q);
sycl::free(b, Q);
sycl::free(c, Q);
return result;
}