Skip to content

Commit 590e2cd

Browse files
committed
up
1 parent 1010aa3 commit 590e2cd

2 files changed

Lines changed: 82 additions & 107 deletions

File tree

classification.ipynb

Lines changed: 24 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -460,16 +460,9 @@
460460
"훈련된 숫자-5 감별기(`sgd_clf`)에 처음 10개 샘플을 넣어 결정 함숫값을 확인해보면 다음과 같다. 첫 번째 샘플의 점수만 양수이고 나머지 9개는 모두 음수이므로, 첫 번째 샘플만 '5'로 판별되고 나머지 9개는 '5가 아니다'라고 판별된다.\n",
461461
"\n",
462462
"```python\n",
463-
"array([ 1200.93051237,\n",
464-
" -26883.79202424,\n",
465-
" -33072.03475406,\n",
466-
" -15919.5480689 ,\n",
467-
" -20003.53970191,\n",
468-
" -16652.87731528,\n",
469-
" -14276.86944263,\n",
470-
" -23328.13728948,\n",
471-
" -5172.79611432,\n",
472-
" -13873.5025381 ])\n",
463+
"array([ 1200.93051237, -26883.79202424, -33072.03475406, -15919.5480689 ,\n",
464+
" -20003.53970191, -16652.87731528, -14276.86944263, -23328.13728948,\n",
465+
" -5172.79611432, -13873.5025381 ])\n",
473466
"```"
474467
]
475468
},
@@ -559,20 +552,6 @@
559552
"<p><div align=\"center\"><img src=\"https://github.com/codingalzi/code-workout-ml/blob/master/images/ch03/homl03-04.png?raw=true\" width=\"500\"/></div></p>"
560553
]
561554
},
562-
{
563-
"cell_type": "markdown",
564-
"id": "8bc79216",
565-
"metadata": {},
566-
"source": [
567-
":::{note} 정밀도 90% 분류기 구현\n",
568-
"\n",
569-
"위 그래프에서 검은색 수직 점선은 정밀도는 90%, 재현율은 50% 정도가 되게 하는 결정 임곗값의\n",
570-
"위치를 표시하며, 실제 값은 3,000이다.\n",
571-
"결정 임곗값을 변경하여 원하는 정밀도와 재현율을 갖는 숫자-5 감별기를 구현하는 방법은\n",
572-
"[정밀도 90%분류기 구현](https://colab.research.google.com/github/codingalzi/code-workout-ml/blob/master/notebooks/code-classification.ipynb#scrollTo=cKfpQLyuCHkf)에서 자세히 소개한다.\n",
573-
":::"
574-
]
575-
},
576555
{
577556
"cell_type": "markdown",
578557
"id": "2e50fe7b-754d-4fad-a33e-d23cb38c79bd",
@@ -591,6 +570,19 @@
591570
"<div align=\"center\"><img src=\"https://github.com/codingalzi/code-workout-ml/blob/master/images/ch03/homl03-05.png?raw=true\" width=\"400\"/></div>"
592571
]
593572
},
573+
{
574+
"cell_type": "markdown",
575+
"id": "9f645fc2",
576+
"metadata": {},
577+
"source": [
578+
"**정밀도 90% 분류기 구현**\n",
579+
"\n",
580+
"위 그래프에서 검은색 수직 점선은 정밀도는 90%, 재현율은 50% 정도가 되게 하는 결정 임곗값의\n",
581+
"위치를 표시하며, 실제 값은 3,000이다.\n",
582+
"결정 임곗값을 변경하여 원하는 정밀도와 재현율을 갖는 숫자-5 감별기를 구현하는 방법은\n",
583+
"[정밀도 90%분류기 구현](https://colab.research.google.com/github/codingalzi/code-workout-ml/blob/master/notebooks/code-classification.ipynb#scrollTo=cKfpQLyuCHkf)에서 자세히 소개한다."
584+
]
585+
},
594586
{
595587
"cell_type": "markdown",
596588
"id": "e1c5700b-cc7f-4223-a118-060381026313",
@@ -744,18 +736,8 @@
744736
}
745737
},
746738
"source": [
747-
"**다중 클래스 분류 모델의 혼동 행렬**"
748-
]
749-
},
750-
{
751-
"cell_type": "markdown",
752-
"id": "9772ffb0-9665-4b85-8023-1e1dd3592bfb",
753-
"metadata": {
754-
"slideshow": {
755-
"slide_type": "slide"
756-
}
757-
},
758-
"source": [
739+
"**다중 클래스 분류 모델의 혼동 행렬**\n",
740+
"\n",
759741
"아래 왼쪽 이미지는 훈련된 다중 클래스 분류 모델의 혼동 행렬을 색상으로 시각화한 결과다. 주대각선 상의 색상이 전반적으로 밝은 것은 모델의 분류가 대체로 정확하게 이루어졌음을 의미한다. 다만, 5번과 8번 행의 색상이 다른 곳보다 상대적으로 어두운데, 이는 모델이 '숫자 5'와 '숫자 8'을 분류하는 정확도가 상대적으로 낮다는 것을 보여준다.\n",
760742
"\n",
761743
"반면에 아래 오른쪽 이미지는 오분류의 패턴을 더 명확히 파악하기 위해 혼동 행렬의 값을 각 숫자의 전체 개수 대비 비율로 변환한 결과다. 즉, 각 행(실제 숫자)에 속한 값들의 합이 100%가 되도록 정규화(Normalization)하였다.\n",
@@ -772,14 +754,8 @@
772754
}
773755
},
774756
"source": [
775-
"**오차율 활용**"
776-
]
777-
},
778-
{
779-
"cell_type": "markdown",
780-
"id": "323e178a",
781-
"metadata": {},
782-
"source": [
757+
"**오차율 활용**\n",
758+
"\n",
783759
"위 오른쪽 이미지를 보면 손글씨 이미지를 '숫자 8' 또는 '숫자 9'로 잘못 판별한 비율이 가장 높다. 즉, 8번과 9번 열의 오분류 비율 합이 다른 열에 비해 상대적으로 크다.\n",
784760
"\n",
785761
"올바르게 예측된 샘플을 제외하고 행(실제 범주)을 기준으로 오분류 비율만 시각화하면 아래 왼쪽 이미지와 같다. 8번과 9번 열이 전반적으로 밝게 나타나며, 이는 많은 손글씨 이미지가 8 또는 9로 잘못 판별되었음을 의미한다. \n",
@@ -808,18 +784,8 @@
808784
}
809785
},
810786
"source": [
811-
"**개별 오류 확인**"
812-
]
813-
},
814-
{
815-
"cell_type": "markdown",
816-
"id": "107be02c-37c7-4e0f-92ca-571298d57bb4",
817-
"metadata": {
818-
"slideshow": {
819-
"slide_type": "slide"
820-
}
821-
},
822-
"source": [
787+
"**개별 오류 확인**\n",
788+
"\n",
823789
"위 오른쪽 이미지를 보면 숫자 5로 잘못 판별된 이미지 중에서 실제 숫자가 3인 이미지의 비율이 29%로 상당히 높다는 것을 알 수 있다. \n",
824790
"\n",
825791
"구체적인 오분류 양상을 확인하기 위해 숫자 3과 5 데이터에 대해서만 예측 결과를 혼동 행렬 형태로 시각화하면 아래와 같다. 여기서 행은 실제 범주(라벨)를, 열은 예측 범주(라벨)를 나타낸다."
@@ -875,7 +841,7 @@
875841
"source": [
876842
"(2) MNIST 테스트셋에 대한 정확도가 97% 이상 나오는 MNIST 분류기를 훈련시켜 보아라.\n",
877843
"\n",
878-
"힌트:"
844+
"힌트: [정확도 97% 성능의 MNIST 분류기](https://colab.research.google.com/github/codingalzi/code-workout-ml/blob/master/notebooks/code-classification.ipynb#scrollTo=wsPjQ28VCHkm)"
879845
]
880846
},
881847
{
@@ -885,7 +851,7 @@
885851
"source": [
886852
"(3) [Kaggle](https://www.kaggle.com/c/titanic) 타이타닉 데이터셋으로 생존 여부 분류 문제를 풀어 보아라.\n",
887853
"\n",
888-
"힌트:"
854+
"힌트: [타이타닉 데이터셋 도전](https://colab.research.google.com/github/codingalzi/code-workout-ml/blob/master/notebooks/code-classification.ipynb#scrollTo=D6CIjV_ICHko)"
889855
]
890856
}
891857
],

