1- import tensorflow as tf
2- import os
31import argparse
42import glob
5- import numpy as np
3+ import os
4+
65import keras_tuner as kt
6+ import tensorflow as tf
7+
78from ds_mesh_to_pc import read_off
89
10+
911def create_model (hp ):
1012 model = tf .keras .Sequential ()
1113 model .add (tf .keras .layers .InputLayer (input_shape = (2048 , 3 )))
12-
14+
1315 for i in range (hp .Int ('num_layers' , 1 , 5 )):
1416 model .add (tf .keras .layers .Dense (
1517 hp .Int (f'layer_{ i } _units' , min_value = 64 , max_value = 1024 , step = 64 ),
1618 activation = 'relu'
1719 ))
18-
20+
1921 model .add (tf .keras .layers .Dense (3 , activation = 'sigmoid' ))
20-
22+
2123 model .compile (
2224 optimizer = tf .keras .optimizers .Adam (
2325 learning_rate = hp .Float ('learning_rate' , 1e-5 , 1e-3 , sampling = 'log' )
@@ -28,7 +30,7 @@ def create_model(hp):
2830
2931def load_and_preprocess_data (input_dir , batch_size ):
3032 file_paths = glob .glob (os .path .join (input_dir , "*.ply" ))
31-
33+
3234 def parse_ply_file (file_path ):
3335 mesh_data = read_off (file_path )
3436 return mesh_data .vertices
@@ -47,7 +49,7 @@ def data_generator():
4749 dataset = dataset .shuffle (buffer_size = len (file_paths ))
4850 dataset = dataset .batch (batch_size )
4951 dataset = dataset .prefetch (tf .data .experimental .AUTOTUNE )
50-
52+
5153 return dataset
5254
5355def tune_hyperparameters (input_dir , output_dir , num_epochs = 10 ):
@@ -63,10 +65,10 @@ def tune_hyperparameters(input_dir, output_dir, num_epochs=10):
6365
6466 dataset = load_and_preprocess_data (input_dir , batch_size = 32 )
6567 tuner .search (dataset , epochs = num_epochs , validation_data = dataset )
66-
68+
6769 best_model = tuner .get_best_models (num_models = 1 )[0 ]
6870 best_hps = tuner .get_best_hyperparameters (num_trials = 1 )[0 ]
69-
71+
7072 print ("Best Hyperparameters:" , best_hps .values )
7173 best_model .save (os .path .join (output_dir , 'best_model' ))
7274
@@ -95,4 +97,4 @@ def main():
9597 model .save (os .path .join (args .output_dir , 'trained_model' ))
9698
9799if __name__ == "__main__" :
98- main ()
100+ main ()
0 commit comments