2022년 10~11월에 진행한 제주도 도로 교통량 예측 AI 경진대회 관련 포스팅 글입니다.
시험기간이 겹쳐 모델링을 많이 못해서 아쉬운 대회였습니다.
시도해 본 알고리즘 : lightgbm , Randomforest, Decisiontree
제출 알고리즘 : lightgbm
예측과 달리 가장 높은 점수를 받았던 알고리즘 : deicsiontree
본 포스팅은 예측 알고리즘 코드입니다.
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
import gc
train = pd.read_csv('./train.csv')
test = pd.read_csv('./test.csv')
data_info=pd.read_csv('data_info.csv')
data_info=pd.concat([data_info[:4],data_info[5:]])
data_info=data_info.reset_index()[['변수명','변수 설명']]
data_info
| 변수명 | 변수 설명 | |
|---|---|---|
| 0 | id | 아이디 |
| 1 | base_date | 날짜 |
| 2 | day_of_week | 요일 |
| 3 | base_hour | 시간대 |
| 4 | lane_count | 차로수 |
| 5 | road_rating | 도로등급 |
| 6 | multi_linked | 중용구간 여부 |
| 7 | connect_code | 연결로 코드 |
| 8 | maximum_speed_limit | 최고속도제한 |
| 9 | weight_restricted | 통과제한하중 |
| 10 | height_restricted | 통과제한높이 |
| 11 | road_type | 도로유형 |
| 12 | start_latitude | 시작지점의 위도 |
| 13 | start_longitude | 시작지점의 경도 |
| 14 | start_turn_restricted | 시작 지점의 회전제한 유무 |
| 15 | end_latitude | 도착지점의 위도 |
| 16 | end_longitude | 도착지점의 경도 |
| 17 | end_turn_restricted | 도작지점의 회전제한 유무 |
| 18 | road_name | 도로명 |
| 19 | start_node_name | 시작지점명 |
| 20 | end_node_name | 도착지점명 |
| 21 | vehicle_restricted | 통과제한차량 |
| 22 | target | 평균속도(km) |
train
| id | base_date | day_of_week | base_hour | lane_count | road_rating | road_name | multi_linked | connect_code | maximum_speed_limit | ... | road_type | start_node_name | start_latitude | start_longitude | start_turn_restricted | end_node_name | end_latitude | end_longitude | end_turn_restricted | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | TRAIN_0000000 | 20220623 | 목 | 17 | 1 | 106 | 지방도1112호선 | 0 | 0 | 60.0 | ... | 3 | 제3교래교 | 33.427747 | 126.662612 | 없음 | 제3교래교 | 33.427749 | 126.662335 | 없음 | 52.0 |
| 1 | TRAIN_0000001 | 20220728 | 목 | 21 | 2 | 103 | 일반국도11호선 | 0 | 0 | 60.0 | ... | 0 | 광양사거리 | 33.500730 | 126.529107 | 있음 | KAL사거리 | 33.504811 | 126.526240 | 없음 | 30.0 |
| 2 | TRAIN_0000002 | 20211010 | 일 | 7 | 2 | 103 | 일반국도16호선 | 0 | 0 | 80.0 | ... | 0 | 창고천교 | 33.279145 | 126.368598 | 없음 | 상창육교 | 33.280072 | 126.362147 | 없음 | 61.0 |
| 3 | TRAIN_0000003 | 20220311 | 금 | 13 | 2 | 107 | 태평로 | 0 | 0 | 50.0 | ... | 0 | 남양리조트 | 33.246081 | 126.567204 | 없음 | 서현주택 | 33.245565 | 126.566228 | 없음 | 20.0 |
| 4 | TRAIN_0000004 | 20211005 | 화 | 8 | 2 | 103 | 일반국도12호선 | 0 | 0 | 80.0 | ... | 0 | 애월샷시 | 33.462214 | 126.326551 | 없음 | 애월입구 | 33.462677 | 126.330152 | 없음 | 38.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 4701212 | TRAIN_4701212 | 20211104 | 목 | 16 | 1 | 107 | - | 0 | 0 | 50.0 | ... | 0 | 대림사거리 | 33.422145 | 126.278125 | 없음 | 금덕해운 | 33.420955 | 126.273750 | 없음 | 20.0 |
| 4701213 | TRAIN_4701213 | 20220331 | 목 | 2 | 2 | 107 | - | 0 | 0 | 80.0 | ... | 3 | 광삼교 | 33.472505 | 126.424368 | 없음 | 광삼교 | 33.472525 | 126.424890 | 없음 | 65.0 |
| 4701214 | TRAIN_4701214 | 20220613 | 월 | 22 | 2 | 103 | 일반국도12호선 | 0 | 0 | 60.0 | ... | 0 | 고성교차로 | 33.447183 | 126.912579 | 없음 | 성산교차로 | 33.444121 | 126.912948 | 없음 | 30.0 |
| 4701215 | TRAIN_4701215 | 20211020 | 수 | 2 | 2 | 103 | 일반국도95호선 | 0 | 0 | 80.0 | ... | 0 | 제6광령교 | 33.443596 | 126.431817 | 없음 | 관광대학입구 | 33.444996 | 126.433332 | 없음 | 73.0 |
| 4701216 | TRAIN_4701216 | 20211019 | 화 | 6 | 2 | 107 | 경찰로 | 0 | 0 | 60.0 | ... | 0 | 서귀포경찰서 | 33.256785 | 126.508940 | 없음 | 시민공원 | 33.257130 | 126.510364 | 없음 | 35.0 |
4701217 rows × 23 columns
train.describe()
| base_date | base_hour | lane_count | road_rating | multi_linked | connect_code | maximum_speed_limit | vehicle_restricted | weight_restricted | height_restricted | road_type | start_latitude | start_longitude | end_latitude | end_longitude | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 | 4701217.0 | 4.701217e+06 | 4701217.0 | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 |
| mean | 2.021672e+07 | 1.192820e+01 | 1.836651e+00 | 1.049585e+02 | 4.762597e-04 | 2.660218e-01 | 6.125329e+01 | 0.0 | 5.618742e+03 | 0.0 | 6.152237e-01 | 3.338432e+01 | 1.265217e+02 | 3.338432e+01 | 1.265217e+02 | 4.278844e+01 |
| std | 4.555709e+03 | 6.722092e+00 | 6.877513e-01 | 1.840107e+00 | 2.181818e-02 | 5.227760e+00 | 1.213354e+01 | 0.0 | 1.395343e+04 | 0.0 | 1.211268e+00 | 1.012015e-01 | 1.563657e-01 | 1.011948e-01 | 1.563519e-01 | 1.595443e+01 |
| min | 2.021090e+07 | 0.000000e+00 | 1.000000e+00 | 1.030000e+02 | 0.000000e+00 | 0.000000e+00 | 3.000000e+01 | 0.0 | 0.000000e+00 | 0.0 | 0.000000e+00 | 3.324343e+01 | 1.261826e+02 | 3.324343e+01 | 1.261826e+02 | 1.000000e+00 |
| 25% | 2.021110e+07 | 6.000000e+00 | 1.000000e+00 | 1.030000e+02 | 0.000000e+00 | 0.000000e+00 | 5.000000e+01 | 0.0 | 0.000000e+00 | 0.0 | 0.000000e+00 | 3.326422e+01 | 1.264232e+02 | 3.326422e+01 | 1.264232e+02 | 3.000000e+01 |
| 50% | 2.022013e+07 | 1.200000e+01 | 2.000000e+00 | 1.060000e+02 | 0.000000e+00 | 0.000000e+00 | 6.000000e+01 | 0.0 | 0.000000e+00 | 0.0 | 0.000000e+00 | 3.341257e+01 | 1.265112e+02 | 3.341257e+01 | 1.265112e+02 | 4.300000e+01 |
| 75% | 2.022050e+07 | 1.800000e+01 | 2.000000e+00 | 1.070000e+02 | 0.000000e+00 | 0.000000e+00 | 7.000000e+01 | 0.0 | 0.000000e+00 | 0.0 | 0.000000e+00 | 3.347804e+01 | 1.265840e+02 | 3.347804e+01 | 1.265840e+02 | 5.400000e+01 |
| max | 2.022073e+07 | 2.300000e+01 | 3.000000e+00 | 1.070000e+02 | 1.000000e+00 | 1.030000e+02 | 8.000000e+01 | 0.0 | 5.000000e+04 | 0.0 | 3.000000e+00 | 3.355608e+01 | 1.269309e+02 | 3.355608e+01 | 1.269309e+02 | 1.130000e+02 |
import seaborn as sns
# id 제외
df=train.drop(columns='id')
#요일 전처리
df['day_of_week']=df['day_of_week'].map({'월':1,'화':2,'수':3,'목':4,'금':5,'토':6,'일':7})
df
| base_date | day_of_week | base_hour | lane_count | road_rating | road_name | multi_linked | connect_code | maximum_speed_limit | vehicle_restricted | ... | road_type | start_node_name | start_latitude | start_longitude | start_turn_restricted | end_node_name | end_latitude | end_longitude | end_turn_restricted | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 20220623 | 4 | 17 | 1 | 106 | 지방도1112호선 | 0 | 0 | 60.0 | 0.0 | ... | 3 | 제3교래교 | 33.427747 | 126.662612 | 없음 | 제3교래교 | 33.427749 | 126.662335 | 없음 | 52.0 |
| 1 | 20220728 | 4 | 21 | 2 | 103 | 일반국도11호선 | 0 | 0 | 60.0 | 0.0 | ... | 0 | 광양사거리 | 33.500730 | 126.529107 | 있음 | KAL사거리 | 33.504811 | 126.526240 | 없음 | 30.0 |
| 2 | 20211010 | 7 | 7 | 2 | 103 | 일반국도16호선 | 0 | 0 | 80.0 | 0.0 | ... | 0 | 창고천교 | 33.279145 | 126.368598 | 없음 | 상창육교 | 33.280072 | 126.362147 | 없음 | 61.0 |
| 3 | 20220311 | 5 | 13 | 2 | 107 | 태평로 | 0 | 0 | 50.0 | 0.0 | ... | 0 | 남양리조트 | 33.246081 | 126.567204 | 없음 | 서현주택 | 33.245565 | 126.566228 | 없음 | 20.0 |
| 4 | 20211005 | 2 | 8 | 2 | 103 | 일반국도12호선 | 0 | 0 | 80.0 | 0.0 | ... | 0 | 애월샷시 | 33.462214 | 126.326551 | 없음 | 애월입구 | 33.462677 | 126.330152 | 없음 | 38.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 4701212 | 20211104 | 4 | 16 | 1 | 107 | - | 0 | 0 | 50.0 | 0.0 | ... | 0 | 대림사거리 | 33.422145 | 126.278125 | 없음 | 금덕해운 | 33.420955 | 126.273750 | 없음 | 20.0 |
| 4701213 | 20220331 | 4 | 2 | 2 | 107 | - | 0 | 0 | 80.0 | 0.0 | ... | 3 | 광삼교 | 33.472505 | 126.424368 | 없음 | 광삼교 | 33.472525 | 126.424890 | 없음 | 65.0 |
| 4701214 | 20220613 | 1 | 22 | 2 | 103 | 일반국도12호선 | 0 | 0 | 60.0 | 0.0 | ... | 0 | 고성교차로 | 33.447183 | 126.912579 | 없음 | 성산교차로 | 33.444121 | 126.912948 | 없음 | 30.0 |
| 4701215 | 20211020 | 3 | 2 | 2 | 103 | 일반국도95호선 | 0 | 0 | 80.0 | 0.0 | ... | 0 | 제6광령교 | 33.443596 | 126.431817 | 없음 | 관광대학입구 | 33.444996 | 126.433332 | 없음 | 73.0 |
| 4701216 | 20211019 | 2 | 6 | 2 | 107 | 경찰로 | 0 | 0 | 60.0 | 0.0 | ... | 0 | 서귀포경찰서 | 33.256785 | 126.508940 | 없음 | 시민공원 | 33.257130 | 126.510364 | 없음 | 35.0 |
4701217 rows × 22 columns
train[train.describe().columns]
| base_date | base_hour | lane_count | road_rating | multi_linked | connect_code | maximum_speed_limit | vehicle_restricted | weight_restricted | height_restricted | road_type | start_latitude | start_longitude | end_latitude | end_longitude | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 20220623 | 17 | 1 | 106 | 0 | 0 | 60.0 | 0.0 | 32400.0 | 0.0 | 3 | 33.427747 | 126.662612 | 33.427749 | 126.662335 | 52.0 |
| 1 | 20220728 | 21 | 2 | 103 | 0 | 0 | 60.0 | 0.0 | 0.0 | 0.0 | 0 | 33.500730 | 126.529107 | 33.504811 | 126.526240 | 30.0 |
| 2 | 20211010 | 7 | 2 | 103 | 0 | 0 | 80.0 | 0.0 | 0.0 | 0.0 | 0 | 33.279145 | 126.368598 | 33.280072 | 126.362147 | 61.0 |
| 3 | 20220311 | 13 | 2 | 107 | 0 | 0 | 50.0 | 0.0 | 0.0 | 0.0 | 0 | 33.246081 | 126.567204 | 33.245565 | 126.566228 | 20.0 |
| 4 | 20211005 | 8 | 2 | 103 | 0 | 0 | 80.0 | 0.0 | 0.0 | 0.0 | 0 | 33.462214 | 126.326551 | 33.462677 | 126.330152 | 38.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 4701212 | 20211104 | 16 | 1 | 107 | 0 | 0 | 50.0 | 0.0 | 0.0 | 0.0 | 0 | 33.422145 | 126.278125 | 33.420955 | 126.273750 | 20.0 |
| 4701213 | 20220331 | 2 | 2 | 107 | 0 | 0 | 80.0 | 0.0 | 43200.0 | 0.0 | 3 | 33.472505 | 126.424368 | 33.472525 | 126.424890 | 65.0 |
| 4701214 | 20220613 | 22 | 2 | 103 | 0 | 0 | 60.0 | 0.0 | 0.0 | 0.0 | 0 | 33.447183 | 126.912579 | 33.444121 | 126.912948 | 30.0 |
| 4701215 | 20211020 | 2 | 2 | 103 | 0 | 0 | 80.0 | 0.0 | 0.0 | 0.0 | 0 | 33.443596 | 126.431817 | 33.444996 | 126.433332 | 73.0 |
| 4701216 | 20211019 | 6 | 2 | 107 | 0 | 0 | 60.0 | 0.0 | 0.0 | 0.0 | 0 | 33.256785 | 126.508940 | 33.257130 | 126.510364 | 35.0 |
4701217 rows × 16 columns
ind=train.describe().columns.drop(['base_date','start_latitude','start_longitude','end_latitude','end_longitude'])
ind
Index(['base_hour', 'lane_count', 'road_rating', 'multi_linked',
'connect_code', 'maximum_speed_limit', 'vehicle_restricted',
'weight_restricted', 'height_restricted', 'road_type', 'target'],
dtype='object')
train[ind].describe()
| base_hour | lane_count | road_rating | multi_linked | connect_code | maximum_speed_limit | vehicle_restricted | weight_restricted | height_restricted | road_type | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 | 4.701217e+06 | 4701217.0 | 4.701217e+06 | 4701217.0 | 4.701217e+06 | 4.701217e+06 |
| mean | 1.192820e+01 | 1.836651e+00 | 1.049585e+02 | 4.762597e-04 | 2.660218e-01 | 6.125329e+01 | 0.0 | 5.618742e+03 | 0.0 | 6.152237e-01 | 4.278844e+01 |
| std | 6.722092e+00 | 6.877513e-01 | 1.840107e+00 | 2.181818e-02 | 5.227760e+00 | 1.213354e+01 | 0.0 | 1.395343e+04 | 0.0 | 1.211268e+00 | 1.595443e+01 |
| min | 0.000000e+00 | 1.000000e+00 | 1.030000e+02 | 0.000000e+00 | 0.000000e+00 | 3.000000e+01 | 0.0 | 0.000000e+00 | 0.0 | 0.000000e+00 | 1.000000e+00 |
| 25% | 6.000000e+00 | 1.000000e+00 | 1.030000e+02 | 0.000000e+00 | 0.000000e+00 | 5.000000e+01 | 0.0 | 0.000000e+00 | 0.0 | 0.000000e+00 | 3.000000e+01 |
| 50% | 1.200000e+01 | 2.000000e+00 | 1.060000e+02 | 0.000000e+00 | 0.000000e+00 | 6.000000e+01 | 0.0 | 0.000000e+00 | 0.0 | 0.000000e+00 | 4.300000e+01 |
| 75% | 1.800000e+01 | 2.000000e+00 | 1.070000e+02 | 0.000000e+00 | 0.000000e+00 | 7.000000e+01 | 0.0 | 0.000000e+00 | 0.0 | 0.000000e+00 | 5.400000e+01 |
| max | 2.300000e+01 | 3.000000e+00 | 1.070000e+02 | 1.000000e+00 | 1.030000e+02 | 8.000000e+01 | 0.0 | 5.000000e+04 | 0.0 | 3.000000e+00 | 1.130000e+02 |
train.loc[:500,ind]
| base_hour | lane_count | road_rating | multi_linked | connect_code | maximum_speed_limit | vehicle_restricted | weight_restricted | height_restricted | road_type | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 17 | 1 | 106 | 0 | 0 | 60.0 | 0.0 | 32400.0 | 0.0 | 3 | 52.0 |
| 1 | 21 | 2 | 103 | 0 | 0 | 60.0 | 0.0 | 0.0 | 0.0 | 0 | 30.0 |
| 2 | 7 | 2 | 103 | 0 | 0 | 80.0 | 0.0 | 0.0 | 0.0 | 0 | 61.0 |
| 3 | 13 | 2 | 107 | 0 | 0 | 50.0 | 0.0 | 0.0 | 0.0 | 0 | 20.0 |
| 4 | 8 | 2 | 103 | 0 | 0 | 80.0 | 0.0 | 0.0 | 0.0 | 0 | 38.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 496 | 3 | 3 | 106 | 0 | 0 | 70.0 | 0.0 | 0.0 | 0.0 | 0 | 66.0 |
| 497 | 1 | 2 | 103 | 0 | 0 | 80.0 | 0.0 | 0.0 | 0.0 | 0 | 76.0 |
| 498 | 14 | 1 | 103 | 0 | 0 | 50.0 | 0.0 | 0.0 | 0.0 | 0 | 49.0 |
| 499 | 16 | 1 | 106 | 0 | 0 | 50.0 | 0.0 | 0.0 | 0.0 | 0 | 55.0 |
| 500 | 9 | 1 | 107 | 0 | 0 | 30.0 | 0.0 | 0.0 | 0.0 | 0 | 22.0 |
501 rows × 11 columns
import matplotlib.pyplot as plt
plt.figure(figsize=(10,10))
sns.heatmap(train.loc[:500,ind].astype(float).corr(), linewidths = 0.1, vmax = 1.0, fmt=".1f",
square = True, cmap = 'Blues', linecolor = "white", annot = True, annot_kws = {"size" : 16})
<AxesSubplot:>

