Skip to content

Commit

Permalink
feat: add descriptor binding logic
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Pollind <[email protected]>
  • Loading branch information
pollend committed Dec 1, 2024
1 parent 4b283c0 commit b4079f3
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 48 deletions.
1 change: 0 additions & 1 deletion Source/Metal/DescriptorMTL.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ struct DescriptorMTL {
Result Create(const Texture3DViewDesc& textureViewDesc);
Result Create(const SamplerDesc& samplerDesc);


private:

DeviceMTL& m_Device;
Expand Down
14 changes: 6 additions & 8 deletions Source/Metal/DescriptorPoolMTL.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,29 @@ struct DescriptorPoolMTL {
, m_AllocatedSets(device.GetStdAllocator()) {
m_AllocatedSets.reserve(64);
}
//
// inline operator VkDescriptorPool() const {
// return m_Handle;
// }
//

inline DeviceMTL& GetDevice() const {
return m_Device;
}

~DescriptorPoolMTL();

Result Create(const DescriptorPoolDesc& descriptorPoolDesc);

//================================================================================================================
// NRI
//================================================================================================================


// size_t GetNumberOfArugmentsAlloc();
void SetDebugName(const char* name);
void Reset();
Result AllocateDescriptorSets(const PipelineLayout& pipelineLayout, uint32_t setIndex, DescriptorSet** descriptorSets, uint32_t instanceNum, uint32_t variableDescriptorNum);

private:
DeviceMTL& m_Device;
size_t m_ArgumentOffset = 0;
Vector<DescriptorSetMTL*> m_AllocatedSets;
//VkDescriptorPool m_Handle = VK_NULL_HANDLE;
id<MTLBuffer> m_ArgumentBuffer;

uint32_t m_UsedSets = 0;
bool m_OwnsNativeObjects = true;
};
Expand Down
45 changes: 44 additions & 1 deletion Source/Metal/DescriptorPoolMTL.mm
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

#include "SharedMTL.h"
#include "DescriptorPoolMTL.h"
#include "DescriptorSetMTL.h"
#include "PipelineLayoutMTL.h"

using namespace nri;

Expand All @@ -10,9 +12,25 @@
}

Result DescriptorPoolMTL::Create(const DescriptorPoolDesc& descriptorPoolDesc) {

size_t numArgs = descriptorPoolDesc.samplerMaxNum +
descriptorPoolDesc.constantBufferMaxNum +
descriptorPoolDesc.dynamicConstantBufferMaxNum +
descriptorPoolDesc.textureMaxNum +
descriptorPoolDesc.storageTextureMaxNum +
descriptorPoolDesc.bufferMaxNum +
descriptorPoolDesc.storageBufferMaxNum +
descriptorPoolDesc.structuredBufferMaxNum +
descriptorPoolDesc.accelerationStructureMaxNum;

m_ArgumentBuffer = [m_Device
newBufferWithLength: numArgs * sizeof(uint32_t) options:MTLResourceStorageModeShared];
}


//size_t DescriptorPoolMTL::GetNumberOfArugmentsAlloc() {
// return [m_ArgumentBuffer length] / sizeof(uint32_t);
//}

//================================================================================================================
// NRI
//================================================================================================================
Expand All @@ -23,7 +41,32 @@
void DescriptorPoolMTL::Reset() {

}

Result DescriptorPoolMTL::AllocateDescriptorSets(const PipelineLayout& pipelineLayout, uint32_t setIndex, DescriptorSet** descriptorSets, uint32_t instanceNum, uint32_t variableDescriptorNum) {
PipelineLayoutMTL* pipelineLayoutMTL = (PipelineLayoutMTL*)&pipelineLayout;

uint32_t freeSetNum = (uint32_t)m_AllocatedSets.size() - m_UsedSets;
if (freeSetNum < instanceNum) {
uint32_t newSetNum = instanceNum - freeSetNum;
uint32_t prevSetNum = (uint32_t)m_AllocatedSets.size();
m_AllocatedSets.resize(prevSetNum + newSetNum);
for (size_t i = 0; i < newSetNum; i++) {
Construct(m_AllocatedSets[prevSetNum + i], 1, m_Device);
}
}

struct DescriptorSetLayout* setLayoutMTL = pipelineLayoutMTL->GetDescriptorSetLayout(setIndex);
for(uint32_t i = 0; i < instanceNum; i++) {
descriptorSets[i] = (DescriptorSet*)m_AllocatedSets[m_UsedSets++];
((DescriptorSetMTL*)descriptorSets[i])->Create(
m_ArgumentOffset,
m_ArgumentBuffer,
setLayoutMTL->m_ArgumentDescriptors,
&setLayoutMTL->m_DescriptorSetDesc);
m_ArgumentOffset += ((DescriptorSetMTL*)descriptorSets[i])->getDescriptorLength();
}


return Result::SUCCESS;
}

