torch.where(condition, x, y) → Tensor

Return a tensor of elements selected from either x or y, depending on condition.

The operation is defined as:

out i = { x i if condition i y i otherwise \text{out}_i = \begin{cases} \text{x}_i & \text{if } \text{condition}_i \\ \text{y}_i & \text{otherwise} \\ \end{cases}


The tensors condition, x, y must be broadcastable.


Currently valid scalar and tensor combination are 1. Scalar of floating dtype and torch.double 2. Scalar of integral dtype and torch.long 3. Scalar of complex dtype and torch.complex128

  • condition (BoolTensor) – When True (nonzero), yield x, otherwise yield y
  • x (Tensor or Scalar) – value (if :attr:x is a scalar) or values selected at indices where condition is True
  • y (Tensor or Scalar) – value (if :attr:x is a scalar) or values selected at indices where condition is False

A tensor of shape equal to the broadcasted shape of condition, x, y

Return type



>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)
torch.where(condition) → tuple of LongTensor

torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).


See also torch.nonzero().

