Skip to content

Commit 5d297f7

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS][BE] Combine two upsample_kernel_out_template into one (pytorch#148211)
- First, by stopp inverting sizes and strides, i.e. passing them as is, but reading them in inverse order in the shader as 1st stride of 4D tensor is one used for batches, 2nd for channels and 3rd and 4th for spatial coordinates - Pass `scales` as float2 even in linear tensor Above allows one to collide two flavors `upsample_kernel_out_template` into one Pull Request resolved: pytorch#148211 Approved by: https://github.com/dcci ghstack dependencies: pytorch#148154, pytorch#148187
1 parent 83fb974 commit 5d297f7

File tree

2 files changed

+65
-107
lines changed

2 files changed

+65
-107
lines changed

aten/src/ATen/native/mps/kernels/UpSample.metal

+50-50
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ scalar_t upsample_get_value_bounded(
123123
int access_y = max(min(y, dim.y - 1), 0L);
124124
int access_x = max(min(x, dim.x - 1), 0L);
125125
return data
126-
[n * strides.w + c * strides.z + access_y * strides.y +
127-
access_x * strides.x];
126+
[n * strides.x + c * strides.y + access_y * strides.z +
127+
access_x * strides.w];
128128
}
129129

130130
template <typename scalar_t>
@@ -136,7 +136,7 @@ scalar_t upsample_get_value_bounded(
136136
long c,
137137
long x) {
138138
int access_x = max(min(x, dim - 1), 0L);
139-
return data[n * strides.z + c * strides.y + access_x * strides.x];
139+
return data[n * strides.x + c * strides.y + access_x * strides.z];
140140
}
141141

142142
template <typename scalar_t>
@@ -153,8 +153,8 @@ void upsample_increment_value_bounded(
153153
int access_x = max(min(x, dim.x - 1), 0L);
154154
AtomicType<scalar_t>::atomic_add(
155155
data,
156-
n * strides.w + c * strides.z + access_y * strides.y +
157-
access_x * strides.x,
156+
n * strides.x + c * strides.y + access_y * strides.z +
157+
access_x * strides.w,
158158
value);
159159
}
160160

@@ -200,24 +200,24 @@ kernel void upsample_linear1d(
200200
constant ulong3& output_strides [[buffer(3)]],
201201
constant long3& input_sizes [[buffer(4)]],
202202
constant long3& output_sizes [[buffer(5)]],
203-
constant float& scale [[buffer(6)]],
203+
constant float2& scales [[buffer(6)]],
204204
constant bool& align_corners [[buffer(7)]],
205205
uint thread_index [[thread_position_in_grid]]) {
206206
auto output_x = thread_index;
207207
auto real_x = area_pixel_compute_source_index(
208-
scale, output_x, align_corners, /*cubic=*/false);
208+
scales.x, output_x, align_corners, /*cubic=*/false);
209209
auto t_x = fract(real_x);
210210

211-
for (int n = 0; n < output_sizes.z; n++) {
211+
for (int n = 0; n < output_sizes.x; n++) {
212212
for (int c = 0; c < output_sizes.y; c++) {
213213
auto i00 = upsample_get_value_bounded<T>(
214-
inputData, input_sizes.x, input_strides, n, c, real_x);
214+
inputData, input_sizes.z, input_strides, n, c, real_x);
215215
auto i01 = upsample_get_value_bounded<T>(
216-
inputData, input_sizes.x, input_strides, n, c, real_x + 1);
216+
inputData, input_sizes.z, input_strides, n, c, real_x + 1);
217217
auto res = linear_interp(i00, i01, t_x);
218218
outputData
219-
[n * output_strides.z + c * output_strides.y +
220-
output_x * output_strides.x] = static_cast<T>(res);
219+
[n * output_strides.x + c * output_strides.y +
220+
output_x * output_strides.z] = static_cast<T>(res);
221221
}
222222
}
223223
}
@@ -232,26 +232,26 @@ kernel void upsample_bilinear2d(
232232
constant float2& scales [[buffer(6)]],
233233
constant bool& align_corners [[buffer(7)]],
234234
uint thread_index [[thread_position_in_grid]]) {
235-
auto output_x = thread_index % output_sizes.x;
236-
auto output_y = thread_index / output_sizes.x;
235+
auto output_x = thread_index % output_sizes.w;
236+
auto output_y = thread_index / output_sizes.w;
237237
auto real_x = area_pixel_compute_source_index(
238238
scales.x, output_x, align_corners, /*cubic=*/false);
239239
auto t_x = fract(real_x);
240240

241241
auto real_y = area_pixel_compute_source_index(
242242
scales.y, output_y, align_corners, /*cubic=*/false);
243243
auto t_y = fract(real_y);
244-
for (int n = 0; n < output_sizes.w; n++) {
245-
for (int c = 0; c < output_sizes.z; c++) {
244+
for (int n = 0; n < output_sizes.x; n++) {
245+
for (int c = 0; c < output_sizes.y; c++) {
246246
auto i00 = upsample_get_value_bounded<T>(
247-
inputData, input_sizes.xy, input_strides, n, c, real_y, real_x);
247+
inputData, input_sizes.wz, input_strides, n, c, real_y, real_x);
248248
auto i01 = upsample_get_value_bounded<T>(
249-
inputData, input_sizes.xy, input_strides, n, c, real_y, real_x + 1);
249+
inputData, input_sizes.wz, input_strides, n, c, real_y, real_x + 1);
250250
auto i10 = upsample_get_value_bounded<T>(
251-
inputData, input_sizes.xy, input_strides, n, c, real_y + 1, real_x);
251+
inputData, input_sizes.wz, input_strides, n, c, real_y + 1, real_x);
252252
auto i11 = upsample_get_value_bounded<T>(
253253
inputData,
254-
input_sizes.xy,
254+
input_sizes.wz,
255255
input_strides,
256256
n,
257257
c,
@@ -261,8 +261,8 @@ kernel void upsample_bilinear2d(
261261
auto i1_l = linear_interp(i10, i11, t_x);
262262
auto res = linear_interp(i0_l, i1_l, t_y);
263263
outputData
264-
[n * output_strides.w + c * output_strides.z +
265-
output_x * output_strides.x + output_y * output_strides.y] =
264+
[n * output_strides.x + c * output_strides.y +
265+
output_y * output_strides.z + output_x * output_strides.w] =
266266
static_cast<T>(res);
267267
}
268268
}
@@ -283,36 +283,36 @@ kernel void upsample_bilinear2d_aa(
283283
constant float2& scales [[buffer(6)]],
284284
constant bool& align_corners [[buffer(7)]],
285285
uint thread_index [[thread_position_in_grid]]) {
286-
auto output_x = thread_index % output_sizes.x;
287-
auto output_y = thread_index / output_sizes.x;
286+
auto output_x = thread_index % output_sizes.w;
287+
auto output_y = thread_index / output_sizes.w;
288288
(void)align_corners; // Align corners is unused for AA algorithm
289289
auto x_center = area_pixel_compute_source_index(
290290
scales.x, output_x, /*align_corners=*/false, /*cubic=*/false);
291291
auto y_center = area_pixel_compute_source_index(
292292
scales.y, output_y, /*align_corners=*/false, /*cubic=*/false);
293293
auto clamped_scales = max(1.0, scales);
294294
auto x_min = max(0L, long(floor(x_center - clamped_scales.x + 1)));
295-
auto x_max = min(input_sizes.x, long(ceil(x_center + clamped_scales.x)));
295+
auto x_max = min(input_sizes.w, long(ceil(x_center + clamped_scales.x)));
296296
auto y_min = max(0L, long(floor(y_center - clamped_scales.y + 1)));
297-
auto y_max = min(input_sizes.y, long(ceil(y_center + clamped_scales.y)));
298-
for (int n = 0; n < output_sizes.w; n++) {
299-
for (int c = 0; c < output_sizes.z; c++) {
297+
auto y_max = min(input_sizes.z, long(ceil(y_center + clamped_scales.y)));
298+
for (int n = 0; n < output_sizes.x; n++) {
299+
for (int c = 0; c < output_sizes.y; c++) {
300300
float res = 0.0;
301301
float ws = 0.0;
302302
constant auto* input =
303-
inputData + n * input_strides.w + c * input_strides.z;
303+
inputData + n * input_strides.x + c * input_strides.y;
304304
for (auto y = y_min; y < y_max; ++y) {
305305
auto dy = bilinear_functor((y - y_center) / clamped_scales.y);
306306
for (auto x = x_min; x < x_max; ++x) {
307307
auto dx = bilinear_functor((x - x_center) / clamped_scales.x);
308-
auto val = input[x * input_strides.x + y * input_strides.y];
308+
auto val = input[x * input_strides.w + y * input_strides.z];
309309
res += val * dx * dy;
310310
ws += dx * dy;
311311
}
312312
}
313313
outputData
314-
[n * output_strides.w + c * output_strides.z +
315-
output_x * output_strides.x + output_y * output_strides.y] =
314+
[n * output_strides.x + c * output_strides.y +
315+
output_y * output_strides.z + output_x * output_strides.w] =
316316
static_cast<T>(res / ws);
317317
}
318318
}
@@ -329,8 +329,8 @@ kernel void upsample_bicubic2d(
329329
constant float2& scales [[buffer(6)]],
330330
constant bool& align_corners [[buffer(7)]],
331331
uint thread_index [[thread_position_in_grid]]) {
332-
auto output_x = thread_index % output_sizes.x;
333-
auto output_y = thread_index / output_sizes.x;
332+
auto output_x = thread_index % output_sizes.w;
333+
auto output_y = thread_index / output_sizes.w;
334334
auto real_x = area_pixel_compute_source_index(
335335
scales.x, output_x, align_corners, /*cubic=*/true);
336336
int in_x = floor(real_x);
@@ -340,38 +340,38 @@ kernel void upsample_bicubic2d(
340340
scales.y, output_y, align_corners, /*cubic=*/true);
341341
int in_y = floor(real_y);
342342
auto t_y = real_y - in_y;
343-
for (int n = 0; n < output_sizes.w; n++) {
344-
for (int c = 0; c < output_sizes.z; c++) {
343+
for (int n = 0; n < output_sizes.x; n++) {
344+
for (int c = 0; c < output_sizes.y; c++) {
345345
float coefficients[4];
346346
for (int k = 0; k < 4; k++) {
347347
coefficients[k] = cubic_interp1d(
348348
upsample_get_value_bounded<T>(
349349
inputData,
350-
input_sizes.xy,
350+
input_sizes.wz,
351351
input_strides,
352352
n,
353353
c,
354354
in_y - 1 + k,
355355
in_x - 1),
356356
upsample_get_value_bounded<T>(
357357
inputData,
358-
input_sizes.xy,
358+
input_sizes.wz,
359359
input_strides,
360360
n,
361361
c,
362362
in_y - 1 + k,
363363
in_x + 0),
364364
upsample_get_value_bounded<T>(
365365
inputData,
366-
input_sizes.xy,
366+
input_sizes.wz,
367367
input_strides,
368368
n,
369369
c,
370370
in_y - 1 + k,
371371
in_x + 1),
372372
upsample_get_value_bounded<T>(
373373
inputData,
374-
input_sizes.xy,
374+
input_sizes.wz,
375375
input_strides,
376376
n,
377377
c,
@@ -386,8 +386,8 @@ kernel void upsample_bicubic2d(
386386
coefficients[3],
387387
t_y));
388388
outputData
389-
[n * output_strides.w + c * output_strides.z +
390-
output_x * output_strides.x + output_y * output_strides.y] = inp;
389+
[n * output_strides.x + c * output_strides.y +
390+
output_y * output_strides.z + output_x * output_strides.w] = inp;
391391
}
392392
}
393393
}
@@ -403,8 +403,8 @@ kernel void upsample_bicubic2d_backward(
403403
constant float2& scales [[buffer(6)]],
404404
constant bool& align_corners [[buffer(7)]],
405405
uint thread_index [[thread_position_in_grid]]) {
406-
auto output_x = thread_index % output_sizes.x;
407-
auto output_y = thread_index / output_sizes.x;
406+
auto output_x = thread_index % output_sizes.w;
407+
auto output_y = thread_index / output_sizes.w;
408408
auto real_x = area_pixel_compute_source_index<float>(
409409
scales.x, output_x, align_corners, /*cubic=*/true);
410410
int input_x = floor(real_x);
@@ -421,16 +421,16 @@ kernel void upsample_bicubic2d_backward(
421421
get_cubic_upsampling_coefficients(x_coeffs, t_x);
422422
get_cubic_upsampling_coefficients(y_coeffs, t_y);
423423

424-
for (int n = 0; n < output_sizes.w; n++) {
425-
for (int c = 0; c < output_sizes.z; ++c) {
424+
for (int n = 0; n < output_sizes.x; n++) {
425+
for (int c = 0; c < output_sizes.y; ++c) {
426426
auto out_value = gradOutputData
427-
[n * output_strides.w + c * output_strides.z +
428-
output_x * output_strides.x + output_y * output_strides.y];
427+
[n * output_strides.x + c * output_strides.y +
428+
output_y * output_strides.z + output_x * output_strides.w];
429429
for (int i = 0; i < 4; i++) {
430430
for (int j = 0; j < 4; j++) {
431431
upsample_increment_value_bounded<T>(
432432
gradInputData,
433-
input_sizes.xy,
433+
input_sizes.wz,
434434
input_strides,
435435
n,
436436
c,
@@ -478,7 +478,7 @@ kernel void upsample_bicubic2d_backward(
478478
constant ulong3 & output_strides [[buffer(3)]], \
479479
constant long3 & input_sizes [[buffer(4)]], \
480480
constant long3 & output_sizes [[buffer(5)]], \
481-
constant float& scale [[buffer(6)]], \
481+
constant float2 & scales [[buffer(6)]], \
482482
constant bool& align_corners [[buffer(7)]], \
483483
uint thread_index [[thread_position_in_grid]])
484484

aten/src/ATen/native/mps/operations/UpSample.mm

+15-57
Original file line numberDiff line numberDiff line change
@@ -263,62 +263,28 @@ static void upsample_kernel_out_template(const Tensor& input,
263263
return;
264264
}
265265
std::array<float, 2> scales = {
266-
area_pixel_compute_scale<float>(input.size(3), output.size(3), align_corners, scale_w_opt),
266+
area_pixel_compute_scale<float>(input.size(-1), output.size(-1), align_corners, scale_w_opt),
267267
area_pixel_compute_scale<float>(input.size(2), output.size(2), align_corners, scale_h_opt)};
268268
auto upsamplePSO = lib.getPipelineStateForFunc(fmt::format("upsample_{}_{}", name, scalarToMetalTypeString(input)));
269269
auto stream = getCurrentMPSStream();
270270
dispatch_sync_with_rethrow(stream->queue(), ^() {
271271
@autoreleasepool {
272-
std::array<int64_t, 4> output_strides = {output.stride(3), output.stride(2), output.stride(1), output.stride(0)};
273-
std::array<int64_t, 4> output_sizes = {output.size(3), output.size(2), output.size(1), output.size(0)};
274-
std::array<int64_t, 4> input_sizes = {input.size(3), input.size(2), input.size(1), input.size(0)};
275-
std::array<int64_t, 4> input_strides = {input.stride(3), input.stride(2), input.stride(1), input.stride(0)};
276272
auto computeEncoder = stream->commandEncoder();
277273
[computeEncoder setComputePipelineState:upsamplePSO];
278274
mtl_setArgs(computeEncoder,
279275
input,
280276
output,
281-
input_strides,
282-
output_strides,
283-
input_sizes,
284-
output_sizes,
277+
input.strides(),
278+
output.strides(),
279+
input.sizes(),
280+
output.sizes(),
285281
scales,
286282
align_corners);
287-
mtl_dispatch1DJob(computeEncoder, upsamplePSO, output_size[0] * output_size[1]);
288-
}
289-
});
290-
}
291-
292-
static void upsample_kernel_out_template(const Tensor& input,
293-
IntArrayRef output_size,
294-
bool align_corners,
295-
std::optional<double> scale_opt,
296-
const Tensor& output,
297-
const std::string name) {
298-
if (output.numel() == 0) {
299-
return;
300-
}
301-
float scale = area_pixel_compute_scale<float>(input.size(2), output.size(2), align_corners, scale_opt);
302-
auto upsamplePSO = lib.getPipelineStateForFunc(fmt::format("upsample_{}_{}", name, scalarToMetalTypeString(input)));
303-
auto stream = getCurrentMPSStream();
304-
dispatch_sync_with_rethrow(stream->queue(), ^() {
305-
@autoreleasepool {
306-
std::array<int64_t, 3> output_strides = {output.stride(2), output.stride(1), output.stride(0)};
307-
std::array<int64_t, 3> output_sizes = {output.size(2), output.size(1), output.size(0)};
308-
std::array<int64_t, 3> input_sizes = {input.size(2), input.size(1), input.size(0)};
309-
std::array<int64_t, 3> input_strides = {input.stride(2), input.stride(1), input.stride(0)};
310-
auto computeEncoder = stream->commandEncoder();
311-
[computeEncoder setComputePipelineState:upsamplePSO];
312-
mtl_setArgs(computeEncoder,
313-
input,
314-
output,
315-
input_strides,
316-
output_strides,
317-
input_sizes,
318-
output_sizes,
319-
scale,
320-
align_corners);
321-
mtl_dispatch1DJob(computeEncoder, upsamplePSO, output_size[0]);
283+
if (output.ndimension() == 4) {
284+
mtl_dispatch1DJob(computeEncoder, upsamplePSO, output_size[0] * output_size[1]);
285+
} else {
286+
mtl_dispatch1DJob(computeEncoder, upsamplePSO, output_size[0]);
287+
}
322288
}
323289
});
324290
}
@@ -343,23 +309,15 @@ static void upsample_kernel_backward_out_template(const Tensor& grad_input,
343309
auto stream = getCurrentMPSStream();
344310
dispatch_sync_with_rethrow(stream->queue(), ^() {
345311
@autoreleasepool {
346-
std::array<int64_t, 4> output_strides = {
347-
grad_output.stride(3), grad_output.stride(2), grad_output.stride(1), grad_output.stride(0)};
348-
std::array<int64_t, 4> output_sizes = {
349-
grad_output.size(3), grad_output.size(2), grad_output.size(1), grad_output.size(0)};
350-
std::array<int64_t, 4> input_sizes = {
351-
grad_input.size(3), grad_input.size(2), grad_input.size(1), grad_input.size(0)};
352-
std::array<int64_t, 4> input_strides = {
353-
grad_input.stride(3), grad_input.stride(2), grad_input.stride(1), grad_input.stride(0)};
354312
auto computeEncoder = stream->commandEncoder();
355313
[computeEncoder setComputePipelineState:upsamplePSO];
356314
mtl_setArgs(computeEncoder,
357315
grad_input,
358316
grad_output,
359-
input_strides,
360-
output_strides,
361-
input_sizes,
362-
output_sizes,
317+
grad_input.strides(),
318+
grad_output.strides(),
319+
grad_input.sizes(),
320+
grad_output.sizes(),
363321
scales,
364322
align_corners);
365323
mtl_dispatch1DJob(computeEncoder, upsamplePSO, output_size[0] * output_size[1]);
@@ -439,7 +397,7 @@ static void upsample_kernel_backward_out_template(const Tensor& grad_input,
439397

440398
TORCH_IMPL_FUNC(upsample_linear1d_out_mps)
441399
(const Tensor& input, IntArrayRef output_size, bool align_corners, std::optional<double> scale, const Tensor& output) {
442-
mps::upsample_kernel_out_template(input, output_size, align_corners, scale, output, "linear1d");
400+
mps::upsample_kernel_out_template(input, output_size, align_corners, scale, scale, output, "linear1d");
443401
}
444402

445403
TORCH_IMPL_FUNC(upsample_linear1d_backward_out_mps)

0 commit comments

Comments
 (0)