Skip to content

Commit 385ce12

Browse files
feat: 接入CNNL,并添加unary/binary/softmax/batchnorm/reduce/transpose/pooling算子
1 parent 7f82d74 commit 385ce12

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1924
-9
lines changed

src/02hardware/CMakeLists.txt

+2-6
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,10 @@ project(hardware VERSION 0.0.0 LANGUAGES CXX)
33
message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION})
44

55
# Source files
6-
file(GLOB HARDWARE_SRC src/*.cc src/*.cpp src/devices/cpu/*.cc)
6+
file(GLOB_RECURSE HARDWARE_SRC src/*.cc src/*.cpp)
77

88
if(USE_CUDA)
9-
file(GLOB_RECURSE HARDWARE_CUDA_SRC src/devices/nvidia/*.cu src/devices/nvidia/*.cc)
10-
endif()
11-
12-
if(USE_BANG)
13-
file(GLOB_RECURSE HARDWARE_BANG_SRC src/devices/mlu/*.cc)
9+
file(GLOB_RECURSE HARDWARE_CUDA_SRC src/devices/nvidia/*.cu)
1410
endif()
1511

1612
add_library(hardware STATIC ${HARDWARE_SRC} ${HARDWARE_CUDA_SRC} ${HARDWARE_BANG_SRC})

src/02hardware/src/device_manager.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "hardware/device_manager.h"
22
#include "hardware/devices/cpu.h"
33
#include "hardware/devices/nvidia.h"
4+
#include "hardware/devices/mlu.h"
45

56
namespace refactor::hardware::device {
67

@@ -37,6 +38,7 @@ namespace refactor::hardware::device {
3738
using T = Device::Type;
3839
// clang-format off
3940
auto device = type == T::Nvidia ? std::make_shared<Nvidia>(card)
41+
: type == T::Mlu ? std::make_shared<Mlu>(card)
4042
: UNREACHABLEX(Arc<Device>, "");
4143
// clang-format on
4244
auto [kind, ok] = DEVICES.try_emplace(static_cast<int32_t>(type));

src/02hardware/src/devices/mlu/device.cc

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
1-
#include "functions.cc"
1+
#include "functions.hh"
22
#include "hardware/devices/mlu.h"
33
#include "hardware/mem_pool.h"
44
#include "memory.hh"
55

66
namespace refactor::hardware {
77

88
static Arc<Memory> bangMemory(int32_t card) {
9+
#ifdef USE_BANG
910
ASSERT(0 <= card && card < getDeviceCount(), "Invalid card id: {}", card);
1011
setDevice(card);
1112
auto [free, total] = getMemInfo();
1213
auto size = std::min(free, std::max(5ul << 30, total * 4 / 5));
13-
fmt::println("initializing Nvidia GPU {}, memory {} / {}, alloc {}",
14+
fmt::println("initializing Cambricon MLU {}, memory {} / {}, alloc {}",
1415
card, free, total, size);
1516
return std::make_shared<MemPool>(
1617
std::make_shared<MluMemory>(),
1718
size,
1819
256ul);
20+
#else
21+
return nullptr;
22+
#endif
1923
}
2024

2125
Mlu::Mlu(int32_t card) : Device(card, bangMemory(card)) {}

src/02hardware/src/devices/mlu/functions.cc

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
namespace refactor::hardware {
44

5+
#ifdef USE_BANG
56
int getDeviceCount() {
67
unsigned deviceCount;
78
BANG_ASSERT(cnrtGetDeviceCount(&deviceCount));
@@ -15,5 +16,6 @@ namespace refactor::hardware {
1516
BANG_ASSERT(cnrtMemGetInfo(&memInfo.free, &memInfo.total));
1617
return memInfo;
1718
}
19+
#endif
1820

1921
}// namespace refactor::hardware

src/02hardware/src/devices/mlu/functions.hh

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
#ifndef HARDWARE_DEVICES_MLU_FUNCTIONS_CUH
22
#define HARDWARE_DEVICES_MLU_FUNCTIONS_CUH
33

4-
#include "cnrt.h"
54
#include "common.h"
65

6+
#ifdef USE_BANG
7+
#include "cnrt.h"
8+
79
#define BANG_ASSERT(STATUS) \
810
if (auto status = (STATUS); status != CNRT_RET_SUCCESS) { \
911
RUNTIME_ERROR(fmt::format("bang failed on \"" #STATUS "\" with \"{}\" ({})", \
1012
cnrtGetErrorStr(status), (int) status)); \
1113
}
14+
#endif
1215

1316
namespace refactor::hardware {
1417

src/02hardware/src/devices/mlu/memory.cc

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "functions.hh"
33

44
namespace refactor::hardware {
5+
#ifdef USE_BANG
56
using M = MluMemory;
67

78
void *M::malloc(size_t size) {
@@ -27,5 +28,6 @@ namespace refactor::hardware {
2728
CNRT_MEM_TRANS_DIR_PEER2PEER));
2829
return dst;
2930
}
31+
#endif
3032

3133
}// namespace refactor::hardware

src/02hardware/src/devices/nvidia/device.cc

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
namespace refactor::hardware {
77

88
static Arc<Memory> cudaMemory(int32_t card) {
9+
#ifdef USE_CUDA
910
ASSERT(0 <= card && card < getDeviceCount(), "Invalid card id: {}", card);
1011
setDevice(card);
1112
auto [free, total] = getMemInfo();
@@ -16,6 +17,9 @@ namespace refactor::hardware {
1617
std::make_shared<NvidiaMemory>(),
1718
size,
1819
256ul);
20+
#else
21+
return nullptr;
22+
#endif
1923
}
2024

2125
Nvidia::Nvidia(int32_t card) : Device(card, cudaMemory(card)) {}

src/04kernel/src/collectors/batch_normalization.cc

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "kernel/collectors/batch_normalization.h"
22
#include "../kernels/batch_normalization/cpu_kernel.hh"
33
#include "../kernels/batch_normalization/cudnn_kernel.hh"
4+
#include "../kernels/batch_normalization/cnnl_kernel.hh"
45

56
namespace refactor::kernel {
67

@@ -20,6 +21,9 @@ namespace refactor::kernel {
2021
case decltype(_target)::Nvidia:
2122
REGISTER(BatchNormalizationCudnn)
2223
break;
24+
case decltype(_target)::Mlu:
25+
REGISTER(BatchNormalizationCnnl)
26+
break;
2327
default:
2428
UNREACHABLEX(void, "Unknown target");
2529
}

src/04kernel/src/collectors/pool.cc

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "kernel/collectors/pool.h"
22
#include "../kernels/pool/cudnn_kernel.hh"
3+
#include "../kernels/pool/cnnl_kernel.hh"
34

45
namespace refactor::kernel {
56

@@ -29,6 +30,11 @@ namespace refactor::kernel {
2930
ans.emplace_back(std::move(ptr));
3031
}
3132
break;
33+
case decltype(_target)::Mlu:
34+
if (auto ptr = PoolCnnl::build(type, ceil, kernelShape, attributes, x, y); ptr) {
35+
ans.emplace_back(std::move(ptr));
36+
}
37+
break;
3238
default:
3339
UNREACHABLEX(void, "Unknown target");
3440
}

src/04kernel/src/collectors/reduce.cc

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "kernel/collectors/reduce.h"
22
#include "../kernels/reduce/cpu_kernel.hh"
33
#include "../kernels/reduce/cudnn_kernel.hh"
4+
#include "../kernels/reduce/cnnl_kernel.hh"
45

56
namespace refactor::kernel {
67

@@ -27,6 +28,9 @@ namespace refactor::kernel {
2728
case decltype(_target)::Nvidia:
2829
REGISTER(ReduceCudnn)
2930
break;
31+
case decltype(_target)::Mlu:
32+
REGISTER(ReduceCnnl)
33+
break;
3034
default:
3135
UNREACHABLEX(void, "Unknown target");
3236
}

src/04kernel/src/collectors/simple_binary.cc

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "../kernels/simple_binary/binary_cudnn.hh"
33
#include "../kernels/simple_binary/cpu_kernel.hh"
44
#include "../kernels/simple_binary/cuda_kernel.hh"
5+
#include "../kernels/simple_binary/binary_cnnl.hh"
56

67
namespace refactor::kernel {
78

@@ -48,6 +49,9 @@ namespace refactor::kernel {
4849
REGISTER_BROCAST(BinaryCudnn)
4950
REGISTER(BinaryCuda)
5051
break;
52+
case decltype(_target)::Mlu:
53+
REGISTER_BROCAST(BinaryCnnl)
54+
break;
5155
default:
5256
UNREACHABLEX(void, "Unknown target");
5357
}

src/04kernel/src/collectors/simple_unary.cc

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include "../kernels/simple_unary/cpu_kernel.hh"
33
#include "../kernels/simple_unary/cuda_kernel.hh"
44
#include "../kernels/simple_unary/cudnn_activation_kernel.hh"
5+
#include "../kernels/simple_unary/cnnl_activation_kernel.hh"
6+
#include "../kernels/simple_unary/cnnl_simple_unary_kernel.hh"
57
#include "common.h"
68

79
namespace refactor::kernel {
@@ -54,6 +56,10 @@ namespace refactor::kernel {
5456
REGISTER(ActivationCudnn)
5557
REGISTER(SimpleUnaryCuda)
5658
break;
59+
case decltype(_target)::Mlu:
60+
REGISTER(ActivationCnnl)
61+
REGISTER(SimpleUnaryCnnl)
62+
break;
5763
default:
5864
UNREACHABLEX(void, "Unknown target");
5965
}

src/04kernel/src/collectors/softmax.cc

+7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "kernel/collectors/softmax.h"
2+
#include "../kernels/softmax/cnnl_kernel.hh"
23
#include "../kernels/softmax/cpu_kernel.hh"
34
#include "../kernels/softmax/cuda_kernel.hh"
45
#include "../kernels/softmax/cudnn_kernel.hh"
@@ -28,6 +29,12 @@ namespace refactor::kernel {
2829
}
2930
break;
3031
}
32+
case decltype(_target)::Mlu: {
33+
if (auto ptr = SoftmaxCnnl::build(cnnl::SoftmaxAlgo::ACCURATE, info); ptr) {
34+
ans.emplace_back(std::move(ptr));
35+
}
36+
break;
37+
}
3138
default:
3239
UNREACHABLEX(void, "Unknown target");
3340
}

src/04kernel/src/collectors/transpose.cc

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "kernel/collectors/transpose.h"
22
#include "../kernels/transpose/cpu_kernel.hh"
33
#include "../kernels/transpose/cuda_kernel.hh"
4+
#include "../kernels/transpose/cnnl_kernel.hh"
45

56
namespace refactor::kernel {
67

@@ -25,6 +26,11 @@ namespace refactor::kernel {
2526
ans.emplace_back(std::move(ptr));
2627
}
2728
break;
29+
case decltype(_target)::Mlu:
30+
if (auto ptr = TransposeCnnl::build(data.dataType, data.shape, perm); ptr) {
31+
ans.emplace_back(std::move(ptr));
32+
}
33+
break;
2834
default:
2935
UNREACHABLEX(void, "Unknown target");
3036
}

0 commit comments

Comments
 (0)