23 lines
487 B
Python
23 lines
487 B
Python
import torch
|
|
|
|
|
|
def to_one_hot(number, max_val, min_val=0):
|
|
assert max_val > min_val, "illegal parameters, max_val is not larger than min_val"
|
|
hot = torch.zeros(max_val - min_val)
|
|
hot[number] = 1
|
|
return hot
|
|
|
|
|
|
def from_one_hot(hot):
|
|
index = torch.where(hot)
|
|
assert len(index) == 1, "more than one value is 1 in one hot vector"
|
|
|
|
return index[0]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
hot = to_one_hot(4, 0, 5)
|
|
print(hot)
|
|
val = from_one_hot(hot)
|
|
print(val)
|