Source code for bout_runners.database.database_reader

"""Module containing the DatabaseReader class."""


from typing import Iterable, Mapping, Optional, Union

import pandas as pd
from numpy import int64

from bout_runners.database.database_connector import DatabaseConnector


[docs]class DatabaseReader: r""" Class for reading the schema of the database. Attributes ---------- db_connector : DatabaseConnector The database object to read from Methods ------- query(query_str) Make a query to the database get_latest_row_id() Return the latest row id get_entry_id(table_name, entries_dict) Get the id of a table entry check_tables_created() Check if the tables is created in the database Examples -------- Import dependencies >>> from pathlib import Path >>> from bout_runners.executor.bout_paths import BoutPaths >>> from bout_runners.parameters.default_parameters import DefaultParameters >>> from bout_runners.parameters.final_parameters import FinalParameters >>> from bout_runners.database.database_connector import DatabaseConnector >>> from bout_runners.database.database_creator import DatabaseCreator >>> from bout_runners.database.database_writer import DatabaseWriter Create the `bout_paths` object >>> project_path = Path().joinpath('path', 'to', 'project') >>> bout_inp_src_dir = Path().joinpath('path', 'to', 'source', 'BOUT.inp') >>> bout_inp_dst_dir = Path().joinpath('path', 'to', 'destination', 'BOUT.inp') >>> bout_paths = BoutPaths(project_path=project_path, ... bout_inp_src_dir=bout_inp_src_dir, ... bout_inp_dst_dir=bout_inp_dst_dir) Obtain the parameters >>> default_parameters = DefaultParameters(bout_paths) >>> final_parameters = FinalParameters(default_parameters) >>> final_parameters_dict = final_parameters.get_final_parameters() >>> final_parameters_as_sql_types = \ ... final_parameters.cast_to_sql_type(final_parameters_dict) Create the database >>> db_connector = DatabaseConnector('name', project_path) >>> db_creator = DatabaseCreator(db_connector) >>> db_creator.create_all_schema_tables(final_parameters_as_sql_types) Write to the database >>> db_writer = DatabaseWriter(db_connector) >>> dummy_split_dict = {'number_of_processors': 1, ... 'number_of_nodes': 2, ... 'processors_per_node': 3} >>> db_writer.create_entry('split', dummy_split_dict) Read from the database >>> db_reader = DatabaseReader(db_connector) >>> db_reader.query('SELECT * FROM split') id number_of_processors number_of_nodes processors_per_node 0 1 1 1 1 >>> db_reader.get_latest_row_id() 1 >>> db_reader.get_entry_id('split', dummy_split_dict) 1 >>> db_reader.check_tables_created() True """ def __init__(self, db_connector: DatabaseConnector) -> None: """ Set the database to use. Parameters ---------- db_connector : DatabaseConnector The database object to read from """ self.db_connector = db_connector
[docs] def query( self, query_str: str, params: Optional[Iterable[Union[str, float, int, bool, None]]] = None, **kwargs, ) -> pd.DataFrame: """ Make a query to the database. Parameters ---------- query_str : str The query to execute params : array_like Array of parameters Used to protect against SQL injection >>> self.query("SELECT ? FROM table WHERE foo = ?", params=("bar", "baz")) kwargs : dict Additional keyword parameters to pd.read_sql_query Returns ------- table : pd.DataFrame The result of a query as a DataFrame """ table = pd.read_sql_query( query_str, self.db_connector.connection, params=params, **kwargs ) return table
[docs] def get_latest_row_id(self) -> int64: """ Return the latest row id. Returns ------- row_id : int The latest row inserted row id """ # https://stackoverflow.com/questions/3442033/sqlite-how-to-get-value-of-auto-increment-primary-key-after-insert-other-than # pylint: disable=no-member row_id = self.query("SELECT last_insert_rowid() AS id").loc[0, "id"] return row_id
[docs] def get_entry_id( self, table_name: str, entries_dict: Mapping[str, Union[int, str, float, None]] ) -> Optional[int]: """ Get the id of a table entry. Parameters ---------- table_name : str Name of the table to check entries_dict : dict Dictionary containing the entries as key value pairs Returns ------- row_id : int or None The id of the hit If none is found, None is returned """ # NOTE: About checking for existence # https://stackoverflow.com/questions/9755860/valid-query-to-check-if-row-exists-in-sqlite3 # NOTE: About SELECT 1 # https://stackoverflow.com/questions/7039938/what-does-select-1-from-do where_statements_list = list() where_values = list() for field, val in entries_dict.items(): val = f"{val}" if isinstance(val, str) else val where_statements_list.append(f'{" "*7}AND {field}=?') where_values.append(val) where_statements_list[0] = where_statements_list[0].replace("AND", "WHERE") where_statements_str = "\n".join(where_statements_list) # NOTE: Protection against SQL injection through the use of ? in for-loop above query_str = f"SELECT id\nFROM {table_name}\n{where_statements_str}" # nosec table = self.query(query_str, params=where_values) # NOTE: We explicitly cast to int, as sqlite3 will cast # np.int64 to bytes # pylint: disable=no-member row_id = None if table.empty else int(table.loc[0, "id"]) return row_id
[docs] def check_tables_created(self) -> bool: """ Check if the tables is created in the database. Returns ------- bool Whether or not the tables are created """ query_str = 'SELECT name FROM sqlite_master WHERE type="table"' table = self.query(query_str) return len(table.index) != 0