Skip to content

tensor

Define IRx's backend-neutral Tensor metadata helpers on top of the canonical buffer/view substrate and the Arrow C++ backed tensor runtime.

Classes:

Functions:

TensorLayout dataclass

TensorLayout(
    shape: tuple[int, ...],
    strides: tuple[int, ...],
    offset_bytes: int = 0,
)

Represent the logical rank, shape, strides, and byte offset of one Tensor value without duplicating the lower-level storage machinery. attributes: shape: type: tuple[int, Ellipsis] strides: type: tuple[int, Ellipsis] offset_bytes: type: int

Attributes:

ndim property

ndim: int

TensorOrder

Bases: str, Enum

tensor_buffer_dtype

tensor_buffer_dtype(
    type_: DataType | None,
) -> BufferHandle | None
Source code in packages/irx/src/irx/builtins/collections/tensor.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
@public
@typechecked
def tensor_buffer_dtype(type_: astx.DataType | None) -> BufferHandle | None:
    """
    title: Return the canonical buffer dtype handle for one Tensor element.
    parameters:
      type_:
        type: astx.DataType | None
    returns:
      type: BufferHandle | None
    """
    primitive_name = tensor_primitive_type_name(type_)
    if primitive_name is None:
        return None
    return buffer_dtype_handle(primitive_name)

tensor_buffer_view_metadata

tensor_buffer_view_metadata(
    *,
    data: BufferHandle,
    owner: BufferHandle,
    dtype: BufferHandle,
    layout: TensorLayout,
    ownership: BufferOwnership,
    mutability: BufferMutability,
    has_validity_bitmap: bool = False,
) -> BufferViewMetadata
Source code in packages/irx/src/irx/builtins/collections/tensor.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
@public
@typechecked
def tensor_buffer_view_metadata(
    *,
    data: BufferHandle,
    owner: BufferHandle,
    dtype: BufferHandle,
    layout: TensorLayout,
    ownership: BufferOwnership,
    mutability: BufferMutability,
    has_validity_bitmap: bool = False,
) -> BufferViewMetadata:
    """
    title: Bridge one Tensor layout into canonical buffer/view metadata.
    parameters:
      data:
        type: BufferHandle
      owner:
        type: BufferHandle
      dtype:
        type: BufferHandle
      layout:
        type: TensorLayout
      ownership:
        type: BufferOwnership
      mutability:
        type: BufferMutability
      has_validity_bitmap:
        type: bool
    returns:
      type: BufferViewMetadata
    """
    c_contiguous = False
    f_contiguous = False
    item_size_bytes = tensor_element_size_bytes_from_dtype(dtype)
    if item_size_bytes is not None:
        c_contiguous = tensor_is_c_contiguous(layout, item_size_bytes)
        f_contiguous = tensor_is_f_contiguous(layout, item_size_bytes)

    flags = buffer_view_flags(
        ownership,
        mutability,
        c_contiguous=c_contiguous,
        f_contiguous=f_contiguous,
    )
    if has_validity_bitmap:
        flags |= BUFFER_FLAG_VALIDITY_BITMAP

    return BufferViewMetadata(
        data=data,
        owner=owner,
        dtype=dtype,
        ndim=layout.ndim,
        shape=layout.shape,
        strides=layout.strides,
        offset_bytes=layout.offset_bytes,
        flags=flags,
    )

tensor_byte_bounds

tensor_byte_bounds(
    layout: TensorLayout,
) -> tuple[int, int] | None

The result is relative to the underlying data pointer. None means the logical layout has zero extent and therefore addresses no elements. parameters: layout: type: TensorLayout returns: type: tuple[int, int] | None

Source code in packages/irx/src/irx/builtins/collections/tensor.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
@public
@typechecked
def tensor_byte_bounds(
    layout: TensorLayout,
) -> tuple[int, int] | None:
    """
    title: Return the minimum and maximum element-start byte offsets.
    summary: >-
      The result is relative to the underlying data pointer. None means the
      logical layout has zero extent and therefore addresses no elements.
    parameters:
      layout:
        type: TensorLayout
    returns:
      type: tuple[int, int] | None
    """
    if tensor_element_count(layout) == 0:
        return None

    minimum = layout.offset_bytes
    maximum = layout.offset_bytes

    for dim, stride in zip(layout.shape, layout.strides, strict=True):
        if dim <= 1:
            continue
        span = (dim - 1) * stride
        if span < 0:
            minimum += span
        else:
            maximum += span

    return minimum, maximum

tensor_default_strides

tensor_default_strides(
    shape: tuple[int, ...],
    item_size_bytes: int,
    *,
    order: TensorOrder = C,
) -> tuple[int, ...]
Source code in packages/irx/src/irx/builtins/collections/tensor.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
@public
@typechecked
def tensor_default_strides(
    shape: tuple[int, ...],
    item_size_bytes: int,
    *,
    order: TensorOrder = TensorOrder.C,
) -> tuple[int, ...]:
    """
    title: Return canonical byte strides for one contiguous Tensor shape.
    parameters:
      shape:
        type: tuple[int, Ellipsis]
      item_size_bytes:
        type: int
      order:
        type: TensorOrder
    returns:
      type: tuple[int, Ellipsis]
    """
    if item_size_bytes <= 0:
        raise ValueError("tensor item_size_bytes must be positive")
    if not shape:
        return ()

    strides = [0] * len(shape)
    stride = item_size_bytes

    if order is TensorOrder.C:
        indices = range(len(shape) - 1, -1, -1)
    else:
        indices = range(len(shape))

    for axis in indices:
        strides[axis] = stride
        stride *= max(shape[axis], 1)

    return tuple(strides)

