Skip to content

Commit

Permalink
Further gsplat shader refinements (#7185)
Browse files Browse the repository at this point in the history
  • Loading branch information
slimbuck authored Dec 9, 2024
1 parent 80b4be9 commit fd10688
Show file tree
Hide file tree
Showing 16 changed files with 222 additions and 186 deletions.
28 changes: 16 additions & 12 deletions examples/src/examples/loaders/gsplat-many.shader.vert
Original file line number Diff line number Diff line change
Expand Up @@ -40,37 +40,41 @@ vec4 animateColor(float height, vec4 clr) {

void main(void) {
// read gaussian center
SplatState state;
if (!initState(state)) {
SplatSource source;
if (!initSource(source)) {
gl_Position = discardVec;
return;
}

vec3 center = animatePosition(readCenter(state));
vec3 centerPos = animatePosition(readCenter(source));

SplatCenter center;
initCenter(source, centerPos, center);

// project center to screen space
ProjectedState projState;
if (!projectCenter(state, center, projState)) {
SplatCorner corner;
if (!initCorner(source, center, corner)) {
gl_Position = discardVec;
return;
}

// read color
vec4 clr = readColor(state);
vec4 clr = readColor(source);

// evaluate spherical harmonics
#if SH_BANDS > 0
clr.xyz = max(clr.xyz + evalSH(state, projState), 0.0);
vec3 dir = normalize(center.view * mat3(center.modelView));
clr.xyz += evalSH(state, dir);
#endif

clr = animateColor(center.y, clr);
clr = animateColor(centerPos.y, clr);

applyClipping(projState, clr.w);
clipCorner(corner, clr.w);

// write output
gl_Position = projState.cornerProj;
gaussianUV = projState.cornerUV;
gaussianColor = vec4(prepareOutputFromGamma(clr.xyz), clr.w);
gl_Position = center.proj + vec4(corner.offset, 0.0, 0.0);
gaussianUV = corner.uv;
gaussianColor = vec4(prepareOutputFromGamma(max(clr.xyz, 0.0)), clr.w);

#ifndef DITHER_NONE
id = float(state.id);
Expand Down
5 changes: 4 additions & 1 deletion src/scene/gsplat/gsplat-compressed.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class GSplatCompressed {

numSplats;

numSplatsVisible;

/** @type {BoundingBox} */
aabb;

Expand Down Expand Up @@ -57,6 +59,7 @@ class GSplatCompressed {

this.device = device;
this.numSplats = numSplats;
this.numVisibleSplats = numSplats;

// initialize aabb
this.aabb = new BoundingBox();
Expand Down Expand Up @@ -147,7 +150,7 @@ class GSplatCompressed {
result.setDefine('GSPLAT_COMPRESSED_DATA', true);
result.setParameter('packedTexture', this.packedTexture);
result.setParameter('chunkTexture', this.chunkTexture);
result.setParameter('tex_params', new Float32Array([this.numSplats, this.packedTexture.width, this.chunkTexture.width / 5]));
result.setParameter('numSplats', this.numSplatsVisible);
if (this.shTexture0) {
result.setDefine('SH_BANDS', 3);
result.setParameter('shTexture0', this.shTexture0);
Expand Down
5 changes: 1 addition & 4 deletions src/scene/gsplat/gsplat-instance.js
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,7 @@ class GSplatInstance {
this.meshInstance.instancingCount = Math.ceil(count / splatInstanceSize);

// update splat count on the material
const tex_params = this.material.getParameter('tex_params');
if (tex_params?.data) {
tex_params.data[0] = count;
}
this.material.setParameter('numSplats', count);
});
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/scene/gsplat/gsplat.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class GSplat {

numSplats;

numSplatsVisible;

/** @type {Float32Array} */
centers;

Expand Down Expand Up @@ -70,6 +72,7 @@ class GSplat {

this.device = device;
this.numSplats = numSplats;
this.numSplatsVisible = numSplats;

this.centers = new Float32Array(gsplatData.numSplats * 3);
gsplatData.getCenters(this.centers);
Expand Down Expand Up @@ -116,7 +119,7 @@ class GSplat {
result.setParameter('splatColor', this.colorTexture);
result.setParameter('transformA', this.transformATexture);
result.setParameter('transformB', this.transformBTexture);
result.setParameter('tex_params', new Float32Array([this.numSplats, this.colorTexture.width]));
result.setParameter('numSplats', this.numSplatsVisible);
if (this.hasSH) {
result.setDefine('SH_BANDS', 3);
result.setParameter('splatSH_1to3', this.sh1to3Texture);
Expand Down
10 changes: 8 additions & 2 deletions src/scene/shader-lib/chunks/chunks.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,17 @@ import gamma2_2PS from './common/frag/gamma2_2.js';
import gles3PS from '../../../platform/graphics/shader-chunks/frag/gles3.js';
import gles3VS from '../../../platform/graphics/shader-chunks/vert/gles3.js';
import glossPS from './standard/frag/gloss.js';
import gsplatCenterVS from './gsplat/vert/gsplatCenter.js';
import gsplatColorVS from './gsplat/vert/gsplatColor.js';
import gsplatCommonVS from './gsplat/vert/gsplatCommon.js';
import gsplatCompressedDataVS from './gsplat/vert/gsplatCompressedData.js';
import gsplatCompressedSHVS from './gsplat/vert/gsplatCompressedSH.js';
import gsplatCornerVS from './gsplat/vert/gsplatCorner.js';
import gsplatDataVS from './gsplat/vert/gsplatData.js';
import gsplatOutputPS from './gsplat/gsplatOutput.js';
import gsplatOutputVS from './gsplat/vert/gsplatOutput.js';
import gsplatPS from './gsplat/frag/gsplat.js';
import gsplatSHVS from './gsplat/vert/gsplatSH.js';
import gsplatSourceVS from './gsplat/vert/gsplatSource.js';
import gsplatVS from './gsplat/vert/gsplat.js';
import iridescenceDiffractionPS from './lit/frag/iridescenceDiffraction.js';
import iridescencePS from './standard/frag/iridescence.js';
Expand Down Expand Up @@ -266,14 +269,17 @@ const shaderChunks = {
gles3PS,
gles3VS,
glossPS,
gsplatCenterVS,
gsplatCornerVS,
gsplatColorVS,
gsplatCommonVS,
gsplatCompressedDataVS,
gsplatCompressedSHVS,
gsplatDataVS,
gsplatOutputPS,
gsplatOutputVS,
gsplatPS,
gsplatSHVS,
gsplatSourceVS,
gsplatVS,
iridescenceDiffractionPS,
iridescencePS,
Expand Down
27 changes: 16 additions & 11 deletions src/scene/shader-lib/chunks/gsplat/vert/gsplat.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,43 @@ mediump vec4 discardVec = vec4(0.0, 0.0, 2.0, 1.0);
void main(void) {
// read gaussian details
SplatState state;
if (!initState(state)) {
SplatSource source;
if (!initSource(source)) {
gl_Position = discardVec;
return;
}
vec3 center = readCenter(state);
vec3 modelCenter = readCenter(source);
SplatCenter center;
initCenter(source, modelCenter, center);
// project center to screen space
ProjectedState projState;
if (!projectCenter(state, center, projState)) {
SplatCorner corner;
if (!initCorner(source, center, corner)) {
gl_Position = discardVec;
return;
}
// read color
vec4 clr = readColor(state);
vec4 clr = readColor(source);
// evaluate spherical harmonics
#if SH_BANDS > 0
clr.xyz += evalSH(state, projState);
// calculate the model-space view direction
vec3 dir = normalize(center.view * mat3(center.modelView));
clr.xyz += evalSH(source, dir);
#endif
applyClipping(projState, clr.w);
clipCorner(corner, clr.w);
// write output
gl_Position = projState.cornerProj;
gaussianUV = projState.cornerUV;
gl_Position = center.proj + vec4(corner.offset, 0, 0);
gaussianUV = corner.uv;
gaussianColor = vec4(prepareOutputFromGamma(max(clr.xyz, 0.0)), clr.w);
#ifndef DITHER_NONE
id = float(state.id);
id = float(source.id);
#endif
}
`;
27 changes: 27 additions & 0 deletions src/scene/shader-lib/chunks/gsplat/vert/gsplatCenter.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
export default /* glsl */`
uniform mat4 matrix_model;
uniform mat4 matrix_view;
uniform mat4 matrix_projection;
// project the model space gaussian center to view and clip space
bool initCenter(SplatSource source, vec3 modelCenter, out SplatCenter center) {
mat4 modelView = matrix_view * matrix_model;
vec4 centerView = modelView * vec4(modelCenter, 1.0);
// early out if splat is behind the camear
if (centerView.z > 0.0) {
return false;
}
vec4 centerProj = matrix_projection * centerView;
// ensure gaussians are not clipped by camera near and far
centerProj.z = clamp(centerProj.z, -abs(centerProj.w), abs(centerProj.w));
center.view = centerView.xyz / centerView.w;
center.proj = centerProj;
center.projMat00 = matrix_projection[0][0];
center.modelView = modelView;
return true;
}
`;
4 changes: 2 additions & 2 deletions src/scene/shader-lib/chunks/gsplat/vert/gsplatColor.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ export default /* glsl */`
uniform mediump sampler2D splatColor;
vec4 readColor(in SplatState state) {
return texelFetch(splatColor, state.uv, 0);
vec4 readColor(in SplatSource source) {
return texelFetch(splatColor, source.uv, 0);
}
`;
Loading

0 comments on commit fd10688

Please sign in to comment.