12 changes: 7 additions & 5 deletions Source/Metal/DescriptorSetMTL.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@ struct DescriptorSetMTL {
: m_Device(device) {
}


void Create(size_t argumentBufferOffset, id<MTLBuffer> argumentBuffer, NSArray<MTLArgumentDescriptor *>* argDesc, const struct DescriptorSetDesc* desc);
void UpdateDescriptorRanges(uint32_t rangeOffset, uint32_t rangeNum, const DescriptorRangeUpdateDesc* rangeUpdateDescs);

inline id<MTLArgumentEncoder> GetArgumentHandle() {
return m_ArgumentEncoder;
}
size_t getDescriptorLength();

private:
DeviceMTL& m_Device;
id<MTLArgumentEncoder> m_ArgumentEncoder;

id<MTLArgumentEncoder> m_ArgumentEncoder;
id<MTLBuffer> m_ArgumentBuffer;
size_t m_ArgumentBufferOffset;
NSArray<MTLArgumentDescriptor *>* m_ArgumentDescriptor;
const DescriptorSetDesc* m_Desc = nullptr;
};


} // namespace nri

41 changes: 28 additions & 13 deletions Source/Metal/DescriptorSetMTL.mm
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,45 @@
using namespace nri;


void DescriptorSetMTL::Create(size_t argumentBufferOffset, id<MTLBuffer> argumentBuffer, NSArray<MTLArgumentDescriptor *>* argDesc, const struct DescriptorSetDesc* desc) {
m_ArgumentDescriptor = argDesc;
m_ArgumentBuffer = argumentBuffer;
m_ArgumentBufferOffset = argumentBufferOffset;
m_Desc = desc;
m_ArgumentEncoder = [m_Device newArgumentEncoderWithArguments: argDesc];
[m_ArgumentEncoder setArgumentBuffer:m_ArgumentBuffer offset:argumentBufferOffset];
}

size_t DescriptorSetMTL::getDescriptorLength() {
return [m_ArgumentEncoder encodedLength];
}

