# -*- coding: utf-8 -*-
"""
ToDo: Document
Copyright (c) 2021, the qudi developers. See the AUTHORS.md file at the top-level directory of this
distribution and on <https://github.com/Ulm-IQO/qudi-core/>
This file is part of qudi.
Qudi is free software: you can redistribute it and/or modify it under the terms of
the GNU Lesser General Public License as published by the Free Software Foundation,
either version 3 of the License, or (at your option) any later version.
Qudi is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with qudi.
If not, see <https://www.gnu.org/licenses/>.
"""
__all__ = ('is_fit_model', 'get_all_fit_models', 'FitConfiguration', 'FitConfigurationsModel',
'FitContainer')
import importlib
import logging
import inspect
import lmfit
import numpy as np
from PySide2 import QtCore
from typing import Iterable, Optional, Mapping, Union
import qudi.util.fit_models as _fit_models_ns
from qudi.util.mutex import Mutex
from qudi.util.units import create_formatted_output
from qudi.util.helpers import iter_modules_recursive
from qudi.util.fit_models.model import FitModelBase
_log = logging.getLogger(__name__)
[docs]
def is_fit_model(cls):
return inspect.isclass(cls) and issubclass(cls, FitModelBase) and (cls is not FitModelBase)
# Upon import of this module the global attribute _fit_models is initialized with a dict
# containing all importable fit model objects with names as keys.
_fit_models = dict()
for mod_finder in iter_modules_recursive(_fit_models_ns.__path__, _fit_models_ns.__name__ + '.'):
try:
_fit_models.update(
{name: cls for name, cls in
inspect.getmembers(importlib.import_module(mod_finder.name), is_fit_model)}
)
except:
_log.exception(
f'Exception while importing qudi.util.fit_models sub-module "{mod_finder.name}":'
)
[docs]
def get_all_fit_models():
return _fit_models.copy()
[docs]
class FitConfiguration:
"""
"""
[docs]
def __init__(self, name: str, model: str, estimator: Optional[str] = None, custom_parameters: Optional[lmfit.Parameters] = None):
assert isinstance(name, str), 'FitConfiguration name must be str type.'
assert name, 'FitConfiguration name must be non-empty string.'
assert model in _fit_models, f'Invalid fit model name encountered: "{model}".'
assert name != 'No Fit', '"No Fit" is a reserved name for fit configs. Choose another.'
self._name = name
self._model = model
self._estimator = None
self._custom_parameters = None
self.estimator = estimator
self.custom_parameters = custom_parameters
@property
def name(self):
return self._name
@property
def model(self):
return self._model
@property
def estimator(self):
return self._estimator
@estimator.setter
def estimator(self, value: Union[str, None]):
if value is not None:
assert value in self.available_estimators, \
f'Invalid fit model estimator encountered: "{value}"'
self._estimator = value
@property
def available_estimators(self):
return tuple(_fit_models[self._model]().estimators)
@property
def default_parameters(self):
params = _fit_models[self._model]().make_params()
return lmfit.Parameters() if params is None else params
@property
def custom_parameters(self):
return self._custom_parameters.copy() if self._custom_parameters is not None else None
@custom_parameters.setter
def custom_parameters(self, value: Union[lmfit.Parameters, None]):
if value is not None:
default_params = self.default_parameters
invalid = set(value).difference(default_params)
assert not invalid, f'Invalid model parameters encountered: {invalid}'
assert isinstance(value, lmfit.Parameters), \
'Property custom_parameters must be of type <lmfit.Parameters>.'
self._custom_parameters = value.copy() if value is not None else None
def to_dict(self):
return {
'name': self._name,
'model': self._model,
'estimator': self._estimator,
'custom_parameters': None if self._custom_parameters is None else self._custom_parameters.dumps()
}
@classmethod
def from_dict(cls, dict_repr):
assert set(dict_repr) == {'name', 'model', 'estimator', 'custom_parameters'}
if isinstance(dict_repr['custom_parameters'], str):
dict_repr['custom_parameters'] = lmfit.Parameters().loads(
dict_repr['custom_parameters']
)
return cls(**dict_repr)
[docs]
class FitConfigurationsModel(QtCore.QAbstractListModel):
"""
"""
sigFitConfigurationsChanged = QtCore.Signal(tuple)
[docs]
def __init__(self, *args, configurations=None, **kwargs):
assert (configurations is None) or all(isinstance(c, FitConfiguration) for c in configurations)
super().__init__(*args, **kwargs)
self._fit_configurations = list() if configurations is None else list(configurations)
@property
def model_names(self):
return tuple(_fit_models)
@property
def model_estimators(self):
return {name: tuple(model().estimators) for name, model in _fit_models.items()}
@property
def model_default_parameters(self):
return {name: model().make_params() for name, model in _fit_models.items()}
@property
def configuration_names(self):
return tuple(fc.name for fc in self._fit_configurations)
@property
def configurations(self):
return self._fit_configurations.copy()
@QtCore.Slot(str, str)
def add_configuration(self, name: str, model: str, estimator: Optional[str] = None, custom_parameters: Optional[lmfit.Parameters] = None):
assert name not in self.configuration_names, f'Fit config "{name}" already defined.'
assert name != 'No Fit', '"No Fit" is a reserved name for fit configs. Choose another.'
config = FitConfiguration(name, model, estimator, custom_parameters)
new_row = len(self._fit_configurations)
self.beginInsertRows(self.createIndex(new_row, 0), new_row, new_row)
self._fit_configurations.append(config)
self.endInsertRows()
self.sigFitConfigurationsChanged.emit(self.configuration_names)
@QtCore.Slot(str)
def remove_configuration(self, name):
try:
row_index = self.configuration_names.index(name)
except ValueError:
return
self.beginRemoveRows(self.createIndex(row_index, 0), row_index, row_index)
self._fit_configurations.pop(row_index)
self.endRemoveRows()
self.sigFitConfigurationsChanged.emit(self.configuration_names)
def get_configuration_by_name(self, name):
try:
row_index = self.configuration_names.index(name)
except ValueError:
raise ValueError(f'No fit configuration found with name "{name}".')
return self._fit_configurations[row_index]
def flags(self, index):
if index.isValid():
return QtCore.Qt.ItemIsEditable | QtCore.Qt.ItemIsEnabled
def rowCount(self, parent=QtCore.QModelIndex()):
return len(self._fit_configurations)
def headerData(self, section, orientation, role=QtCore.Qt.DisplayRole):
if role == QtCore.Qt.DisplayRole:
if (orientation == QtCore.Qt.Horizontal) and (section == 0):
return 'Fit Configurations'
elif orientation == QtCore.Qt.Vertical:
try:
return self.configuration_names[section]
except IndexError:
pass
return None
def data(self, index=QtCore.QModelIndex(), role=QtCore.Qt.DisplayRole):
if (role == QtCore.Qt.DisplayRole) and (index.isValid()):
try:
return self._fit_configurations[index.row()]
except IndexError:
pass
return None
def setData(self, index, value, role=QtCore.Qt.EditRole):
if index.isValid():
config = index.data(QtCore.Qt.DisplayRole)
if config is None:
return False
new_params = value[1]
params = config.default_parameters
for name in [p for p in params if p not in new_params]:
del params[name]
for name, p in params.items():
value_tuple = new_params[name]
p.set(vary=value_tuple[0],
value=value_tuple[1],
min=value_tuple[2],
max=value_tuple[3])
config.estimator = None if not value[0] else value[0]
config.custom_parameters = None if not params else params
self.dataChanged.emit(self.createIndex(index.row(), 0),
self.createIndex(index.row(), 0))
return True
return False
def dump_configs(self):
"""
Returns all currently held fit configurations as dictionary representations containing only
data types that can be dumped as YAML in the Qudi app status.
Returns
-------
list of dict
List of fit configuration dictionary representations.
"""
return [cfg.to_dict() for cfg in self._fit_configurations]
def load_configs(self, configs):
"""Initializes/overwrites all currently held fit configurations by a given iterable of dict
representations (see also: FitConfigurationsModel.dump_configs).
This method will reset the list model.
Parameters
----------
configs : iterable
Iterable of FitConfiguration dictionary representations.
See also: FitConfigurationsModel.dump_configs.
"""
config_objects = list()
for cfg in configs:
try:
config_objects.append(FitConfiguration.from_dict(cfg))
except:
_log.warning(f'Unable to load fit configuration:\n{cfg}')
self.beginResetModel()
self._fit_configurations = config_objects
self.endResetModel()
self.sigFitConfigurationsChanged.emit(self.configuration_names)
[docs]
class FitContainer(QtCore.QObject):
"""
"""
sigFitConfigurationsChanged = QtCore.Signal(tuple) # config_names
sigLastFitResultChanged = QtCore.Signal(str, object) # (fit_config name, lmfit.ModelResult)
[docs]
def __init__(self, *args, config_model, **kwargs):
assert isinstance(config_model, FitConfigurationsModel)
super().__init__(*args, **kwargs)
self._access_lock = Mutex()
self._configuration_model = config_model
self._last_fit_result = None
self._last_fit_config = 'No Fit'
self._configuration_model.sigFitConfigurationsChanged.connect(
self.sigFitConfigurationsChanged
)
@property
def fit_configurations(self):
return self._configuration_model.configurations
@property
def fit_configuration_names(self):
return self._configuration_model.configuration_names
@property
def last_fit(self):
with self._access_lock:
return self._last_fit_config, self._last_fit_result
@QtCore.Slot(str, object, object)
def fit_data(self, fit_config, x, data):
with self._access_lock:
if fit_config:
# Handle "No Fit" case
if fit_config == 'No Fit':
self._last_fit_result = None
self._last_fit_config = 'No Fit'
else:
config = self._configuration_model.get_configuration_by_name(fit_config)
model = _fit_models[config.model]()
estimator = config.estimator
add_parameters = config.custom_parameters
if estimator is None:
parameters = model.make_params()
else:
parameters = model.estimators[estimator](data, x)
if add_parameters is not None:
for name, param in add_parameters.items():
parameters[name] = param
result = model.fit(data, parameters, x=x)
# Mutate lmfit.ModelResult object to include high-resolution result curve
high_res_x = np.linspace(x[0], x[-1], len(x) * 10)
result.high_res_best_fit = (high_res_x,
model.eval(**result.best_values, x=high_res_x))
self._last_fit_result = result
self._last_fit_config = fit_config
self.sigLastFitResultChanged.emit(self._last_fit_config, self._last_fit_result)
return self._last_fit_config, self._last_fit_result
return '', None
@staticmethod
def formatted_result(fit_result: Union[None, lmfit.model.ModelResult],
parameters_units: Optional[Mapping[str, str]] = None) -> str:
if fit_result is None:
return ''
if parameters_units is None:
parameters_units = dict()
parameters_to_format = dict()
for name, param in fit_result.params.items():
stderr = param.stderr if param.vary else None
stderr = np.nan if param.vary and stderr is None else stderr
parameters_to_format[name] = {'value': param.value,
'error': stderr,
'unit': parameters_units.get(name, '')}
return create_formatted_output(parameters_to_format)
@staticmethod
def dict_result(fit_result: Union[None, lmfit.model.ModelResult],
parameters_units: Optional[Mapping[str, str]] = None,
export_keys: Optional[Iterable[str]] = ('value', 'stderr')) -> dict:
if fit_result is None:
return dict()
if parameters_units is None:
parameters_units = dict()
fitparams = fit_result.result.params
export_dict = {'model': fit_result.model.name}
for key, res in fitparams.items():
dict_i = {key: getattr(res, key) for key in export_keys}
dict_i['unit'] = parameters_units.get(key, '')
export_dict[key] = dict_i
return export_dict