str_col = ['day_of_week','start_turn_restricted','end_turn_restricted']
for i in str_col:
le = LabelEncoder()
le=le.fit(train[i])
train[i]=le.transform(train[i])
for label in np.unique(test[i]):
if label not in le.classes_:
le.classes_ = np.append(le.classes_, label)
test[i]=le.transform(test[i])
train=train.loc[:2700000]
valid=train.loc[2700000:]
y_train = train['target']
X_train = train.drop(['id','base_date', 'target','road_name', 'start_node_name', 'end_node_name','vehicle_restricted'], axis=1)
y_valid = valid['target']
X_valid = valid.drop(['id','base_date', 'target','road_name', 'start_node_name', 'end_node_name','vehicle_restricted'], axis=1)
test = test.drop(['id','base_date', 'road_name', 'start_node_name', 'end_node_name','vehicle_restricted'], axis=1)
print(X_train.shape)
print(y_train.shape)
print(test.shape)
(2700001, 16)
(2700001,)
(291241, 16)
LR = lgb.LGBMRegressor(random_state=42).fit(X_train, y_train)
pred = LR.predict(X_valid)
from sklearn.metrics import mean_absolute_error
mae = mean_absolute_error(y_valid,pred)
print(mae)
1.7515819641866202
from sklearn.tree import DecisionTreeRegressor
tree = DecisionTreeRegressor(random_state=123).fit(X_train, y_train)
pred = tree.predict(X_valid)
from sklearn.metrics import mean_absolute_error
mae = mean_absolute_error(y_valid,pred)
print(mae)
2.6666666666666643
from sklearn.ensemble import RandomForestRegressor
rf=RandomForestRegressor().fit(X_train,y_train)
pred = rf.predict(X_valid)
from sklearn.metrics import mean_absolute_error
mae = mean_absolute_error(y_valid,pred)
print(mae)
2.6569172934459857
grid = {
'n_estimators' : [100,200],
'max_depth' : [6,8,10,12],
'min_samples_leaf' : [3,5,7,10],
'min_samples_split' : [2,3,5,10]
}
from sklearn.model_selection import GridSearchCV
classifier_grid = GridSearchCV(LR, param_grid = grid, scoring="accuracy", n_jobs=-1, verbose =1)
classifier_grid.fit(X_train, y_train)
print("최고 평균 정확도 : {}".format(classifier_grid.best_score_))
print("최고의 파라미터 :", classifier_grid.best_params_)
Fitting 5 folds for each of 128 candidates, totalling 640 fits
/opt/homebrew/Caskroom/miniforge/base/lib/python3.9/site-packages/sklearn/model_selection/_search.py:969: UserWarning: One or more of the test scores are non-finite: [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan]
warnings.warn(
[LightGBM] [Warning] Unknown parameter: min_samples_leaf
[LightGBM] [Warning] Unknown parameter: min_samples_split
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
최고 평균 정확도 : nan
최고의 파라미터 : {'max_depth': 6, 'min_samples_leaf': 3, 'min_samples_split': 2, 'n_estimators': 100}
수행시간이 너무 오래걸려 random search로 대체
grid = {
'n_estimators' : [100,200],
'max_depth' : [6,8,10],
'min_samples_leaf' : [3,5,7],
'min_samples_split' : [2,3,5]
}
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import RandomizedSearchCV
rf=RandomForestRegressor().fit(X_train,y_train)
rand = RandomizedSearchCV(rf, grid, n_jobs=-1, cv=5, n_iter=3, random_state=1)
rand_result = rand.fit(X_train, y_train)
print(rand_result.best_score_)
print(rand_result.best_params_)
0.809383983301521
{'n_estimators': 100, 'min_samples_split': 3, 'min_samples_leaf': 5, 'max_depth': 10}
LR = lgb.LGBMRegressor(random_state=42).fit(X_train, y_train)
from sklearn.model_selection import RandomizedSearchCV
rand = RandomizedSearchCV(LR, grid, n_jobs=-1, cv=5, n_iter=3, random_state=1)
rand_result = rand.fit(X_train, y_train)
print(rand_result.best_score_)
print(rand_result.best_params_)
0.8682401067973238
{'n_estimators': 200, 'min_samples_split': 2, 'min_samples_leaf': 10, 'max_depth': 12}
#데이터 로드
train = pd.read_csv('./train.csv')
test = pd.read_csv('./test.csv')
# 데이터 전처리
str_col = ['day_of_week','start_turn_restricted','end_turn_restricted']
for i in str_col:
le = LabelEncoder()
le=le.fit(train[i])
train[i]=le.transform(train[i])
for label in np.unique(test[i]):
if label not in le.classes_:
le.classes_ = np.append(le.classes_, label)
test[i]=le.transform(test[i])
#훈련, 테스트 데이터 분리
y_train = train['target']
X_train = train.drop(['id','base_date', 'target','road_name', 'start_node_name', 'end_node_name','vehicle_restricted'], axis=1)
test = test.drop(['id','base_date', 'road_name', 'start_node_name', 'end_node_name','vehicle_restricted'], axis=1)
#모델 훈련 & 예측
LR = lgb.LGBMRegressor(n_estimators=100,min_samples_split=2, min_samples_leaf=3, max_depth=6).fit(X_train,y_train)
pred = LR.predict(test)
sample_submission = pd.read_csv('sample_submission.csv')
sample_submission['target'] = pred
sample_submission.to_csv("./submit.csv", index = False)
sample_submission
| id | target | |
|---|---|---|
| 0 | TEST_000000 | 26.497447 |
| 1 | TEST_000001 | 45.309778 |
| 2 | TEST_000002 | 61.148937 |
| 3 | TEST_000003 | 34.604018 |
| 4 | TEST_000004 | 36.333576 |
| ... | ... | ... |
| 291236 | TEST_291236 | 46.819428 |
| 291237 | TEST_291237 | 51.820056 |
| 291238 | TEST_291238 | 20.387144 |
| 291239 | TEST_291239 | 26.638933 |
| 291240 | TEST_291240 | 40.510140 |
291241 rows × 2 columns
댓글