void DescriptorSetMTL::UpdateDescriptorRanges(uint32_t rangeOffset, uint32_t rangeNum, const DescriptorRangeUpdateDesc* rangeUpdateDescs) {

for(size_t j = 0; j < rangeNum; j++) {
const DescriptorRangeUpdateDesc& update = rangeUpdateDescs[j];

// uint32_t offset = update.baseDescriptor + descriptorOffset;

for(size_t descIdx = 0; descIdx < update.descriptorNum; descIdx++) {
DescriptorMTL* descriptorImpl = (DescriptorMTL*)&update.descriptors[descIdx];
const DescriptorRangeDesc& rangeDesc = m_Desc->ranges[rangeOffset + j];

DescriptorMTL& descriptorImpl = *(DescriptorMTL*)&update.descriptors[descIdx];
switch(descriptorImpl.GetType()) {
switch(descriptorImpl->GetType()) {
case DescriptorTypeMTL::IMAGE_VIEW_1D:
[m_ArgumentEncoder setTexture: descriptorImpl.GetTextureHandle() atIndex:0];
// [m_ArgumentEncoder setTextures:<#(id<MTLTexture> _Nullable const * _Nonnull)#> withRange:<#(NSRange)#>]
break;
case DescriptorTypeMTL::IMAGE_VIEW_2D:
[m_ArgumentEncoder setTexture: descriptorImpl.GetTextureHandle() atIndex:0];
case DescriptorTypeMTL::IMAGE_VIEW_3D:
[m_ArgumentEncoder setTexture: descriptorImpl->GetTextureHandle() atIndex: rangeDesc.baseRegisterIndex + descIdx];
break;
case DescriptorTypeMTL::BUFFER_VIEW:
// [m_ArgumentEncoder setBuffer: offset:<#(NSUInteger)#> atIndex:<#(NSUInteger)#>]
case DescriptorTypeMTL::SAMPLER:
[m_ArgumentEncoder setSamplerState: descriptorImpl->GetSamplerStateHandler() atIndex:rangeDesc.baseRegisterIndex + descIdx]; // not sure if this is correct
break;
case DescriptorTypeMTL::BUFFER_VIEW: {
BufferViewDesc* view = &descriptorImpl->BufferView();
[m_ArgumentEncoder setBuffer: descriptorImpl->GetBufferHandle() offset: view->offset atIndex: rangeDesc.baseRegisterIndex + descIdx];
break;
}
default:
break;
}

// update.descriptors[descIdx];

}
}
}
48 changes: 41 additions & 7 deletions Source/Metal/PipelineLayoutMTL.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,57 @@ namespace nri {

struct DeviceMTL;


//
//struct BindingInfo {
// BindingInfo(StdAllocator<uint8_t>& allocator);
//// Vector<PushConstantBindingDesc> pushConstantBindings;
//// Vector<PushDescriptorBindingDesc> pushDescriptorBindings;
//};

struct DescriptorSetLayout {
DescriptorSetDesc m_DescriptorSetDesc;
NSMutableArray<MTLArgumentDescriptor*>* m_ArgumentDescriptors;
};

struct PipelineLayoutMTL {

inline PipelineLayoutMTL (DeviceMTL& device)
: m_Device(device) {
: m_Device(device)
, m_HasVariableDescriptorNum(device.GetStdAllocator())
, m_DescriptorSetRangeDescs(device.GetStdAllocator())
, m_DynamicConstantBufferDescs(device.GetStdAllocator())
, m_DescriptorSets(device.GetStdAllocator())
{
}

~PipelineLayoutMTL();

Result Create(const PipelineLayoutDesc& pipelineLayoutDesc);

struct PipelineDescriptorSet {
NSMutableArray<MTLArgumentDescriptor*>* m_ArgumentDescriptors;
};
inline DeviceMTL& GetDevice() const {
return m_Device;
}

inline struct DescriptorSetLayout* GetDescriptorSetLayout(uint32_t setIndex) {
return &m_DescriptorSets[setIndex];
}

// inline struct DescriptorSetDesc* GetDescriptorSetDesc(uint32_t setIndex) {
// return &m_DescriptorSetDesc[setIndex];
// }

Result Create(const PipelineLayoutDesc& pipelineLayoutDesc);


private:
DeviceMTL& m_Device;

std::vector<PipelineDescriptorSet> m_DescriptorSets;
Vector<bool> m_HasVariableDescriptorNum;
Vector<DescriptorRangeDesc> m_DescriptorSetRangeDescs;
Vector<DynamicConstantBufferDesc> m_DynamicConstantBufferDescs;
// Vector<DescriptorSetDesc> m_DescriptorSetDesc;
Vector<DescriptorSetLayout> m_DescriptorSets;

// BindingInfo m_BindingInfo;
};

}
71 changes: 58 additions & 13 deletions Source/Metal/PipelineLayoutMTL.mm
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,76 @@

using namespace nri;

//
//BindingInfo::BindingInfo(StdAllocator<uint8_t>& allocator)
// : hasVariableDescriptorNum(allocator)
// , descriptorSetRangeDescs(allocator)
// , dynamicConstantBufferDescs(allocator)
// , descriptorSetDescs(allocator) {
//
//}
//

PipelineLayoutMTL::~PipelineLayoutMTL() {

}

Result PipelineLayoutMTL::Create(const PipelineLayoutDesc& pipelineLayoutDesc) {


size_t rangeNum = 0;
size_t dynamicConstantBufferNum = 0;
for (uint32_t i = 0; i < pipelineLayoutDesc.descriptorSetNum; i++) {
rangeNum += pipelineLayoutDesc.descriptorSets[i].rangeNum;
dynamicConstantBufferNum += pipelineLayoutDesc.descriptorSets[i].dynamicConstantBufferNum;
}

m_DescriptorSets.resize(pipelineLayoutDesc.descriptorSetNum);
m_HasVariableDescriptorNum.resize(pipelineLayoutDesc.descriptorSetNum);
m_DescriptorSetRangeDescs.reserve(rangeNum);
m_DynamicConstantBufferDescs.reserve(dynamicConstantBufferNum);

for (uint32_t i = 0; i < pipelineLayoutDesc.descriptorSetNum; i++) {
const DescriptorSetDesc& descriptorSetDesc = pipelineLayoutDesc.descriptorSets[i];

NSMutableArray<MTLArgumentDescriptor*>* argumentDescriptors = [[NSMutableArray alloc] init];
MTLArgumentDescriptor* argDescriptor = [MTLArgumentDescriptor argumentDescriptor];
// Binding info
m_HasVariableDescriptorNum[i] = false;
m_DescriptorSets[i].m_DescriptorSetDesc = descriptorSetDesc;
m_DescriptorSets[i].m_DescriptorSetDesc.ranges = m_DescriptorSetRangeDescs.data() +m_DescriptorSetRangeDescs.size();
m_DescriptorSets[i].m_DescriptorSetDesc.dynamicConstantBuffers = m_DynamicConstantBufferDescs.data() + m_DynamicConstantBufferDescs.size();
m_DescriptorSetRangeDescs.insert(m_DescriptorSetRangeDescs.end(), descriptorSetDesc.ranges, descriptorSetDesc.ranges + descriptorSetDesc.rangeNum);
m_DynamicConstantBufferDescs.insert(m_DynamicConstantBufferDescs.end(), descriptorSetDesc.dynamicConstantBuffers, descriptorSetDesc.dynamicConstantBuffers + descriptorSetDesc.dynamicConstantBufferNum);

NSMutableArray<MTLArgumentDescriptor*>* argumentDescriptors = [[NSMutableArray alloc] init];
for(size_t r = 0; r < descriptorSetDesc.rangeNum; r++) {

MTLArgumentDescriptor* argDescriptor = [MTLArgumentDescriptor argumentDescriptor];
const DescriptorRangeDesc* range = &descriptorSetDesc.ranges[r];
argDescriptor.arrayLength = range->descriptorNum;
argDescriptor.access = MTLBindingAccessReadWrite;
argDescriptor.index = range->baseRegisterIndex;
switch(range->descriptorType) {
case DescriptorType::TEXTURE:
argDescriptor.dataType = MTLDataTypeTexture;
argDescriptor.textureType = MTLTextureType2D; // descriptor type does not have this
break;
case DescriptorType::SAMPLER:
argDescriptor.dataType = MTLDataTypeSampler;
break;
case DescriptorType::CONSTANT_BUFFER:
case DescriptorType::STORAGE_TEXTURE:
case DescriptorType::BUFFER:
case DescriptorType::STORAGE_BUFFER:
case DescriptorType::STRUCTURED_BUFFER:
case DescriptorType::STORAGE_STRUCTURED_BUFFER:
argDescriptor.dataType = MTLDataTypeStruct;
break;
case DescriptorType::ACCELERATION_STRUCTURE:
argDescriptor.dataType = MTLDataTypePrimitiveAccelerationStructure;
break;
default:
break;
}
}


//argDescriptor.access = memberDescriptor.mAccessType;
//argDescriptor.arrayLength = memberDescriptor.mArrayLength;
//argDescriptor.constantBlockAlignment = memberDescriptor.mAlignment;
//argDescriptor.dataType = memberDescriptor.mDataType;
//argDescriptor.index = memberDescriptor.mArgumentIndex;
//argDescriptor.textureType = memberDescriptor.mTextureType;

[argumentDescriptors addObject:argDescriptor];
m_DescriptorSets[i].m_ArgumentDescriptors = argumentDescriptors;

}
Expand Down

0 comments on commit b4079f3

Please sign in to comment.