1 | #include <cstring>
|
---|
2 |
|
---|
3 | #include <d3dcompiler.h>
|
---|
4 | #include <d3d11.h>
|
---|
5 |
|
---|
6 | #include <windows.h>
|
---|
7 | #include <windowsx.h>
|
---|
8 |
|
---|
9 | #include "../test_utils.h"
|
---|
10 |
|
---|
11 | using namespace dxvk;
|
---|
12 |
|
---|
13 | const std::string g_computeShaderCode =
|
---|
14 | "StructuredBuffer<uint> buf_in : register(t0);\n"
|
---|
15 | "RWStructuredBuffer<uint> buf_out : register(u0);\n"
|
---|
16 | "groupshared uint tmp[64];\n"
|
---|
17 | "[numthreads(64,1,1)]\n"
|
---|
18 | "void main(uint localId : SV_GroupIndex, uint3 globalId : SV_DispatchThreadID) {\n"
|
---|
19 | " tmp[localId] = buf_in[2 * globalId.x + 0]\n"
|
---|
20 | " + buf_in[2 * globalId.x + 1];\n"
|
---|
21 | " GroupMemoryBarrierWithGroupSync();\n"
|
---|
22 | " uint activeGroups = 32;\n"
|
---|
23 | " while (activeGroups != 0) {\n"
|
---|
24 | " if (localId < activeGroups)\n"
|
---|
25 | " tmp[localId] += tmp[localId + activeGroups];\n"
|
---|
26 | " GroupMemoryBarrierWithGroupSync();\n"
|
---|
27 | " activeGroups >>= 1;\n"
|
---|
28 | " }\n"
|
---|
29 | " if (localId == 0)\n"
|
---|
30 | " buf_out[0] = tmp[0];\n"
|
---|
31 | "}\n";
|
---|
32 |
|
---|
33 | int WINAPI WinMain(HINSTANCE hInstance,
|
---|
34 | HINSTANCE hPrevInstance,
|
---|
35 | LPSTR lpCmdLine,
|
---|
36 | int nCmdShow) {
|
---|
37 | Com<ID3D11Device> device;
|
---|
38 | Com<ID3D11DeviceContext> context;
|
---|
39 | Com<ID3D11ComputeShader> computeShader;
|
---|
40 |
|
---|
41 | Com<ID3D11Buffer> srcBuffer;
|
---|
42 | Com<ID3D11Buffer> dstBuffer;
|
---|
43 | Com<ID3D11Buffer> readBuffer;
|
---|
44 |
|
---|
45 | Com<ID3D11ShaderResourceView> srcView;
|
---|
46 | Com<ID3D11UnorderedAccessView> dstView;
|
---|
47 |
|
---|
48 | if (FAILED(D3D11CreateDevice(
|
---|
49 | nullptr, D3D_DRIVER_TYPE_HARDWARE,
|
---|
50 | nullptr, 0, nullptr, 0, D3D11_SDK_VERSION,
|
---|
51 | &device, nullptr, &context))) {
|
---|
52 | std::cerr << "Failed to create D3D11 device" << std::endl;
|
---|
53 | return 1;
|
---|
54 | }
|
---|
55 |
|
---|
56 | Com<ID3DBlob> computeShaderBlob;
|
---|
57 |
|
---|
58 | if (FAILED(D3DCompile(
|
---|
59 | g_computeShaderCode.data(),
|
---|
60 | g_computeShaderCode.size(),
|
---|
61 | "Compute shader",
|
---|
62 | nullptr, nullptr,
|
---|
63 | "main", "cs_5_0", 0, 0,
|
---|
64 | &computeShaderBlob,
|
---|
65 | nullptr))) {
|
---|
66 | std::cerr << "Failed to compile compute shader" << std::endl;
|
---|
67 | return 1;
|
---|
68 | }
|
---|
69 |
|
---|
70 | if (FAILED(device->CreateComputeShader(
|
---|
71 | computeShaderBlob->GetBufferPointer(),
|
---|
72 | computeShaderBlob->GetBufferSize(),
|
---|
73 | nullptr, &computeShader))) {
|
---|
74 | std::cerr << "Failed to create compute shader" << std::endl;
|
---|
75 | return 1;
|
---|
76 | }
|
---|
77 |
|
---|
78 | std::array<uint32_t, 128> srcData;
|
---|
79 | for (uint32_t i = 0; i < srcData.size(); i++)
|
---|
80 | srcData[i] = i + 1;
|
---|
81 |
|
---|
82 | D3D11_BUFFER_DESC srcBufferDesc;
|
---|
83 | srcBufferDesc.ByteWidth = sizeof(uint32_t) * srcData.size();
|
---|
84 | srcBufferDesc.Usage = D3D11_USAGE_IMMUTABLE;
|
---|
85 | srcBufferDesc.BindFlags = D3D11_BIND_SHADER_RESOURCE;
|
---|
86 | srcBufferDesc.CPUAccessFlags = 0;
|
---|
87 | srcBufferDesc.MiscFlags = D3D11_RESOURCE_MISC_BUFFER_STRUCTURED;
|
---|
88 | srcBufferDesc.StructureByteStride = sizeof(uint32_t);
|
---|
89 |
|
---|
90 | D3D11_SUBRESOURCE_DATA srcDataInfo;
|
---|
91 | srcDataInfo.pSysMem = srcData.data();
|
---|
92 | srcDataInfo.SysMemPitch = 0;
|
---|
93 | srcDataInfo.SysMemSlicePitch = 0;
|
---|
94 |
|
---|
95 | if (FAILED(device->CreateBuffer(&srcBufferDesc, &srcDataInfo, &srcBuffer))) {
|
---|
96 | std::cerr << "Failed to create source buffer" << std::endl;
|
---|
97 | return 1;
|
---|
98 | }
|
---|
99 |
|
---|
100 | D3D11_BUFFER_DESC dstBufferDesc;
|
---|
101 | dstBufferDesc.ByteWidth = sizeof(uint32_t);
|
---|
102 | dstBufferDesc.Usage = D3D11_USAGE_DEFAULT;
|
---|
103 | dstBufferDesc.BindFlags = D3D11_BIND_UNORDERED_ACCESS;
|
---|
104 | dstBufferDesc.CPUAccessFlags = 0;
|
---|
105 | dstBufferDesc.MiscFlags = D3D11_RESOURCE_MISC_BUFFER_STRUCTURED;
|
---|
106 | dstBufferDesc.StructureByteStride = sizeof(uint32_t);
|
---|
107 |
|
---|
108 | if (FAILED(device->CreateBuffer(&dstBufferDesc, &srcDataInfo, &dstBuffer))) {
|
---|
109 | std::cerr << "Failed to create destination buffer" << std::endl;
|
---|
110 | return 1;
|
---|
111 | }
|
---|
112 |
|
---|
113 | D3D11_BUFFER_DESC readBufferDesc;
|
---|
114 | readBufferDesc.ByteWidth = sizeof(uint32_t);
|
---|
115 | readBufferDesc.Usage = D3D11_USAGE_STAGING;
|
---|
116 | readBufferDesc.BindFlags = 0;
|
---|
117 | readBufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_READ;
|
---|
118 | readBufferDesc.MiscFlags = 0;
|
---|
119 | readBufferDesc.StructureByteStride = 0;
|
---|
120 |
|
---|
121 | if (FAILED(device->CreateBuffer(&readBufferDesc, nullptr, &readBuffer))) {
|
---|
122 | std::cerr << "Failed to create readback buffer" << std::endl;
|
---|
123 | return 1;
|
---|
124 | }
|
---|
125 |
|
---|
126 | D3D11_SHADER_RESOURCE_VIEW_DESC srcViewDesc;
|
---|
127 | srcViewDesc.Format = DXGI_FORMAT_UNKNOWN;
|
---|
128 | srcViewDesc.ViewDimension = D3D11_SRV_DIMENSION_BUFFEREX;
|
---|
129 | srcViewDesc.BufferEx.FirstElement = 0;
|
---|
130 | srcViewDesc.BufferEx.NumElements = srcData.size();
|
---|
131 | srcViewDesc.BufferEx.Flags = 0;
|
---|
132 |
|
---|
133 | if (FAILED(device->CreateShaderResourceView(srcBuffer.ptr(), &srcViewDesc, &srcView))) {
|
---|
134 | std::cerr << "Failed to create shader resource view" << std::endl;
|
---|
135 | return 1;
|
---|
136 | }
|
---|
137 |
|
---|
138 | D3D11_UNORDERED_ACCESS_VIEW_DESC dstViewDesc;
|
---|
139 | dstViewDesc.Format = DXGI_FORMAT_UNKNOWN;
|
---|
140 | dstViewDesc.ViewDimension = D3D11_UAV_DIMENSION_BUFFER;
|
---|
141 | dstViewDesc.Buffer.FirstElement = 0;
|
---|
142 | dstViewDesc.Buffer.NumElements = 1;
|
---|
143 | dstViewDesc.Buffer.Flags = 0;
|
---|
144 |
|
---|
145 | if (FAILED(device->CreateUnorderedAccessView(dstBuffer.ptr(), &dstViewDesc, &dstView))) {
|
---|
146 | std::cerr << "Failed to create unordered access view" << std::endl;
|
---|
147 | return 1;
|
---|
148 | }
|
---|
149 |
|
---|
150 | // Compute sum of the source buffer values
|
---|
151 | context->CSSetShader(computeShader.ptr(), nullptr, 0);
|
---|
152 | context->CSSetShaderResources(0, 1, &srcView);
|
---|
153 | context->CSSetUnorderedAccessViews(0, 1, &dstView, nullptr);
|
---|
154 | context->Dispatch(1, 1, 1);
|
---|
155 |
|
---|
156 | // Write data to the readback buffer and query the result
|
---|
157 | context->CopyResource(readBuffer.ptr(), dstBuffer.ptr());
|
---|
158 |
|
---|
159 | D3D11_MAPPED_SUBRESOURCE mappedResource;
|
---|
160 | if (FAILED(context->Map(readBuffer.ptr(), 0, D3D11_MAP_READ, 0, &mappedResource))) {
|
---|
161 | std::cerr << "Failed to map readback buffer" << std::endl;
|
---|
162 | return 1;
|
---|
163 | }
|
---|
164 |
|
---|
165 | uint32_t result = 0;
|
---|
166 | std::memcpy(&result, mappedResource.pData, sizeof(result));
|
---|
167 | context->Unmap(readBuffer.ptr(), 0);
|
---|
168 |
|
---|
169 | std::cout << "Sum of the numbers 1 to " << srcData.size() << " = " << result << std::endl;
|
---|
170 | context->ClearState();
|
---|
171 | return 0;
|
---|
172 | }
|
---|