diff --git a/source/adapters/level_zero/v2/memory.cpp b/source/adapters/level_zero/v2/memory.cpp index 3e885d5f6c..0ea30e8605 100644 --- a/source/adapters/level_zero/v2/memory.cpp +++ b/source/adapters/level_zero/v2/memory.cpp @@ -209,7 +209,7 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t( device_access_mode_t accessMode) : ur_mem_handle_t_(hContext, size, accessMode), deviceAllocations(hContext->getPlatform()->getNumDevices()), - activeAllocationDevice(nullptr), hostAllocations() { + activeAllocationDevice(nullptr), mapToPtr(hostPtr), hostAllocations() { if (hostPtr) { auto initialDevice = hContext->getDevices()[0]; UR_CALL_THROWS(migrateBufferTo(initialDevice, hostPtr, size)); @@ -299,21 +299,30 @@ void *ur_discrete_mem_handle_t::mapHostPtr( TRACK_SCOPE_LATENCY("ur_discrete_mem_handle_t::mapHostPtr"); // TODO: use async alloc? - void *ptr; - UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate( - hContext, nullptr, nullptr, UR_USM_TYPE_HOST, size, &ptr)); + void *ptr = mapToPtr; + if (!ptr) { + UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate( + hContext, nullptr, nullptr, UR_USM_TYPE_HOST, size, &ptr)); + } + + usm_unique_ptr_t mappedPtr = + usm_unique_ptr_t(ptr, [ownsAlloc = bool(mapToPtr), this](void *p) { + if (ownsAlloc) { + UR_CALL_THROWS(hContext->getDefaultUSMPool()->free(p)); + } + }); - hostAllocations.emplace_back(ptr, size, offset, flags); + hostAllocations.emplace_back(std::move(mappedPtr), size, offset, flags); if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ)) { auto srcPtr = ur_cast( deviceAllocations[activeAllocationDevice->Id.value()].get()) + offset; - migrate(srcPtr, hostAllocations.back().ptr, size); + migrate(srcPtr, hostAllocations.back().ptr.get(), size); } - return hostAllocations.back().ptr; + return hostAllocations.back().ptr.get(); } void ur_discrete_mem_handle_t::unmapHostPtr( @@ -322,7 +331,7 @@ void ur_discrete_mem_handle_t::unmapHostPtr( TRACK_SCOPE_LATENCY("ur_discrete_mem_handle_t::unmapHostPtr"); for (auto &hostAllocation : hostAllocations) { - if (hostAllocation.ptr == pMappedPtr) { + if (hostAllocation.ptr.get() == pMappedPtr) { void *devicePtr = nullptr; if (activeAllocationDevice) { devicePtr = @@ -337,11 +346,9 @@ void ur_discrete_mem_handle_t::unmapHostPtr( } if (devicePtr) { - migrate(hostAllocation.ptr, devicePtr, hostAllocation.size); + migrate(hostAllocation.ptr.get(), devicePtr, hostAllocation.size); } - // TODO: use async free here? - UR_CALL_THROWS(hContext->getDefaultUSMPool()->free(hostAllocation.ptr)); return; } } diff --git a/source/adapters/level_zero/v2/memory.hpp b/source/adapters/level_zero/v2/memory.hpp index 575a313e14..1067389280 100644 --- a/source/adapters/level_zero/v2/memory.hpp +++ b/source/adapters/level_zero/v2/memory.hpp @@ -98,11 +98,11 @@ struct ur_integrated_mem_handle_t : public ur_mem_handle_t_ { }; struct host_allocation_desc_t { - host_allocation_desc_t(void *ptr, size_t size, size_t offset, + host_allocation_desc_t(usm_unique_ptr_t ptr, size_t size, size_t offset, ur_map_flags_t flags) - : ptr(ptr), size(size), offset(offset), flags(flags) {} + : ptr(std::move(ptr)), size(size), offset(offset), flags(flags) {} - void *ptr; + usm_unique_ptr_t ptr; size_t size; size_t offset; ur_map_flags_t flags; @@ -146,6 +146,9 @@ struct ur_discrete_mem_handle_t : public ur_mem_handle_t_ { // If not null, copy the buffer content back to this memory on release. void *writeBackPtr = nullptr; + // If not null, mapHostPtr should map memory to this ptr + void *mapToPtr = nullptr; + std::vector hostAllocations; void *allocateOnDevice(ur_device_handle_t hDevice, size_t size); diff --git a/test/conformance/enqueue/urEnqueueMemBufferMap.cpp b/test/conformance/enqueue/urEnqueueMemBufferMap.cpp index eb06724139..1f0404ca6d 100644 --- a/test/conformance/enqueue/urEnqueueMemBufferMap.cpp +++ b/test/conformance/enqueue/urEnqueueMemBufferMap.cpp @@ -212,6 +212,26 @@ TEST_P(urEnqueueMemBufferMapTestWithParam, SuccessMultiMaps) { } } +TEST_P(urEnqueueMemBufferMapTestWithParam, SUCCESS) { + uur::raii::Mem buffer = nullptr; + + void *ptr = new char[4096]; + + ur_buffer_properties_t props; + props.pHost = ptr; + + ASSERT_SUCCESS(urMemBufferCreate(context, 0, 4096, &props, buffer.ptr())); + + void *mappedPtr = nullptr; + ASSERT_SUCCESS(urEnqueueMemBufferMap( + queue, buffer.get(), true, UR_MAP_FLAG_READ | UR_MAP_FLAG_WRITE, 0, + size, 0, nullptr, nullptr, &mappedPtr)); + + ASSERT_EQ(ptr, mappedPtr); + + delete ptr; +} + TEST_P(urEnqueueMemBufferMapTestWithParam, InvalidNullHandleQueue) { void *map = nullptr; ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE,