# -*- coding: utf-8 -*-
"""A dictionary-like object for holding atoms.
"""
from itertools import zip_longest
import logging
import pprint # noqa: F401
from typing import Any, Dict, TypeVar
import numpy as np
import pandas
from molsystem import elements
from .column import _Column
from .table import _Table
Atoms_tp = TypeVar("Atoms_tp", "_Atoms", str, None)
logger = logging.getLogger(__name__)
# logger.setLevel("DEBUG")
[docs]
def grouped(iterable, n):
"s -> (s0,s1,s2,...sn-1), (sn,sn+1,sn+2,...s2n-1), (s2n,...s3n-1), ..."
return zip_longest(*[iter(iterable)] * n)
class _Atoms(_Table):
"""The atoms in a configuration of a system.
Parameters
----------
configuration : _Configuration
The configuration of interest.
"""
def __init__(self, configuration) -> None:
self._configuration = configuration
self._system_db = self._configuration.system_db
self._system = None
self._atomset = self._configuration.atomset
self._atom_table = _Table(self.system_db, "atom")
self._coordinates_table = _Table(self.system_db, "coordinates")
self._velocities_table = _Table(self.system_db, "velocities")
self._gradients_table = _Table(self.system_db, "gradients")
super().__init__(self._system_db, "atom")
def __enter__(self) -> Any:
"""Copy the tables to a backup for a 'with' statement."""
self.system_db["atomset_atom"].__enter__()
self.system_db["atom"].__enter__()
return self
def __exit__(self, etype, value, traceback) -> None:
"""Handle returning from a 'with' statement."""
if etype is None:
self.configuration.version = self.configuration.version + 1
self.system_db["atomset_atom"].__exit__(etype, value, traceback)
return self.system_db["atom"].__exit__(etype, value, traceback)
def __delitem__(self, key) -> None:
"""Allow deletion of keys"""
if key in self._atom_table.attributes:
del self._atom_table[key]
elif key in self._coordinates_table.attributes:
del self._coordinates_table[key]
elif key in self._velocities_table.attributes:
del self._velocities_table[key]
elif key in self._gradients_table.attributes:
del self._gradients_table[key]
def __iter__(self) -> iter:
"""Allow iteration over the object"""
return iter([*self.attributes.keys()])
def __len__(self) -> int:
"""The len() command"""
return len(self.attributes)
def __repr__(self) -> str:
"""The string representation of this object"""
df = self.to_dataframe()
return repr(df)
def __str__(self) -> str:
"""The pretty string representation of this object"""
df = self.to_dataframe()
return df.to_string()
def __contains__(self, item) -> bool:
"""Return a boolean indicating if a key exists."""
# Normal the tablename is used as an identifier, so is quoted with ".
# Here we need it as a string literal so strip any quotes from it.
tmp_item = item.strip('"')
return tmp_item in self.attributes
def __eq__(self, other) -> Any:
"""Return a boolean if this object is equal to another"""
diffs = self.diff(other)
return len(diffs) == 0
@property
def ids(self):
"""The ids of the atoms."""
return self.get_ids()
@property
def asymmetric_atomic_numbers(self):
"""The atomic numbers of the asymmetric atoms.
Returns
-------
[int]
The atomic numbers.
"""
return self.get_column_data("atno")
@property
def atomic_numbers(self):
"""The atomic numbers of the atoms.
Returns
-------
[int]
The atomic numbers.
"""
if self.configuration.symmetry.n_symops == 1:
return self.asymmetric_atomic_numbers
else:
result = []
atno = self.asymmetric_atomic_numbers
for i in self.configuration.atom_to_asymmetric_atom:
result.append(atno[i])
return result
@property
def asymmetric_atomic_masses(self):
"""The atomic masses of the atoms.
Returns
-------
[int]
The atomic numbers.
"""
if "mass" in self:
result = self.get_column_data("mass")
else:
atnos = self.atomic_numbers
result = elements.masses(atnos)
return result
@property
def atomic_masses(self):
"""The atomic masses of the atoms.
Returns
-------
[int]
The atomic numbers.
"""
if self.configuration.symmetry.n_symops == 1:
return self.asymmetric_atomic_masses
else:
result = []
mass = self.asymmetric_atomic_masses
for i in self.configuration.atom_to_asymmetric_atom:
result.append(mass[i])
return result
@property
def atomset(self):
"""The atomset for these atoms."""
return self._atomset
@property
def atom_generators(self):
"""The symmetry operations that create the symmetric atoms."""
return self.configuration.symmetry.atom_generators
@property
def attributes(self) -> Dict[str, Any]:
"""The definitions of the attributes.
Combine the attributes of the atom, coordinates, velocities, and gradients
tables to make it look like a single larger table.
"""
result = self._atom_table.attributes
for key, value in self._coordinates_table.attributes.items():
if key != "atom": # atom key links the tables together, so ignore
result[key] = value
for key, value in self._velocities_table.attributes.items():
if key != "atom": # atom key links the tables together, so ignore
result[key] = value
for key, value in self._gradients_table.attributes.items():
if key != "atom": # atom key links the tables together, so ignore
result[key] = value
return result
@property
def configuration(self):
"""Return the configuration."""
return self._configuration
@property
def coordinates(self):
"""The coordinates as list of lists."""
return self.get_coordinates()
@coordinates.setter
def coordinates(self, xyz):
"""The coordinates as list of lists."""
return self.set_coordinates(xyz)
@property
def cursor(self):
"""The database connection."""
return self.system_db.cursor
@property
def db(self):
"""The database connection."""
return self.system_db.db
@property
def gradients(self):
"""The gradients as list of lists."""
return self.get_gradients()
@gradients.setter
def gradients(self, xyz):
"""The gradients as list of lists."""
return self.set_gradients(xyz)
@property
def group(self):
"""The space or point group of the configuration."""
return self.configuration.symmetry.group
@group.setter
def group(self, value):
self.configuration.symmetry.group = value
@property
def have_gradients(self):
"""Whether there are any gradients for this configuration."""
sql = (
"SELECT COUNT(*)"
" FROM atomset_atom AS aa, gradients AS ve "
"WHERE aa.atomset = ?"
" AND ve.atom = aa.atom AND ve.configuration = ?"
)
parameters = [self.atomset, self.configuration.id]
self.cursor.execute(sql, parameters)
return self.cursor.fetchone()[0] > 0
@property
def have_velocities(self):
"""Whether there are any velocities for this configuration."""
sql = (
"SELECT COUNT(*)"
" FROM atomset_atom AS aa, velocities AS ve "
"WHERE aa.atomset = ?"
" AND ve.atom = aa.atom AND ve.configuration = ?"
)
parameters = [self.atomset, self.configuration.id]
self.cursor.execute(sql, parameters)
return self.cursor.fetchone()[0] > 0
@property
def loglevel(self):
"""The logging level for this module."""
result = logger.getEffectiveLevel()
tmp = logging.getLevelName(result)
if "Level" not in tmp:
result = tmp
return result
@loglevel.setter
def loglevel(self, value):
logger.setLevel(value)
@property
def names(self):
"""The names of the atoms."""
return self.get_names()
@property
def n_asymmetric_atoms(self) -> int:
"""The number of symmetry-unique atoms in this configuration."""
self.cursor.execute(
"SELECT COUNT(*) FROM atomset_atom WHERE atomset = ?", (self.atomset,)
)
result = self.cursor.fetchone()[0]
return result
@property
def n_atoms(self) -> int:
"""The number of atoms in this configuration."""
if self.configuration.symmetry.n_symops == 1:
return self.n_asymmetric_atoms
else:
n_atoms = 0
for symops in self.atom_generators:
n_atoms += len(symops)
return n_atoms
@property
def n_symops(self):
"""The number of symmetry operations in the group."""
return self.configuration.symmetry.n_symops
@property
def asymmetric_symbols(self):
"""The element symbols for the atoms in this configuration.
Returns
-------
[str]
The element symbols
"""
return elements.to_symbols(self.asymmetric_atomic_numbers)
@property
def symbols(self):
"""The element symbols for the atoms in this configuration.
Returns
-------
[str]
The element symbols
"""
if self.configuration.symmetry.n_symops == 1:
return self.asymmetric_symbols
else:
return elements.to_symbols(self.atomic_numbers)
@property
def symmetry_matrices(self):
"""The 4x4 matrices for the symmetry operations."""
return self.configuration.symmetry.symmetry_matrices
@property
def symops(self):
"""The symmetry operators as shorthand strings."""
return self.configuration.symmetry.symops
@symops.setter
def symops(self, value):
self.configuration.symmetry.symops = value
@property
def symop_to_atom(self):
"""List of list of symop #'s for creating symmetry atoms from asymmetric."""
return self.configuration.symmetry.symop_to_atom
@property
def system_db(self):
"""The system database that we belong to."""
return self._system_db
@property
def velocities(self):
"""The velocities as list of lists."""
return self.get_velocities()
@velocities.setter
def velocities(self, xyz):
"""The velocities as list of lists."""
return self.set_velocities(xyz)
def add_attribute(
self,
name: str,
coltype: str = "float",
default: Any = None,
notnull: bool = False,
index: bool = False,
pk: bool = False,
references: str = None,
on_delete: str = "cascade",
on_update: str = "cascade",
values: Any = None,
configuration_dependent: bool = False,
) -> None:
"""Adds a new attribute.
If the default value is None, you must always provide values wherever
needed, for example when adding a row.
Parameters
----------
name : str
The name of the attribute.
coltype : str
The type of the attribute (column). Must be one of 'int',
'float', 'str' or 'byte'
default : int, float, str or byte
The default value for the attribute if no value is given.
notnull : bool = False
Whether the value must be non-null
index : bool = False
Whether to create an index on the column
pk : bool = False
Whether the column is the primry keys
references : str = None
If not null, the column is a foreign key for this table.
on_delete : str = 'cascade'
How to handle deletions of a foregin keys
on_update : str = 'cascade'
How to handle updates of a foregin key
values : Any
Either a single value or a list of values length 'nrows' of
values to fill the column.
configuration_dependent : bool = False
Whether the attribute belongs with the coordinates (True)
or atoms (False)
Returns
-------
None
"""
if configuration_dependent:
self._coordinates_table.add_attribute(
name,
coltype=coltype,
default=default,
notnull=notnull,
index=index,
pk=pk,
references=references,
on_delete=on_delete,
on_update=on_update,
values=values,
)
else:
self._atom_table.add_attribute(
name,
coltype=coltype,
default=default,
notnull=notnull,
index=index,
pk=pk,
references=references,
on_delete=on_delete,
on_update=on_update,
values=values,
)
def append(self, **kwargs: Dict[str, Any]) -> None:
"""Append one or more atoms.
The keys give the field for the data. If an existing field is not
mentioned, then the default value is used, unless the default is None,
in which case an error is thrown. It is an error if there is not a
field corrresponding to a key.
"""
# Need to handle the elements specially. Can give atomic numbers,
# or symbols. By construction the references to elements are identical
# to their atomic numbers.
if "symbol" in kwargs:
symbols = kwargs.pop("symbol")
kwargs["atno"] = elements.to_atnos(symbols)
# How many new rows there are
n_rows, lengths = self._get_n_rows(**kwargs)
# Fill in the atom table
data = {}
for column in self._atom_table.attributes:
if column != "id" and column in kwargs:
data[column] = kwargs.pop(column)
if len(data) == 0:
data["atno"] = [None] * n_rows
ids = self._atom_table.append(n=n_rows, **data)
# Now append to the coordinates table
configuration = self.configuration.id
data = {"configuration": configuration, "atom": ids}
for column in self._coordinates_table.attributes:
if column != "id" and column in kwargs:
data[column] = kwargs.pop(column)
self._coordinates_table.append(n=n_rows, **data)
# And velocities table
data = {"configuration": configuration, "atom": ids}
have_velocities = False
for column in self._velocities_table.attributes:
if column != "id" and column in kwargs:
data[column] = kwargs.pop(column)
have_velocities = True
if have_velocities:
self._velocities_table.append(n=n_rows, **data)
# And gradients table
data = {"configuration": configuration, "atom": ids}
have_gradients = False
for column in self._gradients_table.attributes:
if column != "id" and column in kwargs:
data[column] = kwargs.pop(column)
have_gradients = True
if have_gradients:
self._gradients_table.append(n=n_rows, **data)
# And to the atomset
table = _Table(self.system_db, "atomset_atom")
table.append(atomset=self.atomset, atom=ids)
self.configuration.symmetry.reset_atoms()
return ids
def atoms(self, *args):
"""Return an iterator over the atoms.
Parameters
----------
args : [str]
Added selection criteria for the SQL, one word at a time.
Returns
-------
sqlite3.Cursor
A cursor that returns sqlite3.Row objects for the atoms.
"""
have_velocities = self.have_velocities
have_gradients = self.have_gradients
columns = self._columns(velocities=have_velocities, gradients=have_gradients)
column_defs = ", ".join(columns)
# What tables are requested in the extra arguments?
tables = set()
if len(args) == 0:
tables.add("at")
tables.add("co")
if have_velocities:
tables.add("ve")
if have_gradients:
tables.add("gr")
else:
atom_columns = [*self._atom_table.attributes]
coordinates_columns = [*self._coordinates_table.attributes]
velocities_columns = [*self._velocities_table.attributes]
gradients_columns = [*self._gradients_table.attributes]
for col, op, value in grouped(args, 3):
if "." in col:
tables.add(col.split(".")[0])
elif col in atom_columns:
tables.add("at")
elif col in coordinates_columns:
tables.add("co")
elif col in velocities_columns:
if have_velocities:
tables.add("ve")
else:
raise ValueError(
"Query for atom has velocities, but the atoms don't"
)
elif col in gradients_columns:
if have_gradients:
tables.add("gr")
else:
raise ValueError(
"Query for atom has gradients, but the atoms don't"
)
else:
raise ValueError(f"Column '{col}' is not available")
# Build the query based on the tables needed
from_string = ["atomset_atom AS aa", "atom AS at", "coordinates AS co"]
if "ve" in tables:
from_string.append("velocities AS ve")
if "gr" in tables:
from_string.append("gradients AS gr")
from_string = ", ".join(from_string)
sql = (
f"SELECT {column_defs}\n"
f" FROM {from_string}\n"
"WHERE aa.atomset = ?\n"
" AND at.id = aa.atom\n"
" AND co.atom = at.id AND co.configuration = ?"
)
parameters = [self.atomset, self.configuration.id]
if "ve" in tables:
sql += "\n AND ve.atom = at.id AND ve.configuration = ?"
parameters.append(self.configuration.id)
if "gr" in tables:
sql += "\n AND gr.atom = at.id AND gr.configuration = ?"
parameters.append(self.configuration.id)
# And any extra selection criteria
if len(args) > 0:
for col, op, value in grouped(args, 3):
if op == "==":
op = "="
sql += f'\n AND "{col}" {op} ?'
parameters.append(value)
logger.debug("atoms query:")
logger.debug(sql)
logger.debug(parameters)
logger.debug("---")
return self.db.execute(sql, parameters)
def diff(self, other):
"""Difference between these atoms and another
Currently ignores velocities and gradients. Not sure what we want to do....
Parameters
----------
other : _Atoms
The other atoms to diff against
Result
------
result : Dict
The differences, described in a dictionary
"""
result = {}
# Check the columns
columns = self._columns(velocities=False, gradients=False)
other_columns = other._columns(velocities=False, gradients=False)
column_defs = ", ".join(columns)
other_column_defs = ", ".join(other_columns)
if columns == other_columns:
column_def = column_defs
else:
added = columns - other_columns
if len(added) > 0:
result["columns added"] = list(added)
deleted = other_columns - columns
if len(deleted) > 0:
result["columns deleted"] = list(deleted)
in_common = other_columns & columns
if len(in_common) > 0:
column_def = ", ".join(in_common)
else:
# No columns shared
return result
# Need to check the contents of the tables. See if they are in the same
# database or if we need to attach the other database temporarily.
db = self.system_db
other_db = other.system_db
detach = False
schema = self.schema
if db.filename != other_db.filename:
if db.is_attached(other_db):
other_schema = db.attached_as(other_db)
else:
# Attach the other system_db in order to do comparisons.
other_schema = self.system_db.attach(other_db)
detach = True
else:
other_schema = other.schema
atomset = self.atomset
other_atomset = other.atomset
changed = {}
last = None
sql = f"""
SELECT * FROM
(
SELECT {column_def}
FROM {other_schema}.atom AS at, {other_schema}.coordinates AS co
WHERE co.atom = at.id
AND at.id
IN (
SELECT atom
FROM {other_schema}.atomset_atom
WHERE atomset = {other_atomset}
)
EXCEPT
SELECT {column_def}
FROM {schema}.atom AS at, {schema}.coordinates AS co
WHERE co.atom = at.id
AND at.id
IN (
SELECT atom
FROM {schema}.atomset_atom
WHERE atomset = {atomset}
)
)
UNION ALL
SELECT * FROM
(
SELECT {column_def}
FROM {schema}.atom AS at, {schema}.coordinates AS co
WHERE co.atom = at.id
AND at.id
IN (
SELECT atom
FROM {schema}.atomset_atom
WHERE atomset = {atomset}
)
EXCEPT
SELECT {column_def}
FROM {other_schema}.atom AS at, {other_schema}.coordinates AS co
WHERE co.atom = at.id
AND at.id
IN (
SELECT atom
FROM {other_schema}.atomset_atom
WHERE atomset = {other_atomset}
)
)
ORDER BY id
"""
for row in self.db.execute(sql):
if last is None:
last = row
elif row["id"] == last["id"]:
# changes = []
changes = set()
for k1, v1, v2 in zip(last.keys(), last, row):
if v1 != v2:
changes.add((k1, v1, v2))
changed[row["id"]] = changes
last = None
else:
last = row
if len(changed) > 0:
result["changed"] = changed
# See about the rows added
added = {}
sql = f"""
SELECT {column_defs}
FROM {schema}.atom AS at, {schema}.coordinates AS co
WHERE co.atom = at.id
AND at.id
IN (
SELECT atom
FROM {schema}.atomset_atom
WHERE atomset = {atomset}
)
AND at.id
NOT IN (
SELECT atom
FROM {other_schema}.atomset_atom
WHERE atomset = {other_atomset}
)
"""
for row in self.db.execute(sql):
added[row["id"]] = row[1:]
if len(added) > 0:
result["columns in added rows"] = row.keys()[1:]
result["added"] = added
# See about the rows deleted
deleted = {}
sql = f"""
SELECT {other_column_defs}
FROM {other_schema}.atom AS at, {other_schema}.coordinates AS co
WHERE co.atom = at.id
AND at.id
IN (
SELECT atom
FROM {other_schema}.atomset_atom
WHERE atomset = {other_atomset}
)
AND at.id
NOT IN (
SELECT atom
FROM {schema}.atomset_atom
WHERE atomset = {atomset}
)
"""
for row in self.db.execute(sql):
deleted[row["id"]] = row[1:]
if len(deleted) > 0:
result["columns in deleted rows"] = row.keys()[1:]
result["deleted"] = deleted
# Detach the other database if needed
if detach:
self.system_db.detach(other_db)
return result
def get_as_dict(self, *args, asymmetric=False):
"""Return the atom data as a Python dictionary of lists.
Parameters
----------
args : [str]
Added selection criteria for the SQL, one word at a time.
Returns
-------
dict(str: [])
A dictionary whose keys are the column names and values as lists
"""
if asymmetric:
rows = self.atoms(*args)
columns = [x[0] for x in rows.description]
data = {key: [] for key in columns}
for row in rows:
for key, value in zip(columns, row):
data[key].append(value)
else:
data = {}
properties = [*self.attributes.keys()]
for name in properties:
if name in ("id", "x", "y", "z", "vx", "vy", "vz", "name"):
continue
data[name] = self.get_property(name)
if "name" in properties:
data["name"] = self.get_names()
xyz = self.get_coordinates()
data["x"] = [v[0] for v in xyz]
data["y"] = [v[1] for v in xyz]
data["z"] = [v[2] for v in xyz]
data["id"] = [*range(len(xyz))]
return data
def get_ids(self, *args):
"""The ids of the selected atoms.
Any extra arguments are triplets of column, operator, and value force
additional selection criteria. The table names must be used in the column
specification and are 'at' for the atom table and 'co' for the coordinate
table.
For example, if there are three arguments "at.atno", "=", "6" only the ids
of the carbon atoms in the configuration will be returned.
Parameters
----------
args : [str]
Further selection arguments, in sets of three:
column, operator, and value, e.g. "at.atno = 6"
Returns
-------
[int]
The ids of the requested atoms.
"""
# What tables are requested in the extra arguments?
tables = set()
if len(args) > 0:
atom_columns = [*self._atom_table.attributes]
coordinates_columns = [*self._coordinates_table.attributes]
for col, op, value in grouped(args, 3):
if "." in col:
tables.add(col.split(".")[0])
elif col in atom_columns:
tables.add("at")
elif col in coordinates_columns:
tables.add("co")
else:
raise ValueError(f"Column '{col}' is not available")
# Build the query based on the tables needed
sql = "SELECT aa.atom FROM atomset_atom AS aa"
if "at" in tables or "co" in tables:
sql += ", atom AS at"
if "co" in tables:
sql += ", coordinates AS co"
# The WHERE clauses bringing joining the tables
sql += " WHERE aa.atomset = ?"
parameters = [self.atomset]
if "at" in tables or "co" in tables:
sql += " AND at.id = aa.atom"
if "co" in tables:
sql += " AND co.atom = at.id AND co.configuration = ?"
parameters.append(self.configuration.id)
# And any extra selection criteria
if len(args) > 0:
for col, op, value in grouped(args, 3):
if op == "==":
op = "="
sql += f' AND "{col}" {op} ?'
parameters.append(value)
logger.debug("get_id query:")
logger.debug(sql)
logger.debug(parameters)
logger.debug("---")
return [x[0] for x in self.db.execute(sql, parameters)]
def get_coordinates(
self,
fractionals=True,
in_cell=False,
as_array=False,
asymmetric=False,
):
"""Return the coordinates optionally translated back into the principal
unit cell.
Parameters
----------
fractionals : bool = True
Return the coordinates as fractional coordinates for periodic
systems. Non-periodic systems always use Cartesian coordinates.
in_cell : bool, str = False
Whether to translate the atoms into the unit cell, and if so
whether to do so by molecule or just atoms.
as_array : bool = False
Whether to return the results as a np array or as a list of
lists (the default).
asymmetric : bool = False
If true, return coordinates for only the symmetry-unique atoms.
By default, expand to all atoms in the system.
Returns
-------
abc : [N][float*3]
The coordinates, either Cartesian or fractional
"""
xyz = [[row["x"], row["y"], row["z"]] for row in self.atoms()]
periodicity = self.configuration.periodicity
if periodicity == 0:
if asymmetric and self.configuration.symmetry.n_symops > 1:
raise NotImplementedError("Point-group symmetry not handled yet.")
if as_array:
return np.array(xyz)
else:
return xyz
cell = self.configuration.cell
UVW = None
if not asymmetric and self.n_symops > 1:
# Get the asymmetric fractional coordinates as np array
if self.configuration.coordinate_system == "Cartesian":
UVW_asym = cell.to_fractionals(xyz, as_array=True)
elif not isinstance(xyz, np.ndarray):
UVW_asym = np.array(xyz)
else:
UVW_asym = xyz
# Move into unit cell, remembering shift
trans = np.floor(UVW_asym)
UVW_asym -= trans
# Apply the generating symmetry operators
op = self.configuration.symmetry.symmetry_matrices
generators = self.configuration.symmetry.atom_generators
for i, indices in enumerate(generators):
uvw4 = np.array([0.0, 0.0, 0.0, 1.0])
uvw4[0:3] = UVW_asym[i]
xformed = np.einsum("ijk,k", op[indices, :, :], uvw4)
uvw = xformed[:, 0:3]
if UVW is None:
UVW = uvw
else:
UVW = np.concatenate((UVW, uvw))
if (
isinstance(in_cell, str)
and "molecule" in in_cell
and self.configuration.n_bonds > 0
):
# Need fractionals...
if UVW is None:
if self.configuration.coordinate_system == "Cartesian":
UVW = cell.to_fractionals(xyz, as_array=True)
elif not isinstance(xyz, np.ndarray):
UVW = np.array(xyz)
else:
UVW = xyz
molecules = self.configuration.find_molecules(as_indices=True)
for indices in molecules:
indices = np.array(indices)
uvw_mol = np.take(UVW, indices, axis=0)
center = np.average(uvw_mol, axis=0)
delta = np.floor(center)
uvw_mol -= delta
np.put_along_axis(UVW, np.expand_dims(indices, axis=1), uvw_mol, axis=0)
if fractionals:
if as_array:
return UVW
else:
return UVW.tolist()
else:
return cell.to_cartesians(UVW, as_array=as_array)
elif in_cell:
# Need fractionals...
if UVW is None:
if self.configuration.coordinate_system == "Cartesian":
UVW = cell.to_fractionals(xyz, as_array=True)
elif not isinstance(xyz, np.ndarray):
UVW = np.array(xyz)
else:
UVW = xyz
delta = np.floor(UVW)
UVW -= delta
if fractionals:
if as_array:
return UVW
else:
return UVW.tolist()
else:
return cell.to_cartesians(UVW, as_array=as_array)
else:
if fractionals:
if UVW is None:
if self.configuration.coordinate_system == "Cartesian":
return cell.to_fractionals(xyz, as_array=as_array)
elif as_array:
return np.array(xyz)
else:
return xyz
else:
if as_array:
return UVW
else:
return UVW.tolist()
else:
if UVW is None:
if self.configuration.coordinate_system == "fractional":
return cell.to_cartesians(xyz, as_array=as_array)
elif as_array:
return np.array(xyz)
else:
return xyz
else:
return cell.to_cartesians(UVW, as_array=as_array)
def get_names(self, asymmetric=False):
"""Return the names of the atoms, return a default if not in the database.
Parameters
----------
asymmetric : bool = False
Return just the names for the asymmetric atoms.
Returns
-------
[str]
The names of the atoms.
"""
if "name" in self:
name = self.get_column_data("name")
else:
name = []
count = {}
for symbol in self.asymmetric_symbols:
if symbol not in count:
count[symbol] = 1
else:
count[symbol] += 1
name.append(f"{symbol}{count[symbol]}")
symmetry = self.configuration.symmetry
if asymmetric or symmetry.n_symops == 1:
return name
# Expand to the asymmetric atoms
result = []
count = {i: 0 for i in range(len(name))}
for asym_atom in symmetry.atom_to_asymmetric_atom:
count[asym_atom] += 1
result.append(f"{name[asym_atom]}_{count[asym_atom]}")
return result
def get_n_atoms(self, *args):
"""Return the number of atoms meeting the cirteria.
Parameters
----------
args : [str]
Added selection criteria for the SQL, one word at a time.
Returns
-------
int
The number of atoms matching the criteria.
"""
# What tables are requested in the extra arguments?
tables = set()
if len(args) > 0:
atom_columns = [*self._atom_table.attributes]
coordinates_columns = [*self._coordinates_table.attributes]
for col, op, value in grouped(args, 3):
if "." in col:
tables.add(col.split(".")[0])
elif col in atom_columns:
tables.add("at")
elif col in coordinates_columns:
tables.add("co")
else:
raise ValueError(f"Column '{col}' is not available")
# Build the query based on the tables needed
sql = "SELECT COUNT(*) FROM atomset_atom AS aa"
if "at" in tables or "co" in tables:
sql += ", atom AS at"
if "co" in tables:
sql += ", coordinates AS co"
# The WHERE clauses bringing joining the tables
sql += " WHERE aa.atomset = ?"
parameters = [self.atomset]
if "at" in tables or "co" in tables:
sql += " AND at.id = aa.atom"
if "co" in tables:
sql += " AND co.atom = at.id AND co.configuration = ?"
parameters.append(self.configuration.id)
# And any extra selection criteria
if len(args) > 0:
for col, op, value in grouped(args, 3):
if op == "==":
op = "="
sql += f' AND "{col}" {op} ?'
parameters.append(value)
logger.debug("get_n_atoms query:")
logger.debug(sql)
logger.debug(parameters)
logger.debug("---")
self.cursor.execute(sql, parameters)
return self.cursor.fetchone()[0]
def get_property(self, name, asymmetric=False):
"""Return the property from the atoms, expanded to symmetric atoms.
Parameters
----------
name : str
The property (attribute) to return
asymmetric : bool = False
Return just the names for the asymmetric atoms.
Returns
-------
[any]
The values of the property on the symmetric atoms
"""
if name == "name":
return self.get_names(asymmetric=asymmetric)
data = self.get_column_data(name)
symmetry = self.configuration.symmetry
if asymmetric or symmetry.n_symops == 1:
return data
# Expand to the asymmetric atoms
return [data[i] for i in symmetry.atom_to_asymmetric_atom]
def get_gradients(
self,
fractionals=True,
as_array=False,
):
"""Return the gradients.
Symmetry is not supported, because it makes no (little?) sense for gradients.
Parameters
----------
fractionals : bool = True
Return the gradients as fractional gradients for periodic
systems. Non-periodic systems always use Cartesian gradients.
as_array : bool = False
Whether to return the results as a np array or as a list of
lists (the default).
Returns
-------
abc : [N][float*3]
The gradients, either Cartesian or fractional
"""
gxs = self.get_column_data("gx")
gys = self.get_column_data("gy")
gzs = self.get_column_data("gz")
xyz = [[gx, gy, gz] for gx, gy, gz in zip(gxs, gys, gzs)]
periodicity = self.configuration.periodicity
if periodicity == 0:
if as_array:
return np.array(xyz)
else:
return xyz
cell = self.configuration.cell
if fractionals:
if self.configuration.coordinate_system == "Cartesian":
return cell.to_fractionals(xyz, as_array=as_array)
elif as_array:
return np.array(xyz)
else:
return xyz
else:
if self.configuration.coordinate_system == "fractional":
return cell.to_cartesians(xyz, as_array=as_array)
elif as_array:
return np.array(xyz)
else:
return xyz
def get_velocities(
self,
fractionals=True,
as_array=False,
):
"""Return the velocities.
Symmetry is not supported, because it makes no (little?) sense for velocities.
Parameters
----------
fractionals : bool = True
Return the velocities as fractional velocities for periodic
systems. Non-periodic systems always use Cartesian velocities.
as_array : bool = False
Whether to return the results as a np array or as a list of
lists (the default).
Returns
-------
abc : [N][float*3]
The velocities, either Cartesian or fractional
"""
vxs = self.get_column_data("vx")
vys = self.get_column_data("vy")
vzs = self.get_column_data("vz")
xyz = [[vx, vy, vz] for vx, vy, vz in zip(vxs, vys, vzs)]
periodicity = self.configuration.periodicity
if periodicity == 0:
if as_array:
return np.array(xyz)
else:
return xyz
cell = self.configuration.cell
if fractionals:
if self.configuration.coordinate_system == "Cartesian":
return cell.to_fractionals(xyz, as_array=as_array)
elif as_array:
return np.array(xyz)
else:
return xyz
else:
if self.configuration.coordinate_system == "fractional":
return cell.to_cartesians(xyz, as_array=as_array)
elif as_array:
return np.array(xyz)
else:
return xyz
def set_coordinates(self, xyz, fractionals=True):
"""Set the coordinates to new values.
Parameters
----------
fractionals : bool = True
The coordinates are fractional coordinates for periodic
systems. Ignored for non-periodic systems.
Returns
-------
None
"""
as_array = isinstance(xyz, np.ndarray)
if as_array:
n_coords = xyz.shape[0]
else:
n_coords = len(xyz)
# May need to handle symmetry
if n_coords != self.n_asymmetric_atoms and n_coords == self.n_atoms:
# print("set_coordinates: symmetrizing the coordinates")
# pprint.pprint(xyz)
xyz, error = self.configuration.symmetry.symmetrize_coordinates(
xyz, fractionals=fractionals
)
# print("results in")
# pprint.pprint(xyz)
x_column = self.get_column("x")
y_column = self.get_column("y")
z_column = self.get_column("z")
xs = []
ys = []
zs = []
periodicity = self.configuration.periodicity
coordinate_system = self.configuration.coordinate_system
if (
periodicity == 0
or (coordinate_system == "Cartesian" and not fractionals)
or (coordinate_system == "fractional" and fractionals)
):
if as_array:
for x, y, z in xyz.tolist():
xs.append(x)
ys.append(y)
zs.append(z)
else:
for x, y, z in xyz:
xs.append(x)
ys.append(y)
zs.append(z)
else:
cell = self.configuration.cell
if coordinate_system == "fractional":
# Convert coordinates to fractionals
for x, y, z in cell.to_fractionals(xyz):
xs.append(x)
ys.append(y)
zs.append(z)
else:
for x, y, z in cell.to_cartesians(xyz):
xs.append(x)
ys.append(y)
zs.append(z)
x_column[0:] = xs
y_column[0:] = ys
z_column[0:] = zs
def set_gradients(self, gxyz, fractionals=False):
"""Set the gradients to new values.
Parameters
----------
fractionals : bool = False
The gradients are fractional gradients for periodic
systems. Ignored for non-periodic systems.
Returns
-------
None
"""
if self.n_symops > 1:
raise RuntimeError("Can't handle gradients with symmetry.")
as_array = isinstance(gxyz, np.ndarray)
gxs = []
gys = []
gzs = []
periodicity = self.configuration.periodicity
coordinate_system = self.configuration.coordinate_system
if (
periodicity == 0
or (coordinate_system == "Cartesian" and not fractionals)
or (coordinate_system == "fractional" and fractionals)
):
if as_array:
for gx, gy, gz in gxyz.tolist():
gxs.append(gx)
gys.append(gy)
gzs.append(gz)
else:
for gx, gy, gz in gxyz:
gxs.append(gx)
gys.append(gy)
gzs.append(gz)
else:
cell = self.configuration.cell
if coordinate_system == "fractional":
# Convert gradients to fractionals
for gx, gy, gz in cell.to_fractionals(gxyz):
gxs.append(gx)
gys.append(gy)
gzs.append(gz)
else:
for gx, gy, gz in cell.to_cartesians(gxyz):
gxs.append(gx)
gys.append(gy)
gzs.append(gz)
gx_column = self.get_column("gx")
if len(gx_column) == 0:
# No gradients in the database, so need to add rather than setFormatter
self._gradients_table.append(
n=len(gxs),
gx=gxs,
gy=gys,
gz=gzs,
atom=self.ids,
configuration=self.configuration.id,
)
else:
gy_column = self.get_column("gy")
gz_column = self.get_column("gz")
gx_column[0:] = gxs
gy_column[0:] = gys
gz_column[0:] = gzs
def set_velocities(self, vxyz, fractionals=False):
"""Set the velocities to new values.
Parameters
----------
fractionals : bool = False
The velocities are fractional velocities for periodic
systems. Ignored for non-periodic systems.
Returns
-------
None
"""
if self.n_symops > 1:
raise RuntimeError("Can't handle velocities with symmetry.")
as_array = isinstance(vxyz, np.ndarray)
vxs = []
vys = []
vzs = []
periodicity = self.configuration.periodicity
coordinate_system = self.configuration.coordinate_system
if (
periodicity == 0
or (coordinate_system == "Cartesian" and not fractionals)
or (coordinate_system == "fractional" and fractionals)
):
if as_array:
for vx, vy, vz in vxyz.tolist():
vxs.append(vx)
vys.append(vy)
vzs.append(vz)
else:
for vx, vy, vz in vxyz:
vxs.append(vx)
vys.append(vy)
vzs.append(vz)
else:
cell = self.configuration.cell
if coordinate_system == "fractional":
# Convert velocities to fractionals
for vx, vy, vz in cell.to_fractionals(vxyz):
vxs.append(vx)
vys.append(vy)
vzs.append(vz)
else:
for vx, vy, vz in cell.to_cartesians(vxyz):
vxs.append(vx)
vys.append(vy)
vzs.append(vz)
vx_column = self.get_column("vx")
if len(vx_column) == 0:
# No velocities in the database, so need to add rather than setFormatter
self._velocities_table.append(
n=len(vxs),
vx=vxs,
vy=vys,
vz=vzs,
atom=self.ids,
configuration=self.configuration.id,
)
else:
vy_column = self.get_column("vy")
vz_column = self.get_column("vz")
vx_column[0:] = vxs
vy_column[0:] = vys
vz_column[0:] = vzs
def get_column(self, key: str) -> Any:
"""Get a Column object with the requested data
Parameters
----------
key : str
The attribute to get.
Returns
-------
Column
A Column object containing the data.
"""
if key in self._atom_table.attributes:
sql = (
f'SELECT at.rowid, at."{key}"'
" FROM atom as at, atomset_atom as aa"
f" WHERE at.id = aa.atom AND aa.atomset = {self.atomset}"
)
return _Column(self._atom_table, key, sql=sql)
elif key in self._coordinates_table.attributes:
sql = (
f'SELECT co.rowid, co."{key}"'
" FROM atom as at,"
" coordinates as co,"
" atomset_atom as aa"
" WHERE co.atom = at.id"
f" AND co.configuration = {self.configuration.id}"
" AND at.id = aa.atom"
f" AND aa.atomset = {self.atomset}"
)
return _Column(self._coordinates_table, key, sql=sql)
elif key in self._velocities_table.attributes:
sql = (
f'SELECT ve.rowid, ve."{key}"'
" FROM atom as at,"
" velocities as ve,"
" atomset_atom as aa"
" WHERE ve.atom = at.id"
f" AND ve.configuration = {self.configuration.id}"
" AND at.id = aa.atom"
f" AND aa.atomset = {self.atomset}"
)
return _Column(self._velocities_table, key, sql=sql)
elif key in self._gradients_table.attributes:
sql = (
f'SELECT gr.rowid, gr."{key}"'
" FROM atom as at,"
" gradients as gr,"
" atomset_atom as aa"
" WHERE gr.atom = at.id"
f" AND gr.configuration = {self.configuration.id}"
" AND at.id = aa.atom"
f" AND aa.atomset = {self.atomset}"
)
return _Column(self._gradients_table, key, sql=sql)
else:
raise KeyError(f"'{key}' not in atoms")
def get_column_data(self, key: str) -> Any:
"""Return a column of data from the table.
Parameters
----------
key : str
The attribute to get.
Returns
-------
Column
A Column object containing the data.
"""
if key in self._atom_table.attributes:
sql = (
f'SELECT at."{key}"'
" FROM atom as at, atomset_atom as aa"
f" WHERE at.id = aa.atom AND aa.atomset = {self.atomset}"
)
return [row[0] for row in self.db.execute(sql)]
elif key in self._coordinates_table.attributes:
sql = (
f'SELECT co."{key}"'
" FROM atom as at,"
" coordinates as co,"
" atomset_atom as aa"
" WHERE co.atom = at.id"
f" AND co.configuration = {self.configuration.id}"
" AND at.id = aa.atom"
f" AND aa.atomset = {self.atomset}"
)
return [row[0] for row in self.db.execute(sql)]
elif key in self._velocities_table.attributes:
sql = (
f'SELECT ve."{key}"'
" FROM atom as at,"
" velocities as ve,"
" atomset_atom as aa"
" WHERE ve.atom = at.id"
f" AND ve.configuration = {self.configuration.id}"
" AND at.id = aa.atom"
f" AND aa.atomset = {self.atomset}"
)
return [row[0] for row in self.db.execute(sql)]
elif key in self._gradients_table.attributes:
sql = (
f'SELECT gr."{key}"'
" FROM atom as at,"
" gradients as gr,"
" atomset_atom as aa"
" WHERE gr.atom = at.id"
f" AND gr.configuration = {self.configuration.id}"
" AND at.id = aa.atom"
f" AND aa.atomset = {self.atomset}"
)
return [row[0] for row in self.db.execute(sql)]
else:
raise KeyError(f"'{key}' not in atoms")
def _get_n_rows(self, **kwargs):
"""Get the total number of rows represented in the arguments."""
n_rows = None
lengths = {}
for key, value in kwargs.items():
if key not in self:
raise KeyError(f'"{key}" is not an attribute of the atoms.')
length = self.length_of_values(value)
lengths[key] = length
if n_rows is None:
n_rows = 1 if length == 0 else length
if length > 1 and length != n_rows:
if n_rows == 1:
n_rows = length
else:
raise IndexError(
'key "{}" has the wrong number of values, '.format(key)
+ "{}. Should be 1 or the number of atoms ({}).".format(
length, n_rows
)
)
return n_rows, lengths
def length_of_values(self, values: Any) -> int:
"""Return the length of the values argument.
Parameters
----------
values : Any
The values, which might be a string, single value, list, tuple,
etc.
Returns
-------
length : int
The length of the values. 0 indicates a scalar
"""
if isinstance(values, str):
return 0
else:
try:
return len(values)
except TypeError:
return 0
def delete(self, atoms) -> int:
"""Delete the atoms listed
Parameters
----------
atoms : [int]
The list of atoms to delete, or 'all' or '*'
Returns
-------
None
"""
# Delete the listed atoms, which will cascade to delete coordinates,
# velocities, and gradients
if atoms == "all" or atoms == "*":
sql = """
DELETE FROM atom
WHERE id IN (SELECT atom FROM atomset_atom WHERE atomset = ?)
"""
parameters = (self.atomset,)
self.db.execute(sql, parameters)
else:
sql = """
DELETE FROM atom
WHERE id = ?
AND id IN (SELECT atom FROM atomset_atom WHERE atomset = ?)
"""
parameters = [(i, self.atomset) for i in atoms]
self.db.executemany(sql, parameters)
self.configuration.symmetry.reset_atoms()
def to_dataframe(self):
"""Return the contents of the table as a Pandas Dataframe."""
data = {}
rows = self.atoms()
for row in rows:
data[row[0]] = row[1:]
columns = [x[0] for x in rows.description[1:]]
df = pandas.DataFrame.from_dict(data, orient="index", columns=columns)
return df
def _columns(self, velocities=True, gradients=True):
"""The list of columns across the atom, coordinates, velocities, and gradients
tables.
Uses 'at', 'co', 've', and 'gr' as the shorthand for the full table names.
"""
atom_columns = [*self._atom_table.attributes]
coordinates_columns = [*self._coordinates_table.attributes]
coordinates_columns.remove("atom")
columns = [f'at."{x}"' for x in atom_columns]
columns += [f'co."{x}"' for x in coordinates_columns]
if velocities:
velocities_columns = [*self._velocities_table.attributes]
velocities_columns.remove("atom")
columns += [f've."{x}"' for x in velocities_columns]
if gradients:
gradients_columns = [*self._gradients_table.attributes]
gradients_columns.remove("atom")
columns += [f'gr."{x}"' for x in gradients_columns]
return columns
class _SubsetAtoms(_Atoms):
"""The atoms in a subset.
Parameters
----------
configuration : _Configuration
The configuration of interest.
subset_id : int
The id of the subset.
template_order : bool
Whether to return atoms and properties in the order of the template
if the template is full. Defaults to True.
"""
def __init__(self, configuration, subset_id) -> None:
self._sid = subset_id
self.template_order = True
# Caching
self._template_id = None
self._template = None
super().__init__(configuration)
def __eq__(self, other) -> Any:
"""Return a boolean if this object is equal to another"""
raise NotImplementedError("Not implemented for subsets, yet!")
@property
def subset_id(self):
"""The subset for these atoms."""
return self._sid
@property
def atomic_numbers(self):
"""The atomic numbers of the subset atoms.
Note that subsets refer to asymmetric atoms!
Returns
-------
[int]
The atomic numbers.
"""
return self.get_column_data("atno")
@property
def atomic_masses(self):
"""The atomic masses of the atoms in the subset.
Note that subsets refer to asymmetric atoms!
Returns
-------
[int]
The atomic numbers.
"""
if "mass" in self:
result = self.get_column_data("mass")
else:
atnos = self.atomic_numbers
result = elements.masses(atnos)
return result
@property
def have_velocities(self):
"""Whether there are any velocities for this configuration."""
sql = (
"SELECT COUNT(*)"
" FROM subset_atom AS sa, velocities AS ve "
"WHERE sa.subset = ?"
" AND ve.atom = sa.atom AND ve.configuration = ?"
)
parameters = [self.subset_id, self.configuration.id]
self.cursor.execute(sql, parameters)
return self.cursor.fetchone()[0] > 0
@property
def ids(self):
"""The ids of atoms in this subset."""
sql = """
SELECT atom
FROM subset_atom
WHERE subset = ?
"""
if self.template.is_full and self.template_order:
sql += "ORDER BY templateatom"
return [x[0] for x in self.db.execute(sql, (self.subset_id,))]
@property
def n_atoms(self) -> int:
"""The number of atoms in this subset."""
sql = """
SELECT COUNT(*)
FROM atomset_atom
WHERE atomset = ?
AND atom IN (SELECT atom FROM subset_atom WHERE subset = ?)
"""
self.cursor.execute(sql, (self.atomset, self.subset_id))
return self.cursor.fetchone()[0]
@property
def template(self):
"""The template for this subset."""
if self._template is None:
self._template = self.system_db.templates.get(self.template_id)
return self._template
@property
def template_id(self):
"""The id of the template for this subset."""
if self._template_id is None:
sql = "SELECT template FROM SUBSET WHERE id = ?"
self.cursor.execute(sql, (self.subset_id,))
self._template_id = self.cursor.fetchone()[0]
return self._template_id
def add(self, ids):
"""Add atoms to the subset.
Parameters
----------
ids : [int]
The ids of the atoms to add. They are silently ignored if
they are already in the subset.
Raises
------
ValueError
If this subset has a full template or if an atom is not in the
configuration.
"""
# Check if the template has a template configuration
if self.template.is_full:
raise ValueError("Cannot add atoms to a subset for a full template.")
# Remove any ids already in the subset
atom_ids = set(self.ids)
for aid in atom_ids & set(ids):
ids.remove(aid)
sa = self.system_db["subset_atom"]
sa.append(subset=self.subset_id, atom=ids)
def append(self, **kwargs: Dict[str, Any]) -> None:
"""Append one or more atoms.
This is not allowed for a subset.
Raises
------
RuntimeError
Adding atoms to a configuration is not allowed from a subset.
"""
raise RuntimeError(
"Adding atoms to a configuration is not allowed from a subset."
)
def atoms(self, *args):
"""Return an iterator over the atoms.
Parameters
----------
args : [str]
Added selection criteria for the SQL, one word at a time.
Returns
-------
sqlite3.Cursor
A cursor that returns sqlite3.Row objects for the atoms.
"""
have_velocities = self.have_velocities
have_gradients = self.have_gradients
columns = self._columns(velocities=have_velocities, gradients=have_gradients)
column_defs = ", ".join(columns)
from_string = ["atom as at", "coordinates as co", "subset_atom as sa"]
if have_velocities:
from_string.append("velocities AS ve")
if have_gradients:
from_string.append("gradients AS gr")
from_string = ", ".join(from_string)
sql = f"""
SELECT {column_defs}
FROM {from_string}
WHERE co.atom = at.id
AND co.configuration = ?
AND at.id = sa.atom
AND sa.subset = ?
"""
parameters = [self.configuration.id, self.subset_id]
if have_velocities:
sql += "\n AND ve.atom = at.id AND ve.configuration = ?"
parameters.append(self.configuration.id)
if have_gradients:
sql += "\n AND gr.atom = at.id AND gr.configuration = ?"
parameters.append(self.configuration.id)
if len(args) > 0:
for col, op, value in grouped(args, 3):
if op == "==":
op = "="
sql += f'\n AND "{col}" {op} ?'
parameters.append(value)
if self.template.is_full and self.template_order:
sql += "\nORDER BY sa.templateatom"
return self.db.execute(sql, parameters)
def get_n_atoms(self, *args):
"""Return the number of atoms meeting the criteria.
Parameters
----------
args : [str]
Added selection criteria for the SQL, one word at a time.
Returns
-------
int
The number of atoms matching the criteria.
"""
sql = """
SELECT COUNT(*)
FROM atom as at, coordinates as co, subset_atom as sa
WHERE co.atom = at.id
AND co.configuration = ?
AND at.id = sa.atom
AND sa.subset = ?
"""
parameters = [self.configuration.id, self.subset_id]
if len(args) > 0:
for col, op, value in grouped(args, 3):
if op == "==":
op = "="
sql += f' AND "{col}" {op} ?'
parameters.append(value)
self.cursor.execute(sql, parameters)
return self.cursor.fetchone()[0]
def diff(self, other):
"""Difference between these atoms and another
Parameters
----------
other : _Atoms
The other atoms to diff against
Result
------
result : Dict
The differences, described in a dictionary
"""
raise NotImplementedError()
def get_column(self, key: str) -> Any:
"""Get a Column object with the requested data
Parameters
----------
key : str
The attribute to get.
Returns
-------
Column
A Column object containing the data.
"""
if key in self._atom_table.attributes:
sql = f"""
SELECT at.rowid, at."{key}"
FROM atom as at, subset_atom as sa
WHERE at.id = sa.atom
AND sa.subset = {self.subset_id}
"""
if self.template.is_full and self.template_order:
sql += "ORDER BY sa.templateatom"
return _Column(self._atom_table, key, sql=sql)
elif key in self._coordinates_table.attributes:
sql = f"""
SELECT co.rowid, co."{key}"
FROM coordinates as co, subset_atom as sa
WHERE co.atom = sa.atom
AND co.configuration = {self.configuration.id}
AND sa.subset = {self.subset_id}
"""
if self.template.is_full and self.template_order:
sql += "ORDER BY sa.templateatom"
return _Column(self._coordinates_table, key, sql=sql)
else:
raise KeyError(f"'{key}' not in atoms")
def get_column_data(self, key: str) -> Any:
"""Return a column of data from the table.
Parameters
----------
key : str
The attribute to get.
Returns
-------
Column
A Column object containing the data.
"""
if key in self._atom_table.attributes:
sql = f"""
SELECT at."{key}"
FROM atom as at, subset_atom as sa
WHERE at.id = sa.atom
AND sa.subset = {self.subset_id}
"""
if self.template.is_full and self.template_order:
sql += "ORDER BY sa.templateatom"
return [row[0] for row in self.db.execute(sql)]
elif key in self._coordinates_table.attributes:
sql = f"""
SELECT co."{key}"
FROM coordinates as co, subset_atom as sa
WHERE co.atom = sa.atom
AND co.configuration = {self.configuration.id}
AND sa.subset = {self.subset_id}
"""
if self.template.is_full and self.template_order:
sql += "ORDER BY sa.templateatom"
return [row[0] for row in self.db.execute(sql)]
else:
raise KeyError(f"'{key}' not in atoms")
def get_ids(self, *args):
"""The ids of the atoms.
Parameters
----------
args : [str]
Added selection criteria for the SQL, one word at a time.
Returns
-------
[int]
The ids of the requested atoms.
"""
sql = """
SELECT at.id
FROM atom as at, coordinates as co, subset_atom as sa
WHERE at.id = sa.atom
AND co.configuration = ?
AND sa.subset = ?
"""
parameters = [self.configuration.id, self.subset_id]
if len(args) > 0:
for col, op, value in grouped(args, 3):
if op == "==":
op = "="
sql += f' AND "{col}" {op} ?'
parameters.append(value)
if self.template.is_full and self.template_order:
sql += "ORDER BY sa.templateatom"
return [x[0] for x in self.db.execute(sql, parameters)]
def delete(self, atoms) -> int:
"""Delete the atoms listed
Parameters
----------
atoms : [int]
The list of atoms to delete, or 'all' or '*'
Returns
-------
None
"""
raise NotImplementedError()
def remove(self, ids):
"""Remove the given atoms from the subset.
Parameters
----------
ids : [int]
The ids of the atoms to delete.
Raises
------
ValueError
If this subset has a full template or if an atom is not in the
subset.
"""
# Check if the template has a template configuration
if self.template.is_full:
raise ValueError("Cannot add atoms to a subset for a full template.")
# Check that the atoms are in this subset!
atom_ids = self.ids
for aid in ids:
if aid not in atom_ids:
raise ValueError(f"Atom id={aid} is not in the subset.")
parameters = [(_id, self.subset_id) for _id in ids]
self.db.executemany(
"DELETE FROM subset_atom WHERE atom=? AND subset=?", parameters
)