package compiler;
import static d3d11.D3D11.*;
import static d3d11.D3D11.D3D_DRIVER_TYPE.*;
import static org.bridj.Pointer.*;
import static org.junit.Assert.assertEquals;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.channels.FileChannel.MapMode;
import org.bridj.Pointer;
import org.bridj.ValuedEnum;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import sun.nio.ch.DirectBuffer;
import d3d11.D3D11.D3D_INCLUDE_TYPE;
import d3d11.core.ID3D10Blob;
import d3d11.core.ID3D11Device;
import d3d11.core.ID3D11DeviceContext;
import d3d11.shader.ID3D11VertexShader;
import d3dcompiler.D3DCompiler.ID3DInclude;
import static d3dcompiler.D3DCompiler.*;
public class TestInclude {
ID3D11Device device;
ID3D11DeviceContext immediateContext;
@After
public void destroy() {
// Release objects
immediateContext.Release();
device.Release();
}
@Before
public void init() {
Pointer<Pointer<ID3D11Device>> pDevice = allocatePointer(ID3D11Device.class);
Pointer<Pointer<ID3D11DeviceContext>> pDeviceContext = allocatePointer(ID3D11DeviceContext.class);
// Create device
int result = D3D11CreateDevice(null,
D3D_DRIVER_TYPE_HARDWARE,
null,
0,
null,
0,
D3D11_SDK_VERSION,
pDevice,
null,
pDeviceContext);
assertEquals(0, result);
device = new ID3D11Device(pDevice.get());
immediateContext = new ID3D11DeviceContext(pDeviceContext.get());
}
public static class IncludeHandler extends ID3DInclude {
private FileChannel fileChannel;
private ByteBuffer mappedData;
@Override
public int Open(ValuedEnum<D3D_INCLUDE_TYPE> IncludeType, Pointer<Byte> pFileName,
Pointer<?> pParentData, Pointer<Pointer<?>> ppData, Pointer<Integer> pBytes) {
String fileName = pFileName.getCString();
try {
// Map file to memory, avoiding copy
fileChannel = new FileInputStream(fileName).getChannel();
long fileSize = fileChannel.size();
mappedData = fileChannel.map(MapMode.READ_ONLY, 0, fileSize);
// Set data and size
ppData.setPointer(pointerToBytes(mappedData));
pBytes.setInt((int) fileSize);
return 0;
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
return -1;
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
return -2;
}
}
@Override
public int Close(Pointer<?> pData) {
try {
// Force unmap of file
if (mappedData instanceof sun.nio.ch.DirectBuffer) {
sun.misc.Cleaner cleaner = ((sun.nio.ch.DirectBuffer) mappedData).cleaner();
cleaner.clean();
}
// Close file
fileChannel.close();
} catch (IOException e) {}
return 0;
}
}
public static class TestIncludeHandler extends ID3DInclude {
@Override
public int Open(ValuedEnum<D3D_INCLUDE_TYPE> IncludeType, Pointer<Byte> pFileName,
Pointer<?> pParentData, Pointer<Pointer<?>> ppData, Pointer<Integer> pBytes) {
String fileName = pFileName.getCString();
// Set data and size
String code = "";
ppData.setPointer(pointerToCString(code));
pBytes.setInt(code.length());
return 0;
}
@Override
public int Close(Pointer<?> pData) {
pData.release();
return 0;
}
}
@Test
public void compileShaders() {
// Create include callbacks
TestIncludeHandler handler = new TestIncludeHandler();
String vertexShader =
"#include \"test.h\" \n" +
" \n" +
"float4 VS( float4 Pos : POSITION ) : SV_POSITION \n" +
"{ \n" +
" return Pos; \n" +
"} \n";
//String pixelShader = "float4 PS( float4 Pos : SV_POSITION ) : SV_Target { return float4( 1.0f, 1.0f, 0.0f, 1.0f ); // Yellow, with Alpha = 1}";
Pointer<Pointer<ID3D10Blob>> ppCode = allocatePointer(ID3D10Blob.class);
Pointer<Pointer<ID3D10Blob>> ppErrorMsgs = allocatePointer(ID3D10Blob.class);
int result = D3DCompile(pointerToCString(vertexShader), vertexShader.length(), null, null, pointerTo(handler), pointerToCString("VS"), pointerToCString("vs_5_0"), 0, 0, ppCode, ppErrorMsgs);
assertEquals(0, result);
ID3D10Blob vsCode = ppCode.get().getNativeObject(ID3D10Blob.class);
Pointer<Pointer<ID3D11VertexShader>> ppVS = allocatePointer(ID3D11VertexShader.class);
result = device.CreateVertexShader(vsCode.GetBufferPointer(), (int)vsCode.GetBufferSize(), null, ppVS);
assertEquals(0, result);
ID3D11VertexShader vs = ppVS.get().getNativeObject(ID3D11VertexShader.class);
result = vs.Release();
assertEquals(0, result);
}
}