Segment COO¶
- torch_scatter.segment_coo(src: Tensor, index: Tensor, out: Tensor | None = None, dim_size: int | None = None, reduce: str = 'sum') Tensor[source]¶
Reduces all values from the
srctensor intooutat the indices specified in theindextensor along the last dimension ofindex. For each value insrc, its output index is specified by its index insrcfor dimensions outside ofindex.dim() - 1and by the corresponding value inindexfor dimensionindex.dim() - 1. The applied reduction is defined via thereduceargument.Formally, if
srcandindexare \(n\)-dimensional and \(m\)-dimensional tensors with size \((x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})\) and \((x_0, ..., x_{m-1}, x_m)\), respectively, thenoutmust be an \(n\)-dimensional tensor with size \((x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})\). Moreover, the values ofindexmust be between \(0\) and \(y - 1\) in ascending order. Theindextensor supports broadcasting in case its dimensions do not match withsrc.For one-dimensional tensors with
reduce="sum", the operation computes\[\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j\]where \(\sum_j\) is over \(j\) such that \(\mathrm{index}_j = i\).
In contrast to
scatter(), this method expects values inindexto be sorted along dimensionindex.dim() - 1. Due to the use of sorted indices,segment_coo()is usually faster than the more generalscatter()operation.Note
This operation is implemented via atomic operations on the GPU and is therefore non-deterministic since the order of parallel operations to the same value is undetermined. For floating-point variables, this results in a source of variance in the result.
- Parameters:
src – The source tensor.
index – The sorted indices of elements to segment. The number of dimensions of
indexneeds to be less than or equal tosrc.out – The destination tensor.
dim_size – If
outis not given, automatically create output with sizedim_sizeat dimensionindex.dim() - 1. Ifdim_sizeis not given, a minimal sized output tensor according toindex.max() + 1is returned.reduce – The reduce operation (
"sum","mean","min"or"max"). (default:"sum")
- Return type:
Tensor
from torch_scatter import segment_coo src = torch.randn(10, 6, 64) index = torch.tensor([0, 0, 1, 1, 1, 2]) index = index.view(1, -1) # Broadcasting in the first and last dim. out = segment_coo(src, index, reduce="sum") print(out.size())
torch.Size([10, 3, 64])