diff --git a/evaluate.py b/evaluate.py index b4a77a0..9bb8270 100644 --- a/evaluate.py +++ b/evaluate.py @@ -131,7 +131,7 @@ def predict_sliding(net, image, tile_size, classes, recurrence): if isinstance(padded_prediction, list): padded_prediction = padded_prediction[0] padded_prediction = interp(padded_prediction).cpu().numpy().transpose(0,2,3,1) - prediction = padded_prediction[0, 0:img.shape[2], 0:img.shape[3], :] + prediction = padded_prediction[:, 0:img.shape[2], 0:img.shape[3], :] count_predictions[0, y1:y2, x1:x2] += 1 full_probs[:, y1:y2, x1:x2] += prediction # accumulate the predictions also in the overlapping regions