notebooks/code-classification.ipynb

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -571,19 +571,8 @@
571571
"source": [
572572
"이미지가 숫자 5를 표현하는지 여부만을 판단하는 이진 분류기의 훈련을 위해 라벨을 0 또는 1로 변경한다.\n",
573573
"\n",
574-
"* 0: 숫자 5 아님.\n",
575-
"* 1: 숫자 5 맞음."
576-
]
577-
},
578-
{
579-
"cell_type": "markdown",
580-
"metadata": {
581-
"id": "d842heVUWNtg"
582-
},
583-
"source": [
584-
"- 라벨(타깃) 재정의\n",
585-
" - 원래 0부터 9까지에서 0과 1로 변환\n",
586-
" - 5일 때만 1, 나머지는 0"
574+
"- 원래 0부터 9까지에서 0과 1로 변환\n",
575+
"- 5일 때만 1, 나머지는 0"
587576
]
588577
},
589578
{
@@ -2083,8 +2072,10 @@
20832072
"source": [
20842073
"`precision_recall_curve()` 함수를 이용하여 결정 임계값의 변화에 따른 정밀도와 재현율을 확인한다.\n",
20852074
"\n",
2086-
"- `precisions`: 60,001 개의 정밀도: 지정된 결정 임계값에 따른 정밀도. 마지막 항목으로 1 추가.\n",
2087-
"- `recalls`: 60,001 개의 재현율: 지정된 결정 임계값에 따른 재현율. 마지막 항목으로 0 추가.\n",
2075+
"- `precisions`: 60,001 개의 정밀도. 지정된 결정 임계값에 따른 정밀도. 마지막 항목은 1이며,\n",
2076+
" 결정 임곗값을 키우면 최종적으로 정밀도가 1이 됨을 반영함.\n",
2077+
"- `recalls`: 60,001 개의 재현율. 지정된 결정 임계값에 따른 재현율. 마지막 항목은 0 이며,\n",
2078+
" 결정 임곗값을 키우면 최종적으로 재현율 0이 됨을 반영함.\n",
20882079
"- `thresholds`: 60,000 개의 결정 임계값. 각각의 결정 임계값에 맞춰 정밀도와 재현율 계산."
20892080
]
20902081
},
@@ -2107,8 +2098,13 @@
21072098
"id": "EUsFjgSZiI_w"
21082099
},
21092100
"source": [
2110-
"`thresholds`는 `y_scores`를 오름차순으로 정렬한 어레이이다.\n",
2111-
"아래 코드가 이점을 확인해준다."
2101+
"`thresholds`는 `y_scores`를 오름차순으로 정렬한 어레이이며,\n",
2102+
"각 항목을 결정 임곗값으로 사용할 때의 정밀도와 재현율의 변화를\n",
2103+
"그래프로 그릴 때 활용된다.\n",
2104+
"\n",
2105+
"훈련셋이 6만개에 대해 계산된 결정 함수값으로 구성된 `y_scores`에 중복값이 포함될 수도 있기에\n",
2106+
"일반적으로 `thresholds`는 6만개보다 적은 값이 포함할 수도 있다.\n",
2107+
"하지만 여기서는 모두 서로 다른 6만개의 값이 결정 함숫값으로 계산되었다."
21122108
]
21132109
},
21142110
{
@@ -2176,7 +2172,11 @@
21762172
"id": "t51JUZZKLcMY"
21772173
},
21782174
"source": [
2179-
"**정밀도/재현율 그래프**"
2175+
"**정밀도/재현율 그래프**\n",
2176+
"\n",
2177+
"위 그래프를 재현율 대 정밀도 그래프로 변환하면 다음과 같다.\n",
2178+
"결정 임곗값(threshold)을 낮춰 재현율을 올리면 정밀도는 떨어지는\n",
2179+
"상호 반비례 관계를 잘 보여준다."
21802180
]
21812181
},
21822182
{
@@ -2240,7 +2240,7 @@
22402240
"id": "cKfpQLyuCHkf"
22412241
},
22422242
"source": [
2243-
"### 90% 정밀도 분류기 구현"
2243+
"### 정밀도 90% 분류기 구현"
22442244
]
22452245
},
22462246
{
@@ -2375,6 +2375,24 @@
23752375
"## 다중 클래스 분류"
23762376
]
23772377
},
2378+
{
2379+
"cell_type": "markdown",
2380+
"metadata": {},
2381+
"source": [
2382+
"### 다중 클래스 분류 지원 모델"
2383+
]
2384+
},
2385+
{
2386+
"cell_type": "markdown",
2387+
"metadata": {},
2388+
"source": [
2389+
"아래에 언급된 모델들을 포함한 많은 모델이 이진 분류와 다중 클래스 분류를 모두 지원한다.\n",
2390+
"\n",
2391+
"* `LogisticRegression` 모델\n",
2392+
"* `RandomForestClassifier` 모델\n",
2393+
"* `SGDClassifier` 모델"
2394+
]
2395+
},
23782396
{
23792397
"cell_type": "markdown",
23802398
"metadata": {
@@ -2462,6 +2480,13 @@
24622480
"np.array([0.87365, 0.85835, 0.8689 ]).mean()"
24632481
]
24642482
},
2483+
{
2484+
"cell_type": "markdown",
2485+
"metadata": {},
2486+
"source": [
2487+
"**스케일링의 중요성**"
2488+
]
2489+
},
24652490
{
24662491
"cell_type": "markdown",
24672492
"metadata": {
@@ -2570,7 +2595,7 @@
25702595
"id": "Wc7da2abCHkj"
25712596
},
25722597
"source": [
2573-
"## 다중 클래스 분류기 모델의 오류 분석"
2598+
"### 다중 클래스 분류기 모델의 오류 분석"
25742599
]
25752600
},
25762601
{
@@ -2606,19 +2631,11 @@
26062631
"cell_type": "markdown",
26072632
"metadata": {},
26082633
"source": [
2609-
"### 다중 클래스 분류 모델의 혼동 행렬"
2610-
]
2611-
},
2612-
{
2613-
"cell_type": "markdown",
2614-
"metadata": {
2615-
"id": "nULw8eG-CHkk"
2616-
},
2617-
"source": [
2618-
"아래 두 이미지는 위 예측값을 이용하여 다중 클래스 분류기의 혼동 행렬을 그린다.\n",
2634+
"**다중 클래스 분류 모델의 혼동 행렬**\n",
2635+
"\n",
2636+
"아래 왼쪽 이미지는 훈련된 다중 클래스 분류 모델의 혼동 행렬을 색상으로 시각화한 결과다. 주대각선 상의 색상이 전반적으로 밝은 것은 모델의 분류가 대체로 정확하게 이루어졌음을 의미한다. 다만, 5번과 8번 행의 색상이 다른 곳보다 상대적으로 어두운데, 이는 모델이 '숫자 5'와 '숫자 8'을 분류하는 정확도가 상대적으로 낮다는 것을 보여준다.\n",
26192637
"\n",
2620-
"- 왼쪽 이미지: 다중 클래스 분류 모델의 혼동 행렬을 색상과 함께 표현한다.\n",
2621-
"- 오른쪽 이미지: 행별로 비율의 합이 100%가 되도록 정규화한다."
2638+
"반면에 아래 오른쪽 이미지는 오분류의 패턴을 더 명확히 파악하기 위해 혼동 행렬의 값을 각 숫자의 전체 개수 대비 비율로 변환한 결과다. 즉, 각 행(실제 숫자)에 속한 값들의 합이 100%가 되도록 정규화(Normalization)하였다."
26222639
]
26232640
},
26242641
{
@@ -2665,19 +2682,15 @@
26652682
},
26662683
{
26672684
"cell_type": "markdown",
2668-
"metadata": {
2669-
"id": "WIEtWr1yCHkk"
2670-
},
2685+
"metadata": {},
26712686
"source": [
26722687
"**오차율 활용**\n",
26732688
"\n",
2674-
"- 왼쪽 이미지: 올바르게 예측된 샘플을 제외한 후에 행별로 오인된 숫자의 비율을 확인하면 다음과 같다\n",
2675-
" - `sample_weight` 키워드 인자 활용\n",
2676-
" - 행별로 합이 1\n",
2677-
" - 많은 숫자가 9로 잘못 예측됨\n",
2678-
"- 오른쪽 이미지: 칸별 정규화 진행 결과를 보여준다.\n",
2679-
" - 열별로 합이 1\n",
2680-
" - 7로 오인된 숫자중에 9가 41% 차지"
2689+
"위 오른쪽 이미지를 보면 손글씨 이미지를 '숫자 8' 또는 '숫자 9'로 잘못 판별한 비율이 가장 높다. 즉, 8번과 9번 열의 오분류 비율 합이 다른 열에 비해 상대적으로 크다.\n",
2690+
"\n",
2691+
"올바르게 예측된 샘플을 제외하고 행(실제 범주)을 기준으로 오분류 비율만 시각화하면 아래 왼쪽 이미지와 같다. 8번과 9번 열이 전반적으로 밝게 나타나며, 이는 많은 손글씨 이미지가 8 또는 9로 잘못 판별되었음을 의미한다. \n",
2692+
"\n",
2693+
"아래 오른쪽 이미지는 열(예측 범주)을 기준으로 오분류 데이터를 정규화한 결과를 보여준다. 예를 들어, 7로 오인된 전체 손글씨 이미지 중에서 실제 숫자 9인 이미지의 비율이 41%에 달한다는 점을 확인할 수 있다."
26812694
]
26822695
},
26832696
{
@@ -2727,17 +2740,13 @@
27272740
},
27282741
{
27292742
"cell_type": "markdown",
2730-
"metadata": {
2731-
"id": "v90KlamQCHkk"
2732-
},
2743+
"metadata": {},
27332744
"source": [
27342745
"**개별 오류 확인**\n",
27352746
"\n",
2736-
"위 오른쪽 이미지에 의하면 5로 오인된 이미지 중에서 숫자 3 이미지의 비율이 38%로 가장 높다.\n",
2737-
"실제로 혼동 행렬과 유사한 행렬을 3과 5에 대해 나타내면 다음과 같다.\n",
2747+
"위 오른쪽 이미지를 보면 숫자 5로 잘못 판별된 이미지 중에서 실제 숫자가 3인 이미지의 비율이 29%로 상당히 높다는 것을 알 수 있다. \n",
27382748
"\n",
2739-
"* 음성: 3으로 판별\n",
2740-
"* 양성: 5로 판별"
2749+
"구체적인 오분류 양상을 확인하기 위해 숫자 3과 5 데이터에 대해서만 예측 결과를 혼동 행렬 형태로 시각화하면 아래와 같다. 여기서 행은 실제 범주(라벨)를, 열은 예측 범주(라벨)를 나타낸다."
27412750
]
27422751
},
27432752
{

0 commit comments

Comments
 (0)