diff --git a/input/example.in b/input/example.in index c4e3404..8e09162 100644 --- a/input/example.in +++ b/input/example.in @@ -24,16 +24,3 @@ elm = 0 - - -Name = 20241025_120806 - -Name = 20241025_120806 - -Name = 20241025_123012 - -Name = 20241025_123835 - -Name = 20241025_124525 - -Name = 20241025_125032 diff --git a/routine.py b/routine.py index ed36976..a7e0edf 100644 --- a/routine.py +++ b/routine.py @@ -23,14 +23,21 @@ start = time() now = dt.datetime.now() name = str(now.strftime("%Y%m%d")+'_'+now.strftime("%H%M%S")) -path = source_dir+'/models/CSE_0D/'+name +path = source_dir+'/model/'+name + +print('Model path:', path) ## ================================================== INPUT ======== ## ADJUST THESE PARAMETERS FOR DIFFERENT MODELS ## READ INPUT FILE -arg = sys.argv[1] +try: + arg = sys.argv[1] +except Exception: + print('Please provide an input file.') + print('$ python routine.py example') + sys.exit() infile = source_dir+'/input/'+arg+'.in' @@ -54,9 +61,8 @@ ## Load train & test data sets traindata, testdata, data_loader, test_loader = ds.get_data(dt_fract=input.dt_fract, - nb_samples=input.nb_samples, batch_size=batch_size, - nb_test=input.nb_test,kwargs=kwargs) - + nb_samples=input.nb_samples, batch_size=batch_size, + nb_test=input.nb_test,kwargs=kwargs, inpackage=True) ## Make model model = mace.Solver(n_dim=input.n_dim, p_dim=4,z_dim = input.z_dim, nb_hidden=input.nb_hidden, ae_type=input.ae_type, @@ -79,9 +85,9 @@ ## Train tic = time() -train.train(model, - data_loader, test_loader, - end_epochs = input.ini_epochs, +train.train(model, + data_loader, test_loader, + end_epochs = input.ini_epochs, trainloss=trainloss, testloss=testloss, start_time = start) toc = time() @@ -161,7 +167,7 @@ # print(i+1,end='\r') testpath = traindata.testpath[i] - err_test, err_evol, step_time, evol_time = test.test_model(model,testpath, meta, printing = False) + err_test, err_evol, step_time, evol_time = test.test_model(model,testpath, meta, printing = False, inpackage=True, datapath='train' ) sum_err_step += err_test sum_err_evol += err_evol