1212 SupportsBufferProtocol ,
1313 )
1414 from collections .abc import Sequence
15- from ._dtypes import _all_dtypes
15+ from ._dtypes import _DType , _all_dtypes
1616
1717import numpy as np
1818
1919
2020def _check_valid_dtype (dtype ):
2121 # Note: Only spelling dtypes as the dtype objects is supported.
22-
23- # We use this instead of "dtype in _all_dtypes" because the dtype objects
24- # define equality with the sorts of things we want to disallow.
25- for d in (None ,) + _all_dtypes :
26- if dtype is d :
27- return
28- raise ValueError ("dtype must be one of the supported dtypes" )
22+ if dtype not in (None ,) + _all_dtypes :
23+ raise ValueError ("dtype must be one of the supported dtypes" )
2924
3025
3126def asarray (
@@ -50,10 +45,13 @@ def asarray(
5045 """
5146 # _array_object imports in this file are inside the functions to avoid
5247 # circular imports
53- from ._array_object import Array
48+ from ._array_object import Array , CPU_DEVICE
5449
5550 _check_valid_dtype (dtype )
56- if device not in ["cpu" , None ]:
51+ _np_dtype = None
52+ if dtype is not None :
53+ _np_dtype = dtype ._np_dtype
54+ if device not in [CPU_DEVICE , None ]:
5755 raise ValueError (f"Unsupported device { device !r} " )
5856 if copy in (False , np ._CopyMode .IF_NEEDED ):
5957 # Note: copy=False is not yet implemented in np.asarray
@@ -62,13 +60,13 @@ def asarray(
6260 if dtype is not None and obj .dtype != dtype :
6361 copy = True
6462 if copy in (True , np ._CopyMode .ALWAYS ):
65- return Array ._new (np .array (obj ._array , copy = True , dtype = dtype ))
63+ return Array ._new (np .array (obj ._array , copy = True , dtype = _np_dtype ))
6664 return obj
6765 if dtype is None and isinstance (obj , int ) and (obj > 2 ** 64 or obj < - (2 ** 63 )):
6866 # Give a better error message in this case. NumPy would convert this
6967 # to an object array. TODO: This won't handle large integers in lists.
7068 raise OverflowError ("Integer out of bounds for array dtypes" )
71- res = np .asarray (obj , dtype = dtype )
69+ res = np .asarray (obj , dtype = _np_dtype )
7270 return Array ._new (res )
7371
7472
@@ -86,11 +84,13 @@ def arange(
8684
8785 See its docstring for more information.
8886 """
89- from ._array_object import Array
87+ from ._array_object import Array , CPU_DEVICE
9088
9189 _check_valid_dtype (dtype )
92- if device not in ["cpu" , None ]:
90+ if device not in [CPU_DEVICE , None ]:
9391 raise ValueError (f"Unsupported device { device !r} " )
92+ if dtype is not None :
93+ dtype = dtype ._np_dtype
9494 return Array ._new (np .arange (start , stop = stop , step = step , dtype = dtype ))
9595
9696
@@ -105,11 +105,13 @@ def empty(
105105
106106 See its docstring for more information.
107107 """
108- from ._array_object import Array
108+ from ._array_object import Array , CPU_DEVICE
109109
110110 _check_valid_dtype (dtype )
111- if device not in ["cpu" , None ]:
111+ if device not in [CPU_DEVICE , None ]:
112112 raise ValueError (f"Unsupported device { device !r} " )
113+ if dtype is not None :
114+ dtype = dtype ._np_dtype
113115 return Array ._new (np .empty (shape , dtype = dtype ))
114116
115117
@@ -121,11 +123,13 @@ def empty_like(
121123
122124 See its docstring for more information.
123125 """
124- from ._array_object import Array
126+ from ._array_object import Array , CPU_DEVICE
125127
126128 _check_valid_dtype (dtype )
127- if device not in ["cpu" , None ]:
129+ if device not in [CPU_DEVICE , None ]:
128130 raise ValueError (f"Unsupported device { device !r} " )
131+ if dtype is not None :
132+ dtype = dtype ._np_dtype
129133 return Array ._new (np .empty_like (x ._array , dtype = dtype ))
130134
131135
@@ -143,11 +147,13 @@ def eye(
143147
144148 See its docstring for more information.
145149 """
146- from ._array_object import Array
150+ from ._array_object import Array , CPU_DEVICE
147151
148152 _check_valid_dtype (dtype )
149- if device not in ["cpu" , None ]:
153+ if device not in [CPU_DEVICE , None ]:
150154 raise ValueError (f"Unsupported device { device !r} " )
155+ if dtype is not None :
156+ dtype = dtype ._np_dtype
151157 return Array ._new (np .eye (n_rows , M = n_cols , k = k , dtype = dtype ))
152158
153159
@@ -169,15 +175,17 @@ def full(
169175
170176 See its docstring for more information.
171177 """
172- from ._array_object import Array
178+ from ._array_object import Array , CPU_DEVICE
173179
174180 _check_valid_dtype (dtype )
175- if device not in ["cpu" , None ]:
181+ if device not in [CPU_DEVICE , None ]:
176182 raise ValueError (f"Unsupported device { device !r} " )
177183 if isinstance (fill_value , Array ) and fill_value .ndim == 0 :
178184 fill_value = fill_value ._array
185+ if dtype is not None :
186+ dtype = dtype ._np_dtype
179187 res = np .full (shape , fill_value , dtype = dtype )
180- if res .dtype not in _all_dtypes :
188+ if _DType ( res .dtype ) not in _all_dtypes :
181189 # This will happen if the fill value is not something that NumPy
182190 # coerces to one of the acceptable dtypes.
183191 raise TypeError ("Invalid input to full" )
@@ -197,13 +205,15 @@ def full_like(
197205
198206 See its docstring for more information.
199207 """
200- from ._array_object import Array
208+ from ._array_object import Array , CPU_DEVICE
201209
202210 _check_valid_dtype (dtype )
203- if device not in ["cpu" , None ]:
211+ if device not in [CPU_DEVICE , None ]:
204212 raise ValueError (f"Unsupported device { device !r} " )
213+ if dtype is not None :
214+ dtype = dtype ._np_dtype
205215 res = np .full_like (x ._array , fill_value , dtype = dtype )
206- if res .dtype not in _all_dtypes :
216+ if _DType ( res .dtype ) not in _all_dtypes :
207217 # This will happen if the fill value is not something that NumPy
208218 # coerces to one of the acceptable dtypes.
209219 raise TypeError ("Invalid input to full_like" )
@@ -225,11 +235,13 @@ def linspace(
225235
226236 See its docstring for more information.
227237 """
228- from ._array_object import Array
238+ from ._array_object import Array , CPU_DEVICE
229239
230240 _check_valid_dtype (dtype )
231- if device not in ["cpu" , None ]:
241+ if device not in [CPU_DEVICE , None ]:
232242 raise ValueError (f"Unsupported device { device !r} " )
243+ if dtype is not None :
244+ dtype = dtype ._np_dtype
233245 return Array ._new (np .linspace (start , stop , num , dtype = dtype , endpoint = endpoint ))
234246
235247
@@ -264,11 +276,13 @@ def ones(
264276
265277 See its docstring for more information.
266278 """
267- from ._array_object import Array
279+ from ._array_object import Array , CPU_DEVICE
268280
269281 _check_valid_dtype (dtype )
270- if device not in ["cpu" , None ]:
282+ if device not in [CPU_DEVICE , None ]:
271283 raise ValueError (f"Unsupported device { device !r} " )
284+ if dtype is not None :
285+ dtype = dtype ._np_dtype
272286 return Array ._new (np .ones (shape , dtype = dtype ))
273287
274288
@@ -280,11 +294,13 @@ def ones_like(
280294
281295 See its docstring for more information.
282296 """
283- from ._array_object import Array
297+ from ._array_object import Array , CPU_DEVICE
284298
285299 _check_valid_dtype (dtype )
286- if device not in ["cpu" , None ]:
300+ if device not in [CPU_DEVICE , None ]:
287301 raise ValueError (f"Unsupported device { device !r} " )
302+ if dtype is not None :
303+ dtype = dtype ._np_dtype
288304 return Array ._new (np .ones_like (x ._array , dtype = dtype ))
289305
290306
@@ -327,11 +343,13 @@ def zeros(
327343
328344 See its docstring for more information.
329345 """
330- from ._array_object import Array
346+ from ._array_object import Array , CPU_DEVICE
331347
332348 _check_valid_dtype (dtype )
333- if device not in ["cpu" , None ]:
349+ if device not in [CPU_DEVICE , None ]:
334350 raise ValueError (f"Unsupported device { device !r} " )
351+ if dtype is not None :
352+ dtype = dtype ._np_dtype
335353 return Array ._new (np .zeros (shape , dtype = dtype ))
336354
337355
@@ -343,9 +361,11 @@ def zeros_like(
343361
344362 See its docstring for more information.
345363 """
346- from ._array_object import Array
364+ from ._array_object import Array , CPU_DEVICE
347365
348366 _check_valid_dtype (dtype )
349- if device not in ["cpu" , None ]:
367+ if device not in [CPU_DEVICE , None ]:
350368 raise ValueError (f"Unsupported device { device !r} " )
369+ if dtype is not None :
370+ dtype = dtype ._np_dtype
351371 return Array ._new (np .zeros_like (x ._array , dtype = dtype ))
0 commit comments