Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
Bharath2 committed Dec 28, 2024
1 parent d92de18 commit e937cde
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions include/siren_nerf.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class SirenLayer : public torch::nn::Module {
explicit SirenLayer(int64_t dim_in,
int64_t dim_out,
bool is_first = false,
float w0,
bool use_bias = true,
float c = 6.0f);

Expand Down
8 changes: 4 additions & 4 deletions src/siren_nerf.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "siren_nerf.h"

SirenLayer::SirenLayer(int64_t dim_in, int64_t dim_out,
bool is_first, bool use_bias, float c)
: dim_in_(dim_in), is_first_(is_first), w0_(is_first ? 120.0f : 1.0f) {
bool is_first, float w0, bool use_bias, float c)
: dim_in_(dim_in), is_first_(is_first), w0_(w0) {
// Initialize weight and bias
weight_ = register_parameter("weight", torch::zeros({dim_out, dim_in}));
float w_std = is_first_ ? (1.0f / dim_in_) : (std::sqrt(c / dim_in_) / w0_);
Expand All @@ -28,12 +28,12 @@ SirenNeRF::SirenNeRF(torch::Device device, int W, int D): device_(device) {
D = std::max(D, 2);

// Create position encoder SIREN layers
pos_siren_ = std::make_shared<SirenLayer>(1, 64, true);
pos_siren_ = std::make_shared<SirenLayer>(1, 64, true, 120);
register_module("pos_siren", pos_siren_);
pos_siren_->to(device_);

// Create view direction encoder SIREN layers
view_siren_ = std::make_shared<SirenLayer>(1, 32, true);
view_siren_ = std::make_shared<SirenLayer>(1, 32, true, 20);
register_module("view_siren", view_siren_);
view_siren_->to(device_);

Expand Down

0 comments on commit e937cde

Please sign in to comment.