diff --git a/algopy/utpm/utpm.py b/algopy/utpm/utpm.py index ddfafc7..80b0ba3 100644 --- a/algopy/utpm/utpm.py +++ b/algopy/utpm/utpm.py @@ -1976,7 +1976,7 @@ def outer(cls, x, y, out = None): assert x_shp[:2] == y_shp[:2] assert len(y_shp[2:]) == 1 - out_shp = x_shp + x_shp[-1:] + out_shp = x_shp + y_shp[-1:] out = cls(cls.__zeros__(out_shp, dtype = x.data.dtype)) cls._outer( x.data, y.data, out = out.data)