Skip to content

Commit

Permalink
Merge pull request #10 from jiaaodong/hotfix/allow_targetmode
Browse files Browse the repository at this point in the history
use the same logic for high-level features to modify test script as well
  • Loading branch information
andraspalffy authored Jul 9, 2021
2 parents 2b1541f + 85539fb commit f32a197
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions RTCnet/test_RTC_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

# Modules from current project
from TargetLoader import TargetModeDataset, ToTensor
from RTCnet import RTCnet
from RTCnet import RTCnet, TargetMLP
from RTCnet_utils import Tester, Tester_ensemble
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
result_DIR = osp.join(BASE_DIR, osp.pardir, 'results', 'RTCtrain_info')
Expand Down Expand Up @@ -199,7 +199,8 @@ def main():
else:
num_classes = 2
weights_all = np.multiply(weights_factor, weights_all)
feature_type = "low"
model_version = 4
chosen_feature_type = 'high' if model_version == 3 else 'low'
cfg['testset'] = test_data_path
with open(osp.join(result_folder, 'info.json'), 'w') as fp:
json.dump(cfg, fp, sort_keys=True, indent=4, separators=(',', ': '))
Expand All @@ -213,7 +214,7 @@ def main():
test_data = TargetModeDataset(
test_data_path, composed_trans,
mode='test', high_dims=high_dims,
normalize=True, feature_type= feature_type,
normalize=True, feature_type= chosen_feature_type,
norms_path=result_folder,
speed_limit=speed_limit,
dist_near=dist_near,
Expand All @@ -230,12 +231,17 @@ def main():
np.save(osp.join(result_folder, "valid_indx_test"), test_data.indx_valid)

################### Define model ###################
model = RTCnet(
num_classes=2,
Doppler_dims=32,
high_dims = high_dims,
dropout= dropout,
input_size = input_size)
if model_version == 4:
model = RTCnet(
num_classes=2,
Doppler_dims=32,
high_dims = high_dims,
dropout= dropout,
input_size = input_size)
elif model_version == 3:
model = TargetMLP(
num_classes=2,
high_dims=high_dims)

################### Ped VS ALL ###########
scores_Ped_vs_All = test_ova(
Expand Down

0 comments on commit f32a197

Please sign in to comment.