Skip to content

Commit

Permalink
apache#13441 [Clojure] Add Spec Validations for the Random namespace (
Browse files Browse the repository at this point in the history
  • Loading branch information
hellonico authored and zhaoyao73 committed Dec 9, 2018
1 parent ad7ba29 commit c47b8c1
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
;;;train

;;initialize with random noise
img (ndarray/- (random/uniform 0 255 content-np-shape dev) 128)
img (ndarray/- (random/uniform 0 255 content-np-shape {:ctx dev}) 128)
;;; img (random/uniform -0.1 0.1 content-np-shape dev)
;; img content-np
lr-sched (lr-scheduler/factor-scheduler 10 0.9)
Expand Down
26 changes: 13 additions & 13 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
(org.apache.mxnet.optimizer SGD DCASGD NAG AdaDelta RMSProp AdaGrad Adam SGLD)
(org.apache.mxnet FactorScheduler)))

(s/def ::learning-rate float?)
(s/def ::momentum float?)
(s/def ::wd float?)
(s/def ::clip-gradient float?)
(s/def ::lr-scheduler #(instance? FactorScheduler))
(s/def ::learning-rate number?)
(s/def ::momentum number?)
(s/def ::wd number?)
(s/def ::clip-gradient number?)
(s/def ::lr-scheduler #(instance? FactorScheduler %))
(s/def ::sgd-opts (s/keys :opt-un [::learning-rate ::momentum ::wd ::clip-gradient ::lr-scheduler]))

(defn sgd
Expand All @@ -43,7 +43,7 @@
([]
(sgd {})))

(s/def ::lambda float?)
(s/def ::lambda number?)
(s/def ::dcasgd-opts (s/keys :opt-un [::learning-rate ::momentum ::lambda ::wd ::clip-gradient ::lr-scheduler]))

(defn dcasgd
Expand Down Expand Up @@ -77,9 +77,9 @@
([]
(nag {})))

(s/def ::rho float?)
(s/def ::rescale-gradient float?)
(s/def ::epsilon float?)
(s/def ::rho number?)
(s/def ::rescale-gradient number?)
(s/def ::epsilon number?)
(s/def ::ada-delta-opts (s/keys :opt-un [::rho ::rescale-gradient ::epsilon ::wd ::clip-gradient]))

(defn ada-delta
Expand All @@ -96,8 +96,8 @@
([]
(ada-delta {})))

(s/def gamma1 float?)
(s/def gamma2 float?)
(s/def gamma1 number?)
(s/def gamma2 number?)
(s/def ::rms-prop-opts (s/keys :opt-un [::learning-rate ::rescale-gradient ::gamma1 ::gamma2 ::wd ::clip-gradient]))

(defn rms-prop
Expand Down Expand Up @@ -144,8 +144,8 @@
([]
(ada-grad {})))

(s/def ::beta1 float?)
(s/def ::beta2 float?)
(s/def ::beta1 number?)
(s/def ::beta2 number?)
(s/def ::adam-opts (s/keys :opt-un [::learning-rate ::beta1 ::beta2 ::epsilon ::decay-factor ::wd ::clip-gradient ::lr-scheduler]))

(defn adam
Expand Down
30 changes: 27 additions & 3 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@
;;

(ns org.apache.clojure-mxnet.random
(:require [org.apache.clojure-mxnet.shape :as mx-shape])
(:import (org.apache.mxnet Random)))
(:require
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.context :as context]
[clojure.spec.alpha :as s]
[org.apache.clojure-mxnet.util :as util])
(:import (org.apache.mxnet Context Random)))

(s/def ::low number?)
(s/def ::high number?)
(s/def ::shape-vec (s/coll-of pos-int? :kind vector?))
(s/def ::ctx #(instance? Context %))
(s/def ::uniform-opts (s/keys :opt-un [::ctx]))

(defn uniform
"Generate uniform distribution in [low, high) with shape.
Expand All @@ -29,10 +39,18 @@
out: Output place holder}
returns: The result ndarray with generated result./"
([low high shape-vec {:keys [ctx out] :as opts}]
(util/validate! ::uniform-opts opts "Incorrect random uniform parameters")
(util/validate! ::low low "Incorrect random uniform parameter")
(util/validate! ::high high "Incorrect random uniform parameters")
(util/validate! ::shape-vec shape-vec "Incorrect random uniform parameters")
(Random/uniform (float low) (float high) (mx-shape/->shape shape-vec) ctx out))
([low high shape-vec]
(uniform low high shape-vec {})))

(s/def ::loc number?)
(s/def ::scale number?)
(s/def ::normal-opts (s/keys :opt-un [::ctx]))

(defn normal
"Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape.
loc: The standard deviation of the normal distribution
Expand All @@ -43,10 +61,15 @@
out: Output place holder}
returns: The result ndarray with generated result./"
([loc scale shape-vec {:keys [ctx out] :as opts}]
(util/validate! ::normal-opts opts "Incorrect random normal parameters")
(util/validate! ::loc loc "Incorrect random normal parameters")
(util/validate! ::scale scale "Incorrect random normal parameters")
(util/validate! ::shape-vec shape-vec "Incorrect random uniform parameters")
(Random/normal (float loc) (float scale) (mx-shape/->shape shape-vec) ctx out))
([loc scale shape-vec]
(normal loc scale shape-vec {})))

(s/def ::seed-state number?)
(defn seed
" Seed the random number generators in mxnet.
This seed will affect behavior of functions in this module,
Expand All @@ -58,4 +81,5 @@
This means if you set the same seed, the random number sequence
generated from GPU0 can be different from CPU."
[seed-state]
(Random/seed (int seed-state)))
(util/validate! ::seed-state seed-state "Incorrect seed parameters")
(Random/seed (int seed-state)))
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@
test (sym/transpose data)
shape-vec [3 4]
ctx (context/default-context)
arr-data (random/uniform 0 100 shape-vec ctx)
arr-data (random/uniform 0 100 shape-vec {:ctx ctx})
trans (ndarray/transpose (ndarray/copy arr-data))
exec-test (sym/bind test ctx {"data" arr-data})
out (-> (executor/forward exec-test)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
(let [[a b] [-10 10]
shape [100 100]
_ (random/seed 128)
un1 (random/uniform a b shape {:context ctx})
un1 (random/uniform a b shape {:ctx ctx})
_ (random/seed 128)
un2 (random/uniform a b shape {:context ctx})]
un2 (random/uniform a b shape {:ctx ctx})]
(is (= un1 un2))
(is (< (Math/abs
(/ (/ (apply + (ndarray/->vec un1))
Expand All @@ -52,3 +52,16 @@
(is (< (Math/abs (- mean mu)) 0.1))
(is (< (Math/abs (- stddev sigma)) 0.1)))))

(defn random-or-normal [fn_]
(is (thrown? Exception (fn_ 'a 2 [])))
(is (thrown? Exception (fn_ 1 'b [])))
(is (thrown? Exception (fn_ 1 2 [-1])))
(is (thrown? Exception (fn_ 1 2 [2 3 0])))
(is (thrown? Exception (fn_ 1 2 [10 10] {:ctx "a"})))
(let [ctx (context/default-context)]
(is (not (nil? (fn_ 1 1 [100 100] {:ctx ctx}))))))

(deftest test-random-parameters-specs
(random-or-normal random/normal)
(random-or-normal random/uniform)
(is (thrown? Exception (random/seed "a"))))

0 comments on commit c47b8c1

Please sign in to comment.