mpd-public/mpd/utils/one_hot.py
2023-10-23 15:45:14 +02:00

23 lines
487 B
Python

import torch
def to_one_hot(number, max_val, min_val=0):
assert max_val > min_val, "illegal parameters, max_val is not larger than min_val"
hot = torch.zeros(max_val - min_val)
hot[number] = 1
return hot
def from_one_hot(hot):
index = torch.where(hot)
assert len(index) == 1, "more than one value is 1 in one hot vector"
return index[0]
if __name__ == "__main__":
hot = to_one_hot(4, 0, 5)
print(hot)
val = from_one_hot(hot)
print(val)