/* * Copyright © 2021 Google * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), * to deal in the Software without restriction, including without limitation * the rights to use, copy, modify, merge, publish, distribute, sublicense, * and/or sell copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice (including the next * paragraph) shall be included in all copies or substantial portions of the * Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS * IN THE SOFTWARE. */ #include "radv_acceleration_structure.h" #include "radv_debug.h" #include "radv_private.h" #include "radv_shader.h" #include "nir/nir.h" #include "nir/nir_builder.h" #include "nir/nir_builtin_builder.h" static VkRayTracingPipelineCreateInfoKHR radv_create_merged_rt_create_info(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo) { VkRayTracingPipelineCreateInfoKHR local_create_info = *pCreateInfo; uint32_t total_stages = pCreateInfo->stageCount; uint32_t total_groups = pCreateInfo->groupCount; if (pCreateInfo->pLibraryInfo) { for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) { RADV_FROM_HANDLE(radv_pipeline, library, pCreateInfo->pLibraryInfo->pLibraries[i]); total_stages += library->library.stage_count; total_groups += library->library.group_count; } } VkPipelineShaderStageCreateInfo *stages = NULL; VkRayTracingShaderGroupCreateInfoKHR *groups = NULL; local_create_info.stageCount = total_stages; local_create_info.groupCount = total_groups; local_create_info.pStages = stages = malloc(sizeof(VkPipelineShaderStageCreateInfo) * total_stages); local_create_info.pGroups = groups = malloc(sizeof(VkRayTracingShaderGroupCreateInfoKHR) * total_groups); if (!local_create_info.pStages || !local_create_info.pGroups) return local_create_info; total_stages = pCreateInfo->stageCount; total_groups = pCreateInfo->groupCount; for (unsigned j = 0; j < pCreateInfo->stageCount; ++j) stages[j] = pCreateInfo->pStages[j]; for (unsigned j = 0; j < pCreateInfo->groupCount; ++j) groups[j] = pCreateInfo->pGroups[j]; if (pCreateInfo->pLibraryInfo) { for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) { RADV_FROM_HANDLE(radv_pipeline, library, pCreateInfo->pLibraryInfo->pLibraries[i]); for (unsigned j = 0; j < library->library.stage_count; ++j) stages[total_stages + j] = library->library.stages[j]; for (unsigned j = 0; j < library->library.group_count; ++j) { VkRayTracingShaderGroupCreateInfoKHR *dst = &groups[total_groups + j]; *dst = library->library.groups[j]; if (dst->generalShader != VK_SHADER_UNUSED_KHR) dst->generalShader += total_stages; if (dst->closestHitShader != VK_SHADER_UNUSED_KHR) dst->closestHitShader += total_stages; if (dst->anyHitShader != VK_SHADER_UNUSED_KHR) dst->anyHitShader += total_stages; if (dst->intersectionShader != VK_SHADER_UNUSED_KHR) dst->intersectionShader += total_stages; } total_stages += library->library.stage_count; total_groups += library->library.group_count; } } return local_create_info; } static VkResult radv_rt_pipeline_library_create(VkDevice _device, VkPipelineCache _cache, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline) { RADV_FROM_HANDLE(radv_device, device, _device); struct radv_pipeline *pipeline; pipeline = vk_zalloc2(&device->vk.alloc, pAllocator, sizeof(*pipeline), 8, VK_SYSTEM_ALLOCATION_SCOPE_OBJECT); if (pipeline == NULL) return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY); vk_object_base_init(&device->vk, &pipeline->base, VK_OBJECT_TYPE_PIPELINE); pipeline->type = RADV_PIPELINE_LIBRARY; VkRayTracingPipelineCreateInfoKHR local_create_info = radv_create_merged_rt_create_info(pCreateInfo); if (!local_create_info.pStages || !local_create_info.pGroups) goto fail; if (local_create_info.stageCount) { size_t size = sizeof(VkPipelineShaderStageCreateInfo) * local_create_info.stageCount; pipeline->library.stage_count = local_create_info.stageCount; pipeline->library.stages = malloc(size); if (!pipeline->library.stages) goto fail; memcpy(pipeline->library.stages, local_create_info.pStages, size); } if (local_create_info.groupCount) { size_t size = sizeof(VkRayTracingShaderGroupCreateInfoKHR) * local_create_info.groupCount; pipeline->library.group_count = local_create_info.groupCount; pipeline->library.groups = malloc(size); if (!pipeline->library.groups) goto fail; memcpy(pipeline->library.groups, local_create_info.pGroups, size); } *pPipeline = radv_pipeline_to_handle(pipeline); free((void *)local_create_info.pGroups); free((void *)local_create_info.pStages); return VK_SUCCESS; fail: free(pipeline->library.groups); free(pipeline->library.stages); free((void *)local_create_info.pGroups); free((void *)local_create_info.pStages); return VK_ERROR_OUT_OF_HOST_MEMORY; } /* * Global variables for an RT pipeline */ struct rt_variables { /* idx of the next shader to run in the next iteration of the main loop */ nir_variable *idx; /* scratch offset of the argument area relative to stack_ptr */ nir_variable *arg; nir_variable *stack_ptr; /* global address of the SBT entry used for the shader */ nir_variable *shader_record_ptr; /* trace_ray arguments */ nir_variable *accel_struct; nir_variable *flags; nir_variable *cull_mask; nir_variable *sbt_offset; nir_variable *sbt_stride; nir_variable *miss_index; nir_variable *origin; nir_variable *tmin; nir_variable *direction; nir_variable *tmax; /* from the BTAS instance currently being visited */ nir_variable *custom_instance_and_mask; /* Properties of the primitive currently being visited. */ nir_variable *primitive_id; nir_variable *geometry_id_and_flags; nir_variable *instance_id; nir_variable *instance_addr; nir_variable *hit_kind; nir_variable *opaque; /* Safeguard to ensure we don't end up in an infinite loop of non-existing case. Should not be * needed but is extra anti-hang safety during bring-up. */ nir_variable *main_loop_case_visited; /* Output variable for intersection & anyhit shaders. */ nir_variable *ahit_status; /* Array of stack size struct for recording the max stack size for each group. */ struct radv_pipeline_shader_stack_size *stack_sizes; unsigned group_idx; }; static struct rt_variables create_rt_variables(nir_shader *shader, struct radv_pipeline_shader_stack_size *stack_sizes) { struct rt_variables vars = { NULL, }; vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx"); vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg"); vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr"); vars.shader_record_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr"); const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3); vars.accel_struct = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct"); vars.flags = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "ray_flags"); vars.cull_mask = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "cull_mask"); vars.sbt_offset = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_offset"); vars.sbt_stride = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_stride"); vars.miss_index = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "miss_index"); vars.origin = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_origin"); vars.tmin = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmin"); vars.direction = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_direction"); vars.tmax = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmax"); vars.custom_instance_and_mask = nir_variable_create( shader, nir_var_shader_temp, glsl_uint_type(), "custom_instance_and_mask"); vars.primitive_id = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "primitive_id"); vars.geometry_id_and_flags = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags"); vars.instance_id = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "instance_id"); vars.instance_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr"); vars.hit_kind = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "hit_kind"); vars.opaque = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "opaque"); vars.main_loop_case_visited = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "main_loop_case_visited"); vars.ahit_status = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "ahit_status"); vars.stack_sizes = stack_sizes; return vars; } /* * Remap all the variables between the two rt_variables struct for inlining. */ static void map_rt_variables(struct hash_table *var_remap, struct rt_variables *src, const struct rt_variables *dst) { _mesa_hash_table_insert(var_remap, src->idx, dst->idx); _mesa_hash_table_insert(var_remap, src->arg, dst->arg); _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr); _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr); _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct); _mesa_hash_table_insert(var_remap, src->flags, dst->flags); _mesa_hash_table_insert(var_remap, src->cull_mask, dst->cull_mask); _mesa_hash_table_insert(var_remap, src->sbt_offset, dst->sbt_offset); _mesa_hash_table_insert(var_remap, src->sbt_stride, dst->sbt_stride); _mesa_hash_table_insert(var_remap, src->miss_index, dst->miss_index); _mesa_hash_table_insert(var_remap, src->origin, dst->origin); _mesa_hash_table_insert(var_remap, src->tmin, dst->tmin); _mesa_hash_table_insert(var_remap, src->direction, dst->direction); _mesa_hash_table_insert(var_remap, src->tmax, dst->tmax); _mesa_hash_table_insert(var_remap, src->custom_instance_and_mask, dst->custom_instance_and_mask); _mesa_hash_table_insert(var_remap, src->primitive_id, dst->primitive_id); _mesa_hash_table_insert(var_remap, src->geometry_id_and_flags, dst->geometry_id_and_flags); _mesa_hash_table_insert(var_remap, src->instance_id, dst->instance_id); _mesa_hash_table_insert(var_remap, src->instance_addr, dst->instance_addr); _mesa_hash_table_insert(var_remap, src->hit_kind, dst->hit_kind); _mesa_hash_table_insert(var_remap, src->opaque, dst->opaque); _mesa_hash_table_insert(var_remap, src->ahit_status, dst->ahit_status); src->stack_sizes = dst->stack_sizes; src->group_idx = dst->group_idx; } /* * Create a copy of the global rt variables where the primitive/instance related variables are * independent.This is needed as we need to keep the old values of the global variables around * in case e.g. an anyhit shader reject the collision. So there are inner variables that get copied * to the outer variables once we commit to a better hit. */ static struct rt_variables create_inner_vars(nir_builder *b, const struct rt_variables *vars) { struct rt_variables inner_vars = *vars; inner_vars.idx = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_idx"); inner_vars.shader_record_ptr = nir_variable_create( b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_shader_record_ptr"); inner_vars.primitive_id = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_primitive_id"); inner_vars.geometry_id_and_flags = nir_variable_create( b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_geometry_id_and_flags"); inner_vars.tmax = nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "inner_tmax"); inner_vars.instance_id = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_instance_id"); inner_vars.instance_addr = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_instance_addr"); inner_vars.hit_kind = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_hit_kind"); inner_vars.custom_instance_and_mask = nir_variable_create( b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_custom_instance_and_mask"); return inner_vars; } /* The hit attributes are stored on the stack. This is the offset compared to the current stack * pointer of where the hit attrib is stored. */ const uint32_t RADV_HIT_ATTRIB_OFFSET = -(16 + RADV_MAX_HIT_ATTRIB_SIZE); static void insert_rt_return(nir_builder *b, const struct rt_variables *vars) { nir_store_var(b, vars->stack_ptr, nir_iadd(b, nir_load_var(b, vars->stack_ptr), nir_imm_int(b, -16)), 1); nir_store_var(b, vars->idx, nir_load_scratch(b, 1, 32, nir_load_var(b, vars->stack_ptr), .align_mul = 16), 1); } enum sbt_type { SBT_RAYGEN, SBT_MISS, SBT_HIT, SBT_CALLABLE, }; static nir_ssa_def * get_sbt_ptr(nir_builder *b, nir_ssa_def *idx, enum sbt_type binding) { nir_ssa_def *desc = nir_load_sbt_amd(b, 4, .binding = binding); nir_ssa_def *base_addr = nir_pack_64_2x32(b, nir_channels(b, desc, 0x3)); nir_ssa_def *stride = nir_channel(b, desc, 2); nir_ssa_def *ret = nir_imul(b, idx, stride); ret = nir_iadd(b, base_addr, nir_u2u64(b, ret)); return ret; } static void load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_ssa_def *idx, enum sbt_type binding, unsigned offset) { nir_ssa_def *addr = get_sbt_ptr(b, idx, binding); nir_ssa_def *load_addr = addr; if (offset) load_addr = nir_iadd(b, load_addr, nir_imm_int64(b, offset)); nir_ssa_def *v_idx = nir_build_load_global(b, 1, 32, load_addr, .align_mul = 4, .align_offset = 0); nir_store_var(b, vars->idx, v_idx, 1); nir_ssa_def *record_addr = nir_iadd(b, addr, nir_imm_int64(b, RADV_RT_HANDLE_SIZE)); nir_store_var(b, vars->shader_record_ptr, record_addr, 1); } static nir_ssa_def * nir_build_vec3_mat_mult(nir_builder *b, nir_ssa_def *vec, nir_ssa_def *matrix[], bool translation) { nir_ssa_def *result_components[3] = { nir_channel(b, matrix[0], 3), nir_channel(b, matrix[1], 3), nir_channel(b, matrix[2], 3), }; for (unsigned i = 0; i < 3; ++i) { for (unsigned j = 0; j < 3; ++j) { nir_ssa_def *v = nir_fmul(b, nir_channels(b, vec, 1 << j), nir_channels(b, matrix[i], 1 << j)); result_components[i] = (translation || j) ? nir_fadd(b, result_components[i], v) : v; } } return nir_vec(b, result_components, 3); } static nir_ssa_def * nir_build_vec3_mat_mult_pre(nir_builder *b, nir_ssa_def *vec, nir_ssa_def *matrix[]) { nir_ssa_def *result_components[3] = { nir_channel(b, matrix[0], 3), nir_channel(b, matrix[1], 3), nir_channel(b, matrix[2], 3), }; return nir_build_vec3_mat_mult(b, nir_fsub(b, vec, nir_vec(b, result_components, 3)), matrix, false); } static void nir_build_wto_matrix_load(nir_builder *b, nir_ssa_def *instance_addr, nir_ssa_def **out) { unsigned offset = offsetof(struct radv_bvh_instance_node, wto_matrix); for (unsigned i = 0; i < 3; ++i) { out[i] = nir_build_load_global(b, 4, 32, nir_iadd(b, instance_addr, nir_imm_int64(b, offset + i * 16)), .align_mul = 64, .align_offset = offset + i * 16); } } /* This lowers all the RT instructions that we do not want to pass on to the combined shader and * that we can implement using the variables from the shader we are going to inline into. */ static void lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned call_idx_base) { nir_builder b_shader; nir_builder_init(&b_shader, nir_shader_get_entrypoint(shader)); nir_foreach_block (block, nir_shader_get_entrypoint(shader)) { nir_foreach_instr_safe (instr, block) { switch (instr->type) { case nir_instr_type_intrinsic: { nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); switch (intr->intrinsic) { case nir_intrinsic_rt_execute_callable: { uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE; uint32_t ret = call_idx_base + nir_intrinsic_call_idx(intr) + 1; b_shader.cursor = nir_instr_remove(instr); nir_store_var(&b_shader, vars->stack_ptr, nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), nir_imm_int(&b_shader, size)), 1); nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret), nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16, .write_mask = 1); nir_store_var(&b_shader, vars->stack_ptr, nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), nir_imm_int(&b_shader, 16)), 1); load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_CALLABLE, 0); nir_store_var( &b_shader, vars->arg, nir_isub(&b_shader, intr->src[1].ssa, nir_imm_int(&b_shader, size + 16)), 1); vars->stack_sizes[vars->group_idx].recursive_size = MAX2(vars->stack_sizes[vars->group_idx].recursive_size, size + 16); break; } case nir_intrinsic_rt_trace_ray: { uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE; uint32_t ret = call_idx_base + nir_intrinsic_call_idx(intr) + 1; b_shader.cursor = nir_instr_remove(instr); nir_store_var(&b_shader, vars->stack_ptr, nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), nir_imm_int(&b_shader, size)), 1); nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret), nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16, .write_mask = 1); nir_store_var(&b_shader, vars->stack_ptr, nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), nir_imm_int(&b_shader, 16)), 1); nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 1), 1); nir_store_var( &b_shader, vars->arg, nir_isub(&b_shader, intr->src[10].ssa, nir_imm_int(&b_shader, size + 16)), 1); vars->stack_sizes[vars->group_idx].recursive_size = MAX2(vars->stack_sizes[vars->group_idx].recursive_size, size + 16); /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */ nir_store_var(&b_shader, vars->accel_struct, intr->src[0].ssa, 0x1); nir_store_var(&b_shader, vars->flags, intr->src[1].ssa, 0x1); nir_store_var(&b_shader, vars->cull_mask, nir_iand(&b_shader, intr->src[2].ssa, nir_imm_int(&b_shader, 0xff)), 0x1); nir_store_var(&b_shader, vars->sbt_offset, nir_iand(&b_shader, intr->src[3].ssa, nir_imm_int(&b_shader, 0xf)), 0x1); nir_store_var(&b_shader, vars->sbt_stride, nir_iand(&b_shader, intr->src[4].ssa, nir_imm_int(&b_shader, 0xf)), 0x1); nir_store_var(&b_shader, vars->miss_index, nir_iand(&b_shader, intr->src[5].ssa, nir_imm_int(&b_shader, 0xffff)), 0x1); nir_store_var(&b_shader, vars->origin, intr->src[6].ssa, 0x7); nir_store_var(&b_shader, vars->tmin, intr->src[7].ssa, 0x1); nir_store_var(&b_shader, vars->direction, intr->src[8].ssa, 0x7); nir_store_var(&b_shader, vars->tmax, intr->src[9].ssa, 0x1); break; } case nir_intrinsic_rt_resume: { uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE; b_shader.cursor = nir_instr_remove(instr); nir_store_var(&b_shader, vars->stack_ptr, nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), nir_imm_int(&b_shader, -size)), 1); break; } case nir_intrinsic_rt_return_amd: { b_shader.cursor = nir_instr_remove(instr); if (shader->info.stage == MESA_SHADER_RAYGEN) { nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 0), 1); break; } insert_rt_return(&b_shader, vars); break; } case nir_intrinsic_load_scratch: { b_shader.cursor = nir_before_instr(instr); nir_instr_rewrite_src_ssa( instr, &intr->src[0], nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[0].ssa)); break; } case nir_intrinsic_store_scratch: { b_shader.cursor = nir_before_instr(instr); nir_instr_rewrite_src_ssa( instr, &intr->src[1], nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[1].ssa)); break; } case nir_intrinsic_load_rt_arg_scratch_offset_amd: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->arg); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_shader_record_ptr: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->shader_record_ptr); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_ray_launch_id: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_global_invocation_id(&b_shader, 32); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_ray_t_min: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->tmin); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_ray_t_max: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->tmax); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_ray_world_origin: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->origin); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_ray_world_direction: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->direction); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_ray_instance_custom_index: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->custom_instance_and_mask); ret = nir_iand(&b_shader, ret, nir_imm_int(&b_shader, 0xFFFFFF)); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_primitive_id: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->primitive_id); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_ray_geometry_index: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->geometry_id_and_flags); ret = nir_iand(&b_shader, ret, nir_imm_int(&b_shader, 0xFFFFFFF)); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_instance_id: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->instance_id); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_ray_flags: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->flags); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_ray_hit_kind: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->hit_kind); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_load_ray_world_to_object: { unsigned c = nir_intrinsic_column(intr); nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr); nir_ssa_def *wto_matrix[3]; nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix); nir_ssa_def *vals[3]; for (unsigned i = 0; i < 3; ++i) vals[i] = nir_channel(&b_shader, wto_matrix[i], c); nir_ssa_def *val = nir_vec(&b_shader, vals, 3); if (c == 3) val = nir_fneg(&b_shader, nir_build_vec3_mat_mult(&b_shader, val, wto_matrix, false)); b_shader.cursor = nir_instr_remove(instr); nir_ssa_def_rewrite_uses(&intr->dest.ssa, val); break; } case nir_intrinsic_load_ray_object_to_world: { unsigned c = nir_intrinsic_column(intr); nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr); nir_ssa_def *val; if (c == 3) { nir_ssa_def *wto_matrix[3]; nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix); nir_ssa_def *vals[3]; for (unsigned i = 0; i < 3; ++i) vals[i] = nir_channel(&b_shader, wto_matrix[i], c); val = nir_vec(&b_shader, vals, 3); } else { val = nir_build_load_global( &b_shader, 3, 32, nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 92 + c * 12)), .align_mul = 4, .align_offset = 0); } b_shader.cursor = nir_instr_remove(instr); nir_ssa_def_rewrite_uses(&intr->dest.ssa, val); break; } case nir_intrinsic_load_ray_object_origin: { nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr); nir_ssa_def *wto_matrix[] = { nir_build_load_global( &b_shader, 4, 32, nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 16)), .align_mul = 64, .align_offset = 16), nir_build_load_global( &b_shader, 4, 32, nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 32)), .align_mul = 64, .align_offset = 32), nir_build_load_global( &b_shader, 4, 32, nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 48)), .align_mul = 64, .align_offset = 48)}; nir_ssa_def *val = nir_build_vec3_mat_mult_pre( &b_shader, nir_load_var(&b_shader, vars->origin), wto_matrix); b_shader.cursor = nir_instr_remove(instr); nir_ssa_def_rewrite_uses(&intr->dest.ssa, val); break; } case nir_intrinsic_load_ray_object_direction: { nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr); nir_ssa_def *wto_matrix[3]; nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix); nir_ssa_def *val = nir_build_vec3_mat_mult( &b_shader, nir_load_var(&b_shader, vars->direction), wto_matrix, false); b_shader.cursor = nir_instr_remove(instr); nir_ssa_def_rewrite_uses(&intr->dest.ssa, val); break; } case nir_intrinsic_load_intersection_opaque_amd: { b_shader.cursor = nir_instr_remove(instr); nir_ssa_def *ret = nir_load_var(&b_shader, vars->opaque); nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); break; } case nir_intrinsic_ignore_ray_intersection: { b_shader.cursor = nir_instr_remove(instr); nir_store_var(&b_shader, vars->ahit_status, nir_imm_int(&b_shader, 1), 1); /* The if is a workaround to avoid having to fix up control flow manually */ nir_push_if(&b_shader, nir_imm_true(&b_shader)); nir_jump(&b_shader, nir_jump_return); nir_pop_if(&b_shader, NULL); break; } case nir_intrinsic_terminate_ray: { b_shader.cursor = nir_instr_remove(instr); nir_store_var(&b_shader, vars->ahit_status, nir_imm_int(&b_shader, 2), 1); /* The if is a workaround to avoid having to fix up control flow manually */ nir_push_if(&b_shader, nir_imm_true(&b_shader)); nir_jump(&b_shader, nir_jump_return); nir_pop_if(&b_shader, NULL); break; } case nir_intrinsic_report_ray_intersection: { b_shader.cursor = nir_instr_remove(instr); nir_push_if( &b_shader, nir_iand( &b_shader, nir_flt(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmax)), nir_fge(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmin)))); { nir_store_var(&b_shader, vars->ahit_status, nir_imm_int(&b_shader, 0), 1); nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 1); nir_store_var(&b_shader, vars->hit_kind, intr->src[1].ssa, 1); } nir_pop_if(&b_shader, NULL); break; } default: break; } break; } case nir_instr_type_jump: { nir_jump_instr *jump = nir_instr_as_jump(instr); if (jump->type == nir_jump_halt) { b_shader.cursor = nir_instr_remove(instr); nir_jump(&b_shader, nir_jump_return); } break; } default: break; } } } nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_none); } static void insert_rt_case(nir_builder *b, nir_shader *shader, const struct rt_variables *vars, nir_ssa_def *idx, uint32_t call_idx_base, uint32_t call_idx) { struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL); nir_opt_dead_cf(shader); struct rt_variables src_vars = create_rt_variables(shader, vars->stack_sizes); map_rt_variables(var_remap, &src_vars, vars); NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base); NIR_PASS_V(shader, nir_opt_remove_phis); NIR_PASS_V(shader, nir_lower_returns); NIR_PASS_V(shader, nir_opt_dce); if (b->shader->info.stage == MESA_SHADER_ANY_HIT || b->shader->info.stage == MESA_SHADER_INTERSECTION) { src_vars.stack_sizes[src_vars.group_idx].non_recursive_size = MAX2(src_vars.stack_sizes[src_vars.group_idx].non_recursive_size, shader->scratch_size); } else { src_vars.stack_sizes[src_vars.group_idx].recursive_size = MAX2(src_vars.stack_sizes[src_vars.group_idx].recursive_size, shader->scratch_size); } nir_push_if(b, nir_ieq(b, idx, nir_imm_int(b, call_idx))); nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1); nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap); nir_pop_if(b, NULL); /* Adopt the instructions from the source shader, since they are merely moved, not cloned. */ ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader)); ralloc_free(var_remap); } static bool lower_rt_derefs(nir_shader *shader) { nir_function_impl *impl = nir_shader_get_entrypoint(shader); bool progress = false; nir_builder b; nir_builder_init(&b, impl); b.cursor = nir_before_cf_list(&impl->body); nir_ssa_def *arg_offset = nir_load_rt_arg_scratch_offset_amd(&b); nir_foreach_block (block, impl) { nir_foreach_instr_safe (instr, block) { switch (instr->type) { case nir_instr_type_deref: { if (instr->type != nir_instr_type_deref) continue; nir_deref_instr *deref = nir_instr_as_deref(instr); if (nir_deref_mode_is(deref, nir_var_shader_call_data)) { deref->modes = nir_var_function_temp; if (deref->deref_type == nir_deref_type_var) { b.cursor = nir_before_instr(&deref->instr); nir_deref_instr *cast = nir_build_deref_cast( &b, arg_offset, nir_var_function_temp, deref->var->type, 0); nir_ssa_def_rewrite_uses(&deref->dest.ssa, &cast->dest.ssa); nir_instr_remove(&deref->instr); } progress = true; } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) { deref->modes = nir_var_function_temp; if (deref->deref_type == nir_deref_type_var) { b.cursor = nir_before_instr(&deref->instr); nir_deref_instr *cast = nir_build_deref_cast(&b, nir_imm_int(&b, RADV_HIT_ATTRIB_OFFSET), nir_var_function_temp, deref->type, 0); nir_ssa_def_rewrite_uses(&deref->dest.ssa, &cast->dest.ssa); nir_instr_remove(&deref->instr); } progress = true; } break; } default: break; } } } if (progress) { nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance); } else { nir_metadata_preserve(impl, nir_metadata_all); } return progress; } static gl_shader_stage convert_rt_stage(VkShaderStageFlagBits vk_stage) { switch (vk_stage) { case VK_SHADER_STAGE_RAYGEN_BIT_KHR: return MESA_SHADER_RAYGEN; case VK_SHADER_STAGE_ANY_HIT_BIT_KHR: return MESA_SHADER_ANY_HIT; case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR: return MESA_SHADER_CLOSEST_HIT; case VK_SHADER_STAGE_MISS_BIT_KHR: return MESA_SHADER_MISS; case VK_SHADER_STAGE_INTERSECTION_BIT_KHR: return MESA_SHADER_INTERSECTION; case VK_SHADER_STAGE_CALLABLE_BIT_KHR: return MESA_SHADER_CALLABLE; default: unreachable("Unhandled RT stage"); } } static nir_shader * parse_rt_stage(struct radv_device *device, struct radv_pipeline_layout *layout, const VkPipelineShaderStageCreateInfo *stage) { struct radv_pipeline_key key; memset(&key, 0, sizeof(key)); nir_shader *shader = radv_shader_compile_to_nir( device, vk_shader_module_from_handle(stage->module), stage->pName, convert_rt_stage(stage->stage), stage->pSpecializationInfo, layout, &key); if (shader->info.stage == MESA_SHADER_RAYGEN || shader->info.stage == MESA_SHADER_CLOSEST_HIT || shader->info.stage == MESA_SHADER_CALLABLE || shader->info.stage == MESA_SHADER_MISS) { nir_block *last_block = nir_impl_last_block(nir_shader_get_entrypoint(shader)); nir_builder b_inner; nir_builder_init(&b_inner, nir_shader_get_entrypoint(shader)); b_inner.cursor = nir_after_block(last_block); nir_rt_return_amd(&b_inner); } NIR_PASS_V(shader, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib, glsl_get_natural_size_align_bytes); NIR_PASS_V(shader, lower_rt_derefs); NIR_PASS_V(shader, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset); return shader; } static nir_function_impl * lower_any_hit_for_intersection(nir_shader *any_hit) { nir_function_impl *impl = nir_shader_get_entrypoint(any_hit); /* Any-hit shaders need three parameters */ assert(impl->function->num_params == 0); nir_parameter params[] = { { /* A pointer to a boolean value for whether or not the hit was * accepted. */ .num_components = 1, .bit_size = 32, }, { /* The hit T value */ .num_components = 1, .bit_size = 32, }, { /* The hit kind */ .num_components = 1, .bit_size = 32, }, }; impl->function->num_params = ARRAY_SIZE(params); impl->function->params = ralloc_array(any_hit, nir_parameter, ARRAY_SIZE(params)); memcpy(impl->function->params, params, sizeof(params)); nir_builder build; nir_builder_init(&build, impl); nir_builder *b = &build; b->cursor = nir_before_cf_list(&impl->body); nir_ssa_def *commit_ptr = nir_load_param(b, 0); nir_ssa_def *hit_t = nir_load_param(b, 1); nir_ssa_def *hit_kind = nir_load_param(b, 2); nir_deref_instr *commit = nir_build_deref_cast(b, commit_ptr, nir_var_function_temp, glsl_bool_type(), 0); nir_foreach_block_safe (block, impl) { nir_foreach_instr_safe (instr, block) { switch (instr->type) { case nir_instr_type_intrinsic: { nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); switch (intrin->intrinsic) { case nir_intrinsic_ignore_ray_intersection: b->cursor = nir_instr_remove(&intrin->instr); /* We put the newly emitted code inside a dummy if because it's * going to contain a jump instruction and we don't want to * deal with that mess here. It'll get dealt with by our * control-flow optimization passes. */ nir_store_deref(b, commit, nir_imm_false(b), 0x1); nir_push_if(b, nir_imm_true(b)); nir_jump(b, nir_jump_halt); nir_pop_if(b, NULL); break; case nir_intrinsic_terminate_ray: /* The "normal" handling of terminateRay works fine in * intersection shaders. */ break; case nir_intrinsic_load_ray_t_max: nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_t); nir_instr_remove(&intrin->instr); break; case nir_intrinsic_load_ray_hit_kind: nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_kind); nir_instr_remove(&intrin->instr); break; default: break; } break; } case nir_instr_type_jump: { nir_jump_instr *jump = nir_instr_as_jump(instr); if (jump->type == nir_jump_halt) { b->cursor = nir_instr_remove(instr); nir_jump(b, nir_jump_return); } break; } default: break; } } } nir_validate_shader(any_hit, "after initial any-hit lowering"); nir_lower_returns_impl(impl); nir_validate_shader(any_hit, "after lowering returns"); return impl; } /* Inline the any_hit shader into the intersection shader so we don't have * to implement yet another shader call interface here. Neither do any recursion. */ static void nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit) { void *dead_ctx = ralloc_context(intersection); nir_function_impl *any_hit_impl = NULL; struct hash_table *any_hit_var_remap = NULL; if (any_hit) { any_hit = nir_shader_clone(dead_ctx, any_hit); NIR_PASS_V(any_hit, nir_opt_dce); any_hit_impl = lower_any_hit_for_intersection(any_hit); any_hit_var_remap = _mesa_pointer_hash_table_create(dead_ctx); } nir_function_impl *impl = nir_shader_get_entrypoint(intersection); nir_builder build; nir_builder_init(&build, impl); nir_builder *b = &build; b->cursor = nir_before_cf_list(&impl->body); nir_variable *commit = nir_local_variable_create(impl, glsl_bool_type(), "ray_commit"); nir_store_var(b, commit, nir_imm_false(b), 0x1); nir_foreach_block_safe (block, impl) { nir_foreach_instr_safe (instr, block) { if (instr->type != nir_instr_type_intrinsic) continue; nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); if (intrin->intrinsic != nir_intrinsic_report_ray_intersection) continue; b->cursor = nir_instr_remove(&intrin->instr); nir_ssa_def *hit_t = nir_ssa_for_src(b, intrin->src[0], 1); nir_ssa_def *hit_kind = nir_ssa_for_src(b, intrin->src[1], 1); nir_ssa_def *min_t = nir_load_ray_t_min(b); nir_ssa_def *max_t = nir_load_ray_t_max(b); /* bool commit_tmp = false; */ nir_variable *commit_tmp = nir_local_variable_create(impl, glsl_bool_type(), "commit_tmp"); nir_store_var(b, commit_tmp, nir_imm_false(b), 0x1); nir_push_if(b, nir_iand(b, nir_fge(b, hit_t, min_t), nir_fge(b, max_t, hit_t))); { /* Any-hit defaults to commit */ nir_store_var(b, commit_tmp, nir_imm_true(b), 0x1); if (any_hit_impl != NULL) { nir_push_if(b, nir_inot(b, nir_load_intersection_opaque_amd(b))); { nir_ssa_def *params[] = { &nir_build_deref_var(b, commit_tmp)->dest.ssa, hit_t, hit_kind, }; nir_inline_function_impl(b, any_hit_impl, params, any_hit_var_remap); } nir_pop_if(b, NULL); } nir_push_if(b, nir_load_var(b, commit_tmp)); { nir_report_ray_intersection(b, 1, hit_t, hit_kind); } nir_pop_if(b, NULL); } nir_pop_if(b, NULL); nir_ssa_def *accepted = nir_load_var(b, commit_tmp); nir_ssa_def_rewrite_uses(&intrin->dest.ssa, accepted); } } /* We did some inlining; have to re-index SSA defs */ nir_index_ssa_defs(impl); /* Eliminate the casts introduced for the commit return of the any-hit shader. */ NIR_PASS_V(intersection, nir_opt_deref); ralloc_free(dead_ctx); } /* Variables only used internally to ray traversal. This is data that describes * the current state of the traversal vs. what we'd give to a shader. e.g. what * is the instance we're currently visiting vs. what is the instance of the * closest hit. */ struct rt_traversal_vars { nir_variable *origin; nir_variable *dir; nir_variable *inv_dir; nir_variable *sbt_offset_and_flags; nir_variable *instance_id; nir_variable *custom_instance_and_mask; nir_variable *instance_addr; nir_variable *should_return; nir_variable *bvh_base; nir_variable *stack; nir_variable *top_stack; }; static struct rt_traversal_vars init_traversal_vars(nir_builder *b) { const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3); struct rt_traversal_vars ret; ret.origin = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_origin"); ret.dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_dir"); ret.inv_dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_inv_dir"); ret.sbt_offset_and_flags = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_sbt_offset_and_flags"); ret.instance_id = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_instance_id"); ret.custom_instance_and_mask = nir_variable_create( b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_custom_instance_and_mask"); ret.instance_addr = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr"); ret.should_return = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(), "traversal_should_return"); ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_bvh_base"); ret.stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr"); ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_top_stack_ptr"); return ret; } static nir_ssa_def * build_addr_to_node(nir_builder *b, nir_ssa_def *addr) { const uint64_t bvh_size = 1ull << 42; nir_ssa_def *node = nir_ushr(b, addr, nir_imm_int(b, 3)); return nir_iand(b, node, nir_imm_int64(b, (bvh_size - 1) << 3)); } static nir_ssa_def * build_node_to_addr(struct radv_device *device, nir_builder *b, nir_ssa_def *node) { nir_ssa_def *addr = nir_iand(b, node, nir_imm_int64(b, ~7ull)); addr = nir_ishl(b, addr, nir_imm_int(b, 3)); /* Assumes everything is in the top half of address space, which is true in * GFX9+ for now. */ return device->physical_device->rad_info.chip_class >= GFX9 ? nir_ior(b, addr, nir_imm_int64(b, 0xffffull << 48)) : addr; } /* When a hit is opaque the any_hit shader is skipped for this hit and the hit * is assumed to be an actual hit. */ static nir_ssa_def * hit_is_opaque(nir_builder *b, const struct rt_variables *vars, const struct rt_traversal_vars *trav_vars, nir_ssa_def *geometry_id_and_flags) { nir_ssa_def *geom_force_opaque = nir_ine( b, nir_iand(b, geometry_id_and_flags, nir_imm_int(b, 1u << 28 /* VK_GEOMETRY_OPAQUE_BIT */)), nir_imm_int(b, 0)); nir_ssa_def *instance_force_opaque = nir_ine(b, nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), nir_imm_int(b, 4 << 24 /* VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT */)), nir_imm_int(b, 0)); nir_ssa_def *instance_force_non_opaque = nir_ine(b, nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), nir_imm_int(b, 8 << 24 /* VK_GEOMETRY_INSTANCE_FORCE_NO_OPAQUE_BIT */)), nir_imm_int(b, 0)); nir_ssa_def *opaque = geom_force_opaque; opaque = nir_bcsel(b, instance_force_opaque, nir_imm_bool(b, true), opaque); opaque = nir_bcsel(b, instance_force_non_opaque, nir_imm_bool(b, false), opaque); nir_ssa_def *ray_force_opaque = nir_ine(b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 1 /* RayFlagsOpaque */)), nir_imm_int(b, 0)); nir_ssa_def *ray_force_non_opaque = nir_ine( b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 2 /* RayFlagsNoOpaque */)), nir_imm_int(b, 0)); opaque = nir_bcsel(b, ray_force_opaque, nir_imm_bool(b, true), opaque); opaque = nir_bcsel(b, ray_force_non_opaque, nir_imm_bool(b, false), opaque); return opaque; } static void visit_any_hit_shaders(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, struct rt_variables *vars) { RADV_FROM_HANDLE(radv_pipeline_layout, layout, pCreateInfo->layout); nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx); nir_push_if(b, nir_ine(b, sbt_idx, nir_imm_int(b, 0))); for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) { const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i]; uint32_t shader_id = VK_SHADER_UNUSED_KHR; switch (group_info->type) { case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR: shader_id = group_info->anyHitShader; break; default: break; } if (shader_id == VK_SHADER_UNUSED_KHR) continue; const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; nir_shader *nir_stage = parse_rt_stage(device, layout, stage); vars->group_idx = i; insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2); } nir_pop_if(b, NULL); } static void insert_traversal_triangle_case(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, nir_ssa_def *result, const struct rt_variables *vars, const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node) { nir_ssa_def *dist = nir_vector_extract(b, result, nir_imm_int(b, 0)); nir_ssa_def *div = nir_vector_extract(b, result, nir_imm_int(b, 1)); dist = nir_fdiv(b, dist, div); nir_ssa_def *frontface = nir_flt(b, nir_imm_float(b, 0), div); nir_ssa_def *switch_ccw = nir_ine( b, nir_iand( b, nir_load_var(b, trav_vars->sbt_offset_and_flags), nir_imm_int(b, 2 << 24 /* VK_GEOMETRY_INSTANCE_TRIANGLE_FRONT_COUNTERCLOCKWISE_BIT */)), nir_imm_int(b, 0)); frontface = nir_ixor(b, frontface, switch_ccw); nir_ssa_def *not_cull = nir_ieq( b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 256 /* RayFlagsSkipTriangles */)), nir_imm_int(b, 0)); nir_ssa_def *not_facing_cull = nir_ieq( b, nir_iand(b, nir_load_var(b, vars->flags), nir_bcsel(b, frontface, nir_imm_int(b, 32 /* RayFlagsCullFrontFacingTriangles */), nir_imm_int(b, 16 /* RayFlagsCullBackFacingTriangles */))), nir_imm_int(b, 0)); not_cull = nir_iand( b, not_cull, nir_ior( b, not_facing_cull, nir_ine( b, nir_iand( b, nir_load_var(b, trav_vars->sbt_offset_and_flags), nir_imm_int(b, 1 << 24 /* VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT */)), nir_imm_int(b, 0)))); nir_push_if(b, nir_iand(b, nir_iand(b, nir_flt(b, dist, nir_load_var(b, vars->tmax)), nir_fge(b, dist, nir_load_var(b, vars->tmin))), not_cull)); { nir_ssa_def *triangle_info = nir_build_load_global( b, 2, 32, nir_iadd(b, build_node_to_addr(device, b, bvh_node), nir_imm_int64(b, offsetof(struct radv_bvh_triangle_node, triangle_id))), .align_mul = 4, .align_offset = 0); nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0); nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1); nir_ssa_def *geometry_id = nir_iand(b, geometry_id_and_flags, nir_imm_int(b, 0xfffffff)); nir_ssa_def *is_opaque = hit_is_opaque(b, vars, trav_vars, geometry_id_and_flags); not_cull = nir_ieq(b, nir_iand(b, nir_load_var(b, vars->flags), nir_bcsel(b, is_opaque, nir_imm_int(b, 0x40), nir_imm_int(b, 0x80))), nir_imm_int(b, 0)); nir_push_if(b, not_cull); { nir_ssa_def *sbt_idx = nir_iadd(b, nir_iadd(b, nir_load_var(b, vars->sbt_offset), nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), nir_imm_int(b, 0xffffff))), nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id)); nir_ssa_def *divs[2] = {div, div}; nir_ssa_def *ij = nir_fdiv(b, nir_channels(b, result, 0xc), nir_vec(b, divs, 2)); nir_ssa_def *hit_kind = nir_bcsel(b, frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF)); nir_store_scratch( b, ij, nir_iadd(b, nir_load_var(b, vars->stack_ptr), nir_imm_int(b, RADV_HIT_ATTRIB_OFFSET)), .align_mul = 16, .write_mask = 3); nir_store_var(b, vars->ahit_status, nir_imm_int(b, 0), 1); nir_push_if(b, nir_ine(b, is_opaque, nir_imm_bool(b, true))); { struct rt_variables inner_vars = create_inner_vars(b, vars); nir_store_var(b, inner_vars.primitive_id, primitive_id, 1); nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1); nir_store_var(b, inner_vars.tmax, dist, 0x1); nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1); nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1); nir_store_var(b, inner_vars.hit_kind, hit_kind, 0x1); nir_store_var(b, inner_vars.custom_instance_and_mask, nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4); visit_any_hit_shaders(device, pCreateInfo, b, &inner_vars); nir_push_if(b, nir_ieq(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 1))); { nir_jump(b, nir_jump_continue); } nir_pop_if(b, NULL); } nir_pop_if(b, NULL); nir_store_var(b, vars->primitive_id, primitive_id, 1); nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1); nir_store_var(b, vars->tmax, dist, 0x1); nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1); nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1); nir_store_var(b, vars->hit_kind, hit_kind, 0x1); nir_store_var(b, vars->custom_instance_and_mask, nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); load_sbt_entry(b, vars, sbt_idx, SBT_HIT, 0); nir_store_var(b, trav_vars->should_return, nir_ior(b, nir_ine(b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 8 /* SkipClosestHitShader */)), nir_imm_int(b, 0)), nir_ieq(b, nir_load_var(b, vars->idx), nir_imm_int(b, 0))), 1); nir_ssa_def *terminate_on_first_hit = nir_ine(b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 4 /* TerminateOnFirstHitKHR */)), nir_imm_int(b, 0)); nir_ssa_def *ray_terminated = nir_ieq(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 2)); nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated)); { nir_jump(b, nir_jump_break); } nir_pop_if(b, NULL); } nir_pop_if(b, NULL); } nir_pop_if(b, NULL); } static void insert_traversal_aabb_case(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, const struct rt_variables *vars, const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node) { RADV_FROM_HANDLE(radv_pipeline_layout, layout, pCreateInfo->layout); nir_ssa_def *node_addr = build_node_to_addr(device, b, bvh_node); nir_ssa_def *triangle_info = nir_build_load_global( b, 2, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 24)), .align_mul = 4, .align_offset = 0); nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0); nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1); nir_ssa_def *geometry_id = nir_iand(b, geometry_id_and_flags, nir_imm_int(b, 0xfffffff)); nir_ssa_def *is_opaque = hit_is_opaque(b, vars, trav_vars, geometry_id_and_flags); nir_ssa_def *not_cull = nir_ieq(b, nir_iand(b, nir_load_var(b, vars->flags), nir_bcsel(b, is_opaque, nir_imm_int(b, 0x40), nir_imm_int(b, 0x80))), nir_imm_int(b, 0)); nir_push_if(b, not_cull); { nir_ssa_def *sbt_idx = nir_iadd(b, nir_iadd(b, nir_load_var(b, vars->sbt_offset), nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), nir_imm_int(b, 0xffffff))), nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id)); struct rt_variables inner_vars = create_inner_vars(b, vars); /* For AABBs the intersection shader writes the hit kind, and only does it if it is the * next closest hit candidate. */ inner_vars.hit_kind = vars->hit_kind; nir_store_var(b, inner_vars.primitive_id, primitive_id, 1); nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1); nir_store_var(b, inner_vars.tmax, nir_load_var(b, vars->tmax), 0x1); nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1); nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1); nir_store_var(b, inner_vars.custom_instance_and_mask, nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); nir_store_var(b, inner_vars.opaque, is_opaque, 1); load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4); nir_store_var(b, vars->ahit_status, nir_imm_int(b, 1), 1); nir_push_if(b, nir_ine(b, nir_load_var(b, inner_vars.idx), nir_imm_int(b, 0))); for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) { const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i]; uint32_t shader_id = VK_SHADER_UNUSED_KHR; uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR; switch (group_info->type) { case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR: shader_id = group_info->intersectionShader; any_hit_shader_id = group_info->anyHitShader; break; default: break; } if (shader_id == VK_SHADER_UNUSED_KHR) continue; const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; nir_shader *nir_stage = parse_rt_stage(device, layout, stage); nir_shader *any_hit_stage = NULL; if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) { stage = &pCreateInfo->pStages[any_hit_shader_id]; any_hit_stage = parse_rt_stage(device, layout, stage); nir_lower_intersection_shader(nir_stage, any_hit_stage); ralloc_free(any_hit_stage); } inner_vars.group_idx = i; insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2); } nir_push_else(b, NULL); { nir_ssa_def *vec3_zero = nir_channels(b, nir_imm_vec4(b, 0, 0, 0, 0), 0x7); nir_ssa_def *vec3_inf = nir_channels(b, nir_imm_vec4(b, INFINITY, INFINITY, INFINITY, 0), 0x7); nir_ssa_def *bvh_lo = nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 0)), .align_mul = 4, .align_offset = 0); nir_ssa_def *bvh_hi = nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 12)), .align_mul = 4, .align_offset = 0); bvh_lo = nir_fsub(b, bvh_lo, nir_load_var(b, trav_vars->origin)); bvh_hi = nir_fsub(b, bvh_hi, nir_load_var(b, trav_vars->origin)); nir_ssa_def *t_vec = nir_fmin(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)), nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir))); nir_ssa_def *t2_vec = nir_fmax(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)), nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir))); /* If we run parallel to one of the edges the range should be [0, inf) not [0,0] */ t2_vec = nir_bcsel(b, nir_feq(b, nir_load_var(b, trav_vars->dir), vec3_zero), vec3_inf, t2_vec); nir_ssa_def *t_min = nir_fmax(b, nir_channel(b, t_vec, 0), nir_channel(b, t_vec, 1)); t_min = nir_fmax(b, t_min, nir_channel(b, t_vec, 2)); nir_ssa_def *t_max = nir_fmin(b, nir_channel(b, t2_vec, 0), nir_channel(b, t2_vec, 1)); t_max = nir_fmin(b, t_max, nir_channel(b, t2_vec, 2)); nir_push_if(b, nir_iand(b, nir_flt(b, t_min, nir_load_var(b, vars->tmax)), nir_fge(b, t_max, nir_load_var(b, vars->tmin)))); { nir_store_var(b, vars->ahit_status, nir_imm_int(b, 0), 1); nir_store_var(b, vars->tmax, nir_fmax(b, t_min, nir_load_var(b, vars->tmin)), 1); } nir_pop_if(b, NULL); } nir_pop_if(b, NULL); nir_push_if(b, nir_ine(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 1))); { nir_store_var(b, vars->primitive_id, primitive_id, 1); nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1); nir_store_var(b, vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1); nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1); nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1); nir_store_var(b, vars->custom_instance_and_mask, nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); load_sbt_entry(b, vars, sbt_idx, SBT_HIT, 0); nir_store_var(b, trav_vars->should_return, nir_ior(b, nir_ine(b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 8 /* SkipClosestHitShader */)), nir_imm_int(b, 0)), nir_ieq(b, nir_load_var(b, vars->idx), nir_imm_int(b, 0))), 1); nir_ssa_def *terminate_on_first_hit = nir_ine(b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 4 /* TerminateOnFirstHitKHR */)), nir_imm_int(b, 0)); nir_ssa_def *ray_terminated = nir_ieq(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 2)); nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated)); { nir_jump(b, nir_jump_break); } nir_pop_if(b, NULL); } nir_pop_if(b, NULL); } nir_pop_if(b, NULL); } static void nir_sort_hit_pair(nir_builder *b, nir_variable *var_distances, nir_variable *var_indices, uint32_t chan_1, uint32_t chan_2) { nir_ssa_def *ssa_distances = nir_load_var(b, var_distances); nir_ssa_def *ssa_indices = nir_load_var(b, var_indices); /* if (distances[chan_2] < distances[chan_1]) { */ nir_push_if(b, nir_flt(b, nir_channel(b, ssa_distances, chan_2), nir_channel(b, ssa_distances, chan_1))); { /* swap(distances[chan_2], distances[chan_1]); */ nir_ssa_def *new_distances[4] = {nir_ssa_undef(b, 1, 32), nir_ssa_undef(b, 1, 32), nir_ssa_undef(b, 1, 32), nir_ssa_undef(b, 1, 32)}; nir_ssa_def *new_indices[4] = {nir_ssa_undef(b, 1, 32), nir_ssa_undef(b, 1, 32), nir_ssa_undef(b, 1, 32), nir_ssa_undef(b, 1, 32)}; new_distances[chan_2] = nir_channel(b, ssa_distances, chan_1); new_distances[chan_1] = nir_channel(b, ssa_distances, chan_2); new_indices[chan_2] = nir_channel(b, ssa_indices, chan_1); new_indices[chan_1] = nir_channel(b, ssa_indices, chan_2); nir_store_var(b, var_distances, nir_vec(b, new_distances, 4), (1u << chan_1) | (1u << chan_2)); nir_store_var(b, var_indices, nir_vec(b, new_indices, 4), (1u << chan_1) | (1u << chan_2)); } /* } */ nir_pop_if(b, NULL); } static nir_ssa_def * intersect_ray_amd_software_box(struct radv_device *device, nir_builder *b, nir_ssa_def *bvh_node, nir_ssa_def *ray_tmax, nir_ssa_def *origin, nir_ssa_def *dir, nir_ssa_def *inv_dir) { const struct glsl_type *vec4_type = glsl_vector_type(GLSL_TYPE_FLOAT, 4); const struct glsl_type *uvec4_type = glsl_vector_type(GLSL_TYPE_UINT, 4); nir_ssa_def *node_addr = build_node_to_addr(device, b, bvh_node); /* vec4 distances = vec4(INF, INF, INF, INF); */ nir_variable *distances = nir_variable_create(b->shader, nir_var_shader_temp, vec4_type, "distances"); nir_store_var(b, distances, nir_imm_vec4(b, INFINITY, INFINITY, INFINITY, INFINITY), 0xf); /* uvec4 child_indices = uvec4(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff); */ nir_variable *child_indices = nir_variable_create(b->shader, nir_var_shader_temp, uvec4_type, "child_indices"); nir_store_var(b, child_indices, nir_imm_ivec4(b, 0xffffffffu, 0xffffffffu, 0xffffffffu, 0xffffffffu), 0xf); /* Need to remove infinities here because otherwise we get nasty NaN propogation * if the direction has 0s in it. */ /* inv_dir = clamp(inv_dir, -FLT_MAX, FLT_MAX); */ inv_dir = nir_fclamp(b, inv_dir, nir_imm_float(b, -FLT_MAX), nir_imm_float(b, FLT_MAX)); for (int i = 0; i < 4; i++) { const uint32_t child_offset = offsetof(struct radv_bvh_box32_node, children[i]); const uint32_t coord_offsets[2] = { offsetof(struct radv_bvh_box32_node, coords[i][0][0]), offsetof(struct radv_bvh_box32_node, coords[i][1][0]), }; /* node->children[i] -> uint */ nir_ssa_def *child_index = nir_build_load_global(b, 1, 32, nir_iadd(b, node_addr, nir_imm_int64(b, child_offset)), .align_mul = 64, .align_offset = child_offset % 64 ); /* node->coords[i][0], node->coords[i][1] -> vec3 */ nir_ssa_def *node_coords[2] = { nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, coord_offsets[0])), .align_mul = 64, .align_offset = coord_offsets[0] % 64 ), nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, coord_offsets[1])), .align_mul = 64, .align_offset = coord_offsets[1] % 64 ), }; /* If x of the aabb min is NaN, then this is an inactive aabb. * We don't need to care about any other components being NaN as that is UB. * https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/chap36.html#VkAabbPositionsKHR */ nir_ssa_def *min_x = nir_channel(b, node_coords[0], 0); nir_ssa_def *min_x_is_not_nan = nir_inot(b, nir_fneu(b, min_x, min_x)); /* NaN != NaN -> true */ /* vec3 bound0 = (node->coords[i][0] - origin) * inv_dir; */ nir_ssa_def *bound0 = nir_fmul(b, nir_fsub(b, node_coords[0], origin), inv_dir); /* vec3 bound1 = (node->coords[i][1] - origin) * inv_dir; */ nir_ssa_def *bound1 = nir_fmul(b, nir_fsub(b, node_coords[1], origin), inv_dir); /* float tmin = max(max(min(bound0.x, bound1.x), min(bound0.y, bound1.y)), min(bound0.z, bound1.z)); */ nir_ssa_def *tmin = nir_fmax(b, nir_fmax(b, nir_fmin(b, nir_channel(b, bound0, 0), nir_channel(b, bound1, 0)), nir_fmin(b, nir_channel(b, bound0, 1), nir_channel(b, bound1, 1))), nir_fmin(b, nir_channel(b, bound0, 2), nir_channel(b, bound1, 2))); /* float tmax = min(min(max(bound0.x, bound1.x), max(bound0.y, bound1.y)), max(bound0.z, bound1.z)); */ nir_ssa_def *tmax = nir_fmin(b, nir_fmin(b, nir_fmax(b, nir_channel(b, bound0, 0), nir_channel(b, bound1, 0)), nir_fmax(b, nir_channel(b, bound0, 1), nir_channel(b, bound1, 1))), nir_fmax(b, nir_channel(b, bound0, 2), nir_channel(b, bound1, 2))); /* if (!isnan(node->coords[i][0].x) && tmax >= max(0.0f, tmin) && tmin < ray_tmax) { */ nir_push_if(b, nir_iand(b, min_x_is_not_nan, nir_iand(b, nir_fge(b, tmax, nir_fmax(b, nir_imm_float(b, 0.0f), tmin)), nir_flt(b, tmin, ray_tmax)))); { /* child_indices[i] = node->children[i]; */ nir_ssa_def *new_child_indices[4] = {child_index, child_index, child_index, child_index}; nir_store_var(b, child_indices, nir_vec(b, new_child_indices, 4), 1u << i); /* distances[i] = tmin; */ nir_ssa_def *new_distances[4] = {tmin, tmin, tmin, tmin}; nir_store_var(b, distances, nir_vec(b, new_distances, 4), 1u << i); } /* } */ nir_pop_if(b, NULL); } /* Sort our distances with a sorting network. */ nir_sort_hit_pair(b, distances, child_indices, 0, 1); nir_sort_hit_pair(b, distances, child_indices, 2, 3); nir_sort_hit_pair(b, distances, child_indices, 0, 2); nir_sort_hit_pair(b, distances, child_indices, 1, 3); nir_sort_hit_pair(b, distances, child_indices, 1, 2); return nir_load_var(b, child_indices); } static nir_ssa_def * intersect_ray_amd_software_tri(struct radv_device *device, nir_builder *b, nir_ssa_def *bvh_node, nir_ssa_def *ray_tmax, nir_ssa_def *origin, nir_ssa_def *dir, nir_ssa_def *inv_dir) { const struct glsl_type *vec4_type = glsl_vector_type(GLSL_TYPE_FLOAT, 4); nir_ssa_def *node_addr = build_node_to_addr(device, b, bvh_node); const uint32_t coord_offsets[3] = { offsetof(struct radv_bvh_triangle_node, coords[0]), offsetof(struct radv_bvh_triangle_node, coords[1]), offsetof(struct radv_bvh_triangle_node, coords[2]), }; /* node->coords[0], node->coords[1], node->coords[2] -> vec3 */ nir_ssa_def *node_coords[3] = { nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, coord_offsets[0])), .align_mul = 64, .align_offset = coord_offsets[0] % 64 ), nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, coord_offsets[1])), .align_mul = 64, .align_offset = coord_offsets[1] % 64 ), nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, coord_offsets[2])), .align_mul = 64, .align_offset = coord_offsets[2] % 64 ), }; nir_variable *result = nir_variable_create(b->shader, nir_var_shader_temp, vec4_type, "result"); nir_store_var(b, result, nir_imm_vec4(b, INFINITY, 1.0f, 0.0f, 0.0f), 0xf); /* Based on watertight Ray/Triangle intersection from * http://jcgt.org/published/0002/01/05/paper.pdf */ /* Calculate the dimension where the ray direction is largest */ nir_ssa_def *abs_dir = nir_fabs(b, dir); nir_ssa_def *abs_dirs[3] = { nir_channel(b, abs_dir, 0), nir_channel(b, abs_dir, 1), nir_channel(b, abs_dir, 2), }; /* Find index of greatest value of abs_dir and put that as kz. */ nir_ssa_def *kz = nir_bcsel(b, nir_fge(b, abs_dirs[0], abs_dirs[1]), nir_bcsel(b, nir_fge(b, abs_dirs[0], abs_dirs[2]), nir_imm_int(b, 0), nir_imm_int(b, 2)), nir_bcsel(b, nir_fge(b, abs_dirs[1], abs_dirs[2]), nir_imm_int(b, 1), nir_imm_int(b, 2))); nir_ssa_def *kx = nir_imod(b, nir_iadd(b, kz, nir_imm_int(b, 1)), nir_imm_int(b, 3)); nir_ssa_def *ky = nir_imod(b, nir_iadd(b, kx, nir_imm_int(b, 1)), nir_imm_int(b, 3)); nir_ssa_def *k_indices[3] = { kx, ky, kz }; nir_ssa_def *k = nir_vec(b, k_indices, 3); /* Swap kx and ky dimensions to preseve winding order */ unsigned swap_xy_swizzle[4] = {1, 0, 2, 3}; k = nir_bcsel(b, nir_flt(b, nir_vector_extract(b, dir, kz), nir_imm_float(b, 0.0f)), nir_swizzle(b, k, swap_xy_swizzle, 3), k); kx = nir_channel(b, k, 0); ky = nir_channel(b, k, 1); kz = nir_channel(b, k, 2); /* Calculate shear constants */ nir_ssa_def *sz = nir_frcp(b, nir_vector_extract(b, dir, kz)); nir_ssa_def *sx = nir_fmul(b, nir_vector_extract(b, dir, kx), sz); nir_ssa_def *sy = nir_fmul(b, nir_vector_extract(b, dir, ky), sz); /* Calculate vertices relative to ray origin */ nir_ssa_def *v_a = nir_fsub(b, node_coords[0], origin); nir_ssa_def *v_b = nir_fsub(b, node_coords[1], origin); nir_ssa_def *v_c = nir_fsub(b, node_coords[2], origin); /* Perform shear and scale */ nir_ssa_def *ax = nir_fsub(b, nir_vector_extract(b, v_a, kx), nir_fmul(b, sx, nir_vector_extract(b, v_a, kz))); nir_ssa_def *ay = nir_fsub(b, nir_vector_extract(b, v_a, ky), nir_fmul(b, sy, nir_vector_extract(b, v_a, kz))); nir_ssa_def *bx = nir_fsub(b, nir_vector_extract(b, v_b, kx), nir_fmul(b, sx, nir_vector_extract(b, v_b, kz))); nir_ssa_def *by = nir_fsub(b, nir_vector_extract(b, v_b, ky), nir_fmul(b, sy, nir_vector_extract(b, v_b, kz))); nir_ssa_def *cx = nir_fsub(b, nir_vector_extract(b, v_c, kx), nir_fmul(b, sx, nir_vector_extract(b, v_c, kz))); nir_ssa_def *cy = nir_fsub(b, nir_vector_extract(b, v_c, ky), nir_fmul(b, sy, nir_vector_extract(b, v_c, kz))); nir_ssa_def *u = nir_fsub(b, nir_fmul(b, cx, by), nir_fmul(b, cy, bx)); nir_ssa_def *v = nir_fsub(b, nir_fmul(b, ax, cy), nir_fmul(b, ay, cx)); nir_ssa_def *w = nir_fsub(b, nir_fmul(b, bx, ay), nir_fmul(b, by, ax)); nir_variable *u_var = nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "u"); nir_variable *v_var = nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "v"); nir_variable *w_var = nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "w"); nir_store_var(b, u_var, u, 0x1); nir_store_var(b, v_var, v, 0x1); nir_store_var(b, w_var, w, 0x1); /* Fallback to testing edges with double precision... * * The Vulkan spec states it only needs single precision watertightness * but we fail dEQP-VK.ray_tracing_pipeline.watertightness.closedFan2.1024 with * failures = 1 without doing this. :( */ nir_ssa_def *cond_retest = nir_ior(b, nir_ior(b, nir_feq(b, u, nir_imm_float(b, 0.0f)), nir_feq(b, v, nir_imm_float(b, 0.0f))), nir_feq(b, w, nir_imm_float(b, 0.0f))); nir_push_if(b, cond_retest); { ax = nir_f2f64(b, ax); ay = nir_f2f64(b, ay); bx = nir_f2f64(b, bx); by = nir_f2f64(b, by); cx = nir_f2f64(b, cx); cy = nir_f2f64(b, cy); nir_store_var(b, u_var, nir_f2f32(b, nir_fsub(b, nir_fmul(b, cx, by), nir_fmul(b, cy, bx))), 0x1); nir_store_var(b, v_var, nir_f2f32(b, nir_fsub(b, nir_fmul(b, ax, cy), nir_fmul(b, ay, cx))), 0x1); nir_store_var(b, w_var, nir_f2f32(b, nir_fsub(b, nir_fmul(b, bx, ay), nir_fmul(b, by, ax))), 0x1); } nir_pop_if(b, NULL); u = nir_load_var(b, u_var); v = nir_load_var(b, v_var); w = nir_load_var(b, w_var); /* Perform edge tests. */ nir_ssa_def *cond_back = nir_ior(b, nir_ior(b, nir_flt(b, u, nir_imm_float(b, 0.0f)), nir_flt(b, v, nir_imm_float(b, 0.0f))), nir_flt(b, w, nir_imm_float(b, 0.0f))); nir_ssa_def *cond_front = nir_ior(b, nir_ior(b, nir_flt(b, nir_imm_float(b, 0.0f), u), nir_flt(b, nir_imm_float(b, 0.0f), v)), nir_flt(b, nir_imm_float(b, 0.0f), w)); nir_ssa_def *cond = nir_inot(b, nir_iand(b, cond_back, cond_front)); nir_push_if(b, cond); { nir_ssa_def *det = nir_fadd(b, u, nir_fadd(b, v, w)); nir_ssa_def *az = nir_fmul(b, sz, nir_vector_extract(b, v_a, kz)); nir_ssa_def *bz = nir_fmul(b, sz, nir_vector_extract(b, v_b, kz)); nir_ssa_def *cz = nir_fmul(b, sz, nir_vector_extract(b, v_c, kz)); nir_ssa_def *t = nir_fadd(b, nir_fadd(b, nir_fmul(b, u, az), nir_fmul(b, v, bz)), nir_fmul(b, w, cz)); nir_ssa_def *t_signed = nir_fmul(b, nir_fsign(b, det), t); nir_ssa_def *det_cond_front = nir_inot(b, nir_flt(b, t_signed, nir_imm_float(b, 0.0f))); nir_push_if(b, det_cond_front); { nir_ssa_def *indices[4] = { t, det, v, w }; nir_store_var(b, result, nir_vec(b, indices, 4), 0xf); } nir_pop_if(b, NULL); } nir_pop_if(b, NULL); return nir_load_var(b, result); } static void insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, const struct rt_variables *vars) { unsigned stack_entry_size = 4; unsigned lanes = b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * b->shader->info.workgroup_size[2]; unsigned stack_entry_stride = stack_entry_size * lanes; nir_ssa_def *stack_entry_stride_def = nir_imm_int(b, stack_entry_stride); nir_ssa_def *stack_base = nir_iadd(b, nir_imm_int(b, b->shader->info.shared_size), nir_imul(b, nir_load_subgroup_invocation(b), nir_imm_int(b, stack_entry_size))); /* * A top-level AS can contain 2^24 children and a bottom-level AS can contain 2^24 triangles. At * a branching factor of 4, that means we may need up to 24 levels of box nodes + 1 triangle node * + 1 instance node. Furthermore, when processing a box node, worst case we actually push all 4 * children and remove one, so the DFS stack depth is box nodes * 3 + 2. */ b->shader->info.shared_size += stack_entry_stride * 76; assert(b->shader->info.shared_size <= 32768); nir_ssa_def *accel_struct = nir_load_var(b, vars->accel_struct); struct rt_traversal_vars trav_vars = init_traversal_vars(b); /* Initialize the follow-up shader idx to 0, to be replaced by the miss shader * if we actually miss. */ nir_store_var(b, vars->idx, nir_imm_int(b, 0), 1); nir_store_var(b, trav_vars.should_return, nir_imm_bool(b, false), 1); nir_push_if(b, nir_ine(b, accel_struct, nir_imm_int64(b, 0))); { nir_store_var(b, trav_vars.bvh_base, build_addr_to_node(b, accel_struct), 1); nir_ssa_def *bvh_root = nir_build_load_global(b, 1, 32, accel_struct, .access = ACCESS_NON_WRITEABLE, .align_mul = 64, .align_offset = 0); /* We create a BVH descriptor that covers the entire memory range. That way we can always * use the same descriptor, which avoids divergence when different rays hit different * instances at the cost of having to use 64-bit node ids. */ const uint64_t bvh_size = 1ull << 42; nir_ssa_def *desc = nir_imm_ivec4( b, 0, 1u << 31 /* Enable box sorting */, (bvh_size - 1) & 0xFFFFFFFFu, ((bvh_size - 1) >> 32) | (1u << 24 /* Return IJ for triangles */) | (1u << 31)); nir_ssa_def *vec3ones = nir_channels(b, nir_imm_vec4(b, 1.0, 1.0, 1.0, 1.0), 0x7); nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7); nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7); nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7); nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_imm_int(b, 0), 1); nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1); nir_store_var(b, trav_vars.stack, nir_iadd(b, stack_base, stack_entry_stride_def), 1); nir_store_shared(b, bvh_root, stack_base, .base = 0, .write_mask = 0x1, .align_mul = stack_entry_size, .align_offset = 0); nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, 0), 1); nir_push_loop(b); nir_push_if(b, nir_ieq(b, nir_load_var(b, trav_vars.stack), stack_base)); nir_jump(b, nir_jump_break); nir_pop_if(b, NULL); nir_push_if( b, nir_uge(b, nir_load_var(b, trav_vars.top_stack), nir_load_var(b, trav_vars.stack))); nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, 0), 1); nir_store_var(b, trav_vars.bvh_base, build_addr_to_node(b, nir_load_var(b, vars->accel_struct)), 1); nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7); nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7); nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7); nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1); nir_pop_if(b, NULL); nir_store_var(b, trav_vars.stack, nir_isub(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1); nir_ssa_def *bvh_node = nir_load_shared(b, 1, 32, nir_load_var(b, trav_vars.stack), .base = 0, .align_mul = stack_entry_size, .align_offset = 0); nir_ssa_def *bvh_node_type = nir_iand(b, bvh_node, nir_imm_int(b, 7)); bvh_node = nir_iadd(b, nir_load_var(b, trav_vars.bvh_base), nir_u2u(b, bvh_node, 64)); nir_ssa_def *intrinsic_result = NULL; if (device->physical_device->rad_info.chip_class >= GFX10_3 && !(device->instance->perftest_flags & RADV_PERFTEST_FORCE_EMULATE_RT)) { intrinsic_result = nir_bvh64_intersect_ray_amd( b, 32, desc, nir_unpack_64_2x32(b, bvh_node), nir_load_var(b, vars->tmax), nir_load_var(b, trav_vars.origin), nir_load_var(b, trav_vars.dir), nir_load_var(b, trav_vars.inv_dir)); } nir_push_if(b, nir_ine(b, nir_iand(b, bvh_node_type, nir_imm_int(b, 4)), nir_imm_int(b, 0))); { nir_push_if(b, nir_ine(b, nir_iand(b, bvh_node_type, nir_imm_int(b, 2)), nir_imm_int(b, 0))); { /* custom */ nir_push_if( b, nir_ine(b, nir_iand(b, bvh_node_type, nir_imm_int(b, 1)), nir_imm_int(b, 0))); { insert_traversal_aabb_case(device, pCreateInfo, b, vars, &trav_vars, bvh_node); } nir_push_else(b, NULL); { /* instance */ nir_ssa_def *instance_node_addr = build_node_to_addr(device, b, bvh_node); nir_ssa_def *instance_data = nir_build_load_global( b, 4, 32, instance_node_addr, .align_mul = 64, .align_offset = 0); nir_ssa_def *wto_matrix[] = { nir_build_load_global(b, 4, 32, nir_iadd(b, instance_node_addr, nir_imm_int64(b, 16)), .align_mul = 64, .align_offset = 16), nir_build_load_global(b, 4, 32, nir_iadd(b, instance_node_addr, nir_imm_int64(b, 32)), .align_mul = 64, .align_offset = 32), nir_build_load_global(b, 4, 32, nir_iadd(b, instance_node_addr, nir_imm_int64(b, 48)), .align_mul = 64, .align_offset = 48)}; nir_ssa_def *instance_id = nir_build_load_global( b, 1, 32, nir_iadd(b, instance_node_addr, nir_imm_int64(b, 88)), .align_mul = 4, .align_offset = 0); nir_ssa_def *instance_and_mask = nir_channel(b, instance_data, 2); nir_ssa_def *instance_mask = nir_ushr(b, instance_and_mask, nir_imm_int(b, 24)); nir_push_if(b, nir_ieq(b, nir_iand(b, instance_mask, nir_load_var(b, vars->cull_mask)), nir_imm_int(b, 0))); nir_jump(b, nir_jump_continue); nir_pop_if(b, NULL); nir_store_var(b, trav_vars.top_stack, nir_load_var(b, trav_vars.stack), 1); nir_store_var(b, trav_vars.bvh_base, build_addr_to_node( b, nir_pack_64_2x32(b, nir_channels(b, instance_data, 0x3))), 1); nir_store_shared(b, nir_iand(b, nir_channel(b, instance_data, 0), nir_imm_int(b, 63)), nir_load_var(b, trav_vars.stack), .base = 0, .write_mask = 0x1, .align_mul = stack_entry_size, .align_offset = 0); nir_store_var(b, trav_vars.stack, nir_iadd(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1); nir_store_var( b, trav_vars.origin, nir_build_vec3_mat_mult_pre(b, nir_load_var(b, vars->origin), wto_matrix), 7); nir_store_var( b, trav_vars.dir, nir_build_vec3_mat_mult(b, nir_load_var(b, vars->direction), wto_matrix, false), 7); nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7); nir_store_var(b, trav_vars.custom_instance_and_mask, instance_and_mask, 1); nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_channel(b, instance_data, 3), 1); nir_store_var(b, trav_vars.instance_id, instance_id, 1); nir_store_var(b, trav_vars.instance_addr, instance_node_addr, 1); } nir_pop_if(b, NULL); } nir_push_else(b, NULL); { /* box */ nir_ssa_def *result = intrinsic_result; if (!result) { /* If we didn't run the intrinsic cause the hardware didn't support it, * emulate ray/box intersection here */ result = intersect_ray_amd_software_box(device, b, bvh_node, nir_load_var(b, vars->tmax), nir_load_var(b, trav_vars.origin), nir_load_var(b, trav_vars.dir), nir_load_var(b, trav_vars.inv_dir)); } for (unsigned i = 4; i-- > 0; ) { nir_ssa_def *new_node = nir_vector_extract(b, result, nir_imm_int(b, i)); nir_push_if(b, nir_ine(b, new_node, nir_imm_int(b, 0xffffffff))); { nir_store_shared(b, new_node, nir_load_var(b, trav_vars.stack), .base = 0, .write_mask = 0x1, .align_mul = stack_entry_size, .align_offset = 0); nir_store_var( b, trav_vars.stack, nir_iadd(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1); } nir_pop_if(b, NULL); } } nir_pop_if(b, NULL); } nir_push_else(b, NULL); { nir_ssa_def *result = intrinsic_result; if (!result) { /* If we didn't run the intrinsic cause the hardware didn't support it, * emulate ray/tri intersection here */ result = intersect_ray_amd_software_tri(device, b, bvh_node, nir_load_var(b, vars->tmax), nir_load_var(b, trav_vars.origin), nir_load_var(b, trav_vars.dir), nir_load_var(b, trav_vars.inv_dir)); } insert_traversal_triangle_case(device, pCreateInfo, b, result, vars, &trav_vars, bvh_node); } nir_pop_if(b, NULL); nir_pop_loop(b, NULL); } nir_pop_if(b, NULL); /* should_return is set if we had a hit but we won't be calling the closest hit shader and hence * need to return immediately to the calling shader. */ nir_push_if(b, nir_load_var(b, trav_vars.should_return)); { insert_rt_return(b, vars); } nir_push_else(b, NULL); { /* Only load the miss shader if we actually miss, which we determining by not having set * a closest hit shader. It is valid to not specify an SBT pointer for miss shaders if none * of the rays miss. */ nir_push_if(b, nir_ieq(b, nir_load_var(b, vars->idx), nir_imm_int(b, 0))); { load_sbt_entry(b, vars, nir_load_var(b, vars->miss_index), SBT_MISS, 0); } nir_pop_if(b, NULL); } nir_pop_if(b, NULL); } static unsigned compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const struct radv_pipeline_shader_stack_size *stack_sizes) { unsigned raygen_size = 0; unsigned callable_size = 0; unsigned chit_size = 0; unsigned miss_size = 0; unsigned non_recursive_size = 0; for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) { non_recursive_size = MAX2(stack_sizes[i].non_recursive_size, non_recursive_size); const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i]; uint32_t shader_id = VK_SHADER_UNUSED_KHR; unsigned size = stack_sizes[i].recursive_size; switch (group_info->type) { case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR: shader_id = group_info->generalShader; break; case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR: case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR: shader_id = group_info->closestHitShader; break; default: break; } if (shader_id == VK_SHADER_UNUSED_KHR) continue; const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; switch (stage->stage) { case VK_SHADER_STAGE_RAYGEN_BIT_KHR: raygen_size = MAX2(raygen_size, size); break; case VK_SHADER_STAGE_MISS_BIT_KHR: miss_size = MAX2(miss_size, size); break; case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR: chit_size = MAX2(chit_size, size); break; case VK_SHADER_STAGE_CALLABLE_BIT_KHR: callable_size = MAX2(callable_size, size); break; default: unreachable("Invalid stage type in RT shader"); } } return raygen_size + MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) * MAX2(MAX2(chit_size, miss_size), non_recursive_size) + MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) * MAX2(chit_size, miss_size) + 2 * callable_size; } bool radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo) { if (!pCreateInfo->pDynamicState) return false; for (unsigned i = 0; i < pCreateInfo->pDynamicState->dynamicStateCount; ++i) { if (pCreateInfo->pDynamicState->pDynamicStates[i] == VK_DYNAMIC_STATE_RAY_TRACING_PIPELINE_STACK_SIZE_KHR) return true; } return false; } static nir_shader * create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_pipeline_shader_stack_size *stack_sizes) { RADV_FROM_HANDLE(radv_pipeline_layout, layout, pCreateInfo->layout); struct radv_pipeline_key key; memset(&key, 0, sizeof(key)); nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, NULL, "rt_combined"); b.shader->info.workgroup_size[0] = 8; b.shader->info.workgroup_size[1] = 8; b.shader->info.workgroup_size[2] = 1; struct rt_variables vars = create_rt_variables(b.shader, stack_sizes); load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, 0); nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1); nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, true), 1); nir_loop *loop = nir_push_loop(&b); nir_push_if(&b, nir_ior(&b, nir_ieq(&b, nir_load_var(&b, vars.idx), nir_imm_int(&b, 0)), nir_ine(&b, nir_load_var(&b, vars.main_loop_case_visited), nir_imm_bool(&b, true)))); nir_jump(&b, nir_jump_break); nir_pop_if(&b, NULL); nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, false), 1); nir_push_if(&b, nir_ieq(&b, nir_load_var(&b, vars.idx), nir_imm_int(&b, 1))); nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, true), 1); insert_traversal(device, pCreateInfo, &b, &vars); nir_pop_if(&b, NULL); nir_ssa_def *idx = nir_load_var(&b, vars.idx); /* We do a trick with the indexing of the resume shaders so that the first * shader of group x always gets id x and the resume shader ids then come after * groupCount. This makes the shadergroup handles independent of compilation. */ unsigned call_idx_base = pCreateInfo->groupCount + 1; for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) { const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i]; uint32_t shader_id = VK_SHADER_UNUSED_KHR; switch (group_info->type) { case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR: shader_id = group_info->generalShader; break; case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR: case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR: shader_id = group_info->closestHitShader; break; default: break; } if (shader_id == VK_SHADER_UNUSED_KHR) continue; const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; nir_shader *nir_stage = parse_rt_stage(device, layout, stage); b.shader->options = nir_stage->options; uint32_t num_resume_shaders = 0; nir_shader **resume_shaders = NULL; nir_lower_shader_calls(nir_stage, nir_address_format_32bit_offset, 16, &resume_shaders, &num_resume_shaders, nir_stage); vars.group_idx = i; insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, i + 2); for (unsigned j = 0; j < num_resume_shaders; ++j) { insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, call_idx_base + 1 + j); } call_idx_base += num_resume_shaders; } nir_pop_loop(&b, loop); if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo)) { /* Put something so scratch gets enabled in the shader. */ b.shader->scratch_size = 16; } else b.shader->scratch_size = compute_rt_stack_size(pCreateInfo, stack_sizes); /* Deal with all the inline functions. */ nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader)); nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none); return b.shader; } static VkResult radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline) { RADV_FROM_HANDLE(radv_device, device, _device); VkResult result; struct radv_pipeline *pipeline = NULL; struct radv_pipeline_shader_stack_size *stack_sizes = NULL; uint8_t hash[20]; nir_shader *shader = NULL; bool keep_statistic_info = (pCreateInfo->flags & VK_PIPELINE_CREATE_CAPTURE_STATISTICS_BIT_KHR) || (device->instance->debug_flags & RADV_DEBUG_DUMP_SHADER_STATS) || device->keep_shader_info; if (pCreateInfo->flags & VK_PIPELINE_CREATE_LIBRARY_BIT_KHR) return radv_rt_pipeline_library_create(_device, _cache, pCreateInfo, pAllocator, pPipeline); VkRayTracingPipelineCreateInfoKHR local_create_info = radv_create_merged_rt_create_info(pCreateInfo); if (!local_create_info.pStages || !local_create_info.pGroups) { result = VK_ERROR_OUT_OF_HOST_MEMORY; goto fail; } radv_hash_rt_shaders(hash, &local_create_info, radv_get_hash_flags(device, keep_statistic_info)); struct vk_shader_module module = {.base.type = VK_OBJECT_TYPE_SHADER_MODULE}; VkComputePipelineCreateInfo compute_info = { .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, .pNext = NULL, .flags = pCreateInfo->flags | VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT_EXT, .stage = { .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, .stage = VK_SHADER_STAGE_COMPUTE_BIT, .module = vk_shader_module_to_handle(&module), .pName = "main", }, .layout = pCreateInfo->layout, }; /* First check if we can get things from the cache before we take the expensive step of * generating the nir. */ result = radv_compute_pipeline_create(_device, _cache, &compute_info, pAllocator, hash, stack_sizes, local_create_info.groupCount, pPipeline); if (result == VK_PIPELINE_COMPILE_REQUIRED_EXT) { stack_sizes = calloc(sizeof(*stack_sizes), local_create_info.groupCount); if (!stack_sizes) { result = VK_ERROR_OUT_OF_HOST_MEMORY; goto fail; } shader = create_rt_shader(device, &local_create_info, stack_sizes); module.nir = shader; compute_info.flags = pCreateInfo->flags; result = radv_compute_pipeline_create(_device, _cache, &compute_info, pAllocator, hash, stack_sizes, local_create_info.groupCount, pPipeline); stack_sizes = NULL; if (result != VK_SUCCESS) goto shader_fail; } pipeline = radv_pipeline_from_handle(*pPipeline); pipeline->compute.rt_group_handles = calloc(sizeof(*pipeline->compute.rt_group_handles), local_create_info.groupCount); if (!pipeline->compute.rt_group_handles) { result = VK_ERROR_OUT_OF_HOST_MEMORY; goto shader_fail; } pipeline->compute.dynamic_stack_size = radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo); for (unsigned i = 0; i < local_create_info.groupCount; ++i) { const VkRayTracingShaderGroupCreateInfoKHR *group_info = &local_create_info.pGroups[i]; switch (group_info->type) { case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR: if (group_info->generalShader != VK_SHADER_UNUSED_KHR) pipeline->compute.rt_group_handles[i].handles[0] = i + 2; break; case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR: if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR) pipeline->compute.rt_group_handles[i].handles[1] = i + 2; FALLTHROUGH; case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR: if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR) pipeline->compute.rt_group_handles[i].handles[0] = i + 2; if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR) pipeline->compute.rt_group_handles[i].handles[1] = i + 2; break; case VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR: unreachable("VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR"); } } shader_fail: if (result != VK_SUCCESS && pipeline) radv_pipeline_destroy(device, pipeline, pAllocator); ralloc_free(shader); fail: free((void *)local_create_info.pGroups); free((void *)local_create_info.pStages); free(stack_sizes); return result; } VkResult radv_CreateRayTracingPipelinesKHR(VkDevice _device, VkDeferredOperationKHR deferredOperation, VkPipelineCache pipelineCache, uint32_t count, const VkRayTracingPipelineCreateInfoKHR *pCreateInfos, const VkAllocationCallbacks *pAllocator, VkPipeline *pPipelines) { VkResult result = VK_SUCCESS; unsigned i = 0; for (; i < count; i++) { VkResult r; r = radv_rt_pipeline_create(_device, pipelineCache, &pCreateInfos[i], pAllocator, &pPipelines[i]); if (r != VK_SUCCESS) { result = r; pPipelines[i] = VK_NULL_HANDLE; if (pCreateInfos[i].flags & VK_PIPELINE_CREATE_EARLY_RETURN_ON_FAILURE_BIT_EXT) break; } } for (; i < count; ++i) pPipelines[i] = VK_NULL_HANDLE; return result; } VkResult radv_GetRayTracingShaderGroupHandlesKHR(VkDevice device, VkPipeline _pipeline, uint32_t firstGroup, uint32_t groupCount, size_t dataSize, void *pData) { RADV_FROM_HANDLE(radv_pipeline, pipeline, _pipeline); char *data = pData; STATIC_ASSERT(sizeof(*pipeline->compute.rt_group_handles) <= RADV_RT_HANDLE_SIZE); memset(data, 0, groupCount * RADV_RT_HANDLE_SIZE); for (uint32_t i = 0; i < groupCount; ++i) { memcpy(data + i * RADV_RT_HANDLE_SIZE, &pipeline->compute.rt_group_handles[firstGroup + i], sizeof(*pipeline->compute.rt_group_handles)); } return VK_SUCCESS; } VkDeviceSize radv_GetRayTracingShaderGroupStackSizeKHR(VkDevice device, VkPipeline _pipeline, uint32_t group, VkShaderGroupShaderKHR groupShader) { RADV_FROM_HANDLE(radv_pipeline, pipeline, _pipeline); const struct radv_pipeline_shader_stack_size *stack_size = &pipeline->compute.rt_stack_sizes[group]; if (groupShader == VK_SHADER_GROUP_SHADER_ANY_HIT_KHR || groupShader == VK_SHADER_GROUP_SHADER_INTERSECTION_KHR) return stack_size->non_recursive_size; else return stack_size->recursive_size; }