From 1b22199ed59edd30cd04d09afb278c3492d14c8b Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 13 Feb 2023 18:05:37 +0000 Subject: [PATCH] WIP AD --- jax_triton/pallas/pallas_call.py | 179 +++++++++++++++++++++++-------- tests/pallas_test.py | 25 ++++- 2 files changed, 156 insertions(+), 48 deletions(-) diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index ea846971..d1d2d170 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -156,7 +156,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, []) jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts jvp_which_linear = (*which_linear, *(True,) * len(tangents)) - jvp_inshapes = (*in_shapes, *in_shapes) + _, nonzero_tangent_in_shapes = partition_list(nonzero_tangents, in_shapes) + jvp_inshapes = (*in_shapes, *nonzero_tangent_in_shapes) jvp_outshapes = (*out_shapes, *out_shapes) if input_output_aliases: raise NotImplementedError("`input_output_aliases` jvp not supported.") @@ -172,7 +173,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)]) logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)]) in_bms, out_bms = split_list(grid_spec.block_mappings, [len(primals)]) - new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms)) + nonzero_in_bms, _ = partition_list(nonzero_tangents, in_bms) + new_bms = tuple((*in_bms, *nonzero_in_bms, *out_bms, *out_bms)) new_grid_spec = grid_spec.replace(block_mappings=new_bms) jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs, *logical_tangent_inputs, @@ -291,12 +293,13 @@ def _pallas_call_partial_eval( jaxpr_known_resout, jaxpr_unknown_resin_, uk_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom( jaxpr, - in_inst=all_unknowns, + in_inst=True, in_unknowns=all_unknowns, ensure_out_unknowns=[], ensure_out_inst=[], saveable=_save_everything) - # # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and + breakpoint() + # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and # regular valued input/outputs. However, we'd like to bind these jaxprs to a # `for`, which expects only `Ref` inputs and no output. We need to convert # both of these jaxprs into ones that are compatible with `for`. @@ -339,13 +342,13 @@ def _pallas_call_partial_eval( for a in res_avals ] res_block_mappings = [ - BlockMapping((*[None] * len(grid), *a.shape), index_map) + BlockMapping((*[pallas_core.mapped] * len(grid), *a.shape), index_map) for a, index_map in zip(res_avals, res_index_mappings) ] known_grid_spec = GridSpec(grid, (*known_in_block_mappings, *known_out_block_mappings, *res_block_mappings), - grid_spec.mapped_dims) + mapped_dims) unknown_grid_spec = GridSpec(grid, (*res_block_mappings, *unknown_in_block_mappings, *unknown_out_block_mappings), @@ -362,7 +365,7 @@ def _pallas_call_partial_eval( input_output_aliases=(), which_linear=tuple(known_which_linear), **compiler_params) - known_outputs, residuals = split_list(known_out_and_res, [len(known_tracers)]) + known_outputs, residuals = split_list(known_out_and_res, [len(known_out_shapes)]) residuals = map(trace.new_instantiated_const, residuals) unknown_inputs = [*residuals, *unknown_tracers] unknown_outputs = [ @@ -373,8 +376,7 @@ def _pallas_call_partial_eval( source = source_info_util.current().replace(name_stack=name_stack) unknown_params = dict( jaxpr=jaxpr_unknown, - in_shapes=(*(jax.ShapeDtypeStruct(s.shape, s.dtype) for s in res_avals), - *unknown_in_shapes), + in_shapes=(*res_shapes, *unknown_in_shapes), out_shapes=tuple(unknown_out_shapes), grid_spec=unknown_grid_spec, which_linear=(*res_which_linear, *unknown_which_linear), @@ -390,40 +392,6 @@ def _pallas_call_partial_eval( return merge_lists(out_unknowns, known_outputs, unknown_outputs) pe.custom_partial_eval_rules[pallas_call_p] = _pallas_call_partial_eval -def _transpose_jaxpr(jaxpr: jax_core.Jaxpr, which_linear: Sequence[bool] - ) -> jax_core.Jaxpr: - num_inputs = len(which_linear) - num_outputs = len(jaxpr.invars) - num_inputs - def trans(*args): - # First we want to run the computation to read all the residual refs. We can - # do that by using partial evaluation with all linear inputs unknown. - res_jaxpr, tangent_jaxpr_, *_ = \ - pe.partial_eval_jaxpr_custom(jaxpr, - in_unknowns=[*which_linear, *[True] * - num_outputs], - in_inst=[*which_linear, *[True] * - num_outputs], - ensure_out_inst=[], - ensure_out_unknowns=[], - saveable=_save_everything) - res_args = [x for x, lin in zip(args, which_linear) if not lin] - res = jax_core.eval_jaxpr(res_jaxpr, (), *res_args) - - # Now that we have residual values, we run the tangent jaxpr. It takes as - # input the residuals, and all the refs (at least, the ones - # that are used in the body). Luckily, `tangent_jaxpr_` has all known and - # unknown inputs! - breakpoint() - primals_args = [*(r for u, r in zip(used_res, res) if u)] - ct_args = [x for x, u in zip(args, used_ct) if u] - ad.backward_pass( - tangent_jaxpr, (), False, (), (*res, *ct_args), ()) - breakpoint() - return [] - jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(trans), [v.aval for v in jaxpr.invars]) - return jaxpr_trans - def _pallas_call_transpose_rule(cts_in, *args, jaxpr: jax_core.Jaxpr, name: str, @@ -592,6 +560,105 @@ def _pallas_call_batching_rule(args, dims, *, return out, (0,) * len(out) batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule +class TritonCompilationResult(NamedTuple): + name: str + asm: Dict[str, str] + shared_mem: int + lowering_result: lowering.TritonLoweringResult + +@weakref_lru_cache +def _compile_jaxpr(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec, + name: str, num_warps: int, num_stages: int + ) -> TritonCompilationResult: + lowering_result = lowering.lower_jaxpr_to_triton_module(jaxpr, in_shapes, grid_spec, name) + backend = tc.runtime.backend.CUDA + device = 0 + name, asm, shared_mem = tc.code_gen.compile_ttir(backend, lowering_result.module, device, + num_warps, num_stages, {}, 0) + return TritonCompilationResult(name, asm, shared_mem, lowering_result) + + +def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: Tuple[jax.ShapeDtypeStruct, ...], + out_shapes: Tuple[jax.ShapeDtypeStruct, ...], + which_linear: Tuple[bool, ...], + interpret: bool, + debug: bool, + input_output_aliases: Tuple[Tuple[int, int], ...], + grid_spec: GridSpec, + **compiler_params: Any): + if interpret: + return mlir.lower_fun(_pallas_call_impl, multiple_results=True)( + ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes, + in_shapes=in_shapes, + which_linear=which_linear, + interpret=interpret, debug=debug, + input_output_aliases=input_output_aliases, + grid_spec=grid_spec, **compiler_params) + num_warps = compiler_params.get("num_warps", 4) + num_stages = compiler_params.get("num_stages", 3) + compilation_result = _compile_jaxpr(jaxpr, tuple((*in_shapes, *out_shapes)), + grid_spec, name, num_warps, num_stages) + name = compilation_result.name + asm = compilation_result.asm + shared_mem = compilation_result.shared_mem + ref_effects = state.get_ref_state_effects( + [v.aval for v in jaxpr.invars], jaxpr.effects) + is_accum = [ + all(isinstance(eff, state.AccumEffect) for eff in ref_effect) + for ref_effect in ref_effects + ] + if debug: + print(jaxpr) + print(grid_spec) + lowering_result = compilation_result.lowering_result + if debug: + lowering_result.module.print() + out_type = ir.TupleType.get_tuple([ + ir.RankedTensorType.get(out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype)) + for out_shape in ctx.avals_out]) + i32_type = ir.IntegerType.get_signless(32) + + kernel = triton_kernel_call_lib.TritonKernel( + asm["cubin"], name, num_warps, shared_mem + ) + + grid = normalize_grid(compilation_result.lowering_result.grid, metaparams={}) + # All arguments are buffers. + all_args = [None] * (len(in_shapes) + len(out_shapes)) + kernel_call = triton_kernel_call_lib.TritonKernelCall( + kernel, grid[0], grid[1], grid[2], all_args, + is_accum, + [s.size for s in [*in_shapes, *out_shapes]] + ) + + ctx.module_context.add_keepalive(kernel_call) + output_operand_aliases = ir.ArrayAttr.get([ + mhlo.OutputOperandAlias.get( + output_tuple_indices=[output], + operand_index=input, + operand_tuple_indices=[]) + for input, output in input_output_aliases + ]) + out = mhlo.CustomCallOp( + [out_type], + in_nodes, + call_target_name=ir.StringAttr.get("triton_kernel_call"), + has_side_effect=ir.BoolAttr.get(False), + backend_config=ir.StringAttr.get(kernel_call.descriptor), + api_version=ir.IntegerAttr.get(i32_type, 1), + called_computations=ir.ArrayAttr.get([]), + operand_layouts=avals_to_layouts(ctx.avals_in), + result_layouts=avals_to_layouts(ctx.avals_out), + output_operand_aliases=output_operand_aliases, + ) + results = [mhlo.GetTupleElementOp(out, mlir.i32_attr(i)).result + for i in range(len(out_shapes))] + return results +mlir.register_lowering(pallas_call_p, pallas_call_lowering, platform="cuda") + @weakref_lru_cache def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, primitive_name: Optional[str] = None): @@ -633,6 +700,32 @@ def _compute_shape_from_block_spec(block_spec: Optional[BlockSpec], return arg_shape return tuple(s for s in block_spec.block_shape if s is not None) +def _pallas_call_bind(*args, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: Tuple[jax.ShapeDtypeStruct, ...], + out_shapes: Tuple[jax.ShapeDtypeStruct, ...], + which_linear: Tuple[bool, ...], + interpret: bool, + debug: bool, + input_output_aliases: Tuple[Tuple[int, int], ...], + grid_spec: GridSpec, + **compiler_params: Any): + num_inputs = len(in_shapes) + num_outputs = len(out_shapes) + assert len(jaxpr.invars) == num_inputs + num_outputs, (len(jaxpr.invars), + num_inputs, + num_outputs) + assert len(grid_spec.block_mappings) == len(jaxpr.invars) + return jax_core.Primitive.bind( + pallas_call_p, *args, + jaxpr=jaxpr, name=name, in_shapes=in_shapes, + out_shapes=out_shapes, which_linear=which_linear, + interpret=interpret, debug=debug, + input_output_aliases=input_output_aliases, + grid_spec=grid_spec, **compiler_params) +pallas_call_p.def_custom_bind(_pallas_call_bind) + def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False, grid: Optional[Grid] = None, in_specs: Optional[Sequence[Optional[BlockSpec]]] = None, diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 49738f8f..49f95840 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -702,8 +702,7 @@ class PallasCallAutodifferentiationTest(PallasTest): ("square", lambda x: x * x), ("add_one", lambda x: x + 1.), ("exp", jnp.exp), - # ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is - # updated + ("tanh", jnp.tanh), ]) def test_jvp(self, impl): @functools.partial( @@ -728,8 +727,7 @@ def pallas_impl(x_ref, o_ref): ("square", lambda x: x * x), ("add_one", lambda x: x + 1.), ("exp", jnp.exp), - # ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is - # updated + ("tanh", jnp.tanh), ]) def test_jvp_slice(self, impl): @functools.partial( @@ -752,7 +750,6 @@ def pallas_impl(x_ref, o_ref): rtol=1e-5) jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2) - TODO(sharadmv): enable this when we update Triton def test_jvp_matmul(self): k1, k2 = random.split(random.PRNGKey(0)) x = random.normal(k1, (256, 128)) @@ -778,6 +775,24 @@ def add_vectors(x_ref, y_ref, o_ref): out_ref = xy[0] + xy[1] np.testing.assert_allclose(out, out_ref) + @parameterized.named_parameters(*[ + ("square", lambda x: x * x), + ("add_one", lambda x: x + 1.), + ("exp", jnp.exp), + ("tanh", jnp.tanh), + ]) + def test_grad(self, impl): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32)) + def pallas_impl(x_ref, o_ref): + o_ref[...] = impl(x_ref[...]) + + x = random.normal(random.PRNGKey(0)) + g = jax.grad(pallas_impl)(x) + g_ref = jax.grad(impl)(x) + np.testing.assert_allclose(g, g_ref, atol=1e-5, rtol=1e-5) + jtu.check_grads(pallas_impl, (x,), modes=["rev"], order=1) + class PallasCallVmapTest(PallasTest):