-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathone_hot_encoder_decoder.py
More file actions
70 lines (58 loc) · 2.18 KB
/
one_hot_encoder_decoder.py
File metadata and controls
70 lines (58 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
# Define all possible classes as a module-level constant
LAND_USE_CLASSES = [
'Urban fabric',
'Industrial or commercial units',
'Arable land',
'Permanent crops',
'Pastures',
'Complex cultivation patterns',
'Land principally occupied by agriculture, with significant areas of natural vegetation',
'Agro-forestry areas',
'Broad-leaved forest',
'Coniferous forest',
'Mixed forest',
'Natural grassland and sparsely vegetated areas',
'Moors, heathland and sclerophyllous vegetation',
'Transitional woodland, shrub',
'Beaches, dunes, sands',
'Inland wetlands',
'Coastal wetlands',
'Inland waters',
'Marine waters'
]
def one_hot_encode_land_use(class_list):
"""
Convert a list of land use class names into a single one-hot encoded vector using PyTorch.
Each class in the input list gets marked with a 1 in the output tensor.
Args:
class_list (list): List of land use class names
Returns:
torch.Tensor: 1D tensor where 1 indicates presence of a class
"""
# Create a dictionary mapping class names to indices
class_to_idx = {class_name: idx for idx, class_name in enumerate(LAND_USE_CLASSES)}
# Initialize the output tensor
output = torch.zeros(len(LAND_USE_CLASSES), dtype=torch.long)
# Set 1s for each class in the input list
for class_name in class_list:
if class_name in class_to_idx:
output[class_to_idx[class_name]] = 1
else:
raise ValueError(f"Unknown class name: {class_name}")
return output
def decode_land_use(one_hot_tensor):
"""
Convert one-hot encoded tensor back to list of class names.
Args:
one_hot_tensor (torch.Tensor): One-hot encoded 1D tensor
Returns:
list: List of class names where the tensor had 1s
"""
# Make sure input is a tensor
if not isinstance(one_hot_tensor, torch.Tensor):
one_hot_tensor = torch.tensor(one_hot_tensor)
# Get the indices where the tensor is 1
indices = torch.where(one_hot_tensor == 1)[0]
# Convert indices to class names
return [LAND_USE_CLASSES[idx] for idx in indices]