Special Lambda Operations¶
- sn_fma(a: SambaTensor, b: SambaTensor, c: SambaTensor) SambaTensor:¶
Performs a fused multiply add operation. Equivalent to
\[a \times b + c\]Example
>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_fma >>> import sambaflow.samba as samba >>> compute_fma = lambda attrs, x, y, z: sn_fma(x, y, z) >>> samba.set_seed(1) >>> x = samba.randn((2,5), dtype=torch.bfloat16) >>> y = samba.randn((2,5), dtype=torch.bfloat16) >>> z = samba.randn((2,5), dtype=torch.bfloat16) >>> out = sn_zipmapreduce(compute_fma, [x, y, z]) >>> out.data tensor([[ 3.0000, -1.8438, 1.5859, -1.8750, 0.1357], [-0.5234, 2.0312, 2.3438, 2.0312, -0.9180]], dtype=torch.bfloat16)
Note
This operation is only supported when used in a function/lambda that is input to
sn_zipmapreduce()
- sn_imm(input: float | int, dtype: torch.dtype | SNType) SambaTensor:¶
Creates an tensor containing a constant.
- Parameters:
input – input tensor.
dtype – dtype of the input.
- sn_iteridx(attrs: dict, dim: int, dtype: SNType | None) SambaTensor¶
creates an iterator within the body of a
sn_zipmapreduce()func.idx_dimspecifies the dimension that is iterated on, of the output tensor. For example, ifsn_zipmapreduce()has two inputs, [A,1], [1,B] andreduce_dimset to 1, the tensor shape after broadcast will be [A,B]. The tensor shape after reduce will be [A,1]. Then there will be two iterators available within thefuncbody:sn_iteridx(dim=0) = (0 until A by 1) sn_iteridx(dim=1) = (0 until B by 1)
sn_iteridx()can be treated as an input into thesn_zipmapreduce()funcand be used in thefuncbody.- Parameters:
attrs – attrs dictionary passed to the calling
sn_zipmapreduce().dim – dimension of the broadcasted shape.
dtype (optional) –
datatype of the iterator.
dtypemust be either a signed or unsigned integer datatype and its bit-width must match the bit-width of the inputs tosn_zipmapreduce(), the operation belongs to. If unspecified, then this is inferred based on the inputs to thefuncpassed to thesn_zipmapreduce()the operation belongs to. Note that the sign of the datatype is also inferred based on inputs i.e for signed inputs, the dtype will be int and for unsigned it will be uint. Thus, forint16inputs the iterator can only count up to \(2^{15} - 1\), and so the output dimension must fit in \(2^{15}\). The following types are supported:SNType.UINT16SNType.INT16SNType.INT32
Example
>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_iteridx, sn_select, sn_imm >>> from sambaflow.samba.utils import SNType >>> x = samba.SambaTensor(shape=(3, 3), dtype=torch.int32) >>> # Example: Creating a upper triangular matrix >>> def upper_tri(attrs, x): ... # if specified, dtype must match bitwidth of other dtypes in the lambda ... dim_r = sn_iteridx(attrs=attrs, dim=0, dtype=SNType.INT32) ... dim_c = sn_iteridx(attrs=attrs, dim=1, dtype=SNType.INT32) ... mask = dim_r <= dim_c ... one_imm = sn_imm(1, dtype=torch.int32) ... zero_imm = sn_imm(0, dtype=torch.int32) ... return sn_select(mask, one_imm, zero_imm) >>> diag_matrix = sn_zipmapreduce(upper_tri, [x]) >>> diag_matrix.data tensor([[1, 1, 1], [0, 1, 1], [0, 0, 1]], dtype=torch.int32)
Note
This operation is only supported when used inside a Python function/lambda that is input to
sn_zipmapreduce()
- sn_select(cond: SambaTensor, true_val: SambaTensor, false_val: SambaTensor) SambaTensor¶
Performs a select operation on tensors, similar to the
torch.where()function. This function selects elements fromtrue_valorfalse_valbased on the condition specified incond.- Parameters:
cond – a mask tensor where each element is a condition. If the condition is true, the corresponding element from
true_valis selected; otherwise the element fromfalse_valis chosen.true_val – the tensor from which elements are selected when the corresponding condition in
condis true.false_val – the tensor from which elements are selected when the corresponding condition in
condis false.
Example
>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_select >>> # Example showing the use of sn_select with tensors. >>> from sambaflow.samba.functional import sn_select >>> mask = samba.SambaTensor(torch.tensor([1, 0, 1], dtype=torch.int32)) >>> x = samba.SambaTensor(torch.tensor([1, 2, 3], dtype=torch.float)) >>> y = samba.SambaTensor(torch.tensor([4, 5, 6], dtype=torch.float)) >>> f = lambda attrs, mask, x, y: sn_select(mask, x, y) >>> result = sn_zipmapreduce(f, [mask, x, y]) >>> result.data tensor([1., 5., 3.]) >>> # result contains elements from x or y based on the mask. >>> # Here 1 and 3 are selected from x and 5 from y.
Note
This operation is only supported when used inside a Python function/lambda that is input to
sn_zipmapreduce()See also
For more details see
torch.where().
