본문 바로가기
카테고리 없음

sklearn - base (변환기 직접 만들기)

by sun__ 2020. 10. 2.

두 특성을 조합해서 유의미한 특성을 만든다던지 할 수 있는 경우들이 많다.

 

핸즈온 2장의 예제로 설명. 지역마다 특성들이 주어진다. 

 

가구 당 방의 수 'rooms_per_households'  :  방 개수 'total_rooms'와 가구 수 'households'를 조합

가구 당 인원 수 'population_per_household' : 인구 'population'과 구 수 'households'를 조합

 

위 두 특성을 추가하는 변환기 클래스를 짜보자.

 

sklearn.base의 TransformerMixin을 상속하면 fit, transform메서드만 만들어도 자동으로 fit_transform()메서드를 자동으로 생성해준다.

 

sklearn.base의 BaseEstimator를 상속하면 하이퍼파라미터 튜닝에 필요한 두 메서드 get_params()와 set_params()를 추가로 얻게 된다. (생성자에 *args나 **kargs 사용하면 안됨)

 

from sklearn.base import BaseEstimator, TransformerMixin

rooms_ix, bedrooms_ix, population_ix, households_ix = 3,4,5,6

class CombinedAttributesAdder(BaseEstimator, TransformerMixin):
    def __init__(self, add_bedrooms_per_room = True): 			#하이퍼 파라미터
        self.add_bedrooms_per_room = add_bedrooms_per_room
    def fit(self, X, y=None):						#fit에서 조정할 변수는 없음
        return self
    def transform(self, X):
        rooms_per_households = X[:, rooms_ix] / X[:, households_ix]
        population_per_household = X[:, population_ix] / X[:, households_ix]
        if self.add_bedrooms_per_room:
            bedrooms_per_room = X[:, bedrooms_ix] / X[:, rooms_ix]
            return np.c_[X, rooms_per_households, population_per_household,
                        bedrooms_per_room]
        else:
            return np.c_[X, rooms_per_households, population_per_household]
        
attr_adder = CombinedAttributesAdder(add_bedrooms_per_room=False)
housing_extra_attribs = attr_adder.transform(housing.values)
#array([[-121.89, 37.29, 38.0, ..., '<1H OCEAN', 4.625368731563422,
#        2.094395280235988],
#       [-121.93, 37.05, 14.0, ..., '<1H OCEAN', 6.008849557522124,
#        2.7079646017699117],
#       [-117.2, 32.77, 31.0, ..., 'NEAR OCEAN', 4.225108225108225,
#        2.0259740259740258],
#       ...,
#       [-116.4, 34.09, 9.0, ..., 'INLAND', 6.34640522875817,
#        2.742483660130719],
#       [-118.01, 33.82, 31.0, ..., '<1H OCEAN', 5.50561797752809,
#        3.808988764044944],
#       [-122.45, 37.77, 52.0, ..., 'NEAR BAY', 4.843505477308295,
#        1.9859154929577465]], dtype=object)