-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
76 lines (57 loc) · 2.03 KB
/
plot.py
File metadata and controls
76 lines (57 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import json
import sys
import os
import matplotlib
matplotlib.use("Agg") # for servers (no display)
import matplotlib.pyplot as plt
def moving_average(data, window_size):
if window_size <= 1:
return data
smoothed = []
for i in range(len(data)):
start = max(0, i - window_size + 1)
window = data[start:i + 1]
smoothed.append(sum(window) / len(window))
return smoothed
def plot_accuracy(log_path, max_step=None, smooth=1):
# Load logs
with open(log_path, "r") as f:
logs = json.load(f)
# Filter by step
if max_step is not None:
logs = [entry for entry in logs if entry["step"] <= max_step]
# Extract data
steps = [entry["step"] for entry in logs]
train_acc = [entry["train_acc"] for entry in logs]
test_acc = [entry["test_acc"] for entry in logs]
# Apply smoothing
train_acc_smooth = moving_average(train_acc, smooth)
test_acc_smooth = moving_average(test_acc, smooth)
# Create plot
plt.figure(figsize=(8, 5))
plt.plot(steps, train_acc_smooth, label=f"Train Accuracy (smooth={smooth})")
plt.plot(steps, test_acc_smooth, label=f"Test Accuracy (smooth={smooth})")
plt.xlabel("Step")
plt.ylabel("Accuracy")
plt.title("Training vs Test Accuracy")
plt.legend()
plt.grid(True)
plt.tight_layout()
# Save file
base = os.path.splitext(os.path.basename(log_path))[0]
suffix = ""
if max_step is not None:
suffix += f"_upto_{max_step}"
if smooth > 1:
suffix += f"_smooth_{smooth}"
save_path = os.path.join("logs", f"{base}_accuracy{suffix}.pdf")
plt.savefig(save_path)
print(f"Plot saved to {save_path}")
if __name__ == "__main__":
if len(sys.argv) not in [2, 3, 4]:
print("Usage: python plot.py logs/your_log.json [max_step] [smooth_window]")
sys.exit(1)
log_path = sys.argv[1]
max_step = int(sys.argv[2]) if len(sys.argv) >= 3 else None
smooth = int(sys.argv[3]) if len(sys.argv) == 4 else 1
plot_accuracy(log_path, max_step, smooth)