tensor_element_count

tensor_element_count(layout: TensorLayout) -> int
Source code in packages/irx/src/irx/builtins/collections/tensor.py
100
101
102
103
104
105
106
107
108
109
110
111
@public
@typechecked
def tensor_element_count(layout: TensorLayout) -> int:
    """
    title: Return the logical element count for one layout.
    parameters:
      layout:
        type: TensorLayout
    returns:
      type: int
    """
    return _shape_extent(layout.shape)

tensor_element_size_bytes

tensor_element_size_bytes(
    type_: DataType | None,
) -> int | None
Source code in packages/irx/src/irx/builtins/collections/tensor.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
@public
@typechecked
def tensor_element_size_bytes(type_: astx.DataType | None) -> int | None:
    """
    title: Return the byte width for one Tensor element type.
    parameters:
      type_:
        type: astx.DataType | None
    returns:
      type: int | None
    """
    primitive_name = tensor_primitive_type_name(type_)
    if primitive_name is None:
        return None
    spec = ARRAY_PRIMITIVE_TYPE_SPECS.get(primitive_name)
    if spec is None:
        return None
    return spec.element_size_bytes

tensor_element_size_bytes_from_dtype

tensor_element_size_bytes_from_dtype(
    dtype: BufferHandle,
) -> int | None
Source code in packages/irx/src/irx/builtins/collections/tensor.py
394
395
396
397
398
399
400
401
402
403
404
405
406
@typechecked
def tensor_element_size_bytes_from_dtype(dtype: BufferHandle) -> int | None:
    """
    title: Return the byte width for one canonical primitive dtype handle.
    parameters:
      dtype:
        type: BufferHandle
    returns:
      type: int | None
    """
    if dtype.is_null:
        return None
    return _DTYPE_ELEMENT_SIZE_BYTES.get(dtype)

tensor_is_c_contiguous

tensor_is_c_contiguous(
    layout: TensorLayout, item_size_bytes: int
) -> bool
Source code in packages/irx/src/irx/builtins/collections/tensor.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
@public
@typechecked
def tensor_is_c_contiguous(
    layout: TensorLayout,
    item_size_bytes: int,
) -> bool:
    """
    title: Return whether one layout matches canonical C-order strides.
    parameters:
      layout:
        type: TensorLayout
      item_size_bytes:
        type: int
    returns:
      type: bool
    """
    return layout.strides == tensor_default_strides(
        layout.shape,
        item_size_bytes,
        order=TensorOrder.C,
    )

tensor_is_f_contiguous

tensor_is_f_contiguous(
    layout: TensorLayout, item_size_bytes: int
) -> bool
Source code in packages/irx/src/irx/builtins/collections/tensor.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
@public
@typechecked
def tensor_is_f_contiguous(
    layout: TensorLayout,
    item_size_bytes: int,
) -> bool:
    """
    title: Return whether one layout matches canonical Fortran-order strides.
    parameters:
      layout:
        type: TensorLayout
      item_size_bytes:
        type: int
    returns:
      type: bool
    """
    return layout.strides == tensor_default_strides(
        layout.shape,
        item_size_bytes,
        order=TensorOrder.F,
    )

tensor_primitive_type_name

tensor_primitive_type_name(
    type_: DataType | None,
) -> str | None
Source code in packages/irx/src/irx/builtins/collections/tensor.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
@public
@typechecked
def tensor_primitive_type_name(type_: astx.DataType | None) -> str | None:
    """
    title: Return the builtin primitive storage name for one Tensor element.
    parameters:
      type_:
        type: astx.DataType | None
    returns:
      type: str | None
    """
    if isinstance(type_, astx.Boolean):
        return "bool"
    if isinstance(type_, astx.Int8):
        return "int8"
    if isinstance(type_, astx.Int16):
        return "int16"
    if isinstance(type_, astx.Int32):
        return "int32"
    if isinstance(type_, astx.Int64):
        return "int64"
    if isinstance(type_, astx.UInt8):
        return "uint8"
    if isinstance(type_, astx.UInt16):
        return "uint16"
    if isinstance(type_, astx.UInt32):
        return "uint32"
    if isinstance(type_, astx.UInt64):
        return "uint64"
    if isinstance(type_, astx.Float32):
        return "float32"
    if isinstance(type_, astx.Float64):
        return "float64"
    return None

validate_tensor_layout

validate_tensor_layout(
    layout: TensorLayout,
) -> tuple[str, ...]
Source code in packages/irx/src/irx/builtins/collections/tensor.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
@public
@typechecked
def validate_tensor_layout(
    layout: TensorLayout,
) -> tuple[str, ...]:
    """
    title: Validate one static Tensor layout.
    parameters:
      layout:
        type: TensorLayout
    returns:
      type: tuple[str, Ellipsis]
    """
    errors: list[str] = []

    if len(layout.strides) != layout.ndim:
        errors.append("tensor stride length must match ndim")
    if any(dim < 0 for dim in layout.shape):
        errors.append("tensor shape dimensions must be non-negative")
    if any(stride < 0 for stride in layout.strides):
        errors.append("tensor strides must be non-negative")
    if layout.offset_bytes < 0:
        errors.append("tensor offset_bytes must be non-negative")

    return tuple(errors)