diff --git a/tests/backend_tests.py b/tests/backend_tests.py index f41d8b8..cf6b6d7 100644 --- a/tests/backend_tests.py +++ b/tests/backend_tests.py @@ -6,20 +6,14 @@ from shutil import copyfile from pdc_config import TestConfig from app import create_app, db_mgr, db from app.auth.models import User +from tests.common_db_feed import resources_to_instancedb class BaseTestCase(unittest.TestCase): def setUp(self): # initialise app self.app = create_app(TestConfig) - - # copy resource demo db to test file - appdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir) - sqlite_file_name = os.path.abspath(os.path.join(appdir, 'resources', 'lesia-btp.sqlite')) - if not os.path.isdir(self.app.instance_path): - os.mkdir(self.app.instance_path) - self.db_path = os.path.join(self.app.instance_path, 'test.db') - copyfile(sqlite_file_name, self.db_path) + self.db_path = resources_to_instancedb(self.app) # force db path to newly create file self.app.config.update( @@ -59,4 +53,3 @@ class DbMgrTestCase(BaseTestCase): stacked_charges = db_mgr.charges_by_project_stacked(60) # Waiting for 17 periods + headers line self.assertEqual(18, len(stacked_charges)) - diff --git a/tests/common_db_feed.py b/tests/common_db_feed.py index dbe76db..09c658d 100644 --- a/tests/common_db_feed.py +++ b/tests/common_db_feed.py @@ -1,5 +1,20 @@ +import os +from shutil import copyfile + from app.models import Category, db, Label, Project, ProjectLabel + +def resources_to_instancedb(app): + # copy resource demo db to test file + appdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir) + sqlite_file_name = os.path.abspath(os.path.join(appdir, 'resources', 'lesia-btp.sqlite')) + if not os.path.isdir(app.instance_path): + os.mkdir(app.instance_path) + db_path = os.path.join(app.instance_path, 'test.db') + copyfile(sqlite_file_name, db_path) + return db_path + + categorized_labels = {'pole': ['Spatial', 'Sol'], 'domaine': ['soleil-terre', 'atmosphere', 'r&t', 'géologie']} projects_categories = {'ChemCam': {'pole': 'Spatial', 'domaine': 'géologie'}, 'Pilot': {'pole': 'Spatial', 'domaine': 'atmosphere'}, @@ -30,7 +45,7 @@ def feed_projects(): n_l = db.session.query(Label).filter(Label.name == l_name).one() n_pc = ProjectLabel(project=n_p, category=n_c, label=n_l) db.session.add(n_pc) - + db.session.commit() -- libgit2 0.21.2