Skip to content

Commit f3c8eac

Browse files
Add Ascend NPU accelerator support (deepspeedai#3595)
* add Ascend NPU accelerator support * clean code --------- Co-authored-by: jializheng <jializheng@huawei.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
1 parent 52907a6 commit f3c8eac

File tree

1 file changed

+241
-0
lines changed

1 file changed

+241
-0
lines changed

accelerator/npu_accelerator.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
from .abstract_accelerator import DeepSpeedAccelerator
7+
# During setup stage torch may not be installed, pass on no torch will
8+
# allow op builder related API to be executed.
9+
try:
10+
import torch.npu
11+
except ImportError:
12+
pass
13+
14+
15+
class NPU_Accelerator(DeepSpeedAccelerator):
16+
17+
def __init__(self):
18+
self._name = 'npu'
19+
self._communication_backend_name = 'hccl'
20+
21+
def is_synchronized_device(self):
22+
return False
23+
24+
# Device APIs
25+
def device_name(self, device_index=None):
26+
if device_index == None:
27+
return 'npu'
28+
return 'npu:{}'.format(device_index)
29+
30+
def device(self, device_index=None):
31+
return torch.npu.device(device_index)
32+
33+
def set_device(self, device_index):
34+
torch.npu.set_device(device_index)
35+
36+
def current_device(self):
37+
return torch.npu.current_device()
38+
39+
def current_device_name(self):
40+
return 'npu:{}'.format(torch.npu.current_device())
41+
42+
def device_count(self):
43+
return torch.npu.device_count()
44+
45+
def synchronize(self, device_index=None):
46+
return torch.npu.synchronize(device_index)
47+
48+
# RNG APIs
49+
def random(self):
50+
return torch.random
51+
52+
def set_rng_state(self, new_state, device_index=None):
53+
if device_index is None:
54+
return torch.npu.set_rng_state(new_state)
55+
56+
return torch.npu.set_rng_state(new_state, device_index)
57+
58+
def get_rng_state(self, device_index=None):
59+
if device_index is None:
60+
return torch.npu.get_rng_state()
61+
62+
return torch.npu.get_rng_state(device_index)
63+
64+
def manual_seed(self, seed):
65+
return torch.npu.manual_seed(seed)
66+
67+
def manual_seed_all(self, seed):
68+
return torch.npu.manual_seed_all(seed)
69+
70+
def initial_seed(self, seed):
71+
return torch.npu.initial_seed(seed)
72+
73+
def default_generator(self, device_index):
74+
return torch.npu.default_generators[device_index]
75+
76+
# Streams/Events
77+
@property
78+
def Stream(self):
79+
return torch.npu.Stream
80+
81+
def stream(self, stream):
82+
return torch.npu.stream(stream)
83+
84+
def current_stream(self, device_index=None):
85+
return torch.npu.current_stream(device_index)
86+
87+
def default_stream(self, device_index=None):
88+
return torch.npu.default_stream(device_index)
89+
90+
@property
91+
def Event(self):
92+
return torch.npu.Event
93+
94+
# Memory management
95+
def empty_cache(self):
96+
return torch.npu.empty_cache()
97+
98+
def memory_allocated(self, device_index=None):
99+
return torch.npu.memory_allocated(device_index)
100+
101+
def max_memory_allocated(self, device_index=None):
102+
return torch.npu.max_memory_allocated(device_index)
103+
104+
def reset_max_memory_allocated(self, device_index=None):
105+
return torch.npu.reset_max_memory_allocated(device_index)
106+
107+
def memory_cached(self, device_index=None):
108+
return torch.npu.memory_cached(device_index)
109+
110+
def max_memory_cached(self, device_index=None):
111+
return torch.npu.max_memory_cached(device_index)
112+
113+
def reset_max_memory_cached(self, device_index=None):
114+
return torch.npu.reset_max_memory_cached(device_index)
115+
116+
def memory_stats(self, device_index=None):
117+
if hasattr(torch.npu, 'memory_stats'):
118+
return torch.npu.memory_stats(device_index)
119+
120+
def reset_peak_memory_stats(self, device_index=None):
121+
if hasattr(torch.npu, 'reset_peak_memory_stats'):
122+
return torch.npu.reset_peak_memory_stats(device_index)
123+
124+
def memory_reserved(self, device_index=None):
125+
if hasattr(torch.npu, 'memory_reserved'):
126+
return torch.npu.memory_reserved(device_index)
127+
128+
def max_memory_reserved(self, device_index=None):
129+
if hasattr(torch.npu, 'max_memory_reserved'):
130+
return torch.npu.max_memory_reserved(device_index)
131+
132+
def total_memory(self, device_index=None):
133+
return torch.npu.get_device_properties(device_index).total_memory
134+
135+
# Data types
136+
def is_bf16_supported(self):
137+
return torch.npu.is_bf16_supported()
138+
139+
def is_fp16_supported(self):
140+
return True
141+
142+
# Misc
143+
def amp(self):
144+
if hasattr(torch.npu, 'amp'):
145+
return torch.npu.amp
146+
return None
147+
148+
def is_available(self):
149+
return torch.npu.is_available()
150+
151+
def range_push(self, msg):
152+
return
153+
154+
def range_pop(self):
155+
return
156+
157+
def lazy_call(self, callback):
158+
return torch.npu._lazy_call(callback)
159+
160+
def communication_backend_name(self):
161+
return self._communication_backend_name
162+
163+
# Tensor operations
164+
165+
@property
166+
def BFloat16Tensor(self):
167+
return torch.npu.BFloat16Tensor
168+
169+
@property
170+
def ByteTensor(self):
171+
return torch.npu.ByteTensor
172+
173+
@property
174+
def DoubleTensor(self):
175+
return torch.npu.DoubleTensor
176+
177+
@property
178+
def FloatTensor(self):
179+
return torch.npu.FloatTensor
180+
181+
@property
182+
def HalfTensor(self):
183+
return torch.npu.HalfTensor
184+
185+
@property
186+
def IntTensor(self):
187+
return torch.npu.IntTensor
188+
189+
@property
190+
def LongTensor(self):
191+
return torch.npu.LongTensor
192+
193+
def pin_memory(self, tensor):
194+
return tensor.pin_memory()
195+
196+
def on_accelerator(self, tensor):
197+
device_str = str(tensor.device)
198+
if device_str.startswith('npu:'):
199+
return True
200+
else:
201+
return False
202+
203+
def op_builder_dir(self):
204+
try:
205+
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
206+
# if successful this also means we're doing a local install and not JIT compile path
207+
from op_builder import __deepspeed__ # noqa: F401
208+
return "op_builder.npu"
209+
except ImportError:
210+
return "deepspeed.ops.op_builder.npu"
211+
212+
# dict that holds class name <--> class type mapping i.e.
213+
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
214+
# this dict will be filled at init stage
215+
class_dict = None
216+
217+
def _lazy_init_class_dict(self):
218+
if self.class_dict != None:
219+
return
220+
else:
221+
self.class_dict = {}
222+
223+
# create an instance of op builder and return, name specified by class_name
224+
def create_op_builder(self, class_name):
225+
self._lazy_init_class_dict()
226+
if class_name in self.class_dict:
227+
return self.class_dict[class_name]()
228+
else:
229+
return None
230+
231+
# return an op builder class, name specified by class_name
232+
def get_op_builder(self, class_name):
233+
self._lazy_init_class_dict()
234+
if class_name in self.class_dict:
235+
return self.class_dict[class_name]
236+
else:
237+
return None
238+
239+
def build_extension(self):
240+
from torch.utils.cpp_extension import BuildExtension
241+
return BuildExtension

0 commit comments

Comments
 (0)