blob: 7eb537d54286fbf1323f310eb3395197b563c4c7 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
diff --git a/aesara/tensor/nnet/corr.py b/aesara/tensor/nnet/corr.py
index e89054d..77ed344 100644
--- a/aesara/tensor/nnet/corr.py
+++ b/aesara/tensor/nnet/corr.py
@@ -692,12 +692,12 @@ class CorrMM(BaseCorrMM):
if kern.type.ndim != 4:
raise TypeError("kern must be 4D tensor")
- out_shape = tuple(
+ out_shape = tuple([
1 if img.type.shape[0] == 1 else None,
1 if kern.type.shape[0] == 1 else None,
None,
None,
- )
+ ])
dtype = img.type.dtype
return Apply(self, [img, kern], [TensorType(dtype, shape=out_shape)()])
|