Adapt Arx surface tensor syntax to IRx Tensor nodes while keeping user-facing
shape and indexing rules local to Arx.
Classes:
Functions:
attach_binding
Source code in packages/arx/src/arx/tensor.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203 | def attach_binding(node: astx.AST, binding: TensorBinding) -> None:
"""
title: Attach static tensor metadata to one AST node.
parameters:
node:
type: astx.AST
binding:
type: TensorBinding
"""
info = cast(SemanticInfo | None, getattr(node, "semantic", None))
if info is None or not isinstance(info, SemanticInfo):
info = SemanticInfo()
setattr(node, "semantic", info)
info.extras[TENSOR_LAYOUT_EXTRA] = binding.layout
info.extras[TENSOR_ELEMENT_TYPE_EXTRA] = binding.element_type
info.extras[TENSOR_FLAGS_EXTRA] = binding.flags
|
binding_from_type
binding_from_type(
data_type: DataType | None,
) -> TensorBinding | None
Source code in packages/arx/src/arx/tensor.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185 | def binding_from_type(
data_type: astx.DataType | None,
) -> TensorBinding | None:
"""
title: Build one static tensor binding from one declared type.
parameters:
data_type:
type: astx.DataType | None
returns:
type: TensorBinding | None
"""
if not is_tensor_type(data_type):
return None
shape = tensor_shape(data_type)
element_type = cast(astx.TensorType, data_type).element_type
if shape is None or element_type is None:
return None
item_size = _element_size_bytes(element_type)
layout = TensorLayout(
shape=shape,
strides=tensor_default_strides(shape, item_size),
offset_bytes=0,
)
flags = buffer_view_flags(
BufferOwnership.EXTERNAL_OWNER,
BufferMutability.READONLY,
c_contiguous=tensor_is_c_contiguous(layout, item_size),
f_contiguous=tensor_is_f_contiguous(layout, item_size),
)
return TensorBinding(
element_type=element_type,
layout=layout,
flags=flags,
)
|
build_literal_from_literal
build_literal_from_literal(
expr: Literal, target_type: DataType, *, context: str
) -> TensorLiteral
Source code in packages/arx/src/arx/tensor.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288 | def build_literal_from_literal(
expr: astx.Literal,
target_type: astx.DataType,
*,
context: str,
) -> astx.TensorLiteral:
"""
title: Build one tensor literal from one nested literal value.
parameters:
expr:
type: astx.Literal
target_type:
type: astx.DataType
context:
type: str
returns:
type: astx.TensorLiteral
"""
binding = binding_from_type(target_type)
if binding is None:
if not is_tensor_type(target_type):
raise ValueError("tensor literal target must be a tensor type")
raise ValueError(
"tensor literal target requires a static tensor shape"
)
shape, values = _flatten_literal(expr)
if shape != binding.layout.shape:
raise ValueError(
f"{context} has shape {_format_shape(shape)} but the declared "
f"tensor shape is {_format_shape(binding.layout.shape)}"
)
for value in values:
_validate_scalar_literal(value, binding.element_type, context=context)
return _literal_from_values(binding, values)
|
coerce_expression
coerce_expression(
expr: Expr, target_type: DataType, *, context: str
) -> Expr
Source code in packages/arx/src/arx/tensor.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230 | def coerce_expression(
expr: astx.Expr,
target_type: astx.DataType,
*,
context: str,
) -> astx.Expr:
"""
title: Coerce one parsed expression into one declared tensor type.
parameters:
expr:
type: astx.Expr
target_type:
type: astx.DataType
context:
type: str
returns:
type: astx.Expr
"""
if not is_tensor_type(target_type):
return expr
if isinstance(expr, astx.TensorLiteral):
return expr
if not isinstance(expr, astx.Literal):
return expr
return build_literal_from_literal(expr, target_type, context=context)
|
default_value
Source code in packages/arx/src/arx/tensor.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249 | def default_value(target_type: astx.DataType) -> astx.TensorLiteral:
"""
title: Build one default tensor literal for one declared type.
parameters:
target_type:
type: astx.DataType
returns:
type: astx.TensorLiteral
"""
binding = binding_from_type(target_type)
if binding is None:
raise ValueError("default tensor value requires a static tensor shape")
count = prod(binding.layout.shape)
scalar = _zero_literal(binding.element_type)
values = tuple(_clone_scalar_literal(scalar) for _ in range(count))
return _literal_from_values(binding, values)
|
infer_literal
Source code in packages/arx/src/arx/tensor.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318 | def infer_literal(expr: astx.Literal) -> astx.TensorLiteral:
"""
title: Infer one tensor literal directly from one literal value.
parameters:
expr:
type: astx.Literal
returns:
type: astx.TensorLiteral
"""
shape, values = _flatten_literal(expr)
if not values:
raise ValueError(
"cannot infer a tensor element type from an empty literal"
)
element_type = _infer_element_type(values[0])
for value in values:
_validate_scalar_literal(
value,
element_type,
context="tensor literal",
)
binding = cast(
TensorBinding,
binding_from_type(tensor_type(element_type, shape)),
)
return _literal_from_values(binding, values)
|
is_tensor_type
is_tensor_type(data_type: DataType | None) -> bool
Source code in packages/arx/src/arx/tensor.py
86
87
88
89
90
91
92
93
94
95
96
97
98 | def is_tensor_type(data_type: astx.DataType | None) -> bool:
"""
title: Return whether one type is an Arx tensor surface type.
parameters:
data_type:
type: astx.DataType | None
returns:
type: bool
"""
return (
isinstance(data_type, astx.TensorType)
and getattr(data_type, TENSOR_SURFACE_ATTR, False) is True
)
|
literal_values
Source code in packages/arx/src/arx/tensor.py
321
322
323
324
325
326
327
328
329
330
331
332 | def literal_values(
node: astx.TensorLiteral,
) -> tuple[astx.AST, ...]:
"""
title: Return one flattened scalar payload from a tensor literal.
parameters:
node:
type: astx.TensorLiteral
returns:
type: tuple[astx.AST, Ellipsis]
"""
return tuple(node.values)
|
runtime_tensor_type
runtime_tensor_type(element_type: DataType) -> TensorType
Source code in packages/arx/src/arx/tensor.py
137
138
139
140
141
142
143
144
145
146
147 | def runtime_tensor_type(element_type: astx.DataType) -> astx.TensorType:
"""
title: Build one runtime-shaped tensor surface type.
parameters:
element_type:
type: astx.DataType
returns:
type: astx.TensorType
"""
_element_size_bytes(element_type)
return _mark_tensor_type(astx.TensorType(element_type), None)
|
tensor_shape
tensor_shape(
data_type: DataType | None,
) -> tuple[int, ...] | None
Source code in packages/arx/src/arx/tensor.py
101
102
103
104
105
106
107
108
109
110
111
112 | def tensor_shape(data_type: astx.DataType | None) -> tuple[int, ...] | None:
"""
title: Return the declared tensor shape when available.
parameters:
data_type:
type: astx.DataType | None
returns:
type: tuple[int, Ellipsis] | None
"""
if data_type is None:
return None
return _shape_of(data_type)
|
tensor_type
Source code in packages/arx/src/arx/tensor.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134 | def tensor_type(
element_type: astx.DataType,
shape: tuple[int, ...],
) -> astx.TensorType:
"""
title: Build one static-shape tensor surface type.
parameters:
element_type:
type: astx.DataType
shape:
type: tuple[int, Ellipsis]
returns:
type: astx.TensorType
"""
_element_size_bytes(element_type)
if not shape:
raise ValueError("tensor shapes must include at least one dimension")
if any(dim < 0 for dim in shape):
raise ValueError("tensor dimensions must be non-negative")
return _mark_tensor_type(astx.TensorType(element_type), shape)
|