Source code for hyperts.framework.wrappers.nas_wrappers

import numpy as np
from hyperts.utils import consts

from hyperts.framework.nas import TSNASEstimator
from hyperts.framework.wrappers._base import EstimatorWrapper, WrapperMixin

from hypernets.utils import logging

logger = logging.get_logger(__name__)


[docs]class TSNASWrapper(EstimatorWrapper, WrapperMixin): """ Adapt: forecast, classification and regression. """ def __init__(self, fit_kwargs, **kwargs): kwargs = self.update_init_kwargs(**kwargs) super(TSNASWrapper, self).__init__(fit_kwargs, **kwargs) self.update_fit_kwargs() self.model = TSNASEstimator(**self.init_kwargs)
[docs] def fit(self, X, y=None, **kwargs): if self.drop_sample_rate: X, y = self.drop_hist_sample(X, y, **self.init_kwargs) fit_kwargs = self._merge_dict(self.fit_kwargs, kwargs) if self.init_kwargs.get('task') in consts.TASK_LIST_FORECAST: y = self.fit_transform(y) else: X = self.fit_transform(X) self.model.fit(X, y, **fit_kwargs)
[docs] def predict(self, X, **kwargs): if self.init_kwargs.get('task') in consts.TASK_LIST_FORECAST: preds = self.model.forecast(X) preds = self.inverse_transform(preds) preds = np.clip(preds, a_min=1e-6, a_max=abs(preds)) if self.is_scale is not None else preds return preds elif self.init_kwargs.get('task') in consts.TASK_LIST_CLASSIFICATION: X = self.transform(X) return self.model.predict(X) else: X = self.transform(X) return self.model.predict_proba(X)
[docs] def predict_proba(self, X, **kwargs): X = self.transform(X) return self.model.predict_proba(X)
@property def classes_(self): if self.init_kwargs.get('task') in consts.TASK_LIST_CLASSIFICATION: return self.model.meta.labels_